1
0
Fork 0

Adding upstream version 20.1.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:16:46 +01:00
parent 6a89523da4
commit 5bd573dda1
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
127 changed files with 73384 additions and 73067 deletions

View file

@ -3,9 +3,12 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot._typing import E
from sqlglot.dialects.dialect import (
Dialect,
NormalizationStrategy,
binary_from_function,
date_delta_sql,
date_trunc_to_time,
datestrtodate_sql,
format_time_lambda,
@ -21,7 +24,6 @@ from sqlglot.dialects.dialect import (
)
from sqlglot.expressions import Literal
from sqlglot.helper import seq_get
from sqlglot.parser import binary_range_parser
from sqlglot.tokens import TokenType
@ -50,7 +52,7 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime,
elif second_arg.name == "3":
timescale = exp.UnixToTime.MILLIS
elif second_arg.name == "9":
timescale = exp.UnixToTime.MICROS
timescale = exp.UnixToTime.NANOS
return exp.UnixToTime(this=first_arg, scale=timescale)
@ -95,14 +97,17 @@ def _parse_datediff(args: t.List) -> exp.DateDiff:
def _unix_to_time_sql(self: Snowflake.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale in [None, exp.UnixToTime.SECONDS]:
if scale in (None, exp.UnixToTime.SECONDS):
return f"TO_TIMESTAMP({timestamp})"
if scale == exp.UnixToTime.MILLIS:
return f"TO_TIMESTAMP({timestamp}, 3)"
if scale == exp.UnixToTime.MICROS:
return f"TO_TIMESTAMP({timestamp} / 1000, 3)"
if scale == exp.UnixToTime.NANOS:
return f"TO_TIMESTAMP({timestamp}, 9)"
raise ValueError("Improper scale for timestamp")
self.unsupported(f"Unsupported scale for timestamp: {scale}.")
return ""
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
@ -201,7 +206,7 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[Snowflake.Parser]
class Snowflake(Dialect):
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
RESOLVES_IDENTIFIERS_AS_UPPERCASE = True
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
NULL_ORDERING = "nulls_are_large"
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
SUPPORTS_USER_DEFINED_TYPES = False
@ -236,6 +241,18 @@ class Snowflake(Dialect):
"ff6": "%f",
}
def quote_identifier(self, expression: E, identify: bool = True) -> E:
# This disables quoting DUAL in SELECT ... FROM DUAL, because Snowflake treats an
# unquoted DUAL keyword in a special way and does not map it to a user-defined table
if (
isinstance(expression, exp.Identifier)
and isinstance(expression.parent, exp.Table)
and expression.name.lower() == "dual"
):
return t.cast(E, expression)
return super().quote_identifier(expression, identify=identify)
class Parser(parser.Parser):
IDENTIFY_PIVOT_STRINGS = True
@ -245,6 +262,9 @@ class Snowflake(Dialect):
**parser.Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
"ARRAY_CONTAINS": lambda args: exp.ArrayContains(
this=seq_get(args, 1), expression=seq_get(args, 0)
),
"ARRAY_GENERATE_RANGE": lambda args: exp.GenerateSeries(
# ARRAY_GENERATE_RANGE has an exlusive end; we normalize it to be inclusive
start=seq_get(args, 0),
@ -296,8 +316,8 @@ class Snowflake(Dialect):
RANGE_PARSERS = {
**parser.Parser.RANGE_PARSERS,
TokenType.LIKE_ANY: binary_range_parser(exp.LikeAny),
TokenType.ILIKE_ANY: binary_range_parser(exp.ILikeAny),
TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny),
TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny),
}
ALTER_PARSERS = {
@ -317,6 +337,11 @@ class Snowflake(Dialect):
TokenType.SHOW: lambda self: self._parse_show(),
}
PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS,
"LOCATION": lambda self: self._parse_location(),
}
SHOW_PARSERS = {
"PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
@ -349,7 +374,7 @@ class Snowflake(Dialect):
table: t.Optional[exp.Expression] = None
if self._match_text_seq("@"):
table_name = "@"
while True:
while self._curr:
self._advance()
table_name += self._prev.text
if not self._match_set(self.STAGED_FILE_SINGLE_TOKENS, advance=False):
@ -411,6 +436,20 @@ class Snowflake(Dialect):
self._match_text_seq("WITH")
return self.expression(exp.SwapTable, this=self._parse_table(schema=True))
def _parse_location(self) -> exp.LocationProperty:
self._match(TokenType.EQ)
parts = [self._parse_var(any_token=True)]
while self._match(TokenType.SLASH):
if self._curr and self._prev.end + 1 == self._curr.start:
parts.append(self._parse_var(any_token=True))
else:
parts.append(exp.Var(this=""))
return self.expression(
exp.LocationProperty, this=exp.var("/".join(str(p) for p in parts))
)
class Tokenizer(tokens.Tokenizer):
STRING_ESCAPES = ["\\", "'"]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
@ -457,6 +496,7 @@ class Snowflake(Dialect):
AGGREGATE_FILTER_SUPPORTED = False
SUPPORTS_TABLE_COPY = False
COLLATE_IS_FUNC = True
LIMIT_ONLY_LITERALS = True
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@ -464,15 +504,14 @@ class Snowflake(Dialect):
exp.ArgMin: rename_func("MIN_BY"),
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.ArrayContains: lambda self, e: self.func("ARRAY_CONTAINS", e.expression, e.this),
exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
exp.AtTimeZone: lambda self, e: self.func(
"CONVERT_TIMEZONE", e.args.get("zone"), e.this
),
exp.BitwiseXor: rename_func("BITXOR"),
exp.DateAdd: lambda self, e: self.func("DATEADD", e.text("unit"), e.expression, e.this),
exp.DateDiff: lambda self, e: self.func(
"DATEDIFF", e.text("unit"), e.expression, e.this
),
exp.DateAdd: date_delta_sql("DATEADD"),
exp.DateDiff: date_delta_sql("DATEDIFF"),
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
exp.DayOfMonth: rename_func("DAYOFMONTH"),
@ -501,10 +540,11 @@ class Snowflake(Dialect):
exp.Select: transforms.preprocess(
[
transforms.eliminate_distinct_on,
transforms.explode_to_unnest(0),
transforms.explode_to_unnest(),
transforms.eliminate_semi_and_anti_joins,
]
),
exp.SHA: rename_func("SHA1"),
exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
exp.StartsWith: rename_func("STARTSWITH"),
exp.StrPosition: lambda self, e: self.func(
@ -524,6 +564,8 @@ class Snowflake(Dialect):
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"),
exp.UnixToTime: _unix_to_time_sql,
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
@ -547,6 +589,20 @@ class Snowflake(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
def trycast_sql(self, expression: exp.TryCast) -> str:
value = expression.this
if value.type is None:
from sqlglot.optimizer.annotate_types import annotate_types
value = annotate_types(value)
if value.is_type(*exp.DataType.TEXT_TYPES, exp.DataType.Type.UNKNOWN):
return super().trycast_sql(expression)
# TRY_CAST only works for string values in Snowflake
return self.cast_sql(expression)
def log_sql(self, expression: exp.Log) -> str:
if not expression.expression:
return self.func("LN", expression.this)
@ -554,24 +610,28 @@ class Snowflake(Dialect):
return super().log_sql(expression)
def unnest_sql(self, expression: exp.Unnest) -> str:
selects = ["value"]
unnest_alias = expression.args.get("alias")
offset = expression.args.get("offset")
if offset:
if unnest_alias:
unnest_alias.append("columns", offset.pop())
selects.append("index")
columns = [
exp.to_identifier("seq"),
exp.to_identifier("key"),
exp.to_identifier("path"),
offset.pop() if isinstance(offset, exp.Expression) else exp.to_identifier("index"),
seq_get(unnest_alias.columns if unnest_alias else [], 0)
or exp.to_identifier("value"),
exp.to_identifier("this"),
]
subquery = exp.Subquery(
this=exp.select(*selects).from_(
f"TABLE(FLATTEN(INPUT => {self.sql(expression.expressions[0])}))"
),
)
if unnest_alias:
unnest_alias.set("columns", columns)
else:
unnest_alias = exp.TableAlias(this="_u", columns=columns)
explode = f"TABLE(FLATTEN(INPUT => {self.sql(expression.expressions[0])}))"
alias = self.sql(unnest_alias)
alias = f" AS {alias}" if alias else ""
return f"{self.sql(subquery)}{alias}"
return f"{explode}{alias}"
def show_sql(self, expression: exp.Show) -> str:
scope = self.sql(expression, "scope")
@ -632,3 +692,6 @@ class Snowflake(Dialect):
def swaptable_sql(self, expression: exp.SwapTable) -> str:
this = self.sql(expression, "this")
return f"SWAP WITH {this}"
def with_properties(self, properties: exp.Properties) -> str:
return self.properties(properties, wrapped=False, prefix=self.seg(""), sep=" ")