Coverage for lapspython/translation.py: 88%

191 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-08-27 20:15 +0200

1"""Implements functions for translation from lambda calculus to Python.""" 

2 

3import logging 

4import re 

5import traceback 

6 

7from dreamcoder.program import (Abstraction, Application, Index, Invented, 

8 Primitive, Program) 

9from lapspython.types import (ParsedGrammar, ParsedProgram, ParsedProgramBase, 

10 ParsedRProgram, ParsedType) 

11 

12 

13class Translator: 

14 """Translate lambda programs to Python code.""" 

15 

16 def __init__(self, grammar: ParsedGrammar) -> None: 

17 """Init grammar used for translation and empty containers. 

18 

19 :param grammar: Grammar used for translation 

20 :type grammar: lapspython.types.ParsedGrammar 

21 """ 

22 self.mode = grammar.mode 

23 

24 if self.mode == 'python': 

25 self.sep = ' = ' 

26 else: 

27 self.sep = ' <- ' 

28 

29 self.grammar = grammar 

30 self.call_counts = {p: 0 for p in self.grammar.primitives} 

31 self.call_counts.update({i: 0 for i in self.grammar.invented}) 

32 self.code: list = [] 

33 self.args: list = [] 

34 self.imports: set = set() 

35 self.dependencies: set = set() 

36 self.debug_stack: list = [] 

37 self.logger = self.setup_logger() 

38 

39 def setup_logger(self) -> logging.Logger: 

40 """Setup a logger for exceptions caught during translation. 

41 

42 rtype: logging.Logger 

43 """ 

44 logger = logging.getLogger(__name__) 

45 handler = logging.FileHandler('translation.log', 'w') 

46 handler.setLevel(logging.DEBUG) 

47 handler.setFormatter(logging.Formatter('%(message)s')) 

48 logger.addHandler(handler) 

49 return logger 

50 

51 def log_exception(self): 

52 """Write current debug stack into translation.log.""" 

53 self.logger.debug(f'{self.name}\n') 

54 for entry in self.debug_stack: 

55 self.logger.debug(entry) 

56 if len(self.code) > 0: 

57 code = '\n'.join(self.code) 

58 self.logger.debug(f'\n{code}') 

59 self.logger.debug(f'\n{traceback.format_exc()}\n') 

60 

61 def translate(self, program: Program, name: str) -> ParsedProgramBase: 

62 """Translate a synthesized program under the current grammar. 

63 

64 :param program: Abstraction/Invented at any depth of lambda expression 

65 :type program: subclass of dreamcoder.program.Program 

66 :param name: Task/Function name 

67 :type name: string 

68 :returns: Translated program 

69 :rtype: ParsedProgram 

70 """ 

71 for call in self.call_counts: 

72 self.call_counts[call] = 0 

73 self.code = [] 

74 self.name = name 

75 arg_types = ParsedType.parse_argument_types(program.infer()) 

76 n_args = len(arg_types) - 1 

77 self.args = [f'arg{i + 1}' for i in range(n_args)] 

78 self.imports = set() 

79 self.dependencies = set() 

80 self.name = name 

81 self.debug_stack = [] 

82 

83 self.translate_wrapper(program) 

84 

85 source = '\n'.join(self.code) 

86 

87 if self.mode == 'python': 

88 last_variable_assignments = re.findall(r'\w+ = ', source) 

89 if len(last_variable_assignments) > 0: 

90 split = source.split(last_variable_assignments[-1]) 

91 source = 'return '.join(split) 

92 

93 return ParsedProgram( 

94 name, 

95 source, 

96 self.args, 

97 self.imports, 

98 self.dependencies 

99 ) 

100 

101 return ParsedRProgram( 

102 name, 

103 source, 

104 self.args, 

105 self.imports, 

106 self.dependencies 

107 ) 

108 

109 def translate_wrapper(self, program: Program, node_type: str = 'body'): 

110 """Redirect node to corresponding translation procedure. 

111 

112 :param program: Node of program tree. 

113 :type program: Subclass of dreamcoder.program.Program 

114 """ 

115 debug = (str(program), str(type(program)), node_type) 

116 self.debug_stack.append(', '.join(debug)) 

117 

118 if program.isAbstraction: 

119 if node_type == 'x': 

120 return self._translate_abstraction_x(program) 

121 return self._translate_abstraction_body(program) 

122 if program.isApplication: 

123 if node_type == 'f': 

124 return self._translate_application_f(program) 

125 if node_type == 'x': 

126 return self._translate_application_x(program) 

127 return self._translate_application_body(program) 

128 if program.isIndex: 

129 return self._translate_index(program) 

130 if program.isInvented: 

131 if node_type == 'f': 

132 return self._translate_invented(program) 

133 return self._translate_abstraction_body(program) 

134 if program.isPrimitive: 

135 if node_type == 'f': 

136 return self._translate_primitive_f(program) 

137 if node_type == 'x': 

138 return self._translate_primitive_x(program) 

139 return self._translate_primitive_body(program) 

140 raise ValueError(f'{node_type} node of type {type(program)}') 

141 

142 def _translate_abstraction_body(self, abstraction: Abstraction) -> tuple: 

143 parsed, args = self.translate_wrapper(abstraction.body) 

144 args = [f'lambda x: {args[0]}'] 

145 return parsed, args 

146 

147 def _translate_abstraction_x(self, abstraction: Abstraction) -> tuple: 

148 parsed, args = self.translate_wrapper(abstraction.body) 

149 

150 try: 

151 lambda_head = '' 

152 if self.mode == 'python': 

153 lambda_head = 'lambda lx: ' 

154 if self.contains_index(abstraction) and len(self.code) > 0: 

155 last_row = self.code[-1] 

156 body = last_row.split(self.sep)[1] 

157 body = re.sub(r'arg\d', 'lx', body) 

158 args = [f'{lambda_head}{body}'] 

159 else: 

160 args = [f'{lambda_head}{args[0]}'] 

161 

162 return parsed, args 

163 

164 except IndexError: 

165 return '# ERROR', ['# ERROR'] 

166 

167 def _translate_application_f(self, application: Application) -> tuple: 

168 f = application.f 

169 x = application.x 

170 

171 _, x_args = self.translate_wrapper(x, 'x') 

172 f_parsed, f_args = self.translate_wrapper(f, 'f') 

173 

174 return f_parsed, f_args + x_args 

175 

176 def _translate_application_x(self, application: Application) -> tuple: 

177 f = application.f 

178 x = application.x 

179 

180 x_parsed, x_args = self.translate_wrapper(x, 'x') 

181 f_parsed, f_args = self.translate_wrapper(f, 'f') 

182 

183 if x_args[-1][:3] == 'arg' and x_args[-1] not in self.args: 

184 new_x_arg = self.get_last_variable() 

185 if new_x_arg != 'x': 

186 x_args[-1] = new_x_arg 

187 

188 if not f.isInvented: 

189 x_args = f_args + x_args 

190 

191 if not f.isIndex: 

192 self.call_counts[f_parsed.handle] += 1 

193 name = f'{f_parsed.name}_{self.call_counts[f_parsed.handle]}' 

194 try: 

195 f_parsed_resolved = f_parsed.resolve_variables(x_args, name) 

196 except ValueError: 

197 self.log_exception() 

198 f_parsed_resolved = f'{name} = None' 

199 self.code.append(f_parsed_resolved) 

200 x_args = name 

201 

202 return f_parsed, [x_args] 

203 

204 def _translate_application_body(self, application: Application) -> tuple: 

205 f = application.f 

206 x = application.x 

207 

208 x_parsed, x_args = self.translate_wrapper(x, 'x') 

209 f_parsed, f_args = self.translate_wrapper(f, 'f') 

210 

211 x_args = f_args + x_args 

212 

213 self.call_counts[f_parsed.handle] += 1 

214 name = f'{f_parsed.name}_{self.call_counts[f_parsed.handle]}' 

215 

216 missing_args = len(f_parsed.args) - len(x_args) 

217 for i in range(missing_args): 

218 new_arg = f'arg{i + 1}' 

219 if new_arg in self.args: 

220 x_args.append(new_arg) 

221 

222 try: 

223 f_parsed_resolved = f_parsed.resolve_variables(x_args, name) 

224 except ValueError: 

225 self.log_exception() 

226 f_parsed_resolved = f'{name} = None' 

227 

228 self.code.append(f_parsed_resolved) 

229 

230 return None, [name] 

231 

232 def _translate_index(self, index: Index) -> tuple: 

233 arg = f'arg{index.i + 1}' 

234 f_parsed = f'lambda lx: {arg}' 

235 return f_parsed, [arg] 

236 

237 def _translate_invented(self, invented: Invented) -> tuple: 

238 handle = str(invented) 

239 f_parsed = self.grammar.invented[handle] 

240 if f_parsed.source == '': 

241 translator = Translator(self.grammar) 

242 f_trans = translator.translate(f_parsed.program, f_parsed.name) 

243 f_parsed.source = f_trans.source 

244 f_parsed.args = f_trans.args 

245 f_parsed.dependencies = f_trans.dependencies 

246 self.imports.update(f_parsed.imports) 

247 self.dependencies.update(f_parsed.dependencies) 

248 self.dependencies.add(str(f_parsed)) 

249 name = f'{f_parsed.name}_{self.call_counts[handle]}' 

250 x_args = [name] 

251 

252 return f_parsed, x_args 

253 

254 def _translate_primitive_f(self, primitive: Primitive) -> tuple: 

255 parsed = self.grammar.primitives[primitive.name].resolve_lambdas() 

256 self.imports.update(parsed.imports) 

257 self.dependencies.update(parsed.dependencies) 

258 return parsed, [] 

259 

260 def _translate_primitive_x(self, primitive: Primitive) -> tuple: 

261 return None, [f"'{primitive.value}'"] 

262 

263 def _translate_primitive_body(self, primitive: Primitive) -> tuple: 

264 return None, [f"'{primitive.value}'"] 

265 

266 def contains_index(self, program: Program) -> bool: 

267 """Test whether the subprogram contains a de Bruijin index.""" 

268 if program.isIndex: 

269 return True 

270 if program.isPrimitive: 

271 return False 

272 if 'body' in dir(program): 

273 return self.contains_index(program.body) 

274 else: 

275 f = self.contains_index(program.f) 

276 x = self.contains_index(program.x) 

277 return (f or x) 

278 

279 def get_last_variable(self) -> str: 

280 """Return the declared variable in the last line of code.""" 

281 if len(self.code) == 0: 

282 return '' 

283 return self.code[-1].split(self.sep)[0]