Merging upstream version 26.29.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
141a93f866
commit
4c1ec9be5a
58 changed files with 17605 additions and 17151 deletions
|
@ -12,7 +12,7 @@ from sqlglot.helper import (
|
|||
seq_get,
|
||||
)
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import Schema, ensure_schema
|
||||
from sqlglot.schema import MappingSchema, Schema, ensure_schema
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
@ -290,9 +290,52 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
elif isinstance(source.expression, exp.Unnest):
|
||||
self._set_type(col, source.expression.type)
|
||||
|
||||
if isinstance(self.schema, MappingSchema):
|
||||
for table_column in scope.table_columns:
|
||||
source = scope.sources.get(table_column.name)
|
||||
|
||||
if isinstance(source, exp.Table):
|
||||
schema = self.schema.find(
|
||||
source, raise_on_missing=False, ensure_data_types=True
|
||||
)
|
||||
if not isinstance(schema, dict):
|
||||
continue
|
||||
|
||||
struct_type = exp.DataType(
|
||||
this=exp.DataType.Type.STRUCT,
|
||||
expressions=[
|
||||
exp.ColumnDef(this=exp.to_identifier(c), kind=kind)
|
||||
for c, kind in schema.items()
|
||||
],
|
||||
nested=True,
|
||||
)
|
||||
self._set_type(table_column, struct_type)
|
||||
elif (
|
||||
isinstance(source, Scope)
|
||||
and isinstance(source.expression, exp.Query)
|
||||
and source.expression.is_type(exp.DataType.Type.STRUCT)
|
||||
):
|
||||
self._set_type(table_column, source.expression.type)
|
||||
|
||||
# Then (possibly) annotate the remaining expressions in the scope
|
||||
self._maybe_annotate(scope.expression)
|
||||
|
||||
if self.schema.dialect == "bigquery" and isinstance(scope.expression, exp.Query):
|
||||
struct_type = exp.DataType(
|
||||
this=exp.DataType.Type.STRUCT,
|
||||
expressions=[
|
||||
exp.ColumnDef(this=exp.to_identifier(select.output_name), kind=select.type)
|
||||
for select in scope.expression.selects
|
||||
],
|
||||
nested=True,
|
||||
)
|
||||
if not any(
|
||||
cd.kind.is_type(exp.DataType.Type.UNKNOWN)
|
||||
for cd in struct_type.expressions
|
||||
if cd.kind
|
||||
):
|
||||
self._set_type(scope.expression, struct_type)
|
||||
|
||||
def _maybe_annotate(self, expression: E) -> E:
|
||||
if id(expression) in self._visited:
|
||||
return expression # We've already inferred the expression's type
|
||||
|
|
|
@ -529,6 +529,13 @@ def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualificati
|
|||
column_table = resolver.get_table(column_name)
|
||||
if column_table:
|
||||
column.set("table", column_table)
|
||||
elif (
|
||||
resolver.schema.dialect == "bigquery"
|
||||
and len(column.parts) == 1
|
||||
and column_name in scope.selected_sources
|
||||
):
|
||||
# BigQuery allows tables to be referenced as columns, treating them as structs
|
||||
scope.replace(column, exp.TableColumn(this=column.this))
|
||||
|
||||
for pivot in scope.pivots:
|
||||
for column in pivot.find_all(exp.Column):
|
||||
|
|
|
@ -88,6 +88,7 @@ class Scope:
|
|||
def clear_cache(self):
|
||||
self._collected = False
|
||||
self._raw_columns = None
|
||||
self._table_columns = None
|
||||
self._stars = None
|
||||
self._derived_tables = None
|
||||
self._udtfs = None
|
||||
|
@ -125,6 +126,7 @@ class Scope:
|
|||
self._derived_tables = []
|
||||
self._udtfs = []
|
||||
self._raw_columns = []
|
||||
self._table_columns = []
|
||||
self._stars = []
|
||||
self._join_hints = []
|
||||
self._semi_anti_join_tables = set()
|
||||
|
@ -156,6 +158,8 @@ class Scope:
|
|||
self._derived_tables.append(node)
|
||||
elif isinstance(node, exp.UNWRAPPED_QUERIES):
|
||||
self._subqueries.append(node)
|
||||
elif isinstance(node, exp.TableColumn):
|
||||
self._table_columns.append(node)
|
||||
|
||||
self._collected = True
|
||||
|
||||
|
@ -309,6 +313,13 @@ class Scope:
|
|||
|
||||
return self._columns
|
||||
|
||||
@property
|
||||
def table_columns(self):
|
||||
if self._table_columns is None:
|
||||
self._ensure_collected()
|
||||
|
||||
return self._table_columns
|
||||
|
||||
@property
|
||||
def selected_sources(self):
|
||||
"""
|
||||
|
@ -849,12 +860,14 @@ def walk_in_scope(expression, bfs=True, prune=None):
|
|||
|
||||
if node is expression:
|
||||
continue
|
||||
|
||||
if (
|
||||
isinstance(node, exp.CTE)
|
||||
or (
|
||||
isinstance(node.parent, (exp.From, exp.Join, exp.Subquery))
|
||||
and (_is_derived_table(node) or isinstance(node, exp.UDTF))
|
||||
and _is_derived_table(node)
|
||||
)
|
||||
or (isinstance(node.parent, exp.UDTF) and isinstance(node, exp.Query))
|
||||
or isinstance(node, exp.UNWRAPPED_QUERIES)
|
||||
):
|
||||
crossed_scope_boundary = True
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue