Coverage for kye/engine/engine.py: 34%

61 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-01-16 14:12 -0700

1import duckdb 

2from duckdb import DuckDBPyConnection 

3from kye.types import Type, EDGE, TYPE_REF 

4from kye.engine.load_json import json_to_edges 

5from kye.engine.validate import check_table 

6from kye.errors import error_factory, Error 

7import pandas as pd 

8 

9class DuckDBEngine: 

10 db: DuckDBPyConnection 

11 models: dict[TYPE_REF, Type] 

12 

13 def __init__(self, models: dict[TYPE_REF, Type]): 

14 self.db = duckdb.connect(':memory:') 

15 self.models = models 

16 self.has_validated = True 

17 self.create_tables() 

18 

19 def create_tables(self): 

20 self.db.sql(''' 

21 CREATE TABLE edges ( 

22 loc TEXT NOT NULL, 

23 tbl TEXT NOT NULL, 

24 row TEXT NOT NULL, 

25 col TEXT NOT NULL, 

26 val TEXT NOT NULL, 

27 idx UINT64 

28 ); 

29 CREATE TABLE errors ( 

30 err TEXT NOT NULL, 

31 tbl TEXT NOT NULL, 

32 idx TEXT, 

33 row TEXT, 

34 col TEXT, 

35 val TEXT 

36 ); 

37 ''') 

38 

39 @property 

40 def edges(self): 

41 return self.db.table('edges') 

42 

43 @property 

44 def errors(self): 

45 return self.db.table('errors') 

46 

47 def load_json(self, model: TYPE_REF, data): 

48 self.has_validated = False 

49 assert model in self.models 

50 df = pd.DataFrame(json_to_edges(self.models[model], data)) 

51 r = duckdb.df(df, connection=self.db) 

52 r.select('*, NULL as idx').insert_into('edges') 

53 

54 def validate(self): 

55 if not self.has_validated: 

56 self.db.sql(''' 

57 TRUNCATE errors; 

58 UPDATE edges SET idx = NULL; 

59 ''') 

60 for model_name in self.edges.aggregate('distinct tbl').fetchall(): 

61 model = self.models[model_name[0]] 

62 check_table(model, self.db) 

63 self.has_validated = True 

64 

65 def get_table(self, model: TYPE_REF): 

66 assert model in self.models 

67 self.validate() 

68 typ = self.models[model] 

69 table = self.db.sql(f''' 

70 PIVOT ( 

71 SELECT * FROM edges 

72 ANTI JOIN errors on 

73 edges.tbl=errors.tbl 

74 AND (edges.row = errors.row OR errors.row IS NULL) 

75 AND (edges.col = errors.col OR errors.col IS NULL) 

76 AND (edges.val = errors.val OR errors.val IS NULL) 

77 AND (edges.idx = errors.idx OR errors.idx IS NULL) 

78 WHERE tbl = '{model}' 

79 ) ON col USING list(val) GROUP BY idx 

80 ''') 

81 select = [] 

82 for edge in typ.edges: 

83 if edge in table.columns: 

84 if typ.allows_multiple(edge): 

85 select.append(f'list_distinct({edge}) as {edge}') 

86 else: 

87 select.append(f'list_any_value({edge}) as {edge}') 

88 else: 

89 if typ.allows_multiple(edge): 

90 select.append(f'CAST([] AS VARCHAR[]) as {edge}') 

91 else: 

92 select.append(f'CAST(NULL AS VARCHAR) as {edge}') 

93 return table.select(','.join(select)) 

94 

95 def fetch_json(self, model: TYPE_REF): 

96 assert model in self.models 

97 table = self.get_table(model) 

98 return table.fetchdf().to_dict(orient='records') 

99 

100 def get_errors(self) -> list[Error]: 

101 r = self.errors.aggregate(''' 

102 err, 

103 tbl, 

104 col, 

105 count(distinct row) as num_row, 

106 count(distinct idx) as num_idx, 

107 count(distinct val) as num_val, 

108 first(row) as row_example, 

109 first(idx) as idx_example, 

110 first(val) as val_example, 

111 ''') 

112 errors = [] 

113 for err,tbl,col, \ 

114 num_row, num_idx, num_val, \ 

115 row_example, idx_example, val_example in r.fetchall(): 

116 errors.append(error_factory( 

117 err_type=err, 

118 table_name=tbl, 

119 column_name=col, 

120 num_rows=num_row, 

121 num_indexes=num_idx, 

122 num_values=num_val, 

123 row_example=row_example, 

124 idx_example=idx_example, 

125 val_example=val_example, 

126 )) 

127 return errors