1
0
Fork 0

Merging upstream version 26.19.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-05-24 07:15:28 +02:00
parent 58527c3d26
commit a99682f526
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
98 changed files with 67345 additions and 65319 deletions

View file

@ -32,7 +32,7 @@ def annotate_types(
schema: t.Optional[t.Dict | Schema] = None,
annotators: t.Optional[AnnotatorsType] = None,
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
dialect: t.Optional[DialectType] = None,
dialect: DialectType = None,
) -> E:
"""
Infers the types of an expression, annotating its AST accordingly.
@ -55,9 +55,9 @@ def annotate_types(
The expression annotated with types.
"""
schema = ensure_schema(schema)
schema = ensure_schema(schema, dialect=dialect)
return TypeAnnotator(schema, annotators, coerces_to, dialect=dialect).annotate(expression)
return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
def _coerce_date_literal(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
@ -182,11 +182,12 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
annotators: t.Optional[AnnotatorsType] = None,
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
binary_coercions: t.Optional[BinaryCoercions] = None,
dialect: t.Optional[DialectType] = None,
) -> None:
self.schema = schema
self.annotators = annotators or Dialect.get_or_raise(dialect).ANNOTATORS
self.coerces_to = coerces_to or self.COERCES_TO
self.annotators = annotators or Dialect.get_or_raise(schema.dialect).ANNOTATORS
self.coerces_to = (
coerces_to or Dialect.get_or_raise(schema.dialect).COERCES_TO or self.COERCES_TO
)
self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
# Caches the ids of annotated sub-Expressions, to ensure we only visit them once
@ -311,7 +312,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
return expression
def _maybe_coerce(
self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
self,
type1: exp.DataType | exp.DataType.Type,
type2: exp.DataType | exp.DataType.Type,
) -> exp.DataType | exp.DataType.Type:
"""
Returns type2 if type1 can be coerced into it, otherwise type1.

View file

@ -1,17 +1,32 @@
from __future__ import annotations
import typing as t
from sqlglot import alias, exp
from sqlglot.errors import OptimizeError
from sqlglot.optimizer.scope import traverse_scope
from sqlglot.schema import ensure_schema
if t.TYPE_CHECKING:
from sqlglot._typing import E
from sqlglot.schema import Schema
from sqlglot.dialects.dialect import DialectType
def isolate_table_selects(expression, schema=None):
schema = ensure_schema(schema)
def isolate_table_selects(
expression: E,
schema: t.Optional[t.Dict | Schema] = None,
dialect: DialectType = None,
) -> E:
schema = ensure_schema(schema, dialect=dialect)
for scope in traverse_scope(expression):
if len(scope.selected_sources) == 1:
continue
for _, source in scope.selected_sources.values():
assert source.parent
if (
not isinstance(source, exp.Table)
or not schema.column_names(source)

View file

@ -92,9 +92,9 @@ def merge_ctes(expression: E, leave_tables_isolated: bool = False) -> E:
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, table, alias)
_merge_expressions(outer_scope, inner_scope, alias)
_merge_order(outer_scope, inner_scope)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
_merge_hints(outer_scope, inner_scope)
_pop_cte(inner_scope)
outer_scope.clear_cache()
@ -111,9 +111,9 @@ def merge_derived_tables(expression: E, leave_tables_isolated: bool = False) ->
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, subquery, alias)
_merge_expressions(outer_scope, inner_scope, alias)
_merge_order(outer_scope, inner_scope)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
_merge_hints(outer_scope, inner_scope)
outer_scope.clear_cache()

View file

@ -1,3 +1,6 @@
from __future__ import annotations
import typing as t
from collections import defaultdict
from sqlglot import alias, exp
@ -7,6 +10,11 @@ from sqlglot.schema import ensure_schema
from sqlglot.errors import OptimizeError
from sqlglot.helper import seq_get
if t.TYPE_CHECKING:
from sqlglot._typing import E
from sqlglot.schema import Schema
from sqlglot.dialects.dialect import DialectType
# Sentinel value that means an outer query selecting ALL columns
SELECT_ALL = object()
@ -16,7 +24,12 @@ def default_selection(is_agg: bool) -> exp.Alias:
return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_")
def pushdown_projections(expression, schema=None, remove_unused_selections=True):
def pushdown_projections(
expression: E,
schema: t.Optional[t.Dict | Schema] = None,
remove_unused_selections: bool = True,
dialect: DialectType = None,
) -> E:
"""
Rewrite sqlglot AST to remove unused columns projections.
@ -34,9 +47,9 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
sqlglot.Expression: optimized expression
"""
# Map of Scope to all columns being selected by outer queries.
schema = ensure_schema(schema)
source_column_alias_count = {}
referenced_columns = defaultdict(set)
schema = ensure_schema(schema, dialect=dialect)
source_column_alias_count: t.Dict[exp.Expression | Scope, int] = {}
referenced_columns: t.DefaultDict[Scope, t.Set[str | object]] = defaultdict(set)
# We build the scope tree (which is traversed in DFS postorder), then iterate
# over the result in reverse order. This should ensure that the set of selected
@ -69,12 +82,12 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
if scope.expression.args.get("by_name"):
referenced_columns[right] = referenced_columns[left]
else:
referenced_columns[right] = [
referenced_columns[right] = {
right.expression.selects[i].alias_or_name
for i, select in enumerate(left.expression.selects)
if SELECT_ALL in parent_selections
or select.alias_or_name in parent_selections
]
}
if isinstance(scope.expression, exp.Select):
if remove_unused_selections:

View file

@ -23,6 +23,7 @@ def qualify_columns(
expand_stars: bool = True,
infer_schema: t.Optional[bool] = None,
allow_partial_qualification: bool = False,
dialect: DialectType = None,
) -> exp.Expression:
"""
Rewrite sqlglot AST to have fully qualified columns.
@ -50,7 +51,7 @@ def qualify_columns(
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
"""
schema = ensure_schema(schema)
schema = ensure_schema(schema, dialect=dialect)
annotator = TypeAnnotator(schema)
infer_schema = schema.empty if infer_schema is None else infer_schema
dialect = Dialect.get_or_raise(schema.dialect)