1
0
Fork 0

Merging upstream version 25.1.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:39:30 +01:00
parent 7ab180cac9
commit 3b7539dcad
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
79 changed files with 28803 additions and 24929 deletions

View file

@ -13,6 +13,7 @@ from sqlglot.dialects.dialect import (
date_add_interval_sql,
datestrtodate_sql,
build_formatted_time,
build_timestamp_from_parts,
filter_array_using_unnest,
if_sql,
inline_array_unless_query,
@ -22,6 +23,7 @@ from sqlglot.dialects.dialect import (
build_date_delta_with_interval,
regexp_replace_sql,
rename_func,
sha256_sql,
timestrtotime_sql,
ts_or_ds_add_cast,
unit_to_var,
@ -321,6 +323,7 @@ class BigQuery(Dialect):
unit=exp.Literal.string(str(seq_get(args, 1))),
this=seq_get(args, 0),
),
"DATETIME": build_timestamp_from_parts,
"DATETIME_ADD": build_date_delta_with_interval(exp.DatetimeAdd),
"DATETIME_SUB": build_date_delta_with_interval(exp.DatetimeSub),
"DIV": binary_from_function(exp.IntDiv),
@ -637,9 +640,7 @@ class BigQuery(Dialect):
]
),
exp.SHA: rename_func("SHA1"),
exp.SHA2: lambda self, e: self.func(
"SHA256" if e.text("length") == "256" else "SHA512", e.this
),
exp.SHA2: sha256_sql,
exp.StabilityProperty: lambda self, e: (
"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC"
),
@ -649,6 +650,7 @@ class BigQuery(Dialect):
),
exp.TimeAdd: date_add_interval_sql("TIME", "ADD"),
exp.TimeFromParts: rename_func("TIME"),
exp.TimestampFromParts: rename_func("DATETIME"),
exp.TimeSub: date_add_interval_sql("TIME", "SUB"),
exp.TimestampAdd: date_add_interval_sql("TIMESTAMP", "ADD"),
exp.TimestampDiff: rename_func("TIMESTAMP_DIFF"),

View file

@ -14,6 +14,7 @@ from sqlglot.dialects.dialect import (
no_pivot_sql,
build_json_extract_path,
rename_func,
sha256_sql,
var_map_sql,
timestamptrunc_sql,
)
@ -758,9 +759,7 @@ class ClickHouse(Dialect):
exp.MD5Digest: rename_func("MD5"),
exp.MD5: lambda self, e: self.func("LOWER", self.func("HEX", self.func("MD5", e.this))),
exp.SHA: rename_func("SHA1"),
exp.SHA2: lambda self, e: self.func(
"SHA256" if e.text("length") == "256" else "SHA512", e.this
),
exp.SHA2: sha256_sql,
exp.UnixToTime: _unix_to_time_sql,
exp.TimestampTrunc: timestamptrunc_sql(zone=True),
exp.Variance: rename_func("varSamp"),

View file

@ -169,6 +169,7 @@ class _Dialect(type):
if enum not in ("", "athena", "presto", "trino"):
klass.generator_class.TRY_SUPPORTED = False
klass.generator_class.SUPPORTS_UESCAPE = False
if enum not in ("", "databricks", "hive", "spark", "spark2"):
modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
@ -177,6 +178,14 @@ class _Dialect(type):
klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
if enum not in ("", "doris", "mysql"):
klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | {
TokenType.STRAIGHT_JOIN,
}
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
TokenType.STRAIGHT_JOIN,
}
if not klass.SUPPORTS_SEMI_ANTI_JOIN:
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
TokenType.ANTI,
@ -220,6 +229,9 @@ class Dialect(metaclass=_Dialect):
SUPPORTS_SEMI_ANTI_JOIN = True
"""Whether `SEMI` or `ANTI` joins are supported."""
SUPPORTS_COLUMN_JOIN_MARKS = False
"""Whether the old-style outer join (+) syntax is supported."""
NORMALIZE_FUNCTIONS: bool | str = "upper"
"""
Determines how function names are going to be normalized.
@ -1178,3 +1190,16 @@ def build_default_decimal_type(
return exp.DataType.build(f"DECIMAL({params})")
return _builder
def build_timestamp_from_parts(args: t.List) -> exp.Func:
if len(args) == 2:
# Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept,
# so we parse this into Anonymous for now instead of introducing complexity
return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args)
return exp.TimestampFromParts.from_arg_list(args)
def sha256_sql(self: Generator, expression: exp.SHA2) -> str:
return self.func(f"SHA{expression.text('length') or '256'}", expression.this)

View file

@ -207,7 +207,7 @@ class DuckDB(Dialect):
"PIVOT_WIDER": TokenType.PIVOT,
"POSITIONAL": TokenType.POSITIONAL,
"SIGNED": TokenType.INT,
"STRING": TokenType.VARCHAR,
"STRING": TokenType.TEXT,
"UBIGINT": TokenType.UBIGINT,
"UINTEGER": TokenType.UINT,
"USMALLINT": TokenType.USMALLINT,
@ -216,6 +216,7 @@ class DuckDB(Dialect):
"TIMESTAMP_MS": TokenType.TIMESTAMP_MS,
"TIMESTAMP_NS": TokenType.TIMESTAMP_NS,
"TIMESTAMP_US": TokenType.TIMESTAMP,
"VARCHAR": TokenType.TEXT,
}
SINGLE_TOKENS = {
@ -312,9 +313,11 @@ class DuckDB(Dialect):
),
}
TYPE_CONVERTER = {
TYPE_CONVERTERS = {
# https://duckdb.org/docs/sql/data_types/numeric
exp.DataType.Type.DECIMAL: build_default_decimal_type(precision=18, scale=3),
# https://duckdb.org/docs/sql/data_types/text
exp.DataType.Type.TEXT: lambda dtype: exp.DataType.build("TEXT"),
}
def _parse_table_sample(self, as_modifier: bool = False) -> t.Optional[exp.TableSample]:
@ -495,6 +498,7 @@ class DuckDB(Dialect):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BINARY: "BLOB",
exp.DataType.Type.BPCHAR: "TEXT",
exp.DataType.Type.CHAR: "TEXT",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.NCHAR: "TEXT",

View file

@ -202,6 +202,7 @@ class MySQL(Dialect):
"CHARSET": TokenType.CHARACTER_SET,
"FORCE": TokenType.FORCE,
"IGNORE": TokenType.IGNORE,
"KEY": TokenType.KEY,
"LOCK TABLES": TokenType.COMMAND,
"LONGBLOB": TokenType.LONGBLOB,
"LONGTEXT": TokenType.LONGTEXT,

View file

@ -13,6 +13,7 @@ from sqlglot.dialects.dialect import (
trim_sql,
)
from sqlglot.helper import seq_get
from sqlglot.parser import OPTIONS_TYPE
from sqlglot.tokens import TokenType
if t.TYPE_CHECKING:
@ -32,10 +33,171 @@ def _build_timetostr_or_tochar(args: t.List) -> exp.TimeToStr | exp.ToChar:
return exp.ToChar.from_arg_list(args)
def eliminate_join_marks(ast: exp.Expression) -> exp.Expression:
from sqlglot.optimizer.scope import traverse_scope
"""Remove join marks from an expression
SELECT * FROM a, b WHERE a.id = b.id(+)
becomes:
SELECT * FROM a LEFT JOIN b ON a.id = b.id
- for each scope
- for each column with a join mark
- find the predicate it belongs to
- remove the predicate from the where clause
- convert the predicate to a join with the (+) side as the left join table
- replace the existing join with the new join
Args:
ast: The AST to remove join marks from
Returns:
The AST with join marks removed"""
for scope in traverse_scope(ast):
_eliminate_join_marks_from_scope(scope)
return ast
def _update_from(
select: exp.Select,
new_join_dict: t.Dict[str, exp.Join],
old_join_dict: t.Dict[str, exp.Join],
) -> None:
"""If the from clause needs to become a new join, find an appropriate table to use as the new from.
updates select in place
Args:
select: The select statement to update
new_join_dict: The dictionary of new joins
old_join_dict: The dictionary of old joins
"""
old_from = select.args["from"]
if old_from.alias_or_name not in new_join_dict:
return
in_old_not_new = old_join_dict.keys() - new_join_dict.keys()
if len(in_old_not_new) >= 1:
new_from_name = list(old_join_dict.keys() - new_join_dict.keys())[0]
new_from_this = old_join_dict[new_from_name].this
new_from = exp.From(this=new_from_this)
del old_join_dict[new_from_name]
select.set("from", new_from)
else:
raise ValueError("Cannot determine which table to use as the new from")
def _has_join_mark(col: exp.Expression) -> bool:
"""Check if the column has a join mark
Args:
The column to check
"""
return col.args.get("join_mark", False)
def _predicate_to_join(
eq: exp.Binary, old_joins: t.Dict[str, exp.Join], old_from: exp.From
) -> t.Optional[exp.Join]:
"""Convert an equality predicate to a join if it contains a join mark
Args:
eq: The equality expression to convert to a join
Returns:
The join expression if the equality contains a join mark (otherwise None)
"""
# if not (isinstance(eq.left, exp.Column) or isinstance(eq.right, exp.Column)):
# return None
left_columns = [col for col in eq.left.find_all(exp.Column) if _has_join_mark(col)]
right_columns = [col for col in eq.right.find_all(exp.Column) if _has_join_mark(col)]
left_has_join_mark = len(left_columns) > 0
right_has_join_mark = len(right_columns) > 0
if left_has_join_mark:
for col in left_columns:
col.set("join_mark", False)
join_on = col.table
elif right_has_join_mark:
for col in right_columns:
col.set("join_mark", False)
join_on = col.table
else:
return None
join_this = old_joins.get(join_on, old_from).this
return exp.Join(this=join_this, on=eq, kind="LEFT")
if t.TYPE_CHECKING:
from sqlglot.optimizer.scope import Scope
def _eliminate_join_marks_from_scope(scope: Scope) -> None:
"""Remove join marks columns in scope's where clause.
Converts them to left joins and replaces any existing joins.
Updates scope in place.
Args:
scope: The scope to remove join marks from
"""
select_scope = scope.expression
where = select_scope.args.get("where")
joins = select_scope.args.get("joins")
if not where:
return
if not joins:
return
# dictionaries used to keep track of joins to be replaced
old_joins = {join.alias_or_name: join for join in list(joins)}
new_joins: t.Dict[str, exp.Join] = {}
for node in scope.find_all(exp.Column):
if _has_join_mark(node):
predicate = node.find_ancestor(exp.Predicate)
if not isinstance(predicate, exp.Binary):
continue
predicate_parent = predicate.parent
join_on = predicate.pop()
new_join = _predicate_to_join(
join_on, old_joins=old_joins, old_from=select_scope.args["from"]
)
# upsert new_join into new_joins dictionary
if new_join:
if new_join.alias_or_name in new_joins:
new_joins[new_join.alias_or_name].set(
"on",
exp.and_(
new_joins[new_join.alias_or_name].args["on"],
new_join.args["on"],
),
)
else:
new_joins[new_join.alias_or_name] = new_join
# If the parent is a binary node with only one child, promote the child to the parent
if predicate_parent:
if isinstance(predicate_parent, exp.Binary):
if predicate_parent.left is None:
predicate_parent.replace(predicate_parent.right)
elif predicate_parent.right is None:
predicate_parent.replace(predicate_parent.left)
_update_from(select_scope, new_joins, old_joins)
replacement_joins = [new_joins.get(join.alias_or_name, join) for join in old_joins.values()]
select_scope.set("joins", replacement_joins)
if not where.this:
where.pop()
class Oracle(Dialect):
ALIAS_POST_TABLESAMPLE = True
LOCKING_READS_SUPPORTED = True
TABLESAMPLE_SIZE_IS_PERCENT = True
SUPPORTS_COLUMN_JOIN_MARKS = True
# See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
@ -70,6 +232,12 @@ class Oracle(Dialect):
class Tokenizer(tokens.Tokenizer):
VAR_SINGLE_TOKENS = {"@", "$", "#"}
UNICODE_STRINGS = [
(prefix + q, q)
for q in t.cast(t.List[str], tokens.Tokenizer.QUOTES)
for prefix in ("U", "u")
]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"(+)": TokenType.JOIN_MARKER,
@ -132,6 +300,7 @@ class Oracle(Dialect):
QUERY_MODIFIER_PARSERS = {
**parser.Parser.QUERY_MODIFIER_PARSERS,
TokenType.ORDER_SIBLINGS_BY: lambda self: ("order", self._parse_order()),
TokenType.WITH: lambda self: ("options", [self._parse_query_restrictions()]),
}
TYPE_LITERAL_PARSERS = {
@ -144,6 +313,13 @@ class Oracle(Dialect):
# Reference: https://stackoverflow.com/a/336455
DISTINCT_TOKENS = {TokenType.DISTINCT, TokenType.UNIQUE}
QUERY_RESTRICTIONS: OPTIONS_TYPE = {
"WITH": (
("READ", "ONLY"),
("CHECK", "OPTION"),
),
}
def _parse_xml_table(self) -> exp.XMLTable:
this = self._parse_string()
@ -173,12 +349,6 @@ class Oracle(Dialect):
**kwargs,
)
def _parse_column(self) -> t.Optional[exp.Expression]:
column = super()._parse_column()
if column:
column.set("join_mark", self._match(TokenType.JOIN_MARKER))
return column
def _parse_hint(self) -> t.Optional[exp.Hint]:
if self._match(TokenType.HINT):
start = self._curr
@ -193,11 +363,22 @@ class Oracle(Dialect):
return None
def _parse_query_restrictions(self) -> t.Optional[exp.Expression]:
kind = self._parse_var_from_options(self.QUERY_RESTRICTIONS, raise_unmatched=False)
if not kind:
return None
return self.expression(
exp.QueryOption,
this=kind,
expression=self._match(TokenType.CONSTRAINT) and self._parse_field(),
)
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
JOIN_HINTS = False
TABLE_HINTS = False
COLUMN_JOIN_MARKS_SUPPORTED = True
DATA_TYPE_SPECIFIERS_ALLOWED = True
ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = False
LIMIT_FETCH = "FETCH"
@ -282,3 +463,10 @@ class Oracle(Dialect):
if len(expression.args.get("actions", [])) > 1:
return f"ADD ({actions})"
return f"ADD {actions}"
def queryoption_sql(self, expression: exp.QueryOption) -> str:
option = self.sql(expression, "this")
value = self.sql(expression, "expression")
value = f" CONSTRAINT {value}" if value else ""
return f"{option}{value}"

View file

@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
Dialect,
JSON_EXTRACT_TYPE,
any_value_to_max_sql,
binary_from_function,
bool_xor_sql,
datestrtodate_sql,
build_formatted_time,
@ -25,6 +26,7 @@ from sqlglot.dialects.dialect import (
build_json_extract_path,
build_timestamp_trunc,
rename_func,
sha256_sql,
str_position_sql,
struct_extract_sql,
timestamptrunc_sql,
@ -329,6 +331,7 @@ class Postgres(Dialect):
"REGTYPE": TokenType.OBJECT_IDENTIFIER,
"FLOAT": TokenType.DOUBLE,
}
KEYWORDS.pop("DIV")
SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
@ -347,6 +350,9 @@ class Postgres(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE_TRUNC": build_timestamp_trunc,
"DIV": lambda args: exp.cast(
binary_from_function(exp.IntDiv)(args), exp.DataType.Type.DECIMAL
),
"GENERATE_SERIES": _build_generate_series,
"JSON_EXTRACT_PATH": build_json_extract_path(exp.JSONExtract),
"JSON_EXTRACT_PATH_TEXT": build_json_extract_path(exp.JSONExtractScalar),
@ -357,6 +363,9 @@ class Postgres(Dialect):
"TO_CHAR": build_formatted_time(exp.TimeToStr, "postgres"),
"TO_TIMESTAMP": _build_to_timestamp,
"UNNEST": exp.Explode.from_arg_list,
"SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)),
"SHA384": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(384)),
"SHA512": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(512)),
}
FUNCTION_PARSERS = {
@ -494,6 +503,7 @@ class Postgres(Dialect):
exp.DateSub: _date_add_sql("-"),
exp.Explode: rename_func("UNNEST"),
exp.GroupConcat: _string_agg_sql,
exp.IntDiv: rename_func("DIV"),
exp.JSONExtract: _json_extract_sql("JSON_EXTRACT_PATH", "->"),
exp.JSONExtractScalar: _json_extract_sql("JSON_EXTRACT_PATH_TEXT", "->>"),
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
@ -528,6 +538,7 @@ class Postgres(Dialect):
transforms.eliminate_qualify,
]
),
exp.SHA2: sha256_sql,
exp.StrPosition: str_position_sql,
exp.StrToDate: lambda self, e: self.func("TO_DATE", e.this, self.format_time(e)),
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
@ -621,3 +632,12 @@ class Postgres(Dialect):
return f"{self.expressions(expression, flat=True)}[{values}]"
return "ARRAY"
return super().datatype_sql(expression)
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
this = expression.this
# Postgres casts DIV() to decimal for transpilation but when roundtripping it's superfluous
if isinstance(this, exp.IntDiv) and expression.to == exp.DataType.build("decimal"):
return self.sql(this)
return super().cast_sql(expression, safe_prefix=safe_prefix)

View file

@ -21,6 +21,7 @@ from sqlglot.dialects.dialect import (
regexp_extract_sql,
rename_func,
right_to_substring_sql,
sha256_sql,
struct_extract_sql,
str_position_sql,
timestamptrunc_sql,
@ -452,9 +453,7 @@ class Presto(Dialect):
),
exp.MD5Digest: rename_func("MD5"),
exp.SHA: rename_func("SHA1"),
exp.SHA2: lambda self, e: self.func(
"SHA256" if e.text("length") == "256" else "SHA512", e.this
),
exp.SHA2: sha256_sql,
}
RESERVED_KEYWORDS = {

View file

@ -40,6 +40,7 @@ class Redshift(Postgres):
INDEX_OFFSET = 0
COPY_PARAMS_ARE_CSV = False
HEX_LOWERCASE = True
SUPPORTS_COLUMN_JOIN_MARKS = True
TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'"
TIME_MAPPING = {
@ -122,12 +123,13 @@ class Redshift(Postgres):
KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS,
"(+)": TokenType.JOIN_MARKER,
"HLLSKETCH": TokenType.HLLSKETCH,
"MINUS": TokenType.EXCEPT,
"SUPER": TokenType.SUPER,
"TOP": TokenType.TOP,
"UNLOAD": TokenType.COMMAND,
"VARBYTE": TokenType.VARBINARY,
"MINUS": TokenType.EXCEPT,
}
KEYWORDS.pop("VALUES")
@ -209,6 +211,7 @@ class Redshift(Postgres):
# Redshift supports LAST_DAY(..)
TRANSFORMS.pop(exp.LastDay)
TRANSFORMS.pop(exp.SHA2)
RESERVED_KEYWORDS = {
"aes128",

View file

@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
NormalizationStrategy,
binary_from_function,
build_default_decimal_type,
build_timestamp_from_parts,
date_delta_sql,
date_trunc_to_time,
datestrtodate_sql,
@ -236,15 +237,6 @@ def _date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
return trunc
def _build_timestamp_from_parts(args: t.List) -> exp.Func:
if len(args) == 2:
# Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept,
# so we parse this into Anonymous for now instead of introducing complexity
return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args)
return exp.TimestampFromParts.from_arg_list(args)
def _unqualify_unpivot_columns(expression: exp.Expression) -> exp.Expression:
"""
Snowflake doesn't allow columns referenced in UNPIVOT to be qualified,
@ -391,8 +383,8 @@ class Snowflake(Dialect):
"TIMEDIFF": _build_datediff,
"TIMESTAMPADD": _build_date_time_add(exp.DateAdd),
"TIMESTAMPDIFF": _build_datediff,
"TIMESTAMPFROMPARTS": _build_timestamp_from_parts,
"TIMESTAMP_FROM_PARTS": _build_timestamp_from_parts,
"TIMESTAMPFROMPARTS": build_timestamp_from_parts,
"TIMESTAMP_FROM_PARTS": build_timestamp_from_parts,
"TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DataType.Type.DATE, safe=True),
"TO_DATE": _build_datetime("TO_DATE", exp.DataType.Type.DATE),
"TO_NUMBER": lambda args: exp.ToNumber(
@ -446,7 +438,7 @@ class Snowflake(Dialect):
"LOCATION": lambda self: self._parse_location_property(),
}
TYPE_CONVERTER = {
TYPE_CONVERTERS = {
# https://docs.snowflake.com/en/sql-reference/data-types-numeric#number
exp.DataType.Type.DECIMAL: build_default_decimal_type(precision=38, scale=0),
}
@ -510,15 +502,18 @@ class Snowflake(Dialect):
self._retreat(self._index - 1)
if self._match_text_seq("MASKING", "POLICY"):
policy = self._parse_column()
return self.expression(
exp.MaskingPolicyColumnConstraint,
this=self._parse_id_var(),
this=policy.to_dot() if isinstance(policy, exp.Column) else policy,
expressions=self._match(TokenType.USING)
and self._parse_wrapped_csv(self._parse_id_var),
)
if self._match_text_seq("PROJECTION", "POLICY"):
policy = self._parse_column()
return self.expression(
exp.ProjectionPolicyColumnConstraint, this=self._parse_id_var()
exp.ProjectionPolicyColumnConstraint,
this=policy.to_dot() if isinstance(policy, exp.Column) else policy,
)
if self._match(TokenType.TAG):
return self.expression(

View file

@ -41,6 +41,21 @@ def _build_datediff(args: t.List) -> exp.Expression:
)
def _build_dateadd(args: t.List) -> exp.Expression:
expression = seq_get(args, 1)
if len(args) == 2:
# DATE_ADD(startDate, numDays INTEGER)
# https://docs.databricks.com/en/sql/language-manual/functions/date_add.html
return exp.TsOrDsAdd(
this=seq_get(args, 0), expression=expression, unit=exp.Literal.string("DAY")
)
# DATE_ADD / DATEADD / TIMESTAMPADD(unit, value integer, expr)
# https://docs.databricks.com/en/sql/language-manual/functions/date_add3.html
return exp.TimestampAdd(this=seq_get(args, 2), expression=expression, unit=seq_get(args, 0))
def _normalize_partition(e: exp.Expression) -> exp.Expression:
"""Normalize the expressions in PARTITION BY (<expression>, <expression>, ...)"""
if isinstance(e, str):
@ -50,6 +65,30 @@ def _normalize_partition(e: exp.Expression) -> exp.Expression:
return e
def _dateadd_sql(self: Spark.Generator, expression: exp.TsOrDsAdd | exp.TimestampAdd) -> str:
if not expression.unit or (
isinstance(expression, exp.TsOrDsAdd) and expression.text("unit").upper() == "DAY"
):
# Coming from Hive/Spark2 DATE_ADD or roundtripping the 2-arg version of Spark3/DB
return self.func("DATE_ADD", expression.this, expression.expression)
this = self.func(
"DATE_ADD",
unit_to_var(expression),
expression.expression,
expression.this,
)
if isinstance(expression, exp.TsOrDsAdd):
# The 3 arg version of DATE_ADD produces a timestamp in Spark3/DB but possibly not
# in other dialects
return_type = expression.return_type
if not return_type.is_type(exp.DataType.Type.TIMESTAMP, exp.DataType.Type.DATETIME):
this = f"CAST({this} AS {return_type})"
return this
class Spark(Spark2):
class Tokenizer(Spark2.Tokenizer):
RAW_STRINGS = [
@ -62,6 +101,9 @@ class Spark(Spark2):
FUNCTIONS = {
**Spark2.Parser.FUNCTIONS,
"ANY_VALUE": _build_with_ignore_nulls(exp.AnyValue),
"DATE_ADD": _build_dateadd,
"DATEADD": _build_dateadd,
"TIMESTAMPADD": _build_dateadd,
"DATEDIFF": _build_datediff,
"TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"),
"TIMESTAMP_NTZ": _build_as_cast("TIMESTAMP_NTZ"),
@ -111,9 +153,8 @@ class Spark(Spark2):
exp.PartitionedByProperty: lambda self,
e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}",
exp.StartsWith: rename_func("STARTSWITH"),
exp.TimestampAdd: lambda self, e: self.func(
"DATEADD", unit_to_var(e), e.expression, e.this
),
exp.TsOrDsAdd: _dateadd_sql,
exp.TimestampAdd: _dateadd_sql,
exp.TryCast: lambda self, e: (
self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e)
),

View file

@ -75,6 +75,26 @@ def _transform_create(expression: exp.Expression) -> exp.Expression:
return expression
def _generated_to_auto_increment(expression: exp.Expression) -> exp.Expression:
if not isinstance(expression, exp.ColumnDef):
return expression
generated = expression.find(exp.GeneratedAsIdentityColumnConstraint)
if generated:
t.cast(exp.ColumnConstraint, generated.parent).pop()
not_null = expression.find(exp.NotNullColumnConstraint)
if not_null:
t.cast(exp.ColumnConstraint, not_null.parent).pop()
expression.append(
"constraints", exp.ColumnConstraint(kind=exp.AutoIncrementColumnConstraint())
)
return expression
class SQLite(Dialect):
# https://sqlite.org/forum/forumpost/5e575586ac5c711b?raw
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
@ -141,6 +161,7 @@ class SQLite(Dialect):
exp.CurrentDate: lambda *_: "CURRENT_DATE",
exp.CurrentTime: lambda *_: "CURRENT_TIME",
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.ColumnDef: transforms.preprocess([_generated_to_auto_increment]),
exp.DateAdd: _date_add_sql,
exp.DateStrToDate: lambda self, e: self.sql(e, "this"),
exp.If: rename_func("IIF"),

View file

@ -1118,3 +1118,7 @@ class TSQL(Dialect):
kind = f"TABLE {kind}"
return f"{variable} AS {kind}{default}"
def options_modifier(self, expression: exp.Expression) -> str:
options = self.expressions(expression, key="options")
return f" OPTION{self.wrap(options)}" if options else ""

View file

@ -3119,22 +3119,6 @@ class Intersect(Union):
pass
class Unnest(UDTF):
arg_types = {
"expressions": True,
"alias": False,
"offset": False,
}
@property
def selects(self) -> t.List[Expression]:
columns = super().selects
offset = self.args.get("offset")
if offset:
columns = columns + [to_identifier("offset") if offset is True else offset]
return columns
class Update(Expression):
arg_types = {
"with": False,
@ -5240,6 +5224,22 @@ class PosexplodeOuter(Posexplode, ExplodeOuter):
pass
class Unnest(Func, UDTF):
arg_types = {
"expressions": True,
"alias": False,
"offset": False,
}
@property
def selects(self) -> t.List[Expression]:
columns = super().selects
offset = self.args.get("offset")
if offset:
columns = columns + [to_identifier("offset") if offset is True else offset]
return columns
class Floor(Func):
arg_types = {"this": True, "decimals": False}
@ -5765,7 +5765,7 @@ class StrPosition(Func):
class StrToDate(Func):
arg_types = {"this": True, "format": True}
arg_types = {"this": True, "format": False}
class StrToTime(Func):

View file

@ -225,9 +225,6 @@ class Generator(metaclass=_Generator):
# Whether to generate INSERT INTO ... RETURNING or INSERT INTO RETURNING ...
RETURNING_END = True
# Whether to generate the (+) suffix for columns used in old-style join conditions
COLUMN_JOIN_MARKS_SUPPORTED = False
# Whether to generate an unquoted value for EXTRACT's date part argument
EXTRACT_ALLOWS_QUOTES = True
@ -359,6 +356,9 @@ class Generator(metaclass=_Generator):
# Whether the conditional TRY(expression) function is supported
TRY_SUPPORTED = True
# Whether the UESCAPE syntax in unicode strings is supported
SUPPORTS_UESCAPE = True
# The keyword to use when generating a star projection with excluded columns
STAR_EXCEPT = "EXCEPT"
@ -827,7 +827,7 @@ class Generator(metaclass=_Generator):
def column_sql(self, expression: exp.Column) -> str:
join_mark = " (+)" if expression.args.get("join_mark") else ""
if join_mark and not self.COLUMN_JOIN_MARKS_SUPPORTED:
if join_mark and not self.dialect.SUPPORTS_COLUMN_JOIN_MARKS:
join_mark = ""
self.unsupported("Outer join syntax using the (+) operator is not supported.")
@ -1146,16 +1146,23 @@ class Generator(metaclass=_Generator):
escape = expression.args.get("escape")
if self.dialect.UNICODE_START:
escape = f" UESCAPE {self.sql(escape)}" if escape else ""
return f"{self.dialect.UNICODE_START}{this}{self.dialect.UNICODE_END}{escape}"
escape_substitute = r"\\\1"
left_quote, right_quote = self.dialect.UNICODE_START, self.dialect.UNICODE_END
else:
escape_substitute = r"\\u\1"
left_quote, right_quote = self.dialect.QUOTE_START, self.dialect.QUOTE_END
if escape:
pattern = re.compile(rf"{escape.name}(\d+)")
escape_pattern = re.compile(rf"{escape.name}(\d+)")
escape_sql = f" UESCAPE {self.sql(escape)}" if self.SUPPORTS_UESCAPE else ""
else:
pattern = ESCAPED_UNICODE_RE
escape_pattern = ESCAPED_UNICODE_RE
escape_sql = ""
this = pattern.sub(r"\\u\1", this)
return f"{self.dialect.QUOTE_START}{this}{self.dialect.QUOTE_END}"
if not self.dialect.UNICODE_START or (escape and not self.SUPPORTS_UESCAPE):
this = escape_pattern.sub(escape_substitute, this)
return f"{left_quote}{this}{right_quote}{escape_sql}"
def rawstring_sql(self, expression: exp.RawString) -> str:
string = self.escape_str(expression.this.replace("\\", "\\\\"), escape_backslash=False)
@ -1973,7 +1980,9 @@ class Generator(metaclass=_Generator):
return f", {this_sql}"
op_sql = f"{op_sql} JOIN" if op_sql else "JOIN"
if op_sql != "STRAIGHT_JOIN":
op_sql = f"{op_sql} JOIN" if op_sql else "JOIN"
return f"{self.seg(op_sql)} {this_sql}{match_cond}{on_sql}"
def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str:
@ -2235,10 +2244,6 @@ class Generator(metaclass=_Generator):
elif self.LIMIT_FETCH == "FETCH" and isinstance(limit, exp.Limit):
limit = exp.Fetch(direction="FIRST", count=exp.maybe_copy(limit.expression))
options = self.expressions(expression, key="options")
if options:
options = f" OPTION{self.wrap(options)}"
return csv(
*sqls,
*[self.sql(join) for join in expression.args.get("joins") or []],
@ -2253,10 +2258,14 @@ class Generator(metaclass=_Generator):
self.sql(expression, "order"),
*self.offset_limit_modifiers(expression, isinstance(limit, exp.Fetch), limit),
*self.after_limit_modifiers(expression),
options,
self.options_modifier(expression),
sep="",
)
def options_modifier(self, expression: exp.Expression) -> str:
options = self.expressions(expression, key="options")
return f" {options}" if options else ""
def queryoption_sql(self, expression: exp.QueryOption) -> str:
return ""

View file

@ -1034,7 +1034,7 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr
return (
DATETRUNC_BINARY_COMPARISONS[comparison](
trunc_arg, date, unit, dialect, extract_type(trunc_arg, r)
trunc_arg, date, unit, dialect, extract_type(r)
)
or expression
)
@ -1060,7 +1060,7 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr
return expression
ranges = merge_ranges(ranges)
target_type = extract_type(l, *rs)
target_type = extract_type(*rs)
return exp.or_(
*[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False

View file

@ -588,11 +588,12 @@ class Parser(metaclass=_Parser):
}
JOIN_KINDS = {
TokenType.ANTI,
TokenType.CROSS,
TokenType.INNER,
TokenType.OUTER,
TokenType.CROSS,
TokenType.SEMI,
TokenType.ANTI,
TokenType.STRAIGHT_JOIN,
}
JOIN_HINTS: t.Set[str] = set()
@ -1065,7 +1066,7 @@ class Parser(metaclass=_Parser):
exp.DataType.Type.JSON: lambda self, this, _: self.expression(exp.ParseJSON, this=this),
}
TYPE_CONVERTER: t.Dict[exp.DataType.Type, t.Callable[[exp.DataType], exp.DataType]] = {}
TYPE_CONVERTERS: t.Dict[exp.DataType.Type, t.Callable[[exp.DataType], exp.DataType]] = {}
DDL_SELECT_TOKENS = {TokenType.SELECT, TokenType.WITH, TokenType.L_PAREN}
@ -1138,7 +1139,14 @@ class Parser(metaclass=_Parser):
FETCH_TOKENS = ID_VAR_TOKENS - {TokenType.ROW, TokenType.ROWS, TokenType.PERCENT}
ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY}
ADD_CONSTRAINT_TOKENS = {
TokenType.CONSTRAINT,
TokenType.FOREIGN_KEY,
TokenType.INDEX,
TokenType.KEY,
TokenType.PRIMARY_KEY,
TokenType.UNIQUE,
}
DISTINCT_TOKENS = {TokenType.DISTINCT}
@ -3099,7 +3107,7 @@ class Parser(metaclass=_Parser):
index = self._index
method, side, kind = self._parse_join_parts()
hint = self._prev.text if self._match_texts(self.JOIN_HINTS) else None
join = self._match(TokenType.JOIN)
join = self._match(TokenType.JOIN) or (kind and kind.token_type == TokenType.STRAIGHT_JOIN)
if not skip_join_token and not join:
self._retreat(index)
@ -3242,7 +3250,7 @@ class Parser(metaclass=_Parser):
while self._match_set(self.TABLE_INDEX_HINT_TOKENS):
hint = exp.IndexTableHint(this=self._prev.text.upper())
self._match_texts(("INDEX", "KEY"))
self._match_set((TokenType.INDEX, TokenType.KEY))
if self._match(TokenType.FOR):
hint.set("target", self._advance_any() and self._prev.text.upper())
@ -4464,8 +4472,8 @@ class Parser(metaclass=_Parser):
)
self._match(TokenType.R_BRACKET)
if self.TYPE_CONVERTER and isinstance(this.this, exp.DataType.Type):
converter = self.TYPE_CONVERTER.get(this.this)
if self.TYPE_CONVERTERS and isinstance(this.this, exp.DataType.Type):
converter = self.TYPE_CONVERTERS.get(this.this)
if converter:
this = converter(t.cast(exp.DataType, this))
@ -4496,7 +4504,12 @@ class Parser(metaclass=_Parser):
def _parse_column(self) -> t.Optional[exp.Expression]:
this = self._parse_column_reference()
return self._parse_column_ops(this) if this else self._parse_bracket(this)
column = self._parse_column_ops(this) if this else self._parse_bracket(this)
if self.dialect.SUPPORTS_COLUMN_JOIN_MARKS and column:
column.set("join_mark", self._match(TokenType.JOIN_MARKER))
return column
def _parse_column_reference(self) -> t.Optional[exp.Expression]:
this = self._parse_field()
@ -4522,7 +4535,11 @@ class Parser(metaclass=_Parser):
while self._match(TokenType.COLON):
start_index = self._index
path = self._parse_column_ops(self._parse_field(any_token=True))
# Snowflake allows reserved keywords as json keys but advance_any() excludes TokenType.SELECT from any_tokens=True
path = self._parse_column_ops(
self._parse_field(any_token=True, tokens=(TokenType.SELECT,))
)
# The cast :: operator has a lower precedence than the extraction operator :, so
# we rearrange the AST appropriately to avoid casting the JSON path

View file

@ -287,6 +287,7 @@ class TokenType(AutoName):
JOIN = auto()
JOIN_MARKER = auto()
KEEP = auto()
KEY = auto()
KILL = auto()
LANGUAGE = auto()
LATERAL = auto()
@ -360,6 +361,7 @@ class TokenType(AutoName):
SORT_BY = auto()
START_WITH = auto()
STORAGE_INTEGRATION = auto()
STRAIGHT_JOIN = auto()
STRUCT = auto()
TABLE_SAMPLE = auto()
TAG = auto()
@ -764,6 +766,7 @@ class Tokenizer(metaclass=_Tokenizer):
"SOME": TokenType.SOME,
"SORT BY": TokenType.SORT_BY,
"START WITH": TokenType.START_WITH,
"STRAIGHT_JOIN": TokenType.STRAIGHT_JOIN,
"TABLE": TokenType.TABLE,
"TABLESAMPLE": TokenType.TABLE_SAMPLE,
"TEMP": TokenType.TEMPORARY,
@ -1270,18 +1273,6 @@ class Tokenizer(metaclass=_Tokenizer):
elif token_type == TokenType.BIT_STRING:
base = 2
elif token_type == TokenType.HEREDOC_STRING:
if (
self.HEREDOC_TAG_IS_IDENTIFIER
and not self._peek.isidentifier()
and not self._peek == end
):
if self.HEREDOC_STRING_ALTERNATIVE != token_type.VAR:
self._add(self.HEREDOC_STRING_ALTERNATIVE)
else:
self._scan_var()
return True
self._advance()
if self._char == end:
@ -1293,7 +1284,10 @@ class Tokenizer(metaclass=_Tokenizer):
raise_unmatched=not self.HEREDOC_TAG_IS_IDENTIFIER,
)
if self._end and tag and self.HEREDOC_TAG_IS_IDENTIFIER:
if tag and self.HEREDOC_TAG_IS_IDENTIFIER and (self._end or not tag.isidentifier()):
if not self._end:
self._advance(-1)
self._advance(-len(tag))
self._add(self.HEREDOC_STRING_ALTERNATIVE)
return True

View file

@ -505,7 +505,10 @@ def ensure_bools(expression: exp.Expression) -> exp.Expression:
def _ensure_bool(node: exp.Expression) -> None:
if (
node.is_number
or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
or (
not isinstance(node, exp.SubqueryPredicate)
and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
)
or (isinstance(node, exp.Column) and not node.type)
):
node.replace(node.neq(0))