1
0
Fork 0

Merging upstream version 23.10.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:31:23 +01:00
parent 6cbc5d6f97
commit 49aa147013
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
91 changed files with 52881 additions and 50396 deletions

View file

@ -212,6 +212,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Month,
exp.Week,
exp.Year,
exp.Quarter,
},
exp.DataType.Type.VARCHAR: {
exp.ArrayConcat,
@ -504,7 +505,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
last_datatype = expr_type
break
last_datatype = self._maybe_coerce(last_datatype or expr_type, expr_type)
if not expr_type.is_type(exp.DataType.Type.NULL, exp.DataType.Type.UNKNOWN):
last_datatype = self._maybe_coerce(last_datatype or expr_type, expr_type)
self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)

View file

@ -66,7 +66,7 @@ def qualify(
"""
schema = ensure_schema(schema, dialect=dialect)
expression = normalize_identifiers(expression, dialect=dialect)
expression = qualify_tables(expression, db=db, catalog=catalog, schema=schema)
expression = qualify_tables(expression, db=db, catalog=catalog, schema=schema, dialect=dialect)
if isolate_tables:
expression = isolate_table_selects(expression, schema=schema)

View file

@ -7,6 +7,7 @@ from sqlglot import alias, exp
from sqlglot.dialects.dialect import Dialect, DialectType
from sqlglot.errors import OptimizeError
from sqlglot.helper import seq_get, SingleValuedMapping
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
from sqlglot.optimizer.simplify import simplify_parens
from sqlglot.schema import Schema, ensure_schema
@ -652,8 +653,19 @@ class Resolver:
if isinstance(source, exp.Table):
columns = self.schema.column_names(source, only_visible)
elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
columns = source.expression.alias_column_names
elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)):
columns = source.expression.named_selects
# in bigquery, unnest structs are automatically scoped as tables, so you can
# directly select a struct field in a query.
# this handles the case where the unnest is statically defined.
if self.schema.dialect == "bigquery":
expression = source.expression
annotate_types(expression)
if expression.is_type(exp.DataType.Type.STRUCT):
for k in expression.type.expressions: # type: ignore
columns.append(k.name)
else:
columns = source.expression.named_selects

View file

@ -55,7 +55,7 @@ def qualify_tables(
if not table.args.get("catalog") and table.args.get("db"):
table.set("catalog", catalog)
if not isinstance(expression, exp.Query):
if (db or catalog) and not isinstance(expression, exp.Query):
for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)):
if isinstance(node, exp.Table):
_qualify(node)
@ -78,10 +78,10 @@ def qualify_tables(
if pivots and not pivots[0].alias:
pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())))
table_aliases = {}
for name, source in scope.sources.items():
if isinstance(source, exp.Table):
_qualify(source)
pivots = pivots = source.args.get("pivots")
if not source.alias:
# Don't add the pivot's alias to the pivoted table, use the table's name instead
@ -91,6 +91,12 @@ def qualify_tables(
# Mutates the source by attaching an alias to it
alias(source, name or source.name or next_alias_name(), copy=False, table=True)
table_aliases[".".join(p.name for p in source.parts)] = exp.to_identifier(
source.alias
)
_qualify(source)
if pivots and not pivots[0].alias:
pivots[0].set(
"alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))
@ -127,4 +133,13 @@ def qualify_tables(
# Mutates the table by attaching an alias to it
alias(node, node.name, copy=False, table=True)
for column in scope.columns:
if column.db:
table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1]))
if table_alias:
for p in exp.COLUMN_PARTS[1:]:
column.set(p, None)
column.set("table", table_alias)
return expression

View file

@ -600,7 +600,7 @@ def _traverse_ctes(scope):
sources = {}
for cte in scope.ctes:
recursive_scope = None
cte_name = cte.alias
# if the scope is a recursive cte, it must be in the form of base_case UNION recursive.
# thus the recursive scope is the first section of the union.
@ -609,7 +609,7 @@ def _traverse_ctes(scope):
union = cte.this
if isinstance(union, exp.Union):
recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE)
sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE)
child_scope = None
@ -623,15 +623,9 @@ def _traverse_ctes(scope):
):
yield child_scope
alias = cte.alias
sources[alias] = child_scope
if recursive_scope:
child_scope.add_source(alias, recursive_scope)
child_scope.cte_sources[alias] = recursive_scope
# append the final child_scope yielded
if child_scope:
sources[cte_name] = child_scope
scope.cte_scopes.append(child_scope)
scope.sources.update(sources)

View file

@ -41,8 +41,6 @@ def unnest(select, parent_select, next_alias_name):
return
predicate = select.find_ancestor(exp.Condition)
alias = next_alias_name()
if (
not predicate
or parent_select is not predicate.parent_select
@ -50,6 +48,10 @@ def unnest(select, parent_select, next_alias_name):
):
return
if isinstance(select, exp.Union):
select = exp.select(*select.selects).from_(select.subquery(next_alias_name()))
alias = next_alias_name()
clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
# This subquery returns a scalar and can just be converted to a cross join