Merging upstream version 25.16.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
7688e2bdf8
commit
bad79d1f7c
110 changed files with 75353 additions and 68092 deletions
|
@ -6,6 +6,12 @@ from sqlglot.tokens import TokenType
|
|||
|
||||
|
||||
class Athena(Trino):
|
||||
class Tokenizer(Trino.Tokenizer):
|
||||
KEYWORDS = {
|
||||
**Trino.Tokenizer.KEYWORDS,
|
||||
"UNLOAD": TokenType.COMMAND,
|
||||
}
|
||||
|
||||
class Parser(Trino.Parser):
|
||||
STATEMENT_PARSERS = {
|
||||
**Trino.Parser.STATEMENT_PARSERS,
|
||||
|
|
|
@ -252,6 +252,7 @@ class BigQuery(Dialect):
|
|||
TIME_MAPPING = {
|
||||
"%D": "%m/%d/%y",
|
||||
"%E6S": "%S.%f",
|
||||
"%e": "%-d",
|
||||
}
|
||||
|
||||
FORMAT_MAPPING = {
|
||||
|
@ -401,6 +402,9 @@ class BigQuery(Dialect):
|
|||
),
|
||||
"TIMESTAMP_SECONDS": lambda args: exp.UnixToTime(this=seq_get(args, 0)),
|
||||
"TO_JSON_STRING": exp.JSONFormat.from_arg_list,
|
||||
"FORMAT_DATETIME": lambda args: exp.TimeToStr(
|
||||
this=exp.TsOrDsToTimestamp(this=seq_get(args, 1)), format=seq_get(args, 0)
|
||||
),
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
|
@ -500,7 +504,7 @@ class BigQuery(Dialect):
|
|||
table.set("db", exp.Identifier(this=parts[0]))
|
||||
table.set("this", exp.Identifier(this=parts[1]))
|
||||
|
||||
if any("." in p.name for p in table.parts):
|
||||
if isinstance(table.this, exp.Identifier) and any("." in p.name for p in table.parts):
|
||||
catalog, db, this, *rest = (
|
||||
exp.to_identifier(p, quoted=True)
|
||||
for p in split_num_words(".".join(p.name for p in table.parts), ".", 3)
|
||||
|
@ -583,6 +587,28 @@ class BigQuery(Dialect):
|
|||
|
||||
return bracket
|
||||
|
||||
def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]:
|
||||
unnest = super()._parse_unnest(with_alias=with_alias)
|
||||
|
||||
if not unnest:
|
||||
return None
|
||||
|
||||
unnest_expr = seq_get(unnest.expressions, 0)
|
||||
if unnest_expr:
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
unnest_expr = annotate_types(unnest_expr)
|
||||
|
||||
# Unnesting a nested array (i.e array of structs) explodes the top-level struct fields,
|
||||
# in contrast to other dialects such as DuckDB which flattens only the array by default
|
||||
if unnest_expr.is_type(exp.DataType.Type.ARRAY) and any(
|
||||
array_elem.is_type(exp.DataType.Type.STRUCT)
|
||||
for array_elem in unnest_expr._type.expressions
|
||||
):
|
||||
unnest.set("explode_array", True)
|
||||
|
||||
return unnest
|
||||
|
||||
class Generator(generator.Generator):
|
||||
EXPLICIT_SET_OP = True
|
||||
INTERVAL_ALLOWS_PLURAL_FORM = False
|
||||
|
@ -606,6 +632,7 @@ class BigQuery(Dialect):
|
|||
NAMED_PLACEHOLDER_TOKEN = "@"
|
||||
HEX_FUNC = "TO_HEX"
|
||||
WITH_PROPERTIES_PREFIX = "OPTIONS"
|
||||
SUPPORTS_EXPLODING_PROJECTIONS = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -878,8 +905,16 @@ class BigQuery(Dialect):
|
|||
return super().table_parts(expression)
|
||||
|
||||
def timetostr_sql(self, expression: exp.TimeToStr) -> str:
|
||||
this = expression.this if isinstance(expression.this, exp.TsOrDsToDate) else expression
|
||||
return self.func("FORMAT_DATE", self.format_time(expression), this.this)
|
||||
if isinstance(expression.this, exp.TsOrDsToTimestamp):
|
||||
func_name = "FORMAT_DATETIME"
|
||||
else:
|
||||
func_name = "FORMAT_DATE"
|
||||
this = (
|
||||
expression.this
|
||||
if isinstance(expression.this, (exp.TsOrDsToTimestamp, exp.TsOrDsToDate))
|
||||
else expression
|
||||
)
|
||||
return self.func(func_name, self.format_time(expression), this.this)
|
||||
|
||||
def eq_sql(self, expression: exp.EQ) -> str:
|
||||
# Operands of = cannot be NULL in BigQuery
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
import datetime
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
NormalizationStrategy,
|
||||
arg_max_or_min_no_count,
|
||||
build_date_delta,
|
||||
build_formatted_time,
|
||||
|
@ -31,7 +33,7 @@ def _build_date_format(args: t.List) -> exp.TimeToStr:
|
|||
|
||||
timezone = seq_get(args, 2)
|
||||
if timezone:
|
||||
expr.set("timezone", timezone)
|
||||
expr.set("zone", timezone)
|
||||
|
||||
return expr
|
||||
|
||||
|
@ -104,6 +106,28 @@ def _datetime_delta_sql(name: str) -> t.Callable[[Generator, DATEΤΙΜΕ_DELTA]
|
|||
return _delta_sql
|
||||
|
||||
|
||||
def _timestrtotime_sql(self: ClickHouse.Generator, expression: exp.TimeStrToTime):
|
||||
tz = expression.args.get("zone")
|
||||
datatype = exp.DataType.build(exp.DataType.Type.TIMESTAMP)
|
||||
ts = expression.this
|
||||
if tz:
|
||||
# build a datatype that encodes the timezone as a type parameter, eg DateTime('America/Los_Angeles')
|
||||
datatype = exp.DataType.build(
|
||||
exp.DataType.Type.TIMESTAMPTZ, # Type.TIMESTAMPTZ maps to DateTime
|
||||
expressions=[exp.DataTypeParam(this=tz)],
|
||||
)
|
||||
|
||||
if isinstance(ts, exp.Literal):
|
||||
# strip the timezone out of the literal, eg turn '2020-01-01 12:13:14-08:00' into '2020-01-01 12:13:14'
|
||||
# this is because Clickhouse encodes the timezone as a data type parameter and throws an error if it's part of the timestamp string
|
||||
ts_without_tz = (
|
||||
datetime.datetime.fromisoformat(ts.name).replace(tzinfo=None).isoformat(sep=" ")
|
||||
)
|
||||
ts = exp.Literal.string(ts_without_tz)
|
||||
|
||||
return self.sql(exp.cast(ts, datatype, dialect=self.dialect))
|
||||
|
||||
|
||||
class ClickHouse(Dialect):
|
||||
NORMALIZE_FUNCTIONS: bool | str = False
|
||||
NULL_ORDERING = "nulls_are_last"
|
||||
|
@ -112,10 +136,15 @@ class ClickHouse(Dialect):
|
|||
LOG_BASE_FIRST: t.Optional[bool] = None
|
||||
FORCE_EARLY_ALIAS_REF_EXPANSION = True
|
||||
|
||||
# https://github.com/ClickHouse/ClickHouse/issues/33935#issue-1112165779
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_SENSITIVE
|
||||
|
||||
UNESCAPED_SEQUENCES = {
|
||||
"\\0": "\0",
|
||||
}
|
||||
|
||||
CREATABLE_KIND_MAPPING = {"DATABASE": "SCHEMA"}
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
|
||||
IDENTIFIERS = ['"', "`"]
|
||||
|
@ -424,6 +453,27 @@ class ClickHouse(Dialect):
|
|||
"INDEX",
|
||||
}
|
||||
|
||||
PLACEHOLDER_PARSERS = {
|
||||
**parser.Parser.PLACEHOLDER_PARSERS,
|
||||
TokenType.L_BRACE: lambda self: self._parse_query_parameter(),
|
||||
}
|
||||
|
||||
def _parse_types(
|
||||
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
|
||||
) -> t.Optional[exp.Expression]:
|
||||
dtype = super()._parse_types(
|
||||
check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
|
||||
)
|
||||
if isinstance(dtype, exp.DataType):
|
||||
# Mark every type as non-nullable which is ClickHouse's default. This marker
|
||||
# helps us transpile types from other dialects to ClickHouse, so that we can
|
||||
# e.g. produce `CAST(x AS Nullable(String))` from `CAST(x AS TEXT)`. If there
|
||||
# is a `NULL` value in `x`, the former would fail in ClickHouse without the
|
||||
# `Nullable` type constructor
|
||||
dtype.set("nullable", False)
|
||||
|
||||
return dtype
|
||||
|
||||
def _parse_extract(self) -> exp.Extract | exp.Anonymous:
|
||||
index = self._index
|
||||
this = self._parse_bitwise()
|
||||
|
@ -433,7 +483,7 @@ class ClickHouse(Dialect):
|
|||
|
||||
# We return Anonymous here because extract and regexpExtract have different semantics,
|
||||
# so parsing extract(foo, bar) into RegexpExtract can potentially break queries. E.g.,
|
||||
# `extract('foobar', 'b')` works, but CH crashes for `regexpExtract('foobar', 'b')`.
|
||||
# `extract('foobar', 'b')` works, but ClickHouse crashes for `regexpExtract('foobar', 'b')`.
|
||||
#
|
||||
# TODO: can we somehow convert the former into an equivalent `regexpExtract` call?
|
||||
self._match(TokenType.COMMA)
|
||||
|
@ -454,14 +504,11 @@ class ClickHouse(Dialect):
|
|||
|
||||
return this
|
||||
|
||||
def _parse_placeholder(self) -> t.Optional[exp.Expression]:
|
||||
def _parse_query_parameter(self) -> t.Optional[exp.Expression]:
|
||||
"""
|
||||
Parse a placeholder expression like SELECT {abc: UInt32} or FROM {table: Identifier}
|
||||
https://clickhouse.com/docs/en/sql-reference/syntax#defining-and-using-query-parameters
|
||||
"""
|
||||
if not self._match(TokenType.L_BRACE):
|
||||
return None
|
||||
|
||||
this = self._parse_id_var()
|
||||
self._match(TokenType.COLON)
|
||||
kind = self._parse_types(check_func=False, allow_identifiers=False) or (
|
||||
|
@ -589,7 +636,7 @@ class ClickHouse(Dialect):
|
|||
|
||||
if isinstance(expr, exp.Window):
|
||||
# The window's func was parsed as Anonymous in base parser, fix its
|
||||
# type to be CH style CombinedAnonymousAggFunc / AnonymousAggFunc
|
||||
# type to be ClickHouse style CombinedAnonymousAggFunc / AnonymousAggFunc
|
||||
expr.set("this", func)
|
||||
elif params:
|
||||
# Params have blocked super()._parse_function() from parsing the following window
|
||||
|
@ -715,6 +762,7 @@ class ClickHouse(Dialect):
|
|||
GROUPINGS_SEP = ""
|
||||
SET_OP_MODIFIERS = False
|
||||
SUPPORTS_TABLE_ALIAS_COLUMNS = False
|
||||
VALUES_AS_TABLE = False
|
||||
|
||||
STRING_TYPE_MAPPING = {
|
||||
exp.DataType.Type.CHAR: "String",
|
||||
|
@ -741,7 +789,10 @@ class ClickHouse(Dialect):
|
|||
exp.DataType.Type.ARRAY: "Array",
|
||||
exp.DataType.Type.BIGINT: "Int64",
|
||||
exp.DataType.Type.DATE32: "Date32",
|
||||
exp.DataType.Type.DATETIME: "DateTime",
|
||||
exp.DataType.Type.DATETIME64: "DateTime64",
|
||||
exp.DataType.Type.TIMESTAMP: "DateTime",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "DateTime",
|
||||
exp.DataType.Type.DOUBLE: "Float64",
|
||||
exp.DataType.Type.ENUM: "Enum",
|
||||
exp.DataType.Type.ENUM8: "Enum8",
|
||||
|
@ -790,6 +841,7 @@ class ClickHouse(Dialect):
|
|||
exp.CurrentDate: lambda self, e: self.func("CURRENT_DATE"),
|
||||
exp.DateAdd: _datetime_delta_sql("DATE_ADD"),
|
||||
exp.DateDiff: _datetime_delta_sql("DATE_DIFF"),
|
||||
exp.DateStrToDate: rename_func("toDate"),
|
||||
exp.DateSub: _datetime_delta_sql("DATE_SUB"),
|
||||
exp.Explode: rename_func("arrayJoin"),
|
||||
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
|
||||
|
@ -810,8 +862,9 @@ class ClickHouse(Dialect):
|
|||
"position", e.this, e.args.get("substr"), e.args.get("position")
|
||||
),
|
||||
exp.TimeToStr: lambda self, e: self.func(
|
||||
"DATE_FORMAT", e.this, self.format_time(e), e.args.get("timezone")
|
||||
"DATE_FORMAT", e.this, self.format_time(e), e.args.get("zone")
|
||||
),
|
||||
exp.TimeStrToTime: _timestrtotime_sql,
|
||||
exp.TimestampAdd: _datetime_delta_sql("TIMESTAMP_ADD"),
|
||||
exp.TimestampSub: _datetime_delta_sql("TIMESTAMP_SUB"),
|
||||
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||
|
@ -823,6 +876,7 @@ class ClickHouse(Dialect):
|
|||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.TimestampTrunc: timestamptrunc_sql(zone=True),
|
||||
exp.Variance: rename_func("varSamp"),
|
||||
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
|
||||
exp.Stddev: rename_func("stddevSamp"),
|
||||
}
|
||||
|
||||
|
@ -833,7 +887,7 @@ class ClickHouse(Dialect):
|
|||
exp.OnCluster: exp.Properties.Location.POST_NAME,
|
||||
}
|
||||
|
||||
# there's no list in docs, but it can be found in Clickhouse code
|
||||
# There's no list in docs, but it can be found in Clickhouse code
|
||||
# see `ClickHouse/src/Parsers/ParserCreate*.cpp`
|
||||
ON_CLUSTER_TARGETS = {
|
||||
"DATABASE",
|
||||
|
@ -845,6 +899,14 @@ class ClickHouse(Dialect):
|
|||
"NAMED COLLECTION",
|
||||
}
|
||||
|
||||
# https://clickhouse.com/docs/en/sql-reference/data-types/nullable
|
||||
NON_NULLABLE_TYPES = {
|
||||
exp.DataType.Type.ARRAY,
|
||||
exp.DataType.Type.MAP,
|
||||
exp.DataType.Type.NULLABLE,
|
||||
exp.DataType.Type.STRUCT,
|
||||
}
|
||||
|
||||
def strtodate_sql(self, expression: exp.StrToDate) -> str:
|
||||
strtodate_sql = self.function_fallback_sql(expression)
|
||||
|
||||
|
@ -863,6 +925,14 @@ class ClickHouse(Dialect):
|
|||
|
||||
return super().cast_sql(expression, safe_prefix=safe_prefix)
|
||||
|
||||
def trycast_sql(self, expression: exp.TryCast) -> str:
|
||||
dtype = expression.to
|
||||
if not dtype.is_type(*self.NON_NULLABLE_TYPES, check_nullable=True):
|
||||
# Casting x into Nullable(T) appears to behave similarly to TRY_CAST(x AS T)
|
||||
dtype.set("nullable", True)
|
||||
|
||||
return super().cast_sql(expression)
|
||||
|
||||
def _jsonpathsubscript_sql(self, expression: exp.JSONPathSubscript) -> str:
|
||||
this = self.json_path_part(expression.this)
|
||||
return str(int(this) + 1) if is_int(this) else this
|
||||
|
@ -904,9 +974,30 @@ class ClickHouse(Dialect):
|
|||
#
|
||||
# https://clickhouse.com/docs/en/sql-reference/data-types/string
|
||||
if expression.this in self.STRING_TYPE_MAPPING:
|
||||
return "String"
|
||||
dtype = "String"
|
||||
else:
|
||||
dtype = super().datatype_sql(expression)
|
||||
|
||||
return super().datatype_sql(expression)
|
||||
# This section changes the type to `Nullable(...)` if the following conditions hold:
|
||||
# - It's marked as nullable - this ensures we won't wrap ClickHouse types with `Nullable`
|
||||
# and change their semantics
|
||||
# - It's not the key type of a `Map`. This is because ClickHouse enforces the following
|
||||
# constraint: "Type of Map key must be a type, that can be represented by integer or
|
||||
# String or FixedString (possibly LowCardinality) or UUID or IPv6"
|
||||
# - It's not a composite type, e.g. `Nullable(Array(...))` is not a valid type
|
||||
parent = expression.parent
|
||||
if (
|
||||
expression.args.get("nullable") is not False
|
||||
and not (
|
||||
isinstance(parent, exp.DataType)
|
||||
and parent.is_type(exp.DataType.Type.MAP, check_nullable=True)
|
||||
and expression.index in (None, 0)
|
||||
)
|
||||
and not expression.is_type(*self.NON_NULLABLE_TYPES, check_nullable=True)
|
||||
):
|
||||
dtype = f"Nullable({dtype})"
|
||||
|
||||
return dtype
|
||||
|
||||
def cte_sql(self, expression: exp.CTE) -> str:
|
||||
if expression.args.get("scalar"):
|
||||
|
@ -953,7 +1044,10 @@ class ClickHouse(Dialect):
|
|||
if expression.kind in self.ON_CLUSTER_TARGETS and locations.get(
|
||||
exp.Properties.Location.POST_NAME
|
||||
):
|
||||
this_name = self.sql(expression.this, "this")
|
||||
this_name = self.sql(
|
||||
expression.this if isinstance(expression.this, exp.Schema) else expression,
|
||||
"this",
|
||||
)
|
||||
this_properties = " ".join(
|
||||
[self.sql(prop) for prop in locations[exp.Properties.Location.POST_NAME]]
|
||||
)
|
||||
|
@ -962,6 +1056,24 @@ class ClickHouse(Dialect):
|
|||
|
||||
return super().createable_sql(expression, locations)
|
||||
|
||||
def create_sql(self, expression: exp.Create) -> str:
|
||||
# The comment property comes last in CTAS statements, i.e. after the query
|
||||
query = expression.expression
|
||||
if isinstance(query, exp.Query):
|
||||
comment_prop = expression.find(exp.SchemaCommentProperty)
|
||||
if comment_prop:
|
||||
comment_prop.pop()
|
||||
query.replace(exp.paren(query))
|
||||
else:
|
||||
comment_prop = None
|
||||
|
||||
create_sql = super().create_sql(expression)
|
||||
|
||||
comment_sql = self.sql(comment_prop)
|
||||
comment_sql = f" {comment_sql}" if comment_sql else ""
|
||||
|
||||
return f"{create_sql}{comment_sql}"
|
||||
|
||||
def prewhere_sql(self, expression: exp.PreWhere) -> str:
|
||||
this = self.indent(self.sql(expression, "this"))
|
||||
return f"{self.seg('PREWHERE')}{self.sep()}{this}"
|
||||
|
|
|
@ -133,6 +133,10 @@ class _Dialect(type):
|
|||
klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()}
|
||||
klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING)
|
||||
|
||||
klass.INVERSE_CREATABLE_KIND_MAPPING = {
|
||||
v: k for k, v in klass.CREATABLE_KIND_MAPPING.items()
|
||||
}
|
||||
|
||||
base = seq_get(bases, 0)
|
||||
base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),)
|
||||
base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),)
|
||||
|
@ -183,6 +187,9 @@ class _Dialect(type):
|
|||
if enum not in ("", "bigquery"):
|
||||
klass.generator_class.SELECT_KINDS = ()
|
||||
|
||||
if enum not in ("", "clickhouse"):
|
||||
klass.generator_class.SUPPORTS_NULLABLE_TYPES = False
|
||||
|
||||
if enum not in ("", "athena", "presto", "trino"):
|
||||
klass.generator_class.TRY_SUPPORTED = False
|
||||
klass.generator_class.SUPPORTS_UESCAPE = False
|
||||
|
@ -369,6 +376,24 @@ class Dialect(metaclass=_Dialect):
|
|||
Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks
|
||||
"""
|
||||
|
||||
HAS_DISTINCT_ARRAY_CONSTRUCTORS = False
|
||||
"""
|
||||
Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3)
|
||||
as the former is of type INT[] vs the latter which is SUPER
|
||||
"""
|
||||
|
||||
SUPPORTS_FIXED_SIZE_ARRAYS = False
|
||||
"""
|
||||
Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In
|
||||
dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator
|
||||
"""
|
||||
|
||||
CREATABLE_KIND_MAPPING: dict[str, str] = {}
|
||||
"""
|
||||
Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse
|
||||
equivalent of CREATE SCHEMA is CREATE DATABASE.
|
||||
"""
|
||||
|
||||
# --- Autofilled ---
|
||||
|
||||
tokenizer_class = Tokenizer
|
||||
|
@ -385,6 +410,8 @@ class Dialect(metaclass=_Dialect):
|
|||
INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {}
|
||||
INVERSE_FORMAT_TRIE: t.Dict = {}
|
||||
|
||||
INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {}
|
||||
|
||||
ESCAPED_SEQUENCES: t.Dict[str, str] = {}
|
||||
|
||||
# Delimiters for string literals and identifiers
|
||||
|
@ -635,6 +662,9 @@ class Dialect(metaclass=_Dialect):
|
|||
exp.GenerateDateArray: lambda self, e: self._annotate_with_type(
|
||||
e, exp.DataType.build("ARRAY<DATE>")
|
||||
),
|
||||
exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type(
|
||||
e, exp.DataType.build("ARRAY<TIMESTAMP>")
|
||||
),
|
||||
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
|
||||
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
|
||||
exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
|
||||
|
@ -1214,7 +1244,13 @@ def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
|
|||
|
||||
|
||||
def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
|
||||
return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
|
||||
datatype = (
|
||||
exp.DataType.Type.TIMESTAMPTZ
|
||||
if expression.args.get("zone")
|
||||
else exp.DataType.Type.TIMESTAMP
|
||||
)
|
||||
|
||||
return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect))
|
||||
|
||||
|
||||
def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
|
||||
|
@ -1464,14 +1500,19 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
|
|||
targets.add(normalize(alias.this))
|
||||
|
||||
for when in expression.expressions:
|
||||
when.transform(
|
||||
lambda node: (
|
||||
exp.column(node.this)
|
||||
if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
|
||||
else node
|
||||
),
|
||||
copy=False,
|
||||
)
|
||||
# only remove the target names from the THEN clause
|
||||
# theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED
|
||||
# ref: https://github.com/TobikoData/sqlmesh/issues/2934
|
||||
then = when.args.get("then")
|
||||
if then:
|
||||
then.transform(
|
||||
lambda node: (
|
||||
exp.column(node.this)
|
||||
if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
|
||||
else node
|
||||
),
|
||||
copy=False,
|
||||
)
|
||||
|
||||
return self.merge_sql(expression)
|
||||
|
||||
|
@ -1590,9 +1631,9 @@ def sha256_sql(self: Generator, expression: exp.SHA2) -> str:
|
|||
return self.func(f"SHA{expression.text('length') or '256'}", expression.this)
|
||||
|
||||
|
||||
def sequence_sql(self: Generator, expression: exp.GenerateSeries):
|
||||
start = expression.args["start"]
|
||||
end = expression.args["end"]
|
||||
def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str:
|
||||
start = expression.args.get("start")
|
||||
end = expression.args.get("end")
|
||||
step = expression.args.get("step")
|
||||
|
||||
if isinstance(start, exp.Cast):
|
||||
|
@ -1602,8 +1643,8 @@ def sequence_sql(self: Generator, expression: exp.GenerateSeries):
|
|||
else:
|
||||
target_type = None
|
||||
|
||||
if target_type and target_type.is_type("timestamp"):
|
||||
if target_type is start.to:
|
||||
if start and end and target_type and target_type.is_type("date", "timestamp"):
|
||||
if isinstance(start, exp.Cast) and target_type is start.to:
|
||||
end = exp.cast(end, target_type)
|
||||
else:
|
||||
start = exp.cast(start, target_type)
|
||||
|
|
|
@ -32,6 +32,7 @@ from sqlglot.dialects.dialect import (
|
|||
timestamptrunc_sql,
|
||||
timestrtotime_sql,
|
||||
unit_to_var,
|
||||
unit_to_str,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
@ -81,6 +82,16 @@ def _date_sql(self: DuckDB.Generator, expression: exp.Date) -> str:
|
|||
return result
|
||||
|
||||
|
||||
# BigQuery -> DuckDB conversion for the TIME_DIFF function
|
||||
def _timediff_sql(self: DuckDB.Generator, expression: exp.TimeDiff) -> str:
|
||||
this = exp.cast(expression.this, exp.DataType.Type.TIME)
|
||||
expr = exp.cast(expression.expression, exp.DataType.Type.TIME)
|
||||
|
||||
# Although the 2 dialects share similar signatures, BQ seems to inverse
|
||||
# the sign of the result so the start/end time operands are flipped
|
||||
return self.func("DATE_DIFF", unit_to_str(expression), expr, this)
|
||||
|
||||
|
||||
def _array_sort_sql(self: DuckDB.Generator, expression: exp.ArraySort) -> str:
|
||||
if expression.expression:
|
||||
self.unsupported("DuckDB ARRAY_SORT does not support a comparator")
|
||||
|
@ -160,8 +171,10 @@ def _datatype_sql(self: DuckDB.Generator, expression: exp.DataType) -> str:
|
|||
if expression.is_type("array"):
|
||||
return f"{self.expressions(expression, flat=True)}[{self.expressions(expression, key='values', flat=True)}]"
|
||||
|
||||
# Type TIMESTAMP / TIME WITH TIME ZONE does not support any modifiers
|
||||
if expression.is_type("timestamptz", "timetz"):
|
||||
# Modifiers are not supported for TIME, [TIME | TIMESTAMP] WITH TIME ZONE
|
||||
if expression.is_type(
|
||||
exp.DataType.Type.TIME, exp.DataType.Type.TIMETZ, exp.DataType.Type.TIMESTAMPTZ
|
||||
):
|
||||
return expression.this.value
|
||||
|
||||
return self.datatype_sql(expression)
|
||||
|
@ -198,6 +211,41 @@ def _arrow_json_extract_sql(self: DuckDB.Generator, expression: JSON_EXTRACT_TYP
|
|||
return arrow_sql
|
||||
|
||||
|
||||
def _implicit_datetime_cast(
|
||||
arg: t.Optional[exp.Expression], type: exp.DataType.Type = exp.DataType.Type.DATE
|
||||
) -> t.Optional[exp.Expression]:
|
||||
return exp.cast(arg, type) if isinstance(arg, exp.Literal) else arg
|
||||
|
||||
|
||||
def _date_diff_sql(self: DuckDB.Generator, expression: exp.DateDiff) -> str:
|
||||
this = _implicit_datetime_cast(expression.this)
|
||||
expr = _implicit_datetime_cast(expression.expression)
|
||||
|
||||
return self.func("DATE_DIFF", unit_to_str(expression), expr, this)
|
||||
|
||||
|
||||
def _generate_datetime_array_sql(
|
||||
self: DuckDB.Generator, expression: t.Union[exp.GenerateDateArray, exp.GenerateTimestampArray]
|
||||
) -> str:
|
||||
is_generate_date_array = isinstance(expression, exp.GenerateDateArray)
|
||||
|
||||
type = exp.DataType.Type.DATE if is_generate_date_array else exp.DataType.Type.TIMESTAMP
|
||||
start = _implicit_datetime_cast(expression.args.get("start"), type=type)
|
||||
end = _implicit_datetime_cast(expression.args.get("end"), type=type)
|
||||
|
||||
# BQ's GENERATE_DATE_ARRAY & GENERATE_TIMESTAMP_ARRAY are transformed to DuckDB'S GENERATE_SERIES
|
||||
gen_series: t.Union[exp.GenerateSeries, exp.Cast] = exp.GenerateSeries(
|
||||
start=start, end=end, step=expression.args.get("step")
|
||||
)
|
||||
|
||||
if is_generate_date_array:
|
||||
# The GENERATE_SERIES result type is TIMESTAMP array, so to match BQ's semantics for
|
||||
# GENERATE_DATE_ARRAY we must cast it back to DATE array
|
||||
gen_series = exp.cast(gen_series, exp.DataType.build("ARRAY<DATE>"))
|
||||
|
||||
return self.sql(gen_series)
|
||||
|
||||
|
||||
class DuckDB(Dialect):
|
||||
NULL_ORDERING = "nulls_are_last"
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
@ -205,6 +253,7 @@ class DuckDB(Dialect):
|
|||
INDEX_OFFSET = 1
|
||||
CONCAT_COALESCE = True
|
||||
SUPPORTS_ORDER_BY_ALL = True
|
||||
SUPPORTS_FIXED_SIZE_ARRAYS = True
|
||||
|
||||
# https://duckdb.org/docs/sql/introduction.html#creating-a-new-table
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
|
||||
|
@ -293,7 +342,7 @@ class DuckDB(Dialect):
|
|||
"LIST_HAS": exp.ArrayContains.from_arg_list,
|
||||
"LIST_REVERSE_SORT": _build_sort_array_desc,
|
||||
"LIST_SORT": exp.SortArray.from_arg_list,
|
||||
"LIST_VALUE": exp.Array.from_arg_list,
|
||||
"LIST_VALUE": lambda args: exp.Array(expressions=args),
|
||||
"MAKE_TIME": exp.TimeFromParts.from_arg_list,
|
||||
"MAKE_TIMESTAMP": _build_make_timestamp,
|
||||
"MEDIAN": lambda args: exp.PercentileCont(
|
||||
|
@ -416,6 +465,7 @@ class DuckDB(Dialect):
|
|||
COPY_HAS_INTO_KEYWORD = False
|
||||
STAR_EXCEPT = "EXCLUDE"
|
||||
PAD_FILL_PATTERN_IS_REQUIRED = True
|
||||
ARRAY_CONCAT_IS_VAR_LEN = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -441,9 +491,7 @@ class DuckDB(Dialect):
|
|||
exp.DateAdd: _date_delta_sql,
|
||||
exp.DateFromParts: rename_func("MAKE_DATE"),
|
||||
exp.DateSub: _date_delta_sql,
|
||||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF", f"'{e.args.get('unit') or 'DAY'}'", e.expression, e.this
|
||||
),
|
||||
exp.DateDiff: _date_diff_sql,
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.Datetime: no_datetime_sql,
|
||||
exp.DatetimeSub: _date_delta_sql,
|
||||
|
@ -454,6 +502,8 @@ class DuckDB(Dialect):
|
|||
exp.DiToDate: lambda self,
|
||||
e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.DATEINT_FORMAT}) AS DATE)",
|
||||
exp.Encode: lambda self, e: encode_decode_sql(self, e, "ENCODE", replace=False),
|
||||
exp.GenerateDateArray: _generate_datetime_array_sql,
|
||||
exp.GenerateTimestampArray: _generate_datetime_array_sql,
|
||||
exp.Explode: rename_func("UNNEST"),
|
||||
exp.IntDiv: lambda self, e: self.binary(e, "//"),
|
||||
exp.IsInf: rename_func("ISINF"),
|
||||
|
@ -498,6 +548,7 @@ class DuckDB(Dialect):
|
|||
exp.Struct: _struct_sql,
|
||||
exp.TimeAdd: _date_delta_sql,
|
||||
exp.Time: no_time_sql,
|
||||
exp.TimeDiff: _timediff_sql,
|
||||
exp.Timestamp: no_timestamp_sql,
|
||||
exp.TimestampDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this
|
||||
|
@ -522,6 +573,9 @@ class DuckDB(Dialect):
|
|||
exp.UnixToStr: lambda self, e: self.func(
|
||||
"STRFTIME", self.func("TO_TIMESTAMP", e.this), self.format_time(e)
|
||||
),
|
||||
exp.DatetimeTrunc: lambda self, e: self.func(
|
||||
"DATE_TRUNC", unit_to_str(e), exp.cast(e.this, exp.DataType.Type.DATETIME)
|
||||
),
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)",
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
|
@ -650,6 +704,9 @@ class DuckDB(Dialect):
|
|||
PROPERTIES_LOCATION[exp.TemporaryProperty] = exp.Properties.Location.POST_CREATE
|
||||
PROPERTIES_LOCATION[exp.ReturnsProperty] = exp.Properties.Location.POST_ALIAS
|
||||
|
||||
def fromiso8601timestamp_sql(self, expression: exp.FromISO8601Timestamp) -> str:
|
||||
return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ))
|
||||
|
||||
def strtotime_sql(self, expression: exp.StrToTime) -> str:
|
||||
if expression.args.get("safe"):
|
||||
formatted_time = self.format_time(expression)
|
||||
|
@ -832,3 +889,24 @@ class DuckDB(Dialect):
|
|||
return self.func("STRUCT_PACK", kv_sql)
|
||||
|
||||
return self.func("STRUCT_INSERT", this, kv_sql)
|
||||
|
||||
def unnest_sql(self, expression: exp.Unnest) -> str:
|
||||
explode_array = expression.args.get("explode_array")
|
||||
if explode_array:
|
||||
# In BigQuery, UNNESTing a nested array leads to explosion of the top-level array & struct
|
||||
# This is transpiled to DDB by transforming "FROM UNNEST(...)" to "FROM (SELECT UNNEST(..., max_depth => 2))"
|
||||
expression.expressions.append(
|
||||
exp.Kwarg(this=exp.var("max_depth"), expression=exp.Literal.number(2))
|
||||
)
|
||||
|
||||
# If BQ's UNNEST is aliased, we transform it from a column alias to a table alias in DDB
|
||||
alias = expression.args.get("alias")
|
||||
if alias:
|
||||
expression.set("alias", None)
|
||||
alias = exp.TableAlias(this=seq_get(alias.args.get("columns"), 0))
|
||||
|
||||
unnest_sql = super().unnest_sql(expression)
|
||||
select = exp.Select(expressions=[unnest_sql]).subquery(alias)
|
||||
return self.sql(select)
|
||||
|
||||
return super().unnest_sql(expression)
|
||||
|
|
|
@ -252,6 +252,7 @@ class Hive(Dialect):
|
|||
"ADD FILES": TokenType.COMMAND,
|
||||
"ADD JAR": TokenType.COMMAND,
|
||||
"ADD JARS": TokenType.COMMAND,
|
||||
"MINUS": TokenType.EXCEPT,
|
||||
"MSCK REPAIR": TokenType.COMMAND,
|
||||
"REFRESH": TokenType.REFRESH,
|
||||
"TIMESTAMP AS OF": TokenType.TIMESTAMP_SNAPSHOT,
|
||||
|
@ -509,6 +510,7 @@ class Hive(Dialect):
|
|||
e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
|
||||
exp.FromBase64: rename_func("UNBASE64"),
|
||||
exp.GenerateSeries: sequence_sql,
|
||||
exp.GenerateDateArray: sequence_sql,
|
||||
exp.If: if_sql(),
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.IsNan: rename_func("ISNAN"),
|
||||
|
@ -552,6 +554,7 @@ class Hive(Dialect):
|
|||
exp.StrToTime: _str_to_time_sql,
|
||||
exp.StrToUnix: _str_to_unix_sql,
|
||||
exp.StructExtract: struct_extract_sql,
|
||||
exp.Table: transforms.preprocess([transforms.unnest_generate_series]),
|
||||
exp.TimeStrToDate: rename_func("TO_DATE"),
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
|
@ -570,6 +573,7 @@ class Hive(Dialect):
|
|||
),
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
|
||||
exp.Unnest: rename_func("EXPLODE"),
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
|
||||
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
|
||||
exp.National: lambda self, e: self.national_sql(e, prefix=""),
|
||||
|
@ -593,6 +597,9 @@ class Hive(Dialect):
|
|||
exp.WithDataProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def unnest_sql(self, expression: exp.Unnest) -> str:
|
||||
return rename_func("EXPLODE")(self, expression)
|
||||
|
||||
def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str:
|
||||
if isinstance(expression.this, exp.JSONPathWildcard):
|
||||
self.unsupported("Unsupported wildcard in JSONPathKey expression")
|
||||
|
|
|
@ -302,6 +302,9 @@ class MySQL(Dialect):
|
|||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"CONVERT_TZ": lambda args: exp.ConvertTimezone(
|
||||
source_tz=seq_get(args, 1), target_tz=seq_get(args, 2), timestamp=seq_get(args, 0)
|
||||
),
|
||||
"DATE": lambda args: exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
"DATE_ADD": build_date_delta_with_interval(exp.DateAdd),
|
||||
"DATE_FORMAT": build_formatted_time(exp.TimeToStr, "mysql"),
|
||||
|
@ -724,6 +727,7 @@ class MySQL(Dialect):
|
|||
transforms.eliminate_semi_and_anti_joins,
|
||||
transforms.eliminate_qualify,
|
||||
transforms.eliminate_full_outer_join,
|
||||
transforms.unnest_generate_date_array_using_recursive_cte,
|
||||
]
|
||||
),
|
||||
exp.StrPosition: strposition_to_locate_sql,
|
||||
|
@ -1213,3 +1217,10 @@ class MySQL(Dialect):
|
|||
dateadd = build_date_delta_with_interval(exp.DateAdd)([start_ts, interval])
|
||||
|
||||
return self.sql(dateadd)
|
||||
|
||||
def converttimezone_sql(self, expression: exp.ConvertTimezone) -> str:
|
||||
from_tz = expression.args.get("source_tz")
|
||||
to_tz = expression.args.get("target_tz")
|
||||
dt = expression.args.get("timestamp")
|
||||
|
||||
return self.func("CONVERT_TZ", dt, from_tz, to_tz)
|
||||
|
|
|
@ -78,6 +78,8 @@ class Oracle(Dialect):
|
|||
for prefix in ("U", "u")
|
||||
]
|
||||
|
||||
NESTED_COMMENTS = False
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"(+)": TokenType.JOIN_MARKER,
|
||||
|
@ -90,7 +92,6 @@ class Oracle(Dialect):
|
|||
"ORDER SIBLINGS BY": TokenType.ORDER_SIBLINGS_BY,
|
||||
"SAMPLE": TokenType.TABLE_SAMPLE,
|
||||
"START": TokenType.BEGIN,
|
||||
"SYSDATE": TokenType.CURRENT_TIMESTAMP,
|
||||
"TOP": TokenType.TOP,
|
||||
"VARCHAR2": TokenType.VARCHAR,
|
||||
}
|
||||
|
@ -106,8 +107,15 @@ class Oracle(Dialect):
|
|||
"TO_CHAR": _build_timetostr_or_tochar,
|
||||
"TO_TIMESTAMP": build_formatted_time(exp.StrToTime, "oracle"),
|
||||
"TO_DATE": build_formatted_time(exp.StrToDate, "oracle"),
|
||||
"NVL": lambda args: exp.Coalesce(
|
||||
this=seq_get(args, 0), expressions=args[1:], is_nvl=True
|
||||
),
|
||||
}
|
||||
|
||||
NO_PAREN_FUNCTION_PARSERS = {
|
||||
**parser.Parser.NO_PAREN_FUNCTION_PARSERS,
|
||||
"SYSDATE": lambda self: self.expression(exp.CurrentTimestamp, sysdate=True),
|
||||
}
|
||||
FUNCTIONS.pop("NVL")
|
||||
|
||||
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
|
@ -247,6 +255,7 @@ class Oracle(Dialect):
|
|||
),
|
||||
exp.Group: transforms.preprocess([transforms.unalias_group]),
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.Mod: rename_func("MOD"),
|
||||
exp.Select: transforms.preprocess(
|
||||
[
|
||||
transforms.eliminate_distinct_on,
|
||||
|
@ -274,6 +283,9 @@ class Oracle(Dialect):
|
|||
}
|
||||
|
||||
def currenttimestamp_sql(self, expression: exp.CurrentTimestamp) -> str:
|
||||
if expression.args.get("sysdate"):
|
||||
return "SYSDATE"
|
||||
|
||||
this = expression.this
|
||||
return self.func("CURRENT_TIMESTAMP", this) if this else "CURRENT_TIMESTAMP"
|
||||
|
||||
|
@ -291,7 +303,7 @@ class Oracle(Dialect):
|
|||
)
|
||||
return f"XMLTABLE({self.sep('')}{self.indent(this + passing + by_ref + columns)}{self.seg(')', sep='')}"
|
||||
|
||||
def add_column_sql(self, expression: exp.AlterTable) -> str:
|
||||
def add_column_sql(self, expression: exp.Alter) -> str:
|
||||
actions = self.expressions(expression, key="actions", flat=True)
|
||||
if len(expression.args.get("actions", [])) > 1:
|
||||
return f"ADD ({actions})"
|
||||
|
@ -303,3 +315,7 @@ class Oracle(Dialect):
|
|||
value = f" CONSTRAINT {value}" if value else ""
|
||||
|
||||
return f"{option}{value}"
|
||||
|
||||
def coalesce_sql(self, expression: exp.Coalesce) -> str:
|
||||
func_name = "NVL" if expression.args.get("is_nvl") else "COALESCE"
|
||||
return rename_func(func_name)(self, expression)
|
||||
|
|
|
@ -166,7 +166,7 @@ def _serial_to_generated(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def _build_generate_series(args: t.List) -> exp.GenerateSeries:
|
||||
def _build_generate_series(args: t.List) -> exp.ExplodingGenerateSeries:
|
||||
# The goal is to convert step values like '1 day' or INTERVAL '1 day' into INTERVAL '1' day
|
||||
# Note: postgres allows calls with just two arguments -- the "step" argument defaults to 1
|
||||
step = seq_get(args, 2)
|
||||
|
@ -176,7 +176,7 @@ def _build_generate_series(args: t.List) -> exp.GenerateSeries:
|
|||
elif isinstance(step, exp.Interval) and not step.args.get("unit"):
|
||||
args[2] = exp.to_interval(step.this.this)
|
||||
|
||||
return exp.GenerateSeries.from_arg_list(args)
|
||||
return exp.ExplodingGenerateSeries.from_arg_list(args)
|
||||
|
||||
|
||||
def _build_to_timestamp(args: t.List) -> exp.UnixToTime | exp.StrToTime:
|
||||
|
@ -440,7 +440,7 @@ class Postgres(Dialect):
|
|||
self._match(TokenType.COMMA)
|
||||
value = self._parse_bitwise()
|
||||
|
||||
if part and part.is_string:
|
||||
if part and isinstance(part, (exp.Column, exp.Literal)):
|
||||
part = exp.var(part.name)
|
||||
|
||||
return self.expression(exp.Extract, this=part, expression=value)
|
||||
|
@ -466,6 +466,7 @@ class Postgres(Dialect):
|
|||
MULTI_ARG_DISTINCT = False
|
||||
CAN_IMPLEMENT_ARRAY_ANY = True
|
||||
COPY_HAS_INTO_KEYWORD = False
|
||||
ARRAY_CONCAT_IS_VAR_LEN = False
|
||||
|
||||
SUPPORTED_JSON_PATH_PARTS = {
|
||||
exp.JSONPathKey,
|
||||
|
@ -487,12 +488,7 @@ class Postgres(Dialect):
|
|||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.AnyValue: any_value_to_max_sql,
|
||||
exp.Array: lambda self, e: (
|
||||
f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
|
||||
if isinstance(seq_get(e.expressions, 0), exp.Select)
|
||||
else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]"
|
||||
),
|
||||
exp.ArrayConcat: rename_func("ARRAY_CAT"),
|
||||
exp.ArrayConcat: lambda self, e: self.arrayconcat_sql(e, name="ARRAY_CAT"),
|
||||
exp.ArrayContainsAll: lambda self, e: self.binary(e, "@>"),
|
||||
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
|
||||
exp.ArrayFilter: filter_array_using_unnest,
|
||||
|
@ -507,6 +503,7 @@ class Postgres(Dialect):
|
|||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateSub: _date_add_sql("-"),
|
||||
exp.Explode: rename_func("UNNEST"),
|
||||
exp.ExplodingGenerateSeries: rename_func("GENERATE_SERIES"),
|
||||
exp.GroupConcat: _string_agg_sql,
|
||||
exp.IntDiv: rename_func("DIV"),
|
||||
exp.JSONExtract: _json_extract_sql("JSON_EXTRACT_PATH", "->"),
|
||||
|
@ -587,21 +584,32 @@ class Postgres(Dialect):
|
|||
|
||||
def unnest_sql(self, expression: exp.Unnest) -> str:
|
||||
if len(expression.expressions) == 1:
|
||||
arg = expression.expressions[0]
|
||||
if isinstance(arg, exp.GenerateDateArray):
|
||||
generate_series: exp.Expression = exp.GenerateSeries(**arg.args)
|
||||
if isinstance(expression.parent, (exp.From, exp.Join)):
|
||||
generate_series = (
|
||||
exp.select("value::date")
|
||||
.from_(generate_series.as_("value"))
|
||||
.subquery(expression.args.get("alias") or "_unnested_generate_series")
|
||||
)
|
||||
return self.sql(generate_series)
|
||||
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
this = annotate_types(expression.expressions[0])
|
||||
this = annotate_types(arg)
|
||||
if this.is_type("array<json>"):
|
||||
while isinstance(this, exp.Cast):
|
||||
this = this.this
|
||||
|
||||
arg = self.sql(exp.cast(this, exp.DataType.Type.JSON))
|
||||
arg_as_json = self.sql(exp.cast(this, exp.DataType.Type.JSON))
|
||||
alias = self.sql(expression, "alias")
|
||||
alias = f" AS {alias}" if alias else ""
|
||||
|
||||
if expression.args.get("offset"):
|
||||
self.unsupported("Unsupported JSON_ARRAY_ELEMENTS with offset")
|
||||
|
||||
return f"JSON_ARRAY_ELEMENTS({arg}){alias}"
|
||||
return f"JSON_ARRAY_ELEMENTS({arg_as_json}){alias}"
|
||||
|
||||
return super().unnest_sql(expression)
|
||||
|
||||
|
@ -646,3 +654,11 @@ class Postgres(Dialect):
|
|||
return self.sql(this)
|
||||
|
||||
return super().cast_sql(expression, safe_prefix=safe_prefix)
|
||||
|
||||
def array_sql(self, expression: exp.Array) -> str:
|
||||
exprs = expression.expressions
|
||||
return (
|
||||
f"{self.normalize_func('ARRAY')}({self.sql(exprs[0])})"
|
||||
if isinstance(seq_get(exprs, 0), exp.Select)
|
||||
else f"{self.normalize_func('ARRAY')}[{self.expressions(expression, flat=True)}]"
|
||||
)
|
||||
|
|
|
@ -142,17 +142,6 @@ def _build_from_unixtime(args: t.List) -> exp.Expression:
|
|||
return exp.UnixToTime.from_arg_list(args)
|
||||
|
||||
|
||||
def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
|
||||
if isinstance(expression, exp.Table):
|
||||
if isinstance(expression.this, exp.GenerateSeries):
|
||||
unnest = exp.Unnest(expressions=[expression.this])
|
||||
|
||||
if expression.alias:
|
||||
return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
|
||||
return unnest
|
||||
return expression
|
||||
|
||||
|
||||
def _first_last_sql(self: Presto.Generator, expression: exp.Func) -> str:
|
||||
"""
|
||||
Trino doesn't support FIRST / LAST as functions, but they're valid in the context
|
||||
|
@ -245,13 +234,17 @@ class Presto(Dialect):
|
|||
INDEX_OFFSET = 1
|
||||
NULL_ORDERING = "nulls_are_last"
|
||||
TIME_FORMAT = MySQL.TIME_FORMAT
|
||||
TIME_MAPPING = MySQL.TIME_MAPPING
|
||||
STRICT_STRING_CONCAT = True
|
||||
SUPPORTS_SEMI_ANTI_JOIN = False
|
||||
TYPED_DIVISION = True
|
||||
TABLESAMPLE_SIZE_IS_PERCENT = True
|
||||
LOG_BASE_FIRST: t.Optional[bool] = None
|
||||
|
||||
TIME_MAPPING = {
|
||||
**MySQL.TIME_MAPPING,
|
||||
"%W": "%A",
|
||||
}
|
||||
|
||||
# https://github.com/trinodb/trino/issues/17
|
||||
# https://github.com/trinodb/trino/issues/12289
|
||||
# https://github.com/prestodb/presto/issues/2863
|
||||
|
@ -434,6 +427,7 @@ class Presto(Dialect):
|
|||
exp.FromTimeZone: lambda self,
|
||||
e: f"WITH_TIMEZONE({self.sql(e, 'this')}, {self.sql(e, 'zone')}) AT TIME ZONE 'UTC'",
|
||||
exp.GenerateSeries: sequence_sql,
|
||||
exp.GenerateDateArray: sequence_sql,
|
||||
exp.Group: transforms.preprocess([transforms.unalias_group]),
|
||||
exp.GroupConcat: lambda self, e: self.func(
|
||||
"ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator")
|
||||
|
@ -471,7 +465,7 @@ class Presto(Dialect):
|
|||
exp.StrToMap: rename_func("SPLIT_TO_MAP"),
|
||||
exp.StrToTime: _str_to_time_sql,
|
||||
exp.StructExtract: struct_extract_sql,
|
||||
exp.Table: transforms.preprocess([_unnest_sequence]),
|
||||
exp.Table: transforms.preprocess([transforms.unnest_generate_series]),
|
||||
exp.Timestamp: no_timestamp_sql,
|
||||
exp.TimestampAdd: _date_delta_sql("DATE_ADD"),
|
||||
exp.TimestampTrunc: timestamptrunc_sql(),
|
||||
|
|
|
@ -17,6 +17,7 @@ from sqlglot.dialects.dialect import (
|
|||
from sqlglot.dialects.postgres import Postgres
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
from sqlglot.parser import build_convert_timezone
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E
|
||||
|
@ -45,13 +46,11 @@ class Redshift(Postgres):
|
|||
INDEX_OFFSET = 0
|
||||
COPY_PARAMS_ARE_CSV = False
|
||||
HEX_LOWERCASE = True
|
||||
HAS_DISTINCT_ARRAY_CONSTRUCTORS = True
|
||||
|
||||
TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'"
|
||||
TIME_MAPPING = {
|
||||
**Postgres.TIME_MAPPING,
|
||||
"MON": "%b",
|
||||
"HH": "%H",
|
||||
}
|
||||
# ref: https://docs.aws.amazon.com/redshift/latest/dg/r_FORMAT_strings.html
|
||||
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
|
||||
TIME_MAPPING = {**Postgres.TIME_MAPPING, "MON": "%b", "HH24": "%H", "HH": "%I"}
|
||||
|
||||
class Parser(Postgres.Parser):
|
||||
FUNCTIONS = {
|
||||
|
@ -62,6 +61,7 @@ class Redshift(Postgres):
|
|||
unit=exp.var("month"),
|
||||
return_type=exp.DataType.build("TIMESTAMP"),
|
||||
),
|
||||
"CONVERT_TIMEZONE": lambda args: build_convert_timezone(args, "UTC"),
|
||||
"DATEADD": _build_date_delta(exp.TsOrDsAdd),
|
||||
"DATE_ADD": _build_date_delta(exp.TsOrDsAdd),
|
||||
"DATEDIFF": _build_date_delta(exp.TsOrDsDiff),
|
||||
|
@ -77,7 +77,7 @@ class Redshift(Postgres):
|
|||
NO_PAREN_FUNCTION_PARSERS = {
|
||||
**Postgres.Parser.NO_PAREN_FUNCTION_PARSERS,
|
||||
"APPROXIMATE": lambda self: self._parse_approximate_count(),
|
||||
"SYSDATE": lambda self: self.expression(exp.CurrentTimestamp, transaction=True),
|
||||
"SYSDATE": lambda self: self.expression(exp.CurrentTimestamp, sysdate=True),
|
||||
}
|
||||
|
||||
SUPPORTS_IMPLICIT_UNNEST = True
|
||||
|
@ -153,6 +153,8 @@ class Redshift(Postgres):
|
|||
COPY_PARAMS_ARE_WRAPPED = False
|
||||
HEX_FUNC = "TO_HEX"
|
||||
PARSE_JSON_NAME = "JSON_PARSE"
|
||||
ARRAY_CONCAT_IS_VAR_LEN = False
|
||||
SUPPORTS_CONVERT_TIMEZONE = True
|
||||
|
||||
# Redshift doesn't have `WITH` as part of their with_properties so we remove it
|
||||
WITH_PROPERTIES_PREFIX = " "
|
||||
|
@ -169,12 +171,13 @@ class Redshift(Postgres):
|
|||
|
||||
TRANSFORMS = {
|
||||
**Postgres.Generator.TRANSFORMS,
|
||||
exp.ArrayConcat: lambda self, e: self.arrayconcat_sql(e, name="ARRAY_CONCAT"),
|
||||
exp.Concat: concat_to_dpipe_sql,
|
||||
exp.ConcatWs: concat_ws_to_dpipe_sql,
|
||||
exp.ApproxDistinct: lambda self,
|
||||
e: f"APPROXIMATE COUNT(DISTINCT {self.sql(e, 'this')})",
|
||||
exp.CurrentTimestamp: lambda self, e: (
|
||||
"SYSDATE" if e.args.get("transaction") else "GETDATE()"
|
||||
"SYSDATE" if e.args.get("sysdate") else "GETDATE()"
|
||||
),
|
||||
exp.DateAdd: date_delta_sql("DATEADD"),
|
||||
exp.DateDiff: date_delta_sql("DATEDIFF"),
|
||||
|
@ -191,6 +194,7 @@ class Redshift(Postgres):
|
|||
transforms.eliminate_distinct_on,
|
||||
transforms.eliminate_semi_and_anti_joins,
|
||||
transforms.unqualify_unnest,
|
||||
transforms.unnest_generate_date_array_using_recursive_cte,
|
||||
]
|
||||
),
|
||||
exp.SortKeyProperty: lambda self,
|
||||
|
@ -423,3 +427,9 @@ class Redshift(Postgres):
|
|||
file_format = f" FILE FORMAT {file_format}" if file_format else ""
|
||||
|
||||
return f"SET{exprs}{location}{file_format}"
|
||||
|
||||
def array_sql(self, expression: exp.Array) -> str:
|
||||
if expression.args.get("bracket_notation"):
|
||||
return super().array_sql(expression)
|
||||
|
||||
return rename_func("ARRAY")(self, expression)
|
||||
|
|
|
@ -122,12 +122,6 @@ def _regexpilike_sql(self: Snowflake.Generator, expression: exp.RegexpILike) ->
|
|||
)
|
||||
|
||||
|
||||
def _build_convert_timezone(args: t.List) -> t.Union[exp.Anonymous, exp.AtTimeZone]:
|
||||
if len(args) == 3:
|
||||
return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args)
|
||||
return exp.AtTimeZone(this=seq_get(args, 1), zone=seq_get(args, 0))
|
||||
|
||||
|
||||
def _build_regexp_replace(args: t.List) -> exp.RegexpReplace:
|
||||
regexp_replace = exp.RegexpReplace.from_arg_list(args)
|
||||
|
||||
|
@ -186,6 +180,47 @@ def _flatten_structured_types_unless_iceberg(expression: exp.Expression) -> exp.
|
|||
return expression
|
||||
|
||||
|
||||
def _unnest_generate_date_array(expression: exp.Expression) -> exp.Expression:
|
||||
if isinstance(expression, exp.Select):
|
||||
for unnest in expression.find_all(exp.Unnest):
|
||||
if (
|
||||
isinstance(unnest.parent, (exp.From, exp.Join))
|
||||
and len(unnest.expressions) == 1
|
||||
and isinstance(unnest.expressions[0], exp.GenerateDateArray)
|
||||
):
|
||||
generate_date_array = unnest.expressions[0]
|
||||
start = generate_date_array.args.get("start")
|
||||
end = generate_date_array.args.get("end")
|
||||
step = generate_date_array.args.get("step")
|
||||
|
||||
if not start or not end or not isinstance(step, exp.Interval) or step.name != "1":
|
||||
continue
|
||||
|
||||
unit = step.args.get("unit")
|
||||
|
||||
unnest_alias = unnest.args.get("alias")
|
||||
if unnest_alias:
|
||||
unnest_alias = unnest_alias.copy()
|
||||
sequence_value_name = seq_get(unnest_alias.columns, 0) or "value"
|
||||
else:
|
||||
sequence_value_name = "value"
|
||||
|
||||
# We'll add the next sequence value to the starting date and project the result
|
||||
date_add = _build_date_time_add(exp.DateAdd)(
|
||||
[unit, exp.cast(sequence_value_name, "int"), exp.cast(start, "date")]
|
||||
).as_(sequence_value_name)
|
||||
|
||||
# We use DATEDIFF to compute the number of sequence values needed
|
||||
number_sequence = Snowflake.Parser.FUNCTIONS["ARRAY_GENERATE_RANGE"](
|
||||
[exp.Literal.number(0), _build_datediff([unit, start, end]) + 1]
|
||||
)
|
||||
|
||||
unnest.set("expressions", [number_sequence])
|
||||
unnest.replace(exp.select(date_add).from_(unnest.copy()).subquery(unnest_alias))
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
class Snowflake(Dialect):
|
||||
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
|
||||
|
@ -255,7 +290,7 @@ class Snowflake(Dialect):
|
|||
**parser.Parser.FUNCTIONS,
|
||||
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
||||
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
|
||||
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
|
||||
"ARRAY_CONSTRUCT": lambda args: exp.Array(expressions=args),
|
||||
"ARRAY_CONTAINS": lambda args: exp.ArrayContains(
|
||||
this=seq_get(args, 1), expression=seq_get(args, 0)
|
||||
),
|
||||
|
@ -268,7 +303,6 @@ class Snowflake(Dialect):
|
|||
"BITXOR": binary_from_function(exp.BitwiseXor),
|
||||
"BIT_XOR": binary_from_function(exp.BitwiseXor),
|
||||
"BOOLXOR": binary_from_function(exp.Xor),
|
||||
"CONVERT_TIMEZONE": _build_convert_timezone,
|
||||
"DATE": _build_datetime("DATE", exp.DataType.Type.DATE),
|
||||
"DATE_TRUNC": _date_trunc_to_time,
|
||||
"DATEADD": _build_date_time_add(exp.DateAdd),
|
||||
|
@ -413,6 +447,26 @@ class Snowflake(Dialect):
|
|||
),
|
||||
}
|
||||
|
||||
def _negate_range(
|
||||
self, this: t.Optional[exp.Expression] = None
|
||||
) -> t.Optional[exp.Expression]:
|
||||
if not this:
|
||||
return this
|
||||
|
||||
query = this.args.get("query")
|
||||
if isinstance(this, exp.In) and isinstance(query, exp.Query):
|
||||
# Snowflake treats `value NOT IN (subquery)` as `VALUE <> ALL (subquery)`, so
|
||||
# we do this conversion here to avoid parsing it into `NOT value IN (subquery)`
|
||||
# which can produce different results (most likely a SnowFlake bug).
|
||||
#
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/in
|
||||
# Context: https://github.com/tobymao/sqlglot/issues/3890
|
||||
return self.expression(
|
||||
exp.NEQ, this=this.this, expression=exp.All(this=query.unnest())
|
||||
)
|
||||
|
||||
return self.expression(exp.Not, this=this)
|
||||
|
||||
def _parse_with_constraint(self) -> t.Optional[exp.Expression]:
|
||||
if self._prev.token_type != TokenType.WITH:
|
||||
self._retreat(self._index - 1)
|
||||
|
@ -638,6 +692,7 @@ class Snowflake(Dialect):
|
|||
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
|
||||
RAW_STRINGS = ["$$"]
|
||||
COMMENTS = ["--", "//", ("/*", "*/")]
|
||||
NESTED_COMMENTS = False
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
|
@ -692,6 +747,9 @@ class Snowflake(Dialect):
|
|||
COPY_PARAMS_ARE_WRAPPED = False
|
||||
COPY_PARAMS_EQ_REQUIRED = True
|
||||
STAR_EXCEPT = "EXCLUDE"
|
||||
SUPPORTS_EXPLODING_PROJECTIONS = False
|
||||
ARRAY_CONCAT_IS_VAR_LEN = False
|
||||
SUPPORTS_CONVERT_TIMEZONE = True
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -699,7 +757,7 @@ class Snowflake(Dialect):
|
|||
exp.ArgMax: rename_func("MAX_BY"),
|
||||
exp.ArgMin: rename_func("MIN_BY"),
|
||||
exp.Array: inline_array_sql,
|
||||
exp.ArrayConcat: rename_func("ARRAY_CAT"),
|
||||
exp.ArrayConcat: lambda self, e: self.arrayconcat_sql(e, name="ARRAY_CAT"),
|
||||
exp.ArrayContains: lambda self, e: self.func("ARRAY_CONTAINS", e.expression, e.this),
|
||||
exp.AtTimeZone: lambda self, e: self.func(
|
||||
"CONVERT_TIMEZONE", e.args.get("zone"), e.this
|
||||
|
@ -751,6 +809,7 @@ class Snowflake(Dialect):
|
|||
transforms.eliminate_distinct_on,
|
||||
transforms.explode_to_unnest(),
|
||||
transforms.eliminate_semi_and_anti_joins,
|
||||
_unnest_generate_date_array,
|
||||
]
|
||||
),
|
||||
exp.SHA: rename_func("SHA1"),
|
||||
|
|
|
@ -132,6 +132,7 @@ class Spark(Spark2):
|
|||
class Generator(Spark2.Generator):
|
||||
SUPPORTS_TO_NUMBER = True
|
||||
PAD_FILL_PATTERN_IS_REQUIRED = False
|
||||
SUPPORTS_CONVERT_TIMEZONE = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**Spark2.Generator.TYPE_MAPPING,
|
||||
|
|
|
@ -18,6 +18,7 @@ from sqlglot.dialects.dialect import (
|
|||
build_date_delta,
|
||||
rename_func,
|
||||
trim_sql,
|
||||
timestrtotime_sql,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.time import format_time
|
||||
|
@ -339,6 +340,16 @@ def _json_extract_sql(
|
|||
return self.func("ISNULL", json_query, json_value)
|
||||
|
||||
|
||||
def _timestrtotime_sql(self: TSQL.Generator, expression: exp.TimeStrToTime):
|
||||
sql = timestrtotime_sql(self, expression)
|
||||
if expression.args.get("zone"):
|
||||
# If there is a timezone, produce an expression like:
|
||||
# CAST('2020-01-01 12:13:14-08:00' AS DATETIMEOFFSET) AT TIME ZONE 'UTC'
|
||||
# If you dont have AT TIME ZONE 'UTC', wrapping that expression in another cast back to DATETIME2 just drops the timezone information
|
||||
return self.sql(exp.AtTimeZone(this=sql, zone=exp.Literal.string("UTC")))
|
||||
return sql
|
||||
|
||||
|
||||
class TSQL(Dialect):
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
|
||||
TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"
|
||||
|
@ -863,6 +874,7 @@ class TSQL(Dialect):
|
|||
transforms.eliminate_distinct_on,
|
||||
transforms.eliminate_semi_and_anti_joins,
|
||||
transforms.eliminate_qualify,
|
||||
transforms.unnest_generate_date_array_using_recursive_cte,
|
||||
]
|
||||
),
|
||||
exp.Stddev: rename_func("STDEV"),
|
||||
|
@ -875,9 +887,7 @@ class TSQL(Dialect):
|
|||
"HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
|
||||
),
|
||||
exp.TemporaryProperty: lambda self, e: "",
|
||||
exp.TimeStrToTime: lambda self, e: self.sql(
|
||||
exp.cast(e.this, exp.DataType.Type.DATETIME)
|
||||
),
|
||||
exp.TimeStrToTime: _timestrtotime_sql,
|
||||
exp.TimeToStr: _format_sql,
|
||||
exp.Trim: trim_sql,
|
||||
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
|
||||
|
@ -1139,11 +1149,11 @@ class TSQL(Dialect):
|
|||
def partition_sql(self, expression: exp.Partition) -> str:
|
||||
return f"WITH (PARTITIONS({self.expressions(expression, flat=True)}))"
|
||||
|
||||
def altertable_sql(self, expression: exp.AlterTable) -> str:
|
||||
def alter_sql(self, expression: exp.Alter) -> str:
|
||||
action = seq_get(expression.args.get("actions") or [], 0)
|
||||
if isinstance(action, exp.RenameTable):
|
||||
return f"EXEC sp_rename '{self.sql(expression.this)}', '{action.this.name}'"
|
||||
return super().altertable_sql(expression)
|
||||
return super().alter_sql(expression)
|
||||
|
||||
def drop_sql(self, expression: exp.Drop) -> str:
|
||||
if expression.args["kind"] == "VIEW":
|
||||
|
|
|
@ -74,7 +74,9 @@ def execute(
|
|||
raise ExecuteError("Tables must support the same table args as schema")
|
||||
|
||||
now = time.time()
|
||||
expression = optimize(sql, schema, leave_tables_isolated=True, dialect=read)
|
||||
expression = optimize(
|
||||
sql, schema, leave_tables_isolated=True, infer_csv_schemas=True, dialect=read
|
||||
)
|
||||
|
||||
logger.debug("Optimization finished: %f", time.time() - now)
|
||||
logger.debug("Optimized SQL: %s", expression.sql(pretty=True))
|
||||
|
|
|
@ -1387,6 +1387,7 @@ class Create(DDL):
|
|||
"exists": False,
|
||||
"properties": False,
|
||||
"replace": False,
|
||||
"refresh": False,
|
||||
"unique": False,
|
||||
"indexes": False,
|
||||
"no_schema_binding": False,
|
||||
|
@ -1436,7 +1437,13 @@ class Clone(Expression):
|
|||
|
||||
|
||||
class Describe(Expression):
|
||||
arg_types = {"this": True, "style": False, "kind": False, "expressions": False}
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"style": False,
|
||||
"kind": False,
|
||||
"expressions": False,
|
||||
"partition": False,
|
||||
}
|
||||
|
||||
|
||||
# https://duckdb.org/docs/guides/meta/summarize.html
|
||||
|
@ -2000,6 +2007,11 @@ class Drop(Expression):
|
|||
"cluster": False,
|
||||
}
|
||||
|
||||
@property
|
||||
def kind(self) -> t.Optional[str]:
|
||||
kind = self.args.get("kind")
|
||||
return kind and kind.upper()
|
||||
|
||||
|
||||
class Filter(Expression):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
@ -2158,6 +2170,8 @@ class Insert(DDL, DML):
|
|||
"ignore": False,
|
||||
"by_name": False,
|
||||
"stored": False,
|
||||
"partition": False,
|
||||
"settings": False,
|
||||
}
|
||||
|
||||
def with_(
|
||||
|
@ -2464,17 +2478,17 @@ class Offset(Expression):
|
|||
|
||||
|
||||
class Order(Expression):
|
||||
arg_types = {
|
||||
"this": False,
|
||||
"expressions": True,
|
||||
"interpolate": False,
|
||||
"siblings": False,
|
||||
}
|
||||
arg_types = {"this": False, "expressions": True, "siblings": False}
|
||||
|
||||
|
||||
# https://clickhouse.com/docs/en/sql-reference/statements/select/order-by#order-by-expr-with-fill-modifier
|
||||
class WithFill(Expression):
|
||||
arg_types = {"from": False, "to": False, "step": False}
|
||||
arg_types = {
|
||||
"from": False,
|
||||
"to": False,
|
||||
"step": False,
|
||||
"interpolate": False,
|
||||
}
|
||||
|
||||
|
||||
# hive specific sorts
|
||||
|
@ -2669,6 +2683,11 @@ class OnCluster(Property):
|
|||
arg_types = {"this": True}
|
||||
|
||||
|
||||
# Clickhouse EMPTY table "property"
|
||||
class EmptyProperty(Property):
|
||||
arg_types = {}
|
||||
|
||||
|
||||
class LikeProperty(Property):
|
||||
arg_types = {"this": True, "expressions": False}
|
||||
|
||||
|
@ -2735,6 +2754,10 @@ class PartitionedOfProperty(Property):
|
|||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
class StreamingTableProperty(Property):
|
||||
arg_types = {}
|
||||
|
||||
|
||||
class RemoteWithConnectionModelProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
@ -3137,11 +3160,11 @@ class SetOperation(Query):
|
|||
return self.this.unnest().selects
|
||||
|
||||
@property
|
||||
def left(self) -> Expression:
|
||||
def left(self) -> Query:
|
||||
return self.this
|
||||
|
||||
@property
|
||||
def right(self) -> Expression:
|
||||
def right(self) -> Query:
|
||||
return self.expression
|
||||
|
||||
|
||||
|
@ -3859,6 +3882,7 @@ class Pivot(Expression):
|
|||
"group": False,
|
||||
"columns": False,
|
||||
"include_nulls": False,
|
||||
"default_on_null": False,
|
||||
}
|
||||
|
||||
@property
|
||||
|
@ -3948,6 +3972,8 @@ class DataTypeParam(Expression):
|
|||
return self.this.name
|
||||
|
||||
|
||||
# The `nullable` arg is helpful when transpiling types from other dialects to ClickHouse, which
|
||||
# assumes non-nullable types by default. Values `None` and `True` mean the type is nullable.
|
||||
class DataType(Expression):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
|
@ -3956,6 +3982,7 @@ class DataType(Expression):
|
|||
"values": False,
|
||||
"prefix": False,
|
||||
"kind": False,
|
||||
"nullable": False,
|
||||
}
|
||||
|
||||
class Type(AutoName):
|
||||
|
@ -4194,28 +4221,45 @@ class DataType(Expression):
|
|||
|
||||
return DataType(**{**data_type_exp.args, **kwargs})
|
||||
|
||||
def is_type(self, *dtypes: DATA_TYPE) -> bool:
|
||||
def is_type(self, *dtypes: DATA_TYPE, check_nullable: bool = False) -> bool:
|
||||
"""
|
||||
Checks whether this DataType matches one of the provided data types. Nested types or precision
|
||||
will be compared using "structural equivalence" semantics, so e.g. array<int> != array<float>.
|
||||
|
||||
Args:
|
||||
dtypes: the data types to compare this DataType to.
|
||||
check_nullable: whether to take the NULLABLE type constructor into account for the comparison.
|
||||
If false, it means that NULLABLE<INT> is equivalent to INT.
|
||||
|
||||
Returns:
|
||||
True, if and only if there is a type in `dtypes` which is equal to this DataType.
|
||||
"""
|
||||
if (
|
||||
not check_nullable
|
||||
and self.this == DataType.Type.NULLABLE
|
||||
and len(self.expressions) == 1
|
||||
):
|
||||
this_type = self.expressions[0]
|
||||
else:
|
||||
this_type = self
|
||||
|
||||
for dtype in dtypes:
|
||||
other = DataType.build(dtype, copy=False, udt=True)
|
||||
other_type = DataType.build(dtype, copy=False, udt=True)
|
||||
if (
|
||||
not check_nullable
|
||||
and other_type.this == DataType.Type.NULLABLE
|
||||
and len(other_type.expressions) == 1
|
||||
):
|
||||
other_type = other_type.expressions[0]
|
||||
|
||||
if (
|
||||
other.expressions
|
||||
or self.this == DataType.Type.USERDEFINED
|
||||
or other.this == DataType.Type.USERDEFINED
|
||||
other_type.expressions
|
||||
or this_type.this == DataType.Type.USERDEFINED
|
||||
or other_type.this == DataType.Type.USERDEFINED
|
||||
):
|
||||
matches = self == other
|
||||
matches = this_type == other_type
|
||||
else:
|
||||
matches = self.this == other.this
|
||||
matches = this_type.this == other_type.this
|
||||
|
||||
if matches:
|
||||
return True
|
||||
|
@ -4270,9 +4314,10 @@ class Rollback(Expression):
|
|||
arg_types = {"savepoint": False, "this": False}
|
||||
|
||||
|
||||
class AlterTable(Expression):
|
||||
class Alter(Expression):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"kind": True,
|
||||
"actions": True,
|
||||
"exists": False,
|
||||
"only": False,
|
||||
|
@ -4536,6 +4581,12 @@ class PivotAlias(Alias):
|
|||
pass
|
||||
|
||||
|
||||
# Represents Snowflake's ANY [ ORDER BY ... ] syntax
|
||||
# https://docs.snowflake.com/en/sql-reference/constructs/pivot
|
||||
class PivotAny(Expression):
|
||||
arg_types = {"this": False}
|
||||
|
||||
|
||||
class Aliases(Expression):
|
||||
arg_types = {"this": True, "expressions": True}
|
||||
|
||||
|
@ -4790,7 +4841,7 @@ class ApproxDistinct(AggFunc):
|
|||
|
||||
|
||||
class Array(Func):
|
||||
arg_types = {"expressions": False}
|
||||
arg_types = {"expressions": False, "bracket_notation": False}
|
||||
is_var_len_args = True
|
||||
|
||||
|
||||
|
@ -4833,10 +4884,21 @@ class Convert(Func):
|
|||
arg_types = {"this": True, "expression": True, "style": False}
|
||||
|
||||
|
||||
class ConvertTimezone(Func):
|
||||
arg_types = {"source_tz": False, "target_tz": True, "timestamp": True}
|
||||
|
||||
|
||||
class GenerateSeries(Func):
|
||||
arg_types = {"start": True, "end": True, "step": False, "is_end_exclusive": False}
|
||||
|
||||
|
||||
# Postgres' GENERATE_SERIES function returns a row set, i.e. it implicitly explodes when it's
|
||||
# used in a projection, so this expression is a helper that facilitates transpilation to other
|
||||
# dialects. For example, we'd generate UNNEST(GENERATE_SERIES(...)) in DuckDB
|
||||
class ExplodingGenerateSeries(GenerateSeries):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayAgg(AggFunc):
|
||||
pass
|
||||
|
||||
|
@ -5025,7 +5087,7 @@ class Ceil(Func):
|
|||
|
||||
|
||||
class Coalesce(Func):
|
||||
arg_types = {"this": True, "expressions": False}
|
||||
arg_types = {"this": True, "expressions": False, "is_nvl": False}
|
||||
is_var_len_args = True
|
||||
_sql_names = ["COALESCE", "IFNULL", "NVL"]
|
||||
|
||||
|
@ -5077,7 +5139,7 @@ class CurrentTime(Func):
|
|||
|
||||
|
||||
class CurrentTimestamp(Func):
|
||||
arg_types = {"this": False, "transaction": False}
|
||||
arg_types = {"this": False, "sysdate": False}
|
||||
|
||||
|
||||
class CurrentUser(Func):
|
||||
|
@ -5286,6 +5348,7 @@ class Unnest(Func, UDTF):
|
|||
"expressions": True,
|
||||
"alias": False,
|
||||
"offset": False,
|
||||
"explode_array": False,
|
||||
}
|
||||
|
||||
@property
|
||||
|
@ -5309,6 +5372,11 @@ class ToBase64(Func):
|
|||
pass
|
||||
|
||||
|
||||
# https://trino.io/docs/current/functions/datetime.html#from_iso8601_timestamp
|
||||
class FromISO8601Timestamp(Func):
|
||||
_sql_names = ["FROM_ISO8601_TIMESTAMP"]
|
||||
|
||||
|
||||
class GapFill(Func):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
|
@ -5321,8 +5389,14 @@ class GapFill(Func):
|
|||
}
|
||||
|
||||
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/array_functions#generate_date_array
|
||||
class GenerateDateArray(Func):
|
||||
arg_types = {"start": True, "end": True, "interval": False}
|
||||
arg_types = {"start": True, "end": True, "step": False}
|
||||
|
||||
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/array_functions#generate_timestamp_array
|
||||
class GenerateTimestampArray(Func):
|
||||
arg_types = {"start": True, "end": True, "step": True}
|
||||
|
||||
|
||||
class Greatest(Func):
|
||||
|
@ -5639,6 +5713,10 @@ class ScopeResolution(Expression):
|
|||
arg_types = {"this": False, "expression": True}
|
||||
|
||||
|
||||
class Stream(Expression):
|
||||
pass
|
||||
|
||||
|
||||
class StarMap(Func):
|
||||
pass
|
||||
|
||||
|
@ -5920,7 +5998,7 @@ class Time(Func):
|
|||
|
||||
|
||||
class TimeToStr(Func):
|
||||
arg_types = {"this": True, "format": True, "culture": False, "timezone": False}
|
||||
arg_types = {"this": True, "format": True, "culture": False, "zone": False}
|
||||
|
||||
|
||||
class TimeToTimeStr(Func):
|
||||
|
@ -5936,7 +6014,7 @@ class TimeStrToDate(Func):
|
|||
|
||||
|
||||
class TimeStrToTime(Func):
|
||||
pass
|
||||
arg_types = {"this": True, "zone": False}
|
||||
|
||||
|
||||
class TimeStrToUnix(Func):
|
||||
|
@ -7144,7 +7222,9 @@ def column(
|
|||
return this
|
||||
|
||||
|
||||
def cast(expression: ExpOrStr, to: DATA_TYPE, copy: bool = True, **opts) -> Cast:
|
||||
def cast(
|
||||
expression: ExpOrStr, to: DATA_TYPE, copy: bool = True, dialect: DialectType = None, **opts
|
||||
) -> Cast:
|
||||
"""Cast an expression to a data type.
|
||||
|
||||
Example:
|
||||
|
@ -7155,15 +7235,37 @@ def cast(expression: ExpOrStr, to: DATA_TYPE, copy: bool = True, **opts) -> Cast
|
|||
expression: The expression to cast.
|
||||
to: The datatype to cast to.
|
||||
copy: Whether to copy the supplied expressions.
|
||||
dialect: The target dialect. This is used to prevent a re-cast in the following scenario:
|
||||
- The expression to be cast is already a exp.Cast expression
|
||||
- The existing cast is to a type that is logically equivalent to new type
|
||||
|
||||
For example, if :expression='CAST(x as DATETIME)' and :to=Type.TIMESTAMP,
|
||||
but in the target dialect DATETIME is mapped to TIMESTAMP, then we will NOT return `CAST(x (as DATETIME) as TIMESTAMP)`
|
||||
and instead just return the original expression `CAST(x as DATETIME)`.
|
||||
|
||||
This is to prevent it being output as a double cast `CAST(x (as TIMESTAMP) as TIMESTAMP)` once the DATETIME -> TIMESTAMP
|
||||
mapping is applied in the target dialect generator.
|
||||
|
||||
Returns:
|
||||
The new Cast instance.
|
||||
"""
|
||||
expr = maybe_parse(expression, copy=copy, **opts)
|
||||
data_type = DataType.build(to, copy=copy, **opts)
|
||||
expr = maybe_parse(expression, copy=copy, dialect=dialect, **opts)
|
||||
data_type = DataType.build(to, copy=copy, dialect=dialect, **opts)
|
||||
|
||||
if expr.is_type(data_type):
|
||||
return expr
|
||||
# dont re-cast if the expression is already a cast to the correct type
|
||||
if isinstance(expr, Cast):
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
|
||||
target_dialect = Dialect.get_or_raise(dialect)
|
||||
type_mapping = target_dialect.generator_class.TYPE_MAPPING
|
||||
|
||||
existing_cast_type: DataType.Type = expr.to.this
|
||||
new_cast_type: DataType.Type = data_type.this
|
||||
types_are_equivalent = type_mapping.get(
|
||||
existing_cast_type, existing_cast_type
|
||||
) == type_mapping.get(new_cast_type, new_cast_type)
|
||||
if expr.is_type(data_type) or types_are_equivalent:
|
||||
return expr
|
||||
|
||||
expr = Cast(this=expr, to=data_type)
|
||||
expr.type = data_type
|
||||
|
@ -7259,7 +7361,7 @@ def rename_table(
|
|||
old_name: str | Table,
|
||||
new_name: str | Table,
|
||||
dialect: DialectType = None,
|
||||
) -> AlterTable:
|
||||
) -> Alter:
|
||||
"""Build ALTER TABLE... RENAME... expression
|
||||
|
||||
Args:
|
||||
|
@ -7272,8 +7374,9 @@ def rename_table(
|
|||
"""
|
||||
old_table = to_table(old_name, dialect=dialect)
|
||||
new_table = to_table(new_name, dialect=dialect)
|
||||
return AlterTable(
|
||||
return Alter(
|
||||
this=old_table,
|
||||
kind="TABLE",
|
||||
actions=[
|
||||
RenameTable(this=new_table),
|
||||
],
|
||||
|
@ -7286,7 +7389,7 @@ def rename_column(
|
|||
new_column_name: str | Column,
|
||||
exists: t.Optional[bool] = None,
|
||||
dialect: DialectType = None,
|
||||
) -> AlterTable:
|
||||
) -> Alter:
|
||||
"""Build ALTER TABLE... RENAME COLUMN... expression
|
||||
|
||||
Args:
|
||||
|
@ -7302,8 +7405,9 @@ def rename_column(
|
|||
table = to_table(table_name, dialect=dialect)
|
||||
old_column = to_column(old_column_name, dialect=dialect)
|
||||
new_column = to_column(new_column_name, dialect=dialect)
|
||||
return AlterTable(
|
||||
return Alter(
|
||||
this=table,
|
||||
kind="TABLE",
|
||||
actions=[
|
||||
RenameColumn(this=old_column, to=new_column, exists=exists),
|
||||
],
|
||||
|
@ -7335,12 +7439,15 @@ def convert(value: t.Any, copy: bool = False) -> Expression:
|
|||
if isinstance(value, bytes):
|
||||
return HexString(this=value.hex())
|
||||
if isinstance(value, datetime.datetime):
|
||||
datetime_literal = Literal.string(
|
||||
(value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat(
|
||||
sep=" "
|
||||
)
|
||||
)
|
||||
return TimeStrToTime(this=datetime_literal)
|
||||
datetime_literal = Literal.string(value.isoformat(sep=" "))
|
||||
|
||||
tz = None
|
||||
if value.tzinfo:
|
||||
# this works for zoneinfo.ZoneInfo, pytz.timezone and datetime.datetime.utc to return IANA timezone names like "America/Los_Angeles"
|
||||
# instead of abbreviations like "PDT". This is for consistency with other timezone handling functions in SQLGlot
|
||||
tz = Literal.string(str(value.tzinfo))
|
||||
|
||||
return TimeStrToTime(this=datetime_literal, zone=tz)
|
||||
if isinstance(value, datetime.date):
|
||||
date_literal = Literal.string(value.strftime("%Y-%m-%d"))
|
||||
return DateStrToDate(this=date_literal)
|
||||
|
|
|
@ -92,6 +92,7 @@ class Generator(metaclass=_Generator):
|
|||
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
|
||||
exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}",
|
||||
exp.DynamicProperty: lambda *_: "DYNAMIC",
|
||||
exp.EmptyProperty: lambda *_: "EMPTY",
|
||||
exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}",
|
||||
exp.EphemeralColumnConstraint: lambda self,
|
||||
e: f"EPHEMERAL{(' ' + self.sql(e, 'this')) if e.this else ''}",
|
||||
|
@ -117,8 +118,10 @@ class Generator(metaclass=_Generator):
|
|||
e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS",
|
||||
exp.OnProperty: lambda self, e: f"ON {self.sql(e, 'this')}",
|
||||
exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
|
||||
exp.Operator: lambda self, e: self.binary(e, ""), # The operator is produced in `binary`
|
||||
exp.OutputModelProperty: lambda self, e: f"OUTPUT{self.sql(e, 'this')}",
|
||||
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
|
||||
exp.PivotAny: lambda self, e: f"ANY{self.sql(e, 'this')}",
|
||||
exp.ProjectionPolicyColumnConstraint: lambda self,
|
||||
e: f"PROJECTION POLICY {self.sql(e, 'this')}",
|
||||
exp.RemoteWithConnectionModelProperty: lambda self,
|
||||
|
@ -136,6 +139,8 @@ class Generator(metaclass=_Generator):
|
|||
exp.SqlSecurityProperty: lambda _,
|
||||
e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
|
||||
exp.StabilityProperty: lambda _, e: e.name,
|
||||
exp.Stream: lambda self, e: f"STREAM {self.sql(e, 'this')}",
|
||||
exp.StreamingTableProperty: lambda *_: "STREAMING",
|
||||
exp.StrictProperty: lambda *_: "STRICT",
|
||||
exp.TemporaryProperty: lambda *_: "TEMPORARY",
|
||||
exp.TagColumnConstraint: lambda self, e: f"TAG ({self.expressions(e, flat=True)})",
|
||||
|
@ -371,6 +376,18 @@ class Generator(metaclass=_Generator):
|
|||
# Whether the text pattern/fill (3rd) parameter of RPAD()/LPAD() is optional (defaults to space)
|
||||
PAD_FILL_PATTERN_IS_REQUIRED = False
|
||||
|
||||
# Whether a projection can explode into multiple rows, e.g. by unnesting an array.
|
||||
SUPPORTS_EXPLODING_PROJECTIONS = True
|
||||
|
||||
# Whether ARRAY_CONCAT can be generated with varlen args or if it should be reduced to 2-arg version
|
||||
ARRAY_CONCAT_IS_VAR_LEN = True
|
||||
|
||||
# Whether CONVERT_TIMEZONE() is supported; if not, it will be generated as exp.AtTimeZone
|
||||
SUPPORTS_CONVERT_TIMEZONE = False
|
||||
|
||||
# Whether nullable types can be constructed, e.g. `Nullable(Int64)`
|
||||
SUPPORTS_NULLABLE_TYPES = True
|
||||
|
||||
# The name to generate for the JSONPath expression. If `None`, only `this` will be generated
|
||||
PARSE_JSON_NAME: t.Optional[str] = "PARSE_JSON"
|
||||
|
||||
|
@ -439,6 +456,7 @@ class Generator(metaclass=_Generator):
|
|||
exp.DynamicProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.EmptyProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.EngineProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.ExecuteAsProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.ExternalProperty: exp.Properties.Location.POST_CREATE,
|
||||
|
@ -488,6 +506,7 @@ class Generator(metaclass=_Generator):
|
|||
exp.SqlReadWriteProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.StreamingTableProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.StrictProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.TemporaryProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.ToTableProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
|
@ -962,6 +981,7 @@ class Generator(metaclass=_Generator):
|
|||
|
||||
def create_sql(self, expression: exp.Create) -> str:
|
||||
kind = self.sql(expression, "kind")
|
||||
kind = self.dialect.INVERSE_CREATABLE_KIND_MAPPING.get(kind) or kind
|
||||
properties = expression.args.get("properties")
|
||||
properties_locs = self.locate_properties(properties) if properties else defaultdict()
|
||||
|
||||
|
@ -1018,6 +1038,7 @@ class Generator(metaclass=_Generator):
|
|||
index_sql = indexes + postindex_props_sql
|
||||
|
||||
replace = " OR REPLACE" if expression.args.get("replace") else ""
|
||||
refresh = " OR REFRESH" if expression.args.get("refresh") else ""
|
||||
unique = " UNIQUE" if expression.args.get("unique") else ""
|
||||
|
||||
clustered = expression.args.get("clustered")
|
||||
|
@ -1037,7 +1058,7 @@ class Generator(metaclass=_Generator):
|
|||
wrapped=False,
|
||||
)
|
||||
|
||||
modifiers = "".join((clustered_sql, replace, unique, postcreate_props_sql))
|
||||
modifiers = "".join((clustered_sql, replace, refresh, unique, postcreate_props_sql))
|
||||
|
||||
postexpression_props_sql = ""
|
||||
if properties_locs.get(exp.Properties.Location.POST_EXPRESSION):
|
||||
|
@ -1096,7 +1117,9 @@ class Generator(metaclass=_Generator):
|
|||
def describe_sql(self, expression: exp.Describe) -> str:
|
||||
style = expression.args.get("style")
|
||||
style = f" {style}" if style else ""
|
||||
return f"DESCRIBE{style} {self.sql(expression, 'this')}"
|
||||
partition = self.sql(expression, "partition")
|
||||
partition = f" {partition}" if partition else ""
|
||||
return f"DESCRIBE{style} {self.sql(expression, 'this')}{partition}"
|
||||
|
||||
def heredoc_sql(self, expression: exp.Heredoc) -> str:
|
||||
tag = self.sql(expression, "tag")
|
||||
|
@ -1195,20 +1218,21 @@ class Generator(metaclass=_Generator):
|
|||
return f"{this}{specifier}"
|
||||
|
||||
def datatype_sql(self, expression: exp.DataType) -> str:
|
||||
type_value = expression.this
|
||||
nested = ""
|
||||
values = ""
|
||||
interior = self.expressions(expression, flat=True)
|
||||
|
||||
type_value = expression.this
|
||||
if type_value == exp.DataType.Type.USERDEFINED and expression.args.get("kind"):
|
||||
type_sql = self.sql(expression, "kind")
|
||||
else:
|
||||
elif type_value != exp.DataType.Type.NULLABLE or self.SUPPORTS_NULLABLE_TYPES:
|
||||
type_sql = (
|
||||
self.TYPE_MAPPING.get(type_value, type_value.value)
|
||||
if isinstance(type_value, exp.DataType.Type)
|
||||
else type_value
|
||||
)
|
||||
|
||||
nested = ""
|
||||
interior = self.expressions(expression, flat=True)
|
||||
values = ""
|
||||
else:
|
||||
return interior
|
||||
|
||||
if interior:
|
||||
if expression.args.get("nested"):
|
||||
|
@ -1258,6 +1282,7 @@ class Generator(metaclass=_Generator):
|
|||
expressions = self.expressions(expression, flat=True)
|
||||
expressions = f" ({expressions})" if expressions else ""
|
||||
kind = expression.args["kind"]
|
||||
kind = self.dialect.INVERSE_CREATABLE_KIND_MAPPING.get(kind) or kind
|
||||
exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
|
||||
on_cluster = self.sql(expression, "cluster")
|
||||
on_cluster = f" {on_cluster}" if on_cluster else ""
|
||||
|
@ -1277,7 +1302,7 @@ class Generator(metaclass=_Generator):
|
|||
def fetch_sql(self, expression: exp.Fetch) -> str:
|
||||
direction = expression.args.get("direction")
|
||||
direction = f" {direction}" if direction else ""
|
||||
count = expression.args.get("count")
|
||||
count = self.sql(expression, "count")
|
||||
count = f" {count}" if count else ""
|
||||
if expression.args.get("percent"):
|
||||
count = f"{count} PERCENT"
|
||||
|
@ -1639,7 +1664,12 @@ class Generator(metaclass=_Generator):
|
|||
else:
|
||||
expression_sql = f"{returning}{expression_sql}{on_conflict}"
|
||||
|
||||
sql = f"INSERT{hint}{alternative}{ignore}{this}{stored}{by_name}{exists}{where}{expression_sql}"
|
||||
partition_by = self.sql(expression, "partition")
|
||||
partition_by = f" {partition_by}" if partition_by else ""
|
||||
settings = self.sql(expression, "settings")
|
||||
settings = f" {settings}" if settings else ""
|
||||
|
||||
sql = f"INSERT{hint}{alternative}{ignore}{this}{stored}{by_name}{exists}{partition_by}{settings}{where}{expression_sql}"
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
def intersect_sql(self, expression: exp.Intersect) -> str:
|
||||
|
@ -1824,13 +1854,20 @@ class Generator(metaclass=_Generator):
|
|||
alias = self.sql(expression, "alias")
|
||||
alias = f" AS {alias}" if alias else ""
|
||||
direction = self.seg("UNPIVOT" if expression.unpivot else "PIVOT")
|
||||
|
||||
field = self.sql(expression, "field")
|
||||
if field and isinstance(expression.args.get("field"), exp.PivotAny):
|
||||
field = f"IN ({field})"
|
||||
|
||||
include_nulls = expression.args.get("include_nulls")
|
||||
if include_nulls is not None:
|
||||
nulls = " INCLUDE NULLS " if include_nulls else " EXCLUDE NULLS "
|
||||
else:
|
||||
nulls = ""
|
||||
return f"{direction}{nulls}({expressions} FOR {field}){alias}"
|
||||
|
||||
default_on_null = self.sql(expression, "default_on_null")
|
||||
default_on_null = f" DEFAULT ON NULL ({default_on_null})" if default_on_null else ""
|
||||
return f"{direction}{nulls}({expressions} FOR {field}{default_on_null}){alias}"
|
||||
|
||||
def version_sql(self, expression: exp.Version) -> str:
|
||||
this = f"FOR {expression.name}"
|
||||
|
@ -2148,15 +2185,7 @@ class Generator(metaclass=_Generator):
|
|||
this = self.sql(expression, "this")
|
||||
this = f"{this} " if this else this
|
||||
siblings = "SIBLINGS " if expression.args.get("siblings") else ""
|
||||
order = self.op_expressions(f"{this}ORDER {siblings}BY", expression, flat=this or flat) # type: ignore
|
||||
interpolated_values = [
|
||||
f"{self.sql(named_expression, 'alias')} AS {self.sql(named_expression, 'this')}"
|
||||
for named_expression in expression.args.get("interpolate") or []
|
||||
]
|
||||
interpolate = (
|
||||
f" INTERPOLATE ({', '.join(interpolated_values)})" if interpolated_values else ""
|
||||
)
|
||||
return f"{order}{interpolate}"
|
||||
return self.op_expressions(f"{this}ORDER {siblings}BY", expression, flat=this or flat) # type: ignore
|
||||
|
||||
def withfill_sql(self, expression: exp.WithFill) -> str:
|
||||
from_sql = self.sql(expression, "from")
|
||||
|
@ -2165,7 +2194,14 @@ class Generator(metaclass=_Generator):
|
|||
to_sql = f" TO {to_sql}" if to_sql else ""
|
||||
step_sql = self.sql(expression, "step")
|
||||
step_sql = f" STEP {step_sql}" if step_sql else ""
|
||||
return f"WITH FILL{from_sql}{to_sql}{step_sql}"
|
||||
interpolated_values = [
|
||||
f"{self.sql(named_expression, 'alias')} AS {self.sql(named_expression, 'this')}"
|
||||
for named_expression in expression.args.get("interpolate") or []
|
||||
]
|
||||
interpolate = (
|
||||
f" INTERPOLATE ({', '.join(interpolated_values)})" if interpolated_values else ""
|
||||
)
|
||||
return f"WITH FILL{from_sql}{to_sql}{step_sql}{interpolate}"
|
||||
|
||||
def cluster_sql(self, expression: exp.Cluster) -> str:
|
||||
return self.op_expressions("CLUSTER BY", expression)
|
||||
|
@ -2875,11 +2911,13 @@ class Generator(metaclass=_Generator):
|
|||
|
||||
def pivotalias_sql(self, expression: exp.PivotAlias) -> str:
|
||||
alias = expression.args["alias"]
|
||||
|
||||
identifier_alias = isinstance(alias, exp.Identifier)
|
||||
literal_alias = isinstance(alias, exp.Literal)
|
||||
|
||||
if identifier_alias and not self.UNPIVOT_ALIASES_ARE_IDENTIFIERS:
|
||||
alias.replace(exp.Literal.string(alias.output_name))
|
||||
elif not identifier_alias and self.UNPIVOT_ALIASES_ARE_IDENTIFIERS:
|
||||
elif not identifier_alias and literal_alias and self.UNPIVOT_ALIASES_ARE_IDENTIFIERS:
|
||||
alias.replace(exp.to_identifier(alias.output_name))
|
||||
|
||||
return self.alias_sql(expression)
|
||||
|
@ -3103,7 +3141,7 @@ class Generator(metaclass=_Generator):
|
|||
exprs = self.expressions(expression, flat=True)
|
||||
return f"SET {exprs}"
|
||||
|
||||
def altertable_sql(self, expression: exp.AlterTable) -> str:
|
||||
def alter_sql(self, expression: exp.Alter) -> str:
|
||||
actions = expression.args["actions"]
|
||||
|
||||
if isinstance(actions[0], exp.ColumnDef):
|
||||
|
@ -3112,6 +3150,8 @@ class Generator(metaclass=_Generator):
|
|||
actions = self.expressions(expression, key="actions", prefix="ADD COLUMNS ")
|
||||
elif isinstance(actions[0], exp.Delete):
|
||||
actions = self.expressions(expression, key="actions", flat=True)
|
||||
elif isinstance(actions[0], exp.Query):
|
||||
actions = "AS " + self.expressions(expression, key="actions")
|
||||
else:
|
||||
actions = self.expressions(expression, key="actions", flat=True)
|
||||
|
||||
|
@ -3121,9 +3161,10 @@ class Generator(metaclass=_Generator):
|
|||
only = " ONLY" if expression.args.get("only") else ""
|
||||
options = self.expressions(expression, key="options")
|
||||
options = f", {options}" if options else ""
|
||||
return f"ALTER TABLE{exists}{only} {self.sql(expression, 'this')}{on_cluster} {actions}{options}"
|
||||
kind = self.sql(expression, "kind")
|
||||
return f"ALTER {kind}{exists}{only} {self.sql(expression, 'this')}{on_cluster} {actions}{options}"
|
||||
|
||||
def add_column_sql(self, expression: exp.AlterTable) -> str:
|
||||
def add_column_sql(self, expression: exp.Alter) -> str:
|
||||
if self.ALTER_TABLE_INCLUDE_COLUMN_KEYWORD:
|
||||
return self.expressions(
|
||||
expression,
|
||||
|
@ -3312,8 +3353,25 @@ class Generator(metaclass=_Generator):
|
|||
return f"USE{kind}{this}"
|
||||
|
||||
def binary(self, expression: exp.Binary, op: str) -> str:
|
||||
op = self.maybe_comment(op, comments=expression.comments)
|
||||
return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
|
||||
sqls: t.List[str] = []
|
||||
stack: t.List[t.Union[str, exp.Expression]] = [expression]
|
||||
binary_type = type(expression)
|
||||
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
|
||||
if type(node) is binary_type:
|
||||
op_func = node.args.get("operator")
|
||||
if op_func:
|
||||
op = f"OPERATOR({self.sql(op_func)})"
|
||||
|
||||
stack.append(node.right)
|
||||
stack.append(f" {self.maybe_comment(op, comments=node.comments)} ")
|
||||
stack.append(node.left)
|
||||
else:
|
||||
sqls.append(self.sql(node))
|
||||
|
||||
return "".join(sqls)
|
||||
|
||||
def function_fallback_sql(self, expression: exp.Func) -> str:
|
||||
args = []
|
||||
|
@ -3660,9 +3718,6 @@ class Generator(metaclass=_Generator):
|
|||
table = "" if isinstance(expression.this, exp.Literal) else "TABLE "
|
||||
return f"REFRESH {table}{this}"
|
||||
|
||||
def operator_sql(self, expression: exp.Operator) -> str:
|
||||
return self.binary(expression, f"OPERATOR({self.sql(expression, 'operator')})")
|
||||
|
||||
def toarray_sql(self, expression: exp.ToArray) -> str:
|
||||
arg = expression.this
|
||||
if not arg.type:
|
||||
|
@ -4041,3 +4096,44 @@ class Generator(metaclass=_Generator):
|
|||
def summarize_sql(self, expression: exp.Summarize) -> str:
|
||||
table = " TABLE" if expression.args.get("table") else ""
|
||||
return f"SUMMARIZE{table} {self.sql(expression.this)}"
|
||||
|
||||
def explodinggenerateseries_sql(self, expression: exp.ExplodingGenerateSeries) -> str:
|
||||
generate_series = exp.GenerateSeries(**expression.args)
|
||||
|
||||
parent = expression.parent
|
||||
if isinstance(parent, (exp.Alias, exp.TableAlias)):
|
||||
parent = parent.parent
|
||||
|
||||
if self.SUPPORTS_EXPLODING_PROJECTIONS and not isinstance(parent, (exp.Table, exp.Unnest)):
|
||||
return self.sql(exp.Unnest(expressions=[generate_series]))
|
||||
|
||||
if isinstance(parent, exp.Select):
|
||||
self.unsupported("GenerateSeries projection unnesting is not supported.")
|
||||
|
||||
return self.sql(generate_series)
|
||||
|
||||
def arrayconcat_sql(self, expression: exp.ArrayConcat, name: str = "ARRAY_CONCAT") -> str:
|
||||
exprs = expression.expressions
|
||||
if not self.ARRAY_CONCAT_IS_VAR_LEN:
|
||||
rhs = reduce(lambda x, y: exp.ArrayConcat(this=x, expressions=[y]), exprs)
|
||||
else:
|
||||
rhs = self.expressions(expression)
|
||||
|
||||
return self.func(name, expression.this, rhs)
|
||||
|
||||
def converttimezone_sql(self, expression: exp.ConvertTimezone) -> str:
|
||||
if self.SUPPORTS_CONVERT_TIMEZONE:
|
||||
return self.function_fallback_sql(expression)
|
||||
|
||||
source_tz = expression.args.get("source_tz")
|
||||
target_tz = expression.args.get("target_tz")
|
||||
timestamp = expression.args.get("timestamp")
|
||||
|
||||
if source_tz and timestamp:
|
||||
timestamp = exp.AtTimeZone(
|
||||
this=exp.cast(timestamp, exp.DataType.Type.TIMESTAMPNTZ), zone=source_tz
|
||||
)
|
||||
|
||||
expr = exp.AtTimeZone(this=timestamp, zone=target_tz)
|
||||
|
||||
return self.sql(expr)
|
||||
|
|
|
@ -128,6 +128,13 @@ class _TypeAnnotator(type):
|
|||
klass.COERCES_TO[data_type] = coerces_to.copy()
|
||||
coerces_to |= {data_type}
|
||||
|
||||
# NULL can be coerced to any type, so e.g. NULL + 1 will have type INT
|
||||
klass.COERCES_TO[exp.DataType.Type.NULL] = {
|
||||
*text_precedence,
|
||||
*numeric_precedence,
|
||||
*timelike_precedence,
|
||||
}
|
||||
|
||||
return klass
|
||||
|
||||
|
||||
|
@ -201,31 +208,47 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
for name, source in scope.sources.items():
|
||||
if not isinstance(source, Scope):
|
||||
continue
|
||||
if isinstance(source.expression, exp.UDTF):
|
||||
|
||||
expression = source.expression
|
||||
if isinstance(expression, exp.UDTF):
|
||||
values = []
|
||||
|
||||
if isinstance(source.expression, exp.Lateral):
|
||||
if isinstance(source.expression.this, exp.Explode):
|
||||
values = [source.expression.this.this]
|
||||
elif isinstance(source.expression, exp.Unnest):
|
||||
values = [source.expression]
|
||||
if isinstance(expression, exp.Lateral):
|
||||
if isinstance(expression.this, exp.Explode):
|
||||
values = [expression.this.this]
|
||||
elif isinstance(expression, exp.Unnest):
|
||||
values = [expression]
|
||||
else:
|
||||
values = source.expression.expressions[0].expressions
|
||||
values = expression.expressions[0].expressions
|
||||
|
||||
if not values:
|
||||
continue
|
||||
|
||||
selects[name] = {
|
||||
alias: column
|
||||
for alias, column in zip(
|
||||
source.expression.alias_column_names,
|
||||
values,
|
||||
)
|
||||
alias: column.type
|
||||
for alias, column in zip(expression.alias_column_names, values)
|
||||
}
|
||||
elif isinstance(expression, exp.SetOperation) and len(expression.left.selects) == len(
|
||||
expression.right.selects
|
||||
):
|
||||
if expression.args.get("by_name"):
|
||||
r_type_by_select = {s.alias_or_name: s.type for s in expression.right.selects}
|
||||
selects[name] = {
|
||||
s.alias_or_name: self._maybe_coerce(
|
||||
t.cast(exp.DataType, s.type),
|
||||
r_type_by_select.get(s.alias_or_name) or exp.DataType.Type.UNKNOWN,
|
||||
)
|
||||
for s in expression.left.selects
|
||||
}
|
||||
else:
|
||||
selects[name] = {
|
||||
ls.alias_or_name: self._maybe_coerce(
|
||||
t.cast(exp.DataType, ls.type), t.cast(exp.DataType, rs.type)
|
||||
)
|
||||
for ls, rs in zip(expression.left.selects, expression.right.selects)
|
||||
}
|
||||
else:
|
||||
selects[name] = {
|
||||
select.alias_or_name: select for select in source.expression.selects
|
||||
}
|
||||
selects[name] = {s.alias_or_name: s.type for s in expression.selects}
|
||||
|
||||
# First annotate the current scope's column references
|
||||
for col in scope.columns:
|
||||
|
@ -237,7 +260,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
self._set_type(col, self.schema.get_column_type(source, col))
|
||||
elif source:
|
||||
if col.table in selects and col.name in selects[col.table]:
|
||||
self._set_type(col, selects[col.table][col.name].type)
|
||||
self._set_type(col, selects[col.table][col.name])
|
||||
elif isinstance(source.expression, exp.Unnest):
|
||||
self._set_type(col, source.expression.type)
|
||||
|
||||
|
@ -264,15 +287,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
def _maybe_coerce(
|
||||
self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
|
||||
) -> exp.DataType | exp.DataType.Type:
|
||||
) -> exp.DataType:
|
||||
type1_value = type1.this if isinstance(type1, exp.DataType) else type1
|
||||
type2_value = type2.this if isinstance(type2, exp.DataType) else type2
|
||||
|
||||
# We propagate the NULL / UNKNOWN types upwards if found
|
||||
if exp.DataType.Type.NULL in (type1_value, type2_value):
|
||||
return exp.DataType.Type.NULL
|
||||
# We propagate the UNKNOWN type upwards if found
|
||||
if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
|
||||
return exp.DataType.Type.UNKNOWN
|
||||
return exp.DataType.build("unknown")
|
||||
|
||||
return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value
|
||||
|
||||
|
@ -282,17 +303,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
left, right = expression.left, expression.right
|
||||
left_type, right_type = left.type.this, right.type.this # type: ignore
|
||||
|
||||
if isinstance(expression, exp.Connector):
|
||||
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
|
||||
self._set_type(expression, exp.DataType.Type.NULL)
|
||||
elif exp.DataType.Type.NULL in (left_type, right_type):
|
||||
self._set_type(
|
||||
expression,
|
||||
exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")),
|
||||
)
|
||||
else:
|
||||
self._set_type(expression, exp.DataType.Type.BOOLEAN)
|
||||
elif isinstance(expression, exp.Predicate):
|
||||
if isinstance(expression, (exp.Connector, exp.Predicate)):
|
||||
self._set_type(expression, exp.DataType.Type.BOOLEAN)
|
||||
elif (left_type, right_type) in self.binary_coercions:
|
||||
self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right))
|
||||
|
@ -351,7 +362,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
last_datatype = expr_type
|
||||
break
|
||||
|
||||
if not expr_type.is_type(exp.DataType.Type.NULL, exp.DataType.Type.UNKNOWN):
|
||||
if not expr_type.is_type(exp.DataType.Type.UNKNOWN):
|
||||
last_datatype = self._maybe_coerce(last_datatype or expr_type, expr_type)
|
||||
|
||||
self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)
|
||||
|
|
|
@ -40,6 +40,8 @@ def replace_date_funcs(node: exp.Expression) -> exp.Expression:
|
|||
isinstance(node, (exp.Date, exp.TsOrDsToDate))
|
||||
and not node.expressions
|
||||
and not node.args.get("zone")
|
||||
and node.this.is_string
|
||||
and is_iso_date(node.this.name)
|
||||
):
|
||||
return exp.cast(node.this, to=exp.DataType.Type.DATE)
|
||||
if isinstance(node, exp.Timestamp) and not node.args.get("zone"):
|
||||
|
@ -90,6 +92,12 @@ def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
|
|||
and expression.to.this == expression.this.type.this
|
||||
):
|
||||
return expression.this
|
||||
if (
|
||||
isinstance(expression, (exp.Date, exp.TsOrDsToDate))
|
||||
and expression.this.type
|
||||
and expression.this.type.this == exp.DataType.Type.DATE
|
||||
):
|
||||
return expression.this
|
||||
return expression
|
||||
|
||||
|
||||
|
|
|
@ -19,8 +19,12 @@ def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.I
|
|||
|
||||
def normalize_identifiers(expression, dialect=None):
|
||||
"""
|
||||
Normalize all unquoted identifiers to either lower or upper case, depending
|
||||
on the dialect. This essentially makes those identifiers case-insensitive.
|
||||
Normalize identifiers by converting them to either lower or upper case,
|
||||
ensuring the semantics are preserved in each case (e.g. by respecting
|
||||
case-sensitivity).
|
||||
|
||||
This transformation reflects how identifiers would be resolved by the engine corresponding
|
||||
to each SQL dialect, and plays a very important role in the standardization of the AST.
|
||||
|
||||
It's possible to make this a no-op by adding a special comment next to the
|
||||
identifier of interest:
|
||||
|
@ -30,7 +34,7 @@ def normalize_identifiers(expression, dialect=None):
|
|||
In this example, the identifier `a` will not be normalized.
|
||||
|
||||
Note:
|
||||
Some dialects (e.g. BigQuery) treat identifiers as case-insensitive even
|
||||
Some dialects (e.g. DuckDB) treat all identifiers as case-insensitive even
|
||||
when they're quoted, so in these cases all identifiers are normalized.
|
||||
|
||||
Example:
|
||||
|
|
|
@ -30,6 +30,7 @@ def qualify(
|
|||
validate_qualify_columns: bool = True,
|
||||
quote_identifiers: bool = True,
|
||||
identify: bool = True,
|
||||
infer_csv_schemas: bool = False,
|
||||
) -> exp.Expression:
|
||||
"""
|
||||
Rewrite sqlglot AST to have normalized and qualified tables and columns.
|
||||
|
@ -60,13 +61,21 @@ def qualify(
|
|||
This step is necessary to ensure correctness for case sensitive queries.
|
||||
But this flag is provided in case this step is performed at a later time.
|
||||
identify: If True, quote all identifiers, else only necessary ones.
|
||||
infer_csv_schemas: Whether to scan READ_CSV calls in order to infer the CSVs' schemas.
|
||||
|
||||
Returns:
|
||||
The qualified expression.
|
||||
"""
|
||||
schema = ensure_schema(schema, dialect=dialect)
|
||||
expression = normalize_identifiers(expression, dialect=dialect)
|
||||
expression = qualify_tables(expression, db=db, catalog=catalog, schema=schema, dialect=dialect)
|
||||
expression = qualify_tables(
|
||||
expression,
|
||||
db=db,
|
||||
catalog=catalog,
|
||||
schema=schema,
|
||||
dialect=dialect,
|
||||
infer_csv_schemas=infer_csv_schemas,
|
||||
)
|
||||
|
||||
if isolate_tables:
|
||||
expression = isolate_table_selects(expression, schema=schema)
|
||||
|
|
|
@ -275,6 +275,17 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bo
|
|||
if isinstance(projection, exp.Alias):
|
||||
alias_to_expression[projection.alias] = (projection.this, i + 1)
|
||||
|
||||
parent_scope = scope
|
||||
while parent_scope.is_union:
|
||||
parent_scope = parent_scope.parent
|
||||
|
||||
# We shouldn't expand aliases if they match the recursive CTE's columns
|
||||
if parent_scope.is_cte:
|
||||
cte = parent_scope.expression.parent
|
||||
if cte.find_ancestor(exp.With).recursive:
|
||||
for recursive_cte_column in cte.args["alias"].columns or cte.this.selects:
|
||||
alias_to_expression.pop(recursive_cte_column.output_name, None)
|
||||
|
||||
replace_columns(expression.args.get("where"))
|
||||
replace_columns(expression.args.get("group"), literal_index=True)
|
||||
replace_columns(expression.args.get("having"), resolve_table=True)
|
||||
|
|
|
@ -18,6 +18,7 @@ def qualify_tables(
|
|||
db: t.Optional[str | exp.Identifier] = None,
|
||||
catalog: t.Optional[str | exp.Identifier] = None,
|
||||
schema: t.Optional[Schema] = None,
|
||||
infer_csv_schemas: bool = False,
|
||||
dialect: DialectType = None,
|
||||
) -> E:
|
||||
"""
|
||||
|
@ -39,6 +40,7 @@ def qualify_tables(
|
|||
db: Database name
|
||||
catalog: Catalog name
|
||||
schema: A schema to populate
|
||||
infer_csv_schemas: Whether to scan READ_CSV calls in order to infer the CSVs' schemas.
|
||||
dialect: The dialect to parse catalog and schema into.
|
||||
|
||||
Returns:
|
||||
|
@ -102,7 +104,7 @@ def qualify_tables(
|
|||
"alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))
|
||||
)
|
||||
|
||||
if schema and isinstance(source.this, exp.ReadCSV):
|
||||
if infer_csv_schemas and schema and isinstance(source.this, exp.ReadCSV):
|
||||
with csv_reader(source.this) as reader:
|
||||
header = next(reader)
|
||||
columns = next(reader)
|
||||
|
|
|
@ -65,6 +65,7 @@ class Scope:
|
|||
scope_type=ScopeType.ROOT,
|
||||
lateral_sources=None,
|
||||
cte_sources=None,
|
||||
can_be_correlated=None,
|
||||
):
|
||||
self.expression = expression
|
||||
self.sources = sources or {}
|
||||
|
@ -81,6 +82,7 @@ class Scope:
|
|||
self.cte_scopes = []
|
||||
self.union_scopes = []
|
||||
self.udtf_scopes = []
|
||||
self.can_be_correlated = can_be_correlated
|
||||
self.clear_cache()
|
||||
|
||||
def clear_cache(self):
|
||||
|
@ -110,6 +112,8 @@ class Scope:
|
|||
scope_type=scope_type,
|
||||
cte_sources={**self.cte_sources, **(cte_sources or {})},
|
||||
lateral_sources=lateral_sources.copy() if lateral_sources else None,
|
||||
can_be_correlated=self.can_be_correlated
|
||||
or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -261,7 +265,11 @@ class Scope:
|
|||
|
||||
external_columns = [
|
||||
column
|
||||
for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes)
|
||||
for scope in itertools.chain(
|
||||
self.subquery_scopes,
|
||||
self.udtf_scopes,
|
||||
(dts for dts in self.derived_table_scopes if dts.can_be_correlated),
|
||||
)
|
||||
for column in scope.external_columns
|
||||
]
|
||||
|
||||
|
@ -425,10 +433,7 @@ class Scope:
|
|||
@property
|
||||
def is_correlated_subquery(self):
|
||||
"""Determine if this scope is a correlated subquery"""
|
||||
return bool(
|
||||
(self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral)))
|
||||
and self.external_columns
|
||||
)
|
||||
return bool(self.can_be_correlated and self.external_columns)
|
||||
|
||||
def rename_source(self, old_name, new_name):
|
||||
"""Rename a source in this scope"""
|
||||
|
|
|
@ -117,6 +117,29 @@ def build_pad(args: t.List, is_left: bool = True):
|
|||
)
|
||||
|
||||
|
||||
def build_array_constructor(
|
||||
exp_class: t.Type[E], args: t.List, bracket_kind: TokenType, dialect: Dialect
|
||||
) -> exp.Expression:
|
||||
array_exp = exp_class(expressions=args)
|
||||
|
||||
if exp_class == exp.Array and dialect.HAS_DISTINCT_ARRAY_CONSTRUCTORS:
|
||||
array_exp.set("bracket_notation", bracket_kind == TokenType.L_BRACKET)
|
||||
|
||||
return array_exp
|
||||
|
||||
|
||||
def build_convert_timezone(
|
||||
args: t.List, default_source_tz: t.Optional[str] = None
|
||||
) -> t.Union[exp.ConvertTimezone, exp.Anonymous]:
|
||||
if len(args) == 2:
|
||||
source_tz = exp.Literal.string(default_source_tz) if default_source_tz else None
|
||||
return exp.ConvertTimezone(
|
||||
source_tz=source_tz, target_tz=seq_get(args, 0), timestamp=seq_get(args, 1)
|
||||
)
|
||||
|
||||
return exp.ConvertTimezone.from_arg_list(args)
|
||||
|
||||
|
||||
class _Parser(type):
|
||||
def __new__(cls, clsname, bases, attrs):
|
||||
klass = super().__new__(cls, clsname, bases, attrs)
|
||||
|
@ -144,6 +167,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
FUNCTIONS: t.Dict[str, t.Callable] = {
|
||||
**{name: func.from_arg_list for name, func in exp.FUNCTION_BY_NAME.items()},
|
||||
"ARRAY": lambda args, dialect: exp.Array(expressions=args),
|
||||
"CONCAT": lambda args, dialect: exp.Concat(
|
||||
expressions=args,
|
||||
safe=not dialect.STRICT_STRING_CONCAT,
|
||||
|
@ -154,10 +178,16 @@ class Parser(metaclass=_Parser):
|
|||
safe=not dialect.STRICT_STRING_CONCAT,
|
||||
coalesce=dialect.CONCAT_COALESCE,
|
||||
),
|
||||
"CONVERT_TIMEZONE": build_convert_timezone,
|
||||
"DATE_TO_DATE_STR": lambda args: exp.Cast(
|
||||
this=seq_get(args, 0),
|
||||
to=exp.DataType(this=exp.DataType.Type.TEXT),
|
||||
),
|
||||
"GENERATE_DATE_ARRAY": lambda args: exp.GenerateDateArray(
|
||||
start=seq_get(args, 0),
|
||||
end=seq_get(args, 1),
|
||||
step=seq_get(args, 2) or exp.Interval(this=exp.Literal.number(1), unit=exp.var("DAY")),
|
||||
),
|
||||
"GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)),
|
||||
"HEX": build_hex,
|
||||
"JSON_EXTRACT": build_extract_json_with_path(exp.JSONExtract),
|
||||
|
@ -192,6 +222,7 @@ class Parser(metaclass=_Parser):
|
|||
"UNNEST": lambda args: exp.Unnest(expressions=ensure_list(seq_get(args, 0))),
|
||||
"UPPER": build_upper,
|
||||
"VAR_MAP": build_var_map,
|
||||
"COALESCE": lambda args: exp.Coalesce(this=seq_get(args, 0), expressions=args[1:]),
|
||||
}
|
||||
|
||||
NO_PAREN_FUNCTIONS = {
|
||||
|
@ -374,6 +405,11 @@ class Parser(metaclass=_Parser):
|
|||
*DB_CREATABLES,
|
||||
}
|
||||
|
||||
ALTERABLES = {
|
||||
TokenType.TABLE,
|
||||
TokenType.VIEW,
|
||||
}
|
||||
|
||||
# Tokens that can represent identifiers
|
||||
ID_VAR_TOKENS = {
|
||||
TokenType.ALL,
|
||||
|
@ -433,6 +469,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.RECURSIVE,
|
||||
TokenType.REFERENCES,
|
||||
TokenType.REFRESH,
|
||||
TokenType.RENAME,
|
||||
TokenType.REPLACE,
|
||||
TokenType.RIGHT,
|
||||
TokenType.ROLLUP,
|
||||
|
@ -842,6 +879,7 @@ class Parser(metaclass=_Parser):
|
|||
"DYNAMIC": lambda self: self.expression(exp.DynamicProperty),
|
||||
"DISTKEY": lambda self: self._parse_distkey(),
|
||||
"DISTSTYLE": lambda self: self._parse_property_assignment(exp.DistStyleProperty),
|
||||
"EMPTY": lambda self: self.expression(exp.EmptyProperty),
|
||||
"ENGINE": lambda self: self._parse_property_assignment(exp.EngineProperty),
|
||||
"EXECUTE": lambda self: self._parse_property_assignment(exp.ExecuteAsProperty),
|
||||
"EXTERNAL": lambda self: self.expression(exp.ExternalProperty),
|
||||
|
@ -885,6 +923,7 @@ class Parser(metaclass=_Parser):
|
|||
"REMOTE": lambda self: self._parse_remote_with_connection(),
|
||||
"RETURNS": lambda self: self._parse_returns(),
|
||||
"STRICT": lambda self: self.expression(exp.StrictProperty),
|
||||
"STREAMING": lambda self: self.expression(exp.StreamingTableProperty),
|
||||
"ROW": lambda self: self._parse_row(),
|
||||
"ROW_FORMAT": lambda self: self._parse_property_assignment(exp.RowFormatProperty),
|
||||
"SAMPLE": lambda self: self.expression(
|
||||
|
@ -892,9 +931,7 @@ class Parser(metaclass=_Parser):
|
|||
),
|
||||
"SECURE": lambda self: self.expression(exp.SecureProperty),
|
||||
"SET": lambda self: self.expression(exp.SetProperty, multi=False),
|
||||
"SETTINGS": lambda self: self.expression(
|
||||
exp.SettingsProperty, expressions=self._parse_csv(self._parse_set_item)
|
||||
),
|
||||
"SETTINGS": lambda self: self._parse_settings_property(),
|
||||
"SHARING": lambda self: self._parse_property_assignment(exp.SharingProperty),
|
||||
"SORTKEY": lambda self: self._parse_sortkey(),
|
||||
"SOURCE": lambda self: self._parse_dict_property(this="SOURCE"),
|
||||
|
@ -992,6 +1029,7 @@ class Parser(metaclass=_Parser):
|
|||
"DROP": lambda self: self._parse_alter_table_drop(),
|
||||
"RENAME": lambda self: self._parse_alter_table_rename(),
|
||||
"SET": lambda self: self._parse_alter_table_set(),
|
||||
"AS": lambda self: self._parse_select(),
|
||||
}
|
||||
|
||||
ALTER_ALTER_PARSERS = {
|
||||
|
@ -1628,7 +1666,7 @@ class Parser(metaclass=_Parser):
|
|||
temporary = self._match(TokenType.TEMPORARY)
|
||||
materialized = self._match_text_seq("MATERIALIZED")
|
||||
|
||||
kind = self._match_set(self.CREATABLES) and self._prev.text
|
||||
kind = self._match_set(self.CREATABLES) and self._prev.text.upper()
|
||||
if not kind:
|
||||
return self._parse_as_command(start)
|
||||
|
||||
|
@ -1650,7 +1688,7 @@ class Parser(metaclass=_Parser):
|
|||
exists=if_exists,
|
||||
this=table,
|
||||
expressions=expressions,
|
||||
kind=kind.upper(),
|
||||
kind=self.dialect.CREATABLE_KIND_MAPPING.get(kind) or kind,
|
||||
temporary=temporary,
|
||||
materialized=materialized,
|
||||
cascade=self._match_text_seq("CASCADE"),
|
||||
|
@ -1676,6 +1714,7 @@ class Parser(metaclass=_Parser):
|
|||
or self._match_pair(TokenType.OR, TokenType.REPLACE)
|
||||
or self._match_pair(TokenType.OR, TokenType.ALTER)
|
||||
)
|
||||
refresh = self._match_pair(TokenType.OR, TokenType.REFRESH)
|
||||
|
||||
unique = self._match(TokenType.UNIQUE)
|
||||
|
||||
|
@ -1792,7 +1831,6 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
# exp.Properties.Location.POST_INDEX
|
||||
extend_props(self._parse_properties())
|
||||
|
||||
if not index:
|
||||
break
|
||||
else:
|
||||
|
@ -1813,12 +1851,14 @@ class Parser(metaclass=_Parser):
|
|||
if self._curr and not self._match_set((TokenType.R_PAREN, TokenType.COMMA), advance=False):
|
||||
return self._parse_as_command(start)
|
||||
|
||||
create_kind_text = create_token.text.upper()
|
||||
return self.expression(
|
||||
exp.Create,
|
||||
comments=comments,
|
||||
this=this,
|
||||
kind=create_token.text.upper(),
|
||||
kind=self.dialect.CREATABLE_KIND_MAPPING.get(create_kind_text) or create_kind_text,
|
||||
replace=replace,
|
||||
refresh=refresh,
|
||||
unique=unique,
|
||||
expression=expression,
|
||||
exists=exists,
|
||||
|
@ -1979,6 +2019,11 @@ class Parser(metaclass=_Parser):
|
|||
exp.FallbackProperty, no=no, protection=self._match_text_seq("PROTECTION")
|
||||
)
|
||||
|
||||
def _parse_settings_property(self) -> exp.SettingsProperty:
|
||||
return self.expression(
|
||||
exp.SettingsProperty, expressions=self._parse_csv(self._parse_assignment)
|
||||
)
|
||||
|
||||
def _parse_volatile_property(self) -> exp.VolatileProperty | exp.StabilityProperty:
|
||||
if self._index >= 2:
|
||||
pre_volatile_token = self._tokens[self._index - 2]
|
||||
|
@ -2451,8 +2496,14 @@ class Parser(metaclass=_Parser):
|
|||
this = self._parse_table(schema=True)
|
||||
properties = self._parse_properties()
|
||||
expressions = properties.expressions if properties else None
|
||||
partition = self._parse_partition()
|
||||
return self.expression(
|
||||
exp.Describe, this=this, style=style, kind=kind, expressions=expressions
|
||||
exp.Describe,
|
||||
this=this,
|
||||
style=style,
|
||||
kind=kind,
|
||||
expressions=expressions,
|
||||
partition=partition,
|
||||
)
|
||||
|
||||
def _parse_insert(self) -> exp.Insert:
|
||||
|
@ -2498,6 +2549,8 @@ class Parser(metaclass=_Parser):
|
|||
by_name=self._match_text_seq("BY", "NAME"),
|
||||
exists=self._parse_exists(),
|
||||
where=self._match_pair(TokenType.REPLACE, TokenType.WHERE) and self._parse_assignment(),
|
||||
partition=self._match(TokenType.PARTITION_BY) and self._parse_partitioned_by(),
|
||||
settings=self._match_text_seq("SETTINGS") and self._parse_settings_property(),
|
||||
expression=self._parse_derived_table_values() or self._parse_ddl_select(),
|
||||
conflict=self._parse_on_conflict(),
|
||||
returning=returning or self._parse_returning(),
|
||||
|
@ -2830,6 +2883,10 @@ class Parser(metaclass=_Parser):
|
|||
table = self._match(TokenType.TABLE)
|
||||
this = self._parse_select() or self._parse_string() or self._parse_table()
|
||||
return self.expression(exp.Summarize, this=this, table=table)
|
||||
elif self._match(TokenType.DESCRIBE):
|
||||
this = self._parse_describe()
|
||||
elif self._match_text_seq("STREAM"):
|
||||
this = self.expression(exp.Stream, this=self._parse_function())
|
||||
else:
|
||||
this = None
|
||||
|
||||
|
@ -3173,6 +3230,15 @@ class Parser(metaclass=_Parser):
|
|||
self._match_set(self.JOIN_KINDS) and self._prev,
|
||||
)
|
||||
|
||||
def _parse_using_identifiers(self) -> t.List[exp.Expression]:
|
||||
def _parse_column_as_identifier() -> t.Optional[exp.Expression]:
|
||||
this = self._parse_column()
|
||||
if isinstance(this, exp.Column):
|
||||
return this.this
|
||||
return this
|
||||
|
||||
return self._parse_wrapped_csv(_parse_column_as_identifier, optional=True)
|
||||
|
||||
def _parse_join(
|
||||
self, skip_join_token: bool = False, parse_bracket: bool = False
|
||||
) -> t.Optional[exp.Join]:
|
||||
|
@ -3213,9 +3279,11 @@ class Parser(metaclass=_Parser):
|
|||
if self._match(TokenType.ON):
|
||||
kwargs["on"] = self._parse_assignment()
|
||||
elif self._match(TokenType.USING):
|
||||
kwargs["using"] = self._parse_wrapped_id_vars()
|
||||
elif not isinstance(kwargs["this"], exp.Unnest) and not (
|
||||
kind and kind.token_type == TokenType.CROSS
|
||||
kwargs["using"] = self._parse_using_identifiers()
|
||||
elif (
|
||||
not (outer_apply or cross_apply)
|
||||
and not isinstance(kwargs["this"], exp.Unnest)
|
||||
and not (kind and kind.token_type == TokenType.CROSS)
|
||||
):
|
||||
index = self._index
|
||||
joins: t.Optional[list] = list(self._parse_joins())
|
||||
|
@ -3223,7 +3291,7 @@ class Parser(metaclass=_Parser):
|
|||
if joins and self._match(TokenType.ON):
|
||||
kwargs["on"] = self._parse_assignment()
|
||||
elif joins and self._match(TokenType.USING):
|
||||
kwargs["using"] = self._parse_wrapped_id_vars()
|
||||
kwargs["using"] = self._parse_using_identifiers()
|
||||
else:
|
||||
joins = None
|
||||
self._retreat(index)
|
||||
|
@ -3607,7 +3675,10 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
def _parse_derived_table_values(self) -> t.Optional[exp.Values]:
|
||||
is_derived = self._match_pair(TokenType.L_PAREN, TokenType.VALUES)
|
||||
if not is_derived and not self._match_text_seq("VALUES"):
|
||||
if not is_derived and not (
|
||||
# ClickHouse's `FORMAT Values` is equivalent to `VALUES`
|
||||
self._match_text_seq("VALUES") or self._match_text_seq("FORMAT", "VALUES")
|
||||
):
|
||||
return None
|
||||
|
||||
expressions = self._parse_csv(self._parse_value)
|
||||
|
@ -3707,13 +3778,15 @@ class Parser(metaclass=_Parser):
|
|||
exp.Pivot, this=this, expressions=expressions, using=using, group=group
|
||||
)
|
||||
|
||||
def _parse_pivot_in(self) -> exp.In:
|
||||
def _parse_pivot_in(self) -> exp.In | exp.PivotAny:
|
||||
def _parse_aliased_expression() -> t.Optional[exp.Expression]:
|
||||
this = self._parse_assignment()
|
||||
this = self._parse_select_or_expression()
|
||||
|
||||
self._match(TokenType.ALIAS)
|
||||
alias = self._parse_field()
|
||||
alias = self._parse_bitwise()
|
||||
if alias:
|
||||
if isinstance(alias, exp.Column) and not alias.db:
|
||||
alias = alias.this
|
||||
return self.expression(exp.PivotAlias, this=this, alias=alias)
|
||||
|
||||
return this
|
||||
|
@ -3723,10 +3796,14 @@ class Parser(metaclass=_Parser):
|
|||
if not self._match_pair(TokenType.IN, TokenType.L_PAREN):
|
||||
self.raise_error("Expecting IN (")
|
||||
|
||||
aliased_expressions = self._parse_csv(_parse_aliased_expression)
|
||||
if self._match(TokenType.ANY):
|
||||
expr: exp.PivotAny | exp.In = self.expression(exp.PivotAny, this=self._parse_order())
|
||||
else:
|
||||
aliased_expressions = self._parse_csv(_parse_aliased_expression)
|
||||
expr = self.expression(exp.In, this=value, expressions=aliased_expressions)
|
||||
|
||||
self._match_r_paren()
|
||||
return self.expression(exp.In, this=value, expressions=aliased_expressions)
|
||||
return expr
|
||||
|
||||
def _parse_pivot(self) -> t.Optional[exp.Pivot]:
|
||||
index = self._index
|
||||
|
@ -3763,6 +3840,9 @@ class Parser(metaclass=_Parser):
|
|||
self.raise_error("Expecting FOR")
|
||||
|
||||
field = self._parse_pivot_in()
|
||||
default_on_null = self._match_text_seq("DEFAULT", "ON", "NULL") and self._parse_wrapped(
|
||||
self._parse_bitwise
|
||||
)
|
||||
|
||||
self._match_r_paren()
|
||||
|
||||
|
@ -3772,6 +3852,7 @@ class Parser(metaclass=_Parser):
|
|||
field=field,
|
||||
unpivot=unpivot,
|
||||
include_nulls=include_nulls,
|
||||
default_on_null=default_on_null,
|
||||
)
|
||||
|
||||
if not self._match_set((TokenType.PIVOT, TokenType.UNPIVOT), advance=False):
|
||||
|
@ -3934,7 +4015,6 @@ class Parser(metaclass=_Parser):
|
|||
exp.Order,
|
||||
this=this,
|
||||
expressions=self._parse_csv(self._parse_ordered),
|
||||
interpolate=self._parse_interpolate(),
|
||||
siblings=siblings,
|
||||
)
|
||||
|
||||
|
@ -3979,6 +4059,7 @@ class Parser(metaclass=_Parser):
|
|||
"from": self._match(TokenType.FROM) and self._parse_bitwise(),
|
||||
"to": self._match_text_seq("TO") and self._parse_bitwise(),
|
||||
"step": self._match_text_seq("STEP") and self._parse_bitwise(),
|
||||
"interpolate": self._parse_interpolate(),
|
||||
},
|
||||
)
|
||||
else:
|
||||
|
@ -4132,6 +4213,11 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
def _parse_assignment(self) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_disjunction()
|
||||
if not this and self._next and self._next.token_type in self.ASSIGNMENT:
|
||||
# This allows us to parse <non-identifier token> := <expr>
|
||||
this = exp.column(
|
||||
t.cast(str, self._advance_any(ignore_reserved=True) and self._prev.text)
|
||||
)
|
||||
|
||||
while self._match_set(self.ASSIGNMENT):
|
||||
this = self.expression(
|
||||
|
@ -4175,13 +4261,19 @@ class Parser(metaclass=_Parser):
|
|||
this = self.expression(exp.Not, this=this)
|
||||
|
||||
if negate:
|
||||
this = self.expression(exp.Not, this=this)
|
||||
this = self._negate_range(this)
|
||||
|
||||
if self._match(TokenType.IS):
|
||||
this = self._parse_is(this)
|
||||
|
||||
return this
|
||||
|
||||
def _negate_range(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
|
||||
if not this:
|
||||
return this
|
||||
|
||||
return self.expression(exp.Not, this=this)
|
||||
|
||||
def _parse_is(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
index = self._index - 1
|
||||
negate = self._match(TokenType.NOT)
|
||||
|
@ -4322,7 +4414,26 @@ class Parser(metaclass=_Parser):
|
|||
return this
|
||||
|
||||
def _parse_term(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_tokens(self._parse_factor, self.TERM)
|
||||
this = self._parse_factor()
|
||||
|
||||
while self._match_set(self.TERM):
|
||||
klass = self.TERM[self._prev.token_type]
|
||||
comments = self._prev_comments
|
||||
expression = self._parse_factor()
|
||||
|
||||
this = self.expression(klass, this=this, comments=comments, expression=expression)
|
||||
|
||||
if isinstance(this, exp.Collate):
|
||||
expr = this.expression
|
||||
|
||||
# Preserve collations such as pg_catalog."default" (Postgres) as columns, otherwise
|
||||
# fallback to Identifier / Var
|
||||
if isinstance(expr, exp.Column) and len(expr.parts) == 1:
|
||||
ident = expr.this
|
||||
if isinstance(ident, exp.Identifier):
|
||||
this.set("expression", ident if ident.quoted else exp.var(ident.name))
|
||||
|
||||
return this
|
||||
|
||||
def _parse_factor(self) -> t.Optional[exp.Expression]:
|
||||
parse_method = self._parse_exponent if self.EXPONENT else self._parse_unary
|
||||
|
@ -4610,6 +4721,7 @@ class Parser(metaclass=_Parser):
|
|||
matched_array = self._match(TokenType.ARRAY)
|
||||
|
||||
while self._curr:
|
||||
datatype_token = self._prev.token_type
|
||||
matched_l_bracket = self._match(TokenType.L_BRACKET)
|
||||
if not matched_l_bracket and not matched_array:
|
||||
break
|
||||
|
@ -4619,8 +4731,12 @@ class Parser(metaclass=_Parser):
|
|||
if (
|
||||
values
|
||||
and not schema
|
||||
and this.is_type(exp.DataType.Type.ARRAY, exp.DataType.Type.MAP)
|
||||
and (
|
||||
not self.dialect.SUPPORTS_FIXED_SIZE_ARRAYS or datatype_token == TokenType.ARRAY
|
||||
)
|
||||
):
|
||||
# Retreating here means that we should not parse the following values as part of the data type, e.g. in DuckDB
|
||||
# ARRAY[1] should retreat and instead be parsed into exp.Array in contrast to INT[x][y] which denotes a fixed-size array data type
|
||||
self._retreat(index)
|
||||
break
|
||||
|
||||
|
@ -5407,11 +5523,18 @@ class Parser(metaclass=_Parser):
|
|||
if bracket_kind == TokenType.L_BRACE:
|
||||
this = self.expression(exp.Struct, expressions=self._kv_to_prop_eq(expressions))
|
||||
elif not this:
|
||||
this = self.expression(exp.Array, expressions=expressions)
|
||||
this = build_array_constructor(
|
||||
exp.Array, args=expressions, bracket_kind=bracket_kind, dialect=self.dialect
|
||||
)
|
||||
else:
|
||||
constructor_type = self.ARRAY_CONSTRUCTORS.get(this.name.upper())
|
||||
if constructor_type:
|
||||
return self.expression(constructor_type, expressions=expressions)
|
||||
return build_array_constructor(
|
||||
constructor_type,
|
||||
args=expressions,
|
||||
bracket_kind=bracket_kind,
|
||||
dialect=self.dialect,
|
||||
)
|
||||
|
||||
expressions = apply_index_offset(this, expressions, -self.dialect.INDEX_OFFSET)
|
||||
this = self.expression(exp.Bracket, this=this, expressions=expressions)
|
||||
|
@ -6440,10 +6563,11 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return alter_set
|
||||
|
||||
def _parse_alter(self) -> exp.AlterTable | exp.Command:
|
||||
def _parse_alter(self) -> exp.Alter | exp.Command:
|
||||
start = self._prev
|
||||
|
||||
if not self._match(TokenType.TABLE):
|
||||
alter_token = self._match_set(self.ALTERABLES) and self._prev
|
||||
if not alter_token:
|
||||
return self._parse_as_command(start)
|
||||
|
||||
exists = self._parse_exists()
|
||||
|
@ -6461,8 +6585,9 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
if not self._curr and actions:
|
||||
return self.expression(
|
||||
exp.AlterTable,
|
||||
exp.Alter,
|
||||
this=this,
|
||||
kind=alter_token.text.upper(),
|
||||
exists=exists,
|
||||
actions=actions,
|
||||
only=only,
|
||||
|
|
|
@ -340,6 +340,7 @@ class TokenType(AutoName):
|
|||
RANGE = auto()
|
||||
RECURSIVE = auto()
|
||||
REFRESH = auto()
|
||||
RENAME = auto()
|
||||
REPLACE = auto()
|
||||
RETURNING = auto()
|
||||
REFERENCES = auto()
|
||||
|
@ -529,6 +530,7 @@ class _Tokenizer(type):
|
|||
},
|
||||
heredoc_tag_is_identifier=klass.HEREDOC_TAG_IS_IDENTIFIER,
|
||||
string_escapes_allowed_in_raw_strings=klass.STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS,
|
||||
nested_comments=klass.NESTED_COMMENTS,
|
||||
)
|
||||
token_types = RsTokenTypeSettings(
|
||||
bit_string=_TOKEN_TYPE_TO_INDEX[TokenType.BIT_STRING],
|
||||
|
@ -608,6 +610,8 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
# Whether string escape characters function as such when placed within raw strings
|
||||
STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = True
|
||||
|
||||
NESTED_COMMENTS = True
|
||||
|
||||
# Autofilled
|
||||
_COMMENTS: t.Dict[str, str] = {}
|
||||
_FORMAT_STRINGS: t.Dict[str, t.Tuple[str, TokenType]] = {}
|
||||
|
@ -753,6 +757,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"RANGE": TokenType.RANGE,
|
||||
"RECURSIVE": TokenType.RECURSIVE,
|
||||
"REGEXP": TokenType.RLIKE,
|
||||
"RENAME": TokenType.RENAME,
|
||||
"REPLACE": TokenType.REPLACE,
|
||||
"RETURNING": TokenType.RETURNING,
|
||||
"REFERENCES": TokenType.REFERENCES,
|
||||
|
@ -913,6 +918,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
TokenType.EXECUTE,
|
||||
TokenType.FETCH,
|
||||
TokenType.SHOW,
|
||||
TokenType.RENAME,
|
||||
}
|
||||
|
||||
COMMAND_PREFIX_TOKENS = {TokenType.SEMICOLON, TokenType.BEGIN}
|
||||
|
@ -1181,7 +1187,11 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._advance(alnum=True)
|
||||
|
||||
# Nested comments are allowed by some dialects, e.g. databricks, duckdb, postgres
|
||||
if not self._end and self._chars(comment_end_size) == comment_start:
|
||||
if (
|
||||
self.NESTED_COMMENTS
|
||||
and not self._end
|
||||
and self._chars(comment_end_size) == comment_start
|
||||
):
|
||||
self._advance(comment_start_size)
|
||||
comment_count += 1
|
||||
|
||||
|
|
|
@ -55,6 +55,76 @@ def preprocess(
|
|||
return _to_sql
|
||||
|
||||
|
||||
def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
|
||||
if isinstance(expression, exp.Select):
|
||||
count = 0
|
||||
recursive_ctes = []
|
||||
|
||||
for unnest in expression.find_all(exp.Unnest):
|
||||
if (
|
||||
not isinstance(unnest.parent, (exp.From, exp.Join))
|
||||
or len(unnest.expressions) != 1
|
||||
or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
|
||||
):
|
||||
continue
|
||||
|
||||
generate_date_array = unnest.expressions[0]
|
||||
start = generate_date_array.args.get("start")
|
||||
end = generate_date_array.args.get("end")
|
||||
step = generate_date_array.args.get("step")
|
||||
|
||||
if not start or not end or not isinstance(step, exp.Interval):
|
||||
continue
|
||||
|
||||
alias = unnest.args.get("alias")
|
||||
column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
|
||||
|
||||
start = exp.cast(start, "date")
|
||||
date_add = exp.func(
|
||||
"date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
|
||||
)
|
||||
cast_date_add = exp.cast(date_add, "date")
|
||||
|
||||
cte_name = "_generated_dates" + (f"_{count}" if count else "")
|
||||
|
||||
base_query = exp.select(start.as_(column_name))
|
||||
recursive_query = (
|
||||
exp.select(cast_date_add)
|
||||
.from_(cte_name)
|
||||
.where(cast_date_add <= exp.cast(end, "date"))
|
||||
)
|
||||
cte_query = base_query.union(recursive_query, distinct=False)
|
||||
|
||||
generate_dates_query = exp.select(column_name).from_(cte_name)
|
||||
unnest.replace(generate_dates_query.subquery(cte_name))
|
||||
|
||||
recursive_ctes.append(
|
||||
exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
|
||||
)
|
||||
count += 1
|
||||
|
||||
if recursive_ctes:
|
||||
with_expression = expression.args.get("with") or exp.With()
|
||||
with_expression.set("recursive", True)
|
||||
with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
|
||||
expression.set("with", with_expression)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
|
||||
"""Unnests GENERATE_SERIES or SEQUENCE table references."""
|
||||
this = expression.this
|
||||
if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
|
||||
unnest = exp.Unnest(expressions=[this])
|
||||
if expression.alias:
|
||||
return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
|
||||
|
||||
return unnest
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def unalias_group(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Replace references to select aliases in GROUP BY clauses.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue