Coverage for lapspython/pipeline.py: 98%
54 statements
« prev ^ index » next coverage.py v6.4.4, created at 2022-08-27 20:15 +0200
« prev ^ index » next coverage.py v6.4.4, created at 2022-08-27 20:15 +0200
1"""Pipe all necessary steps to extract, translate and store programs."""
3from dreamcoder.dreamcoder import ECResult
4from lapspython.extraction import GrammarParser, ProgramExtractor
5from lapspython.stats import Statistics
6from lapspython.translation import Translator
7from lapspython.types import CompactResult
8from lapspython.utils import json_dump, json_read, load_checkpoint
11class Pipeline:
12 """Pipelines the entire extraction/translation process of LapsPython."""
14 @classmethod
15 def extract_translate(
16 cls,
17 result: ECResult,
18 json_path: str = '',
19 mode: str = 'python',
20 verbose: bool = True
21 ) -> CompactResult:
22 """Extract and translate programs from a LAPS result.
24 :param result: A LAPS result.
25 :type result: dreamcoder.dreamcoder.ECResult
26 :param json_path: Path to dump or read from json.
27 :type json_path: string
28 :returns: Extracted and translated programs
29 :rtype: lapspython.types.CompactResult
30 """
31 mode = mode.lower()
32 if mode not in ('python', 'r'):
33 raise ValueError('mode must be "Python" or "R".')
35 print(f'Language Mode: {mode.upper()}')
36 if mode == 'r':
37 print('WARNING: Code verification for R not implemented')
38 print('\nParsing library...', flush=True)
39 parser = GrammarParser(result.grammars[-1], mode)
41 json = json_read(json_path)
42 if json != {}:
43 new_invented = json['grammar'].invented
44 parser.fix_invented(new_invented)
45 grammar = parser.parsed_grammar
47 print('\nTranslating synthesized programs...', flush=True)
48 translator = Translator(grammar)
49 extractor = ProgramExtractor(result, translator)
50 result = extractor.compact_result
52 if json_path != '':
53 print(f'\nSaving results to {json_path}...', end=' ')
54 json_dump(json_path, grammar, extractor.compact_result)
55 print('Done')
57 if not verbose:
58 return result
60 if mode == 'python':
61 print('\nCollecting descriptive statistics:')
62 stats = Statistics(result)
63 print(stats)
64 stats.plot_histogram(result)
66 print(f'\nSampling 1 {"valid "*(mode == "python")}translation:')
67 sample = result.sample()
68 if len(sample) > 0:
69 print(sample['annotation'])
70 print(sample['best_program'], end='\n\n')
71 print(sample['best_valid_translation'])
72 else:
73 print('No validated translation found')
75 return result
77 @classmethod
78 def from_checkpoint(
79 cls,
80 filepath: str,
81 mode='python',
82 verbose=True,
83 save=True
84 ) -> CompactResult:
85 """Load checkpoint, then extract and translate.
87 :param filepath: Checkpoint name in checkpoints directory.
88 :type filepath: str
89 :param mode: Translate to 'python' or 'r'.
90 :type mode: str
91 :param verbose: Whether to print statistics and sample translations.
92 :type verbose: bool
93 :param save: Whether to save the results in a JSON file.
94 :type save: bool
95 :returns: Extracted and translated programs
96 :rtype: lapspython.types.CompactResult
97 """
98 print(f'Loading checkpoint {filepath}...', end=' ')
99 result = load_checkpoint(filepath)
100 print('Done\n')
101 if save:
102 json_path = f'{filepath}_{mode.lower()}'
103 else:
104 json_path = ''
105 return cls.extract_translate(result, json_path, mode, verbose)