Merging upstream version 6.3.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
81e6900b0a
commit
393757f998
41 changed files with 1558 additions and 267 deletions
|
@ -1,16 +1,20 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.helper import ensure_list, subclasses
|
||||
from sqlglot.optimizer.schema import ensure_schema
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
|
||||
|
||||
def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
|
||||
"""
|
||||
Recursively infer & annotate types in an expression syntax tree against a schema.
|
||||
Assumes that we've already executed the optimizer's qualify_columns step.
|
||||
|
||||
(TODO -- replace this with a better example after adding some functionality)
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> annotated_expression = annotate_types(sqlglot.parse_one('5 + 5.3'))
|
||||
>>> annotated_expression.type
|
||||
>>> schema = {"y": {"cola": "SMALLINT"}}
|
||||
>>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
|
||||
>>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
|
||||
>>> annotated_expr.expressions[0].type # Get the type of "x.cola + 2.5 AS cola"
|
||||
<Type.DOUBLE: 'DOUBLE'>
|
||||
|
||||
Args:
|
||||
|
@ -22,6 +26,8 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
|
|||
sqlglot.Expression: expression annotated with types
|
||||
"""
|
||||
|
||||
schema = ensure_schema(schema)
|
||||
|
||||
return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
|
||||
|
||||
|
||||
|
@ -35,10 +41,81 @@ class TypeAnnotator:
|
|||
expr_type: lambda self, expr: self._annotate_binary(expr)
|
||||
for expr_type in subclasses(exp.__name__, exp.Binary)
|
||||
},
|
||||
exp.Cast: lambda self, expr: self._annotate_cast(expr),
|
||||
exp.DataType: lambda self, expr: self._annotate_data_type(expr),
|
||||
exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"].this),
|
||||
exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.this),
|
||||
exp.Alias: lambda self, expr: self._annotate_unary(expr),
|
||||
exp.Literal: lambda self, expr: self._annotate_literal(expr),
|
||||
exp.Boolean: lambda self, expr: self._annotate_boolean(expr),
|
||||
exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||
exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
|
||||
exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
|
||||
exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
|
||||
exp.CurrentTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
|
||||
exp.DatetimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
|
||||
exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.TimestampAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.TimestampSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.DateStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.DateToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||
exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.If: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||
exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.RegexpLike: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||
exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.StrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.UnixToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.VariancePop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||
exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||
}
|
||||
|
||||
# Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
|
||||
|
@ -97,43 +174,82 @@ class TypeAnnotator:
|
|||
},
|
||||
}
|
||||
|
||||
TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery)
|
||||
|
||||
def __init__(self, schema=None, annotators=None, coerces_to=None):
|
||||
self.schema = schema
|
||||
self.annotators = annotators or self.ANNOTATORS
|
||||
self.coerces_to = coerces_to or self.COERCES_TO
|
||||
|
||||
def annotate(self, expression):
|
||||
if isinstance(expression, self.TRAVERSABLES):
|
||||
for scope in traverse_scope(expression):
|
||||
subscope_selects = {
|
||||
name: {select.alias_or_name: select for select in source.selects}
|
||||
for name, source in scope.sources.items()
|
||||
if isinstance(source, Scope)
|
||||
}
|
||||
|
||||
# First annotate the current scope's column references
|
||||
for col in scope.columns:
|
||||
source = scope.sources[col.table]
|
||||
if isinstance(source, exp.Table):
|
||||
col.type = self.schema.get_column_type(source, col)
|
||||
else:
|
||||
col.type = subscope_selects[col.table][col.name].type
|
||||
|
||||
# Then (possibly) annotate the remaining expressions in the scope
|
||||
self._maybe_annotate(scope.expression)
|
||||
|
||||
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
|
||||
|
||||
def _maybe_annotate(self, expression):
|
||||
if not isinstance(expression, exp.Expression):
|
||||
return None
|
||||
|
||||
if expression.type:
|
||||
return expression # We've already inferred the expression's type
|
||||
|
||||
annotator = self.annotators.get(expression.__class__)
|
||||
return annotator(self, expression) if annotator else self._annotate_args(expression)
|
||||
return (
|
||||
annotator(self, expression)
|
||||
if annotator
|
||||
else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
|
||||
)
|
||||
|
||||
def _annotate_args(self, expression):
|
||||
for value in expression.args.values():
|
||||
for v in ensure_list(value):
|
||||
self.annotate(v)
|
||||
self._maybe_annotate(v)
|
||||
|
||||
return expression
|
||||
|
||||
def _annotate_cast(self, expression):
|
||||
expression.type = expression.args["to"].this
|
||||
return self._annotate_args(expression)
|
||||
|
||||
def _annotate_data_type(self, expression):
|
||||
expression.type = expression.this
|
||||
return self._annotate_args(expression)
|
||||
|
||||
def _maybe_coerce(self, type1, type2):
|
||||
# We propagate the NULL / UNKNOWN types upwards if found
|
||||
if exp.DataType.Type.NULL in (type1, type2):
|
||||
return exp.DataType.Type.NULL
|
||||
if exp.DataType.Type.UNKNOWN in (type1, type2):
|
||||
return exp.DataType.Type.UNKNOWN
|
||||
|
||||
return type2 if type2 in self.coerces_to[type1] else type1
|
||||
|
||||
def _annotate_binary(self, expression):
|
||||
self._annotate_args(expression)
|
||||
|
||||
if isinstance(expression, (exp.Condition, exp.Predicate)):
|
||||
left_type = expression.left.type
|
||||
right_type = expression.right.type
|
||||
|
||||
if isinstance(expression, (exp.And, exp.Or)):
|
||||
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
|
||||
expression.type = exp.DataType.Type.NULL
|
||||
elif exp.DataType.Type.NULL in (left_type, right_type):
|
||||
expression.type = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
|
||||
else:
|
||||
expression.type = exp.DataType.Type.BOOLEAN
|
||||
elif isinstance(expression, (exp.Condition, exp.Predicate)):
|
||||
expression.type = exp.DataType.Type.BOOLEAN
|
||||
else:
|
||||
expression.type = self._maybe_coerce(expression.left.type, expression.right.type)
|
||||
expression.type = self._maybe_coerce(left_type, right_type)
|
||||
|
||||
return expression
|
||||
|
||||
|
@ -157,6 +273,6 @@ class TypeAnnotator:
|
|||
|
||||
return expression
|
||||
|
||||
def _annotate_boolean(self, expression):
|
||||
expression.type = exp.DataType.Type.BOOLEAN
|
||||
return expression
|
||||
def _annotate_with_type(self, expression, target_type):
|
||||
expression.type = target_type
|
||||
return self._annotate_args(expression)
|
||||
|
|
|
@ -44,6 +44,7 @@ UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
|
|||
"joins",
|
||||
"where",
|
||||
"order",
|
||||
"hint",
|
||||
}
|
||||
|
||||
|
||||
|
@ -67,21 +68,22 @@ def merge_ctes(expression, leave_tables_isolated=False):
|
|||
singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
|
||||
for outer_scope, inner_scope, table in singular_cte_selections:
|
||||
inner_select = inner_scope.expression.unnest()
|
||||
if _mergeable(outer_scope, inner_select, leave_tables_isolated):
|
||||
from_or_join = table.find_ancestor(exp.From, exp.Join)
|
||||
|
||||
from_or_join = table.find_ancestor(exp.From, exp.Join)
|
||||
if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
||||
node_to_replace = table
|
||||
if isinstance(node_to_replace.parent, exp.Alias):
|
||||
node_to_replace = node_to_replace.parent
|
||||
alias = node_to_replace.alias
|
||||
else:
|
||||
alias = table.name
|
||||
|
||||
_rename_inner_sources(outer_scope, inner_scope, alias)
|
||||
_merge_from(outer_scope, inner_scope, node_to_replace, alias)
|
||||
_merge_expressions(outer_scope, inner_scope, alias)
|
||||
_merge_joins(outer_scope, inner_scope, from_or_join)
|
||||
_merge_where(outer_scope, inner_scope, from_or_join)
|
||||
_merge_order(outer_scope, inner_scope)
|
||||
_merge_hints(outer_scope, inner_scope)
|
||||
_pop_cte(inner_scope)
|
||||
return expression
|
||||
|
||||
|
@ -90,9 +92,9 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
|
|||
for outer_scope in traverse_scope(expression):
|
||||
for subquery in outer_scope.derived_tables:
|
||||
inner_select = subquery.unnest()
|
||||
if _mergeable(outer_scope, inner_select, leave_tables_isolated):
|
||||
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
|
||||
if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
||||
alias = subquery.alias_or_name
|
||||
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
|
||||
inner_scope = outer_scope.sources[alias]
|
||||
|
||||
_rename_inner_sources(outer_scope, inner_scope, alias)
|
||||
|
@ -101,10 +103,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
|
|||
_merge_joins(outer_scope, inner_scope, from_or_join)
|
||||
_merge_where(outer_scope, inner_scope, from_or_join)
|
||||
_merge_order(outer_scope, inner_scope)
|
||||
_merge_hints(outer_scope, inner_scope)
|
||||
return expression
|
||||
|
||||
|
||||
def _mergeable(outer_scope, inner_select, leave_tables_isolated):
|
||||
def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
||||
"""
|
||||
Return True if `inner_select` can be merged into outer query.
|
||||
|
||||
|
@ -112,6 +115,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated):
|
|||
outer_scope (Scope)
|
||||
inner_select (exp.Select)
|
||||
leave_tables_isolated (bool)
|
||||
from_or_join (exp.From|exp.Join)
|
||||
Returns:
|
||||
bool: True if can be merged
|
||||
"""
|
||||
|
@ -123,6 +127,16 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated):
|
|||
and inner_select.args.get("from")
|
||||
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
|
||||
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
|
||||
and not (
|
||||
isinstance(from_or_join, exp.Join)
|
||||
and inner_select.args.get("where")
|
||||
and from_or_join.side in {"FULL", "LEFT", "RIGHT"}
|
||||
)
|
||||
and not (
|
||||
isinstance(from_or_join, exp.From)
|
||||
and inner_select.args.get("where")
|
||||
and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
@ -170,6 +184,12 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
|
|||
"""
|
||||
new_subquery = inner_scope.expression.args.get("from").expressions[0]
|
||||
node_to_replace.replace(new_subquery)
|
||||
for join_hint in outer_scope.join_hints:
|
||||
tables = join_hint.find_all(exp.Table)
|
||||
for table in tables:
|
||||
if table.alias_or_name == node_to_replace.alias_or_name:
|
||||
new_table = new_subquery.this if isinstance(new_subquery, exp.Alias) else new_subquery
|
||||
table.set("this", exp.to_identifier(new_table.alias_or_name))
|
||||
outer_scope.remove_source(alias)
|
||||
outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
|
||||
|
||||
|
@ -273,6 +293,18 @@ def _merge_order(outer_scope, inner_scope):
|
|||
outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
|
||||
|
||||
|
||||
def _merge_hints(outer_scope, inner_scope):
|
||||
inner_scope_hint = inner_scope.expression.args.get("hint")
|
||||
if not inner_scope_hint:
|
||||
return
|
||||
outer_scope_hint = outer_scope.expression.args.get("hint")
|
||||
if outer_scope_hint:
|
||||
for hint_expression in inner_scope_hint.expressions:
|
||||
outer_scope_hint.append("expressions", hint_expression)
|
||||
else:
|
||||
outer_scope.expression.set("hint", inner_scope_hint)
|
||||
|
||||
|
||||
def _pop_cte(inner_scope):
|
||||
"""
|
||||
Remove CTE from the AST.
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from collections import defaultdict
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.optimizer.normalize import normalized
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
|
@ -20,22 +22,30 @@ def pushdown_predicates(expression):
|
|||
Returns:
|
||||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
for scope in reversed(traverse_scope(expression)):
|
||||
scope_ref_count = defaultdict(lambda: 0)
|
||||
scopes = traverse_scope(expression)
|
||||
scopes.reverse()
|
||||
|
||||
for scope in scopes:
|
||||
for _, source in scope.selected_sources.values():
|
||||
scope_ref_count[id(source)] += 1
|
||||
|
||||
for scope in scopes:
|
||||
select = scope.expression
|
||||
where = select.args.get("where")
|
||||
if where:
|
||||
pushdown(where.this, scope.selected_sources)
|
||||
pushdown(where.this, scope.selected_sources, scope_ref_count)
|
||||
|
||||
# joins should only pushdown into itself, not to other joins
|
||||
# so we limit the selected sources to only itself
|
||||
for join in select.args.get("joins") or []:
|
||||
name = join.this.alias_or_name
|
||||
pushdown(join.args.get("on"), {name: scope.selected_sources[name]})
|
||||
pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def pushdown(condition, sources):
|
||||
def pushdown(condition, sources, scope_ref_count):
|
||||
if not condition:
|
||||
return
|
||||
|
||||
|
@ -45,17 +55,17 @@ def pushdown(condition, sources):
|
|||
predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition])
|
||||
|
||||
if cnf_like:
|
||||
pushdown_cnf(predicates, sources)
|
||||
pushdown_cnf(predicates, sources, scope_ref_count)
|
||||
else:
|
||||
pushdown_dnf(predicates, sources)
|
||||
pushdown_dnf(predicates, sources, scope_ref_count)
|
||||
|
||||
|
||||
def pushdown_cnf(predicates, scope):
|
||||
def pushdown_cnf(predicates, scope, scope_ref_count):
|
||||
"""
|
||||
If the predicates are in CNF like form, we can simply replace each block in the parent.
|
||||
"""
|
||||
for predicate in predicates:
|
||||
for node in nodes_for_predicate(predicate, scope).values():
|
||||
for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
|
||||
if isinstance(node, exp.Join):
|
||||
predicate.replace(exp.TRUE)
|
||||
node.on(predicate, copy=False)
|
||||
|
@ -65,7 +75,7 @@ def pushdown_cnf(predicates, scope):
|
|||
node.where(replace_aliases(node, predicate), copy=False)
|
||||
|
||||
|
||||
def pushdown_dnf(predicates, scope):
|
||||
def pushdown_dnf(predicates, scope, scope_ref_count):
|
||||
"""
|
||||
If the predicates are in DNF form, we can only push down conditions that are in all blocks.
|
||||
Additionally, we can't remove predicates from their original form.
|
||||
|
@ -91,7 +101,7 @@ def pushdown_dnf(predicates, scope):
|
|||
# (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z)
|
||||
for table in sorted(pushdown_tables):
|
||||
for predicate in predicates:
|
||||
nodes = nodes_for_predicate(predicate, scope)
|
||||
nodes = nodes_for_predicate(predicate, scope, scope_ref_count)
|
||||
|
||||
if table not in nodes:
|
||||
continue
|
||||
|
@ -120,7 +130,7 @@ def pushdown_dnf(predicates, scope):
|
|||
node.where(replace_aliases(node, predicate), copy=False)
|
||||
|
||||
|
||||
def nodes_for_predicate(predicate, sources):
|
||||
def nodes_for_predicate(predicate, sources, scope_ref_count):
|
||||
nodes = {}
|
||||
tables = exp.column_table_names(predicate)
|
||||
where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
|
||||
|
@ -133,7 +143,7 @@ def nodes_for_predicate(predicate, sources):
|
|||
if node and where_condition:
|
||||
node = node.find_ancestor(exp.Join, exp.From)
|
||||
|
||||
# a node can reference a CTE which should be push down
|
||||
# a node can reference a CTE which should be pushed down
|
||||
if isinstance(node, exp.From) and not isinstance(source, exp.Table):
|
||||
node = source.expression
|
||||
|
||||
|
@ -142,7 +152,9 @@ def nodes_for_predicate(predicate, sources):
|
|||
return {}
|
||||
nodes[table] = node
|
||||
elif isinstance(node, exp.Select) and len(tables) == 1:
|
||||
if not node.args.get("group"):
|
||||
# we can't push down predicates to select statements if they are referenced in
|
||||
# multiple places.
|
||||
if not node.args.get("group") and scope_ref_count[id(source)] < 2:
|
||||
nodes[table] = node
|
||||
return nodes
|
||||
|
||||
|
|
|
@ -31,8 +31,8 @@ def qualify_columns(expression, schema):
|
|||
_pop_table_column_aliases(scope.derived_tables)
|
||||
_expand_using(scope, resolver)
|
||||
_expand_group_by(scope, resolver)
|
||||
_expand_order_by(scope)
|
||||
_qualify_columns(scope, resolver)
|
||||
_expand_order_by(scope)
|
||||
if not isinstance(scope.expression, exp.UDTF):
|
||||
_expand_stars(scope, resolver)
|
||||
_qualify_outputs(scope)
|
||||
|
@ -235,7 +235,7 @@ def _expand_stars(scope, resolver):
|
|||
for table in tables:
|
||||
if table not in scope.sources:
|
||||
raise OptimizeError(f"Unknown table: {table}")
|
||||
columns = resolver.get_source_columns(table)
|
||||
columns = resolver.get_source_columns(table, only_visible=True)
|
||||
table_id = id(table)
|
||||
for name in columns:
|
||||
if name not in except_columns.get(table_id, set()):
|
||||
|
@ -332,7 +332,7 @@ class _Resolver:
|
|||
self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns)
|
||||
return self._all_columns
|
||||
|
||||
def get_source_columns(self, name):
|
||||
def get_source_columns(self, name, only_visible=False):
|
||||
"""Resolve the source columns for a given source `name`"""
|
||||
if name not in self.scope.sources:
|
||||
raise OptimizeError(f"Unknown table: {name}")
|
||||
|
@ -342,7 +342,7 @@ class _Resolver:
|
|||
# If referencing a table, return the columns from the schema
|
||||
if isinstance(source, exp.Table):
|
||||
try:
|
||||
return self.schema.column_names(source)
|
||||
return self.schema.column_names(source, only_visible)
|
||||
except Exception as e:
|
||||
raise OptimizeError(str(e)) from e
|
||||
|
||||
|
|
|
@ -9,16 +9,28 @@ class Schema(abc.ABC):
|
|||
"""Abstract base class for database schemas"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def column_names(self, table):
|
||||
def column_names(self, table, only_visible=False):
|
||||
"""
|
||||
Get the column names for a table.
|
||||
|
||||
Args:
|
||||
table (sqlglot.expressions.Table): Table expression instance
|
||||
only_visible (bool): Whether to include invisible columns
|
||||
Returns:
|
||||
list[str]: list of column names
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_column_type(self, table, column):
|
||||
"""
|
||||
Get the exp.DataType type of a column in the schema.
|
||||
|
||||
Args:
|
||||
table (sqlglot.expressions.Table): The source table.
|
||||
column (sqlglot.expressions.Column): The target column.
|
||||
Returns:
|
||||
sqlglot.expressions.DataType.Type: The resulting column type.
|
||||
"""
|
||||
|
||||
|
||||
class MappingSchema(Schema):
|
||||
"""
|
||||
|
@ -29,10 +41,19 @@ class MappingSchema(Schema):
|
|||
1. {table: {col: type}}
|
||||
2. {db: {table: {col: type}}}
|
||||
3. {catalog: {db: {table: {col: type}}}}
|
||||
visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
|
||||
are assumed to be visible. The nesting should mirror that of the schema:
|
||||
1. {table: set(*cols)}}
|
||||
2. {db: {table: set(*cols)}}}
|
||||
3. {catalog: {db: {table: set(*cols)}}}}
|
||||
dialect (str): The dialect to be used for custom type mappings.
|
||||
"""
|
||||
|
||||
def __init__(self, schema):
|
||||
def __init__(self, schema, visible=None, dialect=None):
|
||||
self.schema = schema
|
||||
self.visible = visible
|
||||
self.dialect = dialect
|
||||
self._type_mapping_cache = {}
|
||||
|
||||
depth = _dict_depth(schema)
|
||||
|
||||
|
@ -49,7 +70,7 @@ class MappingSchema(Schema):
|
|||
|
||||
self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args)
|
||||
|
||||
def column_names(self, table):
|
||||
def column_names(self, table, only_visible=False):
|
||||
if not isinstance(table.this, exp.Identifier):
|
||||
return fs_get(table)
|
||||
|
||||
|
@ -58,7 +79,39 @@ class MappingSchema(Schema):
|
|||
for forbidden in self.forbidden_args:
|
||||
if table.text(forbidden):
|
||||
raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
|
||||
return list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
|
||||
|
||||
columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
|
||||
if not only_visible or not self.visible:
|
||||
return columns
|
||||
|
||||
visible = _nested_get(self.visible, *zip(self.supported_table_args, args))
|
||||
return [col for col in columns if col in visible]
|
||||
|
||||
def get_column_type(self, table, column):
|
||||
try:
|
||||
schema_type = self.schema.get(table.name, {}).get(column.name).upper()
|
||||
return self._convert_type(schema_type)
|
||||
except:
|
||||
raise OptimizeError(f"Failed to get type for column {column.sql()}")
|
||||
|
||||
def _convert_type(self, schema_type):
|
||||
"""
|
||||
Convert a type represented as a string to the corresponding exp.DataType.Type object.
|
||||
|
||||
Args:
|
||||
schema_type (str): The type we want to convert.
|
||||
Returns:
|
||||
sqlglot.expressions.DataType.Type: The resulting expression type.
|
||||
"""
|
||||
if schema_type not in self._type_mapping_cache:
|
||||
try:
|
||||
self._type_mapping_cache[schema_type] = exp.maybe_parse(
|
||||
schema_type, into=exp.DataType, dialect=self.dialect
|
||||
).this
|
||||
except AttributeError:
|
||||
raise OptimizeError(f"Failed to convert type {schema_type}")
|
||||
|
||||
return self._type_mapping_cache[schema_type]
|
||||
|
||||
|
||||
def ensure_schema(schema):
|
||||
|
|
|
@ -68,6 +68,7 @@ class Scope:
|
|||
self._selected_sources = None
|
||||
self._columns = None
|
||||
self._external_columns = None
|
||||
self._join_hints = None
|
||||
|
||||
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
|
||||
"""Branch from the current scope to a new, inner scope"""
|
||||
|
@ -85,14 +86,17 @@ class Scope:
|
|||
self._subqueries = []
|
||||
self._derived_tables = []
|
||||
self._raw_columns = []
|
||||
self._join_hints = []
|
||||
|
||||
for node, parent, _ in self.walk(bfs=False):
|
||||
if node is self.expression:
|
||||
continue
|
||||
elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
|
||||
self._raw_columns.append(node)
|
||||
elif isinstance(node, exp.Table):
|
||||
elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
|
||||
self._tables.append(node)
|
||||
elif isinstance(node, exp.JoinHint):
|
||||
self._join_hints.append(node)
|
||||
elif isinstance(node, exp.UDTF):
|
||||
self._derived_tables.append(node)
|
||||
elif isinstance(node, exp.CTE):
|
||||
|
@ -246,7 +250,7 @@ class Scope:
|
|||
table only becomes a selected source if it's included in a FROM or JOIN clause.
|
||||
|
||||
Returns:
|
||||
dict[str, (exp.Table|exp.Subquery, exp.Table|Scope)]: selected sources and nodes
|
||||
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
|
||||
"""
|
||||
if self._selected_sources is None:
|
||||
referenced_names = []
|
||||
|
@ -310,6 +314,18 @@ class Scope:
|
|||
self._external_columns = [c for c in self.columns if c.table not in self.selected_sources]
|
||||
return self._external_columns
|
||||
|
||||
@property
|
||||
def join_hints(self):
|
||||
"""
|
||||
Hints that exist in the scope that reference tables
|
||||
|
||||
Returns:
|
||||
list[exp.JoinHint]: Join hints that are referenced within the scope
|
||||
"""
|
||||
if self._join_hints is None:
|
||||
return []
|
||||
return self._join_hints
|
||||
|
||||
def source_columns(self, source_name):
|
||||
"""
|
||||
Get all columns in the current scope for a particular source.
|
||||
|
|
|
@ -56,12 +56,16 @@ def simplify_not(expression):
|
|||
NOT (x AND y) -> NOT x OR NOT y
|
||||
"""
|
||||
if isinstance(expression, exp.Not):
|
||||
if isinstance(expression.this, exp.Null):
|
||||
return NULL
|
||||
if isinstance(expression.this, exp.Paren):
|
||||
condition = expression.this.unnest()
|
||||
if isinstance(condition, exp.And):
|
||||
return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
|
||||
if isinstance(condition, exp.Or):
|
||||
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
|
||||
if isinstance(condition, exp.Null):
|
||||
return NULL
|
||||
if always_true(expression.this):
|
||||
return FALSE
|
||||
if expression.this == FALSE:
|
||||
|
@ -95,10 +99,10 @@ def simplify_connectors(expression):
|
|||
return left
|
||||
|
||||
if isinstance(expression, exp.And):
|
||||
if NULL in (left, right):
|
||||
return NULL
|
||||
if FALSE in (left, right):
|
||||
return FALSE
|
||||
if NULL in (left, right):
|
||||
return NULL
|
||||
if always_true(left) and always_true(right):
|
||||
return TRUE
|
||||
if always_true(left):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue