1
0
Fork 0

Adding upstream version 6.2.8.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:42:49 +01:00
parent d62bab68ae
commit 24cf9d8984
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
22 changed files with 361 additions and 98 deletions

View file

@ -1,6 +1,6 @@
# SQLGlot
SQLGlot is a no dependency Python SQL parser, transpiler, and optimizer. It can be used to format SQL or translate between different dialects like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects.
SQLGlot is a no dependency Python SQL parser, transpiler, and optimizer. It can be used to format SQL or translate between different dialects like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects.
It is a very comprehensive generic SQL parser with a robust [test suite](tests). It is also quite [performant](#benchmarks) while being written purely in Python.
@ -30,7 +30,7 @@ sqlglot.transpile("SELECT EPOCH_MS(1618088028295)", read='duckdb', write='hive')
```
```sql
SELECT TO_UTC_TIMESTAMP(FROM_UNIXTIME(1618088028295 / 1000, 'yyyy-MM-dd HH:mm:ss'), 'UTC')
SELECT FROM_UNIXTIME(1618088028295 / 1000)
```
SQLGlot can even translate custom time formats.
@ -299,7 +299,7 @@ class Custom(Dialect):
}
Dialects["custom"]
Dialect["custom"]
```
## Benchmarks

View file

@ -20,7 +20,7 @@ from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
__version__ = "6.2.6"
__version__ = "6.2.8"
pretty = False

View file

@ -33,10 +33,10 @@ def _date_add_sql(data_type, kind):
return func
def _subquery_to_unnest_if_values(self, expression):
if not isinstance(expression.this, exp.Values):
return self.subquery_sql(expression)
rows = [list(tuple_exp.find_all(exp.Literal)) for tuple_exp in expression.this.find_all(exp.Tuple)]
def _derived_table_values_to_unnest(self, expression):
if not isinstance(expression.unnest().parent, exp.From):
return self.values_sql(expression)
rows = [list(tuple_exp.find_all(exp.Literal)) for tuple_exp in expression.find_all(exp.Tuple)]
structs = []
for row in rows:
aliases = [
@ -99,6 +99,7 @@ class BigQuery(Dialect):
"QUALIFY": TokenType.QUALIFY,
"UNKNOWN": TokenType.NULL,
"WINDOW": TokenType.WINDOW,
"NOT DETERMINISTIC": TokenType.VOLATILE,
}
class Parser(Parser):
@ -140,9 +141,10 @@ class BigQuery(Dialect):
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
exp.VariancePop: rename_func("VAR_POP"),
exp.Subquery: _subquery_to_unnest_if_values,
exp.Values: _derived_table_values_to_unnest,
exp.ReturnsProperty: _returnsproperty_sql,
exp.Create: _create_sql,
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC",
}
TYPE_MAPPING = {
@ -160,6 +162,16 @@ class BigQuery(Dialect):
exp.DataType.Type.NVARCHAR: "STRING",
}
ROOT_PROPERTIES = {
exp.LanguageProperty,
exp.ReturnsProperty,
exp.VolatilityProperty,
}
WITH_PROPERTIES = {
exp.AnonymousProperty,
}
def in_unnest_op(self, unnest):
return self.sql(unnest)

View file

@ -77,6 +77,7 @@ class Dialect(metaclass=_Dialect):
alias_post_tablesample = False
normalize_functions = "upper"
null_ordering = "nulls_are_small"
wrap_derived_values = True
date_format = "'%Y-%m-%d'"
dateint_format = "'%Y%m%d'"
@ -169,6 +170,7 @@ class Dialect(metaclass=_Dialect):
"alias_post_tablesample": self.alias_post_tablesample,
"normalize_functions": self.normalize_functions,
"null_ordering": self.null_ordering,
"wrap_derived_values": self.wrap_derived_values,
**opts,
}
)

View file

@ -177,6 +177,8 @@ class Snowflake(Dialect):
exp.ReturnsProperty,
exp.LanguageProperty,
exp.SchemaCommentProperty,
exp.ExecuteAsProperty,
exp.VolatilityProperty,
}
def except_op(self, expression):

View file

@ -47,6 +47,8 @@ def _unix_to_time(self, expression):
class Spark(Hive):
wrap_derived_values = False
class Parser(Hive.Parser):
FUNCTIONS = {
**Hive.Parser.FUNCTIONS,

View file

@ -213,21 +213,23 @@ class Expression(metaclass=_Expression):
"""
return self.find_ancestor(Select)
def walk(self, bfs=True):
def walk(self, bfs=True, prune=None):
"""
Returns a generator object which visits all nodes in this tree.
Args:
bfs (bool): if set to True the BFS traversal order will be applied,
otherwise the DFS traversal will be used instead.
prune ((node, parent, arg_key) -> bool): callable that returns True if
the generator should stop traversing this branch of the tree.
Returns:
the generator object.
"""
if bfs:
yield from self.bfs()
yield from self.bfs(prune=prune)
else:
yield from self.dfs()
yield from self.dfs(prune=prune)
def dfs(self, parent=None, key=None, prune=None):
"""
@ -506,6 +508,10 @@ class DerivedTable(Expression):
return [select.alias_or_name for select in self.selects]
class UDTF(DerivedTable):
pass
class Annotation(Expression):
arg_types = {
"this": True,
@ -652,7 +658,13 @@ class Delete(Expression):
class Drop(Expression):
arg_types = {"this": False, "kind": False, "exists": False}
arg_types = {
"this": False,
"kind": False,
"exists": False,
"temporary": False,
"materialized": False,
}
class Filter(Expression):
@ -827,7 +839,7 @@ class Join(Expression):
return join
class Lateral(DerivedTable):
class Lateral(UDTF):
arg_types = {"this": True, "outer": False, "alias": False}
@ -915,6 +927,14 @@ class LanguageProperty(Property):
pass
class ExecuteAsProperty(Property):
pass
class VolatilityProperty(Property):
arg_types = {"this": True}
class Properties(Expression):
arg_types = {"expressions": True}
@ -1098,7 +1118,7 @@ class Intersect(Union):
pass
class Unnest(DerivedTable):
class Unnest(UDTF):
arg_types = {
"expressions": True,
"ordinality": False,
@ -1116,8 +1136,12 @@ class Update(Expression):
}
class Values(Expression):
arg_types = {"expressions": True}
class Values(UDTF):
arg_types = {
"expressions": True,
"ordinality": False,
"alias": False,
}
class Var(Expression):
@ -2033,23 +2057,17 @@ class Func(Condition):
@classmethod
def from_arg_list(cls, args):
args_num = len(args)
if cls.is_var_len_args:
all_arg_keys = list(cls.arg_types)
# If this function supports variable length argument treat the last argument as such.
non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys
num_non_var = len(non_var_len_arg_keys)
all_arg_keys = list(cls.arg_types)
# If this function supports variable length argument treat the last argument as such.
non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys
args_dict = {arg_key: arg for arg, arg_key in zip(args, non_var_len_arg_keys)}
args_dict[all_arg_keys[-1]] = args[num_non_var:]
else:
args_dict = {arg_key: arg for arg, arg_key in zip(args, cls.arg_types)}
args_dict = {}
arg_idx = 0
for arg_key in non_var_len_arg_keys:
if arg_idx >= args_num:
break
if args[arg_idx] is not None:
args_dict[arg_key] = args[arg_idx]
arg_idx += 1
if arg_idx < args_num and cls.is_var_len_args:
args_dict[all_arg_keys[-1]] = args[arg_idx:]
return cls(**args_dict)
@classmethod

View file

@ -49,10 +49,12 @@ class Generator:
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}",
exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
exp.VolatilityProperty: lambda self, e: self.sql(e.name),
}
NULL_ORDERING_SUPPORTED = True
@ -99,6 +101,7 @@ class Generator:
"unsupported_messages",
"null_ordering",
"max_unsupported",
"wrap_derived_values",
"_indent",
"_replace_backslash",
"_escaped_quote_end",
@ -127,6 +130,7 @@ class Generator:
null_ordering=None,
max_unsupported=3,
leading_comma=False,
wrap_derived_values=True,
):
import sqlglot
@ -150,6 +154,7 @@ class Generator:
self.unsupported_messages = []
self.max_unsupported = max_unsupported
self.null_ordering = null_ordering
self.wrap_derived_values = wrap_derived_values
self._indent = indent
self._replace_backslash = self.escape == "\\"
self._escaped_quote_end = self.escape + self.quote_end
@ -407,7 +412,9 @@ class Generator:
this = self.sql(expression, "this")
kind = expression.args["kind"]
exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
return f"DROP {kind}{exists_sql}{this}"
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}"
def except_sql(self, expression):
return self.prepend_ctes(
@ -583,7 +590,14 @@ class Generator:
return self.prepend_ctes(expression, sql)
def values_sql(self, expression):
return f"VALUES{self.seg('')}{self.expressions(expression)}"
alias = self.sql(expression, "alias")
args = self.expressions(expression)
if not alias:
return f"VALUES{self.seg('')}{args}"
alias = f" AS {alias}" if alias else alias
if self.wrap_derived_values:
return f"(VALUES{self.seg('')}{args}){alias}"
return f"VALUES{self.seg('')}{args}{alias}"
def var_sql(self, expression):
return self.sql(expression, "this")

View file

@ -32,8 +32,8 @@ def merge_subqueries(expression, leave_tables_isolated=False):
Returns:
sqlglot.Expression: optimized expression
"""
merge_ctes(expression, leave_tables_isolated)
merge_derived_tables(expression, leave_tables_isolated)
expression = merge_ctes(expression, leave_tables_isolated)
expression = merge_derived_tables(expression, leave_tables_isolated)
return expression
@ -76,14 +76,14 @@ def merge_ctes(expression, leave_tables_isolated=False):
alias = node_to_replace.alias
else:
alias = table.name
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, node_to_replace, alias)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_expressions(outer_scope, inner_scope, alias)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
_pop_cte(inner_scope)
return expression
def merge_derived_tables(expression, leave_tables_isolated=False):
@ -97,10 +97,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, subquery, alias)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_expressions(outer_scope, inner_scope, alias)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
return expression
def _mergeable(outer_scope, inner_select, leave_tables_isolated):
@ -229,7 +230,7 @@ def _merge_expressions(outer_scope, inner_scope, alias):
continue
columns_to_replace = outer_columns.get(projection_name, [])
for column in columns_to_replace:
column.replace(expression.unalias())
column.replace(expression.unalias().copy())
def _merge_where(outer_scope, inner_scope, from_or_join):

View file

@ -5,8 +5,6 @@ from sqlglot.errors import OptimizeError
from sqlglot.optimizer.schema import ensure_schema
from sqlglot.optimizer.scope import traverse_scope
SKIP_QUALIFY = (exp.Unnest, exp.Lateral)
def qualify_columns(expression, schema):
"""
@ -35,7 +33,7 @@ def qualify_columns(expression, schema):
_expand_group_by(scope, resolver)
_expand_order_by(scope)
_qualify_columns(scope, resolver)
if not isinstance(scope.expression, SKIP_QUALIFY):
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver)
_qualify_outputs(scope)
_check_unknown_tables(scope)
@ -50,7 +48,7 @@ def _pop_table_column_aliases(derived_tables):
(e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
"""
for derived_table in derived_tables:
if isinstance(derived_table, SKIP_QUALIFY):
if isinstance(derived_table, exp.UDTF):
continue
table_alias = derived_table.args.get("alias")
if table_alias:
@ -202,7 +200,7 @@ def _qualify_columns(scope, resolver):
if not column_table:
column_table = resolver.get_table(column_name)
if not scope.is_subquery and not scope.is_unnest:
if not scope.is_subquery and not scope.is_udtf:
if column_name not in resolver.all_columns:
raise OptimizeError(f"Unknown column: {column_name}")
@ -296,7 +294,7 @@ def _qualify_outputs(scope):
def _check_unknown_tables(scope):
if scope.external_columns and not scope.is_unnest and not scope.is_correlated_subquery:
if scope.external_columns and not scope.is_udtf and not scope.is_correlated_subquery:
raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")

View file

@ -1,5 +1,4 @@
import itertools
from copy import copy
from enum import Enum, auto
from sqlglot import exp
@ -12,7 +11,7 @@ class ScopeType(Enum):
DERIVED_TABLE = auto()
CTE = auto()
UNION = auto()
UNNEST = auto()
UDTF = auto()
class Scope:
@ -70,14 +69,11 @@ class Scope:
self._columns = None
self._external_columns = None
def branch(self, expression, scope_type, add_sources=None, **kwargs):
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
"""Branch from the current scope to a new, inner scope"""
sources = copy(self.sources)
if add_sources:
sources.update(add_sources)
return Scope(
expression=expression.unnest(),
sources=sources,
sources={**self.cte_sources, **(chain_sources or {})},
parent=self,
scope_type=scope_type,
**kwargs,
@ -90,30 +86,21 @@ class Scope:
self._derived_tables = []
self._raw_columns = []
# We'll use this variable to pass state into the dfs generator.
# Whenever we set it to True, we exclude a subtree from traversal.
prune = False
for node, parent, _ in self.expression.dfs(prune=lambda *_: prune):
prune = False
for node, parent, _ in self.walk(bfs=False):
if node is self.expression:
continue
if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
self._raw_columns.append(node)
elif isinstance(node, exp.Table):
self._tables.append(node)
elif isinstance(node, (exp.Unnest, exp.Lateral)):
elif isinstance(node, exp.UDTF):
self._derived_tables.append(node)
elif isinstance(node, exp.CTE):
self._ctes.append(node)
prune = True
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
self._derived_tables.append(node)
prune = True
elif isinstance(node, exp.Subqueryable):
self._subqueries.append(node)
prune = True
self._collected = True
@ -121,6 +108,43 @@ class Scope:
if not self._collected:
self._collect()
def walk(self, bfs=True):
return walk_in_scope(self.expression, bfs=bfs)
def find(self, *expression_types, bfs=True):
"""
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Args:
expression_types (type): the expression type(s) to match.
bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching
the criteria was found.
"""
return next(self.find_all(*expression_types, bfs=bfs), None)
def find_all(self, *expression_types, bfs=True):
"""
Returns a generator object which visits all nodes in this scope and only yields those that
match at least one of the specified expression types.
This does NOT traverse into subscopes.
Args:
expression_types (type): the expression type(s) to match.
bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
"""
for expression, _, _ in self.walk(bfs=bfs):
if isinstance(expression, expression_types):
yield expression
def replace(self, old, new):
"""
Replace `old` with `new`.
@ -246,6 +270,16 @@ class Scope:
self._selected_sources = result
return self._selected_sources
@property
def cte_sources(self):
"""
Sources that are CTEs.
Returns:
dict[str, Scope]: Mapping of source alias to Scope
"""
return {alias: scope for alias, scope in self.sources.items() if isinstance(scope, Scope) and scope.is_cte}
@property
def selects(self):
"""
@ -313,9 +347,9 @@ class Scope:
return self.scope_type == ScopeType.ROOT
@property
def is_unnest(self):
"""Determine if this scope is an unnest"""
return self.scope_type == ScopeType.UNNEST
def is_udtf(self):
"""Determine if this scope is a UDTF (User Defined Table Function)"""
return self.scope_type == ScopeType.UDTF
@property
def is_correlated_subquery(self):
@ -348,7 +382,7 @@ class Scope:
Scope: scope instances in depth-first-search post-order
"""
for child_scope in itertools.chain(
self.cte_scopes, self.union_scopes, self.subquery_scopes, self.derived_table_scopes
self.cte_scopes, self.union_scopes, self.derived_table_scopes, self.subquery_scopes
):
yield from child_scope.traverse()
yield self
@ -399,7 +433,7 @@ def _traverse_scope(scope):
yield from _traverse_select(scope)
elif isinstance(scope.expression, exp.Union):
yield from _traverse_union(scope)
elif isinstance(scope.expression, (exp.Lateral, exp.Unnest)):
elif isinstance(scope.expression, exp.UDTF):
pass
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
@ -410,8 +444,8 @@ def _traverse_scope(scope):
def _traverse_select(scope):
yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
yield from _traverse_subqueries(scope)
yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE)
yield from _traverse_subqueries(scope)
_add_table_sources(scope)
@ -437,10 +471,10 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
top = None
for child_scope in _traverse_scope(
scope.branch(
derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this,
add_sources=sources if scope_type == ScopeType.CTE else None,
derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this,
chain_sources=sources if scope_type == ScopeType.CTE else None,
outer_column_list=derived_table.alias_column_names,
scope_type=ScopeType.UNNEST if isinstance(derived_table, exp.Unnest) else scope_type,
scope_type=ScopeType.UDTF if isinstance(derived_table, exp.UDTF) else scope_type,
)
):
yield child_scope
@ -483,3 +517,35 @@ def _traverse_subqueries(scope):
yield child_scope
top = child_scope
scope.subquery_scopes.append(top)
def walk_in_scope(expression, bfs=True):
"""
Returns a generator object which visits all nodes in the syntrax tree, stopping at
nodes that start child scopes.
Args:
expression (exp.Expression):
bfs (bool): if set to True the BFS traversal order will be applied,
otherwise the DFS traversal will be used instead.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
"""
# We'll use this variable to pass state into the dfs generator.
# Whenever we set it to True, we exclude a subtree from traversal.
prune = False
for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
prune = False
yield node, parent, key
if node is expression:
continue
elif isinstance(node, exp.CTE):
prune = True
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
prune = True
elif isinstance(node, exp.Subqueryable):
prune = True

View file

@ -126,6 +126,8 @@ class Parser:
TokenType.CONSTRAINT,
TokenType.DEFAULT,
TokenType.DELETE,
TokenType.DETERMINISTIC,
TokenType.EXECUTE,
TokenType.ENGINE,
TokenType.ESCAPE,
TokenType.EXPLAIN,
@ -139,6 +141,7 @@ class Parser:
TokenType.IF,
TokenType.INDEX,
TokenType.ISNULL,
TokenType.IMMUTABLE,
TokenType.INTERVAL,
TokenType.LAZY,
TokenType.LANGUAGE,
@ -163,6 +166,7 @@ class Parser:
TokenType.SEED,
TokenType.SET,
TokenType.SHOW,
TokenType.STABLE,
TokenType.STORED,
TokenType.TABLE,
TokenType.TABLE_FORMAT,
@ -175,6 +179,8 @@ class Parser:
TokenType.UNIQUE,
TokenType.UNPIVOT,
TokenType.PROPERTIES,
TokenType.PROCEDURE,
TokenType.VOLATILE,
*SUBQUERY_PREDICATES,
*TYPE_TOKENS,
}
@ -204,7 +210,7 @@ class Parser:
TokenType.DATETIME,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
*NESTED_TYPE_TOKENS,
*TYPE_TOKENS,
*SUBQUERY_PREDICATES,
}
@ -379,6 +385,13 @@ class Parser:
TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty),
TokenType.EXECUTE: lambda self: self._parse_execute_as(),
TokenType.DETERMINISTIC: lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
),
TokenType.IMMUTABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")),
TokenType.STABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("STABLE")),
TokenType.VOLATILE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")),
}
CONSTRAINT_PARSERS = {
@ -418,7 +431,7 @@ class Parser:
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX}
CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX, TokenType.PROCEDURE}
STRICT_CAST = True
@ -615,18 +628,20 @@ class Parser:
return expression
def _parse_drop(self):
if self._match(TokenType.TABLE):
kind = "TABLE"
elif self._match(TokenType.VIEW):
kind = "VIEW"
else:
self.raise_error("Expected TABLE or View")
temporary = self._match(TokenType.TEMPORARY)
materialized = self._match(TokenType.MATERIALIZED)
kind = self._match_set(self.CREATABLES) and self._prev.text
if not kind:
self.raise_error("Expected TABLE, VIEW, INDEX, FUNCTION, or PROCEDURE")
return
return self.expression(
exp.Drop,
exists=self._parse_exists(),
this=self._parse_table(schema=True),
kind=kind,
temporary=temporary,
materialized=materialized,
)
def _parse_exists(self, not_=False):
@ -644,14 +659,15 @@ class Parser:
create_token = self._match_set(self.CREATABLES) and self._prev
if not create_token:
self.raise_error("Expected TABLE, VIEW, INDEX, or FUNCTION")
self.raise_error("Expected TABLE, VIEW, INDEX, FUNCTION, or PROCEDURE")
return
exists = self._parse_exists(not_=True)
this = None
expression = None
properties = None
if create_token.token_type == TokenType.FUNCTION:
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function()
properties = self._parse_properties()
if self._match(TokenType.ALIAS):
@ -747,7 +763,9 @@ class Parser:
if is_table:
if self._match(TokenType.LT):
value = self.expression(
exp.Schema, this="TABLE", expressions=self._parse_csv(self._parse_struct_kwargs)
exp.Schema,
this="TABLE",
expressions=self._parse_csv(self._parse_struct_kwargs),
)
if not self._match(TokenType.GT):
self.raise_error("Expecting >")
@ -763,6 +781,14 @@ class Parser:
is_table=is_table,
)
def _parse_execute_as(self):
self._match(TokenType.ALIAS)
return self.expression(
exp.ExecuteAsProperty,
this=exp.Literal.string("EXECUTE AS"),
value=self._parse_var(),
)
def _parse_properties(self):
properties = []
@ -997,7 +1023,12 @@ class Parser:
)
def _parse_subquery(self, this):
return self.expression(exp.Subquery, this=this, pivots=self._parse_pivots(), alias=self._parse_table_alias())
return self.expression(
exp.Subquery,
this=this,
pivots=self._parse_pivots(),
alias=self._parse_table_alias(),
)
def _parse_query_modifiers(self, this):
if not isinstance(this, self.MODIFIABLES):
@ -1118,6 +1149,11 @@ class Parser:
if unnest:
return unnest
values = self._parse_derived_table_values()
if values:
return values
subquery = self._parse_select(table=True)
if subquery:
@ -1186,6 +1222,24 @@ class Parser:
alias=alias,
)
def _parse_derived_table_values(self):
is_derived = self._match_pair(TokenType.L_PAREN, TokenType.VALUES)
if not is_derived and not self._match(TokenType.VALUES):
return None
expressions = self._parse_csv(self._parse_value)
if is_derived:
self._match_r_paren()
alias = self._parse_table_alias()
return self.expression(
exp.Values,
expressions=expressions,
alias=alias,
)
def _parse_table_sample(self):
if not self._match(TokenType.TABLE_SAMPLE):
return None
@ -1700,7 +1754,11 @@ class Parser:
return self._parse_window(this)
def _parse_user_defined_function(self):
this = self._parse_var()
this = self._parse_id_var()
while self._match(TokenType.DOT):
this = self.expression(exp.Dot, this=this, expression=self._parse_id_var())
if not self._match(TokenType.L_PAREN):
return this
expressions = self._parse_csv(self._parse_udf_kwarg)

View file

@ -136,6 +136,7 @@ class TokenType(AutoName):
DEFAULT = auto()
DELETE = auto()
DESC = auto()
DETERMINISTIC = auto()
DISTINCT = auto()
DISTRIBUTE_BY = auto()
DROP = auto()
@ -144,6 +145,7 @@ class TokenType(AutoName):
ENGINE = auto()
ESCAPE = auto()
EXCEPT = auto()
EXECUTE = auto()
EXISTS = auto()
EXPLAIN = auto()
FALSE = auto()
@ -167,6 +169,7 @@ class TokenType(AutoName):
IF = auto()
IGNORE_NULLS = auto()
ILIKE = auto()
IMMUTABLE = auto()
IN = auto()
INDEX = auto()
INNER = auto()
@ -215,6 +218,7 @@ class TokenType(AutoName):
PLACEHOLDER = auto()
PRECEDING = auto()
PRIMARY_KEY = auto()
PROCEDURE = auto()
PROPERTIES = auto()
QUALIFY = auto()
QUOTE = auto()
@ -238,6 +242,7 @@ class TokenType(AutoName):
SIMILAR_TO = auto()
SOME = auto()
SORT_BY = auto()
STABLE = auto()
STORED = auto()
STRUCT = auto()
TABLE_FORMAT = auto()
@ -258,6 +263,7 @@ class TokenType(AutoName):
USING = auto()
VALUES = auto()
VIEW = auto()
VOLATILE = auto()
WHEN = auto()
WHERE = auto()
WINDOW = auto()
@ -430,6 +436,7 @@ class Tokenizer(metaclass=_Tokenizer):
"DEFAULT": TokenType.DEFAULT,
"DELETE": TokenType.DELETE,
"DESC": TokenType.DESC,
"DETERMINISTIC": TokenType.DETERMINISTIC,
"DISTINCT": TokenType.DISTINCT,
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
"DROP": TokenType.DROP,
@ -438,6 +445,7 @@ class Tokenizer(metaclass=_Tokenizer):
"ENGINE": TokenType.ENGINE,
"ESCAPE": TokenType.ESCAPE,
"EXCEPT": TokenType.EXCEPT,
"EXECUTE": TokenType.EXECUTE,
"EXISTS": TokenType.EXISTS,
"EXPLAIN": TokenType.EXPLAIN,
"FALSE": TokenType.FALSE,
@ -456,6 +464,7 @@ class Tokenizer(metaclass=_Tokenizer):
"HAVING": TokenType.HAVING,
"IF": TokenType.IF,
"ILIKE": TokenType.ILIKE,
"IMMUTABLE": TokenType.IMMUTABLE,
"IGNORE NULLS": TokenType.IGNORE_NULLS,
"IN": TokenType.IN,
"INDEX": TokenType.INDEX,
@ -504,6 +513,7 @@ class Tokenizer(metaclass=_Tokenizer):
"PIVOT": TokenType.PIVOT,
"PRECEDING": TokenType.PRECEDING,
"PRIMARY KEY": TokenType.PRIMARY_KEY,
"PROCEDURE": TokenType.PROCEDURE,
"RANGE": TokenType.RANGE,
"RECURSIVE": TokenType.RECURSIVE,
"REGEXP": TokenType.RLIKE,
@ -522,6 +532,7 @@ class Tokenizer(metaclass=_Tokenizer):
"SHOW": TokenType.SHOW,
"SOME": TokenType.SOME,
"SORT BY": TokenType.SORT_BY,
"STABLE": TokenType.STABLE,
"STORED": TokenType.STORED,
"TABLE": TokenType.TABLE,
"TABLE_FORMAT": TokenType.TABLE_FORMAT,
@ -542,6 +553,7 @@ class Tokenizer(metaclass=_Tokenizer):
"USING": TokenType.USING,
"VALUES": TokenType.VALUES,
"VIEW": TokenType.VIEW,
"VOLATILE": TokenType.VOLATILE,
"WHEN": TokenType.WHEN,
"WHERE": TokenType.WHERE,
"WITH": TokenType.WITH,
@ -637,6 +649,7 @@ class Tokenizer(metaclass=_Tokenizer):
"_char",
"_end",
"_peek",
"_prev_token_type",
)
def __init__(self):
@ -657,6 +670,7 @@ class Tokenizer(metaclass=_Tokenizer):
self._char = None
self._end = None
self._peek = None
self._prev_token_type = None
def tokenize(self, sql):
self.reset()
@ -706,8 +720,8 @@ class Tokenizer(metaclass=_Tokenizer):
return self.sql[self._start : self._current]
def _add(self, token_type, text=None):
text = self._text if text is None else text
self.tokens.append(Token(token_type, text, self._line, self._col))
self._prev_token_type = token_type
self.tokens.append(Token(token_type, self._text if text is None else text, self._line, self._col))
if token_type in self.COMMANDS and (len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON):
self._start = self._current
@ -910,7 +924,11 @@ class Tokenizer(metaclass=_Tokenizer):
self._advance()
else:
break
self._add(self.KEYWORDS.get(self._text.upper(), TokenType.VAR))
self._add(
TokenType.VAR
if self._prev_token_type == TokenType.PARAMETER
else self.KEYWORDS.get(self._text.upper(), TokenType.VAR)
)
def _extract_string(self, delimiter):
text = ""

View file

@ -239,7 +239,7 @@ class TestBigQuery(Validator):
self.validate_all(
"SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
write={
"spark": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
"spark": "SELECT cola, colb FROM VALUES (1, 'test') AS tab(cola, colb)",
"bigquery": "SELECT cola, colb FROM UNNEST([STRUCT(1 AS cola, 'test' AS colb)])",
"snowflake": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
},
@ -253,7 +253,7 @@ class TestBigQuery(Validator):
def test_user_defined_functions(self):
self.validate_identity(
"CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 LANGUAGE js AS 'return x*y;'"
"CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 NOT DETERMINISTIC LANGUAGE js AS 'return x*y;'"
)
self.validate_identity("CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) AS ((x + 4) / y)")
self.validate_identity("CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t")

View file

@ -1009,7 +1009,7 @@ class TestDialect(Validator):
self.validate_all(
"SELECT * FROM VALUES ('x'), ('y') AS t(z)",
write={
"spark": "SELECT * FROM (VALUES ('x'), ('y')) AS t(z)",
"spark": "SELECT * FROM VALUES ('x'), ('y') AS t(z)",
},
)
self.validate_all(

View file

@ -293,3 +293,15 @@ class TestSnowflake(Validator):
"bigquery": "CREATE TABLE FUNCTION a() RETURNS TABLE <b INT64> AS SELECT 1",
},
)
self.validate_all(
"CREATE FUNCTION a() RETURNS INT IMMUTABLE AS 'SELECT 1'",
write={
"snowflake": "CREATE FUNCTION a() RETURNS INT IMMUTABLE AS 'SELECT 1'",
},
)
def test_stored_procedures(self):
self.validate_identity("CALL a.b.c(x, y)")
self.validate_identity(
"CREATE PROCEDURE a.b.c(x INT, y VARIANT) RETURNS OBJECT EXECUTE AS CALLER AS 'BEGIN SELECT 1; END;'"
)

View file

@ -50,6 +50,7 @@ a.B()
a['x'].C()
int.x
map.x
a.b.INT(1.234)
x IN (-1, 1)
x IN ('a', 'a''a')
x IN ((1))
@ -357,6 +358,7 @@ SELECT * REPLACE (a + 1 AS b, b AS C)
SELECT * EXCEPT (a, b) REPLACE (a AS b, b AS C)
SELECT a.* EXCEPT (a, b), b.* REPLACE (a AS b, b AS C)
SELECT zoo, animals FROM (VALUES ('oakland', ARRAY('a', 'b')), ('sf', ARRAY('b', 'c'))) AS t(zoo, animals)
SELECT zoo, animals FROM UNNEST(ARRAY(STRUCT('oakland' AS zoo, ARRAY('a', 'b') AS animals), STRUCT('sf' AS zoo, ARRAY('b', 'c') AS animals))) AS t(zoo, animals)
WITH a AS (SELECT 1) SELECT 1 UNION ALL SELECT 2
WITH a AS (SELECT 1) SELECT 1 UNION SELECT 2
WITH a AS (SELECT 1) SELECT 1 INTERSECT SELECT 2
@ -444,6 +446,8 @@ CREATE OR REPLACE TEMPORARY VIEW x AS SELECT *
CREATE TEMPORARY VIEW x AS SELECT a FROM d
CREATE TEMPORARY VIEW IF NOT EXISTS x AS SELECT a FROM d
CREATE TEMPORARY VIEW x AS WITH y AS (SELECT 1) SELECT * FROM y
CREATE MATERIALIZED VIEW x.y.z AS SELECT a FROM b
DROP MATERIALIZED VIEW x.y.z
CREATE TABLE z (a INT, b VARCHAR, c VARCHAR(100), d DECIMAL(5, 3))
CREATE TABLE z (end INT)
CREATE TABLE z (a ARRAY<TEXT>, b MAP<TEXT, DOUBLE>, c DECIMAL(5, 3))
@ -471,10 +475,13 @@ CREATE FUNCTION f AS 'g'
CREATE FUNCTION a(b INT, c VARCHAR) AS 'SELECT 1'
CREATE FUNCTION a() LANGUAGE sql
CREATE FUNCTION a() LANGUAGE sql RETURNS INT
CREATE FUNCTION a.b.c()
DROP FUNCTION a.b.c (INT)
CREATE INDEX abc ON t (a)
CREATE INDEX abc ON t (a, b, b)
CREATE UNIQUE INDEX abc ON t (a, b, b)
CREATE UNIQUE INDEX IF NOT EXISTS my_idx ON tbl (a, b)
DROP INDEX a.b.c
CACHE TABLE x
CACHE LAZY TABLE x
CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value')
@ -484,6 +491,8 @@ CACHE LAZY TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a
CACHE TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a
CACHE TABLE x AS (SELECT 1 AS y)
CALL catalog.system.iceberg_procedure_name(named_arg_1 => 'arg_1', named_arg_2 => 'arg_2')
CREATE PROCEDURE IF NOT EXISTS a.b.c() AS 'DECLARE BEGIN; END'
DROP PROCEDURE a.b.c (INT)
INSERT OVERWRITE TABLE a.b PARTITION(ds) SELECT x FROM y
INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD') SELECT x FROM y
INSERT OVERWRITE TABLE a.b PARTITION(ds, hour) SELECT x FROM y

View file

@ -97,3 +97,11 @@ WITH x AS (SELECT x.a AS a, x.b AS b FROM x AS x) SELECT x.a AS a, y.b AS b FROM
-- Nested CTE
SELECT * FROM (WITH x AS (SELECT a, b FROM x) SELECT a, b FROM x);
SELECT x.a AS a, x.b AS b FROM x AS x;
-- Inner select is an expression
SELECT a FROM (SELECT a FROM (SELECT COALESCE(a) AS a FROM x LEFT JOIN y ON x.a = y.b) AS x) AS x;
SELECT COALESCE(x.a) AS a FROM x AS x LEFT JOIN y AS y ON x.a = y.b;
-- CTE select is an expression
WITH x AS (SELECT COALESCE(a) AS a FROM x LEFT JOIN y ON x.a = y.b) SELECT a FROM (SELECT a FROM x AS x) AS x;
SELECT COALESCE(x.a) AS a FROM x AS x LEFT JOIN y AS y ON x.a = y.b;

View file

@ -137,3 +137,20 @@ SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x;
SELECT
AGGREGATE(ARRAY("x"."a", "x"."b"), 0, ("x", "acc") -> "x" + "acc" + "x"."a") AS "sum_agg"
FROM "x" AS "x";
SELECT cola, colb FROM (VALUES (1, 'test'), (2, 'test2')) AS tab(cola, colb);
SELECT
"tab"."cola" AS "cola",
"tab"."colb" AS "colb"
FROM (VALUES
(1, 'test'),
(2, 'test2')) AS "tab"("cola", "colb");
# dialect: spark
SELECT cola, colb FROM (VALUES (1, 'test'), (2, 'test2')) AS tab(cola, colb);
SELECT
`tab`.`cola` AS `cola`,
`tab`.`colb` AS `colb`
FROM VALUES
(1, 'test'),
(2, 'test2') AS `tab`(`cola`, `colb`);

View file

@ -39,3 +39,15 @@ SELECT "_q_0".b AS b FROM (SELECT SUM(x.b) AS b FROM x AS x GROUP BY x.a) AS "_q
SELECT b FROM (SELECT a, SUM(b) AS b FROM x ORDER BY a);
SELECT "_q_0".b AS b FROM (SELECT x.a AS a, SUM(x.b) AS b FROM x AS x ORDER BY a) AS "_q_0";
SELECT x FROM (VALUES(1, 2)) AS q(x, y);
SELECT q.x AS x FROM (VALUES (1, 2)) AS q(x, y);
SELECT x FROM UNNEST([1, 2]) AS q(x, y);
SELECT q.x AS x FROM UNNEST(ARRAY(1, 2)) AS q(x, y);
WITH t1 AS (SELECT cola, colb FROM UNNEST([STRUCT(1 AS cola, 'test' AS colb)]) AS "q"("cola", "colb")) SELECT cola FROM t1;
WITH t1 AS (SELECT q.cola AS cola FROM UNNEST(ARRAY(STRUCT(1 AS cola, 'test' AS colb))) AS "q"("cola", "colb")) SELECT t1.cola AS cola FROM t1;
SELECT x FROM VALUES(1, 2) AS q(x, y);
SELECT q.x AS x FROM (VALUES (1, 2)) AS q(x, y);

View file

@ -5,7 +5,7 @@ from sqlglot import exp, optimizer, parse_one, table
from sqlglot.errors import OptimizeError
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.schema import MappingSchema, ensure_schema
from sqlglot.optimizer.scope import build_scope, traverse_scope
from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope
from tests.helpers import TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures
@ -264,12 +264,13 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
ON s.b = r.b
WHERE s.b > (SELECT MAX(x.a) FROM x WHERE x.b = s.b)
"""
for scopes in traverse_scope(parse_one(sql)), list(build_scope(parse_one(sql)).traverse()):
expression = parse_one(sql)
for scopes in traverse_scope(expression), list(build_scope(expression).traverse()):
self.assertEqual(len(scopes), 5)
self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x")
self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
self.assertEqual(scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y")
self.assertEqual(scopes[2].expression.sql(), "SELECT y.c AS b FROM y")
self.assertEqual(scopes[3].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql())
self.assertEqual(set(scopes[4].sources), {"q", "r", "s"})
@ -279,6 +280,16 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(len(scopes[4].source_columns("r")), 2)
self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"})
self.assertEqual({c.sql() for c in scopes[-1].find_all(exp.Column)}, {"r.b", "s.b"})
self.assertEqual(scopes[-1].find(exp.Column).sql(), "r.b")
self.assertEqual({c.sql() for c in scopes[0].find_all(exp.Column)}, {"x.b"})
# Check that we can walk in scope from an arbitrary node
self.assertEqual(
{node.sql() for node, *_ in walk_in_scope(expression.find(exp.Where)) if isinstance(node, exp.Column)},
{"s.b"},
)
def test_literal_type_annotation(self):
tests = {
"SELECT 5": exp.DataType.Type.INT,

View file

@ -122,6 +122,9 @@ class TestParser(unittest.TestCase):
def test_parameter(self):
self.assertEqual(parse_one("SELECT @x, @@x, @1").sql(), "SELECT @x, @@x, @1")
def test_var(self):
self.assertEqual(parse_one("SELECT @JOIN, @'foo'").sql(), "SELECT @JOIN, @'foo'")
def test_annotations(self):
expression = parse_one(
"""