1
0
Fork 0

Merging upstream version 10.0.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:53:05 +01:00
parent 528822bfd4
commit b7d21c45b7
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
98 changed files with 4080 additions and 1666 deletions

View file

@ -1,5 +1,5 @@
from sqlglot import exp
from sqlglot.helper import ensure_list, subclasses
from sqlglot.helper import ensure_collection, ensure_list, subclasses
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
@ -48,35 +48,65 @@ class TypeAnnotator:
exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.BIGINT
),
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
exp.CurrentTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DATETIME
),
exp.CurrentTime: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
exp.DatetimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DATETIME
),
exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DATETIME
),
exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.TimestampAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimestampSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.TimestampSub: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.DateStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.DateToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DATE
),
exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.If: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.GroupConcat: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
@ -88,32 +118,52 @@ class TypeAnnotator:
exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.RegexpLike: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DOUBLE
),
exp.RegexpLike: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.BOOLEAN
),
exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.StrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.StrToTime: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DATE
),
exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.UnixToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.UnixToTime: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.VariancePop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.VariancePop: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DOUBLE
),
exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
}
@ -124,7 +174,11 @@ class TypeAnnotator:
exp.DataType.Type.TEXT: set(),
exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
exp.DataType.Type.NCHAR: {exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
exp.DataType.Type.NCHAR: {
exp.DataType.Type.VARCHAR,
exp.DataType.Type.NVARCHAR,
exp.DataType.Type.TEXT,
},
exp.DataType.Type.CHAR: {
exp.DataType.Type.NCHAR,
exp.DataType.Type.VARCHAR,
@ -135,7 +189,11 @@ class TypeAnnotator:
exp.DataType.Type.DOUBLE: set(),
exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
exp.DataType.Type.BIGINT: {exp.DataType.Type.DECIMAL, exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
exp.DataType.Type.BIGINT: {
exp.DataType.Type.DECIMAL,
exp.DataType.Type.FLOAT,
exp.DataType.Type.DOUBLE,
},
exp.DataType.Type.INT: {
exp.DataType.Type.BIGINT,
exp.DataType.Type.DECIMAL,
@ -160,7 +218,10 @@ class TypeAnnotator:
# DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
exp.DataType.Type.TIMESTAMPLTZ: set(),
exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
exp.DataType.Type.TIMESTAMP: {exp.DataType.Type.TIMESTAMPTZ, exp.DataType.Type.TIMESTAMPLTZ},
exp.DataType.Type.TIMESTAMP: {
exp.DataType.Type.TIMESTAMPTZ,
exp.DataType.Type.TIMESTAMPLTZ,
},
exp.DataType.Type.DATETIME: {
exp.DataType.Type.TIMESTAMP,
exp.DataType.Type.TIMESTAMPTZ,
@ -219,7 +280,7 @@ class TypeAnnotator:
def _annotate_args(self, expression):
for value in expression.args.values():
for v in ensure_list(value):
for v in ensure_collection(value):
self._maybe_annotate(v)
return expression
@ -243,7 +304,9 @@ class TypeAnnotator:
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
expression.type = exp.DataType.Type.NULL
elif exp.DataType.Type.NULL in (left_type, right_type):
expression.type = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
expression.type = exp.DataType.build(
"NULLABLE", expressions=exp.DataType.build("BOOLEAN")
)
else:
expression.type = exp.DataType.Type.BOOLEAN
elif isinstance(expression, (exp.Condition, exp.Predicate)):
@ -276,3 +339,17 @@ class TypeAnnotator:
def _annotate_with_type(self, expression, target_type):
expression.type = target_type
return self._annotate_args(expression)
def _annotate_by_args(self, expression, *args):
self._annotate_args(expression)
expressions = []
for arg in args:
arg_expr = expression.args.get(arg)
expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
last_datatype = None
for expr in expressions:
last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
expression.type = last_datatype or exp.DataType.Type.UNKNOWN
return expression

View file

@ -60,7 +60,9 @@ def _join_is_used(scope, join, alias):
on_clause_columns = set(id(column) for column in on.find_all(exp.Column))
else:
on_clause_columns = set()
return any(column for column in scope.source_columns(alias) if id(column) not in on_clause_columns)
return any(
column for column in scope.source_columns(alias) if id(column) not in on_clause_columns
)
def _is_joined_on_all_unique_outputs(scope, join):

View file

@ -45,7 +45,13 @@ def eliminate_subqueries(expression):
# All table names are taken
for scope in root.traverse():
taken.update({source.name: source for _, source in scope.sources.items() if isinstance(source, exp.Table)})
taken.update(
{
source.name: source
for _, source in scope.sources.items()
if isinstance(source, exp.Table)
}
)
# Map of Expression->alias
# Existing CTES in the root expression. We'll use this for deduplication.
@ -70,7 +76,9 @@ def eliminate_subqueries(expression):
new_ctes.append(cte_scope.expression.parent)
# Now append the rest
for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.derived_table_scopes):
for scope in itertools.chain(
root.union_scopes, root.subquery_scopes, root.derived_table_scopes
):
for child_scope in scope.traverse():
new_cte = _eliminate(child_scope, existing_ctes, taken)
if new_cte:

View file

@ -122,7 +122,9 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
unmergable_window_columns = [
column
for column in outer_scope.columns
if column.find_ancestor(exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc)
if column.find_ancestor(
exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc
)
]
window_expressions_in_unmergable = [
column
@ -147,7 +149,9 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
and not (
isinstance(from_or_join, exp.From)
and inner_select.args.get("where")
and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []))
and any(
j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])
)
)
and not _is_a_window_expression_in_unmergable_operation()
)
@ -203,7 +207,9 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
if table.alias_or_name == node_to_replace.alias_or_name:
table.set("this", exp.to_identifier(new_subquery.alias_or_name))
outer_scope.remove_source(alias)
outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
outer_scope.add_source(
new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]
)
def _merge_joins(outer_scope, inner_scope, from_or_join):
@ -296,7 +302,9 @@ def _merge_order(outer_scope, inner_scope):
inner_scope (sqlglot.optimizer.scope.Scope)
"""
if (
any(outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"])
any(
outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"]
)
or len(outer_scope.selected_sources) != 1
or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions)
):

View file

@ -50,7 +50,9 @@ def normalization_distance(expression, dnf=False):
Returns:
int: difference
"""
return sum(_predicate_lengths(expression, dnf)) - (len(list(expression.find_all(exp.Connector))) + 1)
return sum(_predicate_lengths(expression, dnf)) - (
len(list(expression.find_all(exp.Connector))) + 1
)
def _predicate_lengths(expression, dnf):

View file

@ -68,4 +68,8 @@ def normalize(expression):
def other_table_names(join, exclude):
return [name for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) if name != exclude]
return [
name
for name in (exp.column_table_names(join.args.get("on") or exp.TRUE))
if name != exclude
]

View file

@ -58,6 +58,8 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
# Find any additional rule parameters, beyond `expression`
rule_params = rule.__code__.co_varnames
rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs}
rule_kwargs = {
param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
}
expression = rule(expression, **rule_kwargs)
return expression

View file

@ -45,7 +45,11 @@ def pushdown(condition, sources, scope_ref_count):
condition = condition.replace(simplify(condition))
cnf_like = normalized(condition) or not normalized(condition, dnf=True)
predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition])
predicates = list(
condition.flatten()
if isinstance(condition, exp.And if cnf_like else exp.Or)
else [condition]
)
if cnf_like:
pushdown_cnf(predicates, sources, scope_ref_count)
@ -104,11 +108,17 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
for column in predicate.find_all(exp.Column):
if column.table == table:
condition = column.find_ancestor(exp.Condition)
predicate_condition = exp.and_(predicate_condition, condition) if predicate_condition else condition
predicate_condition = (
exp.and_(predicate_condition, condition)
if predicate_condition
else condition
)
if predicate_condition:
conditions[table] = (
exp.or_(conditions[table], predicate_condition) if table in conditions else predicate_condition
exp.or_(conditions[table], predicate_condition)
if table in conditions
else predicate_condition
)
for name, node in nodes.items():
@ -146,10 +156,16 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
nodes[table] = node
elif isinstance(node, exp.Select) and len(tables) == 1:
# We can't push down window expressions
has_window_expression = any(select for select in node.selects if select.find(exp.Window))
has_window_expression = any(
select for select in node.selects if select.find(exp.Window)
)
# we can't push down predicates to select statements if they are referenced in
# multiple places.
if not node.args.get("group") and scope_ref_count[id(source)] < 2 and not has_window_expression:
if (
not node.args.get("group")
and scope_ref_count[id(source)] < 2
and not has_window_expression
):
nodes[table] = node
return nodes
@ -165,7 +181,7 @@ def replace_aliases(source, predicate):
def _replace_alias(column):
if isinstance(column, exp.Column) and column.name in aliases:
return aliases[column.name]
return aliases[column.name].copy()
return column
return predicate.transform(_replace_alias)

View file

@ -98,7 +98,9 @@ def _remove_unused_selections(scope, parent_selections):
def _remove_indexed_selections(scope, indexes_to_remove):
new_selections = [selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove]
new_selections = [
selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
]
if not new_selections:
new_selections.append(DEFAULT_SELECTION)
scope.expression.set("expressions", new_selections)

View file

@ -215,13 +215,21 @@ def _qualify_columns(scope, resolver):
# Determine whether each reference in the order by clause is to a column or an alias.
for ordered in scope.find_all(exp.Ordered):
for column in ordered.find_all(exp.Column):
if not column.table and column.parent is not ordered and column.name in resolver.all_columns:
if (
not column.table
and column.parent is not ordered
and column.name in resolver.all_columns
):
columns_missing_from_scope.append(column)
# Determine whether each reference in the having clause is to a column or an alias.
for having in scope.find_all(exp.Having):
for column in having.find_all(exp.Column):
if not column.table and column.find_ancestor(exp.AggFunc) and column.name in resolver.all_columns:
if (
not column.table
and column.find_ancestor(exp.AggFunc)
and column.name in resolver.all_columns
):
columns_missing_from_scope.append(column)
for column in columns_missing_from_scope:
@ -295,7 +303,9 @@ def _qualify_outputs(scope):
"""Ensure all output columns are aliased"""
new_selections = []
for i, (selection, aliased_column) in enumerate(itertools.zip_longest(scope.selects, scope.outer_column_list)):
for i, (selection, aliased_column) in enumerate(
itertools.zip_longest(scope.selects, scope.outer_column_list)
):
if isinstance(selection, exp.Column):
# convoluted setter because a simple selection.replace(alias) would require a copy
alias_ = alias(exp.column(""), alias=selection.name)
@ -343,14 +353,18 @@ class _Resolver:
(str) table name
"""
if self._unambiguous_columns is None:
self._unambiguous_columns = self._get_unambiguous_columns(self._get_all_source_columns())
self._unambiguous_columns = self._get_unambiguous_columns(
self._get_all_source_columns()
)
return self._unambiguous_columns.get(column_name)
@property
def all_columns(self):
"""All available columns of all sources in this scope"""
if self._all_columns is None:
self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns)
self._all_columns = set(
column for columns in self._get_all_source_columns().values() for column in columns
)
return self._all_columns
def get_source_columns(self, name, only_visible=False):
@ -377,7 +391,9 @@ class _Resolver:
def _get_all_source_columns(self):
if self._source_columns is None:
self._source_columns = {k: self.get_source_columns(k) for k in self.scope.selected_sources}
self._source_columns = {
k: self.get_source_columns(k) for k in self.scope.selected_sources
}
return self._source_columns
def _get_unambiguous_columns(self, source_columns):

View file

@ -226,7 +226,9 @@ class Scope:
self._ensure_collected()
columns = self._raw_columns
external_columns = [column for scope in self.subquery_scopes for column in scope.external_columns]
external_columns = [
column for scope in self.subquery_scopes for column in scope.external_columns
]
named_outputs = {e.alias_or_name for e in self.expression.expressions}
@ -278,7 +280,11 @@ class Scope:
Returns:
dict[str, Scope]: Mapping of source alias to Scope
"""
return {alias: scope for alias, scope in self.sources.items() if isinstance(scope, Scope) and scope.is_cte}
return {
alias: scope
for alias, scope in self.sources.items()
if isinstance(scope, Scope) and scope.is_cte
}
@property
def selects(self):
@ -307,7 +313,9 @@ class Scope:
sources in the current scope.
"""
if self._external_columns is None:
self._external_columns = [c for c in self.columns if c.table not in self.selected_sources]
self._external_columns = [
c for c in self.columns if c.table not in self.selected_sources
]
return self._external_columns
@property

View file

@ -229,7 +229,9 @@ def simplify_literals(expression):
operands.append(a)
if len(operands) < size:
return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
return functools.reduce(
lambda a, b: expression.__class__(this=a, expression=b), operands
)
elif isinstance(expression, exp.Neg):
this = expression.this
if this.is_number:
@ -255,6 +257,12 @@ def _simplify_binary(expression, a, b):
return TRUE if not_ else FALSE
if a == NULL:
return FALSE if not_ else TRUE
elif isinstance(expression, exp.NullSafeEQ):
if a == b:
return TRUE
elif isinstance(expression, exp.NullSafeNEQ):
if a == b:
return FALSE
elif NULL in (a, b):
return NULL
@ -357,7 +365,7 @@ def extract_date(cast):
def extract_interval(interval):
try:
from dateutil.relativedelta import relativedelta
from dateutil.relativedelta import relativedelta # type: ignore
except ModuleNotFoundError:
return None

View file

@ -89,7 +89,11 @@ def decorrelate(select, parent_select, external_columns, sequence):
return
if isinstance(predicate, exp.Binary):
key = predicate.right if any(node is column for node, *_ in predicate.left.walk()) else predicate.left
key = (
predicate.right
if any(node is column for node, *_ in predicate.left.walk())
else predicate.left
)
else:
return
@ -145,7 +149,9 @@ def decorrelate(select, parent_select, external_columns, sequence):
else:
parent_predicate = _replace(parent_predicate, "TRUE")
elif isinstance(parent_predicate, exp.All):
parent_predicate = _replace(parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})")
parent_predicate = _replace(
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
)
elif isinstance(parent_predicate, exp.Any):
if value.this in group_by:
parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
@ -168,7 +174,9 @@ def decorrelate(select, parent_select, external_columns, sequence):
if key in group_by:
key.replace(nested)
parent_predicate = _replace(parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)")
parent_predicate = _replace(
parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
)
elif isinstance(predicate, exp.EQ):
parent_predicate = _replace(
parent_predicate,