1
0
Fork 0

Merging upstream version 10.5.10.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:07:05 +01:00
parent 8588db6332
commit 4d496b7a6a
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
43 changed files with 1384 additions and 356 deletions

View file

@ -1 +1,2 @@
from sqlglot.optimizer.optimizer import RULES, optimize
from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope

View file

@ -1,15 +1,18 @@
from sqlglot import alias, exp
from sqlglot.errors import OptimizeError
from sqlglot.optimizer.scope import traverse_scope
from sqlglot.schema import ensure_schema
def isolate_table_selects(expression):
def isolate_table_selects(expression, schema=None):
schema = ensure_schema(schema)
for scope in traverse_scope(expression):
if len(scope.selected_sources) == 1:
continue
for (_, source) in scope.selected_sources.values():
if not isinstance(source, exp.Table):
if not isinstance(source, exp.Table) or not schema.column_names(source):
continue
if not source.alias:

View file

@ -1,7 +1,8 @@
import itertools
import typing as t
from sqlglot import alias, exp
from sqlglot.errors import OptimizeError, SchemaError
from sqlglot.errors import OptimizeError
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
@ -190,20 +191,15 @@ def _qualify_columns(scope, resolver):
column_table = column.table
column_name = column.name
if (
column_table
and column_table in scope.sources
and column_name not in resolver.get_source_columns(column_table)
):
raise OptimizeError(f"Unknown column: {column_name}")
if column_table and column_table in scope.sources:
source_columns = resolver.get_source_columns(column_table)
if source_columns and column_name not in source_columns:
raise OptimizeError(f"Unknown column: {column_name}")
if not column_table:
column_table = resolver.get_table(column_name)
if not scope.is_subquery and not scope.is_udtf:
if column_name not in resolver.all_columns:
raise OptimizeError(f"Unknown column: {column_name}")
if column_table is None:
raise OptimizeError(f"Ambiguous column: {column_name}")
@ -265,6 +261,10 @@ def _expand_stars(scope, resolver):
if table not in scope.sources:
raise OptimizeError(f"Unknown table: {table}")
columns = resolver.get_source_columns(table, only_visible=True)
if not columns:
raise OptimizeError(
f"Table has no schema/columns. Cannot expand star for table: {table}."
)
table_id = id(table)
for name in columns:
if name not in except_columns.get(table_id, set()):
@ -306,16 +306,11 @@ def _qualify_outputs(scope):
for i, (selection, aliased_column) in enumerate(
itertools.zip_longest(scope.selects, scope.outer_column_list)
):
if isinstance(selection, exp.Column):
# convoluted setter because a simple selection.replace(alias) would require a copy
alias_ = alias(exp.column(""), alias=selection.name)
alias_.set("this", selection)
selection = alias_
elif isinstance(selection, exp.Subquery):
if not selection.alias:
if isinstance(selection, exp.Subquery):
if not selection.output_name:
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
elif not isinstance(selection, exp.Alias):
alias_ = alias(exp.column(""), f"_col_{i}")
alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}")
alias_.set("this", selection)
selection = alias_
@ -346,20 +341,30 @@ class _Resolver:
self._unambiguous_columns = None
self._all_columns = None
def get_table(self, column_name):
def get_table(self, column_name: str) -> t.Optional[str]:
"""
Get the table for a column name.
Args:
column_name (str)
column_name: The column name to find the table for.
Returns:
(str) table name
The table name if it can be found/inferred.
"""
if self._unambiguous_columns is None:
self._unambiguous_columns = self._get_unambiguous_columns(
self._get_all_source_columns()
)
return self._unambiguous_columns.get(column_name)
table = self._unambiguous_columns.get(column_name)
if not table:
sources_without_schema = tuple(
source for source, columns in self._get_all_source_columns().items() if not columns
)
if len(sources_without_schema) == 1:
return sources_without_schema[0]
return table
@property
def all_columns(self):
@ -379,10 +384,7 @@ class _Resolver:
# If referencing a table, return the columns from the schema
if isinstance(source, exp.Table):
try:
return self.schema.column_names(source, only_visible)
except Exception as e:
raise SchemaError(str(e)) from e
return self.schema.column_names(source, only_visible)
if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
return source.expression.alias_column_names

View file

@ -230,7 +230,7 @@ class Scope:
column for scope in self.subquery_scopes for column in scope.external_columns
]
named_outputs = {e.alias_or_name for e in self.expression.expressions}
named_selects = set(self.expression.named_selects)
self._columns = []
for column in columns + external_columns:
@ -238,7 +238,7 @@ class Scope:
if (
not ancestor
or column.table
or (column.name not in named_outputs and not isinstance(ancestor, exp.Hint))
or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
):
self._columns.append(column)