Adding upstream version 21.1.2.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
92ffd7746f
commit
b01402dc30
103 changed files with 18237 additions and 17794 deletions
|
@ -10,7 +10,7 @@ from sqlglot.dialects.dialect import (
|
|||
date_delta_sql,
|
||||
date_trunc_to_time,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
build_formatted_time,
|
||||
if_sql,
|
||||
inline_array_sql,
|
||||
max_or_greatest,
|
||||
|
@ -29,12 +29,12 @@ if t.TYPE_CHECKING:
|
|||
|
||||
|
||||
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
|
||||
def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]:
|
||||
def _build_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]:
|
||||
if len(args) == 2:
|
||||
first_arg, second_arg = args
|
||||
if second_arg.is_string:
|
||||
# case: <string_expr> [ , <format> ]
|
||||
return format_time_lambda(exp.StrToTime, "snowflake")(args)
|
||||
return build_formatted_time(exp.StrToTime, "snowflake")(args)
|
||||
return exp.UnixToTime(this=first_arg, scale=second_arg)
|
||||
|
||||
from sqlglot.optimizer.simplify import simplify_literals
|
||||
|
@ -52,14 +52,14 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime,
|
|||
return exp.UnixToTime.from_arg_list(args)
|
||||
|
||||
# case: <date_expr>
|
||||
return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args)
|
||||
return build_formatted_time(exp.StrToTime, "snowflake", default=True)(args)
|
||||
|
||||
# case: <numeric_expr>
|
||||
return exp.UnixToTime.from_arg_list(args)
|
||||
|
||||
|
||||
def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
|
||||
expression = parser.parse_var_map(args)
|
||||
def _build_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
|
||||
expression = parser.build_var_map(args)
|
||||
|
||||
if isinstance(expression, exp.StarMap):
|
||||
return expression
|
||||
|
@ -71,48 +71,14 @@ def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
|
|||
)
|
||||
|
||||
|
||||
def _parse_datediff(args: t.List) -> exp.DateDiff:
|
||||
def _build_datediff(args: t.List) -> exp.DateDiff:
|
||||
return exp.DateDiff(
|
||||
this=seq_get(args, 2), expression=seq_get(args, 1), unit=_map_date_part(seq_get(args, 0))
|
||||
)
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
|
||||
# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts
|
||||
def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_var() or self._parse_type()
|
||||
|
||||
if not this:
|
||||
return None
|
||||
|
||||
self._match(TokenType.COMMA)
|
||||
expression = self._parse_bitwise()
|
||||
this = _map_date_part(this)
|
||||
name = this.name.upper()
|
||||
|
||||
if name.startswith("EPOCH"):
|
||||
if name == "EPOCH_MILLISECOND":
|
||||
scale = 10**3
|
||||
elif name == "EPOCH_MICROSECOND":
|
||||
scale = 10**6
|
||||
elif name == "EPOCH_NANOSECOND":
|
||||
scale = 10**9
|
||||
else:
|
||||
scale = None
|
||||
|
||||
ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP"))
|
||||
to_unix: exp.Expression = self.expression(exp.TimeToUnix, this=ts)
|
||||
|
||||
if scale:
|
||||
to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale))
|
||||
|
||||
return to_unix
|
||||
|
||||
return self.expression(exp.Extract, this=this, expression=expression)
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/div0
|
||||
def _div0_to_if(args: t.List) -> exp.If:
|
||||
def _build_if_from_div0(args: t.List) -> exp.If:
|
||||
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
|
||||
true = exp.Literal.number(0)
|
||||
false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1))
|
||||
|
@ -120,13 +86,13 @@ def _div0_to_if(args: t.List) -> exp.If:
|
|||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
|
||||
def _zeroifnull_to_if(args: t.List) -> exp.If:
|
||||
def _build_if_from_zeroifnull(args: t.List) -> exp.If:
|
||||
cond = exp.Is(this=seq_get(args, 0), expression=exp.Null())
|
||||
return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0))
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
|
||||
def _nullifzero_to_if(args: t.List) -> exp.If:
|
||||
def _build_if_from_nullifzero(args: t.List) -> exp.If:
|
||||
cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0))
|
||||
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
|
||||
|
||||
|
@ -150,13 +116,13 @@ def _regexpilike_sql(self: Snowflake.Generator, expression: exp.RegexpILike) ->
|
|||
)
|
||||
|
||||
|
||||
def _parse_convert_timezone(args: t.List) -> t.Union[exp.Anonymous, exp.AtTimeZone]:
|
||||
def _build_convert_timezone(args: t.List) -> t.Union[exp.Anonymous, exp.AtTimeZone]:
|
||||
if len(args) == 3:
|
||||
return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args)
|
||||
return exp.AtTimeZone(this=seq_get(args, 1), zone=seq_get(args, 0))
|
||||
|
||||
|
||||
def _parse_regexp_replace(args: t.List) -> exp.RegexpReplace:
|
||||
def _build_regexp_replace(args: t.List) -> exp.RegexpReplace:
|
||||
regexp_replace = exp.RegexpReplace.from_arg_list(args)
|
||||
|
||||
if not regexp_replace.args.get("replacement"):
|
||||
|
@ -266,38 +232,7 @@ def _date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
|
|||
return trunc
|
||||
|
||||
|
||||
def _parse_colon_get_path(
|
||||
self: parser.Parser, this: t.Optional[exp.Expression]
|
||||
) -> t.Optional[exp.Expression]:
|
||||
while True:
|
||||
path = self._parse_bitwise()
|
||||
|
||||
# The cast :: operator has a lower precedence than the extraction operator :, so
|
||||
# we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH
|
||||
if isinstance(path, exp.Cast):
|
||||
target_type = path.to
|
||||
path = path.this
|
||||
else:
|
||||
target_type = None
|
||||
|
||||
if isinstance(path, exp.Expression):
|
||||
path = exp.Literal.string(path.sql(dialect="snowflake"))
|
||||
|
||||
# The extraction operator : is left-associative
|
||||
this = self.expression(
|
||||
exp.JSONExtract, this=this, expression=self.dialect.to_json_path(path)
|
||||
)
|
||||
|
||||
if target_type:
|
||||
this = exp.cast(this, target_type)
|
||||
|
||||
if not self._match(TokenType.COLON):
|
||||
break
|
||||
|
||||
return self._parse_range(this)
|
||||
|
||||
|
||||
def _parse_timestamp_from_parts(args: t.List) -> exp.Func:
|
||||
def _build_timestamp_from_parts(args: t.List) -> exp.Func:
|
||||
if len(args) == 2:
|
||||
# Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept,
|
||||
# so we parse this into Anonymous for now instead of introducing complexity
|
||||
|
@ -396,15 +331,15 @@ class Snowflake(Dialect):
|
|||
"BITXOR": binary_from_function(exp.BitwiseXor),
|
||||
"BIT_XOR": binary_from_function(exp.BitwiseXor),
|
||||
"BOOLXOR": binary_from_function(exp.Xor),
|
||||
"CONVERT_TIMEZONE": _parse_convert_timezone,
|
||||
"CONVERT_TIMEZONE": _build_convert_timezone,
|
||||
"DATE_TRUNC": _date_trunc_to_time,
|
||||
"DATEADD": lambda args: exp.DateAdd(
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
unit=_map_date_part(seq_get(args, 0)),
|
||||
),
|
||||
"DATEDIFF": _parse_datediff,
|
||||
"DIV0": _div0_to_if,
|
||||
"DATEDIFF": _build_datediff,
|
||||
"DIV0": _build_if_from_div0,
|
||||
"FLATTEN": exp.Explode.from_arg_list,
|
||||
"GET_PATH": lambda args, dialect: exp.JSONExtract(
|
||||
this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1))
|
||||
|
@ -414,24 +349,24 @@ class Snowflake(Dialect):
|
|||
this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1))
|
||||
),
|
||||
"LISTAGG": exp.GroupConcat.from_arg_list,
|
||||
"NULLIFZERO": _nullifzero_to_if,
|
||||
"OBJECT_CONSTRUCT": _parse_object_construct,
|
||||
"REGEXP_REPLACE": _parse_regexp_replace,
|
||||
"NULLIFZERO": _build_if_from_nullifzero,
|
||||
"OBJECT_CONSTRUCT": _build_object_construct,
|
||||
"REGEXP_REPLACE": _build_regexp_replace,
|
||||
"REGEXP_SUBSTR": exp.RegexpExtract.from_arg_list,
|
||||
"RLIKE": exp.RegexpLike.from_arg_list,
|
||||
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
|
||||
"TIMEDIFF": _parse_datediff,
|
||||
"TIMESTAMPDIFF": _parse_datediff,
|
||||
"TIMESTAMPFROMPARTS": _parse_timestamp_from_parts,
|
||||
"TIMESTAMP_FROM_PARTS": _parse_timestamp_from_parts,
|
||||
"TO_TIMESTAMP": _parse_to_timestamp,
|
||||
"TIMEDIFF": _build_datediff,
|
||||
"TIMESTAMPDIFF": _build_datediff,
|
||||
"TIMESTAMPFROMPARTS": _build_timestamp_from_parts,
|
||||
"TIMESTAMP_FROM_PARTS": _build_timestamp_from_parts,
|
||||
"TO_TIMESTAMP": _build_to_timestamp,
|
||||
"TO_VARCHAR": exp.ToChar.from_arg_list,
|
||||
"ZEROIFNULL": _zeroifnull_to_if,
|
||||
"ZEROIFNULL": _build_if_from_zeroifnull,
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
"DATE_PART": _parse_date_part,
|
||||
"DATE_PART": lambda self: self._parse_date_part(),
|
||||
"OBJECT_CONSTRUCT_KEEP_NULL": lambda self: self._parse_json_object(),
|
||||
}
|
||||
FUNCTION_PARSERS.pop("TRIM")
|
||||
|
@ -442,7 +377,7 @@ class Snowflake(Dialect):
|
|||
**parser.Parser.RANGE_PARSERS,
|
||||
TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny),
|
||||
TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny),
|
||||
TokenType.COLON: _parse_colon_get_path,
|
||||
TokenType.COLON: lambda self, this: self._parse_colon_get_path(this),
|
||||
}
|
||||
|
||||
ALTER_PARSERS = {
|
||||
|
@ -489,6 +424,69 @@ class Snowflake(Dialect):
|
|||
|
||||
FLATTEN_COLUMNS = ["SEQ", "KEY", "PATH", "INDEX", "VALUE", "THIS"]
|
||||
|
||||
def _parse_colon_get_path(
|
||||
self: parser.Parser, this: t.Optional[exp.Expression]
|
||||
) -> t.Optional[exp.Expression]:
|
||||
while True:
|
||||
path = self._parse_bitwise()
|
||||
|
||||
# The cast :: operator has a lower precedence than the extraction operator :, so
|
||||
# we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH
|
||||
if isinstance(path, exp.Cast):
|
||||
target_type = path.to
|
||||
path = path.this
|
||||
else:
|
||||
target_type = None
|
||||
|
||||
if isinstance(path, exp.Expression):
|
||||
path = exp.Literal.string(path.sql(dialect="snowflake"))
|
||||
|
||||
# The extraction operator : is left-associative
|
||||
this = self.expression(
|
||||
exp.JSONExtract, this=this, expression=self.dialect.to_json_path(path)
|
||||
)
|
||||
|
||||
if target_type:
|
||||
this = exp.cast(this, target_type)
|
||||
|
||||
if not self._match(TokenType.COLON):
|
||||
break
|
||||
|
||||
return self._parse_range(this)
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
|
||||
# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts
|
||||
def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_var() or self._parse_type()
|
||||
|
||||
if not this:
|
||||
return None
|
||||
|
||||
self._match(TokenType.COMMA)
|
||||
expression = self._parse_bitwise()
|
||||
this = _map_date_part(this)
|
||||
name = this.name.upper()
|
||||
|
||||
if name.startswith("EPOCH"):
|
||||
if name == "EPOCH_MILLISECOND":
|
||||
scale = 10**3
|
||||
elif name == "EPOCH_MICROSECOND":
|
||||
scale = 10**6
|
||||
elif name == "EPOCH_NANOSECOND":
|
||||
scale = 10**9
|
||||
else:
|
||||
scale = None
|
||||
|
||||
ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP"))
|
||||
to_unix: exp.Expression = self.expression(exp.TimeToUnix, this=ts)
|
||||
|
||||
if scale:
|
||||
to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale))
|
||||
|
||||
return to_unix
|
||||
|
||||
return self.expression(exp.Extract, this=this, expression=expression)
|
||||
|
||||
def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]:
|
||||
if is_map:
|
||||
# Keys are strings in Snowflake's objects, see also:
|
||||
|
@ -665,6 +663,7 @@ class Snowflake(Dialect):
|
|||
"SAMPLE": TokenType.TABLE_SAMPLE,
|
||||
"SQL_DOUBLE": TokenType.DOUBLE,
|
||||
"SQL_VARCHAR": TokenType.VARCHAR,
|
||||
"STORAGE INTEGRATION": TokenType.STORAGE_INTEGRATION,
|
||||
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
|
||||
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
|
||||
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
|
||||
|
@ -724,8 +723,10 @@ class Snowflake(Dialect):
|
|||
),
|
||||
exp.GroupConcat: rename_func("LISTAGG"),
|
||||
exp.If: if_sql(name="IFF", false_value="NULL"),
|
||||
exp.JSONExtract: rename_func("GET_PATH"),
|
||||
exp.JSONExtractScalar: rename_func("JSON_EXTRACT_PATH_TEXT"),
|
||||
exp.JSONExtract: lambda self, e: self.func("GET_PATH", e.this, e.expression),
|
||||
exp.JSONExtractScalar: lambda self, e: self.func(
|
||||
"JSON_EXTRACT_PATH_TEXT", e.this, e.expression
|
||||
),
|
||||
exp.JSONObject: lambda self, e: self.func("OBJECT_CONSTRUCT_KEEP_NULL", *e.expressions),
|
||||
exp.JSONPathRoot: lambda *_: "",
|
||||
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
|
||||
|
@ -756,8 +757,7 @@ class Snowflake(Dialect):
|
|||
exp.StrPosition: lambda self, e: self.func(
|
||||
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
|
||||
),
|
||||
exp.StrToTime: lambda self,
|
||||
e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
|
||||
exp.Struct: lambda self, e: self.func(
|
||||
"OBJECT_CONSTRUCT",
|
||||
*(arg for expression in e.expressions for arg in expression.flatten()),
|
||||
|
@ -901,12 +901,12 @@ class Snowflake(Dialect):
|
|||
)
|
||||
|
||||
def except_op(self, expression: exp.Except) -> str:
|
||||
if not expression.args.get("distinct", False):
|
||||
if not expression.args.get("distinct"):
|
||||
self.unsupported("EXCEPT with All is not supported in Snowflake")
|
||||
return super().except_op(expression)
|
||||
|
||||
def intersect_op(self, expression: exp.Intersect) -> str:
|
||||
if not expression.args.get("distinct", False):
|
||||
if not expression.args.get("distinct"):
|
||||
self.unsupported("INTERSECT with All is not supported in Snowflake")
|
||||
return super().intersect_op(expression)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue