1
0
Fork 0

Merging upstream version 6.2.8.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:43:32 +01:00
parent 87ba722f7f
commit a62bbc24c3
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
22 changed files with 361 additions and 98 deletions

View file

@ -20,7 +20,7 @@ from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
__version__ = "6.2.6"
__version__ = "6.2.8"
pretty = False

View file

@ -33,10 +33,10 @@ def _date_add_sql(data_type, kind):
return func
def _subquery_to_unnest_if_values(self, expression):
if not isinstance(expression.this, exp.Values):
return self.subquery_sql(expression)
rows = [list(tuple_exp.find_all(exp.Literal)) for tuple_exp in expression.this.find_all(exp.Tuple)]
def _derived_table_values_to_unnest(self, expression):
if not isinstance(expression.unnest().parent, exp.From):
return self.values_sql(expression)
rows = [list(tuple_exp.find_all(exp.Literal)) for tuple_exp in expression.find_all(exp.Tuple)]
structs = []
for row in rows:
aliases = [
@ -99,6 +99,7 @@ class BigQuery(Dialect):
"QUALIFY": TokenType.QUALIFY,
"UNKNOWN": TokenType.NULL,
"WINDOW": TokenType.WINDOW,
"NOT DETERMINISTIC": TokenType.VOLATILE,
}
class Parser(Parser):
@ -140,9 +141,10 @@ class BigQuery(Dialect):
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
exp.VariancePop: rename_func("VAR_POP"),
exp.Subquery: _subquery_to_unnest_if_values,
exp.Values: _derived_table_values_to_unnest,
exp.ReturnsProperty: _returnsproperty_sql,
exp.Create: _create_sql,
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC",
}
TYPE_MAPPING = {
@ -160,6 +162,16 @@ class BigQuery(Dialect):
exp.DataType.Type.NVARCHAR: "STRING",
}
ROOT_PROPERTIES = {
exp.LanguageProperty,
exp.ReturnsProperty,
exp.VolatilityProperty,
}
WITH_PROPERTIES = {
exp.AnonymousProperty,
}
def in_unnest_op(self, unnest):
return self.sql(unnest)

View file

@ -77,6 +77,7 @@ class Dialect(metaclass=_Dialect):
alias_post_tablesample = False
normalize_functions = "upper"
null_ordering = "nulls_are_small"
wrap_derived_values = True
date_format = "'%Y-%m-%d'"
dateint_format = "'%Y%m%d'"
@ -169,6 +170,7 @@ class Dialect(metaclass=_Dialect):
"alias_post_tablesample": self.alias_post_tablesample,
"normalize_functions": self.normalize_functions,
"null_ordering": self.null_ordering,
"wrap_derived_values": self.wrap_derived_values,
**opts,
}
)

View file

@ -177,6 +177,8 @@ class Snowflake(Dialect):
exp.ReturnsProperty,
exp.LanguageProperty,
exp.SchemaCommentProperty,
exp.ExecuteAsProperty,
exp.VolatilityProperty,
}
def except_op(self, expression):

View file

@ -47,6 +47,8 @@ def _unix_to_time(self, expression):
class Spark(Hive):
wrap_derived_values = False
class Parser(Hive.Parser):
FUNCTIONS = {
**Hive.Parser.FUNCTIONS,

View file

@ -213,21 +213,23 @@ class Expression(metaclass=_Expression):
"""
return self.find_ancestor(Select)
def walk(self, bfs=True):
def walk(self, bfs=True, prune=None):
"""
Returns a generator object which visits all nodes in this tree.
Args:
bfs (bool): if set to True the BFS traversal order will be applied,
otherwise the DFS traversal will be used instead.
prune ((node, parent, arg_key) -> bool): callable that returns True if
the generator should stop traversing this branch of the tree.
Returns:
the generator object.
"""
if bfs:
yield from self.bfs()
yield from self.bfs(prune=prune)
else:
yield from self.dfs()
yield from self.dfs(prune=prune)
def dfs(self, parent=None, key=None, prune=None):
"""
@ -506,6 +508,10 @@ class DerivedTable(Expression):
return [select.alias_or_name for select in self.selects]
class UDTF(DerivedTable):
pass
class Annotation(Expression):
arg_types = {
"this": True,
@ -652,7 +658,13 @@ class Delete(Expression):
class Drop(Expression):
arg_types = {"this": False, "kind": False, "exists": False}
arg_types = {
"this": False,
"kind": False,
"exists": False,
"temporary": False,
"materialized": False,
}
class Filter(Expression):
@ -827,7 +839,7 @@ class Join(Expression):
return join
class Lateral(DerivedTable):
class Lateral(UDTF):
arg_types = {"this": True, "outer": False, "alias": False}
@ -915,6 +927,14 @@ class LanguageProperty(Property):
pass
class ExecuteAsProperty(Property):
pass
class VolatilityProperty(Property):
arg_types = {"this": True}
class Properties(Expression):
arg_types = {"expressions": True}
@ -1098,7 +1118,7 @@ class Intersect(Union):
pass
class Unnest(DerivedTable):
class Unnest(UDTF):
arg_types = {
"expressions": True,
"ordinality": False,
@ -1116,8 +1136,12 @@ class Update(Expression):
}
class Values(Expression):
arg_types = {"expressions": True}
class Values(UDTF):
arg_types = {
"expressions": True,
"ordinality": False,
"alias": False,
}
class Var(Expression):
@ -2033,23 +2057,17 @@ class Func(Condition):
@classmethod
def from_arg_list(cls, args):
args_num = len(args)
if cls.is_var_len_args:
all_arg_keys = list(cls.arg_types)
# If this function supports variable length argument treat the last argument as such.
non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys
num_non_var = len(non_var_len_arg_keys)
all_arg_keys = list(cls.arg_types)
# If this function supports variable length argument treat the last argument as such.
non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys
args_dict = {arg_key: arg for arg, arg_key in zip(args, non_var_len_arg_keys)}
args_dict[all_arg_keys[-1]] = args[num_non_var:]
else:
args_dict = {arg_key: arg for arg, arg_key in zip(args, cls.arg_types)}
args_dict = {}
arg_idx = 0
for arg_key in non_var_len_arg_keys:
if arg_idx >= args_num:
break
if args[arg_idx] is not None:
args_dict[arg_key] = args[arg_idx]
arg_idx += 1
if arg_idx < args_num and cls.is_var_len_args:
args_dict[all_arg_keys[-1]] = args[arg_idx:]
return cls(**args_dict)
@classmethod

View file

@ -49,10 +49,12 @@ class Generator:
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}",
exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
exp.VolatilityProperty: lambda self, e: self.sql(e.name),
}
NULL_ORDERING_SUPPORTED = True
@ -99,6 +101,7 @@ class Generator:
"unsupported_messages",
"null_ordering",
"max_unsupported",
"wrap_derived_values",
"_indent",
"_replace_backslash",
"_escaped_quote_end",
@ -127,6 +130,7 @@ class Generator:
null_ordering=None,
max_unsupported=3,
leading_comma=False,
wrap_derived_values=True,
):
import sqlglot
@ -150,6 +154,7 @@ class Generator:
self.unsupported_messages = []
self.max_unsupported = max_unsupported
self.null_ordering = null_ordering
self.wrap_derived_values = wrap_derived_values
self._indent = indent
self._replace_backslash = self.escape == "\\"
self._escaped_quote_end = self.escape + self.quote_end
@ -407,7 +412,9 @@ class Generator:
this = self.sql(expression, "this")
kind = expression.args["kind"]
exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
return f"DROP {kind}{exists_sql}{this}"
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}"
def except_sql(self, expression):
return self.prepend_ctes(
@ -583,7 +590,14 @@ class Generator:
return self.prepend_ctes(expression, sql)
def values_sql(self, expression):
return f"VALUES{self.seg('')}{self.expressions(expression)}"
alias = self.sql(expression, "alias")
args = self.expressions(expression)
if not alias:
return f"VALUES{self.seg('')}{args}"
alias = f" AS {alias}" if alias else alias
if self.wrap_derived_values:
return f"(VALUES{self.seg('')}{args}){alias}"
return f"VALUES{self.seg('')}{args}{alias}"
def var_sql(self, expression):
return self.sql(expression, "this")

View file

@ -32,8 +32,8 @@ def merge_subqueries(expression, leave_tables_isolated=False):
Returns:
sqlglot.Expression: optimized expression
"""
merge_ctes(expression, leave_tables_isolated)
merge_derived_tables(expression, leave_tables_isolated)
expression = merge_ctes(expression, leave_tables_isolated)
expression = merge_derived_tables(expression, leave_tables_isolated)
return expression
@ -76,14 +76,14 @@ def merge_ctes(expression, leave_tables_isolated=False):
alias = node_to_replace.alias
else:
alias = table.name
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, node_to_replace, alias)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_expressions(outer_scope, inner_scope, alias)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
_pop_cte(inner_scope)
return expression
def merge_derived_tables(expression, leave_tables_isolated=False):
@ -97,10 +97,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, subquery, alias)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_expressions(outer_scope, inner_scope, alias)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
return expression
def _mergeable(outer_scope, inner_select, leave_tables_isolated):
@ -229,7 +230,7 @@ def _merge_expressions(outer_scope, inner_scope, alias):
continue
columns_to_replace = outer_columns.get(projection_name, [])
for column in columns_to_replace:
column.replace(expression.unalias())
column.replace(expression.unalias().copy())
def _merge_where(outer_scope, inner_scope, from_or_join):

View file

@ -5,8 +5,6 @@ from sqlglot.errors import OptimizeError
from sqlglot.optimizer.schema import ensure_schema
from sqlglot.optimizer.scope import traverse_scope
SKIP_QUALIFY = (exp.Unnest, exp.Lateral)
def qualify_columns(expression, schema):
"""
@ -35,7 +33,7 @@ def qualify_columns(expression, schema):
_expand_group_by(scope, resolver)
_expand_order_by(scope)
_qualify_columns(scope, resolver)
if not isinstance(scope.expression, SKIP_QUALIFY):
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver)
_qualify_outputs(scope)
_check_unknown_tables(scope)
@ -50,7 +48,7 @@ def _pop_table_column_aliases(derived_tables):
(e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
"""
for derived_table in derived_tables:
if isinstance(derived_table, SKIP_QUALIFY):
if isinstance(derived_table, exp.UDTF):
continue
table_alias = derived_table.args.get("alias")
if table_alias:
@ -202,7 +200,7 @@ def _qualify_columns(scope, resolver):
if not column_table:
column_table = resolver.get_table(column_name)
if not scope.is_subquery and not scope.is_unnest:
if not scope.is_subquery and not scope.is_udtf:
if column_name not in resolver.all_columns:
raise OptimizeError(f"Unknown column: {column_name}")
@ -296,7 +294,7 @@ def _qualify_outputs(scope):
def _check_unknown_tables(scope):
if scope.external_columns and not scope.is_unnest and not scope.is_correlated_subquery:
if scope.external_columns and not scope.is_udtf and not scope.is_correlated_subquery:
raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")

View file

@ -1,5 +1,4 @@
import itertools
from copy import copy
from enum import Enum, auto
from sqlglot import exp
@ -12,7 +11,7 @@ class ScopeType(Enum):
DERIVED_TABLE = auto()
CTE = auto()
UNION = auto()
UNNEST = auto()
UDTF = auto()
class Scope:
@ -70,14 +69,11 @@ class Scope:
self._columns = None
self._external_columns = None
def branch(self, expression, scope_type, add_sources=None, **kwargs):
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
"""Branch from the current scope to a new, inner scope"""
sources = copy(self.sources)
if add_sources:
sources.update(add_sources)
return Scope(
expression=expression.unnest(),
sources=sources,
sources={**self.cte_sources, **(chain_sources or {})},
parent=self,
scope_type=scope_type,
**kwargs,
@ -90,30 +86,21 @@ class Scope:
self._derived_tables = []
self._raw_columns = []
# We'll use this variable to pass state into the dfs generator.
# Whenever we set it to True, we exclude a subtree from traversal.
prune = False
for node, parent, _ in self.expression.dfs(prune=lambda *_: prune):
prune = False
for node, parent, _ in self.walk(bfs=False):
if node is self.expression:
continue
if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
self._raw_columns.append(node)
elif isinstance(node, exp.Table):
self._tables.append(node)
elif isinstance(node, (exp.Unnest, exp.Lateral)):
elif isinstance(node, exp.UDTF):
self._derived_tables.append(node)
elif isinstance(node, exp.CTE):
self._ctes.append(node)
prune = True
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
self._derived_tables.append(node)
prune = True
elif isinstance(node, exp.Subqueryable):
self._subqueries.append(node)
prune = True
self._collected = True
@ -121,6 +108,43 @@ class Scope:
if not self._collected:
self._collect()
def walk(self, bfs=True):
return walk_in_scope(self.expression, bfs=bfs)
def find(self, *expression_types, bfs=True):
"""
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Args:
expression_types (type): the expression type(s) to match.
bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching
the criteria was found.
"""
return next(self.find_all(*expression_types, bfs=bfs), None)
def find_all(self, *expression_types, bfs=True):
"""
Returns a generator object which visits all nodes in this scope and only yields those that
match at least one of the specified expression types.
This does NOT traverse into subscopes.
Args:
expression_types (type): the expression type(s) to match.
bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
"""
for expression, _, _ in self.walk(bfs=bfs):
if isinstance(expression, expression_types):
yield expression
def replace(self, old, new):
"""
Replace `old` with `new`.
@ -246,6 +270,16 @@ class Scope:
self._selected_sources = result
return self._selected_sources
@property
def cte_sources(self):
"""
Sources that are CTEs.
Returns:
dict[str, Scope]: Mapping of source alias to Scope
"""
return {alias: scope for alias, scope in self.sources.items() if isinstance(scope, Scope) and scope.is_cte}
@property
def selects(self):
"""
@ -313,9 +347,9 @@ class Scope:
return self.scope_type == ScopeType.ROOT
@property
def is_unnest(self):
"""Determine if this scope is an unnest"""
return self.scope_type == ScopeType.UNNEST
def is_udtf(self):
"""Determine if this scope is a UDTF (User Defined Table Function)"""
return self.scope_type == ScopeType.UDTF
@property
def is_correlated_subquery(self):
@ -348,7 +382,7 @@ class Scope:
Scope: scope instances in depth-first-search post-order
"""
for child_scope in itertools.chain(
self.cte_scopes, self.union_scopes, self.subquery_scopes, self.derived_table_scopes
self.cte_scopes, self.union_scopes, self.derived_table_scopes, self.subquery_scopes
):
yield from child_scope.traverse()
yield self
@ -399,7 +433,7 @@ def _traverse_scope(scope):
yield from _traverse_select(scope)
elif isinstance(scope.expression, exp.Union):
yield from _traverse_union(scope)
elif isinstance(scope.expression, (exp.Lateral, exp.Unnest)):
elif isinstance(scope.expression, exp.UDTF):
pass
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
@ -410,8 +444,8 @@ def _traverse_scope(scope):
def _traverse_select(scope):
yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
yield from _traverse_subqueries(scope)
yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE)
yield from _traverse_subqueries(scope)
_add_table_sources(scope)
@ -437,10 +471,10 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
top = None
for child_scope in _traverse_scope(
scope.branch(
derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this,
add_sources=sources if scope_type == ScopeType.CTE else None,
derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this,
chain_sources=sources if scope_type == ScopeType.CTE else None,
outer_column_list=derived_table.alias_column_names,
scope_type=ScopeType.UNNEST if isinstance(derived_table, exp.Unnest) else scope_type,
scope_type=ScopeType.UDTF if isinstance(derived_table, exp.UDTF) else scope_type,
)
):
yield child_scope
@ -483,3 +517,35 @@ def _traverse_subqueries(scope):
yield child_scope
top = child_scope
scope.subquery_scopes.append(top)
def walk_in_scope(expression, bfs=True):
"""
Returns a generator object which visits all nodes in the syntrax tree, stopping at
nodes that start child scopes.
Args:
expression (exp.Expression):
bfs (bool): if set to True the BFS traversal order will be applied,
otherwise the DFS traversal will be used instead.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
"""
# We'll use this variable to pass state into the dfs generator.
# Whenever we set it to True, we exclude a subtree from traversal.
prune = False
for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
prune = False
yield node, parent, key
if node is expression:
continue
elif isinstance(node, exp.CTE):
prune = True
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
prune = True
elif isinstance(node, exp.Subqueryable):
prune = True

View file

@ -126,6 +126,8 @@ class Parser:
TokenType.CONSTRAINT,
TokenType.DEFAULT,
TokenType.DELETE,
TokenType.DETERMINISTIC,
TokenType.EXECUTE,
TokenType.ENGINE,
TokenType.ESCAPE,
TokenType.EXPLAIN,
@ -139,6 +141,7 @@ class Parser:
TokenType.IF,
TokenType.INDEX,
TokenType.ISNULL,
TokenType.IMMUTABLE,
TokenType.INTERVAL,
TokenType.LAZY,
TokenType.LANGUAGE,
@ -163,6 +166,7 @@ class Parser:
TokenType.SEED,
TokenType.SET,
TokenType.SHOW,
TokenType.STABLE,
TokenType.STORED,
TokenType.TABLE,
TokenType.TABLE_FORMAT,
@ -175,6 +179,8 @@ class Parser:
TokenType.UNIQUE,
TokenType.UNPIVOT,
TokenType.PROPERTIES,
TokenType.PROCEDURE,
TokenType.VOLATILE,
*SUBQUERY_PREDICATES,
*TYPE_TOKENS,
}
@ -204,7 +210,7 @@ class Parser:
TokenType.DATETIME,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
*NESTED_TYPE_TOKENS,
*TYPE_TOKENS,
*SUBQUERY_PREDICATES,
}
@ -379,6 +385,13 @@ class Parser:
TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty),
TokenType.EXECUTE: lambda self: self._parse_execute_as(),
TokenType.DETERMINISTIC: lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
),
TokenType.IMMUTABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")),
TokenType.STABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("STABLE")),
TokenType.VOLATILE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")),
}
CONSTRAINT_PARSERS = {
@ -418,7 +431,7 @@ class Parser:
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX}
CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX, TokenType.PROCEDURE}
STRICT_CAST = True
@ -615,18 +628,20 @@ class Parser:
return expression
def _parse_drop(self):
if self._match(TokenType.TABLE):
kind = "TABLE"
elif self._match(TokenType.VIEW):
kind = "VIEW"
else:
self.raise_error("Expected TABLE or View")
temporary = self._match(TokenType.TEMPORARY)
materialized = self._match(TokenType.MATERIALIZED)
kind = self._match_set(self.CREATABLES) and self._prev.text
if not kind:
self.raise_error("Expected TABLE, VIEW, INDEX, FUNCTION, or PROCEDURE")
return
return self.expression(
exp.Drop,
exists=self._parse_exists(),
this=self._parse_table(schema=True),
kind=kind,
temporary=temporary,
materialized=materialized,
)
def _parse_exists(self, not_=False):
@ -644,14 +659,15 @@ class Parser:
create_token = self._match_set(self.CREATABLES) and self._prev
if not create_token:
self.raise_error("Expected TABLE, VIEW, INDEX, or FUNCTION")
self.raise_error("Expected TABLE, VIEW, INDEX, FUNCTION, or PROCEDURE")
return
exists = self._parse_exists(not_=True)
this = None
expression = None
properties = None
if create_token.token_type == TokenType.FUNCTION:
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function()
properties = self._parse_properties()
if self._match(TokenType.ALIAS):
@ -747,7 +763,9 @@ class Parser:
if is_table:
if self._match(TokenType.LT):
value = self.expression(
exp.Schema, this="TABLE", expressions=self._parse_csv(self._parse_struct_kwargs)
exp.Schema,
this="TABLE",
expressions=self._parse_csv(self._parse_struct_kwargs),
)
if not self._match(TokenType.GT):
self.raise_error("Expecting >")
@ -763,6 +781,14 @@ class Parser:
is_table=is_table,
)
def _parse_execute_as(self):
self._match(TokenType.ALIAS)
return self.expression(
exp.ExecuteAsProperty,
this=exp.Literal.string("EXECUTE AS"),
value=self._parse_var(),
)
def _parse_properties(self):
properties = []
@ -997,7 +1023,12 @@ class Parser:
)
def _parse_subquery(self, this):
return self.expression(exp.Subquery, this=this, pivots=self._parse_pivots(), alias=self._parse_table_alias())
return self.expression(
exp.Subquery,
this=this,
pivots=self._parse_pivots(),
alias=self._parse_table_alias(),
)
def _parse_query_modifiers(self, this):
if not isinstance(this, self.MODIFIABLES):
@ -1118,6 +1149,11 @@ class Parser:
if unnest:
return unnest
values = self._parse_derived_table_values()
if values:
return values
subquery = self._parse_select(table=True)
if subquery:
@ -1186,6 +1222,24 @@ class Parser:
alias=alias,
)
def _parse_derived_table_values(self):
is_derived = self._match_pair(TokenType.L_PAREN, TokenType.VALUES)
if not is_derived and not self._match(TokenType.VALUES):
return None
expressions = self._parse_csv(self._parse_value)
if is_derived:
self._match_r_paren()
alias = self._parse_table_alias()
return self.expression(
exp.Values,
expressions=expressions,
alias=alias,
)
def _parse_table_sample(self):
if not self._match(TokenType.TABLE_SAMPLE):
return None
@ -1700,7 +1754,11 @@ class Parser:
return self._parse_window(this)
def _parse_user_defined_function(self):
this = self._parse_var()
this = self._parse_id_var()
while self._match(TokenType.DOT):
this = self.expression(exp.Dot, this=this, expression=self._parse_id_var())
if not self._match(TokenType.L_PAREN):
return this
expressions = self._parse_csv(self._parse_udf_kwarg)

View file

@ -136,6 +136,7 @@ class TokenType(AutoName):
DEFAULT = auto()
DELETE = auto()
DESC = auto()
DETERMINISTIC = auto()
DISTINCT = auto()
DISTRIBUTE_BY = auto()
DROP = auto()
@ -144,6 +145,7 @@ class TokenType(AutoName):
ENGINE = auto()
ESCAPE = auto()
EXCEPT = auto()
EXECUTE = auto()
EXISTS = auto()
EXPLAIN = auto()
FALSE = auto()
@ -167,6 +169,7 @@ class TokenType(AutoName):
IF = auto()
IGNORE_NULLS = auto()
ILIKE = auto()
IMMUTABLE = auto()
IN = auto()
INDEX = auto()
INNER = auto()
@ -215,6 +218,7 @@ class TokenType(AutoName):
PLACEHOLDER = auto()
PRECEDING = auto()
PRIMARY_KEY = auto()
PROCEDURE = auto()
PROPERTIES = auto()
QUALIFY = auto()
QUOTE = auto()
@ -238,6 +242,7 @@ class TokenType(AutoName):
SIMILAR_TO = auto()
SOME = auto()
SORT_BY = auto()
STABLE = auto()
STORED = auto()
STRUCT = auto()
TABLE_FORMAT = auto()
@ -258,6 +263,7 @@ class TokenType(AutoName):
USING = auto()
VALUES = auto()
VIEW = auto()
VOLATILE = auto()
WHEN = auto()
WHERE = auto()
WINDOW = auto()
@ -430,6 +436,7 @@ class Tokenizer(metaclass=_Tokenizer):
"DEFAULT": TokenType.DEFAULT,
"DELETE": TokenType.DELETE,
"DESC": TokenType.DESC,
"DETERMINISTIC": TokenType.DETERMINISTIC,
"DISTINCT": TokenType.DISTINCT,
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
"DROP": TokenType.DROP,
@ -438,6 +445,7 @@ class Tokenizer(metaclass=_Tokenizer):
"ENGINE": TokenType.ENGINE,
"ESCAPE": TokenType.ESCAPE,
"EXCEPT": TokenType.EXCEPT,
"EXECUTE": TokenType.EXECUTE,
"EXISTS": TokenType.EXISTS,
"EXPLAIN": TokenType.EXPLAIN,
"FALSE": TokenType.FALSE,
@ -456,6 +464,7 @@ class Tokenizer(metaclass=_Tokenizer):
"HAVING": TokenType.HAVING,
"IF": TokenType.IF,
"ILIKE": TokenType.ILIKE,
"IMMUTABLE": TokenType.IMMUTABLE,
"IGNORE NULLS": TokenType.IGNORE_NULLS,
"IN": TokenType.IN,
"INDEX": TokenType.INDEX,
@ -504,6 +513,7 @@ class Tokenizer(metaclass=_Tokenizer):
"PIVOT": TokenType.PIVOT,
"PRECEDING": TokenType.PRECEDING,
"PRIMARY KEY": TokenType.PRIMARY_KEY,
"PROCEDURE": TokenType.PROCEDURE,
"RANGE": TokenType.RANGE,
"RECURSIVE": TokenType.RECURSIVE,
"REGEXP": TokenType.RLIKE,
@ -522,6 +532,7 @@ class Tokenizer(metaclass=_Tokenizer):
"SHOW": TokenType.SHOW,
"SOME": TokenType.SOME,
"SORT BY": TokenType.SORT_BY,
"STABLE": TokenType.STABLE,
"STORED": TokenType.STORED,
"TABLE": TokenType.TABLE,
"TABLE_FORMAT": TokenType.TABLE_FORMAT,
@ -542,6 +553,7 @@ class Tokenizer(metaclass=_Tokenizer):
"USING": TokenType.USING,
"VALUES": TokenType.VALUES,
"VIEW": TokenType.VIEW,
"VOLATILE": TokenType.VOLATILE,
"WHEN": TokenType.WHEN,
"WHERE": TokenType.WHERE,
"WITH": TokenType.WITH,
@ -637,6 +649,7 @@ class Tokenizer(metaclass=_Tokenizer):
"_char",
"_end",
"_peek",
"_prev_token_type",
)
def __init__(self):
@ -657,6 +670,7 @@ class Tokenizer(metaclass=_Tokenizer):
self._char = None
self._end = None
self._peek = None
self._prev_token_type = None
def tokenize(self, sql):
self.reset()
@ -706,8 +720,8 @@ class Tokenizer(metaclass=_Tokenizer):
return self.sql[self._start : self._current]
def _add(self, token_type, text=None):
text = self._text if text is None else text
self.tokens.append(Token(token_type, text, self._line, self._col))
self._prev_token_type = token_type
self.tokens.append(Token(token_type, self._text if text is None else text, self._line, self._col))
if token_type in self.COMMANDS and (len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON):
self._start = self._current
@ -910,7 +924,11 @@ class Tokenizer(metaclass=_Tokenizer):
self._advance()
else:
break
self._add(self.KEYWORDS.get(self._text.upper(), TokenType.VAR))
self._add(
TokenType.VAR
if self._prev_token_type == TokenType.PARAMETER
else self.KEYWORDS.get(self._text.upper(), TokenType.VAR)
)
def _extract_string(self, delimiter):
text = ""