Merging upstream version 20.11.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
1bce3d0317
commit
e71ccc03da
141 changed files with 66644 additions and 54334 deletions
|
@ -4,7 +4,6 @@ import functools
|
|||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.helper import (
|
||||
ensure_list,
|
||||
is_date_unit,
|
||||
|
@ -17,7 +16,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
|
|||
from sqlglot.schema import Schema, ensure_schema
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
B = t.TypeVar("B", bound=exp.Binary)
|
||||
from sqlglot._typing import B, E
|
||||
|
||||
BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type]
|
||||
BinaryCoercions = t.Dict[
|
||||
|
@ -479,6 +478,20 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
self._set_type(expression, target_type)
|
||||
return self._annotate_args(expression)
|
||||
|
||||
@t.no_type_check
|
||||
def _annotate_struct_value(
|
||||
self, expression: exp.Expression
|
||||
) -> t.Optional[exp.DataType] | exp.ColumnDef:
|
||||
alias = expression.args.get("alias")
|
||||
if alias:
|
||||
return exp.ColumnDef(this=alias.copy(), kind=expression.type)
|
||||
|
||||
# Case: key = value or key := value
|
||||
if expression.expression:
|
||||
return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)
|
||||
|
||||
return expression.type
|
||||
|
||||
@t.no_type_check
|
||||
def _annotate_by_args(
|
||||
self,
|
||||
|
@ -516,16 +529,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
)
|
||||
|
||||
if struct:
|
||||
expressions = [
|
||||
expr.type
|
||||
if not expr.args.get("alias")
|
||||
else exp.ColumnDef(this=expr.args["alias"].copy(), kind=expr.type)
|
||||
for expr in expressions
|
||||
]
|
||||
|
||||
self._set_type(
|
||||
expression,
|
||||
exp.DataType(this=exp.DataType.Type.STRUCT, expressions=expressions, nested=True),
|
||||
exp.DataType(
|
||||
this=exp.DataType.Type.STRUCT,
|
||||
expressions=[self._annotate_struct_value(expr) for expr in expressions],
|
||||
nested=True,
|
||||
),
|
||||
)
|
||||
|
||||
return expression
|
||||
|
|
|
@ -3,18 +3,18 @@ from __future__ import annotations
|
|||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import Dialect, DialectType
|
||||
|
||||
|
||||
@t.overload
|
||||
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
|
||||
...
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E
|
||||
|
||||
|
||||
@t.overload
|
||||
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier:
|
||||
...
|
||||
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: ...
|
||||
|
||||
|
||||
def normalize_identifiers(expression, dialect=None):
|
||||
|
|
|
@ -4,7 +4,6 @@ import itertools
|
|||
import typing as t
|
||||
|
||||
from sqlglot import alias, exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import Dialect, DialectType
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import seq_get
|
||||
|
@ -12,6 +11,9 @@ from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_
|
|||
from sqlglot.optimizer.simplify import simplify_parens
|
||||
from sqlglot.schema import Schema, ensure_schema
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E
|
||||
|
||||
|
||||
def qualify_columns(
|
||||
expression: exp.Expression,
|
||||
|
@ -210,7 +212,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
|
|||
if not node:
|
||||
return
|
||||
|
||||
for column, *_ in walk_in_scope(node):
|
||||
for column, *_ in walk_in_scope(node, prune=lambda node, *_: node.is_star):
|
||||
if not isinstance(column, exp.Column):
|
||||
continue
|
||||
|
||||
|
@ -525,6 +527,7 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
|
|||
selection = alias(
|
||||
selection,
|
||||
alias=selection.output_name or f"_col_{i}",
|
||||
copy=False,
|
||||
)
|
||||
if aliased_column:
|
||||
selection.set("alias", exp.to_identifier(aliased_column))
|
||||
|
|
|
@ -4,12 +4,14 @@ import itertools
|
|||
import typing as t
|
||||
|
||||
from sqlglot import alias, exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
from sqlglot.helper import csv_reader, name_sequence
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import Schema
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E
|
||||
|
||||
|
||||
def qualify_tables(
|
||||
expression: E,
|
||||
|
@ -46,6 +48,18 @@ def qualify_tables(
|
|||
db = exp.parse_identifier(db, dialect=dialect) if db else None
|
||||
catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
|
||||
|
||||
def _qualify(table: exp.Table) -> None:
|
||||
if isinstance(table.this, exp.Identifier):
|
||||
if not table.args.get("db"):
|
||||
table.set("db", db)
|
||||
if not table.args.get("catalog") and table.args.get("db"):
|
||||
table.set("catalog", catalog)
|
||||
|
||||
if not isinstance(expression, exp.Subqueryable):
|
||||
for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Unionable)):
|
||||
if isinstance(node, exp.Table):
|
||||
_qualify(node)
|
||||
|
||||
for scope in traverse_scope(expression):
|
||||
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
|
||||
if isinstance(derived_table, exp.Subquery):
|
||||
|
@ -66,11 +80,7 @@ def qualify_tables(
|
|||
|
||||
for name, source in scope.sources.items():
|
||||
if isinstance(source, exp.Table):
|
||||
if isinstance(source.this, exp.Identifier):
|
||||
if not source.args.get("db"):
|
||||
source.set("db", db)
|
||||
if not source.args.get("catalog") and source.args.get("db"):
|
||||
source.set("catalog", catalog)
|
||||
_qualify(source)
|
||||
|
||||
pivots = pivots = source.args.get("pivots")
|
||||
if not source.alias:
|
||||
|
@ -107,5 +117,14 @@ def qualify_tables(
|
|||
if isinstance(udtf, exp.Values) and not table_alias.columns:
|
||||
for i, e in enumerate(udtf.expressions[0].expressions):
|
||||
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
|
||||
else:
|
||||
for node, parent, _ in scope.walk():
|
||||
if (
|
||||
isinstance(node, exp.Table)
|
||||
and not node.alias
|
||||
and isinstance(parent, (exp.From, exp.Join))
|
||||
):
|
||||
# Mutates the table by attaching an alias to it
|
||||
alias(node, node.name, copy=False, table=True)
|
||||
|
||||
return expression
|
||||
|
|
|
@ -323,9 +323,14 @@ class Scope:
|
|||
sources in the current scope.
|
||||
"""
|
||||
if self._external_columns is None:
|
||||
self._external_columns = [
|
||||
c for c in self.columns if c.table not in self.selected_sources
|
||||
]
|
||||
if isinstance(self.expression, exp.Union):
|
||||
left, right = self.union_scopes
|
||||
self._external_columns = left.external_columns + right.external_columns
|
||||
else:
|
||||
self._external_columns = [
|
||||
c for c in self.columns if c.table not in self.selected_sources
|
||||
]
|
||||
|
||||
return self._external_columns
|
||||
|
||||
@property
|
||||
|
@ -477,11 +482,12 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
|
|||
|
||||
Args:
|
||||
expression (exp.Expression): expression to traverse
|
||||
|
||||
Returns:
|
||||
list[Scope]: scope instances
|
||||
"""
|
||||
if isinstance(expression, exp.Unionable) or (
|
||||
isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable)
|
||||
isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Unionable)
|
||||
):
|
||||
return list(_traverse_scope(Scope(expression)))
|
||||
|
||||
|
|
|
@ -1068,9 +1068,11 @@ def extract_interval(expression):
|
|||
def date_literal(date):
|
||||
return exp.cast(
|
||||
exp.Literal.string(date),
|
||||
exp.DataType.Type.DATETIME
|
||||
if isinstance(date, datetime.datetime)
|
||||
else exp.DataType.Type.DATE,
|
||||
(
|
||||
exp.DataType.Type.DATETIME
|
||||
if isinstance(date, datetime.datetime)
|
||||
else exp.DataType.Type.DATE
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -50,11 +50,12 @@ def unnest(select, parent_select, next_alias_name):
|
|||
):
|
||||
return
|
||||
|
||||
clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
|
||||
|
||||
# This subquery returns a scalar and can just be converted to a cross join
|
||||
if not isinstance(predicate, (exp.In, exp.Any)):
|
||||
column = exp.column(select.selects[0].alias_or_name, alias)
|
||||
|
||||
clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
|
||||
clause_parent_select = clause.parent_select if clause else None
|
||||
|
||||
if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or (
|
||||
|
@ -84,12 +85,18 @@ def unnest(select, parent_select, next_alias_name):
|
|||
column = _other_operand(predicate)
|
||||
value = select.selects[0]
|
||||
|
||||
on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
|
||||
_replace(predicate, f"NOT {on.right} IS NULL")
|
||||
join_key = exp.column(value.alias, alias)
|
||||
join_key_not_null = join_key.is_(exp.null()).not_()
|
||||
|
||||
if isinstance(clause, exp.Join):
|
||||
_replace(predicate, exp.true())
|
||||
parent_select.where(join_key_not_null, copy=False)
|
||||
else:
|
||||
_replace(predicate, join_key_not_null)
|
||||
|
||||
parent_select.join(
|
||||
select.group_by(value.this, copy=False),
|
||||
on=on,
|
||||
on=column.eq(join_key),
|
||||
join_type="LEFT",
|
||||
join_alias=alias,
|
||||
copy=False,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue