Coverage for lapspython/extraction.py: 96%

79 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-08-27 20:15 +0200

1"""Implements classes to extract primitives and lambda expressions.""" 

2 

3from tqdm import tqdm 

4 

5from dreamcoder.dreamcoder import ECResult 

6from dreamcoder.grammar import Grammar 

7from dreamcoder.program import Invented, Primitive 

8from lapspython.translation import Translator 

9from lapspython.types import (CompactFrontier, CompactResult, ParsedGrammar, 

10 ParsedInvented, ParsedPrimitive, ParsedRInvented, 

11 ParsedRPrimitive) 

12 

13 

14class GrammarParser: 

15 """Extract, parse, and store all primitives from grammar.""" 

16 

17 def __init__(self, grammar: Grammar = None, mode='python') -> None: 

18 """Optionally parse grammar if passed during construction. 

19 

20 :param grammar: A grammar induced by LAPS. 

21 :type grammar: dreamcoder.grammar.Grammar, optional 

22 :param mode: Whether to extract Python or R code, can 'python' or 'r'. 

23 :type mode: string, optional 

24 """ 

25 self.mode = mode.lower() 

26 if self.mode not in ('python', 'r'): 

27 raise ValueError('mode must be "Python" or "R".') 

28 

29 if grammar is not None: 

30 self.parse(grammar) 

31 else: 

32 self.parsed_grammar = ParsedGrammar({}, {}) 

33 

34 def parse(self, grammar: Grammar) -> ParsedGrammar: 

35 """Convert Primitive objects to simplified ParsedPrimitive objects. 

36 

37 :param grammar: A grammar induced inside main() or ecIterator(). 

38 :type grammar: dreamcoder.grammar.Grammar 

39 :rtype: ParsedGrammar 

40 """ 

41 parsed_primitives: dict = {} 

42 parsed_invented: dict = {} 

43 

44 for _, _, primitive in tqdm(grammar.productions): 

45 if isinstance(primitive, Primitive): 

46 name = primitive.name 

47 if name not in parsed_primitives: 

48 if self.mode == 'python': 

49 parsed_primitive = ParsedPrimitive(primitive) 

50 parsed_primitive = parsed_primitive.resolve_lambdas() 

51 parsed_primitives[name] = parsed_primitive 

52 else: 

53 parsed_primitives[name] = ParsedRPrimitive(primitive) 

54 

55 elif not isinstance(primitive, Invented): 

56 raise TypeError(f'Encountered unknown type {type(primitive)}.') 

57 

58 elif str(primitive) not in parsed_invented: 

59 handle = str(primitive) 

60 name = f'f{len(parsed_invented)}' 

61 if self.mode == 'python': 

62 parsed_invented[handle] = ParsedInvented(primitive, name) 

63 else: 

64 parsed_invented[handle] = ParsedRInvented(primitive, name) 

65 

66 self.parsed_grammar = ParsedGrammar( 

67 parsed_primitives, 

68 parsed_invented, 

69 self.mode 

70 ) 

71 

72 translator = Translator(self.parsed_grammar) 

73 for invented in self.parsed_grammar.invented.values(): 

74 if invented.source == '': 

75 trans = translator.translate(invented.program, invented.name) 

76 invented.source = trans.source 

77 invented.args = trans.args 

78 invented.dependencies = trans.dependencies 

79 

80 return self.parsed_grammar 

81 

82 def fix_invented(self, new_invented: dict) -> None: 

83 """Replace invented primitives implementations. 

84 

85 :param new_invented: Invented primitives from JSON file. 

86 :type new_invented: dict 

87 """ 

88 this_invented = self.parsed_grammar.invented 

89 if set(this_invented.keys()) != set(new_invented.keys()): 

90 raise ValueError('Keys of the two grammars are not equal.') 

91 

92 for handle in this_invented: 

93 new_data = new_invented[handle] 

94 this_invented[handle].name = new_data['name'] 

95 this_invented[handle].source = new_data['source'] 

96 this_invented[handle].args = new_data['args'] 

97 this_invented[handle].dependencies = new_data['dependencies'] 

98 

99 

100class ProgramExtractor: 

101 """Extract, parse and translate synthesized programs.""" 

102 

103 def __init__(self, result: ECResult = None, 

104 translator: Translator = None) -> None: 

105 """Optionally extract programs if passed during construction. 

106 

107 :param result: A result produced by LAPS or checkpoint. 

108 :type result: dreamcoder.dreamcoder.ECResult, optional 

109 :param translator: Translator to translate programs during extraction. 

110 :type translator: lapspython.translation.Translator, optional 

111 """ 

112 if result is not None: 

113 self.extract(result, translator) 

114 else: 

115 self.compact_result = CompactResult({}, {}) 

116 

117 def extract(self, result: ECResult, 

118 translator: Translator = None) -> CompactResult: 

119 """Extract all frontiers with descriptions and frontiers. 

120 

121 :param result: Result of dreamcoder execution (checkpoint) 

122 :type result: dreamcoder.dreamcoder.ECResult 

123 :param translator: Translator to translate programs during extraction. 

124 :type translator: lapspython.translation.Translator, optional 

125 :rtype: lapspython.types.CompactResult 

126 """ 

127 hit_frontiers = {} 

128 miss_frontiers = {} 

129 

130 for frontier in tqdm(result.allFrontiers.values()): 

131 name = frontier.task.name 

132 annotation = result.taskLanguage.get(name, '')[0] 

133 compact_frontier = CompactFrontier(frontier, annotation) 

134 if frontier.empty: 

135 miss_frontiers[name] = compact_frontier 

136 else: 

137 hit_frontiers[name] = compact_frontier 

138 

139 if translator is not None: 

140 for program in compact_frontier.programs: 

141 transl = translator.translate(program, name) 

142 try: 

143 if transl.verify(compact_frontier.examples): 

144 compact_frontier.translations.append(transl) 

145 else: 

146 compact_frontier.failed.append(transl) 

147 except BaseException: 

148 compact_frontier.failed.append(transl) 

149 

150 self.compact_result = CompactResult(hit_frontiers, miss_frontiers) 

151 return self.compact_result