Merging upstream version 10.4.2.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
de4e42d4d3
commit
0c79f8b507
88 changed files with 1637 additions and 436 deletions
|
@ -3,13 +3,15 @@ from __future__ import annotations
|
|||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
inline_array_sql,
|
||||
rename_func,
|
||||
timestrtotime_sql,
|
||||
var_map_sql,
|
||||
)
|
||||
from sqlglot.expressions import Literal
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.helper import flatten, seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
|
@ -183,7 +185,7 @@ class Snowflake(Dialect):
|
|||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
QUOTES = ["'", "$$"]
|
||||
ESCAPES = ["\\"]
|
||||
ESCAPES = ["\\", "'"]
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
**tokens.Tokenizer.SINGLE_TOKENS,
|
||||
|
@ -206,9 +208,10 @@ class Snowflake(Dialect):
|
|||
CREATE_TRANSIENT = True
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
exp.Array: inline_array_sql,
|
||||
exp.ArrayConcat: rename_func("ARRAY_CAT"),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.If: rename_func("IFF"),
|
||||
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
|
@ -218,13 +221,14 @@ class Snowflake(Dialect):
|
|||
exp.Matches: rename_func("DECODE"),
|
||||
exp.StrPosition: rename_func("POSITION"),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
|
||||
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
|
||||
}
|
||||
|
||||
|
@ -246,3 +250,47 @@ class Snowflake(Dialect):
|
|||
if not expression.args.get("distinct", False):
|
||||
self.unsupported("INTERSECT with All is not supported in Snowflake")
|
||||
return super().intersect_op(expression)
|
||||
|
||||
def values_sql(self, expression: exp.Values) -> str:
|
||||
"""Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted.
|
||||
|
||||
We also want to make sure that after we find matches where we need to unquote a column that we prevent users
|
||||
from adding quotes to the column by using the `identify` argument when generating the SQL.
|
||||
"""
|
||||
alias = expression.args.get("alias")
|
||||
if alias and alias.args.get("columns"):
|
||||
expression = expression.transform(
|
||||
lambda node: exp.Identifier(**{**node.args, "quoted": False})
|
||||
if isinstance(node, exp.Identifier)
|
||||
and isinstance(node.parent, exp.TableAlias)
|
||||
and node.arg_key == "columns"
|
||||
else node,
|
||||
)
|
||||
return self.no_identify(lambda: super(self.__class__, self).values_sql(expression))
|
||||
return super().values_sql(expression)
|
||||
|
||||
def select_sql(self, expression: exp.Select) -> str:
|
||||
"""Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted and also
|
||||
that all columns in a SELECT are unquoted. We also want to make sure that after we find matches where we need
|
||||
to unquote a column that we prevent users from adding quotes to the column by using the `identify` argument when
|
||||
generating the SQL.
|
||||
|
||||
Note: We make an assumption that any columns referenced in a VALUES expression should be unquoted throughout the
|
||||
expression. This might not be true in a case where the same column name can be sourced from another table that can
|
||||
properly quote but should be true in most cases.
|
||||
"""
|
||||
values_expressions = expression.find_all(exp.Values)
|
||||
values_identifiers = set(
|
||||
flatten(
|
||||
v.args.get("alias", exp.Alias()).args.get("columns", [])
|
||||
for v in values_expressions
|
||||
)
|
||||
)
|
||||
if values_identifiers:
|
||||
expression = expression.transform(
|
||||
lambda node: exp.Identifier(**{**node.args, "quoted": False})
|
||||
if isinstance(node, exp.Identifier) and node in values_identifiers
|
||||
else node,
|
||||
)
|
||||
return self.no_identify(lambda: super(self.__class__, self).select_sql(expression))
|
||||
return super().select_sql(expression)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue