Merging upstream version 10.0.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
528822bfd4
commit
b7d21c45b7
98 changed files with 4080 additions and 1666 deletions
|
@ -1,5 +1,5 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.helper import ensure_list, subclasses
|
||||
from sqlglot.helper import ensure_collection, ensure_list, subclasses
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import ensure_schema
|
||||
|
||||
|
@ -48,35 +48,65 @@ class TypeAnnotator:
|
|||
exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||
exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
|
||||
exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
|
||||
exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||
exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.BIGINT
|
||||
),
|
||||
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
|
||||
exp.CurrentTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DATETIME
|
||||
),
|
||||
exp.CurrentTime: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
|
||||
exp.DatetimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
|
||||
exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DATETIME
|
||||
),
|
||||
exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DATETIME
|
||||
),
|
||||
exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.TimestampAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.TimestampSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.TimestampSub: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.DateStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.DateToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DATE
|
||||
),
|
||||
exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||
exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.If: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||
exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
|
||||
exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
|
||||
exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
|
||||
exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
|
||||
exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.GroupConcat: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
|
@ -88,32 +118,52 @@ class TypeAnnotator:
|
|||
exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||
exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.RegexpLike: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||
exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DOUBLE
|
||||
),
|
||||
exp.RegexpLike: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.BOOLEAN
|
||||
),
|
||||
exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.StrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.StrToTime: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DATE
|
||||
),
|
||||
exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.UnixToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.UnixToTime: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.VariancePop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.VariancePop: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DOUBLE
|
||||
),
|
||||
exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||
exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||
}
|
||||
|
@ -124,7 +174,11 @@ class TypeAnnotator:
|
|||
exp.DataType.Type.TEXT: set(),
|
||||
exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
|
||||
exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
|
||||
exp.DataType.Type.NCHAR: {exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
|
||||
exp.DataType.Type.NCHAR: {
|
||||
exp.DataType.Type.VARCHAR,
|
||||
exp.DataType.Type.NVARCHAR,
|
||||
exp.DataType.Type.TEXT,
|
||||
},
|
||||
exp.DataType.Type.CHAR: {
|
||||
exp.DataType.Type.NCHAR,
|
||||
exp.DataType.Type.VARCHAR,
|
||||
|
@ -135,7 +189,11 @@ class TypeAnnotator:
|
|||
exp.DataType.Type.DOUBLE: set(),
|
||||
exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
|
||||
exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
|
||||
exp.DataType.Type.BIGINT: {exp.DataType.Type.DECIMAL, exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
|
||||
exp.DataType.Type.BIGINT: {
|
||||
exp.DataType.Type.DECIMAL,
|
||||
exp.DataType.Type.FLOAT,
|
||||
exp.DataType.Type.DOUBLE,
|
||||
},
|
||||
exp.DataType.Type.INT: {
|
||||
exp.DataType.Type.BIGINT,
|
||||
exp.DataType.Type.DECIMAL,
|
||||
|
@ -160,7 +218,10 @@ class TypeAnnotator:
|
|||
# DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
|
||||
exp.DataType.Type.TIMESTAMPLTZ: set(),
|
||||
exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
|
||||
exp.DataType.Type.TIMESTAMP: {exp.DataType.Type.TIMESTAMPTZ, exp.DataType.Type.TIMESTAMPLTZ},
|
||||
exp.DataType.Type.TIMESTAMP: {
|
||||
exp.DataType.Type.TIMESTAMPTZ,
|
||||
exp.DataType.Type.TIMESTAMPLTZ,
|
||||
},
|
||||
exp.DataType.Type.DATETIME: {
|
||||
exp.DataType.Type.TIMESTAMP,
|
||||
exp.DataType.Type.TIMESTAMPTZ,
|
||||
|
@ -219,7 +280,7 @@ class TypeAnnotator:
|
|||
|
||||
def _annotate_args(self, expression):
|
||||
for value in expression.args.values():
|
||||
for v in ensure_list(value):
|
||||
for v in ensure_collection(value):
|
||||
self._maybe_annotate(v)
|
||||
|
||||
return expression
|
||||
|
@ -243,7 +304,9 @@ class TypeAnnotator:
|
|||
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
|
||||
expression.type = exp.DataType.Type.NULL
|
||||
elif exp.DataType.Type.NULL in (left_type, right_type):
|
||||
expression.type = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
|
||||
expression.type = exp.DataType.build(
|
||||
"NULLABLE", expressions=exp.DataType.build("BOOLEAN")
|
||||
)
|
||||
else:
|
||||
expression.type = exp.DataType.Type.BOOLEAN
|
||||
elif isinstance(expression, (exp.Condition, exp.Predicate)):
|
||||
|
@ -276,3 +339,17 @@ class TypeAnnotator:
|
|||
def _annotate_with_type(self, expression, target_type):
|
||||
expression.type = target_type
|
||||
return self._annotate_args(expression)
|
||||
|
||||
def _annotate_by_args(self, expression, *args):
|
||||
self._annotate_args(expression)
|
||||
expressions = []
|
||||
for arg in args:
|
||||
arg_expr = expression.args.get(arg)
|
||||
expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
|
||||
|
||||
last_datatype = None
|
||||
for expr in expressions:
|
||||
last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
|
||||
|
||||
expression.type = last_datatype or exp.DataType.Type.UNKNOWN
|
||||
return expression
|
||||
|
|
|
@ -60,7 +60,9 @@ def _join_is_used(scope, join, alias):
|
|||
on_clause_columns = set(id(column) for column in on.find_all(exp.Column))
|
||||
else:
|
||||
on_clause_columns = set()
|
||||
return any(column for column in scope.source_columns(alias) if id(column) not in on_clause_columns)
|
||||
return any(
|
||||
column for column in scope.source_columns(alias) if id(column) not in on_clause_columns
|
||||
)
|
||||
|
||||
|
||||
def _is_joined_on_all_unique_outputs(scope, join):
|
||||
|
|
|
@ -45,7 +45,13 @@ def eliminate_subqueries(expression):
|
|||
|
||||
# All table names are taken
|
||||
for scope in root.traverse():
|
||||
taken.update({source.name: source for _, source in scope.sources.items() if isinstance(source, exp.Table)})
|
||||
taken.update(
|
||||
{
|
||||
source.name: source
|
||||
for _, source in scope.sources.items()
|
||||
if isinstance(source, exp.Table)
|
||||
}
|
||||
)
|
||||
|
||||
# Map of Expression->alias
|
||||
# Existing CTES in the root expression. We'll use this for deduplication.
|
||||
|
@ -70,7 +76,9 @@ def eliminate_subqueries(expression):
|
|||
new_ctes.append(cte_scope.expression.parent)
|
||||
|
||||
# Now append the rest
|
||||
for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.derived_table_scopes):
|
||||
for scope in itertools.chain(
|
||||
root.union_scopes, root.subquery_scopes, root.derived_table_scopes
|
||||
):
|
||||
for child_scope in scope.traverse():
|
||||
new_cte = _eliminate(child_scope, existing_ctes, taken)
|
||||
if new_cte:
|
||||
|
|
|
@ -122,7 +122,9 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
|||
unmergable_window_columns = [
|
||||
column
|
||||
for column in outer_scope.columns
|
||||
if column.find_ancestor(exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc)
|
||||
if column.find_ancestor(
|
||||
exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc
|
||||
)
|
||||
]
|
||||
window_expressions_in_unmergable = [
|
||||
column
|
||||
|
@ -147,7 +149,9 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
|||
and not (
|
||||
isinstance(from_or_join, exp.From)
|
||||
and inner_select.args.get("where")
|
||||
and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []))
|
||||
and any(
|
||||
j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])
|
||||
)
|
||||
)
|
||||
and not _is_a_window_expression_in_unmergable_operation()
|
||||
)
|
||||
|
@ -203,7 +207,9 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
|
|||
if table.alias_or_name == node_to_replace.alias_or_name:
|
||||
table.set("this", exp.to_identifier(new_subquery.alias_or_name))
|
||||
outer_scope.remove_source(alias)
|
||||
outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
|
||||
outer_scope.add_source(
|
||||
new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]
|
||||
)
|
||||
|
||||
|
||||
def _merge_joins(outer_scope, inner_scope, from_or_join):
|
||||
|
@ -296,7 +302,9 @@ def _merge_order(outer_scope, inner_scope):
|
|||
inner_scope (sqlglot.optimizer.scope.Scope)
|
||||
"""
|
||||
if (
|
||||
any(outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"])
|
||||
any(
|
||||
outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"]
|
||||
)
|
||||
or len(outer_scope.selected_sources) != 1
|
||||
or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions)
|
||||
):
|
||||
|
|
|
@ -50,7 +50,9 @@ def normalization_distance(expression, dnf=False):
|
|||
Returns:
|
||||
int: difference
|
||||
"""
|
||||
return sum(_predicate_lengths(expression, dnf)) - (len(list(expression.find_all(exp.Connector))) + 1)
|
||||
return sum(_predicate_lengths(expression, dnf)) - (
|
||||
len(list(expression.find_all(exp.Connector))) + 1
|
||||
)
|
||||
|
||||
|
||||
def _predicate_lengths(expression, dnf):
|
||||
|
|
|
@ -68,4 +68,8 @@ def normalize(expression):
|
|||
|
||||
|
||||
def other_table_names(join, exclude):
|
||||
return [name for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) if name != exclude]
|
||||
return [
|
||||
name
|
||||
for name in (exp.column_table_names(join.args.get("on") or exp.TRUE))
|
||||
if name != exclude
|
||||
]
|
||||
|
|
|
@ -58,6 +58,8 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
|
|||
|
||||
# Find any additional rule parameters, beyond `expression`
|
||||
rule_params = rule.__code__.co_varnames
|
||||
rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs}
|
||||
rule_kwargs = {
|
||||
param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
|
||||
}
|
||||
expression = rule(expression, **rule_kwargs)
|
||||
return expression
|
||||
|
|
|
@ -45,7 +45,11 @@ def pushdown(condition, sources, scope_ref_count):
|
|||
condition = condition.replace(simplify(condition))
|
||||
cnf_like = normalized(condition) or not normalized(condition, dnf=True)
|
||||
|
||||
predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition])
|
||||
predicates = list(
|
||||
condition.flatten()
|
||||
if isinstance(condition, exp.And if cnf_like else exp.Or)
|
||||
else [condition]
|
||||
)
|
||||
|
||||
if cnf_like:
|
||||
pushdown_cnf(predicates, sources, scope_ref_count)
|
||||
|
@ -104,11 +108,17 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
|
|||
for column in predicate.find_all(exp.Column):
|
||||
if column.table == table:
|
||||
condition = column.find_ancestor(exp.Condition)
|
||||
predicate_condition = exp.and_(predicate_condition, condition) if predicate_condition else condition
|
||||
predicate_condition = (
|
||||
exp.and_(predicate_condition, condition)
|
||||
if predicate_condition
|
||||
else condition
|
||||
)
|
||||
|
||||
if predicate_condition:
|
||||
conditions[table] = (
|
||||
exp.or_(conditions[table], predicate_condition) if table in conditions else predicate_condition
|
||||
exp.or_(conditions[table], predicate_condition)
|
||||
if table in conditions
|
||||
else predicate_condition
|
||||
)
|
||||
|
||||
for name, node in nodes.items():
|
||||
|
@ -146,10 +156,16 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
|
|||
nodes[table] = node
|
||||
elif isinstance(node, exp.Select) and len(tables) == 1:
|
||||
# We can't push down window expressions
|
||||
has_window_expression = any(select for select in node.selects if select.find(exp.Window))
|
||||
has_window_expression = any(
|
||||
select for select in node.selects if select.find(exp.Window)
|
||||
)
|
||||
# we can't push down predicates to select statements if they are referenced in
|
||||
# multiple places.
|
||||
if not node.args.get("group") and scope_ref_count[id(source)] < 2 and not has_window_expression:
|
||||
if (
|
||||
not node.args.get("group")
|
||||
and scope_ref_count[id(source)] < 2
|
||||
and not has_window_expression
|
||||
):
|
||||
nodes[table] = node
|
||||
return nodes
|
||||
|
||||
|
@ -165,7 +181,7 @@ def replace_aliases(source, predicate):
|
|||
|
||||
def _replace_alias(column):
|
||||
if isinstance(column, exp.Column) and column.name in aliases:
|
||||
return aliases[column.name]
|
||||
return aliases[column.name].copy()
|
||||
return column
|
||||
|
||||
return predicate.transform(_replace_alias)
|
||||
|
|
|
@ -98,7 +98,9 @@ def _remove_unused_selections(scope, parent_selections):
|
|||
|
||||
|
||||
def _remove_indexed_selections(scope, indexes_to_remove):
|
||||
new_selections = [selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove]
|
||||
new_selections = [
|
||||
selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
|
||||
]
|
||||
if not new_selections:
|
||||
new_selections.append(DEFAULT_SELECTION)
|
||||
scope.expression.set("expressions", new_selections)
|
||||
|
|
|
@ -215,13 +215,21 @@ def _qualify_columns(scope, resolver):
|
|||
# Determine whether each reference in the order by clause is to a column or an alias.
|
||||
for ordered in scope.find_all(exp.Ordered):
|
||||
for column in ordered.find_all(exp.Column):
|
||||
if not column.table and column.parent is not ordered and column.name in resolver.all_columns:
|
||||
if (
|
||||
not column.table
|
||||
and column.parent is not ordered
|
||||
and column.name in resolver.all_columns
|
||||
):
|
||||
columns_missing_from_scope.append(column)
|
||||
|
||||
# Determine whether each reference in the having clause is to a column or an alias.
|
||||
for having in scope.find_all(exp.Having):
|
||||
for column in having.find_all(exp.Column):
|
||||
if not column.table and column.find_ancestor(exp.AggFunc) and column.name in resolver.all_columns:
|
||||
if (
|
||||
not column.table
|
||||
and column.find_ancestor(exp.AggFunc)
|
||||
and column.name in resolver.all_columns
|
||||
):
|
||||
columns_missing_from_scope.append(column)
|
||||
|
||||
for column in columns_missing_from_scope:
|
||||
|
@ -295,7 +303,9 @@ def _qualify_outputs(scope):
|
|||
"""Ensure all output columns are aliased"""
|
||||
new_selections = []
|
||||
|
||||
for i, (selection, aliased_column) in enumerate(itertools.zip_longest(scope.selects, scope.outer_column_list)):
|
||||
for i, (selection, aliased_column) in enumerate(
|
||||
itertools.zip_longest(scope.selects, scope.outer_column_list)
|
||||
):
|
||||
if isinstance(selection, exp.Column):
|
||||
# convoluted setter because a simple selection.replace(alias) would require a copy
|
||||
alias_ = alias(exp.column(""), alias=selection.name)
|
||||
|
@ -343,14 +353,18 @@ class _Resolver:
|
|||
(str) table name
|
||||
"""
|
||||
if self._unambiguous_columns is None:
|
||||
self._unambiguous_columns = self._get_unambiguous_columns(self._get_all_source_columns())
|
||||
self._unambiguous_columns = self._get_unambiguous_columns(
|
||||
self._get_all_source_columns()
|
||||
)
|
||||
return self._unambiguous_columns.get(column_name)
|
||||
|
||||
@property
|
||||
def all_columns(self):
|
||||
"""All available columns of all sources in this scope"""
|
||||
if self._all_columns is None:
|
||||
self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns)
|
||||
self._all_columns = set(
|
||||
column for columns in self._get_all_source_columns().values() for column in columns
|
||||
)
|
||||
return self._all_columns
|
||||
|
||||
def get_source_columns(self, name, only_visible=False):
|
||||
|
@ -377,7 +391,9 @@ class _Resolver:
|
|||
|
||||
def _get_all_source_columns(self):
|
||||
if self._source_columns is None:
|
||||
self._source_columns = {k: self.get_source_columns(k) for k in self.scope.selected_sources}
|
||||
self._source_columns = {
|
||||
k: self.get_source_columns(k) for k in self.scope.selected_sources
|
||||
}
|
||||
return self._source_columns
|
||||
|
||||
def _get_unambiguous_columns(self, source_columns):
|
||||
|
|
|
@ -226,7 +226,9 @@ class Scope:
|
|||
self._ensure_collected()
|
||||
columns = self._raw_columns
|
||||
|
||||
external_columns = [column for scope in self.subquery_scopes for column in scope.external_columns]
|
||||
external_columns = [
|
||||
column for scope in self.subquery_scopes for column in scope.external_columns
|
||||
]
|
||||
|
||||
named_outputs = {e.alias_or_name for e in self.expression.expressions}
|
||||
|
||||
|
@ -278,7 +280,11 @@ class Scope:
|
|||
Returns:
|
||||
dict[str, Scope]: Mapping of source alias to Scope
|
||||
"""
|
||||
return {alias: scope for alias, scope in self.sources.items() if isinstance(scope, Scope) and scope.is_cte}
|
||||
return {
|
||||
alias: scope
|
||||
for alias, scope in self.sources.items()
|
||||
if isinstance(scope, Scope) and scope.is_cte
|
||||
}
|
||||
|
||||
@property
|
||||
def selects(self):
|
||||
|
@ -307,7 +313,9 @@ class Scope:
|
|||
sources in the current scope.
|
||||
"""
|
||||
if self._external_columns is None:
|
||||
self._external_columns = [c for c in self.columns if c.table not in self.selected_sources]
|
||||
self._external_columns = [
|
||||
c for c in self.columns if c.table not in self.selected_sources
|
||||
]
|
||||
return self._external_columns
|
||||
|
||||
@property
|
||||
|
|
|
@ -229,7 +229,9 @@ def simplify_literals(expression):
|
|||
operands.append(a)
|
||||
|
||||
if len(operands) < size:
|
||||
return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
|
||||
return functools.reduce(
|
||||
lambda a, b: expression.__class__(this=a, expression=b), operands
|
||||
)
|
||||
elif isinstance(expression, exp.Neg):
|
||||
this = expression.this
|
||||
if this.is_number:
|
||||
|
@ -255,6 +257,12 @@ def _simplify_binary(expression, a, b):
|
|||
return TRUE if not_ else FALSE
|
||||
if a == NULL:
|
||||
return FALSE if not_ else TRUE
|
||||
elif isinstance(expression, exp.NullSafeEQ):
|
||||
if a == b:
|
||||
return TRUE
|
||||
elif isinstance(expression, exp.NullSafeNEQ):
|
||||
if a == b:
|
||||
return FALSE
|
||||
elif NULL in (a, b):
|
||||
return NULL
|
||||
|
||||
|
@ -357,7 +365,7 @@ def extract_date(cast):
|
|||
|
||||
def extract_interval(interval):
|
||||
try:
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from dateutil.relativedelta import relativedelta # type: ignore
|
||||
except ModuleNotFoundError:
|
||||
return None
|
||||
|
||||
|
|
|
@ -89,7 +89,11 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
return
|
||||
|
||||
if isinstance(predicate, exp.Binary):
|
||||
key = predicate.right if any(node is column for node, *_ in predicate.left.walk()) else predicate.left
|
||||
key = (
|
||||
predicate.right
|
||||
if any(node is column for node, *_ in predicate.left.walk())
|
||||
else predicate.left
|
||||
)
|
||||
else:
|
||||
return
|
||||
|
||||
|
@ -145,7 +149,9 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
else:
|
||||
parent_predicate = _replace(parent_predicate, "TRUE")
|
||||
elif isinstance(parent_predicate, exp.All):
|
||||
parent_predicate = _replace(parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})")
|
||||
parent_predicate = _replace(
|
||||
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
|
||||
)
|
||||
elif isinstance(parent_predicate, exp.Any):
|
||||
if value.this in group_by:
|
||||
parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
|
||||
|
@ -168,7 +174,9 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
|
||||
if key in group_by:
|
||||
key.replace(nested)
|
||||
parent_predicate = _replace(parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)")
|
||||
parent_predicate = _replace(
|
||||
parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
|
||||
)
|
||||
elif isinstance(predicate, exp.EQ):
|
||||
parent_predicate = _replace(
|
||||
parent_predicate,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue