Coverage for lapspython/utils.py: 89%
28 statements
« prev ^ index » next coverage.py v6.4.4, created at 2022-08-27 20:15 +0200
« prev ^ index » next coverage.py v6.4.4, created at 2022-08-27 20:15 +0200
1"""Utility functions that do not fit in other modules."""
3import json
4import os
6import dill
8from dreamcoder.dreamcoder import ECResult
9from lapspython.types import CompactResult, ParsedGrammar
12def load_checkpoint(filename: str) -> ECResult:
13 """Load training checkpoint.
15 :param filename: Name of file in checkpoints directory, without extension
16 :type filename: string
17 :returns: dreamcoder.dreamcoder.ECResult
18 """
19 with open(f'checkpoints/{filename}.pickle', 'rb') as handle:
20 return dill.load(handle)
23def json_dump(
24 filename: str,
25 grammar: ParsedGrammar,
26 result: CompactResult
27) -> None:
28 """Store grammar and best results in json file.
30 :param filename: File name in checkpoints folder without file extension.
31 :type filename: str
32 :param grammar: Grammar extracted and parsed from checkpoint.
33 :type grammar: lapspython.types.ParsedGrammar
34 :param result: Result extracted and translated from checkpoint.
35 :type result: lapspython.types.CompactResult
36 """
37 json_path = f'checkpoints/{filename}.json'
38 json_dict = {
39 'grammar': grammar.as_dict(),
40 'result': result.get_best()
41 }
42 with open(json_path, 'w') as json_file:
43 try:
44 json.dump(json_dict, json_file, indent=4)
45 except TypeError:
46 os.remove(json_path)
47 raise
50def json_read(filename: str) -> dict:
51 """Read grammar and results from json file.
53 :param filename: File name in checkpoints folder without file extension.
54 :type filename: str
55 :returns: {grammar, result} dictionary
56 :rytpe: dict
57 """
58 json_path = f'checkpoints/{filename}.json'
59 try:
60 with open(json_path, 'r') as json_file:
61 json_dict = json.load(json_file)
62 grammar = json_dict['grammar']
63 parsed = ParsedGrammar(grammar['primitives'], grammar['invented'])
64 json_dict['grammar'] = parsed
65 return json_dict
66 except FileNotFoundError:
67 return {}