Coverage for kye/validate.py: 26%
202 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-01-10 16:28 -0700
« prev ^ index » next coverage.py v7.3.2, created at 2024-01-10 16:28 -0700
1from kye.types import Type, TYPE_REF, EDGE
2from kye.loader.loader import Loader, struct_pack
3from duckdb import DuckDBPyConnection, DuckDBPyRelation
4import kye.parser.kye_ast as AST
6def struct_pack(edges: list[str], r: DuckDBPyRelation):
7 return 'struct_pack(' + ','.join(
8 f'''"{edge_name}":="{edge_name}"'''
9 for edge_name in edges
10 if edge_name in r.columns
11 ) + ')'
13def string_list(strings: list[str]):
14 # TODO: Escape strings
15 return "'" + "','".join(strings) + "'"
17def flag(err_msg: str, condition: str, r: DuckDBPyRelation, **kwargs):
18 fields = ["'" + err_msg + "'"]
19 for field in ['tbl','idx','row','col','val']:
20 if field in kwargs:
21 if kwargs[field] is None:
22 fields.append('NULL')
23 else:
24 fields.append("'" + kwargs[field] + "'")
25 elif field in r.columns:
26 fields.append(field)
27 else:
28 fields.append('NULL')
30 err = r.filter(condition).select(','.join(fields))
31 err.insert_into('errors')
32 return r.filter(f'''NOT({condition})''')
34def collect(groupby, relations: dict[str, DuckDBPyRelation]):
35 collected = None
36 for alias, r in relations.items():
37 assert len(r.columns) == 2
38 assert r.columns[0] == groupby
39 r = r.select(f'''{groupby}, {r.columns[1]} as {alias}''').set_alias(alias)
40 collected = collected.join(r, groupby, how='outer') if collected else r
41 return collected
43def row_index(r: DuckDBPyRelation, index: list[str], name: str = 'idx'):
44 r = r.filter(' AND '.join(f'{idx} IS NOT NULL' for idx in index))\
45 .select(f'''row, hash({struct_pack(sorted(index), r)}) as {name}''')
46 return r
48def row_indexes(edges: DuckDBPyRelation, typ: Type):
49 global_index = row_index(edges, typ.index).set_alias('idx')
50 partial_index = None
51 for idx in typ.indexes:
52 r = row_index(edges, idx, name='partial')
53 partial_index = partial_index.union(r) if partial_index else r
54 r = partial_index.join(global_index, 'row', how='left')
55 # Create a map of partial ids to full ids
56 partial_map = r.aggregate('partial, unnest(list_distinct(list(idx))) as idx').set_alias('partial_map')
57 # Redefine index using the partial_map
58 r = r.select('row, partial').join(partial_map, 'partial', how='left')
59 return r
61def compute_index(typ: Type, db: DuckDBPyConnection):
62 table = db.table('edges').filter(f'''tbl = '{typ.ref}' ''')
63 edges = table.filter(f'''col in ({string_list(typ.index)})''')
64 edges = flag('MULTIPLE_INDEX_VALUES', 'cnt > 1',
65 edges.aggregate('''row, col, first(val) as val, count(distinct(val)) as cnt'''), tbl=typ.ref, val=None)
66 edges = collect('row', {
67 col: edges.filter(f"col = '{col}'").select('row, val')
68 for col in typ.index
69 })
70 indexes = row_indexes(edges, typ)
71 r = table.aggregate('row').join(indexes, 'row', how='left')
72 r = flag('MISSING_INDEX', 'partial IS NULL', r, tbl=typ.ref)
73 r = flag('MISSING_INDEX_COMPLETION', 'idx IS NULL', r, tbl=typ.ref)
74 r = flag('CONFLICTING_INDEX', 'cnt > 1',
75 r.aggregate('row, first(idx) as idx, count(distinct(idx)) as cnt'), tbl=typ.ref, idx=None)
76 print('hi')
78class Table:
79 typ: Type
80 r: DuckDBPyRelation
82 def __init__(self, typ: Type, edges: DuckDBPyRelation):
83 self.typ = typ
84 self.r = edges.filter(f'''tbl = '{typ.ref}' ''')
85 self.flag_multiple_index_values()
86 self.flag_missing_index()
87 print('hi')
89 def string_list(self, strings: list[str]):
90 # TODO: Escape strings
91 return "'" + "','".join(strings) + "'"
93 def flag_multiple_index_values(self):
94 self.flag(
95 'MULTIPLE_INDEX_VALUES',
96 self.r.aggregate('row, col, count(distinct val) > 1 as has_many')\
97 .filter(f'''has_many AND col in ({self.string_list(self.typ.index)})''')
98 )
100 def flag_missing_index(self):
101 has_one_of_the_indexes = ' OR '.join(
102 f'''list_has_all(columns, [{self.string_list(idx)}])'''
103 for idx in self.typ.indexes
104 )
105 self.flag(
106 'MISSING_INDEX',
107 self.r.aggregate("row, list(distinct(col)) as columns") \
108 .filter(f'NOT ({has_one_of_the_indexes})')
109 )
111 def flag(self, err, r: DuckDBPyRelation):
112 err = r.select("'{err}' as err, '{tbl}' as tbl, {row_col_val}".format(
113 err=err,
114 tbl=self.typ.ref,
115 row_col_val=','.join([
116 field if field in r.columns else 'NULL'
117 for field in ['row','col','val']
118 ])
119 ))
120 err.insert_into('errors')
122 def row_columns(self):
123 return self.r.aggregate("row, list(distinct(col)) as columns")
125 def row_defines(self, columns: list[str]):
126 return self.row_columns().select(f'''row, list_has_all(columns, ['{"','".join(columns)}']) as defines''')
128 def row_edge(self, edge: EDGE):
129 r = self.r.filter(f"col = '{edge}'")
130 return r.aggregate(f'''row, first(val) as val''')
132 def collect(self, groupby, relations: dict[str, DuckDBPyRelation]):
133 collected = self.r.aggregate(groupby)
134 for alias, r in relations.items():
135 assert len(r.columns) == 2
136 assert r.columns[0] == groupby
137 r = r.select(f'''{groupby}, {r.columns[1]} as {alias}''').set_alias(alias)
138 collected = collected.join(r, groupby, how='left')
139 return collected
141 def row_index(self, index: list[str]):
142 r = self.collect('row', {
143 idx: self.row_edge(idx)
144 for idx in index
145 })
146 r = r.select(f'''row, hash({struct_pack(sorted(index), r)}) as idx''')
147 r = self.collect('row', {
148 'idx': r,
149 'defines': self.row_defines(index)
150 }).filter('defines').select('row, idx')
151 return r
153 def row_indexes(self):
154 global_index = self.row_index(self.typ.index).set_alias('idx')
155 partial_index = None
156 for idx in self.typ.indexes:
157 r = self.row_index(idx).select(f"row, idx as partial")
158 if partial_index is None:
159 partial_index = r
160 else:
161 partial_index = partial_index.union(r)
162 r = partial_index.join(global_index, 'row', how='left')
163 partial_map = r.aggregate('partial, unnest(list_distinct(list(idx))) as idx').set_alias('partial_map')
164 r = r.select('row, partial').join(partial_map, 'partial').aggregate('row, list(distinct idx) as idx')
165 print('hi')
166 # r = self.collect('row',{
167 # 'idx': self.row_index(self.typ.index),
168 # **({
169 # f'idx_{i}': self.row_index(idx)
170 # for i,idx in enumerate(self.typ.indexes)
171 # }),
172 # })
173 # partial_map = r.select('idx, unnest(list_value(idx_0,idx_1)) as partial')\
174 # .filter('idx IS NOT NULL AND partial IS NOT NULL')\
175 # .aggregate('partial, list_any_value(list(idx)) as idx, count(distinct idx) as cnt')
176 # return r
178 def get_edge(self, edge: EDGE):
179 assert self.typ.has_edge(edge)
180 return self.r.filter("col = 'id'").select(f'row, val as {"id"}').set_alias(self.typ.ref + '.' + edge)
181 if edge in self.r.columns:
182 edge_rel = self.r.select(f'''_index, {edge} as val''')
183 else:
184 # TODO: Could probably also make this an empty table
185 edge_rel = self.r.select(f'''_index, CAST(NULL as VARCHAR) as val''')
187 # Create a list of each distinct value
188 if edge_rel.val.dtypes[0].id == 'list':
189 edge_rel = edge_rel.aggregate(f'''_index, list_distinct(flatten(list(val))) as val''')
190 else:
191 edge_rel = edge_rel.aggregate(f'''_index, list_distinct(list(val)) as val''')
195class Validate:
196 loader: Loader
197 tables: dict[TYPE_REF, DuckDBPyRelation]
199 def __init__(self, loader: Loader):
200 self.loader = loader
201 self.tables = {}
203 self.db.sql('CREATE TABLE errors (rule_ref TEXT, error_type TEXT, object_id UINT64, val JSON);')
204 self.errors = self.db.table('errors')
206 for model_name, table in self.loader.tables.items():
207 table = self._validate_model(self.models[model_name], table)
208 table_name = f'"{model_name}.validated"'
209 table.create(table_name)
210 self.tables[model_name] = self.db.table(table_name)
212 @property
213 def db(self) -> DuckDBPyConnection:
214 return self.loader.db
216 @property
217 def models(self) -> dict[TYPE_REF, Type]:
218 return self.loader.models
220 def _add_errors_where(self, r: DuckDBPyRelation, condition: str, rule_ref: str, error_type: str):
221 err = r.filter(condition)
222 err = err.select(f''' '{rule_ref}' as rule_ref, '{error_type}' as error_type, _index as object_id, to_json(val) as val''')
223 err.insert_into('errors')
224 return r.filter(f'''NOT ({condition})''')
226 def check_for_index_collision(self, typ: Type, r: DuckDBPyRelation):
227 packed_indexes = ','.join(f"list_pack({','.join(sorted(index))})" for index in typ.indexes)
228 r = r.select(f'''_index, UNNEST([{packed_indexes}]) as index_val''')
229 r = r.aggregate('index_val, list_distinct(list(_index)) as _indexes')
231 r = r.select('index_val as val, unnest(_indexes) as _index, len(_indexes) > 1 as collision')
233 self._add_errors_where(r,
234 condition = 'collision',
235 rule_ref = typ.ref,
236 error_type = 'NON_UNIQUE_INDEX'
237 )
238 # Select the good indexes
239 return r.aggregate('_index, bool_or(collision) as collision').filter('not collision').select('_index')
242 def _validate_model(self, typ: Type, r: DuckDBPyRelation):
243 edges = r.aggregate('_index')
245 # No need to check for conflicting indexes if there is only one
246 if len(typ.indexes) > 1:
247 edges = self.check_for_index_collision(typ, r)
249 for edge in typ.edges:
250 edge_rel = r.select(f'''_index, {edge if edge in r.columns else 'CAST(NULL as VARCHAR)'} as val''')
251 edge_rel = self._validate_edge(typ, edge, edge_rel).set_alias(typ.ref + '.' + edge)
252 edge_rel = edge_rel.select(f'''_index, val as {edge}''')
253 edges = edges.join(edge_rel, '_index', how='left')
254 return edges
256 def _validate_edge(self, typ: Type, edge: EDGE, r: DuckDBPyRelation):
257 agg_fun = 'list_distinct(flatten(list(val)))' if r.val.dtypes[0].id == 'list' else 'list_distinct(list(val))'
258 r = r.aggregate(f'''_index, {agg_fun} as val''')
260 if not typ.allows_null(edge):
261 r = self._add_errors_where(r, 'len(val) == 0', typ.ref + '.' + edge, 'NOT_NULLABLE')
263 if not typ.allows_multiple(edge):
264 r = self._add_errors_where(r, 'len(val) > 1', typ.ref + '.' + edge, 'NOT_MULTIPLE')
265 r = r.select(f'''_index, val[1] as val''')
266 else:
267 r = r.select(f'''_index, unnest(val) as val''')
269 r = r.filter('val IS NOT NULL')
270 r = self._validate_value(typ.get_edge(edge), r)
272 if typ.allows_multiple(edge):
273 r = r.aggregate('_index, list(val) as val')
275 return r
277 def _validate_value(self, typ: Type, r: DuckDBPyRelation):
278 # TODO: Look up object references and see if they exist
280 if typ.ref == 'User.name':
281 length = COMPUTED_EDGES.apply(typ, 'length', r)
282 print(length)
283 print('meep')
284 # if 'this is Boolean' in typ.assertions:
285 # r = self._add_errors_where(r, 'TRY_CAST(val as BOOLEAN) IS NULL', typ.ref, 'INVALID_VALUE')
286 # if 'this is Number' in typ.assertions:
287 # r = self._add_errors_where(r, 'TRY_CAST(val AS DOUBLE) IS NULL', typ.ref, 'INVALID_VALUE')
289 return r
291 def __getitem__(self, model_name: TYPE_REF):
292 return self.tables[model_name]
294 def __repr__(self):
295 return f"<Validate {','.join(self.tables.keys())}>"
298class ComputedEdges:
299 computations: dict[tuple[str,str], str]
301 def __init__(self):
302 self.computations = dict()
304 def _get(self, typ: Type, edge: EDGE):
305 assert typ.has_edge(edge)
306 comp = self.computations.get((typ.ref, edge))
307 if comp is None:
308 if typ.extends and typ.extends.has_edge(edge):
309 return self._get(typ.extends, edge)
310 return comp
312 def has(self, typ: Type, edge: EDGE):
313 return self._get(typ, edge) is not None
315 # allow multiple relations to be passed in
316 # each being an argument to the function
317 # then join each of the relations together?
318 # but only if the other arguments are also indexes and not literals?
319 def apply(self, typ: Type, edge: EDGE, r: DuckDBPyRelation):
320 comp = self._get(typ, edge)
321 return r.select(f'_index, {comp} as val')
323 def set(self, type_ref: TYPE_REF, edge: EDGE, sql: str):
324 assert (type_ref, edge) not in self.computations
325 self.computations[(type_ref, edge)] = sql
327COMPUTED_EDGES = ComputedEdges()
328COMPUTED_EDGES.set('String','length', 'LENGTH(val)')
330def expression(typ: Type, r: DuckDBPyRelation, expr: AST.Expression):
331 assert isinstance(expr, AST.Expression)
332 if isinstance(expr, AST.Identifier):
333 if expr.kind == 'edge':
334 edge = expr.name
335 assert typ.has_edge(edge), f'Unknown edge: "{edge}"'
336 edge_ref = typ.edge_origin(edge).ref + '.' + edge