1
0
Fork 0

Merging upstream version 20.11.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:19:58 +01:00
parent 1bce3d0317
commit e71ccc03da
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
141 changed files with 66644 additions and 54334 deletions

View file

@ -4,7 +4,6 @@ import functools
import typing as t
from sqlglot import exp
from sqlglot._typing import E
from sqlglot.helper import (
ensure_list,
is_date_unit,
@ -17,7 +16,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema, ensure_schema
if t.TYPE_CHECKING:
B = t.TypeVar("B", bound=exp.Binary)
from sqlglot._typing import B, E
BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type]
BinaryCoercions = t.Dict[
@ -479,6 +478,20 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
self._set_type(expression, target_type)
return self._annotate_args(expression)
@t.no_type_check
def _annotate_struct_value(
self, expression: exp.Expression
) -> t.Optional[exp.DataType] | exp.ColumnDef:
alias = expression.args.get("alias")
if alias:
return exp.ColumnDef(this=alias.copy(), kind=expression.type)
# Case: key = value or key := value
if expression.expression:
return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)
return expression.type
@t.no_type_check
def _annotate_by_args(
self,
@ -516,16 +529,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
)
if struct:
expressions = [
expr.type
if not expr.args.get("alias")
else exp.ColumnDef(this=expr.args["alias"].copy(), kind=expr.type)
for expr in expressions
]
self._set_type(
expression,
exp.DataType(this=exp.DataType.Type.STRUCT, expressions=expressions, nested=True),
exp.DataType(
this=exp.DataType.Type.STRUCT,
expressions=[self._annotate_struct_value(expr) for expr in expressions],
nested=True,
),
)
return expression

View file

@ -3,18 +3,18 @@ from __future__ import annotations
import typing as t
from sqlglot import exp
from sqlglot._typing import E
from sqlglot.dialects.dialect import Dialect, DialectType
@t.overload
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
...
if t.TYPE_CHECKING:
from sqlglot._typing import E
@t.overload
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier:
...
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: ...
@t.overload
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: ...
def normalize_identifiers(expression, dialect=None):

View file

@ -4,7 +4,6 @@ import itertools
import typing as t
from sqlglot import alias, exp
from sqlglot._typing import E
from sqlglot.dialects.dialect import Dialect, DialectType
from sqlglot.errors import OptimizeError
from sqlglot.helper import seq_get
@ -12,6 +11,9 @@ from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_
from sqlglot.optimizer.simplify import simplify_parens
from sqlglot.schema import Schema, ensure_schema
if t.TYPE_CHECKING:
from sqlglot._typing import E
def qualify_columns(
expression: exp.Expression,
@ -210,7 +212,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
if not node:
return
for column, *_ in walk_in_scope(node):
for column, *_ in walk_in_scope(node, prune=lambda node, *_: node.is_star):
if not isinstance(column, exp.Column):
continue
@ -525,6 +527,7 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
selection = alias(
selection,
alias=selection.output_name or f"_col_{i}",
copy=False,
)
if aliased_column:
selection.set("alias", exp.to_identifier(aliased_column))

View file

@ -4,12 +4,14 @@ import itertools
import typing as t
from sqlglot import alias, exp
from sqlglot._typing import E
from sqlglot.dialects.dialect import DialectType
from sqlglot.helper import csv_reader, name_sequence
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema
if t.TYPE_CHECKING:
from sqlglot._typing import E
def qualify_tables(
expression: E,
@ -46,6 +48,18 @@ def qualify_tables(
db = exp.parse_identifier(db, dialect=dialect) if db else None
catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
def _qualify(table: exp.Table) -> None:
if isinstance(table.this, exp.Identifier):
if not table.args.get("db"):
table.set("db", db)
if not table.args.get("catalog") and table.args.get("db"):
table.set("catalog", catalog)
if not isinstance(expression, exp.Subqueryable):
for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Unionable)):
if isinstance(node, exp.Table):
_qualify(node)
for scope in traverse_scope(expression):
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
if isinstance(derived_table, exp.Subquery):
@ -66,11 +80,7 @@ def qualify_tables(
for name, source in scope.sources.items():
if isinstance(source, exp.Table):
if isinstance(source.this, exp.Identifier):
if not source.args.get("db"):
source.set("db", db)
if not source.args.get("catalog") and source.args.get("db"):
source.set("catalog", catalog)
_qualify(source)
pivots = pivots = source.args.get("pivots")
if not source.alias:
@ -107,5 +117,14 @@ def qualify_tables(
if isinstance(udtf, exp.Values) and not table_alias.columns:
for i, e in enumerate(udtf.expressions[0].expressions):
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
else:
for node, parent, _ in scope.walk():
if (
isinstance(node, exp.Table)
and not node.alias
and isinstance(parent, (exp.From, exp.Join))
):
# Mutates the table by attaching an alias to it
alias(node, node.name, copy=False, table=True)
return expression

View file

@ -323,9 +323,14 @@ class Scope:
sources in the current scope.
"""
if self._external_columns is None:
self._external_columns = [
c for c in self.columns if c.table not in self.selected_sources
]
if isinstance(self.expression, exp.Union):
left, right = self.union_scopes
self._external_columns = left.external_columns + right.external_columns
else:
self._external_columns = [
c for c in self.columns if c.table not in self.selected_sources
]
return self._external_columns
@property
@ -477,11 +482,12 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
Args:
expression (exp.Expression): expression to traverse
Returns:
list[Scope]: scope instances
"""
if isinstance(expression, exp.Unionable) or (
isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable)
isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Unionable)
):
return list(_traverse_scope(Scope(expression)))

View file

@ -1068,9 +1068,11 @@ def extract_interval(expression):
def date_literal(date):
return exp.cast(
exp.Literal.string(date),
exp.DataType.Type.DATETIME
if isinstance(date, datetime.datetime)
else exp.DataType.Type.DATE,
(
exp.DataType.Type.DATETIME
if isinstance(date, datetime.datetime)
else exp.DataType.Type.DATE
),
)

View file

@ -50,11 +50,12 @@ def unnest(select, parent_select, next_alias_name):
):
return
clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
# This subquery returns a scalar and can just be converted to a cross join
if not isinstance(predicate, (exp.In, exp.Any)):
column = exp.column(select.selects[0].alias_or_name, alias)
clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
clause_parent_select = clause.parent_select if clause else None
if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or (
@ -84,12 +85,18 @@ def unnest(select, parent_select, next_alias_name):
column = _other_operand(predicate)
value = select.selects[0]
on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
_replace(predicate, f"NOT {on.right} IS NULL")
join_key = exp.column(value.alias, alias)
join_key_not_null = join_key.is_(exp.null()).not_()
if isinstance(clause, exp.Join):
_replace(predicate, exp.true())
parent_select.where(join_key_not_null, copy=False)
else:
_replace(predicate, join_key_not_null)
parent_select.join(
select.group_by(value.this, copy=False),
on=on,
on=column.eq(join_key),
join_type="LEFT",
join_alias=alias,
copy=False,