"""Implements functions for translation from lambda calculus to Python."""
import logging
import re
import traceback
from dreamcoder.program import (Abstraction, Application, Index, Invented,
Primitive, Program)
from lapspython.types import (ParsedGrammar, ParsedProgram, ParsedProgramBase,
ParsedRProgram, ParsedType)
[docs]class Translator:
"""Translate lambda programs to Python code."""
def __init__(self, grammar: ParsedGrammar) -> None:
"""Init grammar used for translation and empty containers.
:param grammar: Grammar used for translation
:type grammar: lapspython.types.ParsedGrammar
"""
self.mode = grammar.mode
if self.mode == 'python':
self.sep = ' = '
else:
self.sep = ' <- '
self.grammar = grammar
self.call_counts = {p: 0 for p in self.grammar.primitives}
self.call_counts.update({i: 0 for i in self.grammar.invented})
self.code: list = []
self.args: list = []
self.imports: set = set()
self.dependencies: set = set()
self.debug_stack: list = []
self.logger = self.setup_logger()
[docs] def setup_logger(self) -> logging.Logger:
"""Set up a logger for exceptions caught during translation.
rtype: logging.Logger
"""
logger = logging.getLogger(__name__)
handler = logging.FileHandler('translation.log', 'w')
handler.setLevel(logging.DEBUG)
handler.setFormatter(logging.Formatter('%(message)s'))
logger.addHandler(handler)
return logger
[docs] def log_exception(self):
"""Write current debug stack into translation.log."""
self.logger.debug(f'{self.name}\n')
for entry in self.debug_stack:
self.logger.debug(entry)
if len(self.code) > 0:
code = '\n'.join(self.code)
self.logger.debug(f'\n{code}')
self.logger.debug(f'\n{traceback.format_exc()}\n')
[docs] def translate(self, program: Program, name: str) -> ParsedProgramBase:
"""Translate a synthesized program under the current grammar.
:param program: Abstraction/Invented at any depth of lambda expression
:type program: subclass of dreamcoder.program.Program
:param name: Task/Function name
:type name: string
:returns: Translated program
:rtype: ParsedProgram
"""
for call in self.call_counts:
self.call_counts[call] = 0
self.code = []
self.name = name
arg_types = ParsedType.parse_argument_types(program.infer())
n_args = len(arg_types) - 1
self.args = [f'arg{i + 1}' for i in range(n_args)]
self.imports = set()
self.dependencies = set()
self.name = name
self.debug_stack = []
self.translate_wrapper(program)
source = '\n'.join(self.code)
if self.mode == 'python':
last_variable_assignments = re.findall(r'\w+ = ', source)
if len(last_variable_assignments) > 0:
split = source.split(last_variable_assignments[-1])
source = 'return '.join(split)
return ParsedProgram(
name,
source,
self.args,
self.imports,
self.dependencies
)
return ParsedRProgram(
name,
source,
self.args,
self.imports,
self.dependencies
)
[docs] def translate_wrapper(self, program: Program, node_type: str = 'body'):
"""Redirect node to corresponding translation procedure.
:param program: Node of program tree.
:type program: Subclass of dreamcoder.program.Program
"""
debug = (str(program), str(type(program)), node_type)
self.debug_stack.append(', '.join(debug))
if program.isAbstraction:
if node_type == 'x':
return self._translate_abstraction_x(program)
return self._translate_abstraction_body(program)
if program.isApplication:
if node_type == 'f':
return self._translate_application_f(program)
if node_type == 'x':
return self._translate_application_x(program)
return self._translate_application_body(program)
if program.isIndex:
return self._translate_index(program)
if program.isInvented:
if node_type == 'f':
return self._translate_invented(program)
return self._translate_abstraction_body(program)
if program.isPrimitive:
if node_type == 'f':
return self._translate_primitive_f(program)
if node_type == 'x':
return self._translate_primitive_x(program)
return self._translate_primitive_body(program)
raise ValueError(f'{node_type} node of type {type(program)}')
def _translate_abstraction_body(self, abstraction: Abstraction) -> tuple:
parsed, args = self.translate_wrapper(abstraction.body)
args = [f'lambda x: {args[0]}']
return parsed, args
def _translate_abstraction_x(self, abstraction: Abstraction) -> tuple:
parsed, args = self.translate_wrapper(abstraction.body)
try:
lambda_head = ''
if self.mode == 'python':
lambda_head = 'lambda lx: '
if self.contains_index(abstraction) and len(self.code) > 0:
last_row = self.code[-1]
body = last_row.split(self.sep)[1]
body = re.sub(r'arg\d', 'lx', body)
args = [f'{lambda_head}{body}']
else:
args = [f'{lambda_head}{args[0]}']
return parsed, args
except IndexError:
return '# ERROR', ['# ERROR']
def _translate_application_f(self, application: Application) -> tuple:
f = application.f
x = application.x
_, x_args = self.translate_wrapper(x, 'x')
f_parsed, f_args = self.translate_wrapper(f, 'f')
return f_parsed, f_args + x_args
def _translate_application_x(self, application: Application) -> tuple:
f = application.f
x = application.x
x_parsed, x_args = self.translate_wrapper(x, 'x')
f_parsed, f_args = self.translate_wrapper(f, 'f')
if x_args[-1][:3] == 'arg' and x_args[-1] not in self.args:
new_x_arg = self.get_last_variable()
if new_x_arg != 'x':
x_args[-1] = new_x_arg
if not f.isInvented:
x_args = f_args + x_args
if not f.isIndex:
self.call_counts[f_parsed.handle] += 1
name = f'{f_parsed.name}_{self.call_counts[f_parsed.handle]}'
try:
f_parsed_resolved = f_parsed.resolve_variables(x_args, name)
except ValueError:
self.log_exception()
f_parsed_resolved = f'{name} = None'
self.code.append(f_parsed_resolved)
x_args = name
return f_parsed, [x_args]
def _translate_application_body(self, application: Application) -> tuple:
f = application.f
x = application.x
x_parsed, x_args = self.translate_wrapper(x, 'x')
f_parsed, f_args = self.translate_wrapper(f, 'f')
x_args = f_args + x_args
self.call_counts[f_parsed.handle] += 1
name = f'{f_parsed.name}_{self.call_counts[f_parsed.handle]}'
missing_args = len(f_parsed.args) - len(x_args)
for i in range(missing_args):
new_arg = f'arg{i + 1}'
if new_arg in self.args:
x_args.append(new_arg)
try:
f_parsed_resolved = f_parsed.resolve_variables(x_args, name)
except ValueError:
self.log_exception()
f_parsed_resolved = f'{name} = None'
self.code.append(f_parsed_resolved)
return None, [name]
def _translate_index(self, index: Index) -> tuple:
arg = f'arg{index.i + 1}'
f_parsed = f'lambda lx: {arg}'
return f_parsed, [arg]
def _translate_invented(self, invented: Invented) -> tuple:
handle = str(invented)
f_parsed = self.grammar.invented[handle]
if f_parsed.source == '':
translator = Translator(self.grammar)
f_trans = translator.translate(f_parsed.program, f_parsed.name)
f_parsed.source = f_trans.source
f_parsed.args = f_trans.args
f_parsed.dependencies = f_trans.dependencies
self.imports.update(f_parsed.imports)
self.dependencies.update(f_parsed.dependencies)
self.dependencies.add(str(f_parsed))
name = f'{f_parsed.name}_{self.call_counts[handle]}'
x_args = [name]
return f_parsed, x_args
def _translate_primitive_f(self, primitive: Primitive) -> tuple:
parsed = self.grammar.primitives[primitive.name].resolve_lambdas()
self.imports.update(parsed.imports)
self.dependencies.update(parsed.dependencies)
return parsed, []
def _translate_primitive_x(self, primitive: Primitive) -> tuple:
return None, [f"'{primitive.value}'"]
def _translate_primitive_body(self, primitive: Primitive) -> tuple:
return None, [f"'{primitive.value}'"]
[docs] def contains_index(self, program: Program) -> bool:
"""Test whether the subprogram contains a de Bruijin index."""
if program.isIndex:
return True
if program.isPrimitive:
return False
if 'body' in dir(program):
return self.contains_index(program.body)
else:
f = self.contains_index(program.f)
x = self.contains_index(program.x)
return (f or x)
[docs] def get_last_variable(self) -> str:
"""Return the declared variable in the last line of code."""
if len(self.code) == 0:
return ''
return self.code[-1].split(self.sep)[0]