Merging upstream version 10.0.8.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
407314e8d2
commit
efc1e37108
67 changed files with 2461 additions and 840 deletions
|
@ -245,23 +245,31 @@ class TypeAnnotator:
|
|||
def annotate(self, expression):
|
||||
if isinstance(expression, self.TRAVERSABLES):
|
||||
for scope in traverse_scope(expression):
|
||||
subscope_selects = {
|
||||
name: {select.alias_or_name: select for select in source.selects}
|
||||
for name, source in scope.sources.items()
|
||||
if isinstance(source, Scope)
|
||||
}
|
||||
|
||||
selects = {}
|
||||
for name, source in scope.sources.items():
|
||||
if not isinstance(source, Scope):
|
||||
continue
|
||||
if isinstance(source.expression, exp.Values):
|
||||
selects[name] = {
|
||||
alias: column
|
||||
for alias, column in zip(
|
||||
source.expression.alias_column_names,
|
||||
source.expression.expressions[0].expressions,
|
||||
)
|
||||
}
|
||||
else:
|
||||
selects[name] = {
|
||||
select.alias_or_name: select for select in source.expression.selects
|
||||
}
|
||||
# First annotate the current scope's column references
|
||||
for col in scope.columns:
|
||||
source = scope.sources[col.table]
|
||||
if isinstance(source, exp.Table):
|
||||
col.type = self.schema.get_column_type(source, col)
|
||||
else:
|
||||
col.type = subscope_selects[col.table][col.name].type
|
||||
|
||||
col.type = selects[col.table][col.name].type
|
||||
# Then (possibly) annotate the remaining expressions in the scope
|
||||
self._maybe_annotate(scope.expression)
|
||||
|
||||
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
|
||||
|
||||
def _maybe_annotate(self, expression):
|
||||
|
|
48
sqlglot/optimizer/canonicalize.py
Normal file
48
sqlglot/optimizer/canonicalize.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
import itertools
|
||||
|
||||
from sqlglot import exp
|
||||
|
||||
|
||||
def canonicalize(expression: exp.Expression) -> exp.Expression:
|
||||
"""Converts a sql expression into a standard form.
|
||||
|
||||
This method relies on annotate_types because many of the
|
||||
conversions rely on type inference.
|
||||
|
||||
Args:
|
||||
expression: The expression to canonicalize.
|
||||
"""
|
||||
exp.replace_children(expression, canonicalize)
|
||||
expression = add_text_to_concat(expression)
|
||||
expression = coerce_type(expression)
|
||||
return expression
|
||||
|
||||
|
||||
def add_text_to_concat(node: exp.Expression) -> exp.Expression:
|
||||
if isinstance(node, exp.Add) and node.type in exp.DataType.TEXT_TYPES:
|
||||
node = exp.Concat(this=node.this, expression=node.expression)
|
||||
return node
|
||||
|
||||
|
||||
def coerce_type(node: exp.Expression) -> exp.Expression:
|
||||
if isinstance(node, exp.Binary):
|
||||
_coerce_date(node.left, node.right)
|
||||
elif isinstance(node, exp.Between):
|
||||
_coerce_date(node.this, node.args["low"])
|
||||
elif isinstance(node, exp.Extract):
|
||||
if node.expression.type not in exp.DataType.TEMPORAL_TYPES:
|
||||
_replace_cast(node.expression, "datetime")
|
||||
return node
|
||||
|
||||
|
||||
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
|
||||
for a, b in itertools.permutations([a, b]):
|
||||
if a.type == exp.DataType.Type.DATE and b.type != exp.DataType.Type.DATE:
|
||||
_replace_cast(b, "date")
|
||||
|
||||
|
||||
def _replace_cast(node: exp.Expression, to: str) -> None:
|
||||
data_type = exp.DataType.build(to)
|
||||
cast = exp.Cast(this=node.copy(), to=data_type)
|
||||
cast.type = data_type
|
||||
node.replace(cast)
|
|
@ -128,8 +128,8 @@ def join_condition(join):
|
|||
Tuple of (source key, join key, remaining predicate)
|
||||
"""
|
||||
name = join.this.alias_or_name
|
||||
on = join.args.get("on") or exp.TRUE
|
||||
on = on.copy()
|
||||
on = (join.args.get("on") or exp.true()).copy()
|
||||
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
|
||||
source_key = []
|
||||
join_key = []
|
||||
|
||||
|
@ -141,7 +141,7 @@ def join_condition(join):
|
|||
#
|
||||
# should pull y.b as the join key and x.a as the source key
|
||||
if normalized(on):
|
||||
for condition in on.flatten() if isinstance(on, exp.And) else [on]:
|
||||
for condition in on.flatten():
|
||||
if isinstance(condition, exp.EQ):
|
||||
left, right = condition.unnest_operands()
|
||||
left_tables = exp.column_table_names(left)
|
||||
|
@ -150,13 +150,12 @@ def join_condition(join):
|
|||
if name in left_tables and name not in right_tables:
|
||||
join_key.append(left)
|
||||
source_key.append(right)
|
||||
condition.replace(exp.TRUE)
|
||||
condition.replace(exp.true())
|
||||
elif name in right_tables and name not in left_tables:
|
||||
join_key.append(right)
|
||||
source_key.append(left)
|
||||
condition.replace(exp.TRUE)
|
||||
condition.replace(exp.true())
|
||||
|
||||
on = simplify(on)
|
||||
remaining_condition = None if on == exp.TRUE else on
|
||||
|
||||
remaining_condition = None if on == exp.true() else on
|
||||
return source_key, join_key, remaining_condition
|
||||
|
|
|
@ -29,7 +29,7 @@ def optimize_joins(expression):
|
|||
if isinstance(on, exp.Connector):
|
||||
for predicate in on.flatten():
|
||||
if name in exp.column_table_names(predicate):
|
||||
predicate.replace(exp.TRUE)
|
||||
predicate.replace(exp.true())
|
||||
join.on(predicate, copy=False)
|
||||
|
||||
expression = reorder_joins(expression)
|
||||
|
@ -70,6 +70,6 @@ def normalize(expression):
|
|||
def other_table_names(join, exclude):
|
||||
return [
|
||||
name
|
||||
for name in (exp.column_table_names(join.args.get("on") or exp.TRUE))
|
||||
for name in (exp.column_table_names(join.args.get("on") or exp.true()))
|
||||
if name != exclude
|
||||
]
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import sqlglot
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
from sqlglot.optimizer.canonicalize import canonicalize
|
||||
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
|
||||
from sqlglot.optimizer.eliminate_joins import eliminate_joins
|
||||
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
|
||||
|
@ -28,6 +30,8 @@ RULES = (
|
|||
merge_subqueries,
|
||||
eliminate_joins,
|
||||
eliminate_ctes,
|
||||
annotate_types,
|
||||
canonicalize,
|
||||
quote_identities,
|
||||
)
|
||||
|
||||
|
|
|
@ -64,11 +64,11 @@ def pushdown_cnf(predicates, scope, scope_ref_count):
|
|||
for predicate in predicates:
|
||||
for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
|
||||
if isinstance(node, exp.Join):
|
||||
predicate.replace(exp.TRUE)
|
||||
predicate.replace(exp.true())
|
||||
node.on(predicate, copy=False)
|
||||
break
|
||||
if isinstance(node, exp.Select):
|
||||
predicate.replace(exp.TRUE)
|
||||
predicate.replace(exp.true())
|
||||
node.where(replace_aliases(node, predicate), copy=False)
|
||||
|
||||
|
||||
|
|
|
@ -382,9 +382,7 @@ class _Resolver:
|
|||
raise OptimizeError(str(e)) from e
|
||||
|
||||
if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
|
||||
values_alias = source.expression.parent
|
||||
if hasattr(values_alias, "alias_column_names"):
|
||||
return values_alias.alias_column_names
|
||||
return source.expression.alias_column_names
|
||||
|
||||
# Otherwise, if referencing another scope, return that scope's named selects
|
||||
return source.expression.named_selects
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
import itertools
|
||||
|
||||
from sqlglot import alias, exp
|
||||
from sqlglot.helper import csv_reader
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
|
||||
|
||||
def qualify_tables(expression, db=None, catalog=None):
|
||||
def qualify_tables(expression, db=None, catalog=None, schema=None):
|
||||
"""
|
||||
Rewrite sqlglot AST to have fully qualified tables.
|
||||
|
||||
|
@ -18,6 +19,7 @@ def qualify_tables(expression, db=None, catalog=None):
|
|||
expression (sqlglot.Expression): expression to qualify
|
||||
db (str): Database name
|
||||
catalog (str): Catalog name
|
||||
schema: A schema to populate
|
||||
Returns:
|
||||
sqlglot.Expression: qualified expression
|
||||
"""
|
||||
|
@ -41,7 +43,7 @@ def qualify_tables(expression, db=None, catalog=None):
|
|||
source.set("catalog", exp.to_identifier(catalog))
|
||||
|
||||
if not source.alias:
|
||||
source.replace(
|
||||
source = source.replace(
|
||||
alias(
|
||||
source.copy(),
|
||||
source.this if identifier else f"_q_{next(sequence)}",
|
||||
|
@ -49,4 +51,12 @@ def qualify_tables(expression, db=None, catalog=None):
|
|||
)
|
||||
)
|
||||
|
||||
if schema and isinstance(source.this, exp.ReadCSV):
|
||||
with csv_reader(source.this) as reader:
|
||||
header = next(reader)
|
||||
columns = next(reader)
|
||||
schema.add_table(
|
||||
source, {k: type(v).__name__ for k, v in zip(header, columns)}
|
||||
)
|
||||
|
||||
return expression
|
||||
|
|
|
@ -189,11 +189,11 @@ def absorb_and_eliminate(expression):
|
|||
|
||||
# absorb
|
||||
if is_complement(b, aa):
|
||||
aa.replace(exp.TRUE if kind == exp.And else exp.FALSE)
|
||||
aa.replace(exp.true() if kind == exp.And else exp.false())
|
||||
elif is_complement(b, ab):
|
||||
ab.replace(exp.TRUE if kind == exp.And else exp.FALSE)
|
||||
ab.replace(exp.true() if kind == exp.And else exp.false())
|
||||
elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
|
||||
a.replace(exp.FALSE if kind == exp.And else exp.TRUE)
|
||||
a.replace(exp.false() if kind == exp.And else exp.true())
|
||||
elif isinstance(b, kind):
|
||||
# eliminate
|
||||
rhs = b.unnest_operands()
|
||||
|
|
|
@ -169,7 +169,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
select.parent.replace(alias)
|
||||
|
||||
for key, column, predicate in keys:
|
||||
predicate.replace(exp.TRUE)
|
||||
predicate.replace(exp.true())
|
||||
nested = exp.column(key_aliases[key], table_alias)
|
||||
|
||||
if key in group_by:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue