1
0
Fork 0

Merging upstream version 10.0.8.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:54:32 +01:00
parent 407314e8d2
commit efc1e37108
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
67 changed files with 2461 additions and 840 deletions

View file

@ -245,23 +245,31 @@ class TypeAnnotator:
def annotate(self, expression):
if isinstance(expression, self.TRAVERSABLES):
for scope in traverse_scope(expression):
subscope_selects = {
name: {select.alias_or_name: select for select in source.selects}
for name, source in scope.sources.items()
if isinstance(source, Scope)
}
selects = {}
for name, source in scope.sources.items():
if not isinstance(source, Scope):
continue
if isinstance(source.expression, exp.Values):
selects[name] = {
alias: column
for alias, column in zip(
source.expression.alias_column_names,
source.expression.expressions[0].expressions,
)
}
else:
selects[name] = {
select.alias_or_name: select for select in source.expression.selects
}
# First annotate the current scope's column references
for col in scope.columns:
source = scope.sources[col.table]
if isinstance(source, exp.Table):
col.type = self.schema.get_column_type(source, col)
else:
col.type = subscope_selects[col.table][col.name].type
col.type = selects[col.table][col.name].type
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
def _maybe_annotate(self, expression):

View file

@ -0,0 +1,48 @@
import itertools
from sqlglot import exp
def canonicalize(expression: exp.Expression) -> exp.Expression:
"""Converts a sql expression into a standard form.
This method relies on annotate_types because many of the
conversions rely on type inference.
Args:
expression: The expression to canonicalize.
"""
exp.replace_children(expression, canonicalize)
expression = add_text_to_concat(expression)
expression = coerce_type(expression)
return expression
def add_text_to_concat(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Add) and node.type in exp.DataType.TEXT_TYPES:
node = exp.Concat(this=node.this, expression=node.expression)
return node
def coerce_type(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Binary):
_coerce_date(node.left, node.right)
elif isinstance(node, exp.Between):
_coerce_date(node.this, node.args["low"])
elif isinstance(node, exp.Extract):
if node.expression.type not in exp.DataType.TEMPORAL_TYPES:
_replace_cast(node.expression, "datetime")
return node
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
for a, b in itertools.permutations([a, b]):
if a.type == exp.DataType.Type.DATE and b.type != exp.DataType.Type.DATE:
_replace_cast(b, "date")
def _replace_cast(node: exp.Expression, to: str) -> None:
data_type = exp.DataType.build(to)
cast = exp.Cast(this=node.copy(), to=data_type)
cast.type = data_type
node.replace(cast)

View file

@ -128,8 +128,8 @@ def join_condition(join):
Tuple of (source key, join key, remaining predicate)
"""
name = join.this.alias_or_name
on = join.args.get("on") or exp.TRUE
on = on.copy()
on = (join.args.get("on") or exp.true()).copy()
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
source_key = []
join_key = []
@ -141,7 +141,7 @@ def join_condition(join):
#
# should pull y.b as the join key and x.a as the source key
if normalized(on):
for condition in on.flatten() if isinstance(on, exp.And) else [on]:
for condition in on.flatten():
if isinstance(condition, exp.EQ):
left, right = condition.unnest_operands()
left_tables = exp.column_table_names(left)
@ -150,13 +150,12 @@ def join_condition(join):
if name in left_tables and name not in right_tables:
join_key.append(left)
source_key.append(right)
condition.replace(exp.TRUE)
condition.replace(exp.true())
elif name in right_tables and name not in left_tables:
join_key.append(right)
source_key.append(left)
condition.replace(exp.TRUE)
condition.replace(exp.true())
on = simplify(on)
remaining_condition = None if on == exp.TRUE else on
remaining_condition = None if on == exp.true() else on
return source_key, join_key, remaining_condition

View file

@ -29,7 +29,7 @@ def optimize_joins(expression):
if isinstance(on, exp.Connector):
for predicate in on.flatten():
if name in exp.column_table_names(predicate):
predicate.replace(exp.TRUE)
predicate.replace(exp.true())
join.on(predicate, copy=False)
expression = reorder_joins(expression)
@ -70,6 +70,6 @@ def normalize(expression):
def other_table_names(join, exclude):
return [
name
for name in (exp.column_table_names(join.args.get("on") or exp.TRUE))
for name in (exp.column_table_names(join.args.get("on") or exp.true()))
if name != exclude
]

View file

@ -1,4 +1,6 @@
import sqlglot
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.canonicalize import canonicalize
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
@ -28,6 +30,8 @@ RULES = (
merge_subqueries,
eliminate_joins,
eliminate_ctes,
annotate_types,
canonicalize,
quote_identities,
)

View file

@ -64,11 +64,11 @@ def pushdown_cnf(predicates, scope, scope_ref_count):
for predicate in predicates:
for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
if isinstance(node, exp.Join):
predicate.replace(exp.TRUE)
predicate.replace(exp.true())
node.on(predicate, copy=False)
break
if isinstance(node, exp.Select):
predicate.replace(exp.TRUE)
predicate.replace(exp.true())
node.where(replace_aliases(node, predicate), copy=False)

View file

@ -382,9 +382,7 @@ class _Resolver:
raise OptimizeError(str(e)) from e
if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
values_alias = source.expression.parent
if hasattr(values_alias, "alias_column_names"):
return values_alias.alias_column_names
return source.expression.alias_column_names
# Otherwise, if referencing another scope, return that scope's named selects
return source.expression.named_selects

View file

@ -1,10 +1,11 @@
import itertools
from sqlglot import alias, exp
from sqlglot.helper import csv_reader
from sqlglot.optimizer.scope import traverse_scope
def qualify_tables(expression, db=None, catalog=None):
def qualify_tables(expression, db=None, catalog=None, schema=None):
"""
Rewrite sqlglot AST to have fully qualified tables.
@ -18,6 +19,7 @@ def qualify_tables(expression, db=None, catalog=None):
expression (sqlglot.Expression): expression to qualify
db (str): Database name
catalog (str): Catalog name
schema: A schema to populate
Returns:
sqlglot.Expression: qualified expression
"""
@ -41,7 +43,7 @@ def qualify_tables(expression, db=None, catalog=None):
source.set("catalog", exp.to_identifier(catalog))
if not source.alias:
source.replace(
source = source.replace(
alias(
source.copy(),
source.this if identifier else f"_q_{next(sequence)}",
@ -49,4 +51,12 @@ def qualify_tables(expression, db=None, catalog=None):
)
)
if schema and isinstance(source.this, exp.ReadCSV):
with csv_reader(source.this) as reader:
header = next(reader)
columns = next(reader)
schema.add_table(
source, {k: type(v).__name__ for k, v in zip(header, columns)}
)
return expression

View file

@ -189,11 +189,11 @@ def absorb_and_eliminate(expression):
# absorb
if is_complement(b, aa):
aa.replace(exp.TRUE if kind == exp.And else exp.FALSE)
aa.replace(exp.true() if kind == exp.And else exp.false())
elif is_complement(b, ab):
ab.replace(exp.TRUE if kind == exp.And else exp.FALSE)
ab.replace(exp.true() if kind == exp.And else exp.false())
elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
a.replace(exp.FALSE if kind == exp.And else exp.TRUE)
a.replace(exp.false() if kind == exp.And else exp.true())
elif isinstance(b, kind):
# eliminate
rhs = b.unnest_operands()

View file

@ -169,7 +169,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
select.parent.replace(alias)
for key, column, predicate in keys:
predicate.replace(exp.TRUE)
predicate.replace(exp.true())
nested = exp.column(key_aliases[key], table_alias)
if key in group_by: