Merging upstream version 6.1.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
3c6d649c90
commit
08ecea3adf
61 changed files with 1844 additions and 1555 deletions
|
@ -1,2 +1,2 @@
|
|||
from sqlglot.optimizer.optimizer import optimize
|
||||
from sqlglot.optimizer.optimizer import RULES, optimize
|
||||
from sqlglot.optimizer.schema import Schema
|
||||
|
|
|
@ -13,9 +13,7 @@ def isolate_table_selects(expression):
|
|||
continue
|
||||
|
||||
if not isinstance(source.parent, exp.Alias):
|
||||
raise OptimizeError(
|
||||
"Tables require an alias. Run qualify_tables optimization."
|
||||
)
|
||||
raise OptimizeError("Tables require an alias. Run qualify_tables optimization.")
|
||||
|
||||
parent = source.parent
|
||||
|
||||
|
|
232
sqlglot/optimizer/merge_derived_tables.py
Normal file
232
sqlglot/optimizer/merge_derived_tables.py
Normal file
|
@ -0,0 +1,232 @@
|
|||
from collections import defaultdict
|
||||
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
|
||||
def merge_derived_tables(expression):
|
||||
"""
|
||||
Rewrite sqlglot AST to merge derived tables into the outer query.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x)")
|
||||
>>> merge_derived_tables(expression).sql()
|
||||
'SELECT x.a FROM x'
|
||||
|
||||
Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to optimize
|
||||
Returns:
|
||||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
for outer_scope in traverse_scope(expression):
|
||||
for subquery in outer_scope.derived_tables:
|
||||
inner_select = subquery.unnest()
|
||||
if (
|
||||
isinstance(outer_scope.expression, exp.Select)
|
||||
and isinstance(inner_select, exp.Select)
|
||||
and _mergeable(inner_select)
|
||||
):
|
||||
alias = subquery.alias_or_name
|
||||
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
|
||||
inner_scope = outer_scope.sources[alias]
|
||||
|
||||
_rename_inner_sources(outer_scope, inner_scope, alias)
|
||||
_merge_from(outer_scope, inner_scope, subquery)
|
||||
_merge_joins(outer_scope, inner_scope, from_or_join)
|
||||
_merge_expressions(outer_scope, inner_scope, alias)
|
||||
_merge_where(outer_scope, inner_scope, from_or_join)
|
||||
_merge_order(outer_scope, inner_scope)
|
||||
return expression
|
||||
|
||||
|
||||
# If a derived table has these Select args, it can't be merged
|
||||
UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
|
||||
"expressions",
|
||||
"from",
|
||||
"joins",
|
||||
"where",
|
||||
"order",
|
||||
}
|
||||
|
||||
|
||||
def _mergeable(inner_select):
|
||||
"""
|
||||
Return True if `inner_select` can be merged into outer query.
|
||||
|
||||
Args:
|
||||
inner_select (exp.Select)
|
||||
Returns:
|
||||
bool: True if can be merged
|
||||
"""
|
||||
return (
|
||||
isinstance(inner_select, exp.Select)
|
||||
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
|
||||
and inner_select.args.get("from")
|
||||
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
|
||||
)
|
||||
|
||||
|
||||
def _rename_inner_sources(outer_scope, inner_scope, alias):
|
||||
"""
|
||||
Renames any sources in the inner query that conflict with names in the outer query.
|
||||
|
||||
Args:
|
||||
outer_scope (sqlglot.optimizer.scope.Scope)
|
||||
inner_scope (sqlglot.optimizer.scope.Scope)
|
||||
alias (str)
|
||||
"""
|
||||
taken = set(outer_scope.selected_sources)
|
||||
conflicts = taken.intersection(set(inner_scope.selected_sources))
|
||||
conflicts = conflicts - {alias}
|
||||
|
||||
for conflict in conflicts:
|
||||
new_name = _find_new_name(taken, conflict)
|
||||
|
||||
source, _ = inner_scope.selected_sources[conflict]
|
||||
new_alias = exp.to_identifier(new_name)
|
||||
|
||||
if isinstance(source, exp.Subquery):
|
||||
source.set("alias", exp.TableAlias(this=new_alias))
|
||||
elif isinstance(source, exp.Table) and isinstance(source.parent, exp.Alias):
|
||||
source.parent.set("alias", new_alias)
|
||||
elif isinstance(source, exp.Table):
|
||||
source.replace(exp.alias_(source.copy(), new_alias))
|
||||
|
||||
for column in inner_scope.source_columns(conflict):
|
||||
column.set("table", exp.to_identifier(new_name))
|
||||
|
||||
inner_scope.rename_source(conflict, new_name)
|
||||
|
||||
|
||||
def _find_new_name(taken, base):
|
||||
"""
|
||||
Searches for a new source name.
|
||||
|
||||
Args:
|
||||
taken (set[str]): set of taken names
|
||||
base (str): base name to alter
|
||||
"""
|
||||
i = 2
|
||||
new = f"{base}_{i}"
|
||||
while new in taken:
|
||||
i += 1
|
||||
new = f"{base}_{i}"
|
||||
return new
|
||||
|
||||
|
||||
def _merge_from(outer_scope, inner_scope, subquery):
|
||||
"""
|
||||
Merge FROM clause of inner query into outer query.
|
||||
|
||||
Args:
|
||||
outer_scope (sqlglot.optimizer.scope.Scope)
|
||||
inner_scope (sqlglot.optimizer.scope.Scope)
|
||||
subquery (exp.Subquery)
|
||||
"""
|
||||
new_subquery = inner_scope.expression.args.get("from").expressions[0]
|
||||
subquery.replace(new_subquery)
|
||||
outer_scope.remove_source(subquery.alias_or_name)
|
||||
outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
|
||||
|
||||
|
||||
def _merge_joins(outer_scope, inner_scope, from_or_join):
|
||||
"""
|
||||
Merge JOIN clauses of inner query into outer query.
|
||||
|
||||
Args:
|
||||
outer_scope (sqlglot.optimizer.scope.Scope)
|
||||
inner_scope (sqlglot.optimizer.scope.Scope)
|
||||
from_or_join (exp.From|exp.Join)
|
||||
"""
|
||||
|
||||
new_joins = []
|
||||
comma_joins = inner_scope.expression.args.get("from").expressions[1:]
|
||||
for subquery in comma_joins:
|
||||
new_joins.append(exp.Join(this=subquery, kind="CROSS"))
|
||||
outer_scope.add_source(subquery.alias_or_name, inner_scope.sources[subquery.alias_or_name])
|
||||
|
||||
joins = inner_scope.expression.args.get("joins") or []
|
||||
for join in joins:
|
||||
new_joins.append(join)
|
||||
outer_scope.add_source(join.alias_or_name, inner_scope.sources[join.alias_or_name])
|
||||
|
||||
if new_joins:
|
||||
outer_joins = outer_scope.expression.args.get("joins", [])
|
||||
|
||||
# Maintain the join order
|
||||
if isinstance(from_or_join, exp.From):
|
||||
position = 0
|
||||
else:
|
||||
position = outer_joins.index(from_or_join) + 1
|
||||
outer_joins[position:position] = new_joins
|
||||
|
||||
outer_scope.expression.set("joins", outer_joins)
|
||||
|
||||
|
||||
def _merge_expressions(outer_scope, inner_scope, alias):
|
||||
"""
|
||||
Merge projections of inner query into outer query.
|
||||
|
||||
Args:
|
||||
outer_scope (sqlglot.optimizer.scope.Scope)
|
||||
inner_scope (sqlglot.optimizer.scope.Scope)
|
||||
alias (str)
|
||||
"""
|
||||
# Collect all columns that for the alias of the inner query
|
||||
outer_columns = defaultdict(list)
|
||||
for column in outer_scope.columns:
|
||||
if column.table == alias:
|
||||
outer_columns[column.name].append(column)
|
||||
|
||||
# Replace columns with the projection expression in the inner query
|
||||
for expression in inner_scope.expression.expressions:
|
||||
projection_name = expression.alias_or_name
|
||||
if not projection_name:
|
||||
continue
|
||||
columns_to_replace = outer_columns.get(projection_name, [])
|
||||
for column in columns_to_replace:
|
||||
column.replace(expression.unalias())
|
||||
|
||||
|
||||
def _merge_where(outer_scope, inner_scope, from_or_join):
|
||||
"""
|
||||
Merge WHERE clause of inner query into outer query.
|
||||
|
||||
Args:
|
||||
outer_scope (sqlglot.optimizer.scope.Scope)
|
||||
inner_scope (sqlglot.optimizer.scope.Scope)
|
||||
from_or_join (exp.From|exp.Join)
|
||||
"""
|
||||
where = inner_scope.expression.args.get("where")
|
||||
if not where or not where.this:
|
||||
return
|
||||
|
||||
if isinstance(from_or_join, exp.Join) and from_or_join.side:
|
||||
# Merge predicates from an outer join to the ON clause
|
||||
from_or_join.on(where.this, copy=False)
|
||||
from_or_join.set("on", simplify(from_or_join.args.get("on")))
|
||||
else:
|
||||
outer_scope.expression.where(where.this, copy=False)
|
||||
outer_scope.expression.set("where", simplify(outer_scope.expression.args.get("where")))
|
||||
|
||||
|
||||
def _merge_order(outer_scope, inner_scope):
|
||||
"""
|
||||
Merge ORDER clause of inner query into outer query.
|
||||
|
||||
Args:
|
||||
outer_scope (sqlglot.optimizer.scope.Scope)
|
||||
inner_scope (sqlglot.optimizer.scope.Scope)
|
||||
"""
|
||||
if (
|
||||
any(outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"])
|
||||
or len(outer_scope.selected_sources) != 1
|
||||
or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions)
|
||||
):
|
||||
return
|
||||
|
||||
outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
|
|
@ -22,18 +22,14 @@ def normalize(expression, dnf=False, max_distance=128):
|
|||
"""
|
||||
expression = simplify(expression)
|
||||
|
||||
expression = while_changing(
|
||||
expression, lambda e: distributive_law(e, dnf, max_distance)
|
||||
)
|
||||
expression = while_changing(expression, lambda e: distributive_law(e, dnf, max_distance))
|
||||
return simplify(expression)
|
||||
|
||||
|
||||
def normalized(expression, dnf=False):
|
||||
ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
|
||||
|
||||
return not any(
|
||||
connector.find_ancestor(ancestor) for connector in expression.find_all(root)
|
||||
)
|
||||
return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root))
|
||||
|
||||
|
||||
def normalization_distance(expression, dnf=False):
|
||||
|
@ -54,9 +50,7 @@ def normalization_distance(expression, dnf=False):
|
|||
Returns:
|
||||
int: difference
|
||||
"""
|
||||
return sum(_predicate_lengths(expression, dnf)) - (
|
||||
len(list(expression.find_all(exp.Connector))) + 1
|
||||
)
|
||||
return sum(_predicate_lengths(expression, dnf)) - (len(list(expression.find_all(exp.Connector))) + 1)
|
||||
|
||||
|
||||
def _predicate_lengths(expression, dnf):
|
||||
|
@ -73,11 +67,7 @@ def _predicate_lengths(expression, dnf):
|
|||
left, right = expression.args.values()
|
||||
|
||||
if isinstance(expression, exp.And if dnf else exp.Or):
|
||||
x = [
|
||||
a + b
|
||||
for a in _predicate_lengths(left, dnf)
|
||||
for b in _predicate_lengths(right, dnf)
|
||||
]
|
||||
x = [a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)]
|
||||
return x
|
||||
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
|
||||
|
||||
|
@ -102,9 +92,7 @@ def distributive_law(expression, dnf, max_distance):
|
|||
to_func = exp.and_ if to_exp == exp.And else exp.or_
|
||||
|
||||
if isinstance(a, to_exp) and isinstance(b, to_exp):
|
||||
if len(tuple(a.find_all(exp.Connector))) > len(
|
||||
tuple(b.find_all(exp.Connector))
|
||||
):
|
||||
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
|
||||
return _distribute(a, b, from_func, to_func)
|
||||
return _distribute(b, a, from_func, to_func)
|
||||
if isinstance(a, to_exp):
|
||||
|
|
|
@ -68,8 +68,4 @@ def normalize(expression):
|
|||
|
||||
|
||||
def other_table_names(join, exclude):
|
||||
return [
|
||||
name
|
||||
for name in (exp.column_table_names(join.args.get("on") or exp.TRUE))
|
||||
if name != exclude
|
||||
]
|
||||
return [name for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) if name != exclude]
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
|
||||
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
|
||||
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
|
||||
from sqlglot.optimizer.merge_derived_tables import merge_derived_tables
|
||||
from sqlglot.optimizer.normalize import normalize
|
||||
from sqlglot.optimizer.optimize_joins import optimize_joins
|
||||
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
|
||||
|
@ -10,8 +11,23 @@ from sqlglot.optimizer.qualify_tables import qualify_tables
|
|||
from sqlglot.optimizer.quote_identities import quote_identities
|
||||
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
|
||||
|
||||
RULES = (
|
||||
qualify_tables,
|
||||
isolate_table_selects,
|
||||
qualify_columns,
|
||||
pushdown_projections,
|
||||
normalize,
|
||||
unnest_subqueries,
|
||||
expand_multi_table_selects,
|
||||
pushdown_predicates,
|
||||
optimize_joins,
|
||||
eliminate_subqueries,
|
||||
merge_derived_tables,
|
||||
quote_identities,
|
||||
)
|
||||
|
||||
def optimize(expression, schema=None, db=None, catalog=None):
|
||||
|
||||
def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwargs):
|
||||
"""
|
||||
Rewrite a sqlglot AST into an optimized form.
|
||||
|
||||
|
@ -25,19 +41,18 @@ def optimize(expression, schema=None, db=None, catalog=None):
|
|||
3. {catalog: {db: {table: {col: type}}}}
|
||||
db (str): specify the default database, as might be set by a `USE DATABASE db` statement
|
||||
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
|
||||
rules (list): sequence of optimizer rules to use
|
||||
**kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.
|
||||
Returns:
|
||||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
|
||||
expression = expression.copy()
|
||||
expression = qualify_tables(expression, db=db, catalog=catalog)
|
||||
expression = isolate_table_selects(expression)
|
||||
expression = qualify_columns(expression, schema)
|
||||
expression = pushdown_projections(expression)
|
||||
expression = normalize(expression)
|
||||
expression = unnest_subqueries(expression)
|
||||
expression = expand_multi_table_selects(expression)
|
||||
expression = pushdown_predicates(expression)
|
||||
expression = optimize_joins(expression)
|
||||
expression = eliminate_subqueries(expression)
|
||||
expression = quote_identities(expression)
|
||||
for rule in rules:
|
||||
|
||||
# Find any additional rule parameters, beyond `expression`
|
||||
rule_params = rule.__code__.co_varnames
|
||||
rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs}
|
||||
|
||||
expression = rule(expression, **rule_kwargs)
|
||||
return expression
|
||||
|
|
|
@ -42,11 +42,7 @@ def pushdown(condition, sources):
|
|||
condition = condition.replace(simplify(condition))
|
||||
cnf_like = normalized(condition) or not normalized(condition, dnf=True)
|
||||
|
||||
predicates = list(
|
||||
condition.flatten()
|
||||
if isinstance(condition, exp.And if cnf_like else exp.Or)
|
||||
else [condition]
|
||||
)
|
||||
predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition])
|
||||
|
||||
if cnf_like:
|
||||
pushdown_cnf(predicates, sources)
|
||||
|
@ -105,17 +101,11 @@ def pushdown_dnf(predicates, scope):
|
|||
for column in predicate.find_all(exp.Column):
|
||||
if column.table == table:
|
||||
condition = column.find_ancestor(exp.Condition)
|
||||
predicate_condition = (
|
||||
exp.and_(predicate_condition, condition)
|
||||
if predicate_condition
|
||||
else condition
|
||||
)
|
||||
predicate_condition = exp.and_(predicate_condition, condition) if predicate_condition else condition
|
||||
|
||||
if predicate_condition:
|
||||
conditions[table] = (
|
||||
exp.or_(conditions[table], predicate_condition)
|
||||
if table in conditions
|
||||
else predicate_condition
|
||||
exp.or_(conditions[table], predicate_condition) if table in conditions else predicate_condition
|
||||
)
|
||||
|
||||
for name, node in nodes.items():
|
||||
|
@ -133,9 +123,7 @@ def pushdown_dnf(predicates, scope):
|
|||
def nodes_for_predicate(predicate, sources):
|
||||
nodes = {}
|
||||
tables = exp.column_table_names(predicate)
|
||||
where_condition = isinstance(
|
||||
predicate.find_ancestor(exp.Join, exp.Where), exp.Where
|
||||
)
|
||||
where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
|
||||
|
||||
for table in tables:
|
||||
node, source = sources.get(table) or (None, None)
|
||||
|
|
|
@ -226,9 +226,7 @@ def _expand_stars(scope, resolver):
|
|||
tables = list(scope.selected_sources)
|
||||
_add_except_columns(expression, tables, except_columns)
|
||||
_add_replace_columns(expression, tables, replace_columns)
|
||||
elif isinstance(expression, exp.Column) and isinstance(
|
||||
expression.this, exp.Star
|
||||
):
|
||||
elif isinstance(expression, exp.Column) and isinstance(expression.this, exp.Star):
|
||||
tables = [expression.table]
|
||||
_add_except_columns(expression.this, tables, except_columns)
|
||||
_add_replace_columns(expression.this, tables, replace_columns)
|
||||
|
@ -245,9 +243,7 @@ def _expand_stars(scope, resolver):
|
|||
if name not in except_columns.get(table_id, set()):
|
||||
alias_ = replace_columns.get(table_id, {}).get(name, name)
|
||||
column = exp.column(name, table)
|
||||
new_selections.append(
|
||||
alias(column, alias_) if alias_ != name else column
|
||||
)
|
||||
new_selections.append(alias(column, alias_) if alias_ != name else column)
|
||||
|
||||
scope.expression.set("expressions", new_selections)
|
||||
|
||||
|
@ -280,9 +276,7 @@ def _qualify_outputs(scope):
|
|||
"""Ensure all output columns are aliased"""
|
||||
new_selections = []
|
||||
|
||||
for i, (selection, aliased_column) in enumerate(
|
||||
itertools.zip_longest(scope.selects, scope.outer_column_list)
|
||||
):
|
||||
for i, (selection, aliased_column) in enumerate(itertools.zip_longest(scope.selects, scope.outer_column_list)):
|
||||
if isinstance(selection, exp.Column):
|
||||
# convoluted setter because a simple selection.replace(alias) would require a copy
|
||||
alias_ = alias(exp.column(""), alias=selection.name)
|
||||
|
@ -302,11 +296,7 @@ def _qualify_outputs(scope):
|
|||
|
||||
|
||||
def _check_unknown_tables(scope):
|
||||
if (
|
||||
scope.external_columns
|
||||
and not scope.is_unnest
|
||||
and not scope.is_correlated_subquery
|
||||
):
|
||||
if scope.external_columns and not scope.is_unnest and not scope.is_correlated_subquery:
|
||||
raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")
|
||||
|
||||
|
||||
|
@ -334,20 +324,14 @@ class _Resolver:
|
|||
(str) table name
|
||||
"""
|
||||
if self._unambiguous_columns is None:
|
||||
self._unambiguous_columns = self._get_unambiguous_columns(
|
||||
self._get_all_source_columns()
|
||||
)
|
||||
self._unambiguous_columns = self._get_unambiguous_columns(self._get_all_source_columns())
|
||||
return self._unambiguous_columns.get(column_name)
|
||||
|
||||
@property
|
||||
def all_columns(self):
|
||||
"""All available columns of all sources in this scope"""
|
||||
if self._all_columns is None:
|
||||
self._all_columns = set(
|
||||
column
|
||||
for columns in self._get_all_source_columns().values()
|
||||
for column in columns
|
||||
)
|
||||
self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns)
|
||||
return self._all_columns
|
||||
|
||||
def get_source_columns(self, name):
|
||||
|
@ -369,9 +353,7 @@ class _Resolver:
|
|||
|
||||
def _get_all_source_columns(self):
|
||||
if self._source_columns is None:
|
||||
self._source_columns = {
|
||||
k: self.get_source_columns(k) for k in self.scope.selected_sources
|
||||
}
|
||||
self._source_columns = {k: self.get_source_columns(k) for k in self.scope.selected_sources}
|
||||
return self._source_columns
|
||||
|
||||
def _get_unambiguous_columns(self, source_columns):
|
||||
|
@ -389,9 +371,7 @@ class _Resolver:
|
|||
source_columns = list(source_columns.items())
|
||||
|
||||
first_table, first_columns = source_columns[0]
|
||||
unambiguous_columns = {
|
||||
col: first_table for col in self._find_unique_columns(first_columns)
|
||||
}
|
||||
unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
|
||||
all_columns = set(unambiguous_columns)
|
||||
|
||||
for table, columns in source_columns[1:]:
|
||||
|
|
|
@ -27,9 +27,7 @@ def qualify_tables(expression, db=None, catalog=None):
|
|||
for derived_table in scope.ctes + scope.derived_tables:
|
||||
if not derived_table.args.get("alias"):
|
||||
alias_ = f"_q_{next(sequence)}"
|
||||
derived_table.set(
|
||||
"alias", exp.TableAlias(this=exp.to_identifier(alias_))
|
||||
)
|
||||
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
|
||||
scope.rename_source(None, alias_)
|
||||
|
||||
for source in scope.sources.values():
|
||||
|
|
|
@ -57,9 +57,7 @@ class MappingSchema(Schema):
|
|||
|
||||
for forbidden in self.forbidden_args:
|
||||
if table.text(forbidden):
|
||||
raise ValueError(
|
||||
f"Schema doesn't support {forbidden}. Received: {table.sql()}"
|
||||
)
|
||||
raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
|
||||
return list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
|
||||
|
||||
|
||||
|
|
|
@ -104,9 +104,7 @@ class Scope:
|
|||
elif isinstance(node, exp.CTE):
|
||||
self._ctes.append(node)
|
||||
prune = True
|
||||
elif isinstance(node, exp.Subquery) and isinstance(
|
||||
parent, (exp.From, exp.Join)
|
||||
):
|
||||
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
|
||||
self._derived_tables.append(node)
|
||||
prune = True
|
||||
elif isinstance(node, exp.Subqueryable):
|
||||
|
@ -195,20 +193,14 @@ class Scope:
|
|||
self._ensure_collected()
|
||||
columns = self._raw_columns
|
||||
|
||||
external_columns = [
|
||||
column
|
||||
for scope in self.subquery_scopes
|
||||
for column in scope.external_columns
|
||||
]
|
||||
external_columns = [column for scope in self.subquery_scopes for column in scope.external_columns]
|
||||
|
||||
named_outputs = {e.alias_or_name for e in self.expression.expressions}
|
||||
|
||||
self._columns = [
|
||||
c
|
||||
for c in columns + external_columns
|
||||
if not (
|
||||
c.find_ancestor(exp.Qualify, exp.Order) and c.name in named_outputs
|
||||
)
|
||||
if not (c.find_ancestor(exp.Qualify, exp.Order) and not c.table and c.name in named_outputs)
|
||||
]
|
||||
return self._columns
|
||||
|
||||
|
@ -229,9 +221,7 @@ class Scope:
|
|||
for table in self.tables:
|
||||
referenced_names.append(
|
||||
(
|
||||
table.parent.alias
|
||||
if isinstance(table.parent, exp.Alias)
|
||||
else table.name,
|
||||
table.parent.alias if isinstance(table.parent, exp.Alias) else table.name,
|
||||
table,
|
||||
)
|
||||
)
|
||||
|
@ -274,9 +264,7 @@ class Scope:
|
|||
sources in the current scope.
|
||||
"""
|
||||
if self._external_columns is None:
|
||||
self._external_columns = [
|
||||
c for c in self.columns if c.table not in self.selected_sources
|
||||
]
|
||||
self._external_columns = [c for c in self.columns if c.table not in self.selected_sources]
|
||||
return self._external_columns
|
||||
|
||||
def source_columns(self, source_name):
|
||||
|
@ -310,6 +298,16 @@ class Scope:
|
|||
columns = self.sources.pop(old_name or "", [])
|
||||
self.sources[new_name] = columns
|
||||
|
||||
def add_source(self, name, source):
|
||||
"""Add a source to this scope"""
|
||||
self.sources[name] = source
|
||||
self.clear_cache()
|
||||
|
||||
def remove_source(self, name):
|
||||
"""Remove a source from this scope"""
|
||||
self.sources.pop(name, None)
|
||||
self.clear_cache()
|
||||
|
||||
|
||||
def traverse_scope(expression):
|
||||
"""
|
||||
|
@ -334,7 +332,7 @@ def traverse_scope(expression):
|
|||
Args:
|
||||
expression (exp.Expression): expression to traverse
|
||||
Returns:
|
||||
List[Scope]: scope instances
|
||||
list[Scope]: scope instances
|
||||
"""
|
||||
return list(_traverse_scope(Scope(expression)))
|
||||
|
||||
|
@ -356,9 +354,7 @@ def _traverse_scope(scope):
|
|||
def _traverse_select(scope):
|
||||
yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
|
||||
yield from _traverse_subqueries(scope)
|
||||
yield from _traverse_derived_tables(
|
||||
scope.derived_tables, scope, ScopeType.DERIVED_TABLE
|
||||
)
|
||||
yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE)
|
||||
_add_table_sources(scope)
|
||||
|
||||
|
||||
|
@ -367,15 +363,11 @@ def _traverse_union(scope):
|
|||
|
||||
# The last scope to be yield should be the top most scope
|
||||
left = None
|
||||
for left in _traverse_scope(
|
||||
scope.branch(scope.expression.left, scope_type=ScopeType.UNION)
|
||||
):
|
||||
for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)):
|
||||
yield left
|
||||
|
||||
right = None
|
||||
for right in _traverse_scope(
|
||||
scope.branch(scope.expression.right, scope_type=ScopeType.UNION)
|
||||
):
|
||||
for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)):
|
||||
yield right
|
||||
|
||||
scope.union = (left, right)
|
||||
|
@ -387,14 +379,10 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
|
|||
for derived_table in derived_tables:
|
||||
for child_scope in _traverse_scope(
|
||||
scope.branch(
|
||||
derived_table
|
||||
if isinstance(derived_table, (exp.Unnest, exp.Lateral))
|
||||
else derived_table.this,
|
||||
derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this,
|
||||
add_sources=sources if scope_type == ScopeType.CTE else None,
|
||||
outer_column_list=derived_table.alias_column_names,
|
||||
scope_type=ScopeType.UNNEST
|
||||
if isinstance(derived_table, exp.Unnest)
|
||||
else scope_type,
|
||||
scope_type=ScopeType.UNNEST if isinstance(derived_table, exp.Unnest) else scope_type,
|
||||
)
|
||||
):
|
||||
yield child_scope
|
||||
|
@ -430,9 +418,7 @@ def _add_table_sources(scope):
|
|||
def _traverse_subqueries(scope):
|
||||
for subquery in scope.subqueries:
|
||||
top = None
|
||||
for child_scope in _traverse_scope(
|
||||
scope.branch(subquery, scope_type=ScopeType.SUBQUERY)
|
||||
):
|
||||
for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)):
|
||||
yield child_scope
|
||||
top = child_scope
|
||||
scope.subquery_scopes.append(top)
|
||||
|
|
|
@ -188,9 +188,7 @@ def absorb_and_eliminate(expression):
|
|||
aa.replace(exp.TRUE if kind == exp.And else exp.FALSE)
|
||||
elif is_complement(b, ab):
|
||||
ab.replace(exp.TRUE if kind == exp.And else exp.FALSE)
|
||||
elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(
|
||||
a.flatten()
|
||||
):
|
||||
elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
|
||||
a.replace(exp.FALSE if kind == exp.And else exp.TRUE)
|
||||
elif isinstance(b, kind):
|
||||
# eliminate
|
||||
|
@ -227,9 +225,7 @@ def simplify_literals(expression):
|
|||
operands.append(a)
|
||||
|
||||
if len(operands) < size:
|
||||
return functools.reduce(
|
||||
lambda a, b: expression.__class__(this=a, expression=b), operands
|
||||
)
|
||||
return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
|
||||
elif isinstance(expression, exp.Neg):
|
||||
this = expression.this
|
||||
if this.is_number:
|
||||
|
|
|
@ -89,11 +89,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
return
|
||||
|
||||
if isinstance(predicate, exp.Binary):
|
||||
key = (
|
||||
predicate.right
|
||||
if any(node is column for node, *_ in predicate.left.walk())
|
||||
else predicate.left
|
||||
)
|
||||
key = predicate.right if any(node is column for node, *_ in predicate.left.walk()) else predicate.left
|
||||
else:
|
||||
return
|
||||
|
||||
|
@ -124,9 +120,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
# if the value of the subquery is not an agg or a key, we need to collect it into an array
|
||||
# so that it can be grouped
|
||||
if not value.find(exp.AggFunc) and value.this not in group_by:
|
||||
select.select(
|
||||
f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False
|
||||
)
|
||||
select.select(f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False)
|
||||
|
||||
# exists queries should not have any selects as it only checks if there are any rows
|
||||
# all selects will be added by the optimizer and only used for join keys
|
||||
|
@ -151,16 +145,12 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
else:
|
||||
parent_predicate = _replace(parent_predicate, "TRUE")
|
||||
elif isinstance(parent_predicate, exp.All):
|
||||
parent_predicate = _replace(
|
||||
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
|
||||
)
|
||||
parent_predicate = _replace(parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})")
|
||||
elif isinstance(parent_predicate, exp.Any):
|
||||
if value.this in group_by:
|
||||
parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
|
||||
else:
|
||||
parent_predicate = _replace(
|
||||
parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})"
|
||||
)
|
||||
parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})")
|
||||
elif isinstance(parent_predicate, exp.In):
|
||||
if value.this in group_by:
|
||||
parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
|
||||
|
@ -178,9 +168,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
|
||||
if key in group_by:
|
||||
key.replace(nested)
|
||||
parent_predicate = _replace(
|
||||
parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
|
||||
)
|
||||
parent_predicate = _replace(parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)")
|
||||
elif isinstance(predicate, exp.EQ):
|
||||
parent_predicate = _replace(
|
||||
parent_predicate,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue