Merging upstream version 9.0.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
ebb36a5fc5
commit
4483b8ff47
87 changed files with 7994 additions and 421 deletions
|
@ -1,2 +1 @@
|
|||
from sqlglot.optimizer.optimizer import RULES, optimize
|
||||
from sqlglot.optimizer.schema import Schema
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.helper import ensure_list, subclasses
|
||||
from sqlglot.optimizer.schema import ensure_schema
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import ensure_schema
|
||||
|
||||
|
||||
def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
|
||||
|
|
|
@ -86,7 +86,7 @@ def _eliminate(scope, existing_ctes, taken):
|
|||
if scope.is_union:
|
||||
return _eliminate_union(scope, existing_ctes, taken)
|
||||
|
||||
if scope.is_derived_table and not isinstance(scope.expression, (exp.Unnest, exp.Lateral)):
|
||||
if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF):
|
||||
return _eliminate_derived_table(scope, existing_ctes, taken)
|
||||
|
||||
|
||||
|
|
|
@ -12,18 +12,16 @@ def isolate_table_selects(expression):
|
|||
if not isinstance(source, exp.Table):
|
||||
continue
|
||||
|
||||
if not isinstance(source.parent, exp.Alias):
|
||||
if not source.alias:
|
||||
raise OptimizeError("Tables require an alias. Run qualify_tables optimization.")
|
||||
|
||||
parent = source.parent
|
||||
|
||||
parent.replace(
|
||||
source.replace(
|
||||
exp.select("*")
|
||||
.from_(
|
||||
alias(source, source.name or parent.alias, table=True),
|
||||
alias(source.copy(), source.name or source.alias, table=True),
|
||||
copy=False,
|
||||
)
|
||||
.subquery(parent.alias, copy=False)
|
||||
.subquery(source.alias, copy=False)
|
||||
)
|
||||
|
||||
return expression
|
||||
|
|
|
@ -70,15 +70,10 @@ def merge_ctes(expression, leave_tables_isolated=False):
|
|||
inner_select = inner_scope.expression.unnest()
|
||||
from_or_join = table.find_ancestor(exp.From, exp.Join)
|
||||
if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
||||
node_to_replace = table
|
||||
if isinstance(node_to_replace.parent, exp.Alias):
|
||||
node_to_replace = node_to_replace.parent
|
||||
alias = node_to_replace.alias
|
||||
else:
|
||||
alias = table.name
|
||||
alias = table.alias_or_name
|
||||
|
||||
_rename_inner_sources(outer_scope, inner_scope, alias)
|
||||
_merge_from(outer_scope, inner_scope, node_to_replace, alias)
|
||||
_merge_from(outer_scope, inner_scope, table, alias)
|
||||
_merge_expressions(outer_scope, inner_scope, alias)
|
||||
_merge_joins(outer_scope, inner_scope, from_or_join)
|
||||
_merge_where(outer_scope, inner_scope, from_or_join)
|
||||
|
@ -179,8 +174,8 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
|
|||
|
||||
if isinstance(source, exp.Subquery):
|
||||
source.set("alias", exp.TableAlias(this=new_alias))
|
||||
elif isinstance(source, exp.Table) and isinstance(source.parent, exp.Alias):
|
||||
source.parent.set("alias", new_alias)
|
||||
elif isinstance(source, exp.Table) and source.alias:
|
||||
source.set("alias", new_alias)
|
||||
elif isinstance(source, exp.Table):
|
||||
source.replace(exp.alias_(source.copy(), new_alias))
|
||||
|
||||
|
@ -206,8 +201,7 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
|
|||
tables = join_hint.find_all(exp.Table)
|
||||
for table in tables:
|
||||
if table.alias_or_name == node_to_replace.alias_or_name:
|
||||
new_table = new_subquery.this if isinstance(new_subquery, exp.Alias) else new_subquery
|
||||
table.set("this", exp.to_identifier(new_table.alias_or_name))
|
||||
table.set("this", exp.to_identifier(new_subquery.alias_or_name))
|
||||
outer_scope.remove_source(alias)
|
||||
outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import sqlglot
|
||||
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
|
||||
from sqlglot.optimizer.eliminate_joins import eliminate_joins
|
||||
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
|
||||
|
@ -43,6 +44,7 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
|
|||
1. {table: {col: type}}
|
||||
2. {db: {table: {col: type}}}
|
||||
3. {catalog: {db: {table: {col: type}}}}
|
||||
If no schema is provided then the default schema defined at `sqlgot.schema` will be used
|
||||
db (str): specify the default database, as might be set by a `USE DATABASE db` statement
|
||||
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
|
||||
rules (list): sequence of optimizer rules to use
|
||||
|
@ -50,13 +52,12 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
|
|||
Returns:
|
||||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
|
||||
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema or sqlglot.schema, **kwargs}
|
||||
expression = expression.copy()
|
||||
for rule in rules:
|
||||
|
||||
# Find any additional rule parameters, beyond `expression`
|
||||
rule_params = rule.__code__.co_varnames
|
||||
rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs}
|
||||
|
||||
expression = rule(expression, **rule_kwargs)
|
||||
return expression
|
||||
|
|
|
@ -6,6 +6,9 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
|
|||
# Sentinel value that means an outer query selecting ALL columns
|
||||
SELECT_ALL = object()
|
||||
|
||||
# SELECTION TO USE IF SELECTION LIST IS EMPTY
|
||||
DEFAULT_SELECTION = alias("1", "_")
|
||||
|
||||
|
||||
def pushdown_projections(expression):
|
||||
"""
|
||||
|
@ -25,7 +28,8 @@ def pushdown_projections(expression):
|
|||
"""
|
||||
# Map of Scope to all columns being selected by outer queries.
|
||||
referenced_columns = defaultdict(set)
|
||||
|
||||
left_union = None
|
||||
right_union = None
|
||||
# We build the scope tree (which is traversed in DFS postorder), then iterate
|
||||
# over the result in reverse order. This should ensure that the set of selected
|
||||
# columns for a particular scope are completely build by the time we get to it.
|
||||
|
@ -37,12 +41,16 @@ def pushdown_projections(expression):
|
|||
parent_selections = {SELECT_ALL}
|
||||
|
||||
if isinstance(scope.expression, exp.Union):
|
||||
left, right = scope.union_scopes
|
||||
referenced_columns[left] = parent_selections
|
||||
referenced_columns[right] = parent_selections
|
||||
left_union, right_union = scope.union_scopes
|
||||
referenced_columns[left_union] = parent_selections
|
||||
referenced_columns[right_union] = parent_selections
|
||||
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
_remove_unused_selections(scope, parent_selections)
|
||||
if isinstance(scope.expression, exp.Select) and scope != right_union:
|
||||
removed_indexes = _remove_unused_selections(scope, parent_selections)
|
||||
# The left union is used for column names to select and if we remove columns from the left
|
||||
# we need to also remove those same columns in the right that were at the same position
|
||||
if scope is left_union:
|
||||
_remove_indexed_selections(right_union, removed_indexes)
|
||||
|
||||
# Group columns by source name
|
||||
selects = defaultdict(set)
|
||||
|
@ -61,6 +69,7 @@ def pushdown_projections(expression):
|
|||
|
||||
|
||||
def _remove_unused_selections(scope, parent_selections):
|
||||
removed_indexes = []
|
||||
order = scope.expression.args.get("order")
|
||||
|
||||
if order:
|
||||
|
@ -70,16 +79,26 @@ def _remove_unused_selections(scope, parent_selections):
|
|||
order_refs = set()
|
||||
|
||||
new_selections = []
|
||||
for selection in scope.selects:
|
||||
for i, selection in enumerate(scope.selects):
|
||||
if (
|
||||
SELECT_ALL in parent_selections
|
||||
or selection.alias_or_name in parent_selections
|
||||
or selection.alias_or_name in order_refs
|
||||
):
|
||||
new_selections.append(selection)
|
||||
else:
|
||||
removed_indexes.append(i)
|
||||
|
||||
# If there are no remaining selections, just select a single constant
|
||||
if not new_selections:
|
||||
new_selections.append(alias("1", "_"))
|
||||
new_selections.append(DEFAULT_SELECTION)
|
||||
|
||||
scope.expression.set("expressions", new_selections)
|
||||
return removed_indexes
|
||||
|
||||
|
||||
def _remove_indexed_selections(scope, indexes_to_remove):
|
||||
new_selections = [selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove]
|
||||
if not new_selections:
|
||||
new_selections.append(DEFAULT_SELECTION)
|
||||
scope.expression.set("expressions", new_selections)
|
||||
|
|
|
@ -2,8 +2,8 @@ import itertools
|
|||
|
||||
from sqlglot import alias, exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.optimizer.schema import ensure_schema
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import ensure_schema
|
||||
|
||||
|
||||
def qualify_columns(expression, schema):
|
||||
|
@ -48,7 +48,7 @@ def _pop_table_column_aliases(derived_tables):
|
|||
(e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
|
||||
"""
|
||||
for derived_table in derived_tables:
|
||||
if isinstance(derived_table, exp.UDTF):
|
||||
if isinstance(derived_table.unnest(), exp.UDTF):
|
||||
continue
|
||||
table_alias = derived_table.args.get("alias")
|
||||
if table_alias:
|
||||
|
@ -211,6 +211,22 @@ def _qualify_columns(scope, resolver):
|
|||
if column_table:
|
||||
column.set("table", exp.to_identifier(column_table))
|
||||
|
||||
# Determine whether each reference in the order by clause is to a column or an alias.
|
||||
for ordered in scope.find_all(exp.Ordered):
|
||||
for column in ordered.find_all(exp.Column):
|
||||
column_table = column.table
|
||||
column_name = column.name
|
||||
|
||||
if column_table or column.parent is ordered or column_name not in resolver.all_columns:
|
||||
continue
|
||||
|
||||
column_table = resolver.get_table(column_name)
|
||||
|
||||
if column_table is None:
|
||||
raise OptimizeError(f"Ambiguous column: {column_name}")
|
||||
|
||||
column.set("table", exp.to_identifier(column_table))
|
||||
|
||||
|
||||
def _expand_stars(scope, resolver):
|
||||
"""Expand stars to lists of column selections"""
|
||||
|
@ -346,6 +362,11 @@ class _Resolver:
|
|||
except Exception as e:
|
||||
raise OptimizeError(str(e)) from e
|
||||
|
||||
if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
|
||||
values_alias = source.expression.parent
|
||||
if hasattr(values_alias, "alias_column_names"):
|
||||
return values_alias.alias_column_names
|
||||
|
||||
# Otherwise, if referencing another scope, return that scope's named selects
|
||||
return source.expression.named_selects
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ def qualify_tables(expression, db=None, catalog=None):
|
|||
if not source.args.get("catalog"):
|
||||
source.set("catalog", exp.to_identifier(catalog))
|
||||
|
||||
if not isinstance(source.parent, exp.Alias):
|
||||
if not source.alias:
|
||||
source.replace(
|
||||
alias(
|
||||
source.copy(),
|
||||
|
|
|
@ -1,180 +0,0 @@
|
|||
import abc
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import csv_reader
|
||||
|
||||
|
||||
class Schema(abc.ABC):
|
||||
"""Abstract base class for database schemas"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def column_names(self, table, only_visible=False):
|
||||
"""
|
||||
Get the column names for a table.
|
||||
Args:
|
||||
table (sqlglot.expressions.Table): Table expression instance
|
||||
only_visible (bool): Whether to include invisible columns
|
||||
Returns:
|
||||
list[str]: list of column names
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_column_type(self, table, column):
|
||||
"""
|
||||
Get the exp.DataType type of a column in the schema.
|
||||
|
||||
Args:
|
||||
table (sqlglot.expressions.Table): The source table.
|
||||
column (sqlglot.expressions.Column): The target column.
|
||||
Returns:
|
||||
sqlglot.expressions.DataType.Type: The resulting column type.
|
||||
"""
|
||||
|
||||
|
||||
class MappingSchema(Schema):
|
||||
"""
|
||||
Schema based on a nested mapping.
|
||||
|
||||
Args:
|
||||
schema (dict): Mapping in one of the following forms:
|
||||
1. {table: {col: type}}
|
||||
2. {db: {table: {col: type}}}
|
||||
3. {catalog: {db: {table: {col: type}}}}
|
||||
visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
|
||||
are assumed to be visible. The nesting should mirror that of the schema:
|
||||
1. {table: set(*cols)}}
|
||||
2. {db: {table: set(*cols)}}}
|
||||
3. {catalog: {db: {table: set(*cols)}}}}
|
||||
dialect (str): The dialect to be used for custom type mappings.
|
||||
"""
|
||||
|
||||
def __init__(self, schema, visible=None, dialect=None):
|
||||
self.schema = schema
|
||||
self.visible = visible
|
||||
self.dialect = dialect
|
||||
self._type_mapping_cache = {}
|
||||
|
||||
depth = _dict_depth(schema)
|
||||
|
||||
if not depth: # {}
|
||||
self.supported_table_args = []
|
||||
elif depth == 2: # {table: {col: type}}
|
||||
self.supported_table_args = ("this",)
|
||||
elif depth == 3: # {db: {table: {col: type}}}
|
||||
self.supported_table_args = ("db", "this")
|
||||
elif depth == 4: # {catalog: {db: {table: {col: type}}}}
|
||||
self.supported_table_args = ("catalog", "db", "this")
|
||||
else:
|
||||
raise OptimizeError(f"Invalid schema shape. Depth: {depth}")
|
||||
|
||||
self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args)
|
||||
|
||||
def column_names(self, table, only_visible=False):
|
||||
if not isinstance(table.this, exp.Identifier):
|
||||
return fs_get(table)
|
||||
|
||||
args = tuple(table.text(p) for p in self.supported_table_args)
|
||||
|
||||
for forbidden in self.forbidden_args:
|
||||
if table.text(forbidden):
|
||||
raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
|
||||
|
||||
columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
|
||||
if not only_visible or not self.visible:
|
||||
return columns
|
||||
|
||||
visible = _nested_get(self.visible, *zip(self.supported_table_args, args))
|
||||
return [col for col in columns if col in visible]
|
||||
|
||||
def get_column_type(self, table, column):
|
||||
try:
|
||||
schema_type = self.schema.get(table.name, {}).get(column.name).upper()
|
||||
return self._convert_type(schema_type)
|
||||
except:
|
||||
raise OptimizeError(f"Failed to get type for column {column.sql()}")
|
||||
|
||||
def _convert_type(self, schema_type):
|
||||
"""
|
||||
Convert a type represented as a string to the corresponding exp.DataType.Type object.
|
||||
|
||||
Args:
|
||||
schema_type (str): The type we want to convert.
|
||||
Returns:
|
||||
sqlglot.expressions.DataType.Type: The resulting expression type.
|
||||
"""
|
||||
if schema_type not in self._type_mapping_cache:
|
||||
try:
|
||||
self._type_mapping_cache[schema_type] = exp.maybe_parse(
|
||||
schema_type, into=exp.DataType, dialect=self.dialect
|
||||
).this
|
||||
except AttributeError:
|
||||
raise OptimizeError(f"Failed to convert type {schema_type}")
|
||||
|
||||
return self._type_mapping_cache[schema_type]
|
||||
|
||||
|
||||
def ensure_schema(schema):
|
||||
if isinstance(schema, Schema):
|
||||
return schema
|
||||
|
||||
return MappingSchema(schema)
|
||||
|
||||
|
||||
def fs_get(table):
|
||||
name = table.this.name
|
||||
|
||||
if name.upper() == "READ_CSV":
|
||||
with csv_reader(table) as reader:
|
||||
return next(reader)
|
||||
|
||||
raise ValueError(f"Cannot read schema for {table}")
|
||||
|
||||
|
||||
def _nested_get(d, *path):
|
||||
"""
|
||||
Get a value for a nested dictionary.
|
||||
|
||||
Args:
|
||||
d (dict): dictionary
|
||||
*path (tuple[str, str]): tuples of (name, key)
|
||||
`key` is the key in the dictionary to get.
|
||||
`name` is a string to use in the error if `key` isn't found.
|
||||
"""
|
||||
for name, key in path:
|
||||
d = d.get(key)
|
||||
if d is None:
|
||||
name = "table" if name == "this" else name
|
||||
raise ValueError(f"Unknown {name}")
|
||||
return d
|
||||
|
||||
|
||||
def _dict_depth(d):
|
||||
"""
|
||||
Get the nesting depth of a dictionary.
|
||||
|
||||
For example:
|
||||
>>> _dict_depth(None)
|
||||
0
|
||||
>>> _dict_depth({})
|
||||
1
|
||||
>>> _dict_depth({"a": "b"})
|
||||
1
|
||||
>>> _dict_depth({"a": {}})
|
||||
2
|
||||
>>> _dict_depth({"a": {"b": {}}})
|
||||
3
|
||||
|
||||
Args:
|
||||
d (dict): dictionary
|
||||
Returns:
|
||||
int: depth
|
||||
"""
|
||||
try:
|
||||
return 1 + _dict_depth(next(iter(d.values())))
|
||||
except AttributeError:
|
||||
# d doesn't have attribute "values"
|
||||
return 0
|
||||
except StopIteration:
|
||||
# d.values() returns an empty sequence
|
||||
return 1
|
|
@ -257,12 +257,7 @@ class Scope:
|
|||
referenced_names = []
|
||||
|
||||
for table in self.tables:
|
||||
referenced_names.append(
|
||||
(
|
||||
table.parent.alias if isinstance(table.parent, exp.Alias) else table.name,
|
||||
table,
|
||||
)
|
||||
)
|
||||
referenced_names.append((table.alias_or_name, table))
|
||||
for derived_table in self.derived_tables:
|
||||
referenced_names.append((derived_table.alias, derived_table.unnest()))
|
||||
|
||||
|
@ -538,8 +533,8 @@ def _add_table_sources(scope):
|
|||
for table in scope.tables:
|
||||
table_name = table.name
|
||||
|
||||
if isinstance(table.parent, exp.Alias):
|
||||
source_name = table.parent.alias
|
||||
if table.alias:
|
||||
source_name = table.alias
|
||||
else:
|
||||
source_name = table_name
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue