Merging upstream version 26.19.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
58527c3d26
commit
a99682f526
98 changed files with 67345 additions and 65319 deletions
|
@ -32,7 +32,7 @@ def annotate_types(
|
|||
schema: t.Optional[t.Dict | Schema] = None,
|
||||
annotators: t.Optional[AnnotatorsType] = None,
|
||||
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
|
||||
dialect: t.Optional[DialectType] = None,
|
||||
dialect: DialectType = None,
|
||||
) -> E:
|
||||
"""
|
||||
Infers the types of an expression, annotating its AST accordingly.
|
||||
|
@ -55,9 +55,9 @@ def annotate_types(
|
|||
The expression annotated with types.
|
||||
"""
|
||||
|
||||
schema = ensure_schema(schema)
|
||||
schema = ensure_schema(schema, dialect=dialect)
|
||||
|
||||
return TypeAnnotator(schema, annotators, coerces_to, dialect=dialect).annotate(expression)
|
||||
return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
|
||||
|
||||
|
||||
def _coerce_date_literal(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
|
||||
|
@ -182,11 +182,12 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
annotators: t.Optional[AnnotatorsType] = None,
|
||||
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
|
||||
binary_coercions: t.Optional[BinaryCoercions] = None,
|
||||
dialect: t.Optional[DialectType] = None,
|
||||
) -> None:
|
||||
self.schema = schema
|
||||
self.annotators = annotators or Dialect.get_or_raise(dialect).ANNOTATORS
|
||||
self.coerces_to = coerces_to or self.COERCES_TO
|
||||
self.annotators = annotators or Dialect.get_or_raise(schema.dialect).ANNOTATORS
|
||||
self.coerces_to = (
|
||||
coerces_to or Dialect.get_or_raise(schema.dialect).COERCES_TO or self.COERCES_TO
|
||||
)
|
||||
self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
|
||||
|
||||
# Caches the ids of annotated sub-Expressions, to ensure we only visit them once
|
||||
|
@ -311,7 +312,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
return expression
|
||||
|
||||
def _maybe_coerce(
|
||||
self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
|
||||
self,
|
||||
type1: exp.DataType | exp.DataType.Type,
|
||||
type2: exp.DataType | exp.DataType.Type,
|
||||
) -> exp.DataType | exp.DataType.Type:
|
||||
"""
|
||||
Returns type2 if type1 can be coerced into it, otherwise type1.
|
||||
|
|
|
@ -1,17 +1,32 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import alias, exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
from sqlglot.schema import ensure_schema
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.schema import Schema
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
|
||||
def isolate_table_selects(expression, schema=None):
|
||||
schema = ensure_schema(schema)
|
||||
|
||||
def isolate_table_selects(
|
||||
expression: E,
|
||||
schema: t.Optional[t.Dict | Schema] = None,
|
||||
dialect: DialectType = None,
|
||||
) -> E:
|
||||
schema = ensure_schema(schema, dialect=dialect)
|
||||
|
||||
for scope in traverse_scope(expression):
|
||||
if len(scope.selected_sources) == 1:
|
||||
continue
|
||||
|
||||
for _, source in scope.selected_sources.values():
|
||||
assert source.parent
|
||||
|
||||
if (
|
||||
not isinstance(source, exp.Table)
|
||||
or not schema.column_names(source)
|
||||
|
|
|
@ -92,9 +92,9 @@ def merge_ctes(expression: E, leave_tables_isolated: bool = False) -> E:
|
|||
_rename_inner_sources(outer_scope, inner_scope, alias)
|
||||
_merge_from(outer_scope, inner_scope, table, alias)
|
||||
_merge_expressions(outer_scope, inner_scope, alias)
|
||||
_merge_order(outer_scope, inner_scope)
|
||||
_merge_joins(outer_scope, inner_scope, from_or_join)
|
||||
_merge_where(outer_scope, inner_scope, from_or_join)
|
||||
_merge_order(outer_scope, inner_scope)
|
||||
_merge_hints(outer_scope, inner_scope)
|
||||
_pop_cte(inner_scope)
|
||||
outer_scope.clear_cache()
|
||||
|
@ -111,9 +111,9 @@ def merge_derived_tables(expression: E, leave_tables_isolated: bool = False) ->
|
|||
_rename_inner_sources(outer_scope, inner_scope, alias)
|
||||
_merge_from(outer_scope, inner_scope, subquery, alias)
|
||||
_merge_expressions(outer_scope, inner_scope, alias)
|
||||
_merge_order(outer_scope, inner_scope)
|
||||
_merge_joins(outer_scope, inner_scope, from_or_join)
|
||||
_merge_where(outer_scope, inner_scope, from_or_join)
|
||||
_merge_order(outer_scope, inner_scope)
|
||||
_merge_hints(outer_scope, inner_scope)
|
||||
outer_scope.clear_cache()
|
||||
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from collections import defaultdict
|
||||
|
||||
from sqlglot import alias, exp
|
||||
|
@ -7,6 +10,11 @@ from sqlglot.schema import ensure_schema
|
|||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import seq_get
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.schema import Schema
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
|
||||
# Sentinel value that means an outer query selecting ALL columns
|
||||
SELECT_ALL = object()
|
||||
|
||||
|
@ -16,7 +24,12 @@ def default_selection(is_agg: bool) -> exp.Alias:
|
|||
return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_")
|
||||
|
||||
|
||||
def pushdown_projections(expression, schema=None, remove_unused_selections=True):
|
||||
def pushdown_projections(
|
||||
expression: E,
|
||||
schema: t.Optional[t.Dict | Schema] = None,
|
||||
remove_unused_selections: bool = True,
|
||||
dialect: DialectType = None,
|
||||
) -> E:
|
||||
"""
|
||||
Rewrite sqlglot AST to remove unused columns projections.
|
||||
|
||||
|
@ -34,9 +47,9 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
|
|||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
# Map of Scope to all columns being selected by outer queries.
|
||||
schema = ensure_schema(schema)
|
||||
source_column_alias_count = {}
|
||||
referenced_columns = defaultdict(set)
|
||||
schema = ensure_schema(schema, dialect=dialect)
|
||||
source_column_alias_count: t.Dict[exp.Expression | Scope, int] = {}
|
||||
referenced_columns: t.DefaultDict[Scope, t.Set[str | object]] = defaultdict(set)
|
||||
|
||||
# We build the scope tree (which is traversed in DFS postorder), then iterate
|
||||
# over the result in reverse order. This should ensure that the set of selected
|
||||
|
@ -69,12 +82,12 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
|
|||
if scope.expression.args.get("by_name"):
|
||||
referenced_columns[right] = referenced_columns[left]
|
||||
else:
|
||||
referenced_columns[right] = [
|
||||
referenced_columns[right] = {
|
||||
right.expression.selects[i].alias_or_name
|
||||
for i, select in enumerate(left.expression.selects)
|
||||
if SELECT_ALL in parent_selections
|
||||
or select.alias_or_name in parent_selections
|
||||
]
|
||||
}
|
||||
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
if remove_unused_selections:
|
||||
|
|
|
@ -23,6 +23,7 @@ def qualify_columns(
|
|||
expand_stars: bool = True,
|
||||
infer_schema: t.Optional[bool] = None,
|
||||
allow_partial_qualification: bool = False,
|
||||
dialect: DialectType = None,
|
||||
) -> exp.Expression:
|
||||
"""
|
||||
Rewrite sqlglot AST to have fully qualified columns.
|
||||
|
@ -50,7 +51,7 @@ def qualify_columns(
|
|||
Notes:
|
||||
- Currently only handles a single PIVOT or UNPIVOT operator
|
||||
"""
|
||||
schema = ensure_schema(schema)
|
||||
schema = ensure_schema(schema, dialect=dialect)
|
||||
annotator = TypeAnnotator(schema)
|
||||
infer_schema = schema.empty if infer_schema is None else infer_schema
|
||||
dialect = Dialect.get_or_raise(schema.dialect)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue