1
0
Fork 0

Merging upstream version 25.7.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:51:42 +01:00
parent dba379232c
commit aa0eae236a
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
102 changed files with 52995 additions and 52070 deletions

View file

@ -334,6 +334,11 @@ class ClickHouse(Dialect):
RESERVED_TOKENS = parser.Parser.RESERVED_TOKENS - {TokenType.SELECT}
ID_VAR_TOKENS = {
*parser.Parser.ID_VAR_TOKENS,
TokenType.LIKE,
}
AGG_FUNC_MAPPING = (
lambda functions, suffixes: {
f"{f}{sfx}": (f, sfx) for sfx in (suffixes + [""]) for f in functions

View file

@ -8,7 +8,7 @@ from functools import reduce
from sqlglot import exp
from sqlglot.errors import ParseError
from sqlglot.generator import Generator
from sqlglot.helper import AutoName, flatten, is_int, seq_get
from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses
from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path
from sqlglot.parser import Parser
from sqlglot.time import TIMEZONES, format_time
@ -23,6 +23,10 @@ JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
if t.TYPE_CHECKING:
from sqlglot._typing import B, E, F
from sqlglot.optimizer.annotate_types import TypeAnnotator
AnnotatorsType = t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]
logger = logging.getLogger("sqlglot")
UNESCAPED_SEQUENCES = {
@ -37,6 +41,10 @@ UNESCAPED_SEQUENCES = {
}
def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]:
return lambda self, e: self._annotate_with_type(e, data_type)
class Dialects(str, Enum):
"""Dialects supported by SQLGLot."""
@ -489,6 +497,167 @@ class Dialect(metaclass=_Dialect):
"CENTURIES": "CENTURY",
}
TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
exp.DataType.Type.BIGINT: {
exp.ApproxDistinct,
exp.ArraySize,
exp.Count,
exp.Length,
},
exp.DataType.Type.BOOLEAN: {
exp.Between,
exp.Boolean,
exp.In,
exp.RegexpLike,
},
exp.DataType.Type.DATE: {
exp.CurrentDate,
exp.Date,
exp.DateFromParts,
exp.DateStrToDate,
exp.DiToDate,
exp.StrToDate,
exp.TimeStrToDate,
exp.TsOrDsToDate,
},
exp.DataType.Type.DATETIME: {
exp.CurrentDatetime,
exp.Datetime,
exp.DatetimeAdd,
exp.DatetimeSub,
},
exp.DataType.Type.DOUBLE: {
exp.ApproxQuantile,
exp.Avg,
exp.Div,
exp.Exp,
exp.Ln,
exp.Log,
exp.Pow,
exp.Quantile,
exp.Round,
exp.SafeDivide,
exp.Sqrt,
exp.Stddev,
exp.StddevPop,
exp.StddevSamp,
exp.Variance,
exp.VariancePop,
},
exp.DataType.Type.INT: {
exp.Ceil,
exp.DatetimeDiff,
exp.DateDiff,
exp.TimestampDiff,
exp.TimeDiff,
exp.DateToDi,
exp.Levenshtein,
exp.Sign,
exp.StrPosition,
exp.TsOrDiToDi,
},
exp.DataType.Type.JSON: {
exp.ParseJSON,
},
exp.DataType.Type.TIME: {
exp.Time,
},
exp.DataType.Type.TIMESTAMP: {
exp.CurrentTime,
exp.CurrentTimestamp,
exp.StrToTime,
exp.TimeAdd,
exp.TimeStrToTime,
exp.TimeSub,
exp.TimestampAdd,
exp.TimestampSub,
exp.UnixToTime,
},
exp.DataType.Type.TINYINT: {
exp.Day,
exp.Month,
exp.Week,
exp.Year,
exp.Quarter,
},
exp.DataType.Type.VARCHAR: {
exp.ArrayConcat,
exp.Concat,
exp.ConcatWs,
exp.DateToDateStr,
exp.GroupConcat,
exp.Initcap,
exp.Lower,
exp.Substring,
exp.TimeToStr,
exp.TimeToTimeStr,
exp.Trim,
exp.TsOrDsToDateStr,
exp.UnixToStr,
exp.UnixToTimeStr,
exp.Upper,
},
}
ANNOTATORS: AnnotatorsType = {
**{
expr_type: lambda self, e: self._annotate_unary(e)
for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
},
**{
expr_type: lambda self, e: self._annotate_binary(e)
for expr_type in subclasses(exp.__name__, exp.Binary)
},
**{
expr_type: _annotate_with_type_lambda(data_type)
for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
for expr_type in expressions
},
exp.Abs: lambda self, e: self._annotate_by_args(e, "this"),
exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Bracket: lambda self, e: self._annotate_bracket(e),
exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
exp.DateAdd: lambda self, e: self._annotate_timeunit(e),
exp.DateSub: lambda self, e: self._annotate_timeunit(e),
exp.DateTrunc: lambda self, e: self._annotate_timeunit(e),
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Div: lambda self, e: self._annotate_div(e),
exp.Dot: lambda self, e: self._annotate_dot(e),
exp.Explode: lambda self, e: self._annotate_explode(e),
exp.Extract: lambda self, e: self._annotate_extract(e),
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
exp.GenerateDateArray: lambda self, e: self._annotate_with_type(
e, exp.DataType.build("ARRAY<DATE>")
),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Literal: lambda self, e: self._annotate_literal(e),
exp.Map: lambda self, e: self._annotate_map(e),
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"),
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Struct: lambda self, e: self._annotate_struct(e),
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
exp.Timestamp: lambda self, e: self._annotate_with_type(
e,
exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP,
),
exp.ToMap: lambda self, e: self._annotate_to_map(e),
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.Unnest: lambda self, e: self._annotate_unnest(e),
exp.VarMap: lambda self, e: self._annotate_map(e),
}
@classmethod
def get_or_raise(cls, dialect: DialectType) -> Dialect:
"""
@ -1419,3 +1588,24 @@ def build_timestamp_from_parts(args: t.List) -> exp.Func:
def sha256_sql(self: Generator, expression: exp.SHA2) -> str:
return self.func(f"SHA{expression.text('length') or '256'}", expression.this)
def sequence_sql(self: Generator, expression: exp.GenerateSeries):
start = expression.args["start"]
end = expression.args["end"]
step = expression.args.get("step")
if isinstance(start, exp.Cast):
target_type = start.to
elif isinstance(end, exp.Cast):
target_type = end.to
else:
target_type = None
if target_type and target_type.is_type("timestamp"):
if target_type is start.to:
end = exp.cast(end, target_type)
else:
start = exp.cast(start, target_type)
return self.func("SEQUENCE", start, end, step)

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.expressions import DATA_TYPE
from sqlglot.dialects.dialect import (
Dialect,
JSON_EXTRACT_TYPE,
@ -35,20 +36,34 @@ from sqlglot.dialects.dialect import (
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
def _ts_or_ds_add_sql(self: DuckDB.Generator, expression: exp.TsOrDsAdd) -> str:
this = self.sql(expression, "this")
interval = self.sql(exp.Interval(this=expression.expression, unit=unit_to_var(expression)))
return f"CAST({this} AS {self.sql(expression.return_type)}) + {interval}"
DATETIME_DELTA = t.Union[
exp.DateAdd, exp.TimeAdd, exp.DatetimeAdd, exp.TsOrDsAdd, exp.DateSub, exp.DatetimeSub
]
def _date_delta_sql(
self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub | exp.TimeAdd
) -> str:
this = self.sql(expression, "this")
def _date_delta_sql(self: DuckDB.Generator, expression: DATETIME_DELTA) -> str:
this = expression.this
unit = unit_to_var(expression)
op = "+" if isinstance(expression, (exp.DateAdd, exp.TimeAdd)) else "-"
return f"{this} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
op = (
"+"
if isinstance(expression, (exp.DateAdd, exp.TimeAdd, exp.DatetimeAdd, exp.TsOrDsAdd))
else "-"
)
to_type: t.Optional[DATA_TYPE] = None
if isinstance(expression, exp.TsOrDsAdd):
to_type = expression.return_type
elif this.is_string:
# Cast string literals (i.e function parameters) to the appropriate type for +/- interval to work
to_type = (
exp.DataType.Type.DATETIME
if isinstance(expression, (exp.DatetimeAdd, exp.DatetimeSub))
else exp.DataType.Type.DATE
)
this = exp.cast(this, to_type) if to_type else this
return f"{self.sql(this)} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
# BigQuery -> DuckDB conversion for the DATE function
@ -119,7 +134,12 @@ def _struct_sql(self: DuckDB.Generator, expression: exp.Struct) -> str:
# BigQuery allows inline construction such as "STRUCT<a STRING, b INTEGER>('str', 1)" which is
# canonicalized to "ROW('str', 1) AS STRUCT(a TEXT, b INT)" in DuckDB
is_struct_cast = expression.find_ancestor(exp.Cast)
# The transformation to ROW will take place if a cast to STRUCT / ARRAY of STRUCTs is found
ancestor_cast = expression.find_ancestor(exp.Cast)
is_struct_cast = ancestor_cast and any(
casted_type.is_type(exp.DataType.Type.STRUCT)
for casted_type in ancestor_cast.find_all(exp.DataType)
)
for i, expr in enumerate(expression.expressions):
is_property_eq = isinstance(expr, exp.PropertyEQ)
@ -168,7 +188,7 @@ def _unix_to_time_sql(self: DuckDB.Generator, expression: exp.UnixToTime) -> str
def _arrow_json_extract_sql(self: DuckDB.Generator, expression: JSON_EXTRACT_TYPE) -> str:
arrow_sql = arrow_json_extract_sql(self, expression)
if not expression.same_parent and isinstance(expression.parent, exp.Binary):
if not expression.same_parent and isinstance(expression.parent, (exp.Binary, exp.Bracket)):
arrow_sql = self.wrap(arrow_sql)
return arrow_sql
@ -420,6 +440,8 @@ class DuckDB(Dialect):
),
exp.DateStrToDate: datestrtodate_sql,
exp.Datetime: no_datetime_sql,
exp.DatetimeSub: _date_delta_sql,
exp.DatetimeAdd: _date_delta_sql,
exp.DateToDi: lambda self,
e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)",
exp.Decode: lambda self, e: encode_decode_sql(self, e, "DECODE", replace=False),
@ -484,7 +506,7 @@ class DuckDB(Dialect):
exp.TimeToUnix: rename_func("EPOCH"),
exp.TsOrDiToDi: lambda self,
e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsAdd: _date_delta_sql,
exp.TsOrDsDiff: lambda self, e: self.func(
"DATE_DIFF",
f"'{e.args.get('unit') or 'DAY'}'",
@ -790,3 +812,18 @@ class DuckDB(Dialect):
)
return self.sql(case)
def objectinsert_sql(self, expression: exp.ObjectInsert) -> str:
this = expression.this
key = expression.args.get("key")
key_sql = key.name if isinstance(key, exp.Expression) else ""
value_sql = self.sql(expression, "value")
kv_sql = f"{key_sql} := {value_sql}"
# If the input struct is empty e.g. transpiling OBJECT_INSERT(OBJECT_CONSTRUCT(), key, value) from Snowflake
# then we can generate STRUCT_PACK which will build it since STRUCT_INSERT({}, key := value) is not valid DuckDB
if isinstance(this, exp.Struct) and not this.expressions:
return self.func("STRUCT_PACK", kv_sql)
return self.func("STRUCT_INSERT", this, kv_sql)

View file

@ -31,6 +31,7 @@ from sqlglot.dialects.dialect import (
timestrtotime_sql,
unit_to_str,
var_map_sql,
sequence_sql,
)
from sqlglot.transforms import (
remove_unique_constraints,
@ -310,6 +311,7 @@ class Hive(Dialect):
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2)
),
"SEQUENCE": exp.GenerateSeries.from_arg_list,
"SIZE": exp.ArraySize.from_arg_list,
"SPLIT": exp.RegexpSplit.from_arg_list,
"STR_TO_MAP": lambda args: exp.StrToMap(
@ -506,6 +508,7 @@ class Hive(Dialect):
exp.FileFormatProperty: lambda self,
e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
exp.FromBase64: rename_func("UNBASE64"),
exp.GenerateSeries: sequence_sql,
exp.If: if_sql(),
exp.ILike: no_ilike_sql,
exp.IsNan: rename_func("ISNAN"),

View file

@ -691,6 +691,7 @@ class MySQL(Dialect):
SUPPORTS_TO_NUMBER = False
PARSE_JSON_NAME = None
PAD_FILL_PATTERN_IS_REQUIRED = True
WRAP_DERIVED_VALUES = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,

View file

@ -365,6 +365,7 @@ class Postgres(Dialect):
"NOW": exp.CurrentTimestamp.from_arg_list,
"REGEXP_REPLACE": _build_regexp_replace,
"TO_CHAR": build_formatted_time(exp.TimeToStr, "postgres"),
"TO_DATE": build_formatted_time(exp.StrToDate, "postgres"),
"TO_TIMESTAMP": _build_to_timestamp,
"UNNEST": exp.Explode.from_arg_list,
"SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)),

View file

@ -28,6 +28,7 @@ from sqlglot.dialects.dialect import (
timestrtotime_sql,
ts_or_ds_add_cast,
unit_to_str,
sequence_sql,
)
from sqlglot.dialects.hive import Hive
from sqlglot.dialects.mysql import MySQL
@ -204,11 +205,11 @@ def _jsonextract_sql(self: Presto.Generator, expression: exp.JSONExtract) -> str
return f"{this}{expr}"
def _to_int(expression: exp.Expression) -> exp.Expression:
def _to_int(self: Presto.Generator, expression: exp.Expression) -> exp.Expression:
if not expression.type:
from sqlglot.optimizer.annotate_types import annotate_types
annotate_types(expression)
annotate_types(expression, dialect=self.dialect)
if expression.type and expression.type.this not in exp.DataType.INTEGER_TYPES:
return exp.cast(expression, to=exp.DataType.Type.BIGINT)
return expression
@ -229,7 +230,7 @@ def _date_delta_sql(
name: str, negate_interval: bool = False
) -> t.Callable[[Presto.Generator, DATE_ADD_OR_SUB], str]:
def _delta_sql(self: Presto.Generator, expression: DATE_ADD_OR_SUB) -> str:
interval = _to_int(expression.expression)
interval = _to_int(self, expression.expression)
return self.func(
name,
unit_to_str(expression),
@ -256,6 +257,21 @@ class Presto(Dialect):
# https://github.com/prestodb/presto/issues/2863
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
# The result of certain math functions in Presto/Trino is of type
# equal to the input type e.g: FLOOR(5.5/2) -> DECIMAL, FLOOR(5/2) -> BIGINT
ANNOTATORS = {
**Dialect.ANNOTATORS,
exp.Floor: lambda self, e: self._annotate_by_args(e, "this"),
exp.Ceil: lambda self, e: self._annotate_by_args(e, "this"),
exp.Mod: lambda self, e: self._annotate_by_args(e, "this", "expression"),
exp.Round: lambda self, e: self._annotate_by_args(e, "this"),
exp.Sign: lambda self, e: self._annotate_by_args(e, "this"),
exp.Abs: lambda self, e: self._annotate_by_args(e, "this"),
exp.Rand: lambda self, e: self._annotate_by_args(e, "this")
if e.this
else self._set_type(e, exp.DataType.Type.DOUBLE),
}
class Tokenizer(tokens.Tokenizer):
UNICODE_STRINGS = [
(prefix + q, q)
@ -420,6 +436,7 @@ class Presto(Dialect):
exp.FirstValue: _first_last_sql,
exp.FromTimeZone: lambda self,
e: f"WITH_TIMEZONE({self.sql(e, 'this')}, {self.sql(e, 'zone')}) AT TIME ZONE 'UTC'",
exp.GenerateSeries: sequence_sql,
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.GroupConcat: lambda self, e: self.func(
"ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator")
@ -572,11 +589,20 @@ class Presto(Dialect):
# timezone involved, we wrap it in a `TRY` call and use `PARSE_DATETIME` as a fallback,
# which seems to be using the same time mapping as Hive, as per:
# https://joda-time.sourceforge.net/apidocs/org/joda/time/format/DateTimeFormat.html
value_as_text = exp.cast(expression.this, exp.DataType.Type.TEXT)
this = expression.this
value_as_text = exp.cast(this, exp.DataType.Type.TEXT)
value_as_timestamp = (
exp.cast(this, exp.DataType.Type.TIMESTAMP) if this.is_string else this
)
parse_without_tz = self.func("DATE_PARSE", value_as_text, self.format_time(expression))
formatted_value = self.func(
"DATE_FORMAT", value_as_timestamp, self.format_time(expression)
)
parse_with_tz = self.func(
"PARSE_DATETIME",
value_as_text,
formatted_value,
self.format_time(expression, Hive.INVERSE_TIME_MAPPING, Hive.INVERSE_TIME_TRIE),
)
coalesced = self.func("COALESCE", self.func("TRY", parse_without_tz), parse_with_tz)
@ -636,26 +662,6 @@ class Presto(Dialect):
modes = f" {', '.join(modes)}" if modes else ""
return f"START TRANSACTION{modes}"
def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
start = expression.args["start"]
end = expression.args["end"]
step = expression.args.get("step")
if isinstance(start, exp.Cast):
target_type = start.to
elif isinstance(end, exp.Cast):
target_type = end.to
else:
target_type = None
if target_type and target_type.is_type("timestamp"):
if target_type is start.to:
end = exp.cast(end, target_type)
else:
start = exp.cast(start, target_type)
return self.func("SEQUENCE", start, end, step)
def offset_limit_modifiers(
self, expression: exp.Expression, fetch: bool, limit: t.Optional[exp.Fetch | exp.Limit]
) -> t.List[str]:

View file

@ -504,43 +504,6 @@ class Snowflake(Dialect):
return lateral
def _parse_historical_data(self) -> t.Optional[exp.HistoricalData]:
# https://docs.snowflake.com/en/sql-reference/constructs/at-before
index = self._index
historical_data = None
if self._match_texts(self.HISTORICAL_DATA_PREFIX):
this = self._prev.text.upper()
kind = (
self._match(TokenType.L_PAREN)
and self._match_texts(self.HISTORICAL_DATA_KIND)
and self._prev.text.upper()
)
expression = self._match(TokenType.FARROW) and self._parse_bitwise()
if expression:
self._match_r_paren()
historical_data = self.expression(
exp.HistoricalData, this=this, kind=kind, expression=expression
)
else:
self._retreat(index)
return historical_data
def _parse_changes(self) -> t.Optional[exp.Changes]:
if not self._match_text_seq("CHANGES", "(", "INFORMATION", "=>"):
return None
information = self._parse_var(any_token=True)
self._match_r_paren()
return self.expression(
exp.Changes,
information=information,
at_before=self._parse_historical_data(),
end=self._parse_historical_data(),
)
def _parse_table_parts(
self, schema: bool = False, is_db_reference: bool = False, wildcard: bool = False
) -> exp.Table:
@ -573,14 +536,6 @@ class Snowflake(Dialect):
else:
table = super()._parse_table_parts(schema=schema, is_db_reference=is_db_reference)
changes = self._parse_changes()
if changes:
table.set("changes", changes)
at_before = self._parse_historical_data()
if at_before:
table.set("when", at_before)
return table
def _parse_id_var(
@ -659,7 +614,7 @@ class Snowflake(Dialect):
# can be joined in a query with a comma separator, as well as closing paren
# in case of subqueries
while self._is_connected() and not self._match_set(
(TokenType.COMMA, TokenType.R_PAREN), advance=False
(TokenType.COMMA, TokenType.L_PAREN, TokenType.R_PAREN), advance=False
):
parts.append(self._advance_any(ignore_reserved=True))

View file

@ -165,9 +165,6 @@ class Spark2(Hive):
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
}
def _parse_add_column(self) -> t.Optional[exp.Expression]:
return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
def _parse_drop_column(self) -> t.Optional[exp.Drop | exp.Command]:
return self._match_text_seq("DROP", "COLUMNS") and self.expression(
exp.Drop, this=self._parse_schema(), kind="COLUMNS"

View file

@ -855,6 +855,7 @@ class TSQL(Dialect):
transforms.eliminate_qualify,
]
),
exp.Stddev: rename_func("STDEV"),
exp.StrPosition: lambda self, e: self.func(
"CHARINDEX", e.args.get("substr"), e.this, e.args.get("position")
),

View file

@ -33,7 +33,7 @@ from sqlglot.helper import (
seq_get,
subclasses,
)
from sqlglot.tokens import Token
from sqlglot.tokens import Token, TokenError
if t.TYPE_CHECKING:
from sqlglot._typing import E, Lit
@ -1393,6 +1393,8 @@ class Create(DDL):
"begin": False,
"end": False,
"clone": False,
"concurrently": False,
"clustered": False,
}
@property
@ -5483,6 +5485,16 @@ class JSONTable(Func):
}
# https://docs.snowflake.com/en/sql-reference/functions/object_insert
class ObjectInsert(Func):
arg_types = {
"this": True,
"key": True,
"value": True,
"update_flag": False,
}
class OpenJSONColumnDef(Expression):
arg_types = {"this": True, "kind": True, "path": False, "as_json": False}
@ -5886,7 +5898,7 @@ class Sqrt(Func):
class Stddev(AggFunc):
pass
_sql_names = ["STDDEV", "STDEV"]
class StddevPop(AggFunc):
@ -6881,7 +6893,7 @@ def parse_identifier(name: str | Identifier, dialect: DialectType = None) -> Ide
"""
try:
expression = maybe_parse(name, dialect=dialect, into=Identifier)
except ParseError:
except (ParseError, TokenError):
expression = to_identifier(name)
return expression

View file

@ -1027,6 +1027,14 @@ class Generator(metaclass=_Generator):
replace = " OR REPLACE" if expression.args.get("replace") else ""
unique = " UNIQUE" if expression.args.get("unique") else ""
clustered = expression.args.get("clustered")
if clustered is None:
clustered_sql = ""
elif clustered:
clustered_sql = " CLUSTERED COLUMNSTORE"
else:
clustered_sql = " NONCLUSTERED COLUMNSTORE"
postcreate_props_sql = ""
if properties_locs.get(exp.Properties.Location.POST_CREATE):
postcreate_props_sql = self.properties(
@ -1036,7 +1044,7 @@ class Generator(metaclass=_Generator):
wrapped=False,
)
modifiers = "".join((replace, unique, postcreate_props_sql))
modifiers = "".join((clustered_sql, replace, unique, postcreate_props_sql))
postexpression_props_sql = ""
if properties_locs.get(exp.Properties.Location.POST_EXPRESSION):
@ -1049,6 +1057,7 @@ class Generator(metaclass=_Generator):
wrapped=False,
)
concurrently = " CONCURRENTLY" if expression.args.get("concurrently") else ""
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
no_schema_binding = (
" WITH NO SCHEMA BINDING" if expression.args.get("no_schema_binding") else ""
@ -1057,7 +1066,7 @@ class Generator(metaclass=_Generator):
clone = self.sql(expression, "clone")
clone = f" {clone}" if clone else ""
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{postexpression_props_sql}{index_sql}{no_schema_binding}{clone}"
expression_sql = f"CREATE{modifiers} {kind}{concurrently}{exists_sql} {this}{properties_sql}{expression_sql}{postexpression_props_sql}{index_sql}{no_schema_binding}{clone}"
return self.prepend_ctes(expression, expression_sql)
def sequenceproperties_sql(self, expression: exp.SequenceProperties) -> str:
@ -1734,8 +1743,7 @@ class Generator(metaclass=_Generator):
alias = f"{sep}{alias}" if alias else ""
hints = self.expressions(expression, key="hints", sep=" ")
hints = f" {hints}" if hints and self.TABLE_HINTS else ""
pivots = self.expressions(expression, key="pivots", sep=" ", flat=True)
pivots = f" {pivots}" if pivots else ""
pivots = self.expressions(expression, key="pivots", sep="", flat=True)
joins = self.indent(
self.expressions(expression, key="joins", sep="", flat=True), skip_first=True
)
@ -1822,7 +1830,7 @@ class Generator(metaclass=_Generator):
alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else ""
direction = "UNPIVOT" if expression.unpivot else "PIVOT"
direction = self.seg("UNPIVOT" if expression.unpivot else "PIVOT")
field = self.sql(expression, "field")
include_nulls = expression.args.get("include_nulls")
if include_nulls is not None:
@ -2409,10 +2417,7 @@ class Generator(metaclass=_Generator):
def subquery_sql(self, expression: exp.Subquery, sep: str = " AS ") -> str:
alias = self.sql(expression, "alias")
alias = f"{sep}{alias}" if alias else ""
pivots = self.expressions(expression, key="pivots", sep=" ", flat=True)
pivots = f" {pivots}" if pivots else ""
pivots = self.expressions(expression, key="pivots", sep="", flat=True)
sql = self.query_modifiers(expression, self.wrap(expression), alias, pivots)
return self.prepend_ctes(expression, sql)
@ -3134,6 +3139,7 @@ class Generator(metaclass=_Generator):
expression,
key="actions",
prefix="ADD COLUMN ",
skip_first=True,
)
return f"ADD {self.expressions(expression, key='actions', flat=True)}"

View file

@ -10,10 +10,10 @@ from sqlglot.helper import (
is_iso_date,
is_iso_datetime,
seq_get,
subclasses,
)
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema, ensure_schema
from sqlglot.dialects.dialect import Dialect
if t.TYPE_CHECKING:
from sqlglot._typing import B, E
@ -24,12 +24,15 @@ if t.TYPE_CHECKING:
BinaryCoercionFunc,
]
from sqlglot.dialects.dialect import DialectType, AnnotatorsType
def annotate_types(
expression: E,
schema: t.Optional[t.Dict | Schema] = None,
annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
annotators: t.Optional[AnnotatorsType] = None,
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
dialect: t.Optional[DialectType] = None,
) -> E:
"""
Infers the types of an expression, annotating its AST accordingly.
@ -54,11 +57,7 @@ def annotate_types(
schema = ensure_schema(schema)
return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]:
return lambda self, e: self._annotate_with_type(e, data_type)
return TypeAnnotator(schema, annotators, coerces_to, dialect=dialect).annotate(expression)
def _coerce_date_literal(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
@ -133,168 +132,6 @@ class _TypeAnnotator(type):
class TypeAnnotator(metaclass=_TypeAnnotator):
TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
exp.DataType.Type.BIGINT: {
exp.ApproxDistinct,
exp.ArraySize,
exp.Count,
exp.Length,
},
exp.DataType.Type.BOOLEAN: {
exp.Between,
exp.Boolean,
exp.In,
exp.RegexpLike,
},
exp.DataType.Type.DATE: {
exp.CurrentDate,
exp.Date,
exp.DateFromParts,
exp.DateStrToDate,
exp.DiToDate,
exp.StrToDate,
exp.TimeStrToDate,
exp.TsOrDsToDate,
},
exp.DataType.Type.DATETIME: {
exp.CurrentDatetime,
exp.Datetime,
exp.DatetimeAdd,
exp.DatetimeSub,
},
exp.DataType.Type.DOUBLE: {
exp.ApproxQuantile,
exp.Avg,
exp.Div,
exp.Exp,
exp.Ln,
exp.Log,
exp.Pow,
exp.Quantile,
exp.Round,
exp.SafeDivide,
exp.Sqrt,
exp.Stddev,
exp.StddevPop,
exp.StddevSamp,
exp.Variance,
exp.VariancePop,
},
exp.DataType.Type.INT: {
exp.Ceil,
exp.DatetimeDiff,
exp.DateDiff,
exp.TimestampDiff,
exp.TimeDiff,
exp.DateToDi,
exp.Floor,
exp.Levenshtein,
exp.Sign,
exp.StrPosition,
exp.TsOrDiToDi,
},
exp.DataType.Type.JSON: {
exp.ParseJSON,
},
exp.DataType.Type.TIME: {
exp.Time,
},
exp.DataType.Type.TIMESTAMP: {
exp.CurrentTime,
exp.CurrentTimestamp,
exp.StrToTime,
exp.TimeAdd,
exp.TimeStrToTime,
exp.TimeSub,
exp.TimestampAdd,
exp.TimestampSub,
exp.UnixToTime,
},
exp.DataType.Type.TINYINT: {
exp.Day,
exp.Month,
exp.Week,
exp.Year,
exp.Quarter,
},
exp.DataType.Type.VARCHAR: {
exp.ArrayConcat,
exp.Concat,
exp.ConcatWs,
exp.DateToDateStr,
exp.GroupConcat,
exp.Initcap,
exp.Lower,
exp.Substring,
exp.TimeToStr,
exp.TimeToTimeStr,
exp.Trim,
exp.TsOrDsToDateStr,
exp.UnixToStr,
exp.UnixToTimeStr,
exp.Upper,
},
}
ANNOTATORS: t.Dict = {
**{
expr_type: lambda self, e: self._annotate_unary(e)
for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
},
**{
expr_type: lambda self, e: self._annotate_binary(e)
for expr_type in subclasses(exp.__name__, exp.Binary)
},
**{
expr_type: _annotate_with_type_lambda(data_type)
for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
for expr_type in expressions
},
exp.Abs: lambda self, e: self._annotate_by_args(e, "this"),
exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Bracket: lambda self, e: self._annotate_bracket(e),
exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
exp.DateAdd: lambda self, e: self._annotate_timeunit(e),
exp.DateSub: lambda self, e: self._annotate_timeunit(e),
exp.DateTrunc: lambda self, e: self._annotate_timeunit(e),
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Div: lambda self, e: self._annotate_div(e),
exp.Dot: lambda self, e: self._annotate_dot(e),
exp.Explode: lambda self, e: self._annotate_explode(e),
exp.Extract: lambda self, e: self._annotate_extract(e),
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
exp.GenerateDateArray: lambda self, e: self._annotate_with_type(
e, exp.DataType.build("ARRAY<DATE>")
),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Literal: lambda self, e: self._annotate_literal(e),
exp.Map: lambda self, e: self._annotate_map(e),
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"),
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Struct: lambda self, e: self._annotate_struct(e),
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
exp.Timestamp: lambda self, e: self._annotate_with_type(
e,
exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP,
),
exp.ToMap: lambda self, e: self._annotate_to_map(e),
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.Unnest: lambda self, e: self._annotate_unnest(e),
exp.VarMap: lambda self, e: self._annotate_map(e),
}
NESTED_TYPES = {
exp.DataType.Type.ARRAY,
}
@ -335,12 +172,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
def __init__(
self,
schema: Schema,
annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
annotators: t.Optional[AnnotatorsType] = None,
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
binary_coercions: t.Optional[BinaryCoercions] = None,
dialect: t.Optional[DialectType] = None,
) -> None:
self.schema = schema
self.annotators = annotators or self.ANNOTATORS
self.annotators = annotators or Dialect.get_or_raise(dialect).ANNOTATORS
self.coerces_to = coerces_to or self.COERCES_TO
self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
@ -483,7 +321,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
return expression
def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
def _annotate_with_type(
self, expression: E, target_type: exp.DataType | exp.DataType.Type
) -> E:
self._set_type(expression, target_type)
return self._annotate_args(expression)

View file

@ -376,6 +376,7 @@ class Parser(metaclass=_Parser):
# Tokens that can represent identifiers
ID_VAR_TOKENS = {
TokenType.ALL,
TokenType.VAR,
TokenType.ANTI,
TokenType.APPLY,
@ -929,7 +930,8 @@ class Parser(metaclass=_Parser):
enforced=self._match_text_seq("ENFORCED"),
),
"COLLATE": lambda self: self.expression(
exp.CollateColumnConstraint, this=self._parse_var(any_token=True)
exp.CollateColumnConstraint,
this=self._parse_identifier() or self._parse_column(),
),
"COMMENT": lambda self: self.expression(
exp.CommentColumnConstraint, this=self._parse_string()
@ -1138,7 +1140,9 @@ class Parser(metaclass=_Parser):
ISOLATED_LOADING_OPTIONS: OPTIONS_TYPE = {"FOR": ("ALL", "INSERT", "NONE")}
USABLES: OPTIONS_TYPE = dict.fromkeys(("ROLE", "WAREHOUSE", "DATABASE", "SCHEMA"), tuple())
USABLES: OPTIONS_TYPE = dict.fromkeys(
("ROLE", "WAREHOUSE", "DATABASE", "SCHEMA", "CATALOG"), tuple()
)
CAST_ACTIONS: OPTIONS_TYPE = dict.fromkeys(("RENAME", "ADD"), ("FIELDS",))
@ -1147,6 +1151,17 @@ class Parser(metaclass=_Parser):
**dict.fromkeys(("BINDING", "COMPENSATION", "EVOLUTION"), tuple()),
}
KEY_CONSTRAINT_OPTIONS: OPTIONS_TYPE = {
"NOT": ("ENFORCED",),
"MATCH": (
"FULL",
"PARTIAL",
"SIMPLE",
),
"INITIALLY": ("DEFERRED", "IMMEDIATE"),
**dict.fromkeys(("DEFERRABLE", "NORELY"), tuple()),
}
INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"}
CLONE_KEYWORDS = {"CLONE", "COPY"}
@ -1663,6 +1678,15 @@ class Parser(metaclass=_Parser):
unique = self._match(TokenType.UNIQUE)
if self._match_text_seq("CLUSTERED", "COLUMNSTORE"):
clustered = True
elif self._match_text_seq("NONCLUSTERED", "COLUMNSTORE") or self._match_text_seq(
"COLUMNSTORE"
):
clustered = False
else:
clustered = None
if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False):
self._advance()
@ -1677,6 +1701,7 @@ class Parser(metaclass=_Parser):
if not properties or not create_token:
return self._parse_as_command(start)
concurrently = self._match_text_seq("CONCURRENTLY")
exists = self._parse_exists(not_=True)
this = None
expression: t.Optional[exp.Expression] = None
@ -1802,6 +1827,8 @@ class Parser(metaclass=_Parser):
begin=begin,
end=end,
clone=clone,
concurrently=concurrently,
clustered=clustered,
)
def _parse_sequence_properties(self) -> t.Optional[exp.SequenceProperties]:
@ -2728,8 +2755,12 @@ class Parser(metaclass=_Parser):
comments = self._prev_comments
hint = self._parse_hint()
all_ = self._match(TokenType.ALL)
distinct = self._match_set(self.DISTINCT_TOKENS)
if self._next and not self._next.token_type == TokenType.DOT:
all_ = self._match(TokenType.ALL)
distinct = self._match_set(self.DISTINCT_TOKENS)
else:
all_, distinct = None, None
kind = (
self._match(TokenType.ALIAS)
@ -2827,6 +2858,7 @@ class Parser(metaclass=_Parser):
self.raise_error("Expected CTE to have alias")
self._match(TokenType.ALIAS)
comments = self._prev_comments
if self._match_text_seq("NOT", "MATERIALIZED"):
materialized = False
@ -2840,6 +2872,7 @@ class Parser(metaclass=_Parser):
this=self._parse_wrapped(self._parse_statement),
alias=alias,
materialized=materialized,
comments=comments,
)
def _parse_table_alias(
@ -3352,15 +3385,28 @@ class Parser(metaclass=_Parser):
if not db and is_db_reference:
self.raise_error(f"Expected database name but got {self._curr}")
return self.expression(
table = self.expression(
exp.Table,
comments=comments,
this=table,
db=db,
catalog=catalog,
pivots=self._parse_pivots(),
)
changes = self._parse_changes()
if changes:
table.set("changes", changes)
at_before = self._parse_historical_data()
if at_before:
table.set("when", at_before)
pivots = self._parse_pivots()
if pivots:
table.set("pivots", pivots)
return table
def _parse_table(
self,
schema: bool = False,
@ -3490,6 +3536,43 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Version, this=this, expression=expression, kind=kind)
def _parse_historical_data(self) -> t.Optional[exp.HistoricalData]:
# https://docs.snowflake.com/en/sql-reference/constructs/at-before
index = self._index
historical_data = None
if self._match_texts(self.HISTORICAL_DATA_PREFIX):
this = self._prev.text.upper()
kind = (
self._match(TokenType.L_PAREN)
and self._match_texts(self.HISTORICAL_DATA_KIND)
and self._prev.text.upper()
)
expression = self._match(TokenType.FARROW) and self._parse_bitwise()
if expression:
self._match_r_paren()
historical_data = self.expression(
exp.HistoricalData, this=this, kind=kind, expression=expression
)
else:
self._retreat(index)
return historical_data
def _parse_changes(self) -> t.Optional[exp.Changes]:
if not self._match_text_seq("CHANGES", "(", "INFORMATION", "=>"):
return None
information = self._parse_var(any_token=True)
self._match_r_paren()
return self.expression(
exp.Changes,
information=information,
at_before=self._parse_historical_data(),
end=self._parse_historical_data(),
)
def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]:
if not self._match(TokenType.UNNEST):
return None
@ -5216,18 +5299,13 @@ class Parser(metaclass=_Parser):
self.raise_error("Invalid key constraint")
options.append(f"ON {on} {action}")
elif self._match_text_seq("NOT", "ENFORCED"):
options.append("NOT ENFORCED")
elif self._match_text_seq("DEFERRABLE"):
options.append("DEFERRABLE")
elif self._match_text_seq("INITIALLY", "DEFERRED"):
options.append("INITIALLY DEFERRED")
elif self._match_text_seq("NORELY"):
options.append("NORELY")
elif self._match_text_seq("MATCH", "FULL"):
options.append("MATCH FULL")
else:
break
var = self._parse_var_from_options(
self.KEY_CONSTRAINT_OPTIONS, raise_unmatched=False
)
if not var:
break
options.append(var.name)
return options
@ -6227,6 +6305,13 @@ class Parser(metaclass=_Parser):
self._retreat(index)
if not self.ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN and self._match_text_seq("ADD"):
return self._parse_wrapped_csv(self._parse_field_def, optional=True)
if self._match_text_seq("ADD", "COLUMNS"):
schema = self._parse_schema()
if schema:
return [schema]
return []
return self._parse_wrapped_csv(self._parse_add_column, optional=True)
def _parse_alter_table_alter(self) -> t.Optional[exp.Expression]:

View file

@ -229,6 +229,23 @@ def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
"""Convert cross join unnest into lateral view explode."""
if isinstance(expression, exp.Select):
from_ = expression.args.get("from")
if from_ and isinstance(from_.this, exp.Unnest):
unnest = from_.this
alias = unnest.args.get("alias")
udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
this, *expressions = unnest.expressions
unnest.replace(
exp.Table(
this=udtf(
this=this,
expressions=expressions,
),
alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
)
)
for join in expression.args.get("joins") or []:
unnest = join.this