Coverage for lapspython/translation.py: 88%
191 statements
« prev ^ index » next coverage.py v6.4.4, created at 2022-08-27 20:15 +0200
« prev ^ index » next coverage.py v6.4.4, created at 2022-08-27 20:15 +0200
1"""Implements functions for translation from lambda calculus to Python."""
3import logging
4import re
5import traceback
7from dreamcoder.program import (Abstraction, Application, Index, Invented,
8 Primitive, Program)
9from lapspython.types import (ParsedGrammar, ParsedProgram, ParsedProgramBase,
10 ParsedRProgram, ParsedType)
13class Translator:
14 """Translate lambda programs to Python code."""
16 def __init__(self, grammar: ParsedGrammar) -> None:
17 """Init grammar used for translation and empty containers.
19 :param grammar: Grammar used for translation
20 :type grammar: lapspython.types.ParsedGrammar
21 """
22 self.mode = grammar.mode
24 if self.mode == 'python':
25 self.sep = ' = '
26 else:
27 self.sep = ' <- '
29 self.grammar = grammar
30 self.call_counts = {p: 0 for p in self.grammar.primitives}
31 self.call_counts.update({i: 0 for i in self.grammar.invented})
32 self.code: list = []
33 self.args: list = []
34 self.imports: set = set()
35 self.dependencies: set = set()
36 self.debug_stack: list = []
37 self.logger = self.setup_logger()
39 def setup_logger(self) -> logging.Logger:
40 """Setup a logger for exceptions caught during translation.
42 rtype: logging.Logger
43 """
44 logger = logging.getLogger(__name__)
45 handler = logging.FileHandler('translation.log', 'w')
46 handler.setLevel(logging.DEBUG)
47 handler.setFormatter(logging.Formatter('%(message)s'))
48 logger.addHandler(handler)
49 return logger
51 def log_exception(self):
52 """Write current debug stack into translation.log."""
53 self.logger.debug(f'{self.name}\n')
54 for entry in self.debug_stack:
55 self.logger.debug(entry)
56 if len(self.code) > 0:
57 code = '\n'.join(self.code)
58 self.logger.debug(f'\n{code}')
59 self.logger.debug(f'\n{traceback.format_exc()}\n')
61 def translate(self, program: Program, name: str) -> ParsedProgramBase:
62 """Translate a synthesized program under the current grammar.
64 :param program: Abstraction/Invented at any depth of lambda expression
65 :type program: subclass of dreamcoder.program.Program
66 :param name: Task/Function name
67 :type name: string
68 :returns: Translated program
69 :rtype: ParsedProgram
70 """
71 for call in self.call_counts:
72 self.call_counts[call] = 0
73 self.code = []
74 self.name = name
75 arg_types = ParsedType.parse_argument_types(program.infer())
76 n_args = len(arg_types) - 1
77 self.args = [f'arg{i + 1}' for i in range(n_args)]
78 self.imports = set()
79 self.dependencies = set()
80 self.name = name
81 self.debug_stack = []
83 self.translate_wrapper(program)
85 source = '\n'.join(self.code)
87 if self.mode == 'python':
88 last_variable_assignments = re.findall(r'\w+ = ', source)
89 if len(last_variable_assignments) > 0:
90 split = source.split(last_variable_assignments[-1])
91 source = 'return '.join(split)
93 return ParsedProgram(
94 name,
95 source,
96 self.args,
97 self.imports,
98 self.dependencies
99 )
101 return ParsedRProgram(
102 name,
103 source,
104 self.args,
105 self.imports,
106 self.dependencies
107 )
109 def translate_wrapper(self, program: Program, node_type: str = 'body'):
110 """Redirect node to corresponding translation procedure.
112 :param program: Node of program tree.
113 :type program: Subclass of dreamcoder.program.Program
114 """
115 debug = (str(program), str(type(program)), node_type)
116 self.debug_stack.append(', '.join(debug))
118 if program.isAbstraction:
119 if node_type == 'x':
120 return self._translate_abstraction_x(program)
121 return self._translate_abstraction_body(program)
122 if program.isApplication:
123 if node_type == 'f':
124 return self._translate_application_f(program)
125 if node_type == 'x':
126 return self._translate_application_x(program)
127 return self._translate_application_body(program)
128 if program.isIndex:
129 return self._translate_index(program)
130 if program.isInvented:
131 if node_type == 'f':
132 return self._translate_invented(program)
133 return self._translate_abstraction_body(program)
134 if program.isPrimitive:
135 if node_type == 'f':
136 return self._translate_primitive_f(program)
137 if node_type == 'x':
138 return self._translate_primitive_x(program)
139 return self._translate_primitive_body(program)
140 raise ValueError(f'{node_type} node of type {type(program)}')
142 def _translate_abstraction_body(self, abstraction: Abstraction) -> tuple:
143 parsed, args = self.translate_wrapper(abstraction.body)
144 args = [f'lambda x: {args[0]}']
145 return parsed, args
147 def _translate_abstraction_x(self, abstraction: Abstraction) -> tuple:
148 parsed, args = self.translate_wrapper(abstraction.body)
150 try:
151 lambda_head = ''
152 if self.mode == 'python':
153 lambda_head = 'lambda lx: '
154 if self.contains_index(abstraction) and len(self.code) > 0:
155 last_row = self.code[-1]
156 body = last_row.split(self.sep)[1]
157 body = re.sub(r'arg\d', 'lx', body)
158 args = [f'{lambda_head}{body}']
159 else:
160 args = [f'{lambda_head}{args[0]}']
162 return parsed, args
164 except IndexError:
165 return '# ERROR', ['# ERROR']
167 def _translate_application_f(self, application: Application) -> tuple:
168 f = application.f
169 x = application.x
171 _, x_args = self.translate_wrapper(x, 'x')
172 f_parsed, f_args = self.translate_wrapper(f, 'f')
174 return f_parsed, f_args + x_args
176 def _translate_application_x(self, application: Application) -> tuple:
177 f = application.f
178 x = application.x
180 x_parsed, x_args = self.translate_wrapper(x, 'x')
181 f_parsed, f_args = self.translate_wrapper(f, 'f')
183 if x_args[-1][:3] == 'arg' and x_args[-1] not in self.args:
184 new_x_arg = self.get_last_variable()
185 if new_x_arg != 'x':
186 x_args[-1] = new_x_arg
188 if not f.isInvented:
189 x_args = f_args + x_args
191 if not f.isIndex:
192 self.call_counts[f_parsed.handle] += 1
193 name = f'{f_parsed.name}_{self.call_counts[f_parsed.handle]}'
194 try:
195 f_parsed_resolved = f_parsed.resolve_variables(x_args, name)
196 except ValueError:
197 self.log_exception()
198 f_parsed_resolved = f'{name} = None'
199 self.code.append(f_parsed_resolved)
200 x_args = name
202 return f_parsed, [x_args]
204 def _translate_application_body(self, application: Application) -> tuple:
205 f = application.f
206 x = application.x
208 x_parsed, x_args = self.translate_wrapper(x, 'x')
209 f_parsed, f_args = self.translate_wrapper(f, 'f')
211 x_args = f_args + x_args
213 self.call_counts[f_parsed.handle] += 1
214 name = f'{f_parsed.name}_{self.call_counts[f_parsed.handle]}'
216 missing_args = len(f_parsed.args) - len(x_args)
217 for i in range(missing_args):
218 new_arg = f'arg{i + 1}'
219 if new_arg in self.args:
220 x_args.append(new_arg)
222 try:
223 f_parsed_resolved = f_parsed.resolve_variables(x_args, name)
224 except ValueError:
225 self.log_exception()
226 f_parsed_resolved = f'{name} = None'
228 self.code.append(f_parsed_resolved)
230 return None, [name]
232 def _translate_index(self, index: Index) -> tuple:
233 arg = f'arg{index.i + 1}'
234 f_parsed = f'lambda lx: {arg}'
235 return f_parsed, [arg]
237 def _translate_invented(self, invented: Invented) -> tuple:
238 handle = str(invented)
239 f_parsed = self.grammar.invented[handle]
240 if f_parsed.source == '':
241 translator = Translator(self.grammar)
242 f_trans = translator.translate(f_parsed.program, f_parsed.name)
243 f_parsed.source = f_trans.source
244 f_parsed.args = f_trans.args
245 f_parsed.dependencies = f_trans.dependencies
246 self.imports.update(f_parsed.imports)
247 self.dependencies.update(f_parsed.dependencies)
248 self.dependencies.add(str(f_parsed))
249 name = f'{f_parsed.name}_{self.call_counts[handle]}'
250 x_args = [name]
252 return f_parsed, x_args
254 def _translate_primitive_f(self, primitive: Primitive) -> tuple:
255 parsed = self.grammar.primitives[primitive.name].resolve_lambdas()
256 self.imports.update(parsed.imports)
257 self.dependencies.update(parsed.dependencies)
258 return parsed, []
260 def _translate_primitive_x(self, primitive: Primitive) -> tuple:
261 return None, [f"'{primitive.value}'"]
263 def _translate_primitive_body(self, primitive: Primitive) -> tuple:
264 return None, [f"'{primitive.value}'"]
266 def contains_index(self, program: Program) -> bool:
267 """Test whether the subprogram contains a de Bruijin index."""
268 if program.isIndex:
269 return True
270 if program.isPrimitive:
271 return False
272 if 'body' in dir(program):
273 return self.contains_index(program.body)
274 else:
275 f = self.contains_index(program.f)
276 x = self.contains_index(program.x)
277 return (f or x)
279 def get_last_variable(self) -> str:
280 """Return the declared variable in the last line of code."""
281 if len(self.code) == 0:
282 return ''
283 return self.code[-1].split(self.sep)[0]