Merging upstream version 21.1.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
4e41aa0bbb
commit
bf03050a25
91 changed files with 49165 additions and 47854 deletions
|
@ -204,7 +204,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.TimeAdd,
|
||||
exp.TimeStrToTime,
|
||||
exp.TimeSub,
|
||||
exp.Timestamp,
|
||||
exp.TimestampAdd,
|
||||
exp.TimestampSub,
|
||||
exp.UnixToTime,
|
||||
|
@ -276,6 +275,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
|
||||
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
|
||||
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
|
||||
exp.Timestamp: lambda self, e: self._annotate_with_type(
|
||||
e,
|
||||
exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP,
|
||||
),
|
||||
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
|
||||
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
|
||||
exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True),
|
||||
|
|
|
@ -38,7 +38,12 @@ def replace_date_funcs(node: exp.Expression) -> exp.Expression:
|
|||
if isinstance(node, exp.Date) and not node.expressions and not node.args.get("zone"):
|
||||
return exp.cast(node.this, to=exp.DataType.Type.DATE)
|
||||
if isinstance(node, exp.Timestamp) and not node.expression:
|
||||
return exp.cast(node.this, to=exp.DataType.Type.TIMESTAMP)
|
||||
if not node.type:
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
node = annotate_types(node)
|
||||
return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP)
|
||||
|
||||
return node
|
||||
|
||||
|
||||
|
@ -76,9 +81,8 @@ def coerce_type(node: exp.Expression) -> exp.Expression:
|
|||
def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
|
||||
if (
|
||||
isinstance(expression, exp.Cast)
|
||||
and expression.to.type
|
||||
and expression.this.type
|
||||
and expression.to.type.this == expression.this.type.this
|
||||
and expression.to.this == expression.this.type.this
|
||||
):
|
||||
return expression.this
|
||||
return expression
|
||||
|
|
|
@ -6,7 +6,7 @@ import typing as t
|
|||
from sqlglot import alias, exp
|
||||
from sqlglot.dialects.dialect import Dialect, DialectType
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.helper import seq_get, SingleValuedMapping
|
||||
from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
|
||||
from sqlglot.optimizer.simplify import simplify_parens
|
||||
from sqlglot.schema import Schema, ensure_schema
|
||||
|
@ -586,8 +586,8 @@ class Resolver:
|
|||
def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
|
||||
self.scope = scope
|
||||
self.schema = schema
|
||||
self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
|
||||
self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
|
||||
self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
|
||||
self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
|
||||
self._all_columns: t.Optional[t.Set[str]] = None
|
||||
self._infer_schema = infer_schema
|
||||
|
||||
|
@ -640,7 +640,7 @@ class Resolver:
|
|||
}
|
||||
return self._all_columns
|
||||
|
||||
def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
|
||||
def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
|
||||
"""Resolve the source columns for a given source `name`."""
|
||||
if name not in self.scope.sources:
|
||||
raise OptimizeError(f"Unknown table: {name}")
|
||||
|
@ -662,10 +662,15 @@ class Resolver:
|
|||
else:
|
||||
column_aliases = []
|
||||
|
||||
# If the source's columns are aliased, their aliases shadow the corresponding column names
|
||||
return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
|
||||
if column_aliases:
|
||||
# If the source's columns are aliased, their aliases shadow the corresponding column names.
|
||||
# This can be expensive if there are lots of columns, so only do this if column_aliases exist.
|
||||
return [
|
||||
alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
|
||||
]
|
||||
return columns
|
||||
|
||||
def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]:
|
||||
def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
|
||||
if self._source_columns is None:
|
||||
self._source_columns = {
|
||||
source_name: self.get_source_columns(source_name)
|
||||
|
@ -676,8 +681,8 @@ class Resolver:
|
|||
return self._source_columns
|
||||
|
||||
def _get_unambiguous_columns(
|
||||
self, source_columns: t.Dict[str, t.List[str]]
|
||||
) -> t.Dict[str, str]:
|
||||
self, source_columns: t.Dict[str, t.Sequence[str]]
|
||||
) -> t.Mapping[str, str]:
|
||||
"""
|
||||
Find all the unambiguous columns in sources.
|
||||
|
||||
|
@ -693,12 +698,17 @@ class Resolver:
|
|||
source_columns_pairs = list(source_columns.items())
|
||||
|
||||
first_table, first_columns = source_columns_pairs[0]
|
||||
unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
|
||||
|
||||
if len(source_columns_pairs) == 1:
|
||||
# Performance optimization - avoid copying first_columns if there is only one table.
|
||||
return SingleValuedMapping(first_columns, first_table)
|
||||
|
||||
unambiguous_columns = {col: first_table for col in first_columns}
|
||||
all_columns = set(unambiguous_columns)
|
||||
|
||||
for table, columns in source_columns_pairs[1:]:
|
||||
unique = self._find_unique_columns(columns)
|
||||
ambiguous = set(all_columns).intersection(unique)
|
||||
unique = set(columns)
|
||||
ambiguous = all_columns.intersection(unique)
|
||||
all_columns.update(columns)
|
||||
|
||||
for column in ambiguous:
|
||||
|
@ -707,19 +717,3 @@ class Resolver:
|
|||
unambiguous_columns[column] = table
|
||||
|
||||
return unambiguous_columns
|
||||
|
||||
@staticmethod
|
||||
def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]:
|
||||
"""
|
||||
Find the unique columns in a list of columns.
|
||||
|
||||
Example:
|
||||
>>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
|
||||
['a', 'c']
|
||||
|
||||
This is necessary because duplicate column names are ambiguous.
|
||||
"""
|
||||
counts: t.Dict[str, int] = {}
|
||||
for column in columns:
|
||||
counts[column] = counts.get(column, 0) + 1
|
||||
return {column for column, count in counts.items() if count == 1}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue