Merging upstream version 25.1.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
7ab180cac9
commit
3b7539dcad
79 changed files with 28803 additions and 24929 deletions
|
@ -13,6 +13,7 @@ from sqlglot.dialects.dialect import (
|
|||
date_add_interval_sql,
|
||||
datestrtodate_sql,
|
||||
build_formatted_time,
|
||||
build_timestamp_from_parts,
|
||||
filter_array_using_unnest,
|
||||
if_sql,
|
||||
inline_array_unless_query,
|
||||
|
@ -22,6 +23,7 @@ from sqlglot.dialects.dialect import (
|
|||
build_date_delta_with_interval,
|
||||
regexp_replace_sql,
|
||||
rename_func,
|
||||
sha256_sql,
|
||||
timestrtotime_sql,
|
||||
ts_or_ds_add_cast,
|
||||
unit_to_var,
|
||||
|
@ -321,6 +323,7 @@ class BigQuery(Dialect):
|
|||
unit=exp.Literal.string(str(seq_get(args, 1))),
|
||||
this=seq_get(args, 0),
|
||||
),
|
||||
"DATETIME": build_timestamp_from_parts,
|
||||
"DATETIME_ADD": build_date_delta_with_interval(exp.DatetimeAdd),
|
||||
"DATETIME_SUB": build_date_delta_with_interval(exp.DatetimeSub),
|
||||
"DIV": binary_from_function(exp.IntDiv),
|
||||
|
@ -637,9 +640,7 @@ class BigQuery(Dialect):
|
|||
]
|
||||
),
|
||||
exp.SHA: rename_func("SHA1"),
|
||||
exp.SHA2: lambda self, e: self.func(
|
||||
"SHA256" if e.text("length") == "256" else "SHA512", e.this
|
||||
),
|
||||
exp.SHA2: sha256_sql,
|
||||
exp.StabilityProperty: lambda self, e: (
|
||||
"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC"
|
||||
),
|
||||
|
@ -649,6 +650,7 @@ class BigQuery(Dialect):
|
|||
),
|
||||
exp.TimeAdd: date_add_interval_sql("TIME", "ADD"),
|
||||
exp.TimeFromParts: rename_func("TIME"),
|
||||
exp.TimestampFromParts: rename_func("DATETIME"),
|
||||
exp.TimeSub: date_add_interval_sql("TIME", "SUB"),
|
||||
exp.TimestampAdd: date_add_interval_sql("TIMESTAMP", "ADD"),
|
||||
exp.TimestampDiff: rename_func("TIMESTAMP_DIFF"),
|
||||
|
|
|
@ -14,6 +14,7 @@ from sqlglot.dialects.dialect import (
|
|||
no_pivot_sql,
|
||||
build_json_extract_path,
|
||||
rename_func,
|
||||
sha256_sql,
|
||||
var_map_sql,
|
||||
timestamptrunc_sql,
|
||||
)
|
||||
|
@ -758,9 +759,7 @@ class ClickHouse(Dialect):
|
|||
exp.MD5Digest: rename_func("MD5"),
|
||||
exp.MD5: lambda self, e: self.func("LOWER", self.func("HEX", self.func("MD5", e.this))),
|
||||
exp.SHA: rename_func("SHA1"),
|
||||
exp.SHA2: lambda self, e: self.func(
|
||||
"SHA256" if e.text("length") == "256" else "SHA512", e.this
|
||||
),
|
||||
exp.SHA2: sha256_sql,
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.TimestampTrunc: timestamptrunc_sql(zone=True),
|
||||
exp.Variance: rename_func("varSamp"),
|
||||
|
|
|
@ -169,6 +169,7 @@ class _Dialect(type):
|
|||
|
||||
if enum not in ("", "athena", "presto", "trino"):
|
||||
klass.generator_class.TRY_SUPPORTED = False
|
||||
klass.generator_class.SUPPORTS_UESCAPE = False
|
||||
|
||||
if enum not in ("", "databricks", "hive", "spark", "spark2"):
|
||||
modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
|
||||
|
@ -177,6 +178,14 @@ class _Dialect(type):
|
|||
|
||||
klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
|
||||
|
||||
if enum not in ("", "doris", "mysql"):
|
||||
klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | {
|
||||
TokenType.STRAIGHT_JOIN,
|
||||
}
|
||||
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
|
||||
TokenType.STRAIGHT_JOIN,
|
||||
}
|
||||
|
||||
if not klass.SUPPORTS_SEMI_ANTI_JOIN:
|
||||
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
|
||||
TokenType.ANTI,
|
||||
|
@ -220,6 +229,9 @@ class Dialect(metaclass=_Dialect):
|
|||
SUPPORTS_SEMI_ANTI_JOIN = True
|
||||
"""Whether `SEMI` or `ANTI` joins are supported."""
|
||||
|
||||
SUPPORTS_COLUMN_JOIN_MARKS = False
|
||||
"""Whether the old-style outer join (+) syntax is supported."""
|
||||
|
||||
NORMALIZE_FUNCTIONS: bool | str = "upper"
|
||||
"""
|
||||
Determines how function names are going to be normalized.
|
||||
|
@ -1178,3 +1190,16 @@ def build_default_decimal_type(
|
|||
return exp.DataType.build(f"DECIMAL({params})")
|
||||
|
||||
return _builder
|
||||
|
||||
|
||||
def build_timestamp_from_parts(args: t.List) -> exp.Func:
|
||||
if len(args) == 2:
|
||||
# Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept,
|
||||
# so we parse this into Anonymous for now instead of introducing complexity
|
||||
return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args)
|
||||
|
||||
return exp.TimestampFromParts.from_arg_list(args)
|
||||
|
||||
|
||||
def sha256_sql(self: Generator, expression: exp.SHA2) -> str:
|
||||
return self.func(f"SHA{expression.text('length') or '256'}", expression.this)
|
||||
|
|
|
@ -207,7 +207,7 @@ class DuckDB(Dialect):
|
|||
"PIVOT_WIDER": TokenType.PIVOT,
|
||||
"POSITIONAL": TokenType.POSITIONAL,
|
||||
"SIGNED": TokenType.INT,
|
||||
"STRING": TokenType.VARCHAR,
|
||||
"STRING": TokenType.TEXT,
|
||||
"UBIGINT": TokenType.UBIGINT,
|
||||
"UINTEGER": TokenType.UINT,
|
||||
"USMALLINT": TokenType.USMALLINT,
|
||||
|
@ -216,6 +216,7 @@ class DuckDB(Dialect):
|
|||
"TIMESTAMP_MS": TokenType.TIMESTAMP_MS,
|
||||
"TIMESTAMP_NS": TokenType.TIMESTAMP_NS,
|
||||
"TIMESTAMP_US": TokenType.TIMESTAMP,
|
||||
"VARCHAR": TokenType.TEXT,
|
||||
}
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
|
@ -312,9 +313,11 @@ class DuckDB(Dialect):
|
|||
),
|
||||
}
|
||||
|
||||
TYPE_CONVERTER = {
|
||||
TYPE_CONVERTERS = {
|
||||
# https://duckdb.org/docs/sql/data_types/numeric
|
||||
exp.DataType.Type.DECIMAL: build_default_decimal_type(precision=18, scale=3),
|
||||
# https://duckdb.org/docs/sql/data_types/text
|
||||
exp.DataType.Type.TEXT: lambda dtype: exp.DataType.build("TEXT"),
|
||||
}
|
||||
|
||||
def _parse_table_sample(self, as_modifier: bool = False) -> t.Optional[exp.TableSample]:
|
||||
|
@ -495,6 +498,7 @@ class DuckDB(Dialect):
|
|||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.BINARY: "BLOB",
|
||||
exp.DataType.Type.BPCHAR: "TEXT",
|
||||
exp.DataType.Type.CHAR: "TEXT",
|
||||
exp.DataType.Type.FLOAT: "REAL",
|
||||
exp.DataType.Type.NCHAR: "TEXT",
|
||||
|
|
|
@ -202,6 +202,7 @@ class MySQL(Dialect):
|
|||
"CHARSET": TokenType.CHARACTER_SET,
|
||||
"FORCE": TokenType.FORCE,
|
||||
"IGNORE": TokenType.IGNORE,
|
||||
"KEY": TokenType.KEY,
|
||||
"LOCK TABLES": TokenType.COMMAND,
|
||||
"LONGBLOB": TokenType.LONGBLOB,
|
||||
"LONGTEXT": TokenType.LONGTEXT,
|
||||
|
|
|
@ -13,6 +13,7 @@ from sqlglot.dialects.dialect import (
|
|||
trim_sql,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.parser import OPTIONS_TYPE
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
@ -32,10 +33,171 @@ def _build_timetostr_or_tochar(args: t.List) -> exp.TimeToStr | exp.ToChar:
|
|||
return exp.ToChar.from_arg_list(args)
|
||||
|
||||
|
||||
def eliminate_join_marks(ast: exp.Expression) -> exp.Expression:
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
|
||||
"""Remove join marks from an expression
|
||||
|
||||
SELECT * FROM a, b WHERE a.id = b.id(+)
|
||||
becomes:
|
||||
SELECT * FROM a LEFT JOIN b ON a.id = b.id
|
||||
|
||||
- for each scope
|
||||
- for each column with a join mark
|
||||
- find the predicate it belongs to
|
||||
- remove the predicate from the where clause
|
||||
- convert the predicate to a join with the (+) side as the left join table
|
||||
- replace the existing join with the new join
|
||||
|
||||
Args:
|
||||
ast: The AST to remove join marks from
|
||||
|
||||
Returns:
|
||||
The AST with join marks removed"""
|
||||
for scope in traverse_scope(ast):
|
||||
_eliminate_join_marks_from_scope(scope)
|
||||
return ast
|
||||
|
||||
|
||||
def _update_from(
|
||||
select: exp.Select,
|
||||
new_join_dict: t.Dict[str, exp.Join],
|
||||
old_join_dict: t.Dict[str, exp.Join],
|
||||
) -> None:
|
||||
"""If the from clause needs to become a new join, find an appropriate table to use as the new from.
|
||||
updates select in place
|
||||
|
||||
Args:
|
||||
select: The select statement to update
|
||||
new_join_dict: The dictionary of new joins
|
||||
old_join_dict: The dictionary of old joins
|
||||
"""
|
||||
old_from = select.args["from"]
|
||||
if old_from.alias_or_name not in new_join_dict:
|
||||
return
|
||||
in_old_not_new = old_join_dict.keys() - new_join_dict.keys()
|
||||
if len(in_old_not_new) >= 1:
|
||||
new_from_name = list(old_join_dict.keys() - new_join_dict.keys())[0]
|
||||
new_from_this = old_join_dict[new_from_name].this
|
||||
new_from = exp.From(this=new_from_this)
|
||||
del old_join_dict[new_from_name]
|
||||
select.set("from", new_from)
|
||||
else:
|
||||
raise ValueError("Cannot determine which table to use as the new from")
|
||||
|
||||
|
||||
def _has_join_mark(col: exp.Expression) -> bool:
|
||||
"""Check if the column has a join mark
|
||||
|
||||
Args:
|
||||
The column to check
|
||||
"""
|
||||
return col.args.get("join_mark", False)
|
||||
|
||||
|
||||
def _predicate_to_join(
|
||||
eq: exp.Binary, old_joins: t.Dict[str, exp.Join], old_from: exp.From
|
||||
) -> t.Optional[exp.Join]:
|
||||
"""Convert an equality predicate to a join if it contains a join mark
|
||||
|
||||
Args:
|
||||
eq: The equality expression to convert to a join
|
||||
|
||||
Returns:
|
||||
The join expression if the equality contains a join mark (otherwise None)
|
||||
"""
|
||||
|
||||
# if not (isinstance(eq.left, exp.Column) or isinstance(eq.right, exp.Column)):
|
||||
# return None
|
||||
|
||||
left_columns = [col for col in eq.left.find_all(exp.Column) if _has_join_mark(col)]
|
||||
right_columns = [col for col in eq.right.find_all(exp.Column) if _has_join_mark(col)]
|
||||
|
||||
left_has_join_mark = len(left_columns) > 0
|
||||
right_has_join_mark = len(right_columns) > 0
|
||||
|
||||
if left_has_join_mark:
|
||||
for col in left_columns:
|
||||
col.set("join_mark", False)
|
||||
join_on = col.table
|
||||
elif right_has_join_mark:
|
||||
for col in right_columns:
|
||||
col.set("join_mark", False)
|
||||
join_on = col.table
|
||||
else:
|
||||
return None
|
||||
|
||||
join_this = old_joins.get(join_on, old_from).this
|
||||
return exp.Join(this=join_this, on=eq, kind="LEFT")
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.optimizer.scope import Scope
|
||||
|
||||
|
||||
def _eliminate_join_marks_from_scope(scope: Scope) -> None:
|
||||
"""Remove join marks columns in scope's where clause.
|
||||
Converts them to left joins and replaces any existing joins.
|
||||
Updates scope in place.
|
||||
|
||||
Args:
|
||||
scope: The scope to remove join marks from
|
||||
"""
|
||||
select_scope = scope.expression
|
||||
where = select_scope.args.get("where")
|
||||
joins = select_scope.args.get("joins")
|
||||
if not where:
|
||||
return
|
||||
if not joins:
|
||||
return
|
||||
|
||||
# dictionaries used to keep track of joins to be replaced
|
||||
old_joins = {join.alias_or_name: join for join in list(joins)}
|
||||
new_joins: t.Dict[str, exp.Join] = {}
|
||||
|
||||
for node in scope.find_all(exp.Column):
|
||||
if _has_join_mark(node):
|
||||
predicate = node.find_ancestor(exp.Predicate)
|
||||
if not isinstance(predicate, exp.Binary):
|
||||
continue
|
||||
predicate_parent = predicate.parent
|
||||
|
||||
join_on = predicate.pop()
|
||||
new_join = _predicate_to_join(
|
||||
join_on, old_joins=old_joins, old_from=select_scope.args["from"]
|
||||
)
|
||||
# upsert new_join into new_joins dictionary
|
||||
if new_join:
|
||||
if new_join.alias_or_name in new_joins:
|
||||
new_joins[new_join.alias_or_name].set(
|
||||
"on",
|
||||
exp.and_(
|
||||
new_joins[new_join.alias_or_name].args["on"],
|
||||
new_join.args["on"],
|
||||
),
|
||||
)
|
||||
else:
|
||||
new_joins[new_join.alias_or_name] = new_join
|
||||
# If the parent is a binary node with only one child, promote the child to the parent
|
||||
if predicate_parent:
|
||||
if isinstance(predicate_parent, exp.Binary):
|
||||
if predicate_parent.left is None:
|
||||
predicate_parent.replace(predicate_parent.right)
|
||||
elif predicate_parent.right is None:
|
||||
predicate_parent.replace(predicate_parent.left)
|
||||
|
||||
_update_from(select_scope, new_joins, old_joins)
|
||||
replacement_joins = [new_joins.get(join.alias_or_name, join) for join in old_joins.values()]
|
||||
select_scope.set("joins", replacement_joins)
|
||||
if not where.this:
|
||||
where.pop()
|
||||
|
||||
|
||||
class Oracle(Dialect):
|
||||
ALIAS_POST_TABLESAMPLE = True
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
TABLESAMPLE_SIZE_IS_PERCENT = True
|
||||
SUPPORTS_COLUMN_JOIN_MARKS = True
|
||||
|
||||
# See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
|
||||
|
@ -70,6 +232,12 @@ class Oracle(Dialect):
|
|||
class Tokenizer(tokens.Tokenizer):
|
||||
VAR_SINGLE_TOKENS = {"@", "$", "#"}
|
||||
|
||||
UNICODE_STRINGS = [
|
||||
(prefix + q, q)
|
||||
for q in t.cast(t.List[str], tokens.Tokenizer.QUOTES)
|
||||
for prefix in ("U", "u")
|
||||
]
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"(+)": TokenType.JOIN_MARKER,
|
||||
|
@ -132,6 +300,7 @@ class Oracle(Dialect):
|
|||
QUERY_MODIFIER_PARSERS = {
|
||||
**parser.Parser.QUERY_MODIFIER_PARSERS,
|
||||
TokenType.ORDER_SIBLINGS_BY: lambda self: ("order", self._parse_order()),
|
||||
TokenType.WITH: lambda self: ("options", [self._parse_query_restrictions()]),
|
||||
}
|
||||
|
||||
TYPE_LITERAL_PARSERS = {
|
||||
|
@ -144,6 +313,13 @@ class Oracle(Dialect):
|
|||
# Reference: https://stackoverflow.com/a/336455
|
||||
DISTINCT_TOKENS = {TokenType.DISTINCT, TokenType.UNIQUE}
|
||||
|
||||
QUERY_RESTRICTIONS: OPTIONS_TYPE = {
|
||||
"WITH": (
|
||||
("READ", "ONLY"),
|
||||
("CHECK", "OPTION"),
|
||||
),
|
||||
}
|
||||
|
||||
def _parse_xml_table(self) -> exp.XMLTable:
|
||||
this = self._parse_string()
|
||||
|
||||
|
@ -173,12 +349,6 @@ class Oracle(Dialect):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
def _parse_column(self) -> t.Optional[exp.Expression]:
|
||||
column = super()._parse_column()
|
||||
if column:
|
||||
column.set("join_mark", self._match(TokenType.JOIN_MARKER))
|
||||
return column
|
||||
|
||||
def _parse_hint(self) -> t.Optional[exp.Hint]:
|
||||
if self._match(TokenType.HINT):
|
||||
start = self._curr
|
||||
|
@ -193,11 +363,22 @@ class Oracle(Dialect):
|
|||
|
||||
return None
|
||||
|
||||
def _parse_query_restrictions(self) -> t.Optional[exp.Expression]:
|
||||
kind = self._parse_var_from_options(self.QUERY_RESTRICTIONS, raise_unmatched=False)
|
||||
|
||||
if not kind:
|
||||
return None
|
||||
|
||||
return self.expression(
|
||||
exp.QueryOption,
|
||||
this=kind,
|
||||
expression=self._match(TokenType.CONSTRAINT) and self._parse_field(),
|
||||
)
|
||||
|
||||
class Generator(generator.Generator):
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
COLUMN_JOIN_MARKS_SUPPORTED = True
|
||||
DATA_TYPE_SPECIFIERS_ALLOWED = True
|
||||
ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = False
|
||||
LIMIT_FETCH = "FETCH"
|
||||
|
@ -282,3 +463,10 @@ class Oracle(Dialect):
|
|||
if len(expression.args.get("actions", [])) > 1:
|
||||
return f"ADD ({actions})"
|
||||
return f"ADD {actions}"
|
||||
|
||||
def queryoption_sql(self, expression: exp.QueryOption) -> str:
|
||||
option = self.sql(expression, "this")
|
||||
value = self.sql(expression, "expression")
|
||||
value = f" CONSTRAINT {value}" if value else ""
|
||||
|
||||
return f"{option}{value}"
|
||||
|
|
|
@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
|
|||
Dialect,
|
||||
JSON_EXTRACT_TYPE,
|
||||
any_value_to_max_sql,
|
||||
binary_from_function,
|
||||
bool_xor_sql,
|
||||
datestrtodate_sql,
|
||||
build_formatted_time,
|
||||
|
@ -25,6 +26,7 @@ from sqlglot.dialects.dialect import (
|
|||
build_json_extract_path,
|
||||
build_timestamp_trunc,
|
||||
rename_func,
|
||||
sha256_sql,
|
||||
str_position_sql,
|
||||
struct_extract_sql,
|
||||
timestamptrunc_sql,
|
||||
|
@ -329,6 +331,7 @@ class Postgres(Dialect):
|
|||
"REGTYPE": TokenType.OBJECT_IDENTIFIER,
|
||||
"FLOAT": TokenType.DOUBLE,
|
||||
}
|
||||
KEYWORDS.pop("DIV")
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
**tokens.Tokenizer.SINGLE_TOKENS,
|
||||
|
@ -347,6 +350,9 @@ class Postgres(Dialect):
|
|||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"DATE_TRUNC": build_timestamp_trunc,
|
||||
"DIV": lambda args: exp.cast(
|
||||
binary_from_function(exp.IntDiv)(args), exp.DataType.Type.DECIMAL
|
||||
),
|
||||
"GENERATE_SERIES": _build_generate_series,
|
||||
"JSON_EXTRACT_PATH": build_json_extract_path(exp.JSONExtract),
|
||||
"JSON_EXTRACT_PATH_TEXT": build_json_extract_path(exp.JSONExtractScalar),
|
||||
|
@ -357,6 +363,9 @@ class Postgres(Dialect):
|
|||
"TO_CHAR": build_formatted_time(exp.TimeToStr, "postgres"),
|
||||
"TO_TIMESTAMP": _build_to_timestamp,
|
||||
"UNNEST": exp.Explode.from_arg_list,
|
||||
"SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)),
|
||||
"SHA384": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(384)),
|
||||
"SHA512": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(512)),
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
|
@ -494,6 +503,7 @@ class Postgres(Dialect):
|
|||
exp.DateSub: _date_add_sql("-"),
|
||||
exp.Explode: rename_func("UNNEST"),
|
||||
exp.GroupConcat: _string_agg_sql,
|
||||
exp.IntDiv: rename_func("DIV"),
|
||||
exp.JSONExtract: _json_extract_sql("JSON_EXTRACT_PATH", "->"),
|
||||
exp.JSONExtractScalar: _json_extract_sql("JSON_EXTRACT_PATH_TEXT", "->>"),
|
||||
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
|
||||
|
@ -528,6 +538,7 @@ class Postgres(Dialect):
|
|||
transforms.eliminate_qualify,
|
||||
]
|
||||
),
|
||||
exp.SHA2: sha256_sql,
|
||||
exp.StrPosition: str_position_sql,
|
||||
exp.StrToDate: lambda self, e: self.func("TO_DATE", e.this, self.format_time(e)),
|
||||
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
|
||||
|
@ -621,3 +632,12 @@ class Postgres(Dialect):
|
|||
return f"{self.expressions(expression, flat=True)}[{values}]"
|
||||
return "ARRAY"
|
||||
return super().datatype_sql(expression)
|
||||
|
||||
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
|
||||
this = expression.this
|
||||
|
||||
# Postgres casts DIV() to decimal for transpilation but when roundtripping it's superfluous
|
||||
if isinstance(this, exp.IntDiv) and expression.to == exp.DataType.build("decimal"):
|
||||
return self.sql(this)
|
||||
|
||||
return super().cast_sql(expression, safe_prefix=safe_prefix)
|
||||
|
|
|
@ -21,6 +21,7 @@ from sqlglot.dialects.dialect import (
|
|||
regexp_extract_sql,
|
||||
rename_func,
|
||||
right_to_substring_sql,
|
||||
sha256_sql,
|
||||
struct_extract_sql,
|
||||
str_position_sql,
|
||||
timestamptrunc_sql,
|
||||
|
@ -452,9 +453,7 @@ class Presto(Dialect):
|
|||
),
|
||||
exp.MD5Digest: rename_func("MD5"),
|
||||
exp.SHA: rename_func("SHA1"),
|
||||
exp.SHA2: lambda self, e: self.func(
|
||||
"SHA256" if e.text("length") == "256" else "SHA512", e.this
|
||||
),
|
||||
exp.SHA2: sha256_sql,
|
||||
}
|
||||
|
||||
RESERVED_KEYWORDS = {
|
||||
|
|
|
@ -40,6 +40,7 @@ class Redshift(Postgres):
|
|||
INDEX_OFFSET = 0
|
||||
COPY_PARAMS_ARE_CSV = False
|
||||
HEX_LOWERCASE = True
|
||||
SUPPORTS_COLUMN_JOIN_MARKS = True
|
||||
|
||||
TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'"
|
||||
TIME_MAPPING = {
|
||||
|
@ -122,12 +123,13 @@ class Redshift(Postgres):
|
|||
|
||||
KEYWORDS = {
|
||||
**Postgres.Tokenizer.KEYWORDS,
|
||||
"(+)": TokenType.JOIN_MARKER,
|
||||
"HLLSKETCH": TokenType.HLLSKETCH,
|
||||
"MINUS": TokenType.EXCEPT,
|
||||
"SUPER": TokenType.SUPER,
|
||||
"TOP": TokenType.TOP,
|
||||
"UNLOAD": TokenType.COMMAND,
|
||||
"VARBYTE": TokenType.VARBINARY,
|
||||
"MINUS": TokenType.EXCEPT,
|
||||
}
|
||||
KEYWORDS.pop("VALUES")
|
||||
|
||||
|
@ -209,6 +211,7 @@ class Redshift(Postgres):
|
|||
|
||||
# Redshift supports LAST_DAY(..)
|
||||
TRANSFORMS.pop(exp.LastDay)
|
||||
TRANSFORMS.pop(exp.SHA2)
|
||||
|
||||
RESERVED_KEYWORDS = {
|
||||
"aes128",
|
||||
|
|
|
@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
|
|||
NormalizationStrategy,
|
||||
binary_from_function,
|
||||
build_default_decimal_type,
|
||||
build_timestamp_from_parts,
|
||||
date_delta_sql,
|
||||
date_trunc_to_time,
|
||||
datestrtodate_sql,
|
||||
|
@ -236,15 +237,6 @@ def _date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
|
|||
return trunc
|
||||
|
||||
|
||||
def _build_timestamp_from_parts(args: t.List) -> exp.Func:
|
||||
if len(args) == 2:
|
||||
# Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept,
|
||||
# so we parse this into Anonymous for now instead of introducing complexity
|
||||
return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args)
|
||||
|
||||
return exp.TimestampFromParts.from_arg_list(args)
|
||||
|
||||
|
||||
def _unqualify_unpivot_columns(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Snowflake doesn't allow columns referenced in UNPIVOT to be qualified,
|
||||
|
@ -391,8 +383,8 @@ class Snowflake(Dialect):
|
|||
"TIMEDIFF": _build_datediff,
|
||||
"TIMESTAMPADD": _build_date_time_add(exp.DateAdd),
|
||||
"TIMESTAMPDIFF": _build_datediff,
|
||||
"TIMESTAMPFROMPARTS": _build_timestamp_from_parts,
|
||||
"TIMESTAMP_FROM_PARTS": _build_timestamp_from_parts,
|
||||
"TIMESTAMPFROMPARTS": build_timestamp_from_parts,
|
||||
"TIMESTAMP_FROM_PARTS": build_timestamp_from_parts,
|
||||
"TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DataType.Type.DATE, safe=True),
|
||||
"TO_DATE": _build_datetime("TO_DATE", exp.DataType.Type.DATE),
|
||||
"TO_NUMBER": lambda args: exp.ToNumber(
|
||||
|
@ -446,7 +438,7 @@ class Snowflake(Dialect):
|
|||
"LOCATION": lambda self: self._parse_location_property(),
|
||||
}
|
||||
|
||||
TYPE_CONVERTER = {
|
||||
TYPE_CONVERTERS = {
|
||||
# https://docs.snowflake.com/en/sql-reference/data-types-numeric#number
|
||||
exp.DataType.Type.DECIMAL: build_default_decimal_type(precision=38, scale=0),
|
||||
}
|
||||
|
@ -510,15 +502,18 @@ class Snowflake(Dialect):
|
|||
self._retreat(self._index - 1)
|
||||
|
||||
if self._match_text_seq("MASKING", "POLICY"):
|
||||
policy = self._parse_column()
|
||||
return self.expression(
|
||||
exp.MaskingPolicyColumnConstraint,
|
||||
this=self._parse_id_var(),
|
||||
this=policy.to_dot() if isinstance(policy, exp.Column) else policy,
|
||||
expressions=self._match(TokenType.USING)
|
||||
and self._parse_wrapped_csv(self._parse_id_var),
|
||||
)
|
||||
if self._match_text_seq("PROJECTION", "POLICY"):
|
||||
policy = self._parse_column()
|
||||
return self.expression(
|
||||
exp.ProjectionPolicyColumnConstraint, this=self._parse_id_var()
|
||||
exp.ProjectionPolicyColumnConstraint,
|
||||
this=policy.to_dot() if isinstance(policy, exp.Column) else policy,
|
||||
)
|
||||
if self._match(TokenType.TAG):
|
||||
return self.expression(
|
||||
|
|
|
@ -41,6 +41,21 @@ def _build_datediff(args: t.List) -> exp.Expression:
|
|||
)
|
||||
|
||||
|
||||
def _build_dateadd(args: t.List) -> exp.Expression:
|
||||
expression = seq_get(args, 1)
|
||||
|
||||
if len(args) == 2:
|
||||
# DATE_ADD(startDate, numDays INTEGER)
|
||||
# https://docs.databricks.com/en/sql/language-manual/functions/date_add.html
|
||||
return exp.TsOrDsAdd(
|
||||
this=seq_get(args, 0), expression=expression, unit=exp.Literal.string("DAY")
|
||||
)
|
||||
|
||||
# DATE_ADD / DATEADD / TIMESTAMPADD(unit, value integer, expr)
|
||||
# https://docs.databricks.com/en/sql/language-manual/functions/date_add3.html
|
||||
return exp.TimestampAdd(this=seq_get(args, 2), expression=expression, unit=seq_get(args, 0))
|
||||
|
||||
|
||||
def _normalize_partition(e: exp.Expression) -> exp.Expression:
|
||||
"""Normalize the expressions in PARTITION BY (<expression>, <expression>, ...)"""
|
||||
if isinstance(e, str):
|
||||
|
@ -50,6 +65,30 @@ def _normalize_partition(e: exp.Expression) -> exp.Expression:
|
|||
return e
|
||||
|
||||
|
||||
def _dateadd_sql(self: Spark.Generator, expression: exp.TsOrDsAdd | exp.TimestampAdd) -> str:
|
||||
if not expression.unit or (
|
||||
isinstance(expression, exp.TsOrDsAdd) and expression.text("unit").upper() == "DAY"
|
||||
):
|
||||
# Coming from Hive/Spark2 DATE_ADD or roundtripping the 2-arg version of Spark3/DB
|
||||
return self.func("DATE_ADD", expression.this, expression.expression)
|
||||
|
||||
this = self.func(
|
||||
"DATE_ADD",
|
||||
unit_to_var(expression),
|
||||
expression.expression,
|
||||
expression.this,
|
||||
)
|
||||
|
||||
if isinstance(expression, exp.TsOrDsAdd):
|
||||
# The 3 arg version of DATE_ADD produces a timestamp in Spark3/DB but possibly not
|
||||
# in other dialects
|
||||
return_type = expression.return_type
|
||||
if not return_type.is_type(exp.DataType.Type.TIMESTAMP, exp.DataType.Type.DATETIME):
|
||||
this = f"CAST({this} AS {return_type})"
|
||||
|
||||
return this
|
||||
|
||||
|
||||
class Spark(Spark2):
|
||||
class Tokenizer(Spark2.Tokenizer):
|
||||
RAW_STRINGS = [
|
||||
|
@ -62,6 +101,9 @@ class Spark(Spark2):
|
|||
FUNCTIONS = {
|
||||
**Spark2.Parser.FUNCTIONS,
|
||||
"ANY_VALUE": _build_with_ignore_nulls(exp.AnyValue),
|
||||
"DATE_ADD": _build_dateadd,
|
||||
"DATEADD": _build_dateadd,
|
||||
"TIMESTAMPADD": _build_dateadd,
|
||||
"DATEDIFF": _build_datediff,
|
||||
"TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"),
|
||||
"TIMESTAMP_NTZ": _build_as_cast("TIMESTAMP_NTZ"),
|
||||
|
@ -111,9 +153,8 @@ class Spark(Spark2):
|
|||
exp.PartitionedByProperty: lambda self,
|
||||
e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}",
|
||||
exp.StartsWith: rename_func("STARTSWITH"),
|
||||
exp.TimestampAdd: lambda self, e: self.func(
|
||||
"DATEADD", unit_to_var(e), e.expression, e.this
|
||||
),
|
||||
exp.TsOrDsAdd: _dateadd_sql,
|
||||
exp.TimestampAdd: _dateadd_sql,
|
||||
exp.TryCast: lambda self, e: (
|
||||
self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e)
|
||||
),
|
||||
|
|
|
@ -75,6 +75,26 @@ def _transform_create(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def _generated_to_auto_increment(expression: exp.Expression) -> exp.Expression:
|
||||
if not isinstance(expression, exp.ColumnDef):
|
||||
return expression
|
||||
|
||||
generated = expression.find(exp.GeneratedAsIdentityColumnConstraint)
|
||||
|
||||
if generated:
|
||||
t.cast(exp.ColumnConstraint, generated.parent).pop()
|
||||
|
||||
not_null = expression.find(exp.NotNullColumnConstraint)
|
||||
if not_null:
|
||||
t.cast(exp.ColumnConstraint, not_null.parent).pop()
|
||||
|
||||
expression.append(
|
||||
"constraints", exp.ColumnConstraint(kind=exp.AutoIncrementColumnConstraint())
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
class SQLite(Dialect):
|
||||
# https://sqlite.org/forum/forumpost/5e575586ac5c711b?raw
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
|
||||
|
@ -141,6 +161,7 @@ class SQLite(Dialect):
|
|||
exp.CurrentDate: lambda *_: "CURRENT_DATE",
|
||||
exp.CurrentTime: lambda *_: "CURRENT_TIME",
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.ColumnDef: transforms.preprocess([_generated_to_auto_increment]),
|
||||
exp.DateAdd: _date_add_sql,
|
||||
exp.DateStrToDate: lambda self, e: self.sql(e, "this"),
|
||||
exp.If: rename_func("IIF"),
|
||||
|
|
|
@ -1118,3 +1118,7 @@ class TSQL(Dialect):
|
|||
kind = f"TABLE {kind}"
|
||||
|
||||
return f"{variable} AS {kind}{default}"
|
||||
|
||||
def options_modifier(self, expression: exp.Expression) -> str:
|
||||
options = self.expressions(expression, key="options")
|
||||
return f" OPTION{self.wrap(options)}" if options else ""
|
||||
|
|
|
@ -3119,22 +3119,6 @@ class Intersect(Union):
|
|||
pass
|
||||
|
||||
|
||||
class Unnest(UDTF):
|
||||
arg_types = {
|
||||
"expressions": True,
|
||||
"alias": False,
|
||||
"offset": False,
|
||||
}
|
||||
|
||||
@property
|
||||
def selects(self) -> t.List[Expression]:
|
||||
columns = super().selects
|
||||
offset = self.args.get("offset")
|
||||
if offset:
|
||||
columns = columns + [to_identifier("offset") if offset is True else offset]
|
||||
return columns
|
||||
|
||||
|
||||
class Update(Expression):
|
||||
arg_types = {
|
||||
"with": False,
|
||||
|
@ -5240,6 +5224,22 @@ class PosexplodeOuter(Posexplode, ExplodeOuter):
|
|||
pass
|
||||
|
||||
|
||||
class Unnest(Func, UDTF):
|
||||
arg_types = {
|
||||
"expressions": True,
|
||||
"alias": False,
|
||||
"offset": False,
|
||||
}
|
||||
|
||||
@property
|
||||
def selects(self) -> t.List[Expression]:
|
||||
columns = super().selects
|
||||
offset = self.args.get("offset")
|
||||
if offset:
|
||||
columns = columns + [to_identifier("offset") if offset is True else offset]
|
||||
return columns
|
||||
|
||||
|
||||
class Floor(Func):
|
||||
arg_types = {"this": True, "decimals": False}
|
||||
|
||||
|
@ -5765,7 +5765,7 @@ class StrPosition(Func):
|
|||
|
||||
|
||||
class StrToDate(Func):
|
||||
arg_types = {"this": True, "format": True}
|
||||
arg_types = {"this": True, "format": False}
|
||||
|
||||
|
||||
class StrToTime(Func):
|
||||
|
|
|
@ -225,9 +225,6 @@ class Generator(metaclass=_Generator):
|
|||
# Whether to generate INSERT INTO ... RETURNING or INSERT INTO RETURNING ...
|
||||
RETURNING_END = True
|
||||
|
||||
# Whether to generate the (+) suffix for columns used in old-style join conditions
|
||||
COLUMN_JOIN_MARKS_SUPPORTED = False
|
||||
|
||||
# Whether to generate an unquoted value for EXTRACT's date part argument
|
||||
EXTRACT_ALLOWS_QUOTES = True
|
||||
|
||||
|
@ -359,6 +356,9 @@ class Generator(metaclass=_Generator):
|
|||
# Whether the conditional TRY(expression) function is supported
|
||||
TRY_SUPPORTED = True
|
||||
|
||||
# Whether the UESCAPE syntax in unicode strings is supported
|
||||
SUPPORTS_UESCAPE = True
|
||||
|
||||
# The keyword to use when generating a star projection with excluded columns
|
||||
STAR_EXCEPT = "EXCEPT"
|
||||
|
||||
|
@ -827,7 +827,7 @@ class Generator(metaclass=_Generator):
|
|||
def column_sql(self, expression: exp.Column) -> str:
|
||||
join_mark = " (+)" if expression.args.get("join_mark") else ""
|
||||
|
||||
if join_mark and not self.COLUMN_JOIN_MARKS_SUPPORTED:
|
||||
if join_mark and not self.dialect.SUPPORTS_COLUMN_JOIN_MARKS:
|
||||
join_mark = ""
|
||||
self.unsupported("Outer join syntax using the (+) operator is not supported.")
|
||||
|
||||
|
@ -1146,16 +1146,23 @@ class Generator(metaclass=_Generator):
|
|||
escape = expression.args.get("escape")
|
||||
|
||||
if self.dialect.UNICODE_START:
|
||||
escape = f" UESCAPE {self.sql(escape)}" if escape else ""
|
||||
return f"{self.dialect.UNICODE_START}{this}{self.dialect.UNICODE_END}{escape}"
|
||||
escape_substitute = r"\\\1"
|
||||
left_quote, right_quote = self.dialect.UNICODE_START, self.dialect.UNICODE_END
|
||||
else:
|
||||
escape_substitute = r"\\u\1"
|
||||
left_quote, right_quote = self.dialect.QUOTE_START, self.dialect.QUOTE_END
|
||||
|
||||
if escape:
|
||||
pattern = re.compile(rf"{escape.name}(\d+)")
|
||||
escape_pattern = re.compile(rf"{escape.name}(\d+)")
|
||||
escape_sql = f" UESCAPE {self.sql(escape)}" if self.SUPPORTS_UESCAPE else ""
|
||||
else:
|
||||
pattern = ESCAPED_UNICODE_RE
|
||||
escape_pattern = ESCAPED_UNICODE_RE
|
||||
escape_sql = ""
|
||||
|
||||
this = pattern.sub(r"\\u\1", this)
|
||||
return f"{self.dialect.QUOTE_START}{this}{self.dialect.QUOTE_END}"
|
||||
if not self.dialect.UNICODE_START or (escape and not self.SUPPORTS_UESCAPE):
|
||||
this = escape_pattern.sub(escape_substitute, this)
|
||||
|
||||
return f"{left_quote}{this}{right_quote}{escape_sql}"
|
||||
|
||||
def rawstring_sql(self, expression: exp.RawString) -> str:
|
||||
string = self.escape_str(expression.this.replace("\\", "\\\\"), escape_backslash=False)
|
||||
|
@ -1973,7 +1980,9 @@ class Generator(metaclass=_Generator):
|
|||
|
||||
return f", {this_sql}"
|
||||
|
||||
op_sql = f"{op_sql} JOIN" if op_sql else "JOIN"
|
||||
if op_sql != "STRAIGHT_JOIN":
|
||||
op_sql = f"{op_sql} JOIN" if op_sql else "JOIN"
|
||||
|
||||
return f"{self.seg(op_sql)} {this_sql}{match_cond}{on_sql}"
|
||||
|
||||
def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str:
|
||||
|
@ -2235,10 +2244,6 @@ class Generator(metaclass=_Generator):
|
|||
elif self.LIMIT_FETCH == "FETCH" and isinstance(limit, exp.Limit):
|
||||
limit = exp.Fetch(direction="FIRST", count=exp.maybe_copy(limit.expression))
|
||||
|
||||
options = self.expressions(expression, key="options")
|
||||
if options:
|
||||
options = f" OPTION{self.wrap(options)}"
|
||||
|
||||
return csv(
|
||||
*sqls,
|
||||
*[self.sql(join) for join in expression.args.get("joins") or []],
|
||||
|
@ -2253,10 +2258,14 @@ class Generator(metaclass=_Generator):
|
|||
self.sql(expression, "order"),
|
||||
*self.offset_limit_modifiers(expression, isinstance(limit, exp.Fetch), limit),
|
||||
*self.after_limit_modifiers(expression),
|
||||
options,
|
||||
self.options_modifier(expression),
|
||||
sep="",
|
||||
)
|
||||
|
||||
def options_modifier(self, expression: exp.Expression) -> str:
|
||||
options = self.expressions(expression, key="options")
|
||||
return f" {options}" if options else ""
|
||||
|
||||
def queryoption_sql(self, expression: exp.QueryOption) -> str:
|
||||
return ""
|
||||
|
||||
|
|
|
@ -1034,7 +1034,7 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr
|
|||
|
||||
return (
|
||||
DATETRUNC_BINARY_COMPARISONS[comparison](
|
||||
trunc_arg, date, unit, dialect, extract_type(trunc_arg, r)
|
||||
trunc_arg, date, unit, dialect, extract_type(r)
|
||||
)
|
||||
or expression
|
||||
)
|
||||
|
@ -1060,7 +1060,7 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr
|
|||
return expression
|
||||
|
||||
ranges = merge_ranges(ranges)
|
||||
target_type = extract_type(l, *rs)
|
||||
target_type = extract_type(*rs)
|
||||
|
||||
return exp.or_(
|
||||
*[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False
|
||||
|
|
|
@ -588,11 +588,12 @@ class Parser(metaclass=_Parser):
|
|||
}
|
||||
|
||||
JOIN_KINDS = {
|
||||
TokenType.ANTI,
|
||||
TokenType.CROSS,
|
||||
TokenType.INNER,
|
||||
TokenType.OUTER,
|
||||
TokenType.CROSS,
|
||||
TokenType.SEMI,
|
||||
TokenType.ANTI,
|
||||
TokenType.STRAIGHT_JOIN,
|
||||
}
|
||||
|
||||
JOIN_HINTS: t.Set[str] = set()
|
||||
|
@ -1065,7 +1066,7 @@ class Parser(metaclass=_Parser):
|
|||
exp.DataType.Type.JSON: lambda self, this, _: self.expression(exp.ParseJSON, this=this),
|
||||
}
|
||||
|
||||
TYPE_CONVERTER: t.Dict[exp.DataType.Type, t.Callable[[exp.DataType], exp.DataType]] = {}
|
||||
TYPE_CONVERTERS: t.Dict[exp.DataType.Type, t.Callable[[exp.DataType], exp.DataType]] = {}
|
||||
|
||||
DDL_SELECT_TOKENS = {TokenType.SELECT, TokenType.WITH, TokenType.L_PAREN}
|
||||
|
||||
|
@ -1138,7 +1139,14 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
FETCH_TOKENS = ID_VAR_TOKENS - {TokenType.ROW, TokenType.ROWS, TokenType.PERCENT}
|
||||
|
||||
ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY}
|
||||
ADD_CONSTRAINT_TOKENS = {
|
||||
TokenType.CONSTRAINT,
|
||||
TokenType.FOREIGN_KEY,
|
||||
TokenType.INDEX,
|
||||
TokenType.KEY,
|
||||
TokenType.PRIMARY_KEY,
|
||||
TokenType.UNIQUE,
|
||||
}
|
||||
|
||||
DISTINCT_TOKENS = {TokenType.DISTINCT}
|
||||
|
||||
|
@ -3099,7 +3107,7 @@ class Parser(metaclass=_Parser):
|
|||
index = self._index
|
||||
method, side, kind = self._parse_join_parts()
|
||||
hint = self._prev.text if self._match_texts(self.JOIN_HINTS) else None
|
||||
join = self._match(TokenType.JOIN)
|
||||
join = self._match(TokenType.JOIN) or (kind and kind.token_type == TokenType.STRAIGHT_JOIN)
|
||||
|
||||
if not skip_join_token and not join:
|
||||
self._retreat(index)
|
||||
|
@ -3242,7 +3250,7 @@ class Parser(metaclass=_Parser):
|
|||
while self._match_set(self.TABLE_INDEX_HINT_TOKENS):
|
||||
hint = exp.IndexTableHint(this=self._prev.text.upper())
|
||||
|
||||
self._match_texts(("INDEX", "KEY"))
|
||||
self._match_set((TokenType.INDEX, TokenType.KEY))
|
||||
if self._match(TokenType.FOR):
|
||||
hint.set("target", self._advance_any() and self._prev.text.upper())
|
||||
|
||||
|
@ -4464,8 +4472,8 @@ class Parser(metaclass=_Parser):
|
|||
)
|
||||
self._match(TokenType.R_BRACKET)
|
||||
|
||||
if self.TYPE_CONVERTER and isinstance(this.this, exp.DataType.Type):
|
||||
converter = self.TYPE_CONVERTER.get(this.this)
|
||||
if self.TYPE_CONVERTERS and isinstance(this.this, exp.DataType.Type):
|
||||
converter = self.TYPE_CONVERTERS.get(this.this)
|
||||
if converter:
|
||||
this = converter(t.cast(exp.DataType, this))
|
||||
|
||||
|
@ -4496,7 +4504,12 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
def _parse_column(self) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_column_reference()
|
||||
return self._parse_column_ops(this) if this else self._parse_bracket(this)
|
||||
column = self._parse_column_ops(this) if this else self._parse_bracket(this)
|
||||
|
||||
if self.dialect.SUPPORTS_COLUMN_JOIN_MARKS and column:
|
||||
column.set("join_mark", self._match(TokenType.JOIN_MARKER))
|
||||
|
||||
return column
|
||||
|
||||
def _parse_column_reference(self) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_field()
|
||||
|
@ -4522,7 +4535,11 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
while self._match(TokenType.COLON):
|
||||
start_index = self._index
|
||||
path = self._parse_column_ops(self._parse_field(any_token=True))
|
||||
|
||||
# Snowflake allows reserved keywords as json keys but advance_any() excludes TokenType.SELECT from any_tokens=True
|
||||
path = self._parse_column_ops(
|
||||
self._parse_field(any_token=True, tokens=(TokenType.SELECT,))
|
||||
)
|
||||
|
||||
# The cast :: operator has a lower precedence than the extraction operator :, so
|
||||
# we rearrange the AST appropriately to avoid casting the JSON path
|
||||
|
|
|
@ -287,6 +287,7 @@ class TokenType(AutoName):
|
|||
JOIN = auto()
|
||||
JOIN_MARKER = auto()
|
||||
KEEP = auto()
|
||||
KEY = auto()
|
||||
KILL = auto()
|
||||
LANGUAGE = auto()
|
||||
LATERAL = auto()
|
||||
|
@ -360,6 +361,7 @@ class TokenType(AutoName):
|
|||
SORT_BY = auto()
|
||||
START_WITH = auto()
|
||||
STORAGE_INTEGRATION = auto()
|
||||
STRAIGHT_JOIN = auto()
|
||||
STRUCT = auto()
|
||||
TABLE_SAMPLE = auto()
|
||||
TAG = auto()
|
||||
|
@ -764,6 +766,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"SOME": TokenType.SOME,
|
||||
"SORT BY": TokenType.SORT_BY,
|
||||
"START WITH": TokenType.START_WITH,
|
||||
"STRAIGHT_JOIN": TokenType.STRAIGHT_JOIN,
|
||||
"TABLE": TokenType.TABLE,
|
||||
"TABLESAMPLE": TokenType.TABLE_SAMPLE,
|
||||
"TEMP": TokenType.TEMPORARY,
|
||||
|
@ -1270,18 +1273,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
elif token_type == TokenType.BIT_STRING:
|
||||
base = 2
|
||||
elif token_type == TokenType.HEREDOC_STRING:
|
||||
if (
|
||||
self.HEREDOC_TAG_IS_IDENTIFIER
|
||||
and not self._peek.isidentifier()
|
||||
and not self._peek == end
|
||||
):
|
||||
if self.HEREDOC_STRING_ALTERNATIVE != token_type.VAR:
|
||||
self._add(self.HEREDOC_STRING_ALTERNATIVE)
|
||||
else:
|
||||
self._scan_var()
|
||||
|
||||
return True
|
||||
|
||||
self._advance()
|
||||
|
||||
if self._char == end:
|
||||
|
@ -1293,7 +1284,10 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
raise_unmatched=not self.HEREDOC_TAG_IS_IDENTIFIER,
|
||||
)
|
||||
|
||||
if self._end and tag and self.HEREDOC_TAG_IS_IDENTIFIER:
|
||||
if tag and self.HEREDOC_TAG_IS_IDENTIFIER and (self._end or not tag.isidentifier()):
|
||||
if not self._end:
|
||||
self._advance(-1)
|
||||
|
||||
self._advance(-len(tag))
|
||||
self._add(self.HEREDOC_STRING_ALTERNATIVE)
|
||||
return True
|
||||
|
|
|
@ -505,7 +505,10 @@ def ensure_bools(expression: exp.Expression) -> exp.Expression:
|
|||
def _ensure_bool(node: exp.Expression) -> None:
|
||||
if (
|
||||
node.is_number
|
||||
or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
|
||||
or (
|
||||
not isinstance(node, exp.SubqueryPredicate)
|
||||
and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
|
||||
)
|
||||
or (isinstance(node, exp.Column) and not node.type)
|
||||
):
|
||||
node.replace(node.neq(0))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue