Coverage for lapspython/pipeline.py: 98%

54 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-08-27 20:15 +0200

1"""Pipe all necessary steps to extract, translate and store programs.""" 

2 

3from dreamcoder.dreamcoder import ECResult 

4from lapspython.extraction import GrammarParser, ProgramExtractor 

5from lapspython.stats import Statistics 

6from lapspython.translation import Translator 

7from lapspython.types import CompactResult 

8from lapspython.utils import json_dump, json_read, load_checkpoint 

9 

10 

11class Pipeline: 

12 """Pipelines the entire extraction/translation process of LapsPython.""" 

13 

14 @classmethod 

15 def extract_translate( 

16 cls, 

17 result: ECResult, 

18 json_path: str = '', 

19 mode: str = 'python', 

20 verbose: bool = True 

21 ) -> CompactResult: 

22 """Extract and translate programs from a LAPS result. 

23 

24 :param result: A LAPS result. 

25 :type result: dreamcoder.dreamcoder.ECResult 

26 :param json_path: Path to dump or read from json. 

27 :type json_path: string 

28 :returns: Extracted and translated programs 

29 :rtype: lapspython.types.CompactResult 

30 """ 

31 mode = mode.lower() 

32 if mode not in ('python', 'r'): 

33 raise ValueError('mode must be "Python" or "R".') 

34 

35 print(f'Language Mode: {mode.upper()}') 

36 if mode == 'r': 

37 print('WARNING: Code verification for R not implemented') 

38 print('\nParsing library...', flush=True) 

39 parser = GrammarParser(result.grammars[-1], mode) 

40 

41 json = json_read(json_path) 

42 if json != {}: 

43 new_invented = json['grammar'].invented 

44 parser.fix_invented(new_invented) 

45 grammar = parser.parsed_grammar 

46 

47 print('\nTranslating synthesized programs...', flush=True) 

48 translator = Translator(grammar) 

49 extractor = ProgramExtractor(result, translator) 

50 result = extractor.compact_result 

51 

52 if json_path != '': 

53 print(f'\nSaving results to {json_path}...', end=' ') 

54 json_dump(json_path, grammar, extractor.compact_result) 

55 print('Done') 

56 

57 if not verbose: 

58 return result 

59 

60 if mode == 'python': 

61 print('\nCollecting descriptive statistics:') 

62 stats = Statistics(result) 

63 print(stats) 

64 stats.plot_histogram(result) 

65 

66 print(f'\nSampling 1 {"valid "*(mode == "python")}translation:') 

67 sample = result.sample() 

68 if len(sample) > 0: 

69 print(sample['annotation']) 

70 print(sample['best_program'], end='\n\n') 

71 print(sample['best_valid_translation']) 

72 else: 

73 print('No validated translation found') 

74 

75 return result 

76 

77 @classmethod 

78 def from_checkpoint( 

79 cls, 

80 filepath: str, 

81 mode='python', 

82 verbose=True, 

83 save=True 

84 ) -> CompactResult: 

85 """Load checkpoint, then extract and translate. 

86 

87 :param filepath: Checkpoint name in checkpoints directory. 

88 :type filepath: str 

89 :param mode: Translate to 'python' or 'r'. 

90 :type mode: str 

91 :param verbose: Whether to print statistics and sample translations. 

92 :type verbose: bool 

93 :param save: Whether to save the results in a JSON file. 

94 :type save: bool 

95 :returns: Extracted and translated programs 

96 :rtype: lapspython.types.CompactResult 

97 """ 

98 print(f'Loading checkpoint {filepath}...', end=' ') 

99 result = load_checkpoint(filepath) 

100 print('Done\n') 

101 if save: 

102 json_path = f'{filepath}_{mode.lower()}' 

103 else: 

104 json_path = '' 

105 return cls.extract_translate(result, json_path, mode, verbose)