Coverage for lapspython/types.py: 95%
266 statements
« prev ^ index » next coverage.py v6.4.4, created at 2022-08-27 20:21 +0200
« prev ^ index » next coverage.py v6.4.4, created at 2022-08-27 20:21 +0200
1"""Implements types for parsed primitives and lambda expressions."""
3import copy
4import inspect
5import random
6import re
7from abc import ABC, abstractmethod
8from typing import Dict, List
10from dreamcoder.frontier import Frontier
11from dreamcoder.program import Invented, Primitive
12from dreamcoder.type import TypeConstructor, TypeVariable
15class ParsedType(ABC):
16 """Abstract base class for program parsing."""
18 @abstractmethod
19 def __init__(self) -> None: # pragma: no cover
20 """Parse input primitive and initialize members."""
21 self.name: str = ''
22 self.handle: str = ''
23 self.source: str = ''
24 self.args: list = []
25 self.imports: set = set()
26 self.dependencies: set = set()
27 self.arg_types: list = []
28 self.return_type = type
30 @abstractmethod
31 def __str__(self) -> str: # pragma: no cover
32 """Convert object to clean code."""
33 pass
35 def as_dict(self) -> dict:
36 """Return member attributes as dict for json dumping."""
37 return {
38 'name': self.name,
39 'handle': self.handle,
40 'source': self.source,
41 'args': self.args,
42 'imports': list(self.imports),
43 'dependencies': list(self.dependencies)
44 }
46 @classmethod
47 def parse_argument_types(cls, arg_types: TypeConstructor) -> list:
48 """Flatten inferred nested type structure of primitive.
50 :param arg_types: Inferred types.
51 :type arg_types: dreamcoder.type.TypeConstructor
52 :returns: Flat list of inferred types.
53 :rtype: list
54 """
55 if not isinstance(arg_types, TypeVariable) and arg_types.name == '->':
56 arguments = arg_types.arguments
57 return [arguments[0]] + cls.parse_argument_types(arguments[1])
58 else:
59 return [arg_types]
61 def resolve_variables(self, args: list, return_name: str) -> str:
62 """Substitute default arguments in source.
64 :param args: List of new argument names.
65 :type args: list
66 :param return_name: Variable name to replace the return statement with
67 :type return_name: string
68 :returns: Source with replaced variable names
69 :rtype: string
70 """
71 if len(args) != len(self.args):
72 func = f'{self.name}({", ".join(self.args)})'
73 raise ValueError(f'Wrong number of arguments for {func}: {args}.')
75 new_source = self.source
76 for i in range(len(args)):
77 pattern = fr'(\(|\[| )({self.args[i]})(,| |\)|\[|\]|$)'
78 fstring_pattern = '{' + str(self.args[i]) + '}'
79 repl = fr'\1{args[i]}\3'
80 new_source = re.sub(pattern, repl, new_source)
81 new_source = re.sub(fstring_pattern, str(args[i]), new_source)
82 if return_name != '':
83 return self.replace_return_statement(return_name, new_source)
84 return new_source
86 def replace_return_statement(self, return_name, source):
87 """Substitute return statement with variable assignment.
89 :param return_name: Variable name to replace return with.
90 :type return_name: string
91 :param source: Source code to apply substitution to.
92 :type source: string
93 :rtype: string
94 """
95 pass # pragma: no cover
98class ParsedPythonType(ParsedType):
99 """Abstract base class for python parsing."""
101 def __str__(self) -> str:
102 """Construct clean Python function from object.
104 :returns: Function source code
105 :rtype: string
106 """
107 header = f'def {self.name}({", ".join(self.args)}):\n'
108 indented_body = re.sub(r'^', ' ', self.source, flags=re.MULTILINE)
109 return header + indented_body + '\n'
112class ParsedRType(ParsedType):
113 """Abstract base class for R parsing."""
115 def __str__(self) -> str:
116 """Return parsed primitive as R code.
118 :returns: R source code
119 :rtype: string
120 """
121 header = f'{self.name} <- function({", ".join(self.args)}) \u007b\n'
122 indented_body = re.sub(r'^', ' ', self.source, flags=re.MULTILINE)
123 return header + indented_body + '\n}\n'
126class ParsedPrimitive(ParsedPythonType):
127 """Class parsing primitives for translation to clean Python code."""
129 def __init__(self, primitive: Primitive) -> None:
130 """Construct ParsedPrimitive object with parsed function specs.
132 :param primitive: A Primitive object
133 :type primitive: dreamcoder.program.Primitive
134 """
135 implementation = primitive.value
137 if inspect.isfunction(implementation):
138 args = inspect.getfullargspec(implementation).args
139 source = self.parse_source(implementation)
140 imports = self.get_imports(implementation)
141 dependencies = self.get_dependencies(implementation)
142 else:
143 args = []
144 source = implementation
145 imports = set()
146 dependencies = []
148 self.handle = primitive.name
149 self.name = re.sub(r'^[^a-z]+', '', self.handle)
150 self.imports = {module for module in imports if module in source}
151 self.dependencies = {d[1] for d in dependencies if d[0] in source}
152 self.source = source.strip()
153 self.args = args
154 self.arg_types = self.parse_argument_types(primitive.tp)
155 self.return_type = self.arg_types.pop()
157 def parse_source(self, implementation) -> str:
158 """Resolve lambdas and arguments to produce cleaner Python code.
160 :param implementation: The function referenced by primitive
161 :type implementation: callable
162 :returns: New source code
163 :rtype: string
164 """
165 source = inspect.getsource(implementation)
167 source = source[source.find(':') + 1:]
169 indent_match = re.search(r'\w', source)
170 if isinstance(indent_match, re.Match):
171 indent = indent_match.start()
173 if indent == 1:
174 source = source[indent:]
175 else:
176 source = re.sub(r'^\n', '', source)
177 source = re.sub(r'^ {4}', '', source, flags=re.MULTILINE)
179 return re.sub(' #.+$', '', source)
181 def get_imports(self, implementation) -> set:
182 """Find import modules that might be required by primitives.
184 :param implementation: The function referenced by a primitive
185 :type implementation: function
186 :returns: A set of module names as strings
187 :rtype: set
188 """
189 main_module = inspect.getmodule(implementation)
190 imports = inspect.getmembers(main_module, inspect.ismodule)
191 return {module[0] for module in imports}
193 def get_dependencies(self, implementation) -> list:
194 """Find functions called by primitives that are not built-ins.
196 :param implementation: The function referenced by a primitive
197 :type implementation: function
198 :returns: A list of (function name, source) tuples
199 :rtype: list
200 """
201 module = inspect.getmodule(implementation)
202 functions = inspect.getmembers(module, inspect.isfunction)
203 dependent_functions = [f for f in functions if f[0][:2] == '__']
204 return [(f[0], inspect.getsource(f[1])) for f in dependent_functions]
206 def resolve_lambdas(self) -> "ParsedPrimitive":
207 """Remove lambda functions from source and extend list of arguments.
209 :returns: New, cleaner parsed primitive
210 :rtype: lapspython.types.ParsedPrimitive
211 """
212 new_primitive = copy.copy(self)
213 pattern = r'lambda (\S+): '
214 new_primitive.args = self.args + re.findall(pattern, self.source)
215 new_primitive.source = re.sub(pattern, '', self.source)
216 return new_primitive
218 def replace_return_statement(self, return_name, source) -> str:
219 """Substitute return statement with variable assignment.
221 :param return_name: Variable name to replace return with.
222 :type return_name: string
223 :param source: Source code to apply substitution to.
224 :type source: string
225 :rtype: string
226 """
227 return re.sub('return', f'{return_name} =', source)
230class ParsedRPrimitive(ParsedRType):
231 """Class parsing primitives for translation to clean R code."""
233 def __init__(self, primitive: Primitive):
234 """Extract name, path and source of R primitive.
236 :param primitive: A primitive extracted from LAPS.
237 :type primitive: dreamcoder.program.Primitive
238 """
239 self.handle = primitive.name
240 self.name = re.sub(r'^[^a-z]+', '', self.handle)
241 py_implementation = primitive.value
243 if inspect.isfunction(py_implementation):
244 py_path = inspect.getsourcefile(py_implementation)
245 if not isinstance(py_path, str): # pragma: no cover
246 msg = f'Cannot get source of primitive {self.name}.'
247 raise ValueError(msg)
248 self.path = py_path[:-2] + 'R'
249 source = self.parse_source(self.name, self.path)
250 imports = self.get_imports(self.path)
251 dependencies = self.get_dependencies(primitive.value)
252 else:
253 source = py_implementation
254 imports = set()
255 dependencies = set()
256 self.args = []
258 self.imports = imports
259 self.dependencies = {d[1] for d in dependencies if d[0] in source}
260 self.source = source.strip()
261 self.arg_types = self.parse_argument_types(primitive.tp)
262 self.return_type = self.arg_types.pop()
264 def parse_source(self, name: str, path: str, is_dep=False) -> str:
265 """Extract source code of primitive from R file.
267 :param handle: Function name in source file.
268 :type handle: string
269 :returns: Source code of corresponding function.
270 :rtype: string
271 """
272 with open(path, 'r') as r_file:
273 lines = r_file.readlines()
275 pattern = f'{name} <- '
277 for i in range(len(lines)):
278 line = lines[i]
279 if line.startswith(pattern):
280 if not line.endswith('{\n'):
281 self.args = []
282 return re.sub(pattern, '', line)
283 if not is_dep:
284 self.args = self.get_args(line)
285 cutoff_lines = lines[i + 1 - is_dep:]
286 break
288 for j in range(len(cutoff_lines)):
289 if cutoff_lines[j] == '}\n':
290 break
291 return ''.join(cutoff_lines[:j + is_dep])
293 def get_imports(self, path) -> set:
294 """Find import modules that might be required by primitives.
296 :param implementation: The function referenced by a primitive
297 :type implementation: function
298 :returns: A set of module names as strings
299 :rtype: set
300 """
301 pattern = r'library\((.+)\)'
302 with open(path, 'r') as r_file:
303 return set(re.findall(pattern, r_file.read()))
305 def get_dependencies(self, implementation):
306 """Find functions called by primitives that are not built-ins.
308 :param implementation: The function referenced by a primitive
309 :type implementation: function
310 :returns: A list of (function name, source) tuples
311 :rtype: list
312 """
313 module = inspect.getmodule(implementation)
314 functions = inspect.getmembers(module, inspect.isfunction)
315 dependent_functions = [(f[0], f[1])
316 for f in functions if f[0][:2] == '__']
317 dependencies = []
318 for f in dependent_functions:
319 if inspect.isfunction(f[1]):
320 path = inspect.getsourcefile(f[1])[:-2] + 'R'
321 dependencies.append(
322 (f[0][2:], self.parse_source(f[0][2:], path, True)))
323 else:
324 dependencies.append((f[0][2:], f[1]))
326 return dependencies
328 def get_args(self, header: str):
329 """Get list of arguments from function code.
331 :param source: Function code
332 :type source: string
333 """
334 match = re.search(r'\(.+\)', header)
335 if match is None: # pragma: no cover
336 return []
337 args = match[0][1:-1]
338 return args.split(', ')
340 def resolve_lambdas(self) -> "ParsedRPrimitive":
341 """No lambdas in R, but required for backwards compatibility."""
342 return self
344 def replace_return_statement(self, return_name, source):
345 """Substitute return statement with variable assignment.
347 :param return_name: Variable name to replace return with.
348 :type return_name: string
349 :param source: Source code to apply substitution to.
350 :type source: string
351 :rtype: string
352 """
353 return re.sub(r'return\((.+)\)', fr'{return_name} <- \1', source)
357class ParsedInvented(ParsedPythonType):
358 """Class parsing invented primitives for translation to Python."""
360 def __init__(self, invented: Invented, name: str):
361 """Construct ParsedInvented object with parsed specs.
363 :param invented: An invented primitive object
364 :type invented: dreamcoder.program.Invented
365 :param name: A custom name since invented primitives are unnamed
366 :type name: string
367 """
368 self.handle = str(invented)
369 self.name = name
370 self.program = invented
371 self.arg_types = self.parse_argument_types(invented.tp)
372 self.return_type = self.arg_types.pop()
374 # To avoid circular imports, source translation is only handled by
375 # lapspython.extraction.GrammarParser instead of during construction.
376 self.source = ''
377 self.args: list = []
378 self.imports: set = set()
379 self.dependencies: set = set()
381 def resolve_variables(self, args: list, return_name: str) -> str:
382 """Instead arguments in function call rather than definition."""
383 head = f'{return_name} = '
384 body = f'{self.name}({", ".join(args)})'
385 return f'{head}{body}'
388class ParsedRInvented(ParsedRType):
389 """Class parsing invented primitives for translation to R."""
391 def __init__(self, invented: Invented, name: str):
392 """Construct ParsedRInvented object with parsed specs.
394 :param invented: An invented primitive object
395 :type invented: dreamcoder.program.Invented
396 :param name: A custom name since invented primitives are unnamed
397 :type name: string
398 """
399 self.handle = str(invented)
400 self.name = name
401 self.program = invented
402 self.arg_types = self.parse_argument_types(invented.tp)
403 self.return_type = self.arg_types.pop()
405 # To avoid circular imports, source translation is only handled by
406 # lapspython.extraction.GrammarParser instead of during construction.
407 self.source = ''
408 self.args: list = []
409 self.imports: set = set()
410 self.dependencies: set = set()
412 def resolve_variables(self, args: list, return_name: str) -> str:
413 """Instead arguments in function call rather than definition."""
414 head = f'{return_name} <- '
415 body = f'{self.name}({", ".join(args)})'
416 return f'{head}{body}'
419class ParsedProgramBase(ParsedType):
420 """Class parsing synthesized programs."""
422 def __init__(
423 self,
424 name: str,
425 source: str,
426 args: list,
427 imports: set,
428 dependencies: set
429 ):
430 """Store Python program with dependencies, arguments, and name.
432 :param name: Task name or invented primitive handle
433 :type name: string
434 :param source: The Python translation of a given program
435 :type source: string
436 :param args: List of arguments to be resolved when used
437 :type args: list
438 :param dependencies: Source codes of called functions
439 :type dependencies: set
440 """
441 self.handle = name
442 self.name = name
443 self.source = source
444 self.args = args
445 self.imports = imports
446 self.dependencies = dependencies
448 @abstractmethod
449 def __str__(self) -> str: # pragma: no cover
450 """Return imports, dependencies and source code as string.
452 :returns: Full source code of translated program
453 :rtype: string
454 """
455 pass
457 @abstractmethod
458 def verify(self, examples: list) -> bool: # pragma: no cover
459 """Verify code for a list of examples from task.
461 :param examples: A list of (input, output) tuples
462 :type examples: list
463 :returns: Whether the translated program is correct.
464 :rtype: bool
465 """
466 pass
469class ParsedProgram(ParsedProgramBase, ParsedType):
470 """Class parsing synthesized programs."""
472 def __str__(self) -> str:
473 """Return dependencies and source code as string.
475 :returns: Full source code of translated program
476 :rtype: string
477 """
478 imports = '\n'.join([f'import {module}' for module in self.imports])
479 dependencies = '\n'.join(self.dependencies) + '\n'
480 header = f'def {self.name}({", ".join(self.args)}):\n'
481 indent_source = re.sub(r'^', ' ', self.source, flags=re.MULTILINE)
482 return imports + '\n\n' + dependencies + header + indent_source
484 def verify(self, examples: list) -> bool:
485 """Verify code for a list of examples from task.
487 :param examples: A list of (input, output) tuples
488 :type examples: list
489 :returns: Whether the translated program is correct.
490 :rtype: bool
491 """
492 exec_translation = str(self) + '\n\n'
494 for example in examples:
495 example_inputs = [f"'{x}'" for x in example[0]]
496 example_output = str(example[1])
498 joined_inputs = ', '.join(example_inputs)
499 exec_example = f'python_output = {self.name}({joined_inputs})'
500 exec_string = exec_translation + exec_example
502 loc: dict = {}
503 try:
504 exec(exec_string, loc)
505 if loc['python_output'] != example_output:
506 return False
507 except BaseException:
508 raise BaseException('\n' + exec_string)
509 return True
512class ParsedRProgram(ParsedProgramBase, ParsedRType):
513 """Class parsing synthesized programs."""
515 def __str__(self) -> str:
516 """Return dependencies and source code as string.
518 :returns: Full source code of translated program
519 :rtype: string
520 """
521 imports = '\n'.join([f'library({module})' for module in self.imports])
522 dependencies = '\n'.join(self.dependencies) + '\n'
523 header = f'{self.name} <- function({", ".join(self.args)}) \u007b\n'
524 indent_source = re.sub(r'^', ' ', self.source, flags=re.MULTILINE)
525 return imports + '\n\n' + dependencies + header + indent_source + '\n}'
527 def verify(self, examples: list) -> bool:
528 """Verify code for a list of examples from task.
530 :param examples: A list of (input, output) tuples
531 :type examples: list
532 :returns: Whether the translated program is correct.
533 :rtype: bool
534 """
535 return True # TODO
538class ParsedGrammar:
539 """Data class containing parsed (invented) primitives."""
541 def __init__(
542 self,
543 primitives: dict,
544 invented: dict,
545 mode: str = 'python'
546 ) -> None:
547 """Store parsed (invented) primitives in member variables.
549 :param primitives: A (name, ParsedPrimitive) dictionary.
550 :type primitives: dict
551 :param invented: A (name, ParsedInvented) dictionary.
552 :type invented: dict
553 """
554 self.mode: str = mode
555 self.primitives: dict = primitives
556 self.invented: dict = invented
558 def as_dict(self):
559 """Return member attributes as dict for json dumping."""
560 primitives = {p.handle: p.as_dict() for p in self.primitives.values()}
561 invented = {i.handle: i.as_dict() for i in self.invented.values()}
562 return {
563 'primitives': primitives,
564 'invented': invented
565 }
568class CompactFrontier:
569 """Data class containing the important specs of extracted frontiers."""
571 def __init__(self, frontier: Frontier, annotation: str = '') -> None:
572 """Construct condensed frontier object with optional annotation."""
573 self.annotation = annotation
574 task = frontier.task
575 self.name = task.name
576 self.requested_types = task.request
577 self.examples = task.examples
578 entries = sorted(frontier.entries, key=lambda e: -e.logPosterior)
579 self.programs = [entry.program for entry in entries]
580 # To avoid circular imports, source translation is handled by
581 # lapspython.extraction.ProgramExtractor instead of the constructor.
582 self.translations: list = []
583 self.failed: list = []
586class CompactResult:
587 """Class containing (compact) extracted frontiers."""
589 def __init__(self, hit: dict, miss: dict) -> None:
590 """Store HIT and MISS CompactFrontiers in member variables.
592 :param hit: A (name, HIT CompactFrontier) dictionary.
593 :type hit: dict
594 :param miss: A (name, MISS CompactFrontier) dictionary.
595 :type miss: dict
596 """
597 self.hit_frontiers: dict = hit
598 self.miss_frontiers: dict = miss
600 def get_best(self) -> List[Dict]:
601 """Return the HIT frontiers as dict with best posteriors.
603 :returns: A list of minimal CompactFrontier dictionaries.
604 :rtype: List[Dict]
605 """
606 hits_best = []
608 for hit in self.hit_frontiers.values():
609 best_valid = best_invalid = None
611 if len(hit.translations) > 0:
612 best_valid = str(hit.translations[0])
613 if len(hit.failed) > 0:
614 best_invalid = str(hit.failed[0])
616 hit_best = {
617 'annotation': hit.annotation,
618 'best_program': str(hit.programs[0]),
619 'best_valid_translation': best_valid,
620 'best_invalid_translation': best_invalid
621 }
622 hits_best.append(hit_best)
624 return hits_best
626 def sample(self) -> dict:
627 """Return a random HIT frontier with valid translation.
629 :returns: A minimal CompactFrontier dictionary.
630 :rtype: dict
631 """
632 best_valid = [best for best in self.get_best()
633 if best['best_valid_translation'] is not None]
634 if len(best_valid) > 0:
635 return random.choice(best_valid)
636 else:
637 return {}