Coverage for lapspython/extraction.py: 96%
79 statements
« prev ^ index » next coverage.py v6.4.4, created at 2022-08-27 20:15 +0200
« prev ^ index » next coverage.py v6.4.4, created at 2022-08-27 20:15 +0200
1"""Implements classes to extract primitives and lambda expressions."""
3from tqdm import tqdm
5from dreamcoder.dreamcoder import ECResult
6from dreamcoder.grammar import Grammar
7from dreamcoder.program import Invented, Primitive
8from lapspython.translation import Translator
9from lapspython.types import (CompactFrontier, CompactResult, ParsedGrammar,
10 ParsedInvented, ParsedPrimitive, ParsedRInvented,
11 ParsedRPrimitive)
14class GrammarParser:
15 """Extract, parse, and store all primitives from grammar."""
17 def __init__(self, grammar: Grammar = None, mode='python') -> None:
18 """Optionally parse grammar if passed during construction.
20 :param grammar: A grammar induced by LAPS.
21 :type grammar: dreamcoder.grammar.Grammar, optional
22 :param mode: Whether to extract Python or R code, can 'python' or 'r'.
23 :type mode: string, optional
24 """
25 self.mode = mode.lower()
26 if self.mode not in ('python', 'r'):
27 raise ValueError('mode must be "Python" or "R".')
29 if grammar is not None:
30 self.parse(grammar)
31 else:
32 self.parsed_grammar = ParsedGrammar({}, {})
34 def parse(self, grammar: Grammar) -> ParsedGrammar:
35 """Convert Primitive objects to simplified ParsedPrimitive objects.
37 :param grammar: A grammar induced inside main() or ecIterator().
38 :type grammar: dreamcoder.grammar.Grammar
39 :rtype: ParsedGrammar
40 """
41 parsed_primitives: dict = {}
42 parsed_invented: dict = {}
44 for _, _, primitive in tqdm(grammar.productions):
45 if isinstance(primitive, Primitive):
46 name = primitive.name
47 if name not in parsed_primitives:
48 if self.mode == 'python':
49 parsed_primitive = ParsedPrimitive(primitive)
50 parsed_primitive = parsed_primitive.resolve_lambdas()
51 parsed_primitives[name] = parsed_primitive
52 else:
53 parsed_primitives[name] = ParsedRPrimitive(primitive)
55 elif not isinstance(primitive, Invented):
56 raise TypeError(f'Encountered unknown type {type(primitive)}.')
58 elif str(primitive) not in parsed_invented:
59 handle = str(primitive)
60 name = f'f{len(parsed_invented)}'
61 if self.mode == 'python':
62 parsed_invented[handle] = ParsedInvented(primitive, name)
63 else:
64 parsed_invented[handle] = ParsedRInvented(primitive, name)
66 self.parsed_grammar = ParsedGrammar(
67 parsed_primitives,
68 parsed_invented,
69 self.mode
70 )
72 translator = Translator(self.parsed_grammar)
73 for invented in self.parsed_grammar.invented.values():
74 if invented.source == '':
75 trans = translator.translate(invented.program, invented.name)
76 invented.source = trans.source
77 invented.args = trans.args
78 invented.dependencies = trans.dependencies
80 return self.parsed_grammar
82 def fix_invented(self, new_invented: dict) -> None:
83 """Replace invented primitives implementations.
85 :param new_invented: Invented primitives from JSON file.
86 :type new_invented: dict
87 """
88 this_invented = self.parsed_grammar.invented
89 if set(this_invented.keys()) != set(new_invented.keys()):
90 raise ValueError('Keys of the two grammars are not equal.')
92 for handle in this_invented:
93 new_data = new_invented[handle]
94 this_invented[handle].name = new_data['name']
95 this_invented[handle].source = new_data['source']
96 this_invented[handle].args = new_data['args']
97 this_invented[handle].dependencies = new_data['dependencies']
100class ProgramExtractor:
101 """Extract, parse and translate synthesized programs."""
103 def __init__(self, result: ECResult = None,
104 translator: Translator = None) -> None:
105 """Optionally extract programs if passed during construction.
107 :param result: A result produced by LAPS or checkpoint.
108 :type result: dreamcoder.dreamcoder.ECResult, optional
109 :param translator: Translator to translate programs during extraction.
110 :type translator: lapspython.translation.Translator, optional
111 """
112 if result is not None:
113 self.extract(result, translator)
114 else:
115 self.compact_result = CompactResult({}, {})
117 def extract(self, result: ECResult,
118 translator: Translator = None) -> CompactResult:
119 """Extract all frontiers with descriptions and frontiers.
121 :param result: Result of dreamcoder execution (checkpoint)
122 :type result: dreamcoder.dreamcoder.ECResult
123 :param translator: Translator to translate programs during extraction.
124 :type translator: lapspython.translation.Translator, optional
125 :rtype: lapspython.types.CompactResult
126 """
127 hit_frontiers = {}
128 miss_frontiers = {}
130 for frontier in tqdm(result.allFrontiers.values()):
131 name = frontier.task.name
132 annotation = result.taskLanguage.get(name, '')[0]
133 compact_frontier = CompactFrontier(frontier, annotation)
134 if frontier.empty:
135 miss_frontiers[name] = compact_frontier
136 else:
137 hit_frontiers[name] = compact_frontier
139 if translator is not None:
140 for program in compact_frontier.programs:
141 transl = translator.translate(program, name)
142 try:
143 if transl.verify(compact_frontier.examples):
144 compact_frontier.translations.append(transl)
145 else:
146 compact_frontier.failed.append(transl)
147 except BaseException:
148 compact_frontier.failed.append(transl)
150 self.compact_result = CompactResult(hit_frontiers, miss_frontiers)
151 return self.compact_result