Merging upstream version 6.2.6.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
0f5b9ddee1
commit
66e2d714bf
49 changed files with 1741 additions and 566 deletions
162
sqlglot/optimizer/annotate_types.py
Normal file
162
sqlglot/optimizer/annotate_types.py
Normal file
|
@ -0,0 +1,162 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.helper import ensure_list, subclasses
|
||||
|
||||
|
||||
def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
|
||||
"""
|
||||
Recursively infer & annotate types in an expression syntax tree against a schema.
|
||||
|
||||
(TODO -- replace this with a better example after adding some functionality)
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> annotated_expression = annotate_types(sqlglot.parse_one('5 + 5.3'))
|
||||
>>> annotated_expression.type
|
||||
<Type.DOUBLE: 'DOUBLE'>
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): Expression to annotate.
|
||||
schema (dict|sqlglot.optimizer.Schema): Database schema.
|
||||
annotators (dict): Maps expression type to corresponding annotation function.
|
||||
coerces_to (dict): Maps expression type to set of types that it can be coerced into.
|
||||
Returns:
|
||||
sqlglot.Expression: expression annotated with types
|
||||
"""
|
||||
|
||||
return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
|
||||
|
||||
|
||||
class TypeAnnotator:
|
||||
ANNOTATORS = {
|
||||
**{
|
||||
expr_type: lambda self, expr: self._annotate_unary(expr)
|
||||
for expr_type in subclasses(exp.__name__, exp.Unary)
|
||||
},
|
||||
**{
|
||||
expr_type: lambda self, expr: self._annotate_binary(expr)
|
||||
for expr_type in subclasses(exp.__name__, exp.Binary)
|
||||
},
|
||||
exp.Cast: lambda self, expr: self._annotate_cast(expr),
|
||||
exp.DataType: lambda self, expr: self._annotate_data_type(expr),
|
||||
exp.Literal: lambda self, expr: self._annotate_literal(expr),
|
||||
exp.Boolean: lambda self, expr: self._annotate_boolean(expr),
|
||||
}
|
||||
|
||||
# Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
|
||||
COERCES_TO = {
|
||||
# CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT
|
||||
exp.DataType.Type.TEXT: set(),
|
||||
exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
|
||||
exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
|
||||
exp.DataType.Type.NCHAR: {exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
|
||||
exp.DataType.Type.CHAR: {
|
||||
exp.DataType.Type.NCHAR,
|
||||
exp.DataType.Type.VARCHAR,
|
||||
exp.DataType.Type.NVARCHAR,
|
||||
exp.DataType.Type.TEXT,
|
||||
},
|
||||
# TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE
|
||||
exp.DataType.Type.DOUBLE: set(),
|
||||
exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
|
||||
exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
|
||||
exp.DataType.Type.BIGINT: {exp.DataType.Type.DECIMAL, exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
|
||||
exp.DataType.Type.INT: {
|
||||
exp.DataType.Type.BIGINT,
|
||||
exp.DataType.Type.DECIMAL,
|
||||
exp.DataType.Type.FLOAT,
|
||||
exp.DataType.Type.DOUBLE,
|
||||
},
|
||||
exp.DataType.Type.SMALLINT: {
|
||||
exp.DataType.Type.INT,
|
||||
exp.DataType.Type.BIGINT,
|
||||
exp.DataType.Type.DECIMAL,
|
||||
exp.DataType.Type.FLOAT,
|
||||
exp.DataType.Type.DOUBLE,
|
||||
},
|
||||
exp.DataType.Type.TINYINT: {
|
||||
exp.DataType.Type.SMALLINT,
|
||||
exp.DataType.Type.INT,
|
||||
exp.DataType.Type.BIGINT,
|
||||
exp.DataType.Type.DECIMAL,
|
||||
exp.DataType.Type.FLOAT,
|
||||
exp.DataType.Type.DOUBLE,
|
||||
},
|
||||
# DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
|
||||
exp.DataType.Type.TIMESTAMPLTZ: set(),
|
||||
exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
|
||||
exp.DataType.Type.TIMESTAMP: {exp.DataType.Type.TIMESTAMPTZ, exp.DataType.Type.TIMESTAMPLTZ},
|
||||
exp.DataType.Type.DATETIME: {
|
||||
exp.DataType.Type.TIMESTAMP,
|
||||
exp.DataType.Type.TIMESTAMPTZ,
|
||||
exp.DataType.Type.TIMESTAMPLTZ,
|
||||
},
|
||||
exp.DataType.Type.DATE: {
|
||||
exp.DataType.Type.DATETIME,
|
||||
exp.DataType.Type.TIMESTAMP,
|
||||
exp.DataType.Type.TIMESTAMPTZ,
|
||||
exp.DataType.Type.TIMESTAMPLTZ,
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, schema=None, annotators=None, coerces_to=None):
|
||||
self.schema = schema
|
||||
self.annotators = annotators or self.ANNOTATORS
|
||||
self.coerces_to = coerces_to or self.COERCES_TO
|
||||
|
||||
def annotate(self, expression):
|
||||
if not isinstance(expression, exp.Expression):
|
||||
return None
|
||||
|
||||
annotator = self.annotators.get(expression.__class__)
|
||||
return annotator(self, expression) if annotator else self._annotate_args(expression)
|
||||
|
||||
def _annotate_args(self, expression):
|
||||
for value in expression.args.values():
|
||||
for v in ensure_list(value):
|
||||
self.annotate(v)
|
||||
|
||||
return expression
|
||||
|
||||
def _annotate_cast(self, expression):
|
||||
expression.type = expression.args["to"].this
|
||||
return self._annotate_args(expression)
|
||||
|
||||
def _annotate_data_type(self, expression):
|
||||
expression.type = expression.this
|
||||
return self._annotate_args(expression)
|
||||
|
||||
def _maybe_coerce(self, type1, type2):
|
||||
return type2 if type2 in self.coerces_to[type1] else type1
|
||||
|
||||
def _annotate_binary(self, expression):
|
||||
self._annotate_args(expression)
|
||||
|
||||
if isinstance(expression, (exp.Condition, exp.Predicate)):
|
||||
expression.type = exp.DataType.Type.BOOLEAN
|
||||
else:
|
||||
expression.type = self._maybe_coerce(expression.left.type, expression.right.type)
|
||||
|
||||
return expression
|
||||
|
||||
def _annotate_unary(self, expression):
|
||||
self._annotate_args(expression)
|
||||
|
||||
if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
|
||||
expression.type = exp.DataType.Type.BOOLEAN
|
||||
else:
|
||||
expression.type = expression.this.type
|
||||
|
||||
return expression
|
||||
|
||||
def _annotate_literal(self, expression):
|
||||
if expression.is_string:
|
||||
expression.type = exp.DataType.Type.VARCHAR
|
||||
elif expression.is_int:
|
||||
expression.type = exp.DataType.Type.INT
|
||||
else:
|
||||
expression.type = exp.DataType.Type.DOUBLE
|
||||
|
||||
return expression
|
||||
|
||||
def _annotate_boolean(self, expression):
|
||||
expression.type = exp.DataType.Type.BOOLEAN
|
||||
return expression
|
|
@ -1,48 +1,144 @@
|
|||
import itertools
|
||||
|
||||
from sqlglot import alias, exp, select, table
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.helper import find_new_name
|
||||
from sqlglot.optimizer.scope import build_scope
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
|
||||
def eliminate_subqueries(expression):
|
||||
"""
|
||||
Rewrite duplicate subqueries from sqlglot AST.
|
||||
Rewrite subqueries as CTES, deduplicating if possible.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one("SELECT 1 AS x, 2 AS y UNION ALL SELECT 1 AS x, 2 AS y")
|
||||
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
|
||||
>>> eliminate_subqueries(expression).sql()
|
||||
'WITH _e_0 AS (SELECT 1 AS x, 2 AS y) SELECT * FROM _e_0 UNION ALL SELECT * FROM _e_0'
|
||||
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
|
||||
|
||||
This also deduplicates common subqueries:
|
||||
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z")
|
||||
>>> eliminate_subqueries(expression).sql()
|
||||
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to qualify
|
||||
schema (dict|sqlglot.optimizer.Schema): Database schema
|
||||
expression (sqlglot.Expression): expression
|
||||
Returns:
|
||||
sqlglot.Expression: qualified expression
|
||||
sqlglot.Expression: expression
|
||||
"""
|
||||
if isinstance(expression, exp.Subquery):
|
||||
# It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1
|
||||
eliminate_subqueries(expression.this)
|
||||
return expression
|
||||
|
||||
expression = simplify(expression)
|
||||
queries = {}
|
||||
root = build_scope(expression)
|
||||
|
||||
for scope in traverse_scope(expression):
|
||||
query = scope.expression
|
||||
queries[query] = queries.get(query, []) + [query]
|
||||
# Map of alias->Scope|Table
|
||||
# These are all aliases that are already used in the expression.
|
||||
# We don't want to create new CTEs that conflict with these names.
|
||||
taken = {}
|
||||
|
||||
sequence = itertools.count()
|
||||
# All CTE aliases in the root scope are taken
|
||||
for scope in root.cte_scopes:
|
||||
taken[scope.expression.parent.alias] = scope
|
||||
|
||||
for query, duplicates in queries.items():
|
||||
if len(duplicates) == 1:
|
||||
continue
|
||||
# All table names are taken
|
||||
for scope in root.traverse():
|
||||
taken.update({source.name: source for _, source in scope.sources.items() if isinstance(source, exp.Table)})
|
||||
|
||||
alias_ = f"_e_{next(sequence)}"
|
||||
# Map of Expression->alias
|
||||
# Existing CTES in the root expression. We'll use this for deduplication.
|
||||
existing_ctes = {}
|
||||
|
||||
for dup in duplicates:
|
||||
parent = dup.parent
|
||||
if isinstance(parent, exp.Subquery):
|
||||
parent.replace(alias(table(alias_), parent.alias_or_name, table=True))
|
||||
elif isinstance(parent, exp.Union):
|
||||
dup.replace(select("*").from_(alias_))
|
||||
with_ = root.expression.args.get("with")
|
||||
if with_:
|
||||
for cte in with_.expressions:
|
||||
existing_ctes[cte.this] = cte.alias
|
||||
new_ctes = []
|
||||
|
||||
expression.with_(alias_, as_=query, copy=False)
|
||||
# We're adding more CTEs, but we want to maintain the DAG order.
|
||||
# Derived tables within an existing CTE need to come before the existing CTE.
|
||||
for cte_scope in root.cte_scopes:
|
||||
# Append all the new CTEs from this existing CTE
|
||||
for scope in cte_scope.traverse():
|
||||
new_cte = _eliminate(scope, existing_ctes, taken)
|
||||
if new_cte:
|
||||
new_ctes.append(new_cte)
|
||||
|
||||
# Append the existing CTE itself
|
||||
new_ctes.append(cte_scope.expression.parent)
|
||||
|
||||
# Now append the rest
|
||||
for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.derived_table_scopes):
|
||||
for child_scope in scope.traverse():
|
||||
new_cte = _eliminate(child_scope, existing_ctes, taken)
|
||||
if new_cte:
|
||||
new_ctes.append(new_cte)
|
||||
|
||||
if new_ctes:
|
||||
expression.set("with", exp.With(expressions=new_ctes))
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def _eliminate(scope, existing_ctes, taken):
|
||||
if scope.is_union:
|
||||
return _eliminate_union(scope, existing_ctes, taken)
|
||||
|
||||
if scope.is_derived_table and not isinstance(scope.expression, (exp.Unnest, exp.Lateral)):
|
||||
return _eliminate_derived_table(scope, existing_ctes, taken)
|
||||
|
||||
|
||||
def _eliminate_union(scope, existing_ctes, taken):
|
||||
duplicate_cte_alias = existing_ctes.get(scope.expression)
|
||||
|
||||
alias = duplicate_cte_alias or find_new_name(taken=taken, base="cte")
|
||||
|
||||
taken[alias] = scope
|
||||
|
||||
# Try to maintain the selections
|
||||
expressions = scope.expression.args.get("expressions")
|
||||
selects = [
|
||||
exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name)
|
||||
for e in expressions
|
||||
if e.alias_or_name
|
||||
]
|
||||
# If not all selections have an alias, just select *
|
||||
if len(selects) != len(expressions):
|
||||
selects = ["*"]
|
||||
|
||||
scope.expression.replace(exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias)))
|
||||
|
||||
if not duplicate_cte_alias:
|
||||
existing_ctes[scope.expression] = alias
|
||||
return exp.CTE(
|
||||
this=scope.expression,
|
||||
alias=exp.TableAlias(this=exp.to_identifier(alias)),
|
||||
)
|
||||
|
||||
|
||||
def _eliminate_derived_table(scope, existing_ctes, taken):
|
||||
duplicate_cte_alias = existing_ctes.get(scope.expression)
|
||||
parent = scope.expression.parent
|
||||
name = alias = parent.alias
|
||||
|
||||
if not alias:
|
||||
name = alias = find_new_name(taken=taken, base="cte")
|
||||
|
||||
if duplicate_cte_alias:
|
||||
name = duplicate_cte_alias
|
||||
elif taken.get(alias):
|
||||
name = find_new_name(taken=taken, base=alias)
|
||||
|
||||
taken[name] = scope
|
||||
|
||||
table = exp.alias_(exp.table_(name), alias=alias)
|
||||
parent.replace(table)
|
||||
|
||||
if not duplicate_cte_alias:
|
||||
existing_ctes[scope.expression] = name
|
||||
return exp.CTE(
|
||||
this=scope.expression,
|
||||
alias=exp.TableAlias(this=exp.to_identifier(name)),
|
||||
)
|
||||
|
|
|
@ -1,45 +1,39 @@
|
|||
from collections import defaultdict
|
||||
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
from sqlglot.helper import find_new_name
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
|
||||
def merge_derived_tables(expression):
|
||||
def merge_subqueries(expression, leave_tables_isolated=False):
|
||||
"""
|
||||
Rewrite sqlglot AST to merge derived tables into the outer query.
|
||||
|
||||
This also merges CTEs if they are selected from only once.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x)")
|
||||
>>> merge_derived_tables(expression).sql()
|
||||
'SELECT x.a FROM x'
|
||||
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
|
||||
>>> merge_subqueries(expression).sql()
|
||||
'SELECT x.a FROM x JOIN y'
|
||||
|
||||
If `leave_tables_isolated` is True, this will not merge inner queries into outer
|
||||
queries if it would result in multiple table selects in a single query:
|
||||
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
|
||||
>>> merge_subqueries(expression, leave_tables_isolated=True).sql()
|
||||
'SELECT a FROM (SELECT x.a FROM x) JOIN y'
|
||||
|
||||
Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to optimize
|
||||
leave_tables_isolated (bool):
|
||||
Returns:
|
||||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
for outer_scope in traverse_scope(expression):
|
||||
for subquery in outer_scope.derived_tables:
|
||||
inner_select = subquery.unnest()
|
||||
if (
|
||||
isinstance(outer_scope.expression, exp.Select)
|
||||
and isinstance(inner_select, exp.Select)
|
||||
and _mergeable(inner_select)
|
||||
):
|
||||
alias = subquery.alias_or_name
|
||||
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
|
||||
inner_scope = outer_scope.sources[alias]
|
||||
|
||||
_rename_inner_sources(outer_scope, inner_scope, alias)
|
||||
_merge_from(outer_scope, inner_scope, subquery)
|
||||
_merge_joins(outer_scope, inner_scope, from_or_join)
|
||||
_merge_expressions(outer_scope, inner_scope, alias)
|
||||
_merge_where(outer_scope, inner_scope, from_or_join)
|
||||
_merge_order(outer_scope, inner_scope)
|
||||
merge_ctes(expression, leave_tables_isolated)
|
||||
merge_derived_tables(expression, leave_tables_isolated)
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -53,20 +47,81 @@ UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
|
|||
}
|
||||
|
||||
|
||||
def _mergeable(inner_select):
|
||||
def merge_ctes(expression, leave_tables_isolated=False):
|
||||
scopes = traverse_scope(expression)
|
||||
|
||||
# All places where we select from CTEs.
|
||||
# We key on the CTE scope so we can detect CTES that are selected from multiple times.
|
||||
cte_selections = defaultdict(list)
|
||||
for outer_scope in scopes:
|
||||
for table, inner_scope in outer_scope.selected_sources.values():
|
||||
if isinstance(inner_scope, Scope) and inner_scope.is_cte:
|
||||
cte_selections[id(inner_scope)].append(
|
||||
(
|
||||
outer_scope,
|
||||
inner_scope,
|
||||
table,
|
||||
)
|
||||
)
|
||||
|
||||
singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
|
||||
for outer_scope, inner_scope, table in singular_cte_selections:
|
||||
inner_select = inner_scope.expression.unnest()
|
||||
if _mergeable(outer_scope, inner_select, leave_tables_isolated):
|
||||
from_or_join = table.find_ancestor(exp.From, exp.Join)
|
||||
|
||||
node_to_replace = table
|
||||
if isinstance(node_to_replace.parent, exp.Alias):
|
||||
node_to_replace = node_to_replace.parent
|
||||
alias = node_to_replace.alias
|
||||
else:
|
||||
alias = table.name
|
||||
|
||||
_rename_inner_sources(outer_scope, inner_scope, alias)
|
||||
_merge_from(outer_scope, inner_scope, node_to_replace, alias)
|
||||
_merge_joins(outer_scope, inner_scope, from_or_join)
|
||||
_merge_expressions(outer_scope, inner_scope, alias)
|
||||
_merge_where(outer_scope, inner_scope, from_or_join)
|
||||
_merge_order(outer_scope, inner_scope)
|
||||
_pop_cte(inner_scope)
|
||||
|
||||
|
||||
def merge_derived_tables(expression, leave_tables_isolated=False):
|
||||
for outer_scope in traverse_scope(expression):
|
||||
for subquery in outer_scope.derived_tables:
|
||||
inner_select = subquery.unnest()
|
||||
if _mergeable(outer_scope, inner_select, leave_tables_isolated):
|
||||
alias = subquery.alias_or_name
|
||||
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
|
||||
inner_scope = outer_scope.sources[alias]
|
||||
|
||||
_rename_inner_sources(outer_scope, inner_scope, alias)
|
||||
_merge_from(outer_scope, inner_scope, subquery, alias)
|
||||
_merge_joins(outer_scope, inner_scope, from_or_join)
|
||||
_merge_expressions(outer_scope, inner_scope, alias)
|
||||
_merge_where(outer_scope, inner_scope, from_or_join)
|
||||
_merge_order(outer_scope, inner_scope)
|
||||
|
||||
|
||||
def _mergeable(outer_scope, inner_select, leave_tables_isolated):
|
||||
"""
|
||||
Return True if `inner_select` can be merged into outer query.
|
||||
|
||||
Args:
|
||||
outer_scope (Scope)
|
||||
inner_select (exp.Select)
|
||||
leave_tables_isolated (bool)
|
||||
Returns:
|
||||
bool: True if can be merged
|
||||
"""
|
||||
return (
|
||||
isinstance(inner_select, exp.Select)
|
||||
isinstance(outer_scope.expression, exp.Select)
|
||||
and isinstance(inner_select, exp.Select)
|
||||
and isinstance(inner_select, exp.Select)
|
||||
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
|
||||
and inner_select.args.get("from")
|
||||
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
|
||||
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
|
||||
)
|
||||
|
||||
|
||||
|
@ -84,7 +139,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
|
|||
conflicts = conflicts - {alias}
|
||||
|
||||
for conflict in conflicts:
|
||||
new_name = _find_new_name(taken, conflict)
|
||||
new_name = find_new_name(taken, conflict)
|
||||
|
||||
source, _ = inner_scope.selected_sources[conflict]
|
||||
new_alias = exp.to_identifier(new_name)
|
||||
|
@ -102,34 +157,19 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
|
|||
inner_scope.rename_source(conflict, new_name)
|
||||
|
||||
|
||||
def _find_new_name(taken, base):
|
||||
"""
|
||||
Searches for a new source name.
|
||||
|
||||
Args:
|
||||
taken (set[str]): set of taken names
|
||||
base (str): base name to alter
|
||||
"""
|
||||
i = 2
|
||||
new = f"{base}_{i}"
|
||||
while new in taken:
|
||||
i += 1
|
||||
new = f"{base}_{i}"
|
||||
return new
|
||||
|
||||
|
||||
def _merge_from(outer_scope, inner_scope, subquery):
|
||||
def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
|
||||
"""
|
||||
Merge FROM clause of inner query into outer query.
|
||||
|
||||
Args:
|
||||
outer_scope (sqlglot.optimizer.scope.Scope)
|
||||
inner_scope (sqlglot.optimizer.scope.Scope)
|
||||
subquery (exp.Subquery)
|
||||
node_to_replace (exp.Subquery|exp.Table)
|
||||
alias (str)
|
||||
"""
|
||||
new_subquery = inner_scope.expression.args.get("from").expressions[0]
|
||||
subquery.replace(new_subquery)
|
||||
outer_scope.remove_source(subquery.alias_or_name)
|
||||
node_to_replace.replace(new_subquery)
|
||||
outer_scope.remove_source(alias)
|
||||
outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
|
||||
|
||||
|
||||
|
@ -176,7 +216,7 @@ def _merge_expressions(outer_scope, inner_scope, alias):
|
|||
inner_scope (sqlglot.optimizer.scope.Scope)
|
||||
alias (str)
|
||||
"""
|
||||
# Collect all columns that for the alias of the inner query
|
||||
# Collect all columns that reference the alias of the inner query
|
||||
outer_columns = defaultdict(list)
|
||||
for column in outer_scope.columns:
|
||||
if column.table == alias:
|
||||
|
@ -205,7 +245,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
|
|||
if not where or not where.this:
|
||||
return
|
||||
|
||||
if isinstance(from_or_join, exp.Join) and from_or_join.side:
|
||||
if isinstance(from_or_join, exp.Join):
|
||||
# Merge predicates from an outer join to the ON clause
|
||||
from_or_join.on(where.this, copy=False)
|
||||
from_or_join.set("on", simplify(from_or_join.args.get("on")))
|
||||
|
@ -230,3 +270,18 @@ def _merge_order(outer_scope, inner_scope):
|
|||
return
|
||||
|
||||
outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
|
||||
|
||||
|
||||
def _pop_cte(inner_scope):
|
||||
"""
|
||||
Remove CTE from the AST.
|
||||
|
||||
Args:
|
||||
inner_scope (sqlglot.optimizer.scope.Scope)
|
||||
"""
|
||||
cte = inner_scope.expression.parent
|
||||
with_ = cte.parent
|
||||
if len(with_.expressions) == 1:
|
||||
with_.pop()
|
||||
else:
|
||||
cte.pop()
|
|
@ -1,7 +1,7 @@
|
|||
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
|
||||
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
|
||||
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
|
||||
from sqlglot.optimizer.merge_derived_tables import merge_derived_tables
|
||||
from sqlglot.optimizer.merge_subqueries import merge_subqueries
|
||||
from sqlglot.optimizer.normalize import normalize
|
||||
from sqlglot.optimizer.optimize_joins import optimize_joins
|
||||
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
|
||||
|
@ -22,7 +22,7 @@ RULES = (
|
|||
pushdown_predicates,
|
||||
optimize_joins,
|
||||
eliminate_subqueries,
|
||||
merge_derived_tables,
|
||||
merge_subqueries,
|
||||
quote_identities,
|
||||
)
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ def pushdown_projections(expression):
|
|||
parent_selections = {SELECT_ALL}
|
||||
|
||||
if isinstance(scope.expression, exp.Union):
|
||||
left, right = scope.union
|
||||
left, right = scope.union_scopes
|
||||
referenced_columns[left] = parent_selections
|
||||
referenced_columns[right] = parent_selections
|
||||
|
||||
|
|
|
@ -69,7 +69,7 @@ def ensure_schema(schema):
|
|||
|
||||
|
||||
def fs_get(table):
|
||||
name = table.this.name.upper()
|
||||
name = table.this.name
|
||||
|
||||
if name.upper() == "READ_CSV":
|
||||
with csv_reader(table) as reader:
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import itertools
|
||||
from copy import copy
|
||||
from enum import Enum, auto
|
||||
|
||||
|
@ -32,10 +33,11 @@ class Scope:
|
|||
The inner query would have `["col1", "col2"]` for its `outer_column_list`
|
||||
parent (Scope): Parent scope
|
||||
scope_type (ScopeType): Type of this scope, relative to it's parent
|
||||
subquery_scopes (list[Scope]): List of all child scopes for subqueries.
|
||||
This does not include derived tables or CTEs.
|
||||
union (tuple[Scope, Scope]): If this Scope is for a Union expression, this will be
|
||||
a tuple of the left and right child scopes.
|
||||
subquery_scopes (list[Scope]): List of all child scopes for subqueries
|
||||
cte_scopes = (list[Scope]) List of all child scopes for CTEs
|
||||
derived_table_scopes = (list[Scope]) List of all child scopes for derived_tables
|
||||
union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
|
||||
a list of the left and right child scopes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -52,7 +54,9 @@ class Scope:
|
|||
self.parent = parent
|
||||
self.scope_type = scope_type
|
||||
self.subquery_scopes = []
|
||||
self.union = None
|
||||
self.derived_table_scopes = []
|
||||
self.cte_scopes = []
|
||||
self.union_scopes = []
|
||||
self.clear_cache()
|
||||
|
||||
def clear_cache(self):
|
||||
|
@ -197,11 +201,16 @@ class Scope:
|
|||
|
||||
named_outputs = {e.alias_or_name for e in self.expression.expressions}
|
||||
|
||||
self._columns = [
|
||||
c
|
||||
for c in columns + external_columns
|
||||
if not (c.find_ancestor(exp.Qualify, exp.Order) and not c.table and c.name in named_outputs)
|
||||
]
|
||||
self._columns = []
|
||||
for column in columns + external_columns:
|
||||
ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Hint)
|
||||
if (
|
||||
not ancestor
|
||||
or column.table
|
||||
or (column.name not in named_outputs and not isinstance(ancestor, exp.Hint))
|
||||
):
|
||||
self._columns.append(column)
|
||||
|
||||
return self._columns
|
||||
|
||||
@property
|
||||
|
@ -283,6 +292,26 @@ class Scope:
|
|||
"""Determine if this scope is a subquery"""
|
||||
return self.scope_type == ScopeType.SUBQUERY
|
||||
|
||||
@property
|
||||
def is_derived_table(self):
|
||||
"""Determine if this scope is a derived table"""
|
||||
return self.scope_type == ScopeType.DERIVED_TABLE
|
||||
|
||||
@property
|
||||
def is_union(self):
|
||||
"""Determine if this scope is a union"""
|
||||
return self.scope_type == ScopeType.UNION
|
||||
|
||||
@property
|
||||
def is_cte(self):
|
||||
"""Determine if this scope is a common table expression"""
|
||||
return self.scope_type == ScopeType.CTE
|
||||
|
||||
@property
|
||||
def is_root(self):
|
||||
"""Determine if this is the root scope"""
|
||||
return self.scope_type == ScopeType.ROOT
|
||||
|
||||
@property
|
||||
def is_unnest(self):
|
||||
"""Determine if this scope is an unnest"""
|
||||
|
@ -308,6 +337,22 @@ class Scope:
|
|||
self.sources.pop(name, None)
|
||||
self.clear_cache()
|
||||
|
||||
def __repr__(self):
|
||||
return f"Scope<{self.expression.sql()}>"
|
||||
|
||||
def traverse(self):
|
||||
"""
|
||||
Traverse the scope tree from this node.
|
||||
|
||||
Yields:
|
||||
Scope: scope instances in depth-first-search post-order
|
||||
"""
|
||||
for child_scope in itertools.chain(
|
||||
self.cte_scopes, self.union_scopes, self.subquery_scopes, self.derived_table_scopes
|
||||
):
|
||||
yield from child_scope.traverse()
|
||||
yield self
|
||||
|
||||
|
||||
def traverse_scope(expression):
|
||||
"""
|
||||
|
@ -337,6 +382,18 @@ def traverse_scope(expression):
|
|||
return list(_traverse_scope(Scope(expression)))
|
||||
|
||||
|
||||
def build_scope(expression):
|
||||
"""
|
||||
Build a scope tree.
|
||||
|
||||
Args:
|
||||
expression (exp.Expression): expression to build the scope tree for
|
||||
Returns:
|
||||
Scope: root scope
|
||||
"""
|
||||
return traverse_scope(expression)[-1]
|
||||
|
||||
|
||||
def _traverse_scope(scope):
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
yield from _traverse_select(scope)
|
||||
|
@ -370,13 +427,14 @@ def _traverse_union(scope):
|
|||
for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)):
|
||||
yield right
|
||||
|
||||
scope.union = (left, right)
|
||||
scope.union_scopes = [left, right]
|
||||
|
||||
|
||||
def _traverse_derived_tables(derived_tables, scope, scope_type):
|
||||
sources = {}
|
||||
|
||||
for derived_table in derived_tables:
|
||||
top = None
|
||||
for child_scope in _traverse_scope(
|
||||
scope.branch(
|
||||
derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this,
|
||||
|
@ -386,11 +444,16 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
|
|||
)
|
||||
):
|
||||
yield child_scope
|
||||
top = child_scope
|
||||
# Tables without aliases will be set as ""
|
||||
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
|
||||
# Until then, this means that only a single, unaliased derived table is allowed (rather,
|
||||
# the latest one wins.
|
||||
sources[derived_table.alias] = child_scope
|
||||
if scope_type == ScopeType.CTE:
|
||||
scope.cte_scopes.append(top)
|
||||
else:
|
||||
scope.derived_table_scopes.append(top)
|
||||
scope.sources.update(sources)
|
||||
|
||||
|
||||
|
@ -407,8 +470,6 @@ def _add_table_sources(scope):
|
|||
if table_name in scope.sources:
|
||||
# This is a reference to a parent source (e.g. a CTE), not an actual table.
|
||||
scope.sources[source_name] = scope.sources[table_name]
|
||||
elif source_name in scope.sources:
|
||||
raise OptimizeError(f"Duplicate table name: {source_name}")
|
||||
else:
|
||||
sources[source_name] = table
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue