Merging upstream version 10.5.2.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
77197f1e44
commit
e0f3bbb5f3
58 changed files with 1480 additions and 383 deletions
|
@ -43,7 +43,7 @@ class TypeAnnotator:
|
|||
},
|
||||
exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
|
||||
exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
|
||||
exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr),
|
||||
exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
|
||||
exp.Alias: lambda self, expr: self._annotate_unary(expr),
|
||||
exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||
exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||
|
|
|
@ -57,7 +57,7 @@ def _join_is_used(scope, join, alias):
|
|||
# But columns in the ON clause shouldn't count.
|
||||
on = join.args.get("on")
|
||||
if on:
|
||||
on_clause_columns = set(id(column) for column in on.find_all(exp.Column))
|
||||
on_clause_columns = {id(column) for column in on.find_all(exp.Column)}
|
||||
else:
|
||||
on_clause_columns = set()
|
||||
return any(
|
||||
|
@ -71,7 +71,7 @@ def _is_joined_on_all_unique_outputs(scope, join):
|
|||
return False
|
||||
|
||||
_, join_keys, _ = join_condition(join)
|
||||
remaining_unique_outputs = unique_outputs - set(c.name for c in join_keys)
|
||||
remaining_unique_outputs = unique_outputs - {c.name for c in join_keys}
|
||||
return not remaining_unique_outputs
|
||||
|
||||
|
||||
|
|
|
@ -67,11 +67,9 @@ def merge_ctes(expression, leave_tables_isolated=False):
|
|||
|
||||
singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
|
||||
for outer_scope, inner_scope, table in singular_cte_selections:
|
||||
inner_select = inner_scope.expression.unnest()
|
||||
from_or_join = table.find_ancestor(exp.From, exp.Join)
|
||||
if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
||||
if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
||||
alias = table.alias_or_name
|
||||
|
||||
_rename_inner_sources(outer_scope, inner_scope, alias)
|
||||
_merge_from(outer_scope, inner_scope, table, alias)
|
||||
_merge_expressions(outer_scope, inner_scope, alias)
|
||||
|
@ -80,18 +78,17 @@ def merge_ctes(expression, leave_tables_isolated=False):
|
|||
_merge_order(outer_scope, inner_scope)
|
||||
_merge_hints(outer_scope, inner_scope)
|
||||
_pop_cte(inner_scope)
|
||||
outer_scope.clear_cache()
|
||||
return expression
|
||||
|
||||
|
||||
def merge_derived_tables(expression, leave_tables_isolated=False):
|
||||
for outer_scope in traverse_scope(expression):
|
||||
for subquery in outer_scope.derived_tables:
|
||||
inner_select = subquery.unnest()
|
||||
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
|
||||
if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
||||
alias = subquery.alias_or_name
|
||||
inner_scope = outer_scope.sources[alias]
|
||||
|
||||
alias = subquery.alias_or_name
|
||||
inner_scope = outer_scope.sources[alias]
|
||||
if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
||||
_rename_inner_sources(outer_scope, inner_scope, alias)
|
||||
_merge_from(outer_scope, inner_scope, subquery, alias)
|
||||
_merge_expressions(outer_scope, inner_scope, alias)
|
||||
|
@ -99,21 +96,23 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
|
|||
_merge_where(outer_scope, inner_scope, from_or_join)
|
||||
_merge_order(outer_scope, inner_scope)
|
||||
_merge_hints(outer_scope, inner_scope)
|
||||
outer_scope.clear_cache()
|
||||
return expression
|
||||
|
||||
|
||||
def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
||||
def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
||||
"""
|
||||
Return True if `inner_select` can be merged into outer query.
|
||||
|
||||
Args:
|
||||
outer_scope (Scope)
|
||||
inner_select (exp.Select)
|
||||
inner_scope (Scope)
|
||||
leave_tables_isolated (bool)
|
||||
from_or_join (exp.From|exp.Join)
|
||||
Returns:
|
||||
bool: True if can be merged
|
||||
"""
|
||||
inner_select = inner_scope.expression.unnest()
|
||||
|
||||
def _is_a_window_expression_in_unmergable_operation():
|
||||
window_expressions = inner_select.find_all(exp.Window)
|
||||
|
@ -133,10 +132,40 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
|||
]
|
||||
return any(window_expressions_in_unmergable)
|
||||
|
||||
def _outer_select_joins_on_inner_select_join():
|
||||
"""
|
||||
All columns from the inner select in the ON clause must be from the first FROM table.
|
||||
|
||||
That is, this can be merged:
|
||||
SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a
|
||||
^^^ ^
|
||||
But this can't:
|
||||
SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a
|
||||
^^^ ^
|
||||
"""
|
||||
if not isinstance(from_or_join, exp.Join):
|
||||
return False
|
||||
|
||||
alias = from_or_join.this.alias_or_name
|
||||
|
||||
on = from_or_join.args.get("on")
|
||||
if not on:
|
||||
return False
|
||||
selections = [c.name for c in on.find_all(exp.Column) if c.table == alias]
|
||||
inner_from = inner_scope.expression.args.get("from")
|
||||
if not inner_from:
|
||||
return False
|
||||
inner_from_table = inner_from.expressions[0].alias_or_name
|
||||
inner_projections = {s.alias_or_name: s for s in inner_scope.selects}
|
||||
return any(
|
||||
col.table != inner_from_table
|
||||
for selection in selections
|
||||
for col in inner_projections[selection].find_all(exp.Column)
|
||||
)
|
||||
|
||||
return (
|
||||
isinstance(outer_scope.expression, exp.Select)
|
||||
and isinstance(inner_select, exp.Select)
|
||||
and isinstance(inner_select, exp.Select)
|
||||
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
|
||||
and inner_select.args.get("from")
|
||||
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
|
||||
|
@ -153,6 +182,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
|||
j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])
|
||||
)
|
||||
)
|
||||
and not _outer_select_joins_on_inner_select_join()
|
||||
and not _is_a_window_expression_in_unmergable_operation()
|
||||
)
|
||||
|
||||
|
@ -168,7 +198,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
|
|||
"""
|
||||
taken = set(outer_scope.selected_sources)
|
||||
conflicts = taken.intersection(set(inner_scope.selected_sources))
|
||||
conflicts = conflicts - {alias}
|
||||
conflicts -= {alias}
|
||||
|
||||
for conflict in conflicts:
|
||||
new_name = find_new_name(taken, conflict)
|
||||
|
|
|
@ -15,6 +15,7 @@ from sqlglot.optimizer.pushdown_projections import pushdown_projections
|
|||
from sqlglot.optimizer.qualify_columns import qualify_columns
|
||||
from sqlglot.optimizer.qualify_tables import qualify_tables
|
||||
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
|
||||
from sqlglot.schema import ensure_schema
|
||||
|
||||
RULES = (
|
||||
lower_identities,
|
||||
|
@ -51,12 +52,13 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
|
|||
If no schema is provided then the default schema defined at `sqlgot.schema` will be used
|
||||
db (str): specify the default database, as might be set by a `USE DATABASE db` statement
|
||||
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
|
||||
rules (list): sequence of optimizer rules to use
|
||||
rules (sequence): sequence of optimizer rules to use
|
||||
**kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.
|
||||
Returns:
|
||||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema or sqlglot.schema, **kwargs}
|
||||
schema = ensure_schema(schema or sqlglot.schema)
|
||||
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
|
||||
expression = expression.copy()
|
||||
for rule in rules:
|
||||
|
||||
|
|
|
@ -79,6 +79,7 @@ def _remove_unused_selections(scope, parent_selections):
|
|||
order_refs = set()
|
||||
|
||||
new_selections = []
|
||||
removed = False
|
||||
for i, selection in enumerate(scope.selects):
|
||||
if (
|
||||
SELECT_ALL in parent_selections
|
||||
|
@ -88,12 +89,15 @@ def _remove_unused_selections(scope, parent_selections):
|
|||
new_selections.append(selection)
|
||||
else:
|
||||
removed_indexes.append(i)
|
||||
removed = True
|
||||
|
||||
# If there are no remaining selections, just select a single constant
|
||||
if not new_selections:
|
||||
new_selections.append(DEFAULT_SELECTION.copy())
|
||||
|
||||
scope.expression.set("expressions", new_selections)
|
||||
if removed:
|
||||
scope.clear_cache()
|
||||
return removed_indexes
|
||||
|
||||
|
||||
|
|
|
@ -365,9 +365,9 @@ class _Resolver:
|
|||
def all_columns(self):
|
||||
"""All available columns of all sources in this scope"""
|
||||
if self._all_columns is None:
|
||||
self._all_columns = set(
|
||||
self._all_columns = {
|
||||
column for columns in self._get_all_source_columns().values() for column in columns
|
||||
)
|
||||
}
|
||||
return self._all_columns
|
||||
|
||||
def get_source_columns(self, name, only_visible=False):
|
||||
|
|
|
@ -361,7 +361,7 @@ def _simplify_binary(expression, a, b):
|
|||
return boolean
|
||||
elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
|
||||
a, b = extract_date(a), extract_interval(b)
|
||||
if b:
|
||||
if a and b:
|
||||
if isinstance(expression, exp.Add):
|
||||
return date_literal(a + b)
|
||||
if isinstance(expression, exp.Sub):
|
||||
|
@ -369,7 +369,7 @@ def _simplify_binary(expression, a, b):
|
|||
elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
|
||||
a, b = extract_interval(a), extract_date(b)
|
||||
# you cannot subtract a date from an interval
|
||||
if a and isinstance(expression, exp.Add):
|
||||
if a and b and isinstance(expression, exp.Add):
|
||||
return date_literal(a + b)
|
||||
|
||||
return None
|
||||
|
@ -424,9 +424,15 @@ def eval_boolean(expression, a, b):
|
|||
|
||||
|
||||
def extract_date(cast):
|
||||
if cast.args["to"].this == exp.DataType.Type.DATE:
|
||||
return datetime.date.fromisoformat(cast.name)
|
||||
return None
|
||||
# The "fromisoformat" conversion could fail if the cast is used on an identifier,
|
||||
# so in that case we can't extract the date.
|
||||
try:
|
||||
if cast.args["to"].this == exp.DataType.Type.DATE:
|
||||
return datetime.date.fromisoformat(cast.name)
|
||||
if cast.args["to"].this == exp.DataType.Type.DATETIME:
|
||||
return datetime.datetime.fromisoformat(cast.name)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def extract_interval(interval):
|
||||
|
@ -450,7 +456,8 @@ def extract_interval(interval):
|
|||
|
||||
|
||||
def date_literal(date):
|
||||
return exp.Cast(this=exp.Literal.string(date), to=exp.DataType.build("DATE"))
|
||||
expr_type = exp.DataType.build("DATETIME" if isinstance(date, datetime.datetime) else "DATE")
|
||||
return exp.Cast(this=exp.Literal.string(date), to=expr_type)
|
||||
|
||||
|
||||
def boolean_literal(condition):
|
||||
|
|
|
@ -15,8 +15,7 @@ def unnest_subqueries(expression):
|
|||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
|
||||
>>> unnest_subqueries(expression).sql()
|
||||
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\
|
||||
AS _u_0 ON x.a = _u_0.a WHERE (_u_0.a = 1 AND NOT _u_0.a IS NULL)'
|
||||
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to unnest
|
||||
|
@ -173,10 +172,8 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
other = _other_operand(parent_predicate)
|
||||
|
||||
if isinstance(parent_predicate, exp.Exists):
|
||||
if value.this in group_by:
|
||||
parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
|
||||
else:
|
||||
parent_predicate = _replace(parent_predicate, "TRUE")
|
||||
alias = exp.column(list(key_aliases.values())[0], table_alias)
|
||||
parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
|
||||
elif isinstance(parent_predicate, exp.All):
|
||||
parent_predicate = _replace(
|
||||
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
|
||||
|
@ -197,6 +194,23 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
else:
|
||||
if is_subquery_projection:
|
||||
alias = exp.alias_(alias, select.parent.alias)
|
||||
|
||||
# COUNT always returns 0 on empty datasets, so we need take that into consideration here
|
||||
# by transforming all counts into 0 and using that as the coalesced value
|
||||
if value.find(exp.Count):
|
||||
|
||||
def remove_aggs(node):
|
||||
if isinstance(node, exp.Count):
|
||||
return exp.Literal.number(0)
|
||||
elif isinstance(node, exp.AggFunc):
|
||||
return exp.null()
|
||||
return node
|
||||
|
||||
alias = exp.Coalesce(
|
||||
this=alias,
|
||||
expressions=[value.this.transform(remove_aggs)],
|
||||
)
|
||||
|
||||
select.parent.replace(alias)
|
||||
|
||||
for key, column, predicate in keys:
|
||||
|
@ -209,9 +223,6 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
|
||||
if key in group_by:
|
||||
key.replace(nested)
|
||||
parent_predicate = _replace(
|
||||
parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
|
||||
)
|
||||
elif isinstance(predicate, exp.EQ):
|
||||
parent_predicate = _replace(
|
||||
parent_predicate,
|
||||
|
@ -245,7 +256,14 @@ def _other_operand(expression):
|
|||
if isinstance(expression, exp.In):
|
||||
return expression.this
|
||||
|
||||
if isinstance(expression, (exp.Any, exp.All)):
|
||||
return _other_operand(expression.parent)
|
||||
|
||||
if isinstance(expression, exp.Binary):
|
||||
return expression.right if expression.arg_key == "this" else expression.left
|
||||
return (
|
||||
expression.right
|
||||
if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All))
|
||||
else expression.left
|
||||
)
|
||||
|
||||
return None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue