Adding upstream version 6.2.8.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
d62bab68ae
commit
24cf9d8984
22 changed files with 361 additions and 98 deletions
|
@ -1,6 +1,6 @@
|
|||
# SQLGlot
|
||||
|
||||
SQLGlot is a no dependency Python SQL parser, transpiler, and optimizer. It can be used to format SQL or translate between different dialects like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects.
|
||||
SQLGlot is a no dependency Python SQL parser, transpiler, and optimizer. It can be used to format SQL or translate between different dialects like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects.
|
||||
|
||||
It is a very comprehensive generic SQL parser with a robust [test suite](tests). It is also quite [performant](#benchmarks) while being written purely in Python.
|
||||
|
||||
|
@ -30,7 +30,7 @@ sqlglot.transpile("SELECT EPOCH_MS(1618088028295)", read='duckdb', write='hive')
|
|||
```
|
||||
|
||||
```sql
|
||||
SELECT TO_UTC_TIMESTAMP(FROM_UNIXTIME(1618088028295 / 1000, 'yyyy-MM-dd HH:mm:ss'), 'UTC')
|
||||
SELECT FROM_UNIXTIME(1618088028295 / 1000)
|
||||
```
|
||||
|
||||
SQLGlot can even translate custom time formats.
|
||||
|
@ -299,7 +299,7 @@ class Custom(Dialect):
|
|||
}
|
||||
|
||||
|
||||
Dialects["custom"]
|
||||
Dialect["custom"]
|
||||
```
|
||||
|
||||
## Benchmarks
|
||||
|
|
|
@ -20,7 +20,7 @@ from sqlglot.generator import Generator
|
|||
from sqlglot.parser import Parser
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
|
||||
__version__ = "6.2.6"
|
||||
__version__ = "6.2.8"
|
||||
|
||||
pretty = False
|
||||
|
||||
|
|
|
@ -33,10 +33,10 @@ def _date_add_sql(data_type, kind):
|
|||
return func
|
||||
|
||||
|
||||
def _subquery_to_unnest_if_values(self, expression):
|
||||
if not isinstance(expression.this, exp.Values):
|
||||
return self.subquery_sql(expression)
|
||||
rows = [list(tuple_exp.find_all(exp.Literal)) for tuple_exp in expression.this.find_all(exp.Tuple)]
|
||||
def _derived_table_values_to_unnest(self, expression):
|
||||
if not isinstance(expression.unnest().parent, exp.From):
|
||||
return self.values_sql(expression)
|
||||
rows = [list(tuple_exp.find_all(exp.Literal)) for tuple_exp in expression.find_all(exp.Tuple)]
|
||||
structs = []
|
||||
for row in rows:
|
||||
aliases = [
|
||||
|
@ -99,6 +99,7 @@ class BigQuery(Dialect):
|
|||
"QUALIFY": TokenType.QUALIFY,
|
||||
"UNKNOWN": TokenType.NULL,
|
||||
"WINDOW": TokenType.WINDOW,
|
||||
"NOT DETERMINISTIC": TokenType.VOLATILE,
|
||||
}
|
||||
|
||||
class Parser(Parser):
|
||||
|
@ -140,9 +141,10 @@ class BigQuery(Dialect):
|
|||
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
|
||||
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
exp.Subquery: _subquery_to_unnest_if_values,
|
||||
exp.Values: _derived_table_values_to_unnest,
|
||||
exp.ReturnsProperty: _returnsproperty_sql,
|
||||
exp.Create: _create_sql,
|
||||
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC",
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
|
@ -160,6 +162,16 @@ class BigQuery(Dialect):
|
|||
exp.DataType.Type.NVARCHAR: "STRING",
|
||||
}
|
||||
|
||||
ROOT_PROPERTIES = {
|
||||
exp.LanguageProperty,
|
||||
exp.ReturnsProperty,
|
||||
exp.VolatilityProperty,
|
||||
}
|
||||
|
||||
WITH_PROPERTIES = {
|
||||
exp.AnonymousProperty,
|
||||
}
|
||||
|
||||
def in_unnest_op(self, unnest):
|
||||
return self.sql(unnest)
|
||||
|
||||
|
|
|
@ -77,6 +77,7 @@ class Dialect(metaclass=_Dialect):
|
|||
alias_post_tablesample = False
|
||||
normalize_functions = "upper"
|
||||
null_ordering = "nulls_are_small"
|
||||
wrap_derived_values = True
|
||||
|
||||
date_format = "'%Y-%m-%d'"
|
||||
dateint_format = "'%Y%m%d'"
|
||||
|
@ -169,6 +170,7 @@ class Dialect(metaclass=_Dialect):
|
|||
"alias_post_tablesample": self.alias_post_tablesample,
|
||||
"normalize_functions": self.normalize_functions,
|
||||
"null_ordering": self.null_ordering,
|
||||
"wrap_derived_values": self.wrap_derived_values,
|
||||
**opts,
|
||||
}
|
||||
)
|
||||
|
|
|
@ -177,6 +177,8 @@ class Snowflake(Dialect):
|
|||
exp.ReturnsProperty,
|
||||
exp.LanguageProperty,
|
||||
exp.SchemaCommentProperty,
|
||||
exp.ExecuteAsProperty,
|
||||
exp.VolatilityProperty,
|
||||
}
|
||||
|
||||
def except_op(self, expression):
|
||||
|
|
|
@ -47,6 +47,8 @@ def _unix_to_time(self, expression):
|
|||
|
||||
|
||||
class Spark(Hive):
|
||||
wrap_derived_values = False
|
||||
|
||||
class Parser(Hive.Parser):
|
||||
FUNCTIONS = {
|
||||
**Hive.Parser.FUNCTIONS,
|
||||
|
|
|
@ -213,21 +213,23 @@ class Expression(metaclass=_Expression):
|
|||
"""
|
||||
return self.find_ancestor(Select)
|
||||
|
||||
def walk(self, bfs=True):
|
||||
def walk(self, bfs=True, prune=None):
|
||||
"""
|
||||
Returns a generator object which visits all nodes in this tree.
|
||||
|
||||
Args:
|
||||
bfs (bool): if set to True the BFS traversal order will be applied,
|
||||
otherwise the DFS traversal will be used instead.
|
||||
prune ((node, parent, arg_key) -> bool): callable that returns True if
|
||||
the generator should stop traversing this branch of the tree.
|
||||
|
||||
Returns:
|
||||
the generator object.
|
||||
"""
|
||||
if bfs:
|
||||
yield from self.bfs()
|
||||
yield from self.bfs(prune=prune)
|
||||
else:
|
||||
yield from self.dfs()
|
||||
yield from self.dfs(prune=prune)
|
||||
|
||||
def dfs(self, parent=None, key=None, prune=None):
|
||||
"""
|
||||
|
@ -506,6 +508,10 @@ class DerivedTable(Expression):
|
|||
return [select.alias_or_name for select in self.selects]
|
||||
|
||||
|
||||
class UDTF(DerivedTable):
|
||||
pass
|
||||
|
||||
|
||||
class Annotation(Expression):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
|
@ -652,7 +658,13 @@ class Delete(Expression):
|
|||
|
||||
|
||||
class Drop(Expression):
|
||||
arg_types = {"this": False, "kind": False, "exists": False}
|
||||
arg_types = {
|
||||
"this": False,
|
||||
"kind": False,
|
||||
"exists": False,
|
||||
"temporary": False,
|
||||
"materialized": False,
|
||||
}
|
||||
|
||||
|
||||
class Filter(Expression):
|
||||
|
@ -827,7 +839,7 @@ class Join(Expression):
|
|||
return join
|
||||
|
||||
|
||||
class Lateral(DerivedTable):
|
||||
class Lateral(UDTF):
|
||||
arg_types = {"this": True, "outer": False, "alias": False}
|
||||
|
||||
|
||||
|
@ -915,6 +927,14 @@ class LanguageProperty(Property):
|
|||
pass
|
||||
|
||||
|
||||
class ExecuteAsProperty(Property):
|
||||
pass
|
||||
|
||||
|
||||
class VolatilityProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class Properties(Expression):
|
||||
arg_types = {"expressions": True}
|
||||
|
||||
|
@ -1098,7 +1118,7 @@ class Intersect(Union):
|
|||
pass
|
||||
|
||||
|
||||
class Unnest(DerivedTable):
|
||||
class Unnest(UDTF):
|
||||
arg_types = {
|
||||
"expressions": True,
|
||||
"ordinality": False,
|
||||
|
@ -1116,8 +1136,12 @@ class Update(Expression):
|
|||
}
|
||||
|
||||
|
||||
class Values(Expression):
|
||||
arg_types = {"expressions": True}
|
||||
class Values(UDTF):
|
||||
arg_types = {
|
||||
"expressions": True,
|
||||
"ordinality": False,
|
||||
"alias": False,
|
||||
}
|
||||
|
||||
|
||||
class Var(Expression):
|
||||
|
@ -2033,23 +2057,17 @@ class Func(Condition):
|
|||
|
||||
@classmethod
|
||||
def from_arg_list(cls, args):
|
||||
args_num = len(args)
|
||||
if cls.is_var_len_args:
|
||||
all_arg_keys = list(cls.arg_types)
|
||||
# If this function supports variable length argument treat the last argument as such.
|
||||
non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys
|
||||
num_non_var = len(non_var_len_arg_keys)
|
||||
|
||||
all_arg_keys = list(cls.arg_types)
|
||||
# If this function supports variable length argument treat the last argument as such.
|
||||
non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys
|
||||
args_dict = {arg_key: arg for arg, arg_key in zip(args, non_var_len_arg_keys)}
|
||||
args_dict[all_arg_keys[-1]] = args[num_non_var:]
|
||||
else:
|
||||
args_dict = {arg_key: arg for arg, arg_key in zip(args, cls.arg_types)}
|
||||
|
||||
args_dict = {}
|
||||
arg_idx = 0
|
||||
for arg_key in non_var_len_arg_keys:
|
||||
if arg_idx >= args_num:
|
||||
break
|
||||
if args[arg_idx] is not None:
|
||||
args_dict[arg_key] = args[arg_idx]
|
||||
arg_idx += 1
|
||||
|
||||
if arg_idx < args_num and cls.is_var_len_args:
|
||||
args_dict[all_arg_keys[-1]] = args[arg_idx:]
|
||||
return cls(**args_dict)
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -49,10 +49,12 @@ class Generator:
|
|||
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}",
|
||||
exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
|
||||
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
|
||||
exp.LanguageProperty: lambda self, e: self.naked_property(e),
|
||||
exp.LocationProperty: lambda self, e: self.naked_property(e),
|
||||
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
|
||||
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
|
||||
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
|
||||
exp.VolatilityProperty: lambda self, e: self.sql(e.name),
|
||||
}
|
||||
|
||||
NULL_ORDERING_SUPPORTED = True
|
||||
|
@ -99,6 +101,7 @@ class Generator:
|
|||
"unsupported_messages",
|
||||
"null_ordering",
|
||||
"max_unsupported",
|
||||
"wrap_derived_values",
|
||||
"_indent",
|
||||
"_replace_backslash",
|
||||
"_escaped_quote_end",
|
||||
|
@ -127,6 +130,7 @@ class Generator:
|
|||
null_ordering=None,
|
||||
max_unsupported=3,
|
||||
leading_comma=False,
|
||||
wrap_derived_values=True,
|
||||
):
|
||||
import sqlglot
|
||||
|
||||
|
@ -150,6 +154,7 @@ class Generator:
|
|||
self.unsupported_messages = []
|
||||
self.max_unsupported = max_unsupported
|
||||
self.null_ordering = null_ordering
|
||||
self.wrap_derived_values = wrap_derived_values
|
||||
self._indent = indent
|
||||
self._replace_backslash = self.escape == "\\"
|
||||
self._escaped_quote_end = self.escape + self.quote_end
|
||||
|
@ -407,7 +412,9 @@ class Generator:
|
|||
this = self.sql(expression, "this")
|
||||
kind = expression.args["kind"]
|
||||
exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
|
||||
return f"DROP {kind}{exists_sql}{this}"
|
||||
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
|
||||
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
|
||||
return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}"
|
||||
|
||||
def except_sql(self, expression):
|
||||
return self.prepend_ctes(
|
||||
|
@ -583,7 +590,14 @@ class Generator:
|
|||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
def values_sql(self, expression):
|
||||
return f"VALUES{self.seg('')}{self.expressions(expression)}"
|
||||
alias = self.sql(expression, "alias")
|
||||
args = self.expressions(expression)
|
||||
if not alias:
|
||||
return f"VALUES{self.seg('')}{args}"
|
||||
alias = f" AS {alias}" if alias else alias
|
||||
if self.wrap_derived_values:
|
||||
return f"(VALUES{self.seg('')}{args}){alias}"
|
||||
return f"VALUES{self.seg('')}{args}{alias}"
|
||||
|
||||
def var_sql(self, expression):
|
||||
return self.sql(expression, "this")
|
||||
|
|
|
@ -32,8 +32,8 @@ def merge_subqueries(expression, leave_tables_isolated=False):
|
|||
Returns:
|
||||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
merge_ctes(expression, leave_tables_isolated)
|
||||
merge_derived_tables(expression, leave_tables_isolated)
|
||||
expression = merge_ctes(expression, leave_tables_isolated)
|
||||
expression = merge_derived_tables(expression, leave_tables_isolated)
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -76,14 +76,14 @@ def merge_ctes(expression, leave_tables_isolated=False):
|
|||
alias = node_to_replace.alias
|
||||
else:
|
||||
alias = table.name
|
||||
|
||||
_rename_inner_sources(outer_scope, inner_scope, alias)
|
||||
_merge_from(outer_scope, inner_scope, node_to_replace, alias)
|
||||
_merge_joins(outer_scope, inner_scope, from_or_join)
|
||||
_merge_expressions(outer_scope, inner_scope, alias)
|
||||
_merge_joins(outer_scope, inner_scope, from_or_join)
|
||||
_merge_where(outer_scope, inner_scope, from_or_join)
|
||||
_merge_order(outer_scope, inner_scope)
|
||||
_pop_cte(inner_scope)
|
||||
return expression
|
||||
|
||||
|
||||
def merge_derived_tables(expression, leave_tables_isolated=False):
|
||||
|
@ -97,10 +97,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
|
|||
|
||||
_rename_inner_sources(outer_scope, inner_scope, alias)
|
||||
_merge_from(outer_scope, inner_scope, subquery, alias)
|
||||
_merge_joins(outer_scope, inner_scope, from_or_join)
|
||||
_merge_expressions(outer_scope, inner_scope, alias)
|
||||
_merge_joins(outer_scope, inner_scope, from_or_join)
|
||||
_merge_where(outer_scope, inner_scope, from_or_join)
|
||||
_merge_order(outer_scope, inner_scope)
|
||||
return expression
|
||||
|
||||
|
||||
def _mergeable(outer_scope, inner_select, leave_tables_isolated):
|
||||
|
@ -229,7 +230,7 @@ def _merge_expressions(outer_scope, inner_scope, alias):
|
|||
continue
|
||||
columns_to_replace = outer_columns.get(projection_name, [])
|
||||
for column in columns_to_replace:
|
||||
column.replace(expression.unalias())
|
||||
column.replace(expression.unalias().copy())
|
||||
|
||||
|
||||
def _merge_where(outer_scope, inner_scope, from_or_join):
|
||||
|
|
|
@ -5,8 +5,6 @@ from sqlglot.errors import OptimizeError
|
|||
from sqlglot.optimizer.schema import ensure_schema
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
|
||||
SKIP_QUALIFY = (exp.Unnest, exp.Lateral)
|
||||
|
||||
|
||||
def qualify_columns(expression, schema):
|
||||
"""
|
||||
|
@ -35,7 +33,7 @@ def qualify_columns(expression, schema):
|
|||
_expand_group_by(scope, resolver)
|
||||
_expand_order_by(scope)
|
||||
_qualify_columns(scope, resolver)
|
||||
if not isinstance(scope.expression, SKIP_QUALIFY):
|
||||
if not isinstance(scope.expression, exp.UDTF):
|
||||
_expand_stars(scope, resolver)
|
||||
_qualify_outputs(scope)
|
||||
_check_unknown_tables(scope)
|
||||
|
@ -50,7 +48,7 @@ def _pop_table_column_aliases(derived_tables):
|
|||
(e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
|
||||
"""
|
||||
for derived_table in derived_tables:
|
||||
if isinstance(derived_table, SKIP_QUALIFY):
|
||||
if isinstance(derived_table, exp.UDTF):
|
||||
continue
|
||||
table_alias = derived_table.args.get("alias")
|
||||
if table_alias:
|
||||
|
@ -202,7 +200,7 @@ def _qualify_columns(scope, resolver):
|
|||
if not column_table:
|
||||
column_table = resolver.get_table(column_name)
|
||||
|
||||
if not scope.is_subquery and not scope.is_unnest:
|
||||
if not scope.is_subquery and not scope.is_udtf:
|
||||
if column_name not in resolver.all_columns:
|
||||
raise OptimizeError(f"Unknown column: {column_name}")
|
||||
|
||||
|
@ -296,7 +294,7 @@ def _qualify_outputs(scope):
|
|||
|
||||
|
||||
def _check_unknown_tables(scope):
|
||||
if scope.external_columns and not scope.is_unnest and not scope.is_correlated_subquery:
|
||||
if scope.external_columns and not scope.is_udtf and not scope.is_correlated_subquery:
|
||||
raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import itertools
|
||||
from copy import copy
|
||||
from enum import Enum, auto
|
||||
|
||||
from sqlglot import exp
|
||||
|
@ -12,7 +11,7 @@ class ScopeType(Enum):
|
|||
DERIVED_TABLE = auto()
|
||||
CTE = auto()
|
||||
UNION = auto()
|
||||
UNNEST = auto()
|
||||
UDTF = auto()
|
||||
|
||||
|
||||
class Scope:
|
||||
|
@ -70,14 +69,11 @@ class Scope:
|
|||
self._columns = None
|
||||
self._external_columns = None
|
||||
|
||||
def branch(self, expression, scope_type, add_sources=None, **kwargs):
|
||||
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
|
||||
"""Branch from the current scope to a new, inner scope"""
|
||||
sources = copy(self.sources)
|
||||
if add_sources:
|
||||
sources.update(add_sources)
|
||||
return Scope(
|
||||
expression=expression.unnest(),
|
||||
sources=sources,
|
||||
sources={**self.cte_sources, **(chain_sources or {})},
|
||||
parent=self,
|
||||
scope_type=scope_type,
|
||||
**kwargs,
|
||||
|
@ -90,30 +86,21 @@ class Scope:
|
|||
self._derived_tables = []
|
||||
self._raw_columns = []
|
||||
|
||||
# We'll use this variable to pass state into the dfs generator.
|
||||
# Whenever we set it to True, we exclude a subtree from traversal.
|
||||
prune = False
|
||||
|
||||
for node, parent, _ in self.expression.dfs(prune=lambda *_: prune):
|
||||
prune = False
|
||||
|
||||
for node, parent, _ in self.walk(bfs=False):
|
||||
if node is self.expression:
|
||||
continue
|
||||
if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
|
||||
elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
|
||||
self._raw_columns.append(node)
|
||||
elif isinstance(node, exp.Table):
|
||||
self._tables.append(node)
|
||||
elif isinstance(node, (exp.Unnest, exp.Lateral)):
|
||||
elif isinstance(node, exp.UDTF):
|
||||
self._derived_tables.append(node)
|
||||
elif isinstance(node, exp.CTE):
|
||||
self._ctes.append(node)
|
||||
prune = True
|
||||
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
|
||||
self._derived_tables.append(node)
|
||||
prune = True
|
||||
elif isinstance(node, exp.Subqueryable):
|
||||
self._subqueries.append(node)
|
||||
prune = True
|
||||
|
||||
self._collected = True
|
||||
|
||||
|
@ -121,6 +108,43 @@ class Scope:
|
|||
if not self._collected:
|
||||
self._collect()
|
||||
|
||||
def walk(self, bfs=True):
|
||||
return walk_in_scope(self.expression, bfs=bfs)
|
||||
|
||||
def find(self, *expression_types, bfs=True):
|
||||
"""
|
||||
Returns the first node in this scope which matches at least one of the specified types.
|
||||
|
||||
This does NOT traverse into subscopes.
|
||||
|
||||
Args:
|
||||
expression_types (type): the expression type(s) to match.
|
||||
bfs (bool): True to use breadth-first search, False to use depth-first.
|
||||
|
||||
Returns:
|
||||
exp.Expression: the node which matches the criteria or None if no node matching
|
||||
the criteria was found.
|
||||
"""
|
||||
return next(self.find_all(*expression_types, bfs=bfs), None)
|
||||
|
||||
def find_all(self, *expression_types, bfs=True):
|
||||
"""
|
||||
Returns a generator object which visits all nodes in this scope and only yields those that
|
||||
match at least one of the specified expression types.
|
||||
|
||||
This does NOT traverse into subscopes.
|
||||
|
||||
Args:
|
||||
expression_types (type): the expression type(s) to match.
|
||||
bfs (bool): True to use breadth-first search, False to use depth-first.
|
||||
|
||||
Yields:
|
||||
exp.Expression: nodes
|
||||
"""
|
||||
for expression, _, _ in self.walk(bfs=bfs):
|
||||
if isinstance(expression, expression_types):
|
||||
yield expression
|
||||
|
||||
def replace(self, old, new):
|
||||
"""
|
||||
Replace `old` with `new`.
|
||||
|
@ -246,6 +270,16 @@ class Scope:
|
|||
self._selected_sources = result
|
||||
return self._selected_sources
|
||||
|
||||
@property
|
||||
def cte_sources(self):
|
||||
"""
|
||||
Sources that are CTEs.
|
||||
|
||||
Returns:
|
||||
dict[str, Scope]: Mapping of source alias to Scope
|
||||
"""
|
||||
return {alias: scope for alias, scope in self.sources.items() if isinstance(scope, Scope) and scope.is_cte}
|
||||
|
||||
@property
|
||||
def selects(self):
|
||||
"""
|
||||
|
@ -313,9 +347,9 @@ class Scope:
|
|||
return self.scope_type == ScopeType.ROOT
|
||||
|
||||
@property
|
||||
def is_unnest(self):
|
||||
"""Determine if this scope is an unnest"""
|
||||
return self.scope_type == ScopeType.UNNEST
|
||||
def is_udtf(self):
|
||||
"""Determine if this scope is a UDTF (User Defined Table Function)"""
|
||||
return self.scope_type == ScopeType.UDTF
|
||||
|
||||
@property
|
||||
def is_correlated_subquery(self):
|
||||
|
@ -348,7 +382,7 @@ class Scope:
|
|||
Scope: scope instances in depth-first-search post-order
|
||||
"""
|
||||
for child_scope in itertools.chain(
|
||||
self.cte_scopes, self.union_scopes, self.subquery_scopes, self.derived_table_scopes
|
||||
self.cte_scopes, self.union_scopes, self.derived_table_scopes, self.subquery_scopes
|
||||
):
|
||||
yield from child_scope.traverse()
|
||||
yield self
|
||||
|
@ -399,7 +433,7 @@ def _traverse_scope(scope):
|
|||
yield from _traverse_select(scope)
|
||||
elif isinstance(scope.expression, exp.Union):
|
||||
yield from _traverse_union(scope)
|
||||
elif isinstance(scope.expression, (exp.Lateral, exp.Unnest)):
|
||||
elif isinstance(scope.expression, exp.UDTF):
|
||||
pass
|
||||
elif isinstance(scope.expression, exp.Subquery):
|
||||
yield from _traverse_subqueries(scope)
|
||||
|
@ -410,8 +444,8 @@ def _traverse_scope(scope):
|
|||
|
||||
def _traverse_select(scope):
|
||||
yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
|
||||
yield from _traverse_subqueries(scope)
|
||||
yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE)
|
||||
yield from _traverse_subqueries(scope)
|
||||
_add_table_sources(scope)
|
||||
|
||||
|
||||
|
@ -437,10 +471,10 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
|
|||
top = None
|
||||
for child_scope in _traverse_scope(
|
||||
scope.branch(
|
||||
derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this,
|
||||
add_sources=sources if scope_type == ScopeType.CTE else None,
|
||||
derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this,
|
||||
chain_sources=sources if scope_type == ScopeType.CTE else None,
|
||||
outer_column_list=derived_table.alias_column_names,
|
||||
scope_type=ScopeType.UNNEST if isinstance(derived_table, exp.Unnest) else scope_type,
|
||||
scope_type=ScopeType.UDTF if isinstance(derived_table, exp.UDTF) else scope_type,
|
||||
)
|
||||
):
|
||||
yield child_scope
|
||||
|
@ -483,3 +517,35 @@ def _traverse_subqueries(scope):
|
|||
yield child_scope
|
||||
top = child_scope
|
||||
scope.subquery_scopes.append(top)
|
||||
|
||||
|
||||
def walk_in_scope(expression, bfs=True):
|
||||
"""
|
||||
Returns a generator object which visits all nodes in the syntrax tree, stopping at
|
||||
nodes that start child scopes.
|
||||
|
||||
Args:
|
||||
expression (exp.Expression):
|
||||
bfs (bool): if set to True the BFS traversal order will be applied,
|
||||
otherwise the DFS traversal will be used instead.
|
||||
|
||||
Yields:
|
||||
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
|
||||
"""
|
||||
# We'll use this variable to pass state into the dfs generator.
|
||||
# Whenever we set it to True, we exclude a subtree from traversal.
|
||||
prune = False
|
||||
|
||||
for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
|
||||
prune = False
|
||||
|
||||
yield node, parent, key
|
||||
|
||||
if node is expression:
|
||||
continue
|
||||
elif isinstance(node, exp.CTE):
|
||||
prune = True
|
||||
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
|
||||
prune = True
|
||||
elif isinstance(node, exp.Subqueryable):
|
||||
prune = True
|
||||
|
|
|
@ -126,6 +126,8 @@ class Parser:
|
|||
TokenType.CONSTRAINT,
|
||||
TokenType.DEFAULT,
|
||||
TokenType.DELETE,
|
||||
TokenType.DETERMINISTIC,
|
||||
TokenType.EXECUTE,
|
||||
TokenType.ENGINE,
|
||||
TokenType.ESCAPE,
|
||||
TokenType.EXPLAIN,
|
||||
|
@ -139,6 +141,7 @@ class Parser:
|
|||
TokenType.IF,
|
||||
TokenType.INDEX,
|
||||
TokenType.ISNULL,
|
||||
TokenType.IMMUTABLE,
|
||||
TokenType.INTERVAL,
|
||||
TokenType.LAZY,
|
||||
TokenType.LANGUAGE,
|
||||
|
@ -163,6 +166,7 @@ class Parser:
|
|||
TokenType.SEED,
|
||||
TokenType.SET,
|
||||
TokenType.SHOW,
|
||||
TokenType.STABLE,
|
||||
TokenType.STORED,
|
||||
TokenType.TABLE,
|
||||
TokenType.TABLE_FORMAT,
|
||||
|
@ -175,6 +179,8 @@ class Parser:
|
|||
TokenType.UNIQUE,
|
||||
TokenType.UNPIVOT,
|
||||
TokenType.PROPERTIES,
|
||||
TokenType.PROCEDURE,
|
||||
TokenType.VOLATILE,
|
||||
*SUBQUERY_PREDICATES,
|
||||
*TYPE_TOKENS,
|
||||
}
|
||||
|
@ -204,7 +210,7 @@ class Parser:
|
|||
TokenType.DATETIME,
|
||||
TokenType.TIMESTAMP,
|
||||
TokenType.TIMESTAMPTZ,
|
||||
*NESTED_TYPE_TOKENS,
|
||||
*TYPE_TOKENS,
|
||||
*SUBQUERY_PREDICATES,
|
||||
}
|
||||
|
||||
|
@ -379,6 +385,13 @@ class Parser:
|
|||
TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
|
||||
TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
|
||||
TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty),
|
||||
TokenType.EXECUTE: lambda self: self._parse_execute_as(),
|
||||
TokenType.DETERMINISTIC: lambda self: self.expression(
|
||||
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
|
||||
),
|
||||
TokenType.IMMUTABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")),
|
||||
TokenType.STABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("STABLE")),
|
||||
TokenType.VOLATILE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")),
|
||||
}
|
||||
|
||||
CONSTRAINT_PARSERS = {
|
||||
|
@ -418,7 +431,7 @@ class Parser:
|
|||
|
||||
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
|
||||
|
||||
CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX}
|
||||
CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX, TokenType.PROCEDURE}
|
||||
|
||||
STRICT_CAST = True
|
||||
|
||||
|
@ -615,18 +628,20 @@ class Parser:
|
|||
return expression
|
||||
|
||||
def _parse_drop(self):
|
||||
if self._match(TokenType.TABLE):
|
||||
kind = "TABLE"
|
||||
elif self._match(TokenType.VIEW):
|
||||
kind = "VIEW"
|
||||
else:
|
||||
self.raise_error("Expected TABLE or View")
|
||||
temporary = self._match(TokenType.TEMPORARY)
|
||||
materialized = self._match(TokenType.MATERIALIZED)
|
||||
kind = self._match_set(self.CREATABLES) and self._prev.text
|
||||
if not kind:
|
||||
self.raise_error("Expected TABLE, VIEW, INDEX, FUNCTION, or PROCEDURE")
|
||||
return
|
||||
|
||||
return self.expression(
|
||||
exp.Drop,
|
||||
exists=self._parse_exists(),
|
||||
this=self._parse_table(schema=True),
|
||||
kind=kind,
|
||||
temporary=temporary,
|
||||
materialized=materialized,
|
||||
)
|
||||
|
||||
def _parse_exists(self, not_=False):
|
||||
|
@ -644,14 +659,15 @@ class Parser:
|
|||
create_token = self._match_set(self.CREATABLES) and self._prev
|
||||
|
||||
if not create_token:
|
||||
self.raise_error("Expected TABLE, VIEW, INDEX, or FUNCTION")
|
||||
self.raise_error("Expected TABLE, VIEW, INDEX, FUNCTION, or PROCEDURE")
|
||||
return
|
||||
|
||||
exists = self._parse_exists(not_=True)
|
||||
this = None
|
||||
expression = None
|
||||
properties = None
|
||||
|
||||
if create_token.token_type == TokenType.FUNCTION:
|
||||
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
|
||||
this = self._parse_user_defined_function()
|
||||
properties = self._parse_properties()
|
||||
if self._match(TokenType.ALIAS):
|
||||
|
@ -747,7 +763,9 @@ class Parser:
|
|||
if is_table:
|
||||
if self._match(TokenType.LT):
|
||||
value = self.expression(
|
||||
exp.Schema, this="TABLE", expressions=self._parse_csv(self._parse_struct_kwargs)
|
||||
exp.Schema,
|
||||
this="TABLE",
|
||||
expressions=self._parse_csv(self._parse_struct_kwargs),
|
||||
)
|
||||
if not self._match(TokenType.GT):
|
||||
self.raise_error("Expecting >")
|
||||
|
@ -763,6 +781,14 @@ class Parser:
|
|||
is_table=is_table,
|
||||
)
|
||||
|
||||
def _parse_execute_as(self):
|
||||
self._match(TokenType.ALIAS)
|
||||
return self.expression(
|
||||
exp.ExecuteAsProperty,
|
||||
this=exp.Literal.string("EXECUTE AS"),
|
||||
value=self._parse_var(),
|
||||
)
|
||||
|
||||
def _parse_properties(self):
|
||||
properties = []
|
||||
|
||||
|
@ -997,7 +1023,12 @@ class Parser:
|
|||
)
|
||||
|
||||
def _parse_subquery(self, this):
|
||||
return self.expression(exp.Subquery, this=this, pivots=self._parse_pivots(), alias=self._parse_table_alias())
|
||||
return self.expression(
|
||||
exp.Subquery,
|
||||
this=this,
|
||||
pivots=self._parse_pivots(),
|
||||
alias=self._parse_table_alias(),
|
||||
)
|
||||
|
||||
def _parse_query_modifiers(self, this):
|
||||
if not isinstance(this, self.MODIFIABLES):
|
||||
|
@ -1118,6 +1149,11 @@ class Parser:
|
|||
if unnest:
|
||||
return unnest
|
||||
|
||||
values = self._parse_derived_table_values()
|
||||
|
||||
if values:
|
||||
return values
|
||||
|
||||
subquery = self._parse_select(table=True)
|
||||
|
||||
if subquery:
|
||||
|
@ -1186,6 +1222,24 @@ class Parser:
|
|||
alias=alias,
|
||||
)
|
||||
|
||||
def _parse_derived_table_values(self):
|
||||
is_derived = self._match_pair(TokenType.L_PAREN, TokenType.VALUES)
|
||||
if not is_derived and not self._match(TokenType.VALUES):
|
||||
return None
|
||||
|
||||
expressions = self._parse_csv(self._parse_value)
|
||||
|
||||
if is_derived:
|
||||
self._match_r_paren()
|
||||
|
||||
alias = self._parse_table_alias()
|
||||
|
||||
return self.expression(
|
||||
exp.Values,
|
||||
expressions=expressions,
|
||||
alias=alias,
|
||||
)
|
||||
|
||||
def _parse_table_sample(self):
|
||||
if not self._match(TokenType.TABLE_SAMPLE):
|
||||
return None
|
||||
|
@ -1700,7 +1754,11 @@ class Parser:
|
|||
return self._parse_window(this)
|
||||
|
||||
def _parse_user_defined_function(self):
|
||||
this = self._parse_var()
|
||||
this = self._parse_id_var()
|
||||
|
||||
while self._match(TokenType.DOT):
|
||||
this = self.expression(exp.Dot, this=this, expression=self._parse_id_var())
|
||||
|
||||
if not self._match(TokenType.L_PAREN):
|
||||
return this
|
||||
expressions = self._parse_csv(self._parse_udf_kwarg)
|
||||
|
|
|
@ -136,6 +136,7 @@ class TokenType(AutoName):
|
|||
DEFAULT = auto()
|
||||
DELETE = auto()
|
||||
DESC = auto()
|
||||
DETERMINISTIC = auto()
|
||||
DISTINCT = auto()
|
||||
DISTRIBUTE_BY = auto()
|
||||
DROP = auto()
|
||||
|
@ -144,6 +145,7 @@ class TokenType(AutoName):
|
|||
ENGINE = auto()
|
||||
ESCAPE = auto()
|
||||
EXCEPT = auto()
|
||||
EXECUTE = auto()
|
||||
EXISTS = auto()
|
||||
EXPLAIN = auto()
|
||||
FALSE = auto()
|
||||
|
@ -167,6 +169,7 @@ class TokenType(AutoName):
|
|||
IF = auto()
|
||||
IGNORE_NULLS = auto()
|
||||
ILIKE = auto()
|
||||
IMMUTABLE = auto()
|
||||
IN = auto()
|
||||
INDEX = auto()
|
||||
INNER = auto()
|
||||
|
@ -215,6 +218,7 @@ class TokenType(AutoName):
|
|||
PLACEHOLDER = auto()
|
||||
PRECEDING = auto()
|
||||
PRIMARY_KEY = auto()
|
||||
PROCEDURE = auto()
|
||||
PROPERTIES = auto()
|
||||
QUALIFY = auto()
|
||||
QUOTE = auto()
|
||||
|
@ -238,6 +242,7 @@ class TokenType(AutoName):
|
|||
SIMILAR_TO = auto()
|
||||
SOME = auto()
|
||||
SORT_BY = auto()
|
||||
STABLE = auto()
|
||||
STORED = auto()
|
||||
STRUCT = auto()
|
||||
TABLE_FORMAT = auto()
|
||||
|
@ -258,6 +263,7 @@ class TokenType(AutoName):
|
|||
USING = auto()
|
||||
VALUES = auto()
|
||||
VIEW = auto()
|
||||
VOLATILE = auto()
|
||||
WHEN = auto()
|
||||
WHERE = auto()
|
||||
WINDOW = auto()
|
||||
|
@ -430,6 +436,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"DEFAULT": TokenType.DEFAULT,
|
||||
"DELETE": TokenType.DELETE,
|
||||
"DESC": TokenType.DESC,
|
||||
"DETERMINISTIC": TokenType.DETERMINISTIC,
|
||||
"DISTINCT": TokenType.DISTINCT,
|
||||
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
|
||||
"DROP": TokenType.DROP,
|
||||
|
@ -438,6 +445,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"ENGINE": TokenType.ENGINE,
|
||||
"ESCAPE": TokenType.ESCAPE,
|
||||
"EXCEPT": TokenType.EXCEPT,
|
||||
"EXECUTE": TokenType.EXECUTE,
|
||||
"EXISTS": TokenType.EXISTS,
|
||||
"EXPLAIN": TokenType.EXPLAIN,
|
||||
"FALSE": TokenType.FALSE,
|
||||
|
@ -456,6 +464,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"HAVING": TokenType.HAVING,
|
||||
"IF": TokenType.IF,
|
||||
"ILIKE": TokenType.ILIKE,
|
||||
"IMMUTABLE": TokenType.IMMUTABLE,
|
||||
"IGNORE NULLS": TokenType.IGNORE_NULLS,
|
||||
"IN": TokenType.IN,
|
||||
"INDEX": TokenType.INDEX,
|
||||
|
@ -504,6 +513,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"PIVOT": TokenType.PIVOT,
|
||||
"PRECEDING": TokenType.PRECEDING,
|
||||
"PRIMARY KEY": TokenType.PRIMARY_KEY,
|
||||
"PROCEDURE": TokenType.PROCEDURE,
|
||||
"RANGE": TokenType.RANGE,
|
||||
"RECURSIVE": TokenType.RECURSIVE,
|
||||
"REGEXP": TokenType.RLIKE,
|
||||
|
@ -522,6 +532,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"SHOW": TokenType.SHOW,
|
||||
"SOME": TokenType.SOME,
|
||||
"SORT BY": TokenType.SORT_BY,
|
||||
"STABLE": TokenType.STABLE,
|
||||
"STORED": TokenType.STORED,
|
||||
"TABLE": TokenType.TABLE,
|
||||
"TABLE_FORMAT": TokenType.TABLE_FORMAT,
|
||||
|
@ -542,6 +553,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"USING": TokenType.USING,
|
||||
"VALUES": TokenType.VALUES,
|
||||
"VIEW": TokenType.VIEW,
|
||||
"VOLATILE": TokenType.VOLATILE,
|
||||
"WHEN": TokenType.WHEN,
|
||||
"WHERE": TokenType.WHERE,
|
||||
"WITH": TokenType.WITH,
|
||||
|
@ -637,6 +649,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"_char",
|
||||
"_end",
|
||||
"_peek",
|
||||
"_prev_token_type",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
|
@ -657,6 +670,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._char = None
|
||||
self._end = None
|
||||
self._peek = None
|
||||
self._prev_token_type = None
|
||||
|
||||
def tokenize(self, sql):
|
||||
self.reset()
|
||||
|
@ -706,8 +720,8 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
return self.sql[self._start : self._current]
|
||||
|
||||
def _add(self, token_type, text=None):
|
||||
text = self._text if text is None else text
|
||||
self.tokens.append(Token(token_type, text, self._line, self._col))
|
||||
self._prev_token_type = token_type
|
||||
self.tokens.append(Token(token_type, self._text if text is None else text, self._line, self._col))
|
||||
|
||||
if token_type in self.COMMANDS and (len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON):
|
||||
self._start = self._current
|
||||
|
@ -910,7 +924,11 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._advance()
|
||||
else:
|
||||
break
|
||||
self._add(self.KEYWORDS.get(self._text.upper(), TokenType.VAR))
|
||||
self._add(
|
||||
TokenType.VAR
|
||||
if self._prev_token_type == TokenType.PARAMETER
|
||||
else self.KEYWORDS.get(self._text.upper(), TokenType.VAR)
|
||||
)
|
||||
|
||||
def _extract_string(self, delimiter):
|
||||
text = ""
|
||||
|
|
|
@ -239,7 +239,7 @@ class TestBigQuery(Validator):
|
|||
self.validate_all(
|
||||
"SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
|
||||
write={
|
||||
"spark": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
|
||||
"spark": "SELECT cola, colb FROM VALUES (1, 'test') AS tab(cola, colb)",
|
||||
"bigquery": "SELECT cola, colb FROM UNNEST([STRUCT(1 AS cola, 'test' AS colb)])",
|
||||
"snowflake": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
|
||||
},
|
||||
|
@ -253,7 +253,7 @@ class TestBigQuery(Validator):
|
|||
|
||||
def test_user_defined_functions(self):
|
||||
self.validate_identity(
|
||||
"CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 LANGUAGE js AS 'return x*y;'"
|
||||
"CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 NOT DETERMINISTIC LANGUAGE js AS 'return x*y;'"
|
||||
)
|
||||
self.validate_identity("CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) AS ((x + 4) / y)")
|
||||
self.validate_identity("CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t")
|
||||
|
|
|
@ -1009,7 +1009,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"SELECT * FROM VALUES ('x'), ('y') AS t(z)",
|
||||
write={
|
||||
"spark": "SELECT * FROM (VALUES ('x'), ('y')) AS t(z)",
|
||||
"spark": "SELECT * FROM VALUES ('x'), ('y') AS t(z)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
|
|
|
@ -293,3 +293,15 @@ class TestSnowflake(Validator):
|
|||
"bigquery": "CREATE TABLE FUNCTION a() RETURNS TABLE <b INT64> AS SELECT 1",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CREATE FUNCTION a() RETURNS INT IMMUTABLE AS 'SELECT 1'",
|
||||
write={
|
||||
"snowflake": "CREATE FUNCTION a() RETURNS INT IMMUTABLE AS 'SELECT 1'",
|
||||
},
|
||||
)
|
||||
|
||||
def test_stored_procedures(self):
|
||||
self.validate_identity("CALL a.b.c(x, y)")
|
||||
self.validate_identity(
|
||||
"CREATE PROCEDURE a.b.c(x INT, y VARIANT) RETURNS OBJECT EXECUTE AS CALLER AS 'BEGIN SELECT 1; END;'"
|
||||
)
|
||||
|
|
9
tests/fixtures/identity.sql
vendored
9
tests/fixtures/identity.sql
vendored
|
@ -50,6 +50,7 @@ a.B()
|
|||
a['x'].C()
|
||||
int.x
|
||||
map.x
|
||||
a.b.INT(1.234)
|
||||
x IN (-1, 1)
|
||||
x IN ('a', 'a''a')
|
||||
x IN ((1))
|
||||
|
@ -357,6 +358,7 @@ SELECT * REPLACE (a + 1 AS b, b AS C)
|
|||
SELECT * EXCEPT (a, b) REPLACE (a AS b, b AS C)
|
||||
SELECT a.* EXCEPT (a, b), b.* REPLACE (a AS b, b AS C)
|
||||
SELECT zoo, animals FROM (VALUES ('oakland', ARRAY('a', 'b')), ('sf', ARRAY('b', 'c'))) AS t(zoo, animals)
|
||||
SELECT zoo, animals FROM UNNEST(ARRAY(STRUCT('oakland' AS zoo, ARRAY('a', 'b') AS animals), STRUCT('sf' AS zoo, ARRAY('b', 'c') AS animals))) AS t(zoo, animals)
|
||||
WITH a AS (SELECT 1) SELECT 1 UNION ALL SELECT 2
|
||||
WITH a AS (SELECT 1) SELECT 1 UNION SELECT 2
|
||||
WITH a AS (SELECT 1) SELECT 1 INTERSECT SELECT 2
|
||||
|
@ -444,6 +446,8 @@ CREATE OR REPLACE TEMPORARY VIEW x AS SELECT *
|
|||
CREATE TEMPORARY VIEW x AS SELECT a FROM d
|
||||
CREATE TEMPORARY VIEW IF NOT EXISTS x AS SELECT a FROM d
|
||||
CREATE TEMPORARY VIEW x AS WITH y AS (SELECT 1) SELECT * FROM y
|
||||
CREATE MATERIALIZED VIEW x.y.z AS SELECT a FROM b
|
||||
DROP MATERIALIZED VIEW x.y.z
|
||||
CREATE TABLE z (a INT, b VARCHAR, c VARCHAR(100), d DECIMAL(5, 3))
|
||||
CREATE TABLE z (end INT)
|
||||
CREATE TABLE z (a ARRAY<TEXT>, b MAP<TEXT, DOUBLE>, c DECIMAL(5, 3))
|
||||
|
@ -471,10 +475,13 @@ CREATE FUNCTION f AS 'g'
|
|||
CREATE FUNCTION a(b INT, c VARCHAR) AS 'SELECT 1'
|
||||
CREATE FUNCTION a() LANGUAGE sql
|
||||
CREATE FUNCTION a() LANGUAGE sql RETURNS INT
|
||||
CREATE FUNCTION a.b.c()
|
||||
DROP FUNCTION a.b.c (INT)
|
||||
CREATE INDEX abc ON t (a)
|
||||
CREATE INDEX abc ON t (a, b, b)
|
||||
CREATE UNIQUE INDEX abc ON t (a, b, b)
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS my_idx ON tbl (a, b)
|
||||
DROP INDEX a.b.c
|
||||
CACHE TABLE x
|
||||
CACHE LAZY TABLE x
|
||||
CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value')
|
||||
|
@ -484,6 +491,8 @@ CACHE LAZY TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a
|
|||
CACHE TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a
|
||||
CACHE TABLE x AS (SELECT 1 AS y)
|
||||
CALL catalog.system.iceberg_procedure_name(named_arg_1 => 'arg_1', named_arg_2 => 'arg_2')
|
||||
CREATE PROCEDURE IF NOT EXISTS a.b.c() AS 'DECLARE BEGIN; END'
|
||||
DROP PROCEDURE a.b.c (INT)
|
||||
INSERT OVERWRITE TABLE a.b PARTITION(ds) SELECT x FROM y
|
||||
INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD') SELECT x FROM y
|
||||
INSERT OVERWRITE TABLE a.b PARTITION(ds, hour) SELECT x FROM y
|
||||
|
|
|
@ -97,3 +97,11 @@ WITH x AS (SELECT x.a AS a, x.b AS b FROM x AS x) SELECT x.a AS a, y.b AS b FROM
|
|||
-- Nested CTE
|
||||
SELECT * FROM (WITH x AS (SELECT a, b FROM x) SELECT a, b FROM x);
|
||||
SELECT x.a AS a, x.b AS b FROM x AS x;
|
||||
|
||||
-- Inner select is an expression
|
||||
SELECT a FROM (SELECT a FROM (SELECT COALESCE(a) AS a FROM x LEFT JOIN y ON x.a = y.b) AS x) AS x;
|
||||
SELECT COALESCE(x.a) AS a FROM x AS x LEFT JOIN y AS y ON x.a = y.b;
|
||||
|
||||
-- CTE select is an expression
|
||||
WITH x AS (SELECT COALESCE(a) AS a FROM x LEFT JOIN y ON x.a = y.b) SELECT a FROM (SELECT a FROM x AS x) AS x;
|
||||
SELECT COALESCE(x.a) AS a FROM x AS x LEFT JOIN y AS y ON x.a = y.b;
|
||||
|
|
17
tests/fixtures/optimizer/optimizer.sql
vendored
17
tests/fixtures/optimizer/optimizer.sql
vendored
|
@ -137,3 +137,20 @@ SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x;
|
|||
SELECT
|
||||
AGGREGATE(ARRAY("x"."a", "x"."b"), 0, ("x", "acc") -> "x" + "acc" + "x"."a") AS "sum_agg"
|
||||
FROM "x" AS "x";
|
||||
|
||||
SELECT cola, colb FROM (VALUES (1, 'test'), (2, 'test2')) AS tab(cola, colb);
|
||||
SELECT
|
||||
"tab"."cola" AS "cola",
|
||||
"tab"."colb" AS "colb"
|
||||
FROM (VALUES
|
||||
(1, 'test'),
|
||||
(2, 'test2')) AS "tab"("cola", "colb");
|
||||
|
||||
# dialect: spark
|
||||
SELECT cola, colb FROM (VALUES (1, 'test'), (2, 'test2')) AS tab(cola, colb);
|
||||
SELECT
|
||||
`tab`.`cola` AS `cola`,
|
||||
`tab`.`colb` AS `colb`
|
||||
FROM VALUES
|
||||
(1, 'test'),
|
||||
(2, 'test2') AS `tab`(`cola`, `colb`);
|
||||
|
|
|
@ -39,3 +39,15 @@ SELECT "_q_0".b AS b FROM (SELECT SUM(x.b) AS b FROM x AS x GROUP BY x.a) AS "_q
|
|||
|
||||
SELECT b FROM (SELECT a, SUM(b) AS b FROM x ORDER BY a);
|
||||
SELECT "_q_0".b AS b FROM (SELECT x.a AS a, SUM(x.b) AS b FROM x AS x ORDER BY a) AS "_q_0";
|
||||
|
||||
SELECT x FROM (VALUES(1, 2)) AS q(x, y);
|
||||
SELECT q.x AS x FROM (VALUES (1, 2)) AS q(x, y);
|
||||
|
||||
SELECT x FROM UNNEST([1, 2]) AS q(x, y);
|
||||
SELECT q.x AS x FROM UNNEST(ARRAY(1, 2)) AS q(x, y);
|
||||
|
||||
WITH t1 AS (SELECT cola, colb FROM UNNEST([STRUCT(1 AS cola, 'test' AS colb)]) AS "q"("cola", "colb")) SELECT cola FROM t1;
|
||||
WITH t1 AS (SELECT q.cola AS cola FROM UNNEST(ARRAY(STRUCT(1 AS cola, 'test' AS colb))) AS "q"("cola", "colb")) SELECT t1.cola AS cola FROM t1;
|
||||
|
||||
SELECT x FROM VALUES(1, 2) AS q(x, y);
|
||||
SELECT q.x AS x FROM (VALUES (1, 2)) AS q(x, y);
|
||||
|
|
|
@ -5,7 +5,7 @@ from sqlglot import exp, optimizer, parse_one, table
|
|||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
from sqlglot.optimizer.schema import MappingSchema, ensure_schema
|
||||
from sqlglot.optimizer.scope import build_scope, traverse_scope
|
||||
from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope
|
||||
from tests.helpers import TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures
|
||||
|
||||
|
||||
|
@ -264,12 +264,13 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
ON s.b = r.b
|
||||
WHERE s.b > (SELECT MAX(x.a) FROM x WHERE x.b = s.b)
|
||||
"""
|
||||
for scopes in traverse_scope(parse_one(sql)), list(build_scope(parse_one(sql)).traverse()):
|
||||
expression = parse_one(sql)
|
||||
for scopes in traverse_scope(expression), list(build_scope(expression).traverse()):
|
||||
self.assertEqual(len(scopes), 5)
|
||||
self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x")
|
||||
self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
|
||||
self.assertEqual(scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
|
||||
self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y")
|
||||
self.assertEqual(scopes[2].expression.sql(), "SELECT y.c AS b FROM y")
|
||||
self.assertEqual(scopes[3].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
|
||||
self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql())
|
||||
|
||||
self.assertEqual(set(scopes[4].sources), {"q", "r", "s"})
|
||||
|
@ -279,6 +280,16 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
self.assertEqual(len(scopes[4].source_columns("r")), 2)
|
||||
self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"})
|
||||
|
||||
self.assertEqual({c.sql() for c in scopes[-1].find_all(exp.Column)}, {"r.b", "s.b"})
|
||||
self.assertEqual(scopes[-1].find(exp.Column).sql(), "r.b")
|
||||
self.assertEqual({c.sql() for c in scopes[0].find_all(exp.Column)}, {"x.b"})
|
||||
|
||||
# Check that we can walk in scope from an arbitrary node
|
||||
self.assertEqual(
|
||||
{node.sql() for node, *_ in walk_in_scope(expression.find(exp.Where)) if isinstance(node, exp.Column)},
|
||||
{"s.b"},
|
||||
)
|
||||
|
||||
def test_literal_type_annotation(self):
|
||||
tests = {
|
||||
"SELECT 5": exp.DataType.Type.INT,
|
||||
|
|
|
@ -122,6 +122,9 @@ class TestParser(unittest.TestCase):
|
|||
def test_parameter(self):
|
||||
self.assertEqual(parse_one("SELECT @x, @@x, @1").sql(), "SELECT @x, @@x, @1")
|
||||
|
||||
def test_var(self):
|
||||
self.assertEqual(parse_one("SELECT @JOIN, @'foo'").sql(), "SELECT @JOIN, @'foo'")
|
||||
|
||||
def test_annotations(self):
|
||||
expression = parse_one(
|
||||
"""
|
||||
|
|
Loading…
Add table
Reference in a new issue