Coverage for lapspython/types.py: 95%

266 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-08-27 20:21 +0200

1"""Implements types for parsed primitives and lambda expressions.""" 

2 

3import copy 

4import inspect 

5import random 

6import re 

7from abc import ABC, abstractmethod 

8from typing import Dict, List 

9 

10from dreamcoder.frontier import Frontier 

11from dreamcoder.program import Invented, Primitive 

12from dreamcoder.type import TypeConstructor, TypeVariable 

13 

14 

15class ParsedType(ABC): 

16 """Abstract base class for program parsing.""" 

17 

18 @abstractmethod 

19 def __init__(self) -> None: # pragma: no cover 

20 """Parse input primitive and initialize members.""" 

21 self.name: str = '' 

22 self.handle: str = '' 

23 self.source: str = '' 

24 self.args: list = [] 

25 self.imports: set = set() 

26 self.dependencies: set = set() 

27 self.arg_types: list = [] 

28 self.return_type = type 

29 

30 @abstractmethod 

31 def __str__(self) -> str: # pragma: no cover 

32 """Convert object to clean code.""" 

33 pass 

34 

35 def as_dict(self) -> dict: 

36 """Return member attributes as dict for json dumping.""" 

37 return { 

38 'name': self.name, 

39 'handle': self.handle, 

40 'source': self.source, 

41 'args': self.args, 

42 'imports': list(self.imports), 

43 'dependencies': list(self.dependencies) 

44 } 

45 

46 @classmethod 

47 def parse_argument_types(cls, arg_types: TypeConstructor) -> list: 

48 """Flatten inferred nested type structure of primitive. 

49 

50 :param arg_types: Inferred types. 

51 :type arg_types: dreamcoder.type.TypeConstructor 

52 :returns: Flat list of inferred types. 

53 :rtype: list 

54 """ 

55 if not isinstance(arg_types, TypeVariable) and arg_types.name == '->': 

56 arguments = arg_types.arguments 

57 return [arguments[0]] + cls.parse_argument_types(arguments[1]) 

58 else: 

59 return [arg_types] 

60 

61 def resolve_variables(self, args: list, return_name: str) -> str: 

62 """Substitute default arguments in source. 

63 

64 :param args: List of new argument names. 

65 :type args: list 

66 :param return_name: Variable name to replace the return statement with 

67 :type return_name: string 

68 :returns: Source with replaced variable names 

69 :rtype: string 

70 """ 

71 if len(args) != len(self.args): 

72 func = f'{self.name}({", ".join(self.args)})' 

73 raise ValueError(f'Wrong number of arguments for {func}: {args}.') 

74 

75 new_source = self.source 

76 for i in range(len(args)): 

77 pattern = fr'(\(|\[| )({self.args[i]})(,| |\)|\[|\]|$)' 

78 fstring_pattern = '{' + str(self.args[i]) + '}' 

79 repl = fr'\1{args[i]}\3' 

80 new_source = re.sub(pattern, repl, new_source) 

81 new_source = re.sub(fstring_pattern, str(args[i]), new_source) 

82 if return_name != '': 

83 return self.replace_return_statement(return_name, new_source) 

84 return new_source 

85 

86 def replace_return_statement(self, return_name, source): 

87 """Substitute return statement with variable assignment. 

88  

89 :param return_name: Variable name to replace return with. 

90 :type return_name: string 

91 :param source: Source code to apply substitution to. 

92 :type source: string 

93 :rtype: string 

94 """ 

95 pass # pragma: no cover 

96 

97 

98class ParsedPythonType(ParsedType): 

99 """Abstract base class for python parsing.""" 

100 

101 def __str__(self) -> str: 

102 """Construct clean Python function from object. 

103 

104 :returns: Function source code 

105 :rtype: string 

106 """ 

107 header = f'def {self.name}({", ".join(self.args)}):\n' 

108 indented_body = re.sub(r'^', ' ', self.source, flags=re.MULTILINE) 

109 return header + indented_body + '\n' 

110 

111 

112class ParsedRType(ParsedType): 

113 """Abstract base class for R parsing.""" 

114 

115 def __str__(self) -> str: 

116 """Return parsed primitive as R code. 

117 

118 :returns: R source code 

119 :rtype: string 

120 """ 

121 header = f'{self.name} <- function({", ".join(self.args)}) \u007b\n' 

122 indented_body = re.sub(r'^', ' ', self.source, flags=re.MULTILINE) 

123 return header + indented_body + '\n}\n' 

124 

125 

126class ParsedPrimitive(ParsedPythonType): 

127 """Class parsing primitives for translation to clean Python code.""" 

128 

129 def __init__(self, primitive: Primitive) -> None: 

130 """Construct ParsedPrimitive object with parsed function specs. 

131 

132 :param primitive: A Primitive object 

133 :type primitive: dreamcoder.program.Primitive 

134 """ 

135 implementation = primitive.value 

136 

137 if inspect.isfunction(implementation): 

138 args = inspect.getfullargspec(implementation).args 

139 source = self.parse_source(implementation) 

140 imports = self.get_imports(implementation) 

141 dependencies = self.get_dependencies(implementation) 

142 else: 

143 args = [] 

144 source = implementation 

145 imports = set() 

146 dependencies = [] 

147 

148 self.handle = primitive.name 

149 self.name = re.sub(r'^[^a-z]+', '', self.handle) 

150 self.imports = {module for module in imports if module in source} 

151 self.dependencies = {d[1] for d in dependencies if d[0] in source} 

152 self.source = source.strip() 

153 self.args = args 

154 self.arg_types = self.parse_argument_types(primitive.tp) 

155 self.return_type = self.arg_types.pop() 

156 

157 def parse_source(self, implementation) -> str: 

158 """Resolve lambdas and arguments to produce cleaner Python code. 

159 

160 :param implementation: The function referenced by primitive 

161 :type implementation: callable 

162 :returns: New source code 

163 :rtype: string 

164 """ 

165 source = inspect.getsource(implementation) 

166 

167 source = source[source.find(':') + 1:] 

168 

169 indent_match = re.search(r'\w', source) 

170 if isinstance(indent_match, re.Match): 

171 indent = indent_match.start() 

172 

173 if indent == 1: 

174 source = source[indent:] 

175 else: 

176 source = re.sub(r'^\n', '', source) 

177 source = re.sub(r'^ {4}', '', source, flags=re.MULTILINE) 

178 

179 return re.sub(' #.+$', '', source) 

180 

181 def get_imports(self, implementation) -> set: 

182 """Find import modules that might be required by primitives. 

183 

184 :param implementation: The function referenced by a primitive 

185 :type implementation: function 

186 :returns: A set of module names as strings 

187 :rtype: set 

188 """ 

189 main_module = inspect.getmodule(implementation) 

190 imports = inspect.getmembers(main_module, inspect.ismodule) 

191 return {module[0] for module in imports} 

192 

193 def get_dependencies(self, implementation) -> list: 

194 """Find functions called by primitives that are not built-ins. 

195 

196 :param implementation: The function referenced by a primitive 

197 :type implementation: function 

198 :returns: A list of (function name, source) tuples 

199 :rtype: list 

200 """ 

201 module = inspect.getmodule(implementation) 

202 functions = inspect.getmembers(module, inspect.isfunction) 

203 dependent_functions = [f for f in functions if f[0][:2] == '__'] 

204 return [(f[0], inspect.getsource(f[1])) for f in dependent_functions] 

205 

206 def resolve_lambdas(self) -> "ParsedPrimitive": 

207 """Remove lambda functions from source and extend list of arguments. 

208 

209 :returns: New, cleaner parsed primitive 

210 :rtype: lapspython.types.ParsedPrimitive 

211 """ 

212 new_primitive = copy.copy(self) 

213 pattern = r'lambda (\S+): ' 

214 new_primitive.args = self.args + re.findall(pattern, self.source) 

215 new_primitive.source = re.sub(pattern, '', self.source) 

216 return new_primitive 

217 

218 def replace_return_statement(self, return_name, source) -> str: 

219 """Substitute return statement with variable assignment. 

220  

221 :param return_name: Variable name to replace return with. 

222 :type return_name: string 

223 :param source: Source code to apply substitution to. 

224 :type source: string 

225 :rtype: string 

226 """ 

227 return re.sub('return', f'{return_name} =', source) 

228 

229 

230class ParsedRPrimitive(ParsedRType): 

231 """Class parsing primitives for translation to clean R code.""" 

232 

233 def __init__(self, primitive: Primitive): 

234 """Extract name, path and source of R primitive. 

235 

236 :param primitive: A primitive extracted from LAPS. 

237 :type primitive: dreamcoder.program.Primitive 

238 """ 

239 self.handle = primitive.name 

240 self.name = re.sub(r'^[^a-z]+', '', self.handle) 

241 py_implementation = primitive.value 

242 

243 if inspect.isfunction(py_implementation): 

244 py_path = inspect.getsourcefile(py_implementation) 

245 if not isinstance(py_path, str): # pragma: no cover 

246 msg = f'Cannot get source of primitive {self.name}.' 

247 raise ValueError(msg) 

248 self.path = py_path[:-2] + 'R' 

249 source = self.parse_source(self.name, self.path) 

250 imports = self.get_imports(self.path) 

251 dependencies = self.get_dependencies(primitive.value) 

252 else: 

253 source = py_implementation 

254 imports = set() 

255 dependencies = set() 

256 self.args = [] 

257 

258 self.imports = imports 

259 self.dependencies = {d[1] for d in dependencies if d[0] in source} 

260 self.source = source.strip() 

261 self.arg_types = self.parse_argument_types(primitive.tp) 

262 self.return_type = self.arg_types.pop() 

263 

264 def parse_source(self, name: str, path: str, is_dep=False) -> str: 

265 """Extract source code of primitive from R file. 

266 

267 :param handle: Function name in source file. 

268 :type handle: string 

269 :returns: Source code of corresponding function. 

270 :rtype: string 

271 """ 

272 with open(path, 'r') as r_file: 

273 lines = r_file.readlines() 

274 

275 pattern = f'{name} <- ' 

276 

277 for i in range(len(lines)): 

278 line = lines[i] 

279 if line.startswith(pattern): 

280 if not line.endswith('{\n'): 

281 self.args = [] 

282 return re.sub(pattern, '', line) 

283 if not is_dep: 

284 self.args = self.get_args(line) 

285 cutoff_lines = lines[i + 1 - is_dep:] 

286 break 

287 

288 for j in range(len(cutoff_lines)): 

289 if cutoff_lines[j] == '}\n': 

290 break 

291 return ''.join(cutoff_lines[:j + is_dep]) 

292 

293 def get_imports(self, path) -> set: 

294 """Find import modules that might be required by primitives. 

295 

296 :param implementation: The function referenced by a primitive 

297 :type implementation: function 

298 :returns: A set of module names as strings 

299 :rtype: set 

300 """ 

301 pattern = r'library\((.+)\)' 

302 with open(path, 'r') as r_file: 

303 return set(re.findall(pattern, r_file.read())) 

304 

305 def get_dependencies(self, implementation): 

306 """Find functions called by primitives that are not built-ins. 

307 

308 :param implementation: The function referenced by a primitive 

309 :type implementation: function 

310 :returns: A list of (function name, source) tuples 

311 :rtype: list 

312 """ 

313 module = inspect.getmodule(implementation) 

314 functions = inspect.getmembers(module, inspect.isfunction) 

315 dependent_functions = [(f[0], f[1]) 

316 for f in functions if f[0][:2] == '__'] 

317 dependencies = [] 

318 for f in dependent_functions: 

319 if inspect.isfunction(f[1]): 

320 path = inspect.getsourcefile(f[1])[:-2] + 'R' 

321 dependencies.append( 

322 (f[0][2:], self.parse_source(f[0][2:], path, True))) 

323 else: 

324 dependencies.append((f[0][2:], f[1])) 

325 

326 return dependencies 

327 

328 def get_args(self, header: str): 

329 """Get list of arguments from function code. 

330 

331 :param source: Function code 

332 :type source: string 

333 """ 

334 match = re.search(r'\(.+\)', header) 

335 if match is None: # pragma: no cover 

336 return [] 

337 args = match[0][1:-1] 

338 return args.split(', ') 

339 

340 def resolve_lambdas(self) -> "ParsedRPrimitive": 

341 """No lambdas in R, but required for backwards compatibility.""" 

342 return self 

343 

344 def replace_return_statement(self, return_name, source): 

345 """Substitute return statement with variable assignment. 

346  

347 :param return_name: Variable name to replace return with. 

348 :type return_name: string 

349 :param source: Source code to apply substitution to. 

350 :type source: string 

351 :rtype: string 

352 """ 

353 return re.sub(r'return\((.+)\)', fr'{return_name} <- \1', source) 

354 

355 

356 

357class ParsedInvented(ParsedPythonType): 

358 """Class parsing invented primitives for translation to Python.""" 

359 

360 def __init__(self, invented: Invented, name: str): 

361 """Construct ParsedInvented object with parsed specs. 

362 

363 :param invented: An invented primitive object 

364 :type invented: dreamcoder.program.Invented 

365 :param name: A custom name since invented primitives are unnamed 

366 :type name: string 

367 """ 

368 self.handle = str(invented) 

369 self.name = name 

370 self.program = invented 

371 self.arg_types = self.parse_argument_types(invented.tp) 

372 self.return_type = self.arg_types.pop() 

373 

374 # To avoid circular imports, source translation is only handled by 

375 # lapspython.extraction.GrammarParser instead of during construction. 

376 self.source = '' 

377 self.args: list = [] 

378 self.imports: set = set() 

379 self.dependencies: set = set() 

380 

381 def resolve_variables(self, args: list, return_name: str) -> str: 

382 """Instead arguments in function call rather than definition.""" 

383 head = f'{return_name} = ' 

384 body = f'{self.name}({", ".join(args)})' 

385 return f'{head}{body}' 

386 

387 

388class ParsedRInvented(ParsedRType): 

389 """Class parsing invented primitives for translation to R.""" 

390 

391 def __init__(self, invented: Invented, name: str): 

392 """Construct ParsedRInvented object with parsed specs. 

393 

394 :param invented: An invented primitive object 

395 :type invented: dreamcoder.program.Invented 

396 :param name: A custom name since invented primitives are unnamed 

397 :type name: string 

398 """ 

399 self.handle = str(invented) 

400 self.name = name 

401 self.program = invented 

402 self.arg_types = self.parse_argument_types(invented.tp) 

403 self.return_type = self.arg_types.pop() 

404 

405 # To avoid circular imports, source translation is only handled by 

406 # lapspython.extraction.GrammarParser instead of during construction. 

407 self.source = '' 

408 self.args: list = [] 

409 self.imports: set = set() 

410 self.dependencies: set = set() 

411 

412 def resolve_variables(self, args: list, return_name: str) -> str: 

413 """Instead arguments in function call rather than definition.""" 

414 head = f'{return_name} <- ' 

415 body = f'{self.name}({", ".join(args)})' 

416 return f'{head}{body}' 

417 

418 

419class ParsedProgramBase(ParsedType): 

420 """Class parsing synthesized programs.""" 

421 

422 def __init__( 

423 self, 

424 name: str, 

425 source: str, 

426 args: list, 

427 imports: set, 

428 dependencies: set 

429 ): 

430 """Store Python program with dependencies, arguments, and name. 

431 

432 :param name: Task name or invented primitive handle 

433 :type name: string 

434 :param source: The Python translation of a given program 

435 :type source: string 

436 :param args: List of arguments to be resolved when used 

437 :type args: list 

438 :param dependencies: Source codes of called functions 

439 :type dependencies: set 

440 """ 

441 self.handle = name 

442 self.name = name 

443 self.source = source 

444 self.args = args 

445 self.imports = imports 

446 self.dependencies = dependencies 

447 

448 @abstractmethod 

449 def __str__(self) -> str: # pragma: no cover 

450 """Return imports, dependencies and source code as string. 

451 

452 :returns: Full source code of translated program 

453 :rtype: string 

454 """ 

455 pass 

456 

457 @abstractmethod 

458 def verify(self, examples: list) -> bool: # pragma: no cover 

459 """Verify code for a list of examples from task. 

460 

461 :param examples: A list of (input, output) tuples 

462 :type examples: list 

463 :returns: Whether the translated program is correct. 

464 :rtype: bool 

465 """ 

466 pass 

467 

468 

469class ParsedProgram(ParsedProgramBase, ParsedType): 

470 """Class parsing synthesized programs.""" 

471 

472 def __str__(self) -> str: 

473 """Return dependencies and source code as string. 

474 

475 :returns: Full source code of translated program 

476 :rtype: string 

477 """ 

478 imports = '\n'.join([f'import {module}' for module in self.imports]) 

479 dependencies = '\n'.join(self.dependencies) + '\n' 

480 header = f'def {self.name}({", ".join(self.args)}):\n' 

481 indent_source = re.sub(r'^', ' ', self.source, flags=re.MULTILINE) 

482 return imports + '\n\n' + dependencies + header + indent_source 

483 

484 def verify(self, examples: list) -> bool: 

485 """Verify code for a list of examples from task. 

486 

487 :param examples: A list of (input, output) tuples 

488 :type examples: list 

489 :returns: Whether the translated program is correct. 

490 :rtype: bool 

491 """ 

492 exec_translation = str(self) + '\n\n' 

493 

494 for example in examples: 

495 example_inputs = [f"'{x}'" for x in example[0]] 

496 example_output = str(example[1]) 

497 

498 joined_inputs = ', '.join(example_inputs) 

499 exec_example = f'python_output = {self.name}({joined_inputs})' 

500 exec_string = exec_translation + exec_example 

501 

502 loc: dict = {} 

503 try: 

504 exec(exec_string, loc) 

505 if loc['python_output'] != example_output: 

506 return False 

507 except BaseException: 

508 raise BaseException('\n' + exec_string) 

509 return True 

510 

511 

512class ParsedRProgram(ParsedProgramBase, ParsedRType): 

513 """Class parsing synthesized programs.""" 

514 

515 def __str__(self) -> str: 

516 """Return dependencies and source code as string. 

517 

518 :returns: Full source code of translated program 

519 :rtype: string 

520 """ 

521 imports = '\n'.join([f'library({module})' for module in self.imports]) 

522 dependencies = '\n'.join(self.dependencies) + '\n' 

523 header = f'{self.name} <- function({", ".join(self.args)}) \u007b\n' 

524 indent_source = re.sub(r'^', ' ', self.source, flags=re.MULTILINE) 

525 return imports + '\n\n' + dependencies + header + indent_source + '\n}' 

526 

527 def verify(self, examples: list) -> bool: 

528 """Verify code for a list of examples from task. 

529 

530 :param examples: A list of (input, output) tuples 

531 :type examples: list 

532 :returns: Whether the translated program is correct. 

533 :rtype: bool 

534 """ 

535 return True # TODO 

536 

537 

538class ParsedGrammar: 

539 """Data class containing parsed (invented) primitives.""" 

540 

541 def __init__( 

542 self, 

543 primitives: dict, 

544 invented: dict, 

545 mode: str = 'python' 

546 ) -> None: 

547 """Store parsed (invented) primitives in member variables. 

548 

549 :param primitives: A (name, ParsedPrimitive) dictionary. 

550 :type primitives: dict 

551 :param invented: A (name, ParsedInvented) dictionary. 

552 :type invented: dict 

553 """ 

554 self.mode: str = mode 

555 self.primitives: dict = primitives 

556 self.invented: dict = invented 

557 

558 def as_dict(self): 

559 """Return member attributes as dict for json dumping.""" 

560 primitives = {p.handle: p.as_dict() for p in self.primitives.values()} 

561 invented = {i.handle: i.as_dict() for i in self.invented.values()} 

562 return { 

563 'primitives': primitives, 

564 'invented': invented 

565 } 

566 

567 

568class CompactFrontier: 

569 """Data class containing the important specs of extracted frontiers.""" 

570 

571 def __init__(self, frontier: Frontier, annotation: str = '') -> None: 

572 """Construct condensed frontier object with optional annotation.""" 

573 self.annotation = annotation 

574 task = frontier.task 

575 self.name = task.name 

576 self.requested_types = task.request 

577 self.examples = task.examples 

578 entries = sorted(frontier.entries, key=lambda e: -e.logPosterior) 

579 self.programs = [entry.program for entry in entries] 

580 # To avoid circular imports, source translation is handled by 

581 # lapspython.extraction.ProgramExtractor instead of the constructor. 

582 self.translations: list = [] 

583 self.failed: list = [] 

584 

585 

586class CompactResult: 

587 """Class containing (compact) extracted frontiers.""" 

588 

589 def __init__(self, hit: dict, miss: dict) -> None: 

590 """Store HIT and MISS CompactFrontiers in member variables. 

591 

592 :param hit: A (name, HIT CompactFrontier) dictionary. 

593 :type hit: dict 

594 :param miss: A (name, MISS CompactFrontier) dictionary. 

595 :type miss: dict 

596 """ 

597 self.hit_frontiers: dict = hit 

598 self.miss_frontiers: dict = miss 

599 

600 def get_best(self) -> List[Dict]: 

601 """Return the HIT frontiers as dict with best posteriors. 

602 

603 :returns: A list of minimal CompactFrontier dictionaries. 

604 :rtype: List[Dict] 

605 """ 

606 hits_best = [] 

607 

608 for hit in self.hit_frontiers.values(): 

609 best_valid = best_invalid = None 

610 

611 if len(hit.translations) > 0: 

612 best_valid = str(hit.translations[0]) 

613 if len(hit.failed) > 0: 

614 best_invalid = str(hit.failed[0]) 

615 

616 hit_best = { 

617 'annotation': hit.annotation, 

618 'best_program': str(hit.programs[0]), 

619 'best_valid_translation': best_valid, 

620 'best_invalid_translation': best_invalid 

621 } 

622 hits_best.append(hit_best) 

623 

624 return hits_best 

625 

626 def sample(self) -> dict: 

627 """Return a random HIT frontier with valid translation. 

628 

629 :returns: A minimal CompactFrontier dictionary. 

630 :rtype: dict 

631 """ 

632 best_valid = [best for best in self.get_best() 

633 if best['best_valid_translation'] is not None] 

634 if len(best_valid) > 0: 

635 return random.choice(best_valid) 

636 else: 

637 return {}