Merging upstream version 11.3.3.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
f767789b5e
commit
4a70b88890
62 changed files with 28339 additions and 27272 deletions
|
@ -47,7 +47,7 @@ if t.TYPE_CHECKING:
|
|||
T = t.TypeVar("T", bound=Expression)
|
||||
|
||||
|
||||
__version__ = "11.3.0"
|
||||
__version__ = "11.3.3"
|
||||
|
||||
pretty = False
|
||||
"""Whether to format generated SQL by default."""
|
||||
|
|
|
@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
|
|||
Dialect,
|
||||
datestrtodate_sql,
|
||||
inline_array_sql,
|
||||
min_or_least,
|
||||
no_ilike_sql,
|
||||
rename_func,
|
||||
timestrtotime_sql,
|
||||
|
@ -232,6 +233,7 @@ class BigQuery(Dialect):
|
|||
exp.GroupConcat: rename_func("STRING_AGG"),
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.IntDiv: rename_func("DIV"),
|
||||
exp.Min: min_or_least,
|
||||
exp.Select: transforms.preprocess(
|
||||
[_unqualify_unnest], transforms.delegate("select_sql")
|
||||
),
|
||||
|
|
|
@ -407,6 +407,11 @@ def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
|
|||
return f"CAST({self.sql(expression, 'this')} AS DATE)"
|
||||
|
||||
|
||||
def min_or_least(self: Generator, expression: exp.Min) -> str:
|
||||
name = "LEAST" if expression.expressions else "MIN"
|
||||
return rename_func(name)(self, expression)
|
||||
|
||||
|
||||
def trim_sql(self: Generator, expression: exp.Trim) -> str:
|
||||
target = self.sql(expression, "this")
|
||||
trim_type = self.sql(expression, "position")
|
||||
|
|
|
@ -11,7 +11,6 @@ from sqlglot.dialects.dialect import (
|
|||
no_pivot_sql,
|
||||
no_properties_sql,
|
||||
no_safe_divide_sql,
|
||||
no_tablesample_sql,
|
||||
rename_func,
|
||||
str_position_sql,
|
||||
str_to_time_sql,
|
||||
|
@ -155,7 +154,6 @@ class DuckDB(Dialect):
|
|||
exp.StrToTime: str_to_time_sql,
|
||||
exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))",
|
||||
exp.Struct: _struct_sql,
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))",
|
||||
|
@ -179,3 +177,6 @@ class DuckDB(Dialect):
|
|||
**generator.Generator.STAR_MAPPING,
|
||||
"except": "EXCLUDE",
|
||||
}
|
||||
|
||||
def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str:
|
||||
return super().tablesample_sql(expression, seed_prefix="REPEATABLE")
|
||||
|
|
|
@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
|
|||
format_time_lambda,
|
||||
if_sql,
|
||||
locate_to_strposition,
|
||||
min_or_least,
|
||||
no_ilike_sql,
|
||||
no_recursive_cte_sql,
|
||||
no_safe_divide_sql,
|
||||
|
@ -291,6 +292,7 @@ class Hive(Dialect):
|
|||
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
|
||||
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
|
||||
exp.Map: var_map_sql,
|
||||
exp.Min: min_or_least,
|
||||
exp.VarMap: var_map_sql,
|
||||
exp.Create: create_with_partitions_sql,
|
||||
exp.Quantile: rename_func("PERCENTILE"),
|
||||
|
|
|
@ -4,6 +4,7 @@ from sqlglot import exp, generator, parser, tokens
|
|||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
locate_to_strposition,
|
||||
min_or_least,
|
||||
no_ilike_sql,
|
||||
no_paren_current_date_sql,
|
||||
no_tablesample_sql,
|
||||
|
@ -179,7 +180,7 @@ class MySQL(Dialect):
|
|||
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA} # type: ignore
|
||||
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE} # type: ignore
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
|
@ -441,6 +442,7 @@ class MySQL(Dialect):
|
|||
exp.CurrentDate: no_paren_current_date_sql,
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.Min: min_or_least,
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.DateAdd: _date_add_sql("ADD"),
|
||||
|
|
|
@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import (
|
|||
arrow_json_extract_scalar_sql,
|
||||
arrow_json_extract_sql,
|
||||
format_time_lambda,
|
||||
min_or_least,
|
||||
no_paren_current_date_sql,
|
||||
no_tablesample_sql,
|
||||
no_trycast_sql,
|
||||
|
@ -229,6 +230,7 @@ class Postgres(Dialect):
|
|||
"REFRESH": TokenType.COMMAND,
|
||||
"REINDEX": TokenType.COMMAND,
|
||||
"RESET": TokenType.COMMAND,
|
||||
"RETURNING": TokenType.RETURNING,
|
||||
"REVOKE": TokenType.COMMAND,
|
||||
"SERIAL": TokenType.SERIAL,
|
||||
"SMALLSERIAL": TokenType.SMALLSERIAL,
|
||||
|
@ -296,6 +298,7 @@ class Postgres(Dialect):
|
|||
exp.DateSub: _date_add_sql("-"),
|
||||
exp.DateDiff: _date_diff_sql,
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.Min: min_or_least,
|
||||
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
|
||||
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
|
||||
exp.StrPosition: str_position_sql,
|
||||
|
|
|
@ -53,6 +53,7 @@ class Redshift(Postgres):
|
|||
"SUPER": TokenType.SUPER,
|
||||
"TIME": TokenType.TIMESTAMP,
|
||||
"TIMETZ": TokenType.TIMESTAMPTZ,
|
||||
"TOP": TokenType.TOP,
|
||||
"UNLOAD": TokenType.COMMAND,
|
||||
"VARBYTE": TokenType.VARBINARY,
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
|
|||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
inline_array_sql,
|
||||
min_or_least,
|
||||
rename_func,
|
||||
timestrtotime_sql,
|
||||
ts_or_ds_to_date_sql,
|
||||
|
@ -116,10 +117,16 @@ def _div0_to_if(args):
|
|||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
|
||||
def _zeroifnull_to_if(args):
|
||||
cond = exp.EQ(this=seq_get(args, 0), expression=exp.Null())
|
||||
cond = exp.Is(this=seq_get(args, 0), expression=exp.Null())
|
||||
return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0))
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
|
||||
def _nullifzero_to_if(args):
|
||||
cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0))
|
||||
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
|
||||
|
||||
|
||||
def _datatype_sql(self, expression):
|
||||
if expression.this == exp.DataType.Type.ARRAY:
|
||||
return "ARRAY"
|
||||
|
@ -167,6 +174,11 @@ class Snowflake(Dialect):
|
|||
**parser.Parser.FUNCTIONS,
|
||||
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
|
||||
"ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list,
|
||||
"DATEADD": lambda args: exp.DateAdd(
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
unit=seq_get(args, 0),
|
||||
),
|
||||
"DATE_TRUNC": lambda args: exp.DateTrunc(
|
||||
unit=exp.Literal.string(seq_get(args, 0).name), # type: ignore
|
||||
this=seq_get(args, 1),
|
||||
|
@ -180,6 +192,7 @@ class Snowflake(Dialect):
|
|||
"DECODE": exp.Matches.from_arg_list,
|
||||
"OBJECT_CONSTRUCT": parser.parse_var_map,
|
||||
"ZEROIFNULL": _zeroifnull_to_if,
|
||||
"NULLIFZERO": _nullifzero_to_if,
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
|
@ -254,6 +267,7 @@ class Snowflake(Dialect):
|
|||
class Generator(generator.Generator):
|
||||
PARAMETER_TOKEN = "$"
|
||||
INTEGER_DIVISION = False
|
||||
MATCHED_BY_SOURCE = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
|
@ -278,6 +292,7 @@ class Snowflake(Dialect):
|
|||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"),
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.Min: min_or_least,
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
|
@ -343,11 +358,10 @@ class Snowflake(Dialect):
|
|||
expression. This might not be true in a case where the same column name can be sourced from another table that can
|
||||
properly quote but should be true in most cases.
|
||||
"""
|
||||
values_expressions = expression.find_all(exp.Values)
|
||||
values_identifiers = set(
|
||||
flatten(
|
||||
v.args.get("alias", exp.Alias()).args.get("columns", [])
|
||||
for v in values_expressions
|
||||
(v.args.get("alias") or exp.Alias()).args.get("columns", [])
|
||||
for v in expression.find_all(exp.Values)
|
||||
)
|
||||
)
|
||||
if values_identifiers:
|
||||
|
|
|
@ -13,10 +13,6 @@ from sqlglot.dialects.dialect import (
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _fetch_sql(self, expression):
|
||||
return self.limit_sql(exp.Limit(expression=expression.args.get("count")))
|
||||
|
||||
|
||||
# https://www.sqlite.org/lang_aggfunc.html#group_concat
|
||||
def _group_concat_sql(self, expression):
|
||||
this = expression.this
|
||||
|
@ -94,9 +90,17 @@ class SQLite(Dialect):
|
|||
exp.TimeStrToTime: lambda self, e: self.sql(e, "this"),
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.GroupConcat: _group_concat_sql,
|
||||
exp.Fetch: _fetch_sql,
|
||||
}
|
||||
|
||||
def fetch_sql(self, expression):
|
||||
return self.limit_sql(exp.Limit(expression=expression.args.get("count")))
|
||||
|
||||
def least_sql(self, expression):
|
||||
if len(expression.expressions) > 1:
|
||||
return rename_func("MIN")(self, expression)
|
||||
|
||||
return self.expressions(expression)
|
||||
|
||||
def transaction_sql(self, expression):
|
||||
this = expression.this
|
||||
this = f" {this}" if this else ""
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
from sqlglot.dialects.dialect import Dialect, min_or_least
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
|
@ -126,6 +126,11 @@ class Teradata(Dialect):
|
|||
exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX,
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.Min: min_or_least,
|
||||
}
|
||||
|
||||
def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str:
|
||||
return f"PARTITION BY {self.sql(expression, 'this')}"
|
||||
|
||||
|
|
|
@ -4,7 +4,12 @@ import re
|
|||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
min_or_least,
|
||||
parse_date_delta,
|
||||
rename_func,
|
||||
)
|
||||
from sqlglot.expressions import DataType
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.time import format_time
|
||||
|
@ -433,6 +438,7 @@ class TSQL(Dialect):
|
|||
exp.NumberToStr: _format_sql,
|
||||
exp.TimeToStr: _format_sql,
|
||||
exp.GroupConcat: _string_agg_sql,
|
||||
exp.Min: min_or_least,
|
||||
}
|
||||
|
||||
TRANSFORMS.pop(exp.ReturnsProperty)
|
||||
|
|
|
@ -1031,7 +1031,7 @@ class Constraint(Expression):
|
|||
|
||||
|
||||
class Delete(Expression):
|
||||
arg_types = {"with": False, "this": False, "using": False, "where": False}
|
||||
arg_types = {"with": False, "this": False, "using": False, "where": False, "returning": False}
|
||||
|
||||
|
||||
class Drop(Expression):
|
||||
|
@ -1132,6 +1132,7 @@ class Insert(Expression):
|
|||
"with": False,
|
||||
"this": True,
|
||||
"expression": False,
|
||||
"returning": False,
|
||||
"overwrite": False,
|
||||
"exists": False,
|
||||
"partition": False,
|
||||
|
@ -1139,6 +1140,10 @@ class Insert(Expression):
|
|||
}
|
||||
|
||||
|
||||
class Returning(Expression):
|
||||
arg_types = {"expressions": True}
|
||||
|
||||
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/charset-introducer.html
|
||||
class Introducer(Expression):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
@ -1747,6 +1752,7 @@ QUERY_MODIFIERS = {
|
|||
"limit": False,
|
||||
"offset": False,
|
||||
"lock": False,
|
||||
"sample": False,
|
||||
}
|
||||
|
||||
|
||||
|
@ -1895,6 +1901,7 @@ class Update(Expression):
|
|||
"expressions": True,
|
||||
"from": False,
|
||||
"where": False,
|
||||
"returning": False,
|
||||
}
|
||||
|
||||
|
||||
|
@ -2401,6 +2408,18 @@ class Select(Subqueryable):
|
|||
**opts,
|
||||
)
|
||||
|
||||
def qualify(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
|
||||
return _apply_conjunction_builder(
|
||||
*expressions,
|
||||
instance=self,
|
||||
arg="qualify",
|
||||
append=append,
|
||||
into=Qualify,
|
||||
dialect=dialect,
|
||||
copy=copy,
|
||||
**opts,
|
||||
)
|
||||
|
||||
def distinct(self, distinct=True, copy=True) -> Select:
|
||||
"""
|
||||
Set the OFFSET expression.
|
||||
|
@ -2531,6 +2550,7 @@ class TableSample(Expression):
|
|||
"rows": False,
|
||||
"size": False,
|
||||
"seed": False,
|
||||
"kind": False,
|
||||
}
|
||||
|
||||
|
||||
|
@ -3423,7 +3443,7 @@ class JSONBExtractScalar(JSONExtract):
|
|||
|
||||
|
||||
class Least(Func):
|
||||
arg_types = {"this": True, "expressions": False}
|
||||
arg_types = {"expressions": False}
|
||||
is_var_len_args = True
|
||||
|
||||
|
||||
|
@ -3485,11 +3505,13 @@ class Matches(Func):
|
|||
|
||||
|
||||
class Max(AggFunc):
|
||||
arg_types = {"this": True, "expression": False}
|
||||
arg_types = {"this": True, "expressions": False}
|
||||
is_var_len_args = True
|
||||
|
||||
|
||||
class Min(AggFunc):
|
||||
arg_types = {"this": True, "expression": False}
|
||||
arg_types = {"this": True, "expressions": False}
|
||||
is_var_len_args = True
|
||||
|
||||
|
||||
class Month(Func):
|
||||
|
@ -3764,7 +3786,7 @@ class Merge(Expression):
|
|||
|
||||
|
||||
class When(Func):
|
||||
arg_types = {"this": True, "then": True}
|
||||
arg_types = {"matched": True, "source": False, "condition": False, "then": True}
|
||||
|
||||
|
||||
def _norm_args(expression):
|
||||
|
|
|
@ -112,6 +112,9 @@ class Generator:
|
|||
# Whether or not to treat the division operator "/" as integer division
|
||||
INTEGER_DIVISION = True
|
||||
|
||||
# Whether or not MERGE ... WHEN MATCHED BY SOURCE is allowed
|
||||
MATCHED_BY_SOURCE = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
exp.DataType.Type.NCHAR: "CHAR",
|
||||
exp.DataType.Type.NVARCHAR: "VARCHAR",
|
||||
|
@ -688,7 +691,8 @@ class Generator:
|
|||
else ""
|
||||
)
|
||||
where_sql = self.sql(expression, "where")
|
||||
sql = f"DELETE{this}{using_sql}{where_sql}"
|
||||
returning = self.sql(expression, "returning")
|
||||
sql = f"DELETE{this}{using_sql}{where_sql}{returning}"
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
def drop_sql(self, expression: exp.Drop) -> str:
|
||||
|
@ -952,8 +956,9 @@ class Generator:
|
|||
self.sql(expression, "partition") if expression.args.get("partition") else ""
|
||||
)
|
||||
expression_sql = self.sql(expression, "expression")
|
||||
returning = self.sql(expression, "returning")
|
||||
sep = self.sep() if partition_sql else ""
|
||||
sql = f"INSERT{alternative}{this}{exists}{partition_sql}{sep}{expression_sql}"
|
||||
sql = f"INSERT{alternative}{this}{exists}{partition_sql}{sep}{expression_sql}{returning}"
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
def intersect_sql(self, expression: exp.Intersect) -> str:
|
||||
|
@ -971,6 +976,9 @@ class Generator:
|
|||
def pseudotype_sql(self, expression: exp.PseudoType) -> str:
|
||||
return expression.name.upper()
|
||||
|
||||
def returning_sql(self, expression: exp.Returning) -> str:
|
||||
return f"{self.seg('RETURNING')} {self.expressions(expression, flat=True)}"
|
||||
|
||||
def rowformatdelimitedproperty_sql(self, expression: exp.RowFormatDelimitedProperty) -> str:
|
||||
fields = expression.args.get("fields")
|
||||
fields = f" FIELDS TERMINATED BY {fields}" if fields else ""
|
||||
|
@ -1009,7 +1017,7 @@ class Generator:
|
|||
|
||||
return f"{table}{system_time}{alias}{hints}{laterals}{joins}{pivots}"
|
||||
|
||||
def tablesample_sql(self, expression: exp.TableSample) -> str:
|
||||
def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str:
|
||||
if self.alias_post_tablesample and expression.this.alias:
|
||||
this = self.sql(expression.this, "this")
|
||||
alias = f" AS {self.sql(expression.this, 'alias')}"
|
||||
|
@ -1017,7 +1025,7 @@ class Generator:
|
|||
this = self.sql(expression, "this")
|
||||
alias = ""
|
||||
method = self.sql(expression, "method")
|
||||
method = f" {method.upper()} " if method else ""
|
||||
method = f"{method.upper()} " if method else ""
|
||||
numerator = self.sql(expression, "bucket_numerator")
|
||||
denominator = self.sql(expression, "bucket_denominator")
|
||||
field = self.sql(expression, "bucket_field")
|
||||
|
@ -1029,8 +1037,9 @@ class Generator:
|
|||
rows = f"{rows} ROWS" if rows else ""
|
||||
size = self.sql(expression, "size")
|
||||
seed = self.sql(expression, "seed")
|
||||
seed = f" SEED ({seed})" if seed else ""
|
||||
return f"{this} TABLESAMPLE{method}({bucket}{percent}{rows}{size}){seed}{alias}"
|
||||
seed = f" {seed_prefix} ({seed})" if seed else ""
|
||||
kind = expression.args.get("kind", "TABLESAMPLE")
|
||||
return f"{this} {kind} {method}({bucket}{percent}{rows}{size}){seed}{alias}"
|
||||
|
||||
def pivot_sql(self, expression: exp.Pivot) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -1050,7 +1059,8 @@ class Generator:
|
|||
set_sql = self.expressions(expression, flat=True)
|
||||
from_sql = self.sql(expression, "from")
|
||||
where_sql = self.sql(expression, "where")
|
||||
sql = f"UPDATE {this} SET {set_sql}{from_sql}{where_sql}"
|
||||
returning = self.sql(expression, "returning")
|
||||
sql = f"UPDATE {this} SET {set_sql}{from_sql}{where_sql}{returning}"
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
def values_sql(self, expression: exp.Values) -> str:
|
||||
|
@ -1297,6 +1307,7 @@ class Generator:
|
|||
self.sql(expression, "limit"),
|
||||
self.sql(expression, "offset"),
|
||||
self.sql(expression, "lock"),
|
||||
self.sql(expression, "sample"),
|
||||
sep="",
|
||||
)
|
||||
|
||||
|
@ -1956,7 +1967,11 @@ class Generator:
|
|||
return self.binary(expression, "=>")
|
||||
|
||||
def when_sql(self, expression: exp.When) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
matched = "MATCHED" if expression.args["matched"] else "NOT MATCHED"
|
||||
source = " BY SOURCE" if self.MATCHED_BY_SOURCE and expression.args.get("source") else ""
|
||||
condition = self.sql(expression, "condition")
|
||||
condition = f" AND {condition}" if condition else ""
|
||||
|
||||
then_expression = expression.args.get("then")
|
||||
if isinstance(then_expression, exp.Insert):
|
||||
then = f"INSERT {self.sql(then_expression, 'this')}"
|
||||
|
@ -1969,7 +1984,7 @@ class Generator:
|
|||
then = f"UPDATE SET {self.expressions(then_expression, flat=True)}"
|
||||
else:
|
||||
then = self.sql(then_expression)
|
||||
return f"WHEN {this} THEN {then}"
|
||||
return f"WHEN {matched}{source}{condition} THEN {then}"
|
||||
|
||||
def merge_sql(self, expression: exp.Merge) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
|
|
@ -55,9 +55,11 @@ class TypeAnnotator:
|
|||
expr, exp.DataType.Type.BIGINT
|
||||
),
|
||||
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Min: lambda self, expr: self._annotate_by_args(expr, "this"),
|
||||
exp.Max: lambda self, expr: self._annotate_by_args(expr, "this"),
|
||||
exp.Sum: lambda self, expr: self._annotate_by_args(expr, "this", promote=True),
|
||||
exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
|
||||
exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
|
||||
exp.Sum: lambda self, expr: self._annotate_by_args(
|
||||
expr, "this", "expressions", promote=True
|
||||
),
|
||||
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
|
@ -114,6 +116,7 @@ class TypeAnnotator:
|
|||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
|
||||
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
|
|
|
@ -434,6 +434,7 @@ class Parser(metaclass=_Parser):
|
|||
exp.Having: lambda self: self._parse_having(),
|
||||
exp.With: lambda self: self._parse_with(),
|
||||
exp.Window: lambda self: self._parse_named_window(),
|
||||
exp.Qualify: lambda self: self._parse_qualify(),
|
||||
"JOIN_TYPE": lambda self: self._parse_join_side_and_kind(),
|
||||
}
|
||||
|
||||
|
@ -688,6 +689,7 @@ class Parser(metaclass=_Parser):
|
|||
"limit": lambda self: self._parse_limit(),
|
||||
"offset": lambda self: self._parse_offset(),
|
||||
"lock": lambda self: self._parse_lock(),
|
||||
"sample": lambda self: self._parse_table_sample(as_modifier=True),
|
||||
}
|
||||
|
||||
SHOW_PARSERS: t.Dict[str, t.Callable] = {}
|
||||
|
@ -953,7 +955,8 @@ class Parser(metaclass=_Parser):
|
|||
self._prev_comments = None
|
||||
|
||||
def _retreat(self, index: int) -> None:
|
||||
self._advance(index - self._index)
|
||||
if index != self._index:
|
||||
self._advance(index - self._index)
|
||||
|
||||
def _parse_command(self) -> exp.Expression:
|
||||
return self.expression(exp.Command, this=self._prev.text, expression=self._parse_string())
|
||||
|
@ -1515,12 +1518,10 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_insert(self) -> exp.Expression:
|
||||
overwrite = self._match(TokenType.OVERWRITE)
|
||||
local = self._match(TokenType.LOCAL)
|
||||
|
||||
this: t.Optional[exp.Expression]
|
||||
|
||||
alternative = None
|
||||
|
||||
if self._match_text_seq("DIRECTORY"):
|
||||
this = self.expression(
|
||||
this: t.Optional[exp.Expression] = self.expression(
|
||||
exp.Directory,
|
||||
this=self._parse_var_or_string(),
|
||||
local=local,
|
||||
|
@ -1540,10 +1541,17 @@ class Parser(metaclass=_Parser):
|
|||
exists=self._parse_exists(),
|
||||
partition=self._parse_partition(),
|
||||
expression=self._parse_ddl_select(),
|
||||
returning=self._parse_returning(),
|
||||
overwrite=overwrite,
|
||||
alternative=alternative,
|
||||
)
|
||||
|
||||
def _parse_returning(self) -> t.Optional[exp.Expression]:
|
||||
if not self._match(TokenType.RETURNING):
|
||||
return None
|
||||
|
||||
return self.expression(exp.Returning, expressions=self._parse_csv(self._parse_column))
|
||||
|
||||
def _parse_row(self) -> t.Optional[exp.Expression]:
|
||||
if not self._match(TokenType.FORMAT):
|
||||
return None
|
||||
|
@ -1601,6 +1609,7 @@ class Parser(metaclass=_Parser):
|
|||
this=self._parse_table(schema=True),
|
||||
using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table()),
|
||||
where=self._parse_where(),
|
||||
returning=self._parse_returning(),
|
||||
)
|
||||
|
||||
def _parse_update(self) -> exp.Expression:
|
||||
|
@ -1611,6 +1620,7 @@ class Parser(metaclass=_Parser):
|
|||
"expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality),
|
||||
"from": self._parse_from(),
|
||||
"where": self._parse_where(),
|
||||
"returning": self._parse_returning(),
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -2156,11 +2166,12 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return self.expression(exp.Values, expressions=expressions, alias=self._parse_table_alias())
|
||||
|
||||
def _parse_table_sample(self) -> t.Optional[exp.Expression]:
|
||||
if not self._match(TokenType.TABLE_SAMPLE):
|
||||
def _parse_table_sample(self, as_modifier: bool = False) -> t.Optional[exp.Expression]:
|
||||
if not self._match(TokenType.TABLE_SAMPLE) and not (
|
||||
as_modifier and self._match_text_seq("USING", "SAMPLE")
|
||||
):
|
||||
return None
|
||||
|
||||
method = self._parse_var()
|
||||
bucket_numerator = None
|
||||
bucket_denominator = None
|
||||
bucket_field = None
|
||||
|
@ -2169,7 +2180,12 @@ class Parser(metaclass=_Parser):
|
|||
size = None
|
||||
seed = None
|
||||
|
||||
self._match_l_paren()
|
||||
kind = "TABLESAMPLE" if self._prev.token_type == TokenType.TABLE_SAMPLE else "USING SAMPLE"
|
||||
method = self._parse_var(tokens=(TokenType.ROW,))
|
||||
|
||||
self._match(TokenType.L_PAREN)
|
||||
|
||||
num = self._parse_number()
|
||||
|
||||
if self._match(TokenType.BUCKET):
|
||||
bucket_numerator = self._parse_number()
|
||||
|
@ -2177,19 +2193,20 @@ class Parser(metaclass=_Parser):
|
|||
bucket_denominator = bucket_denominator = self._parse_number()
|
||||
self._match(TokenType.ON)
|
||||
bucket_field = self._parse_field()
|
||||
elif self._match_set((TokenType.PERCENT, TokenType.MOD)):
|
||||
percent = num
|
||||
elif self._match(TokenType.ROWS):
|
||||
rows = num
|
||||
else:
|
||||
num = self._parse_number()
|
||||
size = num
|
||||
|
||||
if self._match(TokenType.PERCENT):
|
||||
percent = num
|
||||
elif self._match(TokenType.ROWS):
|
||||
rows = num
|
||||
else:
|
||||
size = num
|
||||
self._match(TokenType.R_PAREN)
|
||||
|
||||
self._match_r_paren()
|
||||
|
||||
if self._match(TokenType.SEED):
|
||||
if self._match(TokenType.L_PAREN):
|
||||
method = self._parse_var()
|
||||
seed = self._match(TokenType.COMMA) and self._parse_number()
|
||||
self._match_r_paren()
|
||||
elif self._match_texts(("SEED", "REPEATABLE")):
|
||||
seed = self._parse_wrapped(self._parse_number)
|
||||
|
||||
return self.expression(
|
||||
|
@ -2202,6 +2219,7 @@ class Parser(metaclass=_Parser):
|
|||
rows=rows,
|
||||
size=size,
|
||||
seed=seed,
|
||||
kind=kind,
|
||||
)
|
||||
|
||||
def _parse_pivots(self) -> t.List[t.Optional[exp.Expression]]:
|
||||
|
@ -2531,7 +2549,7 @@ class Parser(metaclass=_Parser):
|
|||
this = self._parse_column()
|
||||
|
||||
if type_token:
|
||||
if this and not isinstance(this, exp.Star):
|
||||
if isinstance(this, exp.Literal):
|
||||
return self.expression(exp.Cast, this=this, to=type_token)
|
||||
if not type_token.args.get("expressions"):
|
||||
self._retreat(index)
|
||||
|
@ -2626,7 +2644,12 @@ class Parser(metaclass=_Parser):
|
|||
if value is None:
|
||||
value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)
|
||||
elif type_token == TokenType.INTERVAL:
|
||||
value = self.expression(exp.Interval, unit=self._parse_var())
|
||||
unit = self._parse_var()
|
||||
|
||||
if not unit:
|
||||
value = self.expression(exp.DataType, this=exp.DataType.Type.INTERVAL)
|
||||
else:
|
||||
value = self.expression(exp.Interval, unit=unit)
|
||||
|
||||
if maybe_func and check_func:
|
||||
index2 = self._index
|
||||
|
@ -3495,8 +3518,14 @@ class Parser(metaclass=_Parser):
|
|||
return self.expression(exp.Identifier, this=self._prev.text, quoted=True)
|
||||
return self._parse_placeholder()
|
||||
|
||||
def _parse_var(self, any_token: bool = False) -> t.Optional[exp.Expression]:
|
||||
if (any_token and self._advance_any()) or self._match(TokenType.VAR):
|
||||
def _parse_var(
|
||||
self, any_token: bool = False, tokens: t.Optional[t.Collection[TokenType]] = None
|
||||
) -> t.Optional[exp.Expression]:
|
||||
if (
|
||||
(any_token and self._advance_any())
|
||||
or self._match(TokenType.VAR)
|
||||
or (self._match_set(tokens) if tokens else False)
|
||||
):
|
||||
return self.expression(exp.Var, this=self._prev.text)
|
||||
return self._parse_placeholder()
|
||||
|
||||
|
@ -3732,19 +3761,26 @@ class Parser(metaclass=_Parser):
|
|||
return self.expression(exp.RenameTable, this=self._parse_table(schema=True))
|
||||
|
||||
def _parse_alter(self) -> t.Optional[exp.Expression]:
|
||||
start = self._prev
|
||||
|
||||
if not self._match(TokenType.TABLE):
|
||||
return self._parse_as_command(self._prev)
|
||||
return self._parse_as_command(start)
|
||||
|
||||
exists = self._parse_exists()
|
||||
this = self._parse_table(schema=True)
|
||||
|
||||
if not self._curr:
|
||||
return None
|
||||
if self._next:
|
||||
self._advance()
|
||||
parser = self.ALTER_PARSERS.get(self._prev.text.upper()) if self._prev else None
|
||||
|
||||
parser = self.ALTER_PARSERS.get(self._curr.text.upper())
|
||||
actions = ensure_list(self._advance() or parser(self)) if parser else [] # type: ignore
|
||||
|
||||
return self.expression(exp.AlterTable, this=this, exists=exists, actions=actions)
|
||||
if parser:
|
||||
return self.expression(
|
||||
exp.AlterTable,
|
||||
this=this,
|
||||
exists=exists,
|
||||
actions=ensure_list(parser(self)),
|
||||
)
|
||||
return self._parse_as_command(start)
|
||||
|
||||
def _parse_show(self) -> t.Optional[exp.Expression]:
|
||||
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) # type: ignore
|
||||
|
@ -3775,7 +3811,15 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
whens = []
|
||||
while self._match(TokenType.WHEN):
|
||||
this = self._parse_conjunction()
|
||||
matched = not self._match(TokenType.NOT)
|
||||
self._match_text_seq("MATCHED")
|
||||
source = (
|
||||
False
|
||||
if self._match_text_seq("BY", "TARGET")
|
||||
else self._match_text_seq("BY", "SOURCE")
|
||||
)
|
||||
condition = self._parse_conjunction() if self._match(TokenType.AND) else None
|
||||
|
||||
self._match(TokenType.THEN)
|
||||
|
||||
if self._match(TokenType.INSERT):
|
||||
|
@ -3800,8 +3844,18 @@ class Parser(metaclass=_Parser):
|
|||
)
|
||||
elif self._match(TokenType.DELETE):
|
||||
then = self.expression(exp.Var, this=self._prev.text)
|
||||
else:
|
||||
then = None
|
||||
|
||||
whens.append(self.expression(exp.When, this=this, then=then))
|
||||
whens.append(
|
||||
self.expression(
|
||||
exp.When,
|
||||
matched=matched,
|
||||
source=source,
|
||||
condition=condition,
|
||||
then=then,
|
||||
)
|
||||
)
|
||||
|
||||
return self.expression(
|
||||
exp.Merge,
|
||||
|
|
|
@ -855,11 +855,12 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
def _scan_keywords(self) -> None:
|
||||
size = 0
|
||||
word = None
|
||||
chars: t.Optional[str] = self._text
|
||||
chars = self._text
|
||||
char = chars
|
||||
prev_space = False
|
||||
skip = False
|
||||
trie = self.KEYWORD_TRIE
|
||||
single_token = char in self.SINGLE_TOKENS
|
||||
|
||||
while chars:
|
||||
if skip:
|
||||
|
@ -876,6 +877,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
if end < self.size:
|
||||
char = self.sql[end]
|
||||
single_token = single_token or char in self.SINGLE_TOKENS
|
||||
is_space = char in self.WHITE_SPACE
|
||||
|
||||
if not is_space or not prev_space:
|
||||
|
@ -887,7 +889,9 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
else:
|
||||
skip = True
|
||||
else:
|
||||
chars = None
|
||||
chars = " "
|
||||
|
||||
word = None if not single_token and chars[-1] not in self.WHITE_SPACE else word
|
||||
|
||||
if not word:
|
||||
if self._char in self.SINGLE_TOKENS:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue