Adding upstream version 23.7.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
f1aa09959c
commit
27c061b7af
187 changed files with 86502 additions and 71397 deletions
|
@ -20,8 +20,7 @@ from sqlglot.dialects.dialect import (
|
|||
timestrtotime_sql,
|
||||
var_map_sql,
|
||||
)
|
||||
from sqlglot.expressions import Literal
|
||||
from sqlglot.helper import flatten, is_int, seq_get
|
||||
from sqlglot.helper import flatten, is_float, is_int, seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
@ -29,33 +28,35 @@ if t.TYPE_CHECKING:
|
|||
|
||||
|
||||
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
|
||||
def _build_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]:
|
||||
if len(args) == 2:
|
||||
first_arg, second_arg = args
|
||||
if second_arg.is_string:
|
||||
# case: <string_expr> [ , <format> ]
|
||||
return build_formatted_time(exp.StrToTime, "snowflake")(args)
|
||||
return exp.UnixToTime(this=first_arg, scale=second_arg)
|
||||
def _build_datetime(
|
||||
name: str, kind: exp.DataType.Type, safe: bool = False
|
||||
) -> t.Callable[[t.List], exp.Func]:
|
||||
def _builder(args: t.List) -> exp.Func:
|
||||
value = seq_get(args, 0)
|
||||
|
||||
from sqlglot.optimizer.simplify import simplify_literals
|
||||
if isinstance(value, exp.Literal):
|
||||
int_value = is_int(value.this)
|
||||
|
||||
# The first argument might be an expression like 40 * 365 * 86400, so we try to
|
||||
# reduce it using `simplify_literals` first and then check if it's a Literal.
|
||||
first_arg = seq_get(args, 0)
|
||||
if not isinstance(simplify_literals(first_arg, root=True), Literal):
|
||||
# case: <variant_expr> or other expressions such as columns
|
||||
return exp.TimeStrToTime.from_arg_list(args)
|
||||
# Converts calls like `TO_TIME('01:02:03')` into casts
|
||||
if len(args) == 1 and value.is_string and not int_value:
|
||||
return exp.cast(value, kind)
|
||||
|
||||
if first_arg.is_string:
|
||||
if is_int(first_arg.this):
|
||||
# case: <integer>
|
||||
return exp.UnixToTime.from_arg_list(args)
|
||||
# Handles `TO_TIMESTAMP(str, fmt)` and `TO_TIMESTAMP(num, scale)` as special
|
||||
# cases so we can transpile them, since they're relatively common
|
||||
if kind == exp.DataType.Type.TIMESTAMP:
|
||||
if int_value:
|
||||
return exp.UnixToTime(this=value, scale=seq_get(args, 1))
|
||||
if not is_float(value.this):
|
||||
return build_formatted_time(exp.StrToTime, "snowflake")(args)
|
||||
|
||||
# case: <date_expr>
|
||||
return build_formatted_time(exp.StrToTime, "snowflake", default=True)(args)
|
||||
if len(args) == 2 and kind == exp.DataType.Type.DATE:
|
||||
formatted_exp = build_formatted_time(exp.TsOrDsToDate, "snowflake")(args)
|
||||
formatted_exp.set("safe", safe)
|
||||
return formatted_exp
|
||||
|
||||
# case: <numeric_expr>
|
||||
return exp.UnixToTime.from_arg_list(args)
|
||||
return exp.Anonymous(this=name, expressions=args)
|
||||
|
||||
return _builder
|
||||
|
||||
|
||||
def _build_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
|
||||
|
@ -77,6 +78,17 @@ def _build_datediff(args: t.List) -> exp.DateDiff:
|
|||
)
|
||||
|
||||
|
||||
def _build_date_time_add(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
|
||||
def _builder(args: t.List) -> E:
|
||||
return expr_type(
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
unit=_map_date_part(seq_get(args, 0)),
|
||||
)
|
||||
|
||||
return _builder
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/div0
|
||||
def _build_if_from_div0(args: t.List) -> exp.If:
|
||||
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
|
||||
|
@ -97,14 +109,6 @@ def _build_if_from_nullifzero(args: t.List) -> exp.If:
|
|||
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
|
||||
|
||||
|
||||
def _datatype_sql(self: Snowflake.Generator, expression: exp.DataType) -> str:
|
||||
if expression.is_type("array"):
|
||||
return "ARRAY"
|
||||
elif expression.is_type("map"):
|
||||
return "OBJECT"
|
||||
return self.datatype_sql(expression)
|
||||
|
||||
|
||||
def _regexpilike_sql(self: Snowflake.Generator, expression: exp.RegexpILike) -> str:
|
||||
flag = expression.text("flag")
|
||||
|
||||
|
@ -258,6 +262,25 @@ def _unqualify_unpivot_columns(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def _flatten_structured_types_unless_iceberg(expression: exp.Expression) -> exp.Expression:
|
||||
assert isinstance(expression, exp.Create)
|
||||
|
||||
def _flatten_structured_type(expression: exp.DataType) -> exp.DataType:
|
||||
if expression.this in exp.DataType.NESTED_TYPES:
|
||||
expression.set("expressions", None)
|
||||
return expression
|
||||
|
||||
props = expression.args.get("properties")
|
||||
if isinstance(expression.this, exp.Schema) and not (props and props.find(exp.IcebergProperty)):
|
||||
for schema_expression in expression.this.expressions:
|
||||
if isinstance(schema_expression, exp.ColumnDef):
|
||||
column_type = schema_expression.kind
|
||||
if isinstance(column_type, exp.DataType):
|
||||
column_type.transform(_flatten_structured_type, copy=False)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
class Snowflake(Dialect):
|
||||
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
|
||||
|
@ -312,7 +335,13 @@ class Snowflake(Dialect):
|
|||
class Parser(parser.Parser):
|
||||
IDENTIFY_PIVOT_STRINGS = True
|
||||
|
||||
ID_VAR_TOKENS = {
|
||||
*parser.Parser.ID_VAR_TOKENS,
|
||||
TokenType.MATCH_CONDITION,
|
||||
}
|
||||
|
||||
TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS | {TokenType.WINDOW}
|
||||
TABLE_ALIAS_TOKENS.discard(TokenType.MATCH_CONDITION)
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
|
@ -327,17 +356,13 @@ class Snowflake(Dialect):
|
|||
end=exp.Sub(this=seq_get(args, 1), expression=exp.Literal.number(1)),
|
||||
step=seq_get(args, 2),
|
||||
),
|
||||
"ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list,
|
||||
"BITXOR": binary_from_function(exp.BitwiseXor),
|
||||
"BIT_XOR": binary_from_function(exp.BitwiseXor),
|
||||
"BOOLXOR": binary_from_function(exp.Xor),
|
||||
"CONVERT_TIMEZONE": _build_convert_timezone,
|
||||
"DATE": _build_datetime("DATE", exp.DataType.Type.DATE),
|
||||
"DATE_TRUNC": _date_trunc_to_time,
|
||||
"DATEADD": lambda args: exp.DateAdd(
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
unit=_map_date_part(seq_get(args, 0)),
|
||||
),
|
||||
"DATEADD": _build_date_time_add(exp.DateAdd),
|
||||
"DATEDIFF": _build_datediff,
|
||||
"DIV0": _build_if_from_div0,
|
||||
"FLATTEN": exp.Explode.from_arg_list,
|
||||
|
@ -349,17 +374,34 @@ class Snowflake(Dialect):
|
|||
this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1))
|
||||
),
|
||||
"LISTAGG": exp.GroupConcat.from_arg_list,
|
||||
"MEDIAN": lambda args: exp.PercentileCont(
|
||||
this=seq_get(args, 0), expression=exp.Literal.number(0.5)
|
||||
),
|
||||
"NULLIFZERO": _build_if_from_nullifzero,
|
||||
"OBJECT_CONSTRUCT": _build_object_construct,
|
||||
"REGEXP_REPLACE": _build_regexp_replace,
|
||||
"REGEXP_SUBSTR": exp.RegexpExtract.from_arg_list,
|
||||
"RLIKE": exp.RegexpLike.from_arg_list,
|
||||
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
|
||||
"TIMEADD": _build_date_time_add(exp.TimeAdd),
|
||||
"TIMEDIFF": _build_datediff,
|
||||
"TIMESTAMPADD": _build_date_time_add(exp.DateAdd),
|
||||
"TIMESTAMPDIFF": _build_datediff,
|
||||
"TIMESTAMPFROMPARTS": _build_timestamp_from_parts,
|
||||
"TIMESTAMP_FROM_PARTS": _build_timestamp_from_parts,
|
||||
"TO_TIMESTAMP": _build_to_timestamp,
|
||||
"TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DataType.Type.DATE, safe=True),
|
||||
"TO_DATE": _build_datetime("TO_DATE", exp.DataType.Type.DATE),
|
||||
"TO_NUMBER": lambda args: exp.ToNumber(
|
||||
this=seq_get(args, 0),
|
||||
format=seq_get(args, 1),
|
||||
precision=seq_get(args, 2),
|
||||
scale=seq_get(args, 3),
|
||||
),
|
||||
"TO_TIME": _build_datetime("TO_TIME", exp.DataType.Type.TIME),
|
||||
"TO_TIMESTAMP": _build_datetime("TO_TIMESTAMP", exp.DataType.Type.TIMESTAMP),
|
||||
"TO_TIMESTAMP_LTZ": _build_datetime("TO_TIMESTAMP_LTZ", exp.DataType.Type.TIMESTAMPLTZ),
|
||||
"TO_TIMESTAMP_NTZ": _build_datetime("TO_TIMESTAMP_NTZ", exp.DataType.Type.TIMESTAMP),
|
||||
"TO_TIMESTAMP_TZ": _build_datetime("TO_TIMESTAMP_TZ", exp.DataType.Type.TIMESTAMPTZ),
|
||||
"TO_VARCHAR": exp.ToChar.from_arg_list,
|
||||
"ZEROIFNULL": _build_if_from_zeroifnull,
|
||||
}
|
||||
|
@ -377,7 +419,6 @@ class Snowflake(Dialect):
|
|||
**parser.Parser.RANGE_PARSERS,
|
||||
TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny),
|
||||
TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny),
|
||||
TokenType.COLON: lambda self, this: self._parse_colon_get_path(this),
|
||||
}
|
||||
|
||||
ALTER_PARSERS = {
|
||||
|
@ -434,35 +475,35 @@ class Snowflake(Dialect):
|
|||
|
||||
SCHEMA_KINDS = {"OBJECTS", "TABLES", "VIEWS", "SEQUENCES", "UNIQUE KEYS", "IMPORTED KEYS"}
|
||||
|
||||
def _parse_colon_get_path(
|
||||
self: parser.Parser, this: t.Optional[exp.Expression]
|
||||
) -> t.Optional[exp.Expression]:
|
||||
while True:
|
||||
path = self._parse_bitwise() or self._parse_var(any_token=True)
|
||||
def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
this = super()._parse_column_ops(this)
|
||||
|
||||
casts = []
|
||||
json_path = []
|
||||
|
||||
while self._match(TokenType.COLON):
|
||||
path = super()._parse_column_ops(self._parse_field(any_token=True))
|
||||
|
||||
# The cast :: operator has a lower precedence than the extraction operator :, so
|
||||
# we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH
|
||||
if isinstance(path, exp.Cast):
|
||||
target_type = path.to
|
||||
while isinstance(path, exp.Cast):
|
||||
casts.append(path.to)
|
||||
path = path.this
|
||||
else:
|
||||
target_type = None
|
||||
|
||||
if isinstance(path, exp.Expression):
|
||||
path = exp.Literal.string(path.sql(dialect="snowflake"))
|
||||
if path:
|
||||
json_path.append(path.sql(dialect="snowflake", copy=False))
|
||||
|
||||
# The extraction operator : is left-associative
|
||||
if json_path:
|
||||
this = self.expression(
|
||||
exp.JSONExtract, this=this, expression=self.dialect.to_json_path(path)
|
||||
exp.JSONExtract,
|
||||
this=this,
|
||||
expression=self.dialect.to_json_path(exp.Literal.string(".".join(json_path))),
|
||||
)
|
||||
|
||||
if target_type:
|
||||
this = exp.cast(this, target_type)
|
||||
while casts:
|
||||
this = self.expression(exp.Cast, this=this, to=casts.pop())
|
||||
|
||||
if not self._match(TokenType.COLON):
|
||||
break
|
||||
|
||||
return self._parse_range(this)
|
||||
return this
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
|
||||
# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts
|
||||
|
@ -663,6 +704,7 @@ class Snowflake(Dialect):
|
|||
"EXCLUDE": TokenType.EXCEPT,
|
||||
"ILIKE ANY": TokenType.ILIKE_ANY,
|
||||
"LIKE ANY": TokenType.LIKE_ANY,
|
||||
"MATCH_CONDITION": TokenType.MATCH_CONDITION,
|
||||
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
|
||||
"MINUS": TokenType.EXCEPT,
|
||||
"NCHAR VARYING": TokenType.VARCHAR,
|
||||
|
@ -703,6 +745,7 @@ class Snowflake(Dialect):
|
|||
LIMIT_ONLY_LITERALS = True
|
||||
JSON_KEY_VALUE_PAIR_SEP = ","
|
||||
INSERT_OVERWRITE = " OVERWRITE INTO"
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -711,15 +754,14 @@ class Snowflake(Dialect):
|
|||
exp.Array: inline_array_sql,
|
||||
exp.ArrayConcat: rename_func("ARRAY_CAT"),
|
||||
exp.ArrayContains: lambda self, e: self.func("ARRAY_CONTAINS", e.expression, e.this),
|
||||
exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
|
||||
exp.AtTimeZone: lambda self, e: self.func(
|
||||
"CONVERT_TIMEZONE", e.args.get("zone"), e.this
|
||||
),
|
||||
exp.BitwiseXor: rename_func("BITXOR"),
|
||||
exp.Create: transforms.preprocess([_flatten_structured_types_unless_iceberg]),
|
||||
exp.DateAdd: date_delta_sql("DATEADD"),
|
||||
exp.DateDiff: date_delta_sql("DATEDIFF"),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
|
@ -769,6 +811,7 @@ class Snowflake(Dialect):
|
|||
),
|
||||
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
|
||||
exp.Stuff: rename_func("INSERT"),
|
||||
exp.TimeAdd: date_delta_sql("TIMEADD"),
|
||||
exp.TimestampDiff: lambda self, e: self.func(
|
||||
"TIMESTAMPDIFF", e.unit, e.expression, e.this
|
||||
),
|
||||
|
@ -783,6 +826,9 @@ class Snowflake(Dialect):
|
|||
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
|
||||
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
|
||||
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
|
||||
exp.TsOrDsToDate: lambda self, e: self.func(
|
||||
"TRY_TO_DATE" if e.args.get("safe") else "TO_DATE", e.this, self.format_time(e)
|
||||
),
|
||||
exp.UnixToTime: rename_func("TO_TIMESTAMP"),
|
||||
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
|
||||
|
@ -797,6 +843,8 @@ class Snowflake(Dialect):
|
|||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.NESTED: "OBJECT",
|
||||
exp.DataType.Type.STRUCT: "OBJECT",
|
||||
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
|
||||
}
|
||||
|
||||
|
@ -811,6 +859,37 @@ class Snowflake(Dialect):
|
|||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
UNSUPPORTED_VALUES_EXPRESSIONS = {
|
||||
exp.Struct,
|
||||
}
|
||||
|
||||
def values_sql(self, expression: exp.Values, values_as_table: bool = True) -> str:
|
||||
if expression.find(*self.UNSUPPORTED_VALUES_EXPRESSIONS):
|
||||
values_as_table = False
|
||||
|
||||
return super().values_sql(expression, values_as_table=values_as_table)
|
||||
|
||||
def datatype_sql(self, expression: exp.DataType) -> str:
|
||||
expressions = expression.expressions
|
||||
if (
|
||||
expressions
|
||||
and expression.is_type(*exp.DataType.STRUCT_TYPES)
|
||||
and any(isinstance(field_type, exp.DataType) for field_type in expressions)
|
||||
):
|
||||
# The correct syntax is OBJECT [ (<key> <value_type [NOT NULL] [, ...]) ]
|
||||
return "OBJECT"
|
||||
|
||||
return super().datatype_sql(expression)
|
||||
|
||||
def tonumber_sql(self, expression: exp.ToNumber) -> str:
|
||||
return self.func(
|
||||
"TO_NUMBER",
|
||||
expression.this,
|
||||
expression.args.get("format"),
|
||||
expression.args.get("precision"),
|
||||
expression.args.get("scale"),
|
||||
)
|
||||
|
||||
def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str:
|
||||
milli = expression.args.get("milli")
|
||||
if milli is not None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue