Coverage for kye/validate.py: 26%

202 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-01-10 16:28 -0700

1from kye.types import Type, TYPE_REF, EDGE 

2from kye.loader.loader import Loader, struct_pack 

3from duckdb import DuckDBPyConnection, DuckDBPyRelation 

4import kye.parser.kye_ast as AST 

5 

6def struct_pack(edges: list[str], r: DuckDBPyRelation): 

7 return 'struct_pack(' + ','.join( 

8 f'''"{edge_name}":="{edge_name}"''' 

9 for edge_name in edges 

10 if edge_name in r.columns 

11 ) + ')' 

12 

13def string_list(strings: list[str]): 

14 # TODO: Escape strings 

15 return "'" + "','".join(strings) + "'" 

16 

17def flag(err_msg: str, condition: str, r: DuckDBPyRelation, **kwargs): 

18 fields = ["'" + err_msg + "'"] 

19 for field in ['tbl','idx','row','col','val']: 

20 if field in kwargs: 

21 if kwargs[field] is None: 

22 fields.append('NULL') 

23 else: 

24 fields.append("'" + kwargs[field] + "'") 

25 elif field in r.columns: 

26 fields.append(field) 

27 else: 

28 fields.append('NULL') 

29 

30 err = r.filter(condition).select(','.join(fields)) 

31 err.insert_into('errors') 

32 return r.filter(f'''NOT({condition})''') 

33 

34def collect(groupby, relations: dict[str, DuckDBPyRelation]): 

35 collected = None 

36 for alias, r in relations.items(): 

37 assert len(r.columns) == 2 

38 assert r.columns[0] == groupby 

39 r = r.select(f'''{groupby}, {r.columns[1]} as {alias}''').set_alias(alias) 

40 collected = collected.join(r, groupby, how='outer') if collected else r 

41 return collected 

42 

43def row_index(r: DuckDBPyRelation, index: list[str], name: str = 'idx'): 

44 r = r.filter(' AND '.join(f'{idx} IS NOT NULL' for idx in index))\ 

45 .select(f'''row, hash({struct_pack(sorted(index), r)}) as {name}''') 

46 return r 

47 

48def row_indexes(edges: DuckDBPyRelation, typ: Type): 

49 global_index = row_index(edges, typ.index).set_alias('idx') 

50 partial_index = None 

51 for idx in typ.indexes: 

52 r = row_index(edges, idx, name='partial') 

53 partial_index = partial_index.union(r) if partial_index else r 

54 r = partial_index.join(global_index, 'row', how='left') 

55 # Create a map of partial ids to full ids 

56 partial_map = r.aggregate('partial, unnest(list_distinct(list(idx))) as idx').set_alias('partial_map') 

57 # Redefine index using the partial_map 

58 r = r.select('row, partial').join(partial_map, 'partial', how='left') 

59 return r 

60 

61def compute_index(typ: Type, db: DuckDBPyConnection): 

62 table = db.table('edges').filter(f'''tbl = '{typ.ref}' ''') 

63 edges = table.filter(f'''col in ({string_list(typ.index)})''') 

64 edges = flag('MULTIPLE_INDEX_VALUES', 'cnt > 1', 

65 edges.aggregate('''row, col, first(val) as val, count(distinct(val)) as cnt'''), tbl=typ.ref, val=None) 

66 edges = collect('row', { 

67 col: edges.filter(f"col = '{col}'").select('row, val') 

68 for col in typ.index 

69 }) 

70 indexes = row_indexes(edges, typ) 

71 r = table.aggregate('row').join(indexes, 'row', how='left') 

72 r = flag('MISSING_INDEX', 'partial IS NULL', r, tbl=typ.ref) 

73 r = flag('MISSING_INDEX_COMPLETION', 'idx IS NULL', r, tbl=typ.ref) 

74 r = flag('CONFLICTING_INDEX', 'cnt > 1', 

75 r.aggregate('row, first(idx) as idx, count(distinct(idx)) as cnt'), tbl=typ.ref, idx=None) 

76 print('hi') 

77 

78class Table: 

79 typ: Type 

80 r: DuckDBPyRelation 

81 

82 def __init__(self, typ: Type, edges: DuckDBPyRelation): 

83 self.typ = typ 

84 self.r = edges.filter(f'''tbl = '{typ.ref}' ''') 

85 self.flag_multiple_index_values() 

86 self.flag_missing_index() 

87 print('hi') 

88 

89 def string_list(self, strings: list[str]): 

90 # TODO: Escape strings 

91 return "'" + "','".join(strings) + "'" 

92 

93 def flag_multiple_index_values(self): 

94 self.flag( 

95 'MULTIPLE_INDEX_VALUES', 

96 self.r.aggregate('row, col, count(distinct val) > 1 as has_many')\ 

97 .filter(f'''has_many AND col in ({self.string_list(self.typ.index)})''') 

98 ) 

99 

100 def flag_missing_index(self): 

101 has_one_of_the_indexes = ' OR '.join( 

102 f'''list_has_all(columns, [{self.string_list(idx)}])''' 

103 for idx in self.typ.indexes 

104 ) 

105 self.flag( 

106 'MISSING_INDEX', 

107 self.r.aggregate("row, list(distinct(col)) as columns") \ 

108 .filter(f'NOT ({has_one_of_the_indexes})') 

109 ) 

110 

111 def flag(self, err, r: DuckDBPyRelation): 

112 err = r.select("'{err}' as err, '{tbl}' as tbl, {row_col_val}".format( 

113 err=err, 

114 tbl=self.typ.ref, 

115 row_col_val=','.join([ 

116 field if field in r.columns else 'NULL' 

117 for field in ['row','col','val'] 

118 ]) 

119 )) 

120 err.insert_into('errors') 

121 

122 def row_columns(self): 

123 return self.r.aggregate("row, list(distinct(col)) as columns") 

124 

125 def row_defines(self, columns: list[str]): 

126 return self.row_columns().select(f'''row, list_has_all(columns, ['{"','".join(columns)}']) as defines''') 

127 

128 def row_edge(self, edge: EDGE): 

129 r = self.r.filter(f"col = '{edge}'") 

130 return r.aggregate(f'''row, first(val) as val''') 

131 

132 def collect(self, groupby, relations: dict[str, DuckDBPyRelation]): 

133 collected = self.r.aggregate(groupby) 

134 for alias, r in relations.items(): 

135 assert len(r.columns) == 2 

136 assert r.columns[0] == groupby 

137 r = r.select(f'''{groupby}, {r.columns[1]} as {alias}''').set_alias(alias) 

138 collected = collected.join(r, groupby, how='left') 

139 return collected 

140 

141 def row_index(self, index: list[str]): 

142 r = self.collect('row', { 

143 idx: self.row_edge(idx) 

144 for idx in index 

145 }) 

146 r = r.select(f'''row, hash({struct_pack(sorted(index), r)}) as idx''') 

147 r = self.collect('row', { 

148 'idx': r, 

149 'defines': self.row_defines(index) 

150 }).filter('defines').select('row, idx') 

151 return r 

152 

153 def row_indexes(self): 

154 global_index = self.row_index(self.typ.index).set_alias('idx') 

155 partial_index = None 

156 for idx in self.typ.indexes: 

157 r = self.row_index(idx).select(f"row, idx as partial") 

158 if partial_index is None: 

159 partial_index = r 

160 else: 

161 partial_index = partial_index.union(r) 

162 r = partial_index.join(global_index, 'row', how='left') 

163 partial_map = r.aggregate('partial, unnest(list_distinct(list(idx))) as idx').set_alias('partial_map') 

164 r = r.select('row, partial').join(partial_map, 'partial').aggregate('row, list(distinct idx) as idx') 

165 print('hi') 

166 # r = self.collect('row',{ 

167 # 'idx': self.row_index(self.typ.index), 

168 # **({ 

169 # f'idx_{i}': self.row_index(idx) 

170 # for i,idx in enumerate(self.typ.indexes) 

171 # }), 

172 # }) 

173 # partial_map = r.select('idx, unnest(list_value(idx_0,idx_1)) as partial')\ 

174 # .filter('idx IS NOT NULL AND partial IS NOT NULL')\ 

175 # .aggregate('partial, list_any_value(list(idx)) as idx, count(distinct idx) as cnt') 

176 # return r 

177 

178 def get_edge(self, edge: EDGE): 

179 assert self.typ.has_edge(edge) 

180 return self.r.filter("col = 'id'").select(f'row, val as {"id"}').set_alias(self.typ.ref + '.' + edge) 

181 if edge in self.r.columns: 

182 edge_rel = self.r.select(f'''_index, {edge} as val''') 

183 else: 

184 # TODO: Could probably also make this an empty table 

185 edge_rel = self.r.select(f'''_index, CAST(NULL as VARCHAR) as val''') 

186 

187 # Create a list of each distinct value 

188 if edge_rel.val.dtypes[0].id == 'list': 

189 edge_rel = edge_rel.aggregate(f'''_index, list_distinct(flatten(list(val))) as val''') 

190 else: 

191 edge_rel = edge_rel.aggregate(f'''_index, list_distinct(list(val)) as val''') 

192 

193 

194 

195class Validate: 

196 loader: Loader 

197 tables: dict[TYPE_REF, DuckDBPyRelation] 

198 

199 def __init__(self, loader: Loader): 

200 self.loader = loader 

201 self.tables = {} 

202 

203 self.db.sql('CREATE TABLE errors (rule_ref TEXT, error_type TEXT, object_id UINT64, val JSON);') 

204 self.errors = self.db.table('errors') 

205 

206 for model_name, table in self.loader.tables.items(): 

207 table = self._validate_model(self.models[model_name], table) 

208 table_name = f'"{model_name}.validated"' 

209 table.create(table_name) 

210 self.tables[model_name] = self.db.table(table_name) 

211 

212 @property 

213 def db(self) -> DuckDBPyConnection: 

214 return self.loader.db 

215 

216 @property 

217 def models(self) -> dict[TYPE_REF, Type]: 

218 return self.loader.models 

219 

220 def _add_errors_where(self, r: DuckDBPyRelation, condition: str, rule_ref: str, error_type: str): 

221 err = r.filter(condition) 

222 err = err.select(f''' '{rule_ref}' as rule_ref, '{error_type}' as error_type, _index as object_id, to_json(val) as val''') 

223 err.insert_into('errors') 

224 return r.filter(f'''NOT ({condition})''') 

225 

226 def check_for_index_collision(self, typ: Type, r: DuckDBPyRelation): 

227 packed_indexes = ','.join(f"list_pack({','.join(sorted(index))})" for index in typ.indexes) 

228 r = r.select(f'''_index, UNNEST([{packed_indexes}]) as index_val''') 

229 r = r.aggregate('index_val, list_distinct(list(_index)) as _indexes') 

230 

231 r = r.select('index_val as val, unnest(_indexes) as _index, len(_indexes) > 1 as collision') 

232 

233 self._add_errors_where(r, 

234 condition = 'collision', 

235 rule_ref = typ.ref, 

236 error_type = 'NON_UNIQUE_INDEX' 

237 ) 

238 # Select the good indexes 

239 return r.aggregate('_index, bool_or(collision) as collision').filter('not collision').select('_index') 

240 

241 

242 def _validate_model(self, typ: Type, r: DuckDBPyRelation): 

243 edges = r.aggregate('_index') 

244 

245 # No need to check for conflicting indexes if there is only one 

246 if len(typ.indexes) > 1: 

247 edges = self.check_for_index_collision(typ, r) 

248 

249 for edge in typ.edges: 

250 edge_rel = r.select(f'''_index, {edge if edge in r.columns else 'CAST(NULL as VARCHAR)'} as val''') 

251 edge_rel = self._validate_edge(typ, edge, edge_rel).set_alias(typ.ref + '.' + edge) 

252 edge_rel = edge_rel.select(f'''_index, val as {edge}''') 

253 edges = edges.join(edge_rel, '_index', how='left') 

254 return edges 

255 

256 def _validate_edge(self, typ: Type, edge: EDGE, r: DuckDBPyRelation): 

257 agg_fun = 'list_distinct(flatten(list(val)))' if r.val.dtypes[0].id == 'list' else 'list_distinct(list(val))' 

258 r = r.aggregate(f'''_index, {agg_fun} as val''') 

259 

260 if not typ.allows_null(edge): 

261 r = self._add_errors_where(r, 'len(val) == 0', typ.ref + '.' + edge, 'NOT_NULLABLE') 

262 

263 if not typ.allows_multiple(edge): 

264 r = self._add_errors_where(r, 'len(val) > 1', typ.ref + '.' + edge, 'NOT_MULTIPLE') 

265 r = r.select(f'''_index, val[1] as val''') 

266 else: 

267 r = r.select(f'''_index, unnest(val) as val''') 

268 

269 r = r.filter('val IS NOT NULL') 

270 r = self._validate_value(typ.get_edge(edge), r) 

271 

272 if typ.allows_multiple(edge): 

273 r = r.aggregate('_index, list(val) as val') 

274 

275 return r 

276 

277 def _validate_value(self, typ: Type, r: DuckDBPyRelation): 

278 # TODO: Look up object references and see if they exist 

279 

280 if typ.ref == 'User.name': 

281 length = COMPUTED_EDGES.apply(typ, 'length', r) 

282 print(length) 

283 print('meep') 

284 # if 'this is Boolean' in typ.assertions: 

285 # r = self._add_errors_where(r, 'TRY_CAST(val as BOOLEAN) IS NULL', typ.ref, 'INVALID_VALUE') 

286 # if 'this is Number' in typ.assertions: 

287 # r = self._add_errors_where(r, 'TRY_CAST(val AS DOUBLE) IS NULL', typ.ref, 'INVALID_VALUE') 

288 

289 return r 

290 

291 def __getitem__(self, model_name: TYPE_REF): 

292 return self.tables[model_name] 

293 

294 def __repr__(self): 

295 return f"<Validate {','.join(self.tables.keys())}>" 

296 

297 

298class ComputedEdges: 

299 computations: dict[tuple[str,str], str] 

300 

301 def __init__(self): 

302 self.computations = dict() 

303 

304 def _get(self, typ: Type, edge: EDGE): 

305 assert typ.has_edge(edge) 

306 comp = self.computations.get((typ.ref, edge)) 

307 if comp is None: 

308 if typ.extends and typ.extends.has_edge(edge): 

309 return self._get(typ.extends, edge) 

310 return comp 

311 

312 def has(self, typ: Type, edge: EDGE): 

313 return self._get(typ, edge) is not None 

314 

315 # allow multiple relations to be passed in 

316 # each being an argument to the function 

317 # then join each of the relations together? 

318 # but only if the other arguments are also indexes and not literals? 

319 def apply(self, typ: Type, edge: EDGE, r: DuckDBPyRelation): 

320 comp = self._get(typ, edge) 

321 return r.select(f'_index, {comp} as val') 

322 

323 def set(self, type_ref: TYPE_REF, edge: EDGE, sql: str): 

324 assert (type_ref, edge) not in self.computations 

325 self.computations[(type_ref, edge)] = sql 

326 

327COMPUTED_EDGES = ComputedEdges() 

328COMPUTED_EDGES.set('String','length', 'LENGTH(val)') 

329 

330def expression(typ: Type, r: DuckDBPyRelation, expr: AST.Expression): 

331 assert isinstance(expr, AST.Expression) 

332 if isinstance(expr, AST.Identifier): 

333 if expr.kind == 'edge': 

334 edge = expr.name 

335 assert typ.has_edge(edge), f'Unknown edge: "{edge}"' 

336 edge_ref = typ.edge_origin(edge).ref + '.' + edge