Merging upstream version 25.7.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
dba379232c
commit
aa0eae236a
102 changed files with 52995 additions and 52070 deletions
|
@ -334,6 +334,11 @@ class ClickHouse(Dialect):
|
|||
|
||||
RESERVED_TOKENS = parser.Parser.RESERVED_TOKENS - {TokenType.SELECT}
|
||||
|
||||
ID_VAR_TOKENS = {
|
||||
*parser.Parser.ID_VAR_TOKENS,
|
||||
TokenType.LIKE,
|
||||
}
|
||||
|
||||
AGG_FUNC_MAPPING = (
|
||||
lambda functions, suffixes: {
|
||||
f"{f}{sfx}": (f, sfx) for sfx in (suffixes + [""]) for f in functions
|
||||
|
|
|
@ -8,7 +8,7 @@ from functools import reduce
|
|||
from sqlglot import exp
|
||||
from sqlglot.errors import ParseError
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import AutoName, flatten, is_int, seq_get
|
||||
from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses
|
||||
from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.time import TIMEZONES, format_time
|
||||
|
@ -23,6 +23,10 @@ JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
|
|||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import B, E, F
|
||||
|
||||
from sqlglot.optimizer.annotate_types import TypeAnnotator
|
||||
|
||||
AnnotatorsType = t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
UNESCAPED_SEQUENCES = {
|
||||
|
@ -37,6 +41,10 @@ UNESCAPED_SEQUENCES = {
|
|||
}
|
||||
|
||||
|
||||
def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]:
|
||||
return lambda self, e: self._annotate_with_type(e, data_type)
|
||||
|
||||
|
||||
class Dialects(str, Enum):
|
||||
"""Dialects supported by SQLGLot."""
|
||||
|
||||
|
@ -489,6 +497,167 @@ class Dialect(metaclass=_Dialect):
|
|||
"CENTURIES": "CENTURY",
|
||||
}
|
||||
|
||||
TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
|
||||
exp.DataType.Type.BIGINT: {
|
||||
exp.ApproxDistinct,
|
||||
exp.ArraySize,
|
||||
exp.Count,
|
||||
exp.Length,
|
||||
},
|
||||
exp.DataType.Type.BOOLEAN: {
|
||||
exp.Between,
|
||||
exp.Boolean,
|
||||
exp.In,
|
||||
exp.RegexpLike,
|
||||
},
|
||||
exp.DataType.Type.DATE: {
|
||||
exp.CurrentDate,
|
||||
exp.Date,
|
||||
exp.DateFromParts,
|
||||
exp.DateStrToDate,
|
||||
exp.DiToDate,
|
||||
exp.StrToDate,
|
||||
exp.TimeStrToDate,
|
||||
exp.TsOrDsToDate,
|
||||
},
|
||||
exp.DataType.Type.DATETIME: {
|
||||
exp.CurrentDatetime,
|
||||
exp.Datetime,
|
||||
exp.DatetimeAdd,
|
||||
exp.DatetimeSub,
|
||||
},
|
||||
exp.DataType.Type.DOUBLE: {
|
||||
exp.ApproxQuantile,
|
||||
exp.Avg,
|
||||
exp.Div,
|
||||
exp.Exp,
|
||||
exp.Ln,
|
||||
exp.Log,
|
||||
exp.Pow,
|
||||
exp.Quantile,
|
||||
exp.Round,
|
||||
exp.SafeDivide,
|
||||
exp.Sqrt,
|
||||
exp.Stddev,
|
||||
exp.StddevPop,
|
||||
exp.StddevSamp,
|
||||
exp.Variance,
|
||||
exp.VariancePop,
|
||||
},
|
||||
exp.DataType.Type.INT: {
|
||||
exp.Ceil,
|
||||
exp.DatetimeDiff,
|
||||
exp.DateDiff,
|
||||
exp.TimestampDiff,
|
||||
exp.TimeDiff,
|
||||
exp.DateToDi,
|
||||
exp.Levenshtein,
|
||||
exp.Sign,
|
||||
exp.StrPosition,
|
||||
exp.TsOrDiToDi,
|
||||
},
|
||||
exp.DataType.Type.JSON: {
|
||||
exp.ParseJSON,
|
||||
},
|
||||
exp.DataType.Type.TIME: {
|
||||
exp.Time,
|
||||
},
|
||||
exp.DataType.Type.TIMESTAMP: {
|
||||
exp.CurrentTime,
|
||||
exp.CurrentTimestamp,
|
||||
exp.StrToTime,
|
||||
exp.TimeAdd,
|
||||
exp.TimeStrToTime,
|
||||
exp.TimeSub,
|
||||
exp.TimestampAdd,
|
||||
exp.TimestampSub,
|
||||
exp.UnixToTime,
|
||||
},
|
||||
exp.DataType.Type.TINYINT: {
|
||||
exp.Day,
|
||||
exp.Month,
|
||||
exp.Week,
|
||||
exp.Year,
|
||||
exp.Quarter,
|
||||
},
|
||||
exp.DataType.Type.VARCHAR: {
|
||||
exp.ArrayConcat,
|
||||
exp.Concat,
|
||||
exp.ConcatWs,
|
||||
exp.DateToDateStr,
|
||||
exp.GroupConcat,
|
||||
exp.Initcap,
|
||||
exp.Lower,
|
||||
exp.Substring,
|
||||
exp.TimeToStr,
|
||||
exp.TimeToTimeStr,
|
||||
exp.Trim,
|
||||
exp.TsOrDsToDateStr,
|
||||
exp.UnixToStr,
|
||||
exp.UnixToTimeStr,
|
||||
exp.Upper,
|
||||
},
|
||||
}
|
||||
|
||||
ANNOTATORS: AnnotatorsType = {
|
||||
**{
|
||||
expr_type: lambda self, e: self._annotate_unary(e)
|
||||
for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
|
||||
},
|
||||
**{
|
||||
expr_type: lambda self, e: self._annotate_binary(e)
|
||||
for expr_type in subclasses(exp.__name__, exp.Binary)
|
||||
},
|
||||
**{
|
||||
expr_type: _annotate_with_type_lambda(data_type)
|
||||
for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
|
||||
for expr_type in expressions
|
||||
},
|
||||
exp.Abs: lambda self, e: self._annotate_by_args(e, "this"),
|
||||
exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
|
||||
exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
|
||||
exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
|
||||
exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Bracket: lambda self, e: self._annotate_bracket(e),
|
||||
exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
|
||||
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
|
||||
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
|
||||
exp.DateAdd: lambda self, e: self._annotate_timeunit(e),
|
||||
exp.DateSub: lambda self, e: self._annotate_timeunit(e),
|
||||
exp.DateTrunc: lambda self, e: self._annotate_timeunit(e),
|
||||
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
|
||||
exp.Div: lambda self, e: self._annotate_div(e),
|
||||
exp.Dot: lambda self, e: self._annotate_dot(e),
|
||||
exp.Explode: lambda self, e: self._annotate_explode(e),
|
||||
exp.Extract: lambda self, e: self._annotate_extract(e),
|
||||
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
|
||||
exp.GenerateDateArray: lambda self, e: self._annotate_with_type(
|
||||
e, exp.DataType.build("ARRAY<DATE>")
|
||||
),
|
||||
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
|
||||
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
|
||||
exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
|
||||
exp.Literal: lambda self, e: self._annotate_literal(e),
|
||||
exp.Map: lambda self, e: self._annotate_map(e),
|
||||
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
|
||||
exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
|
||||
exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"),
|
||||
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
|
||||
exp.Struct: lambda self, e: self._annotate_struct(e),
|
||||
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
|
||||
exp.Timestamp: lambda self, e: self._annotate_with_type(
|
||||
e,
|
||||
exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP,
|
||||
),
|
||||
exp.ToMap: lambda self, e: self._annotate_to_map(e),
|
||||
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
|
||||
exp.Unnest: lambda self, e: self._annotate_unnest(e),
|
||||
exp.VarMap: lambda self, e: self._annotate_map(e),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_or_raise(cls, dialect: DialectType) -> Dialect:
|
||||
"""
|
||||
|
@ -1419,3 +1588,24 @@ def build_timestamp_from_parts(args: t.List) -> exp.Func:
|
|||
|
||||
def sha256_sql(self: Generator, expression: exp.SHA2) -> str:
|
||||
return self.func(f"SHA{expression.text('length') or '256'}", expression.this)
|
||||
|
||||
|
||||
def sequence_sql(self: Generator, expression: exp.GenerateSeries):
|
||||
start = expression.args["start"]
|
||||
end = expression.args["end"]
|
||||
step = expression.args.get("step")
|
||||
|
||||
if isinstance(start, exp.Cast):
|
||||
target_type = start.to
|
||||
elif isinstance(end, exp.Cast):
|
||||
target_type = end.to
|
||||
else:
|
||||
target_type = None
|
||||
|
||||
if target_type and target_type.is_type("timestamp"):
|
||||
if target_type is start.to:
|
||||
end = exp.cast(end, target_type)
|
||||
else:
|
||||
start = exp.cast(start, target_type)
|
||||
|
||||
return self.func("SEQUENCE", start, end, step)
|
||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.expressions import DATA_TYPE
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
JSON_EXTRACT_TYPE,
|
||||
|
@ -35,20 +36,34 @@ from sqlglot.dialects.dialect import (
|
|||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _ts_or_ds_add_sql(self: DuckDB.Generator, expression: exp.TsOrDsAdd) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
interval = self.sql(exp.Interval(this=expression.expression, unit=unit_to_var(expression)))
|
||||
return f"CAST({this} AS {self.sql(expression.return_type)}) + {interval}"
|
||||
DATETIME_DELTA = t.Union[
|
||||
exp.DateAdd, exp.TimeAdd, exp.DatetimeAdd, exp.TsOrDsAdd, exp.DateSub, exp.DatetimeSub
|
||||
]
|
||||
|
||||
|
||||
def _date_delta_sql(
|
||||
self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub | exp.TimeAdd
|
||||
) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
def _date_delta_sql(self: DuckDB.Generator, expression: DATETIME_DELTA) -> str:
|
||||
this = expression.this
|
||||
unit = unit_to_var(expression)
|
||||
op = "+" if isinstance(expression, (exp.DateAdd, exp.TimeAdd)) else "-"
|
||||
return f"{this} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
|
||||
op = (
|
||||
"+"
|
||||
if isinstance(expression, (exp.DateAdd, exp.TimeAdd, exp.DatetimeAdd, exp.TsOrDsAdd))
|
||||
else "-"
|
||||
)
|
||||
|
||||
to_type: t.Optional[DATA_TYPE] = None
|
||||
if isinstance(expression, exp.TsOrDsAdd):
|
||||
to_type = expression.return_type
|
||||
elif this.is_string:
|
||||
# Cast string literals (i.e function parameters) to the appropriate type for +/- interval to work
|
||||
to_type = (
|
||||
exp.DataType.Type.DATETIME
|
||||
if isinstance(expression, (exp.DatetimeAdd, exp.DatetimeSub))
|
||||
else exp.DataType.Type.DATE
|
||||
)
|
||||
|
||||
this = exp.cast(this, to_type) if to_type else this
|
||||
|
||||
return f"{self.sql(this)} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
|
||||
|
||||
|
||||
# BigQuery -> DuckDB conversion for the DATE function
|
||||
|
@ -119,7 +134,12 @@ def _struct_sql(self: DuckDB.Generator, expression: exp.Struct) -> str:
|
|||
|
||||
# BigQuery allows inline construction such as "STRUCT<a STRING, b INTEGER>('str', 1)" which is
|
||||
# canonicalized to "ROW('str', 1) AS STRUCT(a TEXT, b INT)" in DuckDB
|
||||
is_struct_cast = expression.find_ancestor(exp.Cast)
|
||||
# The transformation to ROW will take place if a cast to STRUCT / ARRAY of STRUCTs is found
|
||||
ancestor_cast = expression.find_ancestor(exp.Cast)
|
||||
is_struct_cast = ancestor_cast and any(
|
||||
casted_type.is_type(exp.DataType.Type.STRUCT)
|
||||
for casted_type in ancestor_cast.find_all(exp.DataType)
|
||||
)
|
||||
|
||||
for i, expr in enumerate(expression.expressions):
|
||||
is_property_eq = isinstance(expr, exp.PropertyEQ)
|
||||
|
@ -168,7 +188,7 @@ def _unix_to_time_sql(self: DuckDB.Generator, expression: exp.UnixToTime) -> str
|
|||
|
||||
def _arrow_json_extract_sql(self: DuckDB.Generator, expression: JSON_EXTRACT_TYPE) -> str:
|
||||
arrow_sql = arrow_json_extract_sql(self, expression)
|
||||
if not expression.same_parent and isinstance(expression.parent, exp.Binary):
|
||||
if not expression.same_parent and isinstance(expression.parent, (exp.Binary, exp.Bracket)):
|
||||
arrow_sql = self.wrap(arrow_sql)
|
||||
return arrow_sql
|
||||
|
||||
|
@ -420,6 +440,8 @@ class DuckDB(Dialect):
|
|||
),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.Datetime: no_datetime_sql,
|
||||
exp.DatetimeSub: _date_delta_sql,
|
||||
exp.DatetimeAdd: _date_delta_sql,
|
||||
exp.DateToDi: lambda self,
|
||||
e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)",
|
||||
exp.Decode: lambda self, e: encode_decode_sql(self, e, "DECODE", replace=False),
|
||||
|
@ -484,7 +506,7 @@ class DuckDB(Dialect):
|
|||
exp.TimeToUnix: rename_func("EPOCH"),
|
||||
exp.TsOrDiToDi: lambda self,
|
||||
e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
|
||||
exp.TsOrDsAdd: _ts_or_ds_add_sql,
|
||||
exp.TsOrDsAdd: _date_delta_sql,
|
||||
exp.TsOrDsDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF",
|
||||
f"'{e.args.get('unit') or 'DAY'}'",
|
||||
|
@ -790,3 +812,18 @@ class DuckDB(Dialect):
|
|||
)
|
||||
|
||||
return self.sql(case)
|
||||
|
||||
def objectinsert_sql(self, expression: exp.ObjectInsert) -> str:
|
||||
this = expression.this
|
||||
key = expression.args.get("key")
|
||||
key_sql = key.name if isinstance(key, exp.Expression) else ""
|
||||
value_sql = self.sql(expression, "value")
|
||||
|
||||
kv_sql = f"{key_sql} := {value_sql}"
|
||||
|
||||
# If the input struct is empty e.g. transpiling OBJECT_INSERT(OBJECT_CONSTRUCT(), key, value) from Snowflake
|
||||
# then we can generate STRUCT_PACK which will build it since STRUCT_INSERT({}, key := value) is not valid DuckDB
|
||||
if isinstance(this, exp.Struct) and not this.expressions:
|
||||
return self.func("STRUCT_PACK", kv_sql)
|
||||
|
||||
return self.func("STRUCT_INSERT", this, kv_sql)
|
||||
|
|
|
@ -31,6 +31,7 @@ from sqlglot.dialects.dialect import (
|
|||
timestrtotime_sql,
|
||||
unit_to_str,
|
||||
var_map_sql,
|
||||
sequence_sql,
|
||||
)
|
||||
from sqlglot.transforms import (
|
||||
remove_unique_constraints,
|
||||
|
@ -310,6 +311,7 @@ class Hive(Dialect):
|
|||
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
|
||||
this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2)
|
||||
),
|
||||
"SEQUENCE": exp.GenerateSeries.from_arg_list,
|
||||
"SIZE": exp.ArraySize.from_arg_list,
|
||||
"SPLIT": exp.RegexpSplit.from_arg_list,
|
||||
"STR_TO_MAP": lambda args: exp.StrToMap(
|
||||
|
@ -506,6 +508,7 @@ class Hive(Dialect):
|
|||
exp.FileFormatProperty: lambda self,
|
||||
e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
|
||||
exp.FromBase64: rename_func("UNBASE64"),
|
||||
exp.GenerateSeries: sequence_sql,
|
||||
exp.If: if_sql(),
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.IsNan: rename_func("ISNAN"),
|
||||
|
|
|
@ -691,6 +691,7 @@ class MySQL(Dialect):
|
|||
SUPPORTS_TO_NUMBER = False
|
||||
PARSE_JSON_NAME = None
|
||||
PAD_FILL_PATTERN_IS_REQUIRED = True
|
||||
WRAP_DERIVED_VALUES = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
|
|
@ -365,6 +365,7 @@ class Postgres(Dialect):
|
|||
"NOW": exp.CurrentTimestamp.from_arg_list,
|
||||
"REGEXP_REPLACE": _build_regexp_replace,
|
||||
"TO_CHAR": build_formatted_time(exp.TimeToStr, "postgres"),
|
||||
"TO_DATE": build_formatted_time(exp.StrToDate, "postgres"),
|
||||
"TO_TIMESTAMP": _build_to_timestamp,
|
||||
"UNNEST": exp.Explode.from_arg_list,
|
||||
"SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)),
|
||||
|
|
|
@ -28,6 +28,7 @@ from sqlglot.dialects.dialect import (
|
|||
timestrtotime_sql,
|
||||
ts_or_ds_add_cast,
|
||||
unit_to_str,
|
||||
sequence_sql,
|
||||
)
|
||||
from sqlglot.dialects.hive import Hive
|
||||
from sqlglot.dialects.mysql import MySQL
|
||||
|
@ -204,11 +205,11 @@ def _jsonextract_sql(self: Presto.Generator, expression: exp.JSONExtract) -> str
|
|||
return f"{this}{expr}"
|
||||
|
||||
|
||||
def _to_int(expression: exp.Expression) -> exp.Expression:
|
||||
def _to_int(self: Presto.Generator, expression: exp.Expression) -> exp.Expression:
|
||||
if not expression.type:
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
annotate_types(expression)
|
||||
annotate_types(expression, dialect=self.dialect)
|
||||
if expression.type and expression.type.this not in exp.DataType.INTEGER_TYPES:
|
||||
return exp.cast(expression, to=exp.DataType.Type.BIGINT)
|
||||
return expression
|
||||
|
@ -229,7 +230,7 @@ def _date_delta_sql(
|
|||
name: str, negate_interval: bool = False
|
||||
) -> t.Callable[[Presto.Generator, DATE_ADD_OR_SUB], str]:
|
||||
def _delta_sql(self: Presto.Generator, expression: DATE_ADD_OR_SUB) -> str:
|
||||
interval = _to_int(expression.expression)
|
||||
interval = _to_int(self, expression.expression)
|
||||
return self.func(
|
||||
name,
|
||||
unit_to_str(expression),
|
||||
|
@ -256,6 +257,21 @@ class Presto(Dialect):
|
|||
# https://github.com/prestodb/presto/issues/2863
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
|
||||
|
||||
# The result of certain math functions in Presto/Trino is of type
|
||||
# equal to the input type e.g: FLOOR(5.5/2) -> DECIMAL, FLOOR(5/2) -> BIGINT
|
||||
ANNOTATORS = {
|
||||
**Dialect.ANNOTATORS,
|
||||
exp.Floor: lambda self, e: self._annotate_by_args(e, "this"),
|
||||
exp.Ceil: lambda self, e: self._annotate_by_args(e, "this"),
|
||||
exp.Mod: lambda self, e: self._annotate_by_args(e, "this", "expression"),
|
||||
exp.Round: lambda self, e: self._annotate_by_args(e, "this"),
|
||||
exp.Sign: lambda self, e: self._annotate_by_args(e, "this"),
|
||||
exp.Abs: lambda self, e: self._annotate_by_args(e, "this"),
|
||||
exp.Rand: lambda self, e: self._annotate_by_args(e, "this")
|
||||
if e.this
|
||||
else self._set_type(e, exp.DataType.Type.DOUBLE),
|
||||
}
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
UNICODE_STRINGS = [
|
||||
(prefix + q, q)
|
||||
|
@ -420,6 +436,7 @@ class Presto(Dialect):
|
|||
exp.FirstValue: _first_last_sql,
|
||||
exp.FromTimeZone: lambda self,
|
||||
e: f"WITH_TIMEZONE({self.sql(e, 'this')}, {self.sql(e, 'zone')}) AT TIME ZONE 'UTC'",
|
||||
exp.GenerateSeries: sequence_sql,
|
||||
exp.Group: transforms.preprocess([transforms.unalias_group]),
|
||||
exp.GroupConcat: lambda self, e: self.func(
|
||||
"ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator")
|
||||
|
@ -572,11 +589,20 @@ class Presto(Dialect):
|
|||
# timezone involved, we wrap it in a `TRY` call and use `PARSE_DATETIME` as a fallback,
|
||||
# which seems to be using the same time mapping as Hive, as per:
|
||||
# https://joda-time.sourceforge.net/apidocs/org/joda/time/format/DateTimeFormat.html
|
||||
value_as_text = exp.cast(expression.this, exp.DataType.Type.TEXT)
|
||||
this = expression.this
|
||||
value_as_text = exp.cast(this, exp.DataType.Type.TEXT)
|
||||
value_as_timestamp = (
|
||||
exp.cast(this, exp.DataType.Type.TIMESTAMP) if this.is_string else this
|
||||
)
|
||||
|
||||
parse_without_tz = self.func("DATE_PARSE", value_as_text, self.format_time(expression))
|
||||
|
||||
formatted_value = self.func(
|
||||
"DATE_FORMAT", value_as_timestamp, self.format_time(expression)
|
||||
)
|
||||
parse_with_tz = self.func(
|
||||
"PARSE_DATETIME",
|
||||
value_as_text,
|
||||
formatted_value,
|
||||
self.format_time(expression, Hive.INVERSE_TIME_MAPPING, Hive.INVERSE_TIME_TRIE),
|
||||
)
|
||||
coalesced = self.func("COALESCE", self.func("TRY", parse_without_tz), parse_with_tz)
|
||||
|
@ -636,26 +662,6 @@ class Presto(Dialect):
|
|||
modes = f" {', '.join(modes)}" if modes else ""
|
||||
return f"START TRANSACTION{modes}"
|
||||
|
||||
def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
|
||||
start = expression.args["start"]
|
||||
end = expression.args["end"]
|
||||
step = expression.args.get("step")
|
||||
|
||||
if isinstance(start, exp.Cast):
|
||||
target_type = start.to
|
||||
elif isinstance(end, exp.Cast):
|
||||
target_type = end.to
|
||||
else:
|
||||
target_type = None
|
||||
|
||||
if target_type and target_type.is_type("timestamp"):
|
||||
if target_type is start.to:
|
||||
end = exp.cast(end, target_type)
|
||||
else:
|
||||
start = exp.cast(start, target_type)
|
||||
|
||||
return self.func("SEQUENCE", start, end, step)
|
||||
|
||||
def offset_limit_modifiers(
|
||||
self, expression: exp.Expression, fetch: bool, limit: t.Optional[exp.Fetch | exp.Limit]
|
||||
) -> t.List[str]:
|
||||
|
|
|
@ -504,43 +504,6 @@ class Snowflake(Dialect):
|
|||
|
||||
return lateral
|
||||
|
||||
def _parse_historical_data(self) -> t.Optional[exp.HistoricalData]:
|
||||
# https://docs.snowflake.com/en/sql-reference/constructs/at-before
|
||||
index = self._index
|
||||
historical_data = None
|
||||
if self._match_texts(self.HISTORICAL_DATA_PREFIX):
|
||||
this = self._prev.text.upper()
|
||||
kind = (
|
||||
self._match(TokenType.L_PAREN)
|
||||
and self._match_texts(self.HISTORICAL_DATA_KIND)
|
||||
and self._prev.text.upper()
|
||||
)
|
||||
expression = self._match(TokenType.FARROW) and self._parse_bitwise()
|
||||
|
||||
if expression:
|
||||
self._match_r_paren()
|
||||
historical_data = self.expression(
|
||||
exp.HistoricalData, this=this, kind=kind, expression=expression
|
||||
)
|
||||
else:
|
||||
self._retreat(index)
|
||||
|
||||
return historical_data
|
||||
|
||||
def _parse_changes(self) -> t.Optional[exp.Changes]:
|
||||
if not self._match_text_seq("CHANGES", "(", "INFORMATION", "=>"):
|
||||
return None
|
||||
|
||||
information = self._parse_var(any_token=True)
|
||||
self._match_r_paren()
|
||||
|
||||
return self.expression(
|
||||
exp.Changes,
|
||||
information=information,
|
||||
at_before=self._parse_historical_data(),
|
||||
end=self._parse_historical_data(),
|
||||
)
|
||||
|
||||
def _parse_table_parts(
|
||||
self, schema: bool = False, is_db_reference: bool = False, wildcard: bool = False
|
||||
) -> exp.Table:
|
||||
|
@ -573,14 +536,6 @@ class Snowflake(Dialect):
|
|||
else:
|
||||
table = super()._parse_table_parts(schema=schema, is_db_reference=is_db_reference)
|
||||
|
||||
changes = self._parse_changes()
|
||||
if changes:
|
||||
table.set("changes", changes)
|
||||
|
||||
at_before = self._parse_historical_data()
|
||||
if at_before:
|
||||
table.set("when", at_before)
|
||||
|
||||
return table
|
||||
|
||||
def _parse_id_var(
|
||||
|
@ -659,7 +614,7 @@ class Snowflake(Dialect):
|
|||
# can be joined in a query with a comma separator, as well as closing paren
|
||||
# in case of subqueries
|
||||
while self._is_connected() and not self._match_set(
|
||||
(TokenType.COMMA, TokenType.R_PAREN), advance=False
|
||||
(TokenType.COMMA, TokenType.L_PAREN, TokenType.R_PAREN), advance=False
|
||||
):
|
||||
parts.append(self._advance_any(ignore_reserved=True))
|
||||
|
||||
|
|
|
@ -165,9 +165,6 @@ class Spark2(Hive):
|
|||
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
|
||||
}
|
||||
|
||||
def _parse_add_column(self) -> t.Optional[exp.Expression]:
|
||||
return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
|
||||
|
||||
def _parse_drop_column(self) -> t.Optional[exp.Drop | exp.Command]:
|
||||
return self._match_text_seq("DROP", "COLUMNS") and self.expression(
|
||||
exp.Drop, this=self._parse_schema(), kind="COLUMNS"
|
||||
|
|
|
@ -855,6 +855,7 @@ class TSQL(Dialect):
|
|||
transforms.eliminate_qualify,
|
||||
]
|
||||
),
|
||||
exp.Stddev: rename_func("STDEV"),
|
||||
exp.StrPosition: lambda self, e: self.func(
|
||||
"CHARINDEX", e.args.get("substr"), e.this, e.args.get("position")
|
||||
),
|
||||
|
|
|
@ -33,7 +33,7 @@ from sqlglot.helper import (
|
|||
seq_get,
|
||||
subclasses,
|
||||
)
|
||||
from sqlglot.tokens import Token
|
||||
from sqlglot.tokens import Token, TokenError
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E, Lit
|
||||
|
@ -1393,6 +1393,8 @@ class Create(DDL):
|
|||
"begin": False,
|
||||
"end": False,
|
||||
"clone": False,
|
||||
"concurrently": False,
|
||||
"clustered": False,
|
||||
}
|
||||
|
||||
@property
|
||||
|
@ -5483,6 +5485,16 @@ class JSONTable(Func):
|
|||
}
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/object_insert
|
||||
class ObjectInsert(Func):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"key": True,
|
||||
"value": True,
|
||||
"update_flag": False,
|
||||
}
|
||||
|
||||
|
||||
class OpenJSONColumnDef(Expression):
|
||||
arg_types = {"this": True, "kind": True, "path": False, "as_json": False}
|
||||
|
||||
|
@ -5886,7 +5898,7 @@ class Sqrt(Func):
|
|||
|
||||
|
||||
class Stddev(AggFunc):
|
||||
pass
|
||||
_sql_names = ["STDDEV", "STDEV"]
|
||||
|
||||
|
||||
class StddevPop(AggFunc):
|
||||
|
@ -6881,7 +6893,7 @@ def parse_identifier(name: str | Identifier, dialect: DialectType = None) -> Ide
|
|||
"""
|
||||
try:
|
||||
expression = maybe_parse(name, dialect=dialect, into=Identifier)
|
||||
except ParseError:
|
||||
except (ParseError, TokenError):
|
||||
expression = to_identifier(name)
|
||||
|
||||
return expression
|
||||
|
|
|
@ -1027,6 +1027,14 @@ class Generator(metaclass=_Generator):
|
|||
replace = " OR REPLACE" if expression.args.get("replace") else ""
|
||||
unique = " UNIQUE" if expression.args.get("unique") else ""
|
||||
|
||||
clustered = expression.args.get("clustered")
|
||||
if clustered is None:
|
||||
clustered_sql = ""
|
||||
elif clustered:
|
||||
clustered_sql = " CLUSTERED COLUMNSTORE"
|
||||
else:
|
||||
clustered_sql = " NONCLUSTERED COLUMNSTORE"
|
||||
|
||||
postcreate_props_sql = ""
|
||||
if properties_locs.get(exp.Properties.Location.POST_CREATE):
|
||||
postcreate_props_sql = self.properties(
|
||||
|
@ -1036,7 +1044,7 @@ class Generator(metaclass=_Generator):
|
|||
wrapped=False,
|
||||
)
|
||||
|
||||
modifiers = "".join((replace, unique, postcreate_props_sql))
|
||||
modifiers = "".join((clustered_sql, replace, unique, postcreate_props_sql))
|
||||
|
||||
postexpression_props_sql = ""
|
||||
if properties_locs.get(exp.Properties.Location.POST_EXPRESSION):
|
||||
|
@ -1049,6 +1057,7 @@ class Generator(metaclass=_Generator):
|
|||
wrapped=False,
|
||||
)
|
||||
|
||||
concurrently = " CONCURRENTLY" if expression.args.get("concurrently") else ""
|
||||
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
|
||||
no_schema_binding = (
|
||||
" WITH NO SCHEMA BINDING" if expression.args.get("no_schema_binding") else ""
|
||||
|
@ -1057,7 +1066,7 @@ class Generator(metaclass=_Generator):
|
|||
clone = self.sql(expression, "clone")
|
||||
clone = f" {clone}" if clone else ""
|
||||
|
||||
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{postexpression_props_sql}{index_sql}{no_schema_binding}{clone}"
|
||||
expression_sql = f"CREATE{modifiers} {kind}{concurrently}{exists_sql} {this}{properties_sql}{expression_sql}{postexpression_props_sql}{index_sql}{no_schema_binding}{clone}"
|
||||
return self.prepend_ctes(expression, expression_sql)
|
||||
|
||||
def sequenceproperties_sql(self, expression: exp.SequenceProperties) -> str:
|
||||
|
@ -1734,8 +1743,7 @@ class Generator(metaclass=_Generator):
|
|||
alias = f"{sep}{alias}" if alias else ""
|
||||
hints = self.expressions(expression, key="hints", sep=" ")
|
||||
hints = f" {hints}" if hints and self.TABLE_HINTS else ""
|
||||
pivots = self.expressions(expression, key="pivots", sep=" ", flat=True)
|
||||
pivots = f" {pivots}" if pivots else ""
|
||||
pivots = self.expressions(expression, key="pivots", sep="", flat=True)
|
||||
joins = self.indent(
|
||||
self.expressions(expression, key="joins", sep="", flat=True), skip_first=True
|
||||
)
|
||||
|
@ -1822,7 +1830,7 @@ class Generator(metaclass=_Generator):
|
|||
|
||||
alias = self.sql(expression, "alias")
|
||||
alias = f" AS {alias}" if alias else ""
|
||||
direction = "UNPIVOT" if expression.unpivot else "PIVOT"
|
||||
direction = self.seg("UNPIVOT" if expression.unpivot else "PIVOT")
|
||||
field = self.sql(expression, "field")
|
||||
include_nulls = expression.args.get("include_nulls")
|
||||
if include_nulls is not None:
|
||||
|
@ -2409,10 +2417,7 @@ class Generator(metaclass=_Generator):
|
|||
def subquery_sql(self, expression: exp.Subquery, sep: str = " AS ") -> str:
|
||||
alias = self.sql(expression, "alias")
|
||||
alias = f"{sep}{alias}" if alias else ""
|
||||
|
||||
pivots = self.expressions(expression, key="pivots", sep=" ", flat=True)
|
||||
pivots = f" {pivots}" if pivots else ""
|
||||
|
||||
pivots = self.expressions(expression, key="pivots", sep="", flat=True)
|
||||
sql = self.query_modifiers(expression, self.wrap(expression), alias, pivots)
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
|
@ -3134,6 +3139,7 @@ class Generator(metaclass=_Generator):
|
|||
expression,
|
||||
key="actions",
|
||||
prefix="ADD COLUMN ",
|
||||
skip_first=True,
|
||||
)
|
||||
return f"ADD {self.expressions(expression, key='actions', flat=True)}"
|
||||
|
||||
|
|
|
@ -10,10 +10,10 @@ from sqlglot.helper import (
|
|||
is_iso_date,
|
||||
is_iso_datetime,
|
||||
seq_get,
|
||||
subclasses,
|
||||
)
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import Schema, ensure_schema
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import B, E
|
||||
|
@ -24,12 +24,15 @@ if t.TYPE_CHECKING:
|
|||
BinaryCoercionFunc,
|
||||
]
|
||||
|
||||
from sqlglot.dialects.dialect import DialectType, AnnotatorsType
|
||||
|
||||
|
||||
def annotate_types(
|
||||
expression: E,
|
||||
schema: t.Optional[t.Dict | Schema] = None,
|
||||
annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
|
||||
annotators: t.Optional[AnnotatorsType] = None,
|
||||
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
|
||||
dialect: t.Optional[DialectType] = None,
|
||||
) -> E:
|
||||
"""
|
||||
Infers the types of an expression, annotating its AST accordingly.
|
||||
|
@ -54,11 +57,7 @@ def annotate_types(
|
|||
|
||||
schema = ensure_schema(schema)
|
||||
|
||||
return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
|
||||
|
||||
|
||||
def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]:
|
||||
return lambda self, e: self._annotate_with_type(e, data_type)
|
||||
return TypeAnnotator(schema, annotators, coerces_to, dialect=dialect).annotate(expression)
|
||||
|
||||
|
||||
def _coerce_date_literal(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
|
||||
|
@ -133,168 +132,6 @@ class _TypeAnnotator(type):
|
|||
|
||||
|
||||
class TypeAnnotator(metaclass=_TypeAnnotator):
|
||||
TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
|
||||
exp.DataType.Type.BIGINT: {
|
||||
exp.ApproxDistinct,
|
||||
exp.ArraySize,
|
||||
exp.Count,
|
||||
exp.Length,
|
||||
},
|
||||
exp.DataType.Type.BOOLEAN: {
|
||||
exp.Between,
|
||||
exp.Boolean,
|
||||
exp.In,
|
||||
exp.RegexpLike,
|
||||
},
|
||||
exp.DataType.Type.DATE: {
|
||||
exp.CurrentDate,
|
||||
exp.Date,
|
||||
exp.DateFromParts,
|
||||
exp.DateStrToDate,
|
||||
exp.DiToDate,
|
||||
exp.StrToDate,
|
||||
exp.TimeStrToDate,
|
||||
exp.TsOrDsToDate,
|
||||
},
|
||||
exp.DataType.Type.DATETIME: {
|
||||
exp.CurrentDatetime,
|
||||
exp.Datetime,
|
||||
exp.DatetimeAdd,
|
||||
exp.DatetimeSub,
|
||||
},
|
||||
exp.DataType.Type.DOUBLE: {
|
||||
exp.ApproxQuantile,
|
||||
exp.Avg,
|
||||
exp.Div,
|
||||
exp.Exp,
|
||||
exp.Ln,
|
||||
exp.Log,
|
||||
exp.Pow,
|
||||
exp.Quantile,
|
||||
exp.Round,
|
||||
exp.SafeDivide,
|
||||
exp.Sqrt,
|
||||
exp.Stddev,
|
||||
exp.StddevPop,
|
||||
exp.StddevSamp,
|
||||
exp.Variance,
|
||||
exp.VariancePop,
|
||||
},
|
||||
exp.DataType.Type.INT: {
|
||||
exp.Ceil,
|
||||
exp.DatetimeDiff,
|
||||
exp.DateDiff,
|
||||
exp.TimestampDiff,
|
||||
exp.TimeDiff,
|
||||
exp.DateToDi,
|
||||
exp.Floor,
|
||||
exp.Levenshtein,
|
||||
exp.Sign,
|
||||
exp.StrPosition,
|
||||
exp.TsOrDiToDi,
|
||||
},
|
||||
exp.DataType.Type.JSON: {
|
||||
exp.ParseJSON,
|
||||
},
|
||||
exp.DataType.Type.TIME: {
|
||||
exp.Time,
|
||||
},
|
||||
exp.DataType.Type.TIMESTAMP: {
|
||||
exp.CurrentTime,
|
||||
exp.CurrentTimestamp,
|
||||
exp.StrToTime,
|
||||
exp.TimeAdd,
|
||||
exp.TimeStrToTime,
|
||||
exp.TimeSub,
|
||||
exp.TimestampAdd,
|
||||
exp.TimestampSub,
|
||||
exp.UnixToTime,
|
||||
},
|
||||
exp.DataType.Type.TINYINT: {
|
||||
exp.Day,
|
||||
exp.Month,
|
||||
exp.Week,
|
||||
exp.Year,
|
||||
exp.Quarter,
|
||||
},
|
||||
exp.DataType.Type.VARCHAR: {
|
||||
exp.ArrayConcat,
|
||||
exp.Concat,
|
||||
exp.ConcatWs,
|
||||
exp.DateToDateStr,
|
||||
exp.GroupConcat,
|
||||
exp.Initcap,
|
||||
exp.Lower,
|
||||
exp.Substring,
|
||||
exp.TimeToStr,
|
||||
exp.TimeToTimeStr,
|
||||
exp.Trim,
|
||||
exp.TsOrDsToDateStr,
|
||||
exp.UnixToStr,
|
||||
exp.UnixToTimeStr,
|
||||
exp.Upper,
|
||||
},
|
||||
}
|
||||
|
||||
ANNOTATORS: t.Dict = {
|
||||
**{
|
||||
expr_type: lambda self, e: self._annotate_unary(e)
|
||||
for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
|
||||
},
|
||||
**{
|
||||
expr_type: lambda self, e: self._annotate_binary(e)
|
||||
for expr_type in subclasses(exp.__name__, exp.Binary)
|
||||
},
|
||||
**{
|
||||
expr_type: _annotate_with_type_lambda(data_type)
|
||||
for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
|
||||
for expr_type in expressions
|
||||
},
|
||||
exp.Abs: lambda self, e: self._annotate_by_args(e, "this"),
|
||||
exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
|
||||
exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
|
||||
exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
|
||||
exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Bracket: lambda self, e: self._annotate_bracket(e),
|
||||
exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
|
||||
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
|
||||
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
|
||||
exp.DateAdd: lambda self, e: self._annotate_timeunit(e),
|
||||
exp.DateSub: lambda self, e: self._annotate_timeunit(e),
|
||||
exp.DateTrunc: lambda self, e: self._annotate_timeunit(e),
|
||||
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
|
||||
exp.Div: lambda self, e: self._annotate_div(e),
|
||||
exp.Dot: lambda self, e: self._annotate_dot(e),
|
||||
exp.Explode: lambda self, e: self._annotate_explode(e),
|
||||
exp.Extract: lambda self, e: self._annotate_extract(e),
|
||||
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
|
||||
exp.GenerateDateArray: lambda self, e: self._annotate_with_type(
|
||||
e, exp.DataType.build("ARRAY<DATE>")
|
||||
),
|
||||
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
|
||||
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
|
||||
exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
|
||||
exp.Literal: lambda self, e: self._annotate_literal(e),
|
||||
exp.Map: lambda self, e: self._annotate_map(e),
|
||||
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
|
||||
exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
|
||||
exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"),
|
||||
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
|
||||
exp.Struct: lambda self, e: self._annotate_struct(e),
|
||||
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
|
||||
exp.Timestamp: lambda self, e: self._annotate_with_type(
|
||||
e,
|
||||
exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP,
|
||||
),
|
||||
exp.ToMap: lambda self, e: self._annotate_to_map(e),
|
||||
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
|
||||
exp.Unnest: lambda self, e: self._annotate_unnest(e),
|
||||
exp.VarMap: lambda self, e: self._annotate_map(e),
|
||||
}
|
||||
|
||||
NESTED_TYPES = {
|
||||
exp.DataType.Type.ARRAY,
|
||||
}
|
||||
|
@ -335,12 +172,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
def __init__(
|
||||
self,
|
||||
schema: Schema,
|
||||
annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
|
||||
annotators: t.Optional[AnnotatorsType] = None,
|
||||
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
|
||||
binary_coercions: t.Optional[BinaryCoercions] = None,
|
||||
dialect: t.Optional[DialectType] = None,
|
||||
) -> None:
|
||||
self.schema = schema
|
||||
self.annotators = annotators or self.ANNOTATORS
|
||||
self.annotators = annotators or Dialect.get_or_raise(dialect).ANNOTATORS
|
||||
self.coerces_to = coerces_to or self.COERCES_TO
|
||||
self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
|
||||
|
||||
|
@ -483,7 +321,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
return expression
|
||||
|
||||
def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
|
||||
def _annotate_with_type(
|
||||
self, expression: E, target_type: exp.DataType | exp.DataType.Type
|
||||
) -> E:
|
||||
self._set_type(expression, target_type)
|
||||
return self._annotate_args(expression)
|
||||
|
||||
|
|
|
@ -376,6 +376,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
# Tokens that can represent identifiers
|
||||
ID_VAR_TOKENS = {
|
||||
TokenType.ALL,
|
||||
TokenType.VAR,
|
||||
TokenType.ANTI,
|
||||
TokenType.APPLY,
|
||||
|
@ -929,7 +930,8 @@ class Parser(metaclass=_Parser):
|
|||
enforced=self._match_text_seq("ENFORCED"),
|
||||
),
|
||||
"COLLATE": lambda self: self.expression(
|
||||
exp.CollateColumnConstraint, this=self._parse_var(any_token=True)
|
||||
exp.CollateColumnConstraint,
|
||||
this=self._parse_identifier() or self._parse_column(),
|
||||
),
|
||||
"COMMENT": lambda self: self.expression(
|
||||
exp.CommentColumnConstraint, this=self._parse_string()
|
||||
|
@ -1138,7 +1140,9 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
ISOLATED_LOADING_OPTIONS: OPTIONS_TYPE = {"FOR": ("ALL", "INSERT", "NONE")}
|
||||
|
||||
USABLES: OPTIONS_TYPE = dict.fromkeys(("ROLE", "WAREHOUSE", "DATABASE", "SCHEMA"), tuple())
|
||||
USABLES: OPTIONS_TYPE = dict.fromkeys(
|
||||
("ROLE", "WAREHOUSE", "DATABASE", "SCHEMA", "CATALOG"), tuple()
|
||||
)
|
||||
|
||||
CAST_ACTIONS: OPTIONS_TYPE = dict.fromkeys(("RENAME", "ADD"), ("FIELDS",))
|
||||
|
||||
|
@ -1147,6 +1151,17 @@ class Parser(metaclass=_Parser):
|
|||
**dict.fromkeys(("BINDING", "COMPENSATION", "EVOLUTION"), tuple()),
|
||||
}
|
||||
|
||||
KEY_CONSTRAINT_OPTIONS: OPTIONS_TYPE = {
|
||||
"NOT": ("ENFORCED",),
|
||||
"MATCH": (
|
||||
"FULL",
|
||||
"PARTIAL",
|
||||
"SIMPLE",
|
||||
),
|
||||
"INITIALLY": ("DEFERRED", "IMMEDIATE"),
|
||||
**dict.fromkeys(("DEFERRABLE", "NORELY"), tuple()),
|
||||
}
|
||||
|
||||
INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"}
|
||||
|
||||
CLONE_KEYWORDS = {"CLONE", "COPY"}
|
||||
|
@ -1663,6 +1678,15 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
unique = self._match(TokenType.UNIQUE)
|
||||
|
||||
if self._match_text_seq("CLUSTERED", "COLUMNSTORE"):
|
||||
clustered = True
|
||||
elif self._match_text_seq("NONCLUSTERED", "COLUMNSTORE") or self._match_text_seq(
|
||||
"COLUMNSTORE"
|
||||
):
|
||||
clustered = False
|
||||
else:
|
||||
clustered = None
|
||||
|
||||
if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False):
|
||||
self._advance()
|
||||
|
||||
|
@ -1677,6 +1701,7 @@ class Parser(metaclass=_Parser):
|
|||
if not properties or not create_token:
|
||||
return self._parse_as_command(start)
|
||||
|
||||
concurrently = self._match_text_seq("CONCURRENTLY")
|
||||
exists = self._parse_exists(not_=True)
|
||||
this = None
|
||||
expression: t.Optional[exp.Expression] = None
|
||||
|
@ -1802,6 +1827,8 @@ class Parser(metaclass=_Parser):
|
|||
begin=begin,
|
||||
end=end,
|
||||
clone=clone,
|
||||
concurrently=concurrently,
|
||||
clustered=clustered,
|
||||
)
|
||||
|
||||
def _parse_sequence_properties(self) -> t.Optional[exp.SequenceProperties]:
|
||||
|
@ -2728,8 +2755,12 @@ class Parser(metaclass=_Parser):
|
|||
comments = self._prev_comments
|
||||
|
||||
hint = self._parse_hint()
|
||||
all_ = self._match(TokenType.ALL)
|
||||
distinct = self._match_set(self.DISTINCT_TOKENS)
|
||||
|
||||
if self._next and not self._next.token_type == TokenType.DOT:
|
||||
all_ = self._match(TokenType.ALL)
|
||||
distinct = self._match_set(self.DISTINCT_TOKENS)
|
||||
else:
|
||||
all_, distinct = None, None
|
||||
|
||||
kind = (
|
||||
self._match(TokenType.ALIAS)
|
||||
|
@ -2827,6 +2858,7 @@ class Parser(metaclass=_Parser):
|
|||
self.raise_error("Expected CTE to have alias")
|
||||
|
||||
self._match(TokenType.ALIAS)
|
||||
comments = self._prev_comments
|
||||
|
||||
if self._match_text_seq("NOT", "MATERIALIZED"):
|
||||
materialized = False
|
||||
|
@ -2840,6 +2872,7 @@ class Parser(metaclass=_Parser):
|
|||
this=self._parse_wrapped(self._parse_statement),
|
||||
alias=alias,
|
||||
materialized=materialized,
|
||||
comments=comments,
|
||||
)
|
||||
|
||||
def _parse_table_alias(
|
||||
|
@ -3352,15 +3385,28 @@ class Parser(metaclass=_Parser):
|
|||
if not db and is_db_reference:
|
||||
self.raise_error(f"Expected database name but got {self._curr}")
|
||||
|
||||
return self.expression(
|
||||
table = self.expression(
|
||||
exp.Table,
|
||||
comments=comments,
|
||||
this=table,
|
||||
db=db,
|
||||
catalog=catalog,
|
||||
pivots=self._parse_pivots(),
|
||||
)
|
||||
|
||||
changes = self._parse_changes()
|
||||
if changes:
|
||||
table.set("changes", changes)
|
||||
|
||||
at_before = self._parse_historical_data()
|
||||
if at_before:
|
||||
table.set("when", at_before)
|
||||
|
||||
pivots = self._parse_pivots()
|
||||
if pivots:
|
||||
table.set("pivots", pivots)
|
||||
|
||||
return table
|
||||
|
||||
def _parse_table(
|
||||
self,
|
||||
schema: bool = False,
|
||||
|
@ -3490,6 +3536,43 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return self.expression(exp.Version, this=this, expression=expression, kind=kind)
|
||||
|
||||
def _parse_historical_data(self) -> t.Optional[exp.HistoricalData]:
|
||||
# https://docs.snowflake.com/en/sql-reference/constructs/at-before
|
||||
index = self._index
|
||||
historical_data = None
|
||||
if self._match_texts(self.HISTORICAL_DATA_PREFIX):
|
||||
this = self._prev.text.upper()
|
||||
kind = (
|
||||
self._match(TokenType.L_PAREN)
|
||||
and self._match_texts(self.HISTORICAL_DATA_KIND)
|
||||
and self._prev.text.upper()
|
||||
)
|
||||
expression = self._match(TokenType.FARROW) and self._parse_bitwise()
|
||||
|
||||
if expression:
|
||||
self._match_r_paren()
|
||||
historical_data = self.expression(
|
||||
exp.HistoricalData, this=this, kind=kind, expression=expression
|
||||
)
|
||||
else:
|
||||
self._retreat(index)
|
||||
|
||||
return historical_data
|
||||
|
||||
def _parse_changes(self) -> t.Optional[exp.Changes]:
|
||||
if not self._match_text_seq("CHANGES", "(", "INFORMATION", "=>"):
|
||||
return None
|
||||
|
||||
information = self._parse_var(any_token=True)
|
||||
self._match_r_paren()
|
||||
|
||||
return self.expression(
|
||||
exp.Changes,
|
||||
information=information,
|
||||
at_before=self._parse_historical_data(),
|
||||
end=self._parse_historical_data(),
|
||||
)
|
||||
|
||||
def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]:
|
||||
if not self._match(TokenType.UNNEST):
|
||||
return None
|
||||
|
@ -5216,18 +5299,13 @@ class Parser(metaclass=_Parser):
|
|||
self.raise_error("Invalid key constraint")
|
||||
|
||||
options.append(f"ON {on} {action}")
|
||||
elif self._match_text_seq("NOT", "ENFORCED"):
|
||||
options.append("NOT ENFORCED")
|
||||
elif self._match_text_seq("DEFERRABLE"):
|
||||
options.append("DEFERRABLE")
|
||||
elif self._match_text_seq("INITIALLY", "DEFERRED"):
|
||||
options.append("INITIALLY DEFERRED")
|
||||
elif self._match_text_seq("NORELY"):
|
||||
options.append("NORELY")
|
||||
elif self._match_text_seq("MATCH", "FULL"):
|
||||
options.append("MATCH FULL")
|
||||
else:
|
||||
break
|
||||
var = self._parse_var_from_options(
|
||||
self.KEY_CONSTRAINT_OPTIONS, raise_unmatched=False
|
||||
)
|
||||
if not var:
|
||||
break
|
||||
options.append(var.name)
|
||||
|
||||
return options
|
||||
|
||||
|
@ -6227,6 +6305,13 @@ class Parser(metaclass=_Parser):
|
|||
self._retreat(index)
|
||||
if not self.ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN and self._match_text_seq("ADD"):
|
||||
return self._parse_wrapped_csv(self._parse_field_def, optional=True)
|
||||
|
||||
if self._match_text_seq("ADD", "COLUMNS"):
|
||||
schema = self._parse_schema()
|
||||
if schema:
|
||||
return [schema]
|
||||
return []
|
||||
|
||||
return self._parse_wrapped_csv(self._parse_add_column, optional=True)
|
||||
|
||||
def _parse_alter_table_alter(self) -> t.Optional[exp.Expression]:
|
||||
|
|
|
@ -229,6 +229,23 @@ def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
|
|||
def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
|
||||
"""Convert cross join unnest into lateral view explode."""
|
||||
if isinstance(expression, exp.Select):
|
||||
from_ = expression.args.get("from")
|
||||
|
||||
if from_ and isinstance(from_.this, exp.Unnest):
|
||||
unnest = from_.this
|
||||
alias = unnest.args.get("alias")
|
||||
udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
|
||||
this, *expressions = unnest.expressions
|
||||
unnest.replace(
|
||||
exp.Table(
|
||||
this=udtf(
|
||||
this=this,
|
||||
expressions=expressions,
|
||||
),
|
||||
alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
|
||||
)
|
||||
)
|
||||
|
||||
for join in expression.args.get("joins") or []:
|
||||
unnest = join.this
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue