Merging upstream version 23.10.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
6cbc5d6f97
commit
49aa147013
91 changed files with 52881 additions and 50396 deletions
|
@ -536,7 +536,7 @@ def year(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def quarter(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_anonymous_function(col, "QUARTER")
|
||||
return Column.invoke_expression_over_column(col, expression.Quarter)
|
||||
|
||||
|
||||
def month(col: ColumnOrName) -> Column:
|
||||
|
|
|
@ -15,7 +15,7 @@ from sqlglot.dialects.dialect import (
|
|||
build_formatted_time,
|
||||
filter_array_using_unnest,
|
||||
if_sql,
|
||||
inline_array_sql,
|
||||
inline_array_unless_query,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
no_ilike_sql,
|
||||
|
@ -80,29 +80,6 @@ def _create_sql(self: BigQuery.Generator, expression: exp.Create) -> str:
|
|||
return self.create_sql(expression)
|
||||
|
||||
|
||||
def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
|
||||
"""Remove references to unnest table aliases since bigquery doesn't allow them.
|
||||
|
||||
These are added by the optimizer's qualify_column step.
|
||||
"""
|
||||
from sqlglot.optimizer.scope import find_all_in_scope
|
||||
|
||||
if isinstance(expression, exp.Select):
|
||||
unnest_aliases = {
|
||||
unnest.alias
|
||||
for unnest in find_all_in_scope(expression, exp.Unnest)
|
||||
if isinstance(unnest.parent, (exp.From, exp.Join))
|
||||
}
|
||||
if unnest_aliases:
|
||||
for column in expression.find_all(exp.Column):
|
||||
if column.table in unnest_aliases:
|
||||
column.set("table", None)
|
||||
elif column.db in unnest_aliases:
|
||||
column.set("db", None)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
# https://issuetracker.google.com/issues/162294746
|
||||
# workaround for bigquery bug when grouping by an expression and then ordering
|
||||
# WITH x AS (SELECT 1 y)
|
||||
|
@ -197,8 +174,8 @@ def _ts_or_ds_add_sql(self: BigQuery.Generator, expression: exp.TsOrDsAdd) -> st
|
|||
|
||||
|
||||
def _ts_or_ds_diff_sql(self: BigQuery.Generator, expression: exp.TsOrDsDiff) -> str:
|
||||
expression.this.replace(exp.cast(expression.this, "TIMESTAMP", copy=True))
|
||||
expression.expression.replace(exp.cast(expression.expression, "TIMESTAMP", copy=True))
|
||||
expression.this.replace(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
|
||||
expression.expression.replace(exp.cast(expression.expression, exp.DataType.Type.TIMESTAMP))
|
||||
unit = unit_to_var(expression)
|
||||
return self.func("DATE_DIFF", expression.this, expression.expression, unit)
|
||||
|
||||
|
@ -214,7 +191,9 @@ def _unix_to_time_sql(self: BigQuery.Generator, expression: exp.UnixToTime) -> s
|
|||
if scale == exp.UnixToTime.MICROS:
|
||||
return self.func("TIMESTAMP_MICROS", timestamp)
|
||||
|
||||
unix_seconds = exp.cast(exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)), "int64")
|
||||
unix_seconds = exp.cast(
|
||||
exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)), exp.DataType.Type.BIGINT
|
||||
)
|
||||
return self.func("TIMESTAMP_SECONDS", unix_seconds)
|
||||
|
||||
|
||||
|
@ -576,6 +555,7 @@ class BigQuery(Dialect):
|
|||
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
|
||||
exp.ArgMax: arg_max_or_min_no_count("MAX_BY"),
|
||||
exp.ArgMin: arg_max_or_min_no_count("MIN_BY"),
|
||||
exp.Array: inline_array_unless_query,
|
||||
exp.ArrayContains: _array_contains_sql,
|
||||
exp.ArrayFilter: filter_array_using_unnest,
|
||||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||
|
@ -629,7 +609,7 @@ class BigQuery(Dialect):
|
|||
exp.Select: transforms.preprocess(
|
||||
[
|
||||
transforms.explode_to_unnest(),
|
||||
_unqualify_unnest,
|
||||
transforms.unqualify_unnest,
|
||||
transforms.eliminate_distinct_on,
|
||||
_alias_ordered_group,
|
||||
transforms.eliminate_semi_and_anti_joins,
|
||||
|
@ -843,13 +823,6 @@ class BigQuery(Dialect):
|
|||
def trycast_sql(self, expression: exp.TryCast) -> str:
|
||||
return self.cast_sql(expression, safe_prefix="SAFE_")
|
||||
|
||||
def array_sql(self, expression: exp.Array) -> str:
|
||||
first_arg = seq_get(expression.expressions, 0)
|
||||
if isinstance(first_arg, exp.Query):
|
||||
return f"ARRAY{self.wrap(self.sql(first_arg))}"
|
||||
|
||||
return inline_array_sql(self, expression)
|
||||
|
||||
def bracket_sql(self, expression: exp.Bracket) -> str:
|
||||
this = expression.this
|
||||
expressions = expression.expressions
|
||||
|
|
|
@ -629,7 +629,8 @@ class ClickHouse(Dialect):
|
|||
exp.CountIf: rename_func("countIf"),
|
||||
exp.CompressColumnConstraint: lambda self,
|
||||
e: f"CODEC({self.expressions(e, key='this', flat=True)})",
|
||||
exp.ComputedColumnConstraint: lambda self, e: f"ALIAS {self.sql(e, 'this')}",
|
||||
exp.ComputedColumnConstraint: lambda self,
|
||||
e: f"{'MATERIALIZED' if e.args.get('persisted') else 'ALIAS'} {self.sql(e, 'this')}",
|
||||
exp.CurrentDate: lambda self, e: self.func("CURRENT_DATE"),
|
||||
exp.DateAdd: date_delta_sql("DATE_ADD"),
|
||||
exp.DateDiff: date_delta_sql("DATE_DIFF"),
|
||||
|
@ -667,6 +668,7 @@ class ClickHouse(Dialect):
|
|||
TABLE_HINTS = False
|
||||
EXPLICIT_UNION = True
|
||||
GROUPINGS_SEP = ""
|
||||
OUTER_UNION_MODIFIERS = False
|
||||
|
||||
# there's no list in docs, but it can be found in Clickhouse code
|
||||
# see `ClickHouse/src/Parsers/ParserCreate*.cpp`
|
||||
|
|
|
@ -562,7 +562,7 @@ def if_sql(
|
|||
def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
|
||||
this = expression.this
|
||||
if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
|
||||
this.replace(exp.cast(this, "json"))
|
||||
this.replace(exp.cast(this, exp.DataType.Type.JSON))
|
||||
|
||||
return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
|
||||
|
||||
|
@ -571,6 +571,13 @@ def inline_array_sql(self: Generator, expression: exp.Array) -> str:
|
|||
return f"[{self.expressions(expression, flat=True)}]"
|
||||
|
||||
|
||||
def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
|
||||
elem = seq_get(expression.expressions, 0)
|
||||
if isinstance(elem, exp.Expression) and elem.find(exp.Query):
|
||||
return self.func("ARRAY", elem)
|
||||
return inline_array_sql(self, expression)
|
||||
|
||||
|
||||
def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
|
||||
return self.like_sql(
|
||||
exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
|
||||
|
@ -765,11 +772,11 @@ def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
|
|||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
|
||||
return self.sql(exp.cast(expression.this, to=target_type))
|
||||
return self.sql(exp.cast(expression.this, target_type))
|
||||
if expression.text("expression").lower() in TIMEZONES:
|
||||
return self.sql(
|
||||
exp.AtTimeZone(
|
||||
this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
|
||||
this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
|
||||
zone=expression.expression,
|
||||
)
|
||||
)
|
||||
|
@ -806,11 +813,11 @@ def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
|
|||
|
||||
|
||||
def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
|
||||
return self.sql(exp.cast(expression.this, "timestamp"))
|
||||
return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
|
||||
|
||||
|
||||
def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
|
||||
return self.sql(exp.cast(expression.this, "date"))
|
||||
return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE))
|
||||
|
||||
|
||||
# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
|
||||
|
@ -1023,7 +1030,7 @@ def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
|
|||
plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
|
||||
minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
|
||||
|
||||
return self.sql(exp.cast(minus_one_day, "date"))
|
||||
return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
|
||||
|
||||
|
||||
def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
|
||||
|
|
|
@ -19,7 +19,7 @@ def _str_to_date(self: Drill.Generator, expression: exp.StrToDate) -> str:
|
|||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format == Drill.DATE_FORMAT:
|
||||
return self.sql(exp.cast(this, "date"))
|
||||
return self.sql(exp.cast(this, exp.DataType.Type.DATE))
|
||||
return self.func("TO_DATE", this, time_format)
|
||||
|
||||
|
||||
|
@ -134,7 +134,7 @@ class Drill(Dialect):
|
|||
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
|
||||
),
|
||||
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
|
||||
exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, "date")),
|
||||
exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, exp.DataType.Type.DATE)),
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)),
|
||||
|
|
|
@ -15,7 +15,7 @@ from sqlglot.dialects.dialect import (
|
|||
datestrtodate_sql,
|
||||
encode_decode_sql,
|
||||
build_formatted_time,
|
||||
inline_array_sql,
|
||||
inline_array_unless_query,
|
||||
no_comment_column_constraint_sql,
|
||||
no_safe_divide_sql,
|
||||
no_timestamp_sql,
|
||||
|
@ -312,6 +312,15 @@ class DuckDB(Dialect):
|
|||
),
|
||||
}
|
||||
|
||||
def _parse_bracket(
|
||||
self, this: t.Optional[exp.Expression] = None
|
||||
) -> t.Optional[exp.Expression]:
|
||||
bracket = super()._parse_bracket(this)
|
||||
if isinstance(bracket, exp.Bracket):
|
||||
bracket.set("returns_list_for_maps", True)
|
||||
|
||||
return bracket
|
||||
|
||||
def _parse_map(self) -> exp.ToMap | exp.Map:
|
||||
if self._match(TokenType.L_BRACE, advance=False):
|
||||
return self.expression(exp.ToMap, this=self._parse_bracket())
|
||||
|
@ -370,11 +379,7 @@ class DuckDB(Dialect):
|
|||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.Array: lambda self, e: (
|
||||
self.func("ARRAY", e.expressions[0])
|
||||
if e.expressions and e.expressions[0].find(exp.Select)
|
||||
else inline_array_sql(self, e)
|
||||
),
|
||||
exp.Array: inline_array_unless_query,
|
||||
exp.ArrayFilter: rename_func("LIST_FILTER"),
|
||||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||
exp.ArgMax: arg_max_or_min_no_count("ARG_MAX"),
|
||||
|
@ -416,8 +421,8 @@ class DuckDB(Dialect):
|
|||
exp.MonthsBetween: lambda self, e: self.func(
|
||||
"DATEDIFF",
|
||||
"'month'",
|
||||
exp.cast(e.expression, "timestamp", copy=True),
|
||||
exp.cast(e.this, "timestamp", copy=True),
|
||||
exp.cast(e.expression, exp.DataType.Type.TIMESTAMP, copy=True),
|
||||
exp.cast(e.this, exp.DataType.Type.TIMESTAMP, copy=True),
|
||||
),
|
||||
exp.ParseJSON: rename_func("JSON"),
|
||||
exp.PercentileCont: _rename_unless_within_group("PERCENTILE_CONT", "QUANTILE_CONT"),
|
||||
|
@ -452,9 +457,11 @@ class DuckDB(Dialect):
|
|||
"DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this
|
||||
),
|
||||
exp.TimestampTrunc: timestamptrunc_sql,
|
||||
exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, "date")),
|
||||
exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, exp.DataType.Type.DATE)),
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeStrToUnix: lambda self, e: self.func("EPOCH", exp.cast(e.this, "timestamp")),
|
||||
exp.TimeStrToUnix: lambda self, e: self.func(
|
||||
"EPOCH", exp.cast(e.this, exp.DataType.Type.TIMESTAMP)
|
||||
),
|
||||
exp.TimeToStr: lambda self, e: self.func("STRFTIME", e.this, self.format_time(e)),
|
||||
exp.TimeToUnix: rename_func("EPOCH"),
|
||||
exp.TsOrDiToDi: lambda self,
|
||||
|
@ -463,8 +470,8 @@ class DuckDB(Dialect):
|
|||
exp.TsOrDsDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF",
|
||||
f"'{e.args.get('unit') or 'DAY'}'",
|
||||
exp.cast(e.expression, "TIMESTAMP"),
|
||||
exp.cast(e.this, "TIMESTAMP"),
|
||||
exp.cast(e.expression, exp.DataType.Type.TIMESTAMP),
|
||||
exp.cast(e.this, exp.DataType.Type.TIMESTAMP),
|
||||
),
|
||||
exp.UnixToStr: lambda self, e: self.func(
|
||||
"STRFTIME", self.func("TO_TIMESTAMP", e.this), self.format_time(e)
|
||||
|
@ -593,7 +600,19 @@ class DuckDB(Dialect):
|
|||
return super().generateseries_sql(expression)
|
||||
|
||||
def bracket_sql(self, expression: exp.Bracket) -> str:
|
||||
if isinstance(expression.this, exp.Array):
|
||||
expression.this.replace(exp.paren(expression.this))
|
||||
this = expression.this
|
||||
if isinstance(this, exp.Array):
|
||||
this.replace(exp.paren(this))
|
||||
|
||||
return super().bracket_sql(expression)
|
||||
bracket = super().bracket_sql(expression)
|
||||
|
||||
if not expression.args.get("returns_list_for_maps"):
|
||||
if not this.type:
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
this = annotate_types(this)
|
||||
|
||||
if this.is_type(exp.DataType.Type.MAP):
|
||||
bracket = f"({bracket})[1]"
|
||||
|
||||
return bracket
|
||||
|
|
|
@ -710,7 +710,9 @@ class MySQL(Dialect):
|
|||
),
|
||||
exp.TimestampSub: date_add_interval_sql("DATE", "SUB"),
|
||||
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)),
|
||||
exp.TimeStrToTime: lambda self, e: self.sql(
|
||||
exp.cast(e.this, exp.DataType.Type.DATETIME, copy=True)
|
||||
),
|
||||
exp.TimeToStr: _remove_ts_or_ds_to_date(
|
||||
lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e))
|
||||
),
|
||||
|
|
|
@ -510,6 +510,9 @@ class Postgres(Dialect):
|
|||
exp.TsOrDsAdd: _date_add_sql("+"),
|
||||
exp.TsOrDsDiff: _date_diff_sql,
|
||||
exp.UnixToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this),
|
||||
exp.TimeToUnix: lambda self, e: self.func(
|
||||
"DATE_PART", exp.Literal.string("epoch"), e.this
|
||||
),
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
exp.Variance: rename_func("VAR_SAMP"),
|
||||
exp.Xor: bool_xor_sql,
|
||||
|
|
|
@ -90,8 +90,10 @@ def _str_to_time_sql(
|
|||
def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate) -> str:
|
||||
time_format = self.format_time(expression)
|
||||
if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT):
|
||||
return self.sql(exp.cast(_str_to_time_sql(self, expression), "DATE"))
|
||||
return self.sql(exp.cast(exp.cast(expression.this, "TIMESTAMP"), "DATE"))
|
||||
return self.sql(exp.cast(_str_to_time_sql(self, expression), exp.DataType.Type.DATE))
|
||||
return self.sql(
|
||||
exp.cast(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), exp.DataType.Type.DATE)
|
||||
)
|
||||
|
||||
|
||||
def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str:
|
||||
|
@ -101,8 +103,8 @@ def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str:
|
|||
|
||||
|
||||
def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> str:
|
||||
this = exp.cast(expression.this, "TIMESTAMP")
|
||||
expr = exp.cast(expression.expression, "TIMESTAMP")
|
||||
this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMP)
|
||||
expr = exp.cast(expression.expression, exp.DataType.Type.TIMESTAMP)
|
||||
unit = unit_to_str(expression)
|
||||
return self.func("DATE_DIFF", unit, expr, this)
|
||||
|
||||
|
@ -222,6 +224,8 @@ class Presto(Dialect):
|
|||
"IPPREFIX": TokenType.IPPREFIX,
|
||||
}
|
||||
|
||||
KEYWORDS.pop("QUALIFY")
|
||||
|
||||
class Parser(parser.Parser):
|
||||
VALUES_FOLLOWED_BY_PAREN = False
|
||||
|
||||
|
@ -445,7 +449,7 @@ class Presto(Dialect):
|
|||
# timezone involved, we wrap it in a `TRY` call and use `PARSE_DATETIME` as a fallback,
|
||||
# which seems to be using the same time mapping as Hive, as per:
|
||||
# https://joda-time.sourceforge.net/apidocs/org/joda/time/format/DateTimeFormat.html
|
||||
value_as_text = exp.cast(expression.this, "text")
|
||||
value_as_text = exp.cast(expression.this, exp.DataType.Type.TEXT)
|
||||
parse_without_tz = self.func("DATE_PARSE", value_as_text, self.format_time(expression))
|
||||
parse_with_tz = self.func(
|
||||
"PARSE_DATETIME",
|
||||
|
|
|
@ -7,7 +7,13 @@ from sqlglot.dialects.dialect import Dialect
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _select_all(table: exp.Expression) -> t.Optional[exp.Select]:
|
||||
return exp.select("*").from_(table, copy=False) if table else None
|
||||
|
||||
|
||||
class PRQL(Dialect):
|
||||
DPIPE_IS_STRING_CONCAT = False
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
IDENTIFIERS = ["`"]
|
||||
QUOTES = ["'", '"']
|
||||
|
@ -26,10 +32,27 @@ class PRQL(Dialect):
|
|||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
CONJUNCTION = {
|
||||
**parser.Parser.CONJUNCTION,
|
||||
TokenType.DAMP: exp.And,
|
||||
TokenType.DPIPE: exp.Or,
|
||||
}
|
||||
|
||||
TRANSFORM_PARSERS = {
|
||||
"DERIVE": lambda self, query: self._parse_selection(query),
|
||||
"SELECT": lambda self, query: self._parse_selection(query, append=False),
|
||||
"TAKE": lambda self, query: self._parse_take(query),
|
||||
"FILTER": lambda self, query: query.where(self._parse_conjunction()),
|
||||
"APPEND": lambda self, query: query.union(
|
||||
_select_all(self._parse_table()), distinct=False, copy=False
|
||||
),
|
||||
"REMOVE": lambda self, query: query.except_(
|
||||
_select_all(self._parse_table()), distinct=False, copy=False
|
||||
),
|
||||
"INTERSECT": lambda self, query: query.intersect(
|
||||
_select_all(self._parse_table()), distinct=False, copy=False
|
||||
),
|
||||
"SORT": lambda self, query: self._parse_order_by(query),
|
||||
}
|
||||
|
||||
def _parse_statement(self) -> t.Optional[exp.Expression]:
|
||||
|
@ -81,6 +104,24 @@ class PRQL(Dialect):
|
|||
num = self._parse_number() # TODO: TAKE for ranges a..b
|
||||
return query.limit(num) if num else None
|
||||
|
||||
def _parse_ordered(
|
||||
self, parse_method: t.Optional[t.Callable] = None
|
||||
) -> t.Optional[exp.Ordered]:
|
||||
asc = self._match(TokenType.PLUS)
|
||||
desc = self._match(TokenType.DASH) or (asc and False)
|
||||
term = term = super()._parse_ordered(parse_method=parse_method)
|
||||
if term and desc:
|
||||
term.set("desc", True)
|
||||
term.set("nulls_first", False)
|
||||
return term
|
||||
|
||||
def _parse_order_by(self, query: exp.Select) -> t.Optional[exp.Query]:
|
||||
l_brace = self._match(TokenType.L_BRACE)
|
||||
expressions = self._parse_csv(self._parse_ordered)
|
||||
if l_brace and not self._match(TokenType.R_BRACE):
|
||||
self.raise_error("Expecting }")
|
||||
return query.order_by(self.expression(exp.Order, expressions=expressions), copy=False)
|
||||
|
||||
def _parse_expression(self) -> t.Optional[exp.Expression]:
|
||||
if self._next and self._next.token_type == TokenType.ALIAS:
|
||||
alias = self._parse_id_var(True)
|
||||
|
|
|
@ -167,7 +167,11 @@ class Redshift(Postgres):
|
|||
exp.GroupConcat: rename_func("LISTAGG"),
|
||||
exp.ParseJSON: rename_func("JSON_PARSE"),
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
|
||||
[
|
||||
transforms.eliminate_distinct_on,
|
||||
transforms.eliminate_semi_and_anti_joins,
|
||||
transforms.unqualify_unnest,
|
||||
]
|
||||
),
|
||||
exp.SortKeyProperty: lambda self,
|
||||
e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
|
||||
|
@ -203,7 +207,7 @@ class Redshift(Postgres):
|
|||
return ""
|
||||
|
||||
arg = self.sql(seq_get(args, 0))
|
||||
alias = self.expressions(expression.args.get("alias"), key="columns")
|
||||
alias = self.expressions(expression.args.get("alias"), key="columns", flat=True)
|
||||
return f"{arg} AS {alias}" if alias else arg
|
||||
|
||||
def with_properties(self, properties: exp.Properties) -> str:
|
||||
|
|
|
@ -818,7 +818,7 @@ class Snowflake(Dialect):
|
|||
exp.TimestampTrunc: timestamptrunc_sql,
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeToStr: lambda self, e: self.func(
|
||||
"TO_CHAR", exp.cast(e.this, "timestamp"), self.format_time(e)
|
||||
"TO_CHAR", exp.cast(e.this, exp.DataType.Type.TIMESTAMP), self.format_time(e)
|
||||
),
|
||||
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
|
||||
exp.ToArray: rename_func("TO_ARRAY"),
|
||||
|
|
|
@ -6,7 +6,7 @@ from sqlglot import exp
|
|||
from sqlglot.dialects.dialect import rename_func, unit_to_var
|
||||
from sqlglot.dialects.hive import _build_with_ignore_nulls
|
||||
from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.helper import ensure_list, seq_get
|
||||
from sqlglot.transforms import (
|
||||
ctas_with_tmp_tables_to_create_tmp_view,
|
||||
remove_unique_constraints,
|
||||
|
@ -63,6 +63,9 @@ class Spark(Spark2):
|
|||
**Spark2.Parser.FUNCTIONS,
|
||||
"ANY_VALUE": _build_with_ignore_nulls(exp.AnyValue),
|
||||
"DATEDIFF": _build_datediff,
|
||||
"TRY_ELEMENT_AT": lambda args: exp.Bracket(
|
||||
this=seq_get(args, 0), expressions=ensure_list(seq_get(args, 1)), safe=True
|
||||
),
|
||||
}
|
||||
|
||||
def _parse_generated_as_identity(
|
||||
|
@ -112,6 +115,13 @@ class Spark(Spark2):
|
|||
TRANSFORMS.pop(exp.DateDiff)
|
||||
TRANSFORMS.pop(exp.Group)
|
||||
|
||||
def bracket_sql(self, expression: exp.Bracket) -> str:
|
||||
if expression.args.get("safe"):
|
||||
key = seq_get(self.bracket_offset_expressions(expression), 0)
|
||||
return self.func("TRY_ELEMENT_AT", expression.this, key)
|
||||
|
||||
return super().bracket_sql(expression)
|
||||
|
||||
def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str:
|
||||
return f"GENERATED ALWAYS AS ({self.sql(expression, 'this')})"
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str
|
|||
timestamp = expression.this
|
||||
|
||||
if scale is None:
|
||||
return self.sql(exp.cast(exp.func("from_unixtime", timestamp), "timestamp"))
|
||||
return self.sql(exp.cast(exp.func("from_unixtime", timestamp), exp.DataType.Type.TIMESTAMP))
|
||||
if scale == exp.UnixToTime.SECONDS:
|
||||
return self.func("TIMESTAMP_SECONDS", timestamp)
|
||||
if scale == exp.UnixToTime.MILLIS:
|
||||
|
@ -129,11 +129,7 @@ class Spark2(Hive):
|
|||
"DOUBLE": _build_as_cast("double"),
|
||||
"FLOAT": _build_as_cast("float"),
|
||||
"FROM_UTC_TIMESTAMP": lambda args: exp.AtTimeZone(
|
||||
this=exp.cast_unless(
|
||||
seq_get(args, 0) or exp.Var(this=""),
|
||||
exp.DataType.build("timestamp"),
|
||||
exp.DataType.build("timestamp"),
|
||||
),
|
||||
this=exp.cast(seq_get(args, 0) or exp.Var(this=""), exp.DataType.Type.TIMESTAMP),
|
||||
zone=seq_get(args, 1),
|
||||
),
|
||||
"INT": _build_as_cast("int"),
|
||||
|
@ -150,11 +146,7 @@ class Spark2(Hive):
|
|||
),
|
||||
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
|
||||
"TO_UTC_TIMESTAMP": lambda args: exp.FromTimeZone(
|
||||
this=exp.cast_unless(
|
||||
seq_get(args, 0) or exp.Var(this=""),
|
||||
exp.DataType.build("timestamp"),
|
||||
exp.DataType.build("timestamp"),
|
||||
),
|
||||
this=exp.cast(seq_get(args, 0) or exp.Var(this=""), exp.DataType.Type.TIMESTAMP),
|
||||
zone=seq_get(args, 1),
|
||||
),
|
||||
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
|
||||
|
|
|
@ -13,6 +13,29 @@ from sqlglot.dialects.dialect import (
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _date_add_sql(
|
||||
kind: t.Literal["+", "-"],
|
||||
) -> t.Callable[[Teradata.Generator, exp.DateAdd | exp.DateSub], str]:
|
||||
def func(self: Teradata.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = expression.args.get("unit")
|
||||
value = self._simplify_unless_literal(expression.expression)
|
||||
|
||||
if not isinstance(value, exp.Literal):
|
||||
self.unsupported("Cannot add non literal")
|
||||
|
||||
if value.is_negative:
|
||||
kind_to_op = {"+": "-", "-": "+"}
|
||||
value = exp.Literal.string(value.name[1:])
|
||||
else:
|
||||
kind_to_op = {"+": "+", "-": "-"}
|
||||
value.set("is_string", True)
|
||||
|
||||
return f"{this} {kind_to_op[kind]} {self.sql(exp.Interval(this=value, unit=unit))}"
|
||||
|
||||
return func
|
||||
|
||||
|
||||
class Teradata(Dialect):
|
||||
SUPPORTS_SEMI_ANTI_JOIN = False
|
||||
TYPED_DIVISION = True
|
||||
|
@ -189,6 +212,7 @@ class Teradata(Dialect):
|
|||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.GEOMETRY: "ST_GEOMETRY",
|
||||
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
|
@ -214,6 +238,10 @@ class Teradata(Dialect):
|
|||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.ToNumber: to_number_with_nls_param,
|
||||
exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}",
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.DateAdd: _date_add_sql("+"),
|
||||
exp.DateSub: _date_add_sql("-"),
|
||||
exp.Quarter: lambda self, e: self.sql(exp.Extract(this="QUARTER", expression=e.this)),
|
||||
}
|
||||
|
||||
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
|
||||
|
@ -276,3 +304,25 @@ class Teradata(Dialect):
|
|||
return f"{this_name}{this_properties}{self.sep()}{this_schema}"
|
||||
|
||||
return super().createable_sql(expression, locations)
|
||||
|
||||
def extract_sql(self, expression: exp.Extract) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
if this.upper() != "QUARTER":
|
||||
return super().extract_sql(expression)
|
||||
|
||||
to_char = exp.func("to_char", expression.expression, exp.Literal.string("Q"))
|
||||
return self.sql(exp.cast(to_char, exp.DataType.Type.INT))
|
||||
|
||||
def interval_sql(self, expression: exp.Interval) -> str:
|
||||
multiplier = 0
|
||||
unit = expression.text("unit")
|
||||
|
||||
if unit.startswith("WEEK"):
|
||||
multiplier = 7
|
||||
elif unit.startswith("QUARTER"):
|
||||
multiplier = 90
|
||||
|
||||
if multiplier:
|
||||
return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('DAY')))})"
|
||||
|
||||
return super().interval_sql(expression)
|
||||
|
|
|
@ -109,7 +109,7 @@ def _build_formatted_time(
|
|||
assert len(args) == 2
|
||||
|
||||
return exp_class(
|
||||
this=exp.cast(args[1], "datetime"),
|
||||
this=exp.cast(args[1], exp.DataType.Type.DATETIME),
|
||||
format=exp.Literal.string(
|
||||
format_time(
|
||||
args[0].name.lower(),
|
||||
|
@ -726,6 +726,7 @@ class TSQL(Dialect):
|
|||
SUPPORTS_SELECT_INTO = True
|
||||
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
|
||||
SUPPORTS_TO_NUMBER = False
|
||||
OUTER_UNION_MODIFIERS = False
|
||||
|
||||
EXPRESSIONS_WITHOUT_NESTED_CTES = {
|
||||
exp.Delete,
|
||||
|
@ -882,13 +883,6 @@ class TSQL(Dialect):
|
|||
|
||||
return rename_func("DATETIMEFROMPARTS")(self, expression)
|
||||
|
||||
def set_operations(self, expression: exp.Union) -> str:
|
||||
limit = expression.args.get("limit")
|
||||
if limit:
|
||||
return self.sql(expression.limit(limit.pop(), copy=False))
|
||||
|
||||
return super().set_operations(expression)
|
||||
|
||||
def setitem_sql(self, expression: exp.SetItem) -> str:
|
||||
this = expression.this
|
||||
if isinstance(this, exp.EQ) and not isinstance(this.left, exp.Parameter):
|
||||
|
|
|
@ -58,6 +58,7 @@ class _Expression(type):
|
|||
|
||||
SQLGLOT_META = "sqlglot.meta"
|
||||
TABLE_PARTS = ("this", "db", "catalog")
|
||||
COLUMN_PARTS = ("this", "table", "db", "catalog")
|
||||
|
||||
|
||||
class Expression(metaclass=_Expression):
|
||||
|
@ -175,6 +176,15 @@ class Expression(metaclass=_Expression):
|
|||
"""
|
||||
return isinstance(self, Literal) and not self.args["is_string"]
|
||||
|
||||
@property
|
||||
def is_negative(self) -> bool:
|
||||
"""
|
||||
Checks whether an expression is negative.
|
||||
|
||||
Handles both exp.Neg and Literal numbers with "-" which come from optimizer.simplify.
|
||||
"""
|
||||
return isinstance(self, Neg) or (self.is_number and self.this.startswith("-"))
|
||||
|
||||
@property
|
||||
def is_int(self) -> bool:
|
||||
"""
|
||||
|
@ -845,10 +855,14 @@ class Expression(metaclass=_Expression):
|
|||
copy: bool = True,
|
||||
**opts,
|
||||
) -> In:
|
||||
subquery = maybe_parse(query, copy=copy, **opts) if query else None
|
||||
if subquery and not isinstance(subquery, Subquery):
|
||||
subquery = subquery.subquery(copy=False)
|
||||
|
||||
return In(
|
||||
this=maybe_copy(self, copy),
|
||||
expressions=[convert(e, copy=copy) for e in expressions],
|
||||
query=maybe_parse(query, copy=copy, **opts) if query else None,
|
||||
query=subquery,
|
||||
unnest=(
|
||||
Unnest(
|
||||
expressions=[
|
||||
|
@ -1018,14 +1032,14 @@ class Query(Expression):
|
|||
return Subquery(this=instance, alias=alias)
|
||||
|
||||
def limit(
|
||||
self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts
|
||||
) -> Select:
|
||||
self: Q, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts
|
||||
) -> Q:
|
||||
"""
|
||||
Adds a LIMIT clause to this query.
|
||||
|
||||
Example:
|
||||
>>> select("1").union(select("1")).limit(1).sql()
|
||||
'SELECT * FROM (SELECT 1 UNION SELECT 1) AS _l_0 LIMIT 1'
|
||||
'SELECT 1 UNION SELECT 1 LIMIT 1'
|
||||
|
||||
Args:
|
||||
expression: the SQL code string to parse.
|
||||
|
@ -1039,10 +1053,90 @@ class Query(Expression):
|
|||
Returns:
|
||||
A limited Select expression.
|
||||
"""
|
||||
return (
|
||||
select("*")
|
||||
.from_(self.subquery(alias="_l_0", copy=copy))
|
||||
.limit(expression, dialect=dialect, copy=False, **opts)
|
||||
return _apply_builder(
|
||||
expression=expression,
|
||||
instance=self,
|
||||
arg="limit",
|
||||
into=Limit,
|
||||
prefix="LIMIT",
|
||||
dialect=dialect,
|
||||
copy=copy,
|
||||
into_arg="expression",
|
||||
**opts,
|
||||
)
|
||||
|
||||
def offset(
|
||||
self: Q, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts
|
||||
) -> Q:
|
||||
"""
|
||||
Set the OFFSET expression.
|
||||
|
||||
Example:
|
||||
>>> Select().from_("tbl").select("x").offset(10).sql()
|
||||
'SELECT x FROM tbl OFFSET 10'
|
||||
|
||||
Args:
|
||||
expression: the SQL code string to parse.
|
||||
This can also be an integer.
|
||||
If a `Offset` instance is passed, this is used as-is.
|
||||
If another `Expression` instance is passed, it will be wrapped in a `Offset`.
|
||||
dialect: the dialect used to parse the input expression.
|
||||
copy: if `False`, modify this expression instance in-place.
|
||||
opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
The modified Select expression.
|
||||
"""
|
||||
return _apply_builder(
|
||||
expression=expression,
|
||||
instance=self,
|
||||
arg="offset",
|
||||
into=Offset,
|
||||
prefix="OFFSET",
|
||||
dialect=dialect,
|
||||
copy=copy,
|
||||
into_arg="expression",
|
||||
**opts,
|
||||
)
|
||||
|
||||
def order_by(
|
||||
self: Q,
|
||||
*expressions: t.Optional[ExpOrStr],
|
||||
append: bool = True,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
) -> Q:
|
||||
"""
|
||||
Set the ORDER BY expression.
|
||||
|
||||
Example:
|
||||
>>> Select().from_("tbl").select("x").order_by("x DESC").sql()
|
||||
'SELECT x FROM tbl ORDER BY x DESC'
|
||||
|
||||
Args:
|
||||
*expressions: the SQL code strings to parse.
|
||||
If a `Group` instance is passed, this is used as-is.
|
||||
If another `Expression` instance is passed, it will be wrapped in a `Order`.
|
||||
append: if `True`, add to any existing expressions.
|
||||
Otherwise, this flattens all the `Order` expression into a single expression.
|
||||
dialect: the dialect used to parse the input expression.
|
||||
copy: if `False`, modify this expression instance in-place.
|
||||
opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
The modified Select expression.
|
||||
"""
|
||||
return _apply_child_list_builder(
|
||||
*expressions,
|
||||
instance=self,
|
||||
arg="order",
|
||||
append=append,
|
||||
copy=copy,
|
||||
prefix="ORDER BY",
|
||||
into=Order,
|
||||
dialect=dialect,
|
||||
**opts,
|
||||
)
|
||||
|
||||
@property
|
||||
|
@ -1536,7 +1630,13 @@ class SwapTable(Expression):
|
|||
|
||||
|
||||
class Comment(Expression):
|
||||
arg_types = {"this": True, "kind": True, "expression": True, "exists": False}
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"kind": True,
|
||||
"expression": True,
|
||||
"exists": False,
|
||||
"materialized": False,
|
||||
}
|
||||
|
||||
|
||||
class Comprehension(Expression):
|
||||
|
@ -1642,6 +1742,10 @@ class ExcludeColumnConstraint(ColumnConstraintKind):
|
|||
pass
|
||||
|
||||
|
||||
class EphemeralColumnConstraint(ColumnConstraintKind):
|
||||
arg_types = {"this": False}
|
||||
|
||||
|
||||
class WithOperator(Expression):
|
||||
arg_types = {"this": True, "op": True}
|
||||
|
||||
|
@ -2221,6 +2325,13 @@ class Lateral(UDTF):
|
|||
}
|
||||
|
||||
|
||||
class MatchRecognizeMeasure(Expression):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"window_frame": False,
|
||||
}
|
||||
|
||||
|
||||
class MatchRecognize(Expression):
|
||||
arg_types = {
|
||||
"partition_by": False,
|
||||
|
@ -3051,46 +3162,6 @@ class Select(Query):
|
|||
**opts,
|
||||
)
|
||||
|
||||
def order_by(
|
||||
self,
|
||||
*expressions: t.Optional[ExpOrStr],
|
||||
append: bool = True,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
) -> Select:
|
||||
"""
|
||||
Set the ORDER BY expression.
|
||||
|
||||
Example:
|
||||
>>> Select().from_("tbl").select("x").order_by("x DESC").sql()
|
||||
'SELECT x FROM tbl ORDER BY x DESC'
|
||||
|
||||
Args:
|
||||
*expressions: the SQL code strings to parse.
|
||||
If a `Group` instance is passed, this is used as-is.
|
||||
If another `Expression` instance is passed, it will be wrapped in a `Order`.
|
||||
append: if `True`, add to any existing expressions.
|
||||
Otherwise, this flattens all the `Order` expression into a single expression.
|
||||
dialect: the dialect used to parse the input expression.
|
||||
copy: if `False`, modify this expression instance in-place.
|
||||
opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
The modified Select expression.
|
||||
"""
|
||||
return _apply_child_list_builder(
|
||||
*expressions,
|
||||
instance=self,
|
||||
arg="order",
|
||||
append=append,
|
||||
copy=copy,
|
||||
prefix="ORDER BY",
|
||||
into=Order,
|
||||
dialect=dialect,
|
||||
**opts,
|
||||
)
|
||||
|
||||
def sort_by(
|
||||
self,
|
||||
*expressions: t.Optional[ExpOrStr],
|
||||
|
@ -3171,55 +3242,6 @@ class Select(Query):
|
|||
**opts,
|
||||
)
|
||||
|
||||
def limit(
|
||||
self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts
|
||||
) -> Select:
|
||||
return _apply_builder(
|
||||
expression=expression,
|
||||
instance=self,
|
||||
arg="limit",
|
||||
into=Limit,
|
||||
prefix="LIMIT",
|
||||
dialect=dialect,
|
||||
copy=copy,
|
||||
into_arg="expression",
|
||||
**opts,
|
||||
)
|
||||
|
||||
def offset(
|
||||
self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts
|
||||
) -> Select:
|
||||
"""
|
||||
Set the OFFSET expression.
|
||||
|
||||
Example:
|
||||
>>> Select().from_("tbl").select("x").offset(10).sql()
|
||||
'SELECT x FROM tbl OFFSET 10'
|
||||
|
||||
Args:
|
||||
expression: the SQL code string to parse.
|
||||
This can also be an integer.
|
||||
If a `Offset` instance is passed, this is used as-is.
|
||||
If another `Expression` instance is passed, it will be wrapped in a `Offset`.
|
||||
dialect: the dialect used to parse the input expression.
|
||||
copy: if `False`, modify this expression instance in-place.
|
||||
opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
The modified Select expression.
|
||||
"""
|
||||
return _apply_builder(
|
||||
expression=expression,
|
||||
instance=self,
|
||||
arg="offset",
|
||||
into=Offset,
|
||||
prefix="OFFSET",
|
||||
dialect=dialect,
|
||||
copy=copy,
|
||||
into_arg="expression",
|
||||
**opts,
|
||||
)
|
||||
|
||||
def select(
|
||||
self,
|
||||
*expressions: t.Optional[ExpOrStr],
|
||||
|
@ -4214,7 +4236,7 @@ class Dot(Binary):
|
|||
|
||||
parts.reverse()
|
||||
|
||||
for arg in ("this", "table", "db", "catalog"):
|
||||
for arg in COLUMN_PARTS:
|
||||
part = this.args.get(arg)
|
||||
|
||||
if isinstance(part, Expression):
|
||||
|
@ -4395,7 +4417,13 @@ class Between(Predicate):
|
|||
|
||||
class Bracket(Condition):
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/operators#array_subscript_operator
|
||||
arg_types = {"this": True, "expressions": True, "offset": False, "safe": False}
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"expressions": True,
|
||||
"offset": False,
|
||||
"safe": False,
|
||||
"returns_list_for_maps": False,
|
||||
}
|
||||
|
||||
@property
|
||||
def output_name(self) -> str:
|
||||
|
@ -5458,6 +5486,10 @@ class ApproxQuantile(Quantile):
|
|||
arg_types = {"this": True, "quantile": True, "accuracy": False, "weight": False}
|
||||
|
||||
|
||||
class Quarter(Func):
|
||||
pass
|
||||
|
||||
|
||||
class Rand(Func):
|
||||
_sql_names = ["RAND", "RANDOM"]
|
||||
arg_types = {"this": False}
|
||||
|
@ -6620,17 +6652,9 @@ def to_interval(interval: str | Literal) -> Interval:
|
|||
)
|
||||
|
||||
|
||||
@t.overload
|
||||
def to_table(sql_path: str | Table, **kwargs) -> Table: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def to_table(sql_path: None, **kwargs) -> None: ...
|
||||
|
||||
|
||||
def to_table(
|
||||
sql_path: t.Optional[str | Table], dialect: DialectType = None, copy: bool = True, **kwargs
|
||||
) -> t.Optional[Table]:
|
||||
sql_path: str | Table, dialect: DialectType = None, copy: bool = True, **kwargs
|
||||
) -> Table:
|
||||
"""
|
||||
Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional.
|
||||
If a table is passed in then that table is returned.
|
||||
|
@ -6644,35 +6668,54 @@ def to_table(
|
|||
Returns:
|
||||
A table expression.
|
||||
"""
|
||||
if sql_path is None or isinstance(sql_path, Table):
|
||||
if isinstance(sql_path, Table):
|
||||
return maybe_copy(sql_path, copy=copy)
|
||||
if not isinstance(sql_path, str):
|
||||
raise ValueError(f"Invalid type provided for a table: {type(sql_path)}")
|
||||
|
||||
table = maybe_parse(sql_path, into=Table, dialect=dialect)
|
||||
if table:
|
||||
for k, v in kwargs.items():
|
||||
table.set(k, v)
|
||||
|
||||
for k, v in kwargs.items():
|
||||
table.set(k, v)
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def to_column(sql_path: str | Column, **kwargs) -> Column:
|
||||
def to_column(
|
||||
sql_path: str | Column,
|
||||
quoted: t.Optional[bool] = None,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
**kwargs,
|
||||
) -> Column:
|
||||
"""
|
||||
Create a column from a `[table].[column]` sql path. Schema is optional.
|
||||
|
||||
Create a column from a `[table].[column]` sql path. Table is optional.
|
||||
If a column is passed in then that column is returned.
|
||||
|
||||
Args:
|
||||
sql_path: `[table].[column]` string
|
||||
sql_path: a `[table].[column]` string.
|
||||
quoted: Whether or not to force quote identifiers.
|
||||
dialect: the source dialect according to which the column name will be parsed.
|
||||
copy: Whether to copy a column if it is passed in.
|
||||
kwargs: the kwargs to instantiate the resulting `Column` expression with.
|
||||
|
||||
Returns:
|
||||
Table: A column expression
|
||||
A column expression.
|
||||
"""
|
||||
if sql_path is None or isinstance(sql_path, Column):
|
||||
return sql_path
|
||||
if not isinstance(sql_path, str):
|
||||
raise ValueError(f"Invalid type provided for column: {type(sql_path)}")
|
||||
return column(*reversed(sql_path.split(".")), **kwargs) # type: ignore
|
||||
if isinstance(sql_path, Column):
|
||||
return maybe_copy(sql_path, copy=copy)
|
||||
|
||||
try:
|
||||
col = maybe_parse(sql_path, into=Column, dialect=dialect)
|
||||
except ParseError:
|
||||
return column(*reversed(sql_path.split(".")), quoted=quoted, **kwargs)
|
||||
|
||||
for k, v in kwargs.items():
|
||||
col.set(k, v)
|
||||
|
||||
if quoted:
|
||||
for i in col.find_all(Identifier):
|
||||
i.set("quoted", True)
|
||||
|
||||
return col
|
||||
|
||||
|
||||
def alias_(
|
||||
|
@ -6756,7 +6799,7 @@ def subquery(
|
|||
A new Select instance with the subquery expression included.
|
||||
"""
|
||||
|
||||
expression = maybe_parse(expression, dialect=dialect, **opts).subquery(alias)
|
||||
expression = maybe_parse(expression, dialect=dialect, **opts).subquery(alias, **opts)
|
||||
return Select().from_(expression, dialect=dialect, **opts)
|
||||
|
||||
|
||||
|
@ -6821,7 +6864,9 @@ def column(
|
|||
)
|
||||
|
||||
if fields:
|
||||
this = Dot.build((this, *(to_identifier(field, copy=copy) for field in fields)))
|
||||
this = Dot.build(
|
||||
(this, *(to_identifier(field, quoted=quoted, copy=copy) for field in fields))
|
||||
)
|
||||
return this
|
||||
|
||||
|
||||
|
@ -6840,11 +6885,16 @@ def cast(expression: ExpOrStr, to: DATA_TYPE, copy: bool = True, **opts) -> Cast
|
|||
Returns:
|
||||
The new Cast instance.
|
||||
"""
|
||||
expression = maybe_parse(expression, copy=copy, **opts)
|
||||
expr = maybe_parse(expression, copy=copy, **opts)
|
||||
data_type = DataType.build(to, copy=copy, **opts)
|
||||
expression = Cast(this=expression, to=data_type)
|
||||
expression.type = data_type
|
||||
return expression
|
||||
|
||||
if expr.is_type(data_type):
|
||||
return expr
|
||||
|
||||
expr = Cast(this=expr, to=data_type)
|
||||
expr.type = data_type
|
||||
|
||||
return expr
|
||||
|
||||
|
||||
def table_(
|
||||
|
@ -6931,18 +6981,23 @@ def var(name: t.Optional[ExpOrStr]) -> Var:
|
|||
return Var(this=name)
|
||||
|
||||
|
||||
def rename_table(old_name: str | Table, new_name: str | Table) -> AlterTable:
|
||||
def rename_table(
|
||||
old_name: str | Table,
|
||||
new_name: str | Table,
|
||||
dialect: DialectType = None,
|
||||
) -> AlterTable:
|
||||
"""Build ALTER TABLE... RENAME... expression
|
||||
|
||||
Args:
|
||||
old_name: The old name of the table
|
||||
new_name: The new name of the table
|
||||
dialect: The dialect to parse the table.
|
||||
|
||||
Returns:
|
||||
Alter table expression
|
||||
"""
|
||||
old_table = to_table(old_name)
|
||||
new_table = to_table(new_name)
|
||||
old_table = to_table(old_name, dialect=dialect)
|
||||
new_table = to_table(new_name, dialect=dialect)
|
||||
return AlterTable(
|
||||
this=old_table,
|
||||
actions=[
|
||||
|
@ -6956,6 +7011,7 @@ def rename_column(
|
|||
old_column_name: str | Column,
|
||||
new_column_name: str | Column,
|
||||
exists: t.Optional[bool] = None,
|
||||
dialect: DialectType = None,
|
||||
) -> AlterTable:
|
||||
"""Build ALTER TABLE... RENAME COLUMN... expression
|
||||
|
||||
|
@ -6964,13 +7020,14 @@ def rename_column(
|
|||
old_column: The old name of the column
|
||||
new_column: The new name of the column
|
||||
exists: Whether to add the `IF EXISTS` clause
|
||||
dialect: The dialect to parse the table/column.
|
||||
|
||||
Returns:
|
||||
Alter table expression
|
||||
"""
|
||||
table = to_table(table_name)
|
||||
old_column = to_column(old_column_name)
|
||||
new_column = to_column(new_column_name)
|
||||
table = to_table(table_name, dialect=dialect)
|
||||
old_column = to_column(old_column_name, dialect=dialect)
|
||||
new_column = to_column(new_column_name, dialect=dialect)
|
||||
return AlterTable(
|
||||
this=table,
|
||||
actions=[
|
||||
|
@ -7366,27 +7423,6 @@ def case(
|
|||
return Case(this=this, ifs=[])
|
||||
|
||||
|
||||
def cast_unless(
|
||||
expression: ExpOrStr,
|
||||
to: DATA_TYPE,
|
||||
*types: DATA_TYPE,
|
||||
**opts: t.Any,
|
||||
) -> Expression | Cast:
|
||||
"""
|
||||
Cast an expression to a data type unless it is a specified type.
|
||||
|
||||
Args:
|
||||
expression: The expression to cast.
|
||||
to: The data type to cast to.
|
||||
**types: The types to exclude from casting.
|
||||
**opts: Extra keyword arguments for parsing `expression`
|
||||
"""
|
||||
expr = maybe_parse(expression, **opts)
|
||||
if expr.is_type(*types):
|
||||
return expr
|
||||
return cast(expr, to, **opts)
|
||||
|
||||
|
||||
def array(
|
||||
*expressions: ExpOrStr, copy: bool = True, dialect: DialectType = None, **kwargs
|
||||
) -> Array:
|
||||
|
|
|
@ -89,6 +89,8 @@ class Generator(metaclass=_Generator):
|
|||
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
|
||||
exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}",
|
||||
exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}",
|
||||
exp.EphemeralColumnConstraint: lambda self,
|
||||
e: f"EPHEMERAL{(' ' + self.sql(e, 'this')) if e.this else ''}",
|
||||
exp.ExcludeColumnConstraint: lambda self, e: f"EXCLUDE {self.sql(e, 'this').lstrip()}",
|
||||
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
|
||||
exp.ExternalProperty: lambda *_: "EXTERNAL",
|
||||
|
@ -332,6 +334,11 @@ class Generator(metaclass=_Generator):
|
|||
# Whether the function TO_NUMBER is supported
|
||||
SUPPORTS_TO_NUMBER = True
|
||||
|
||||
# Whether or not union modifiers apply to the outer union or select.
|
||||
# SELECT * FROM x UNION SELECT * FROM y LIMIT 1
|
||||
# True means limit 1 happens after the union, False means it it happens on y.
|
||||
OUTER_UNION_MODIFIERS = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
exp.DataType.Type.NCHAR: "CHAR",
|
||||
exp.DataType.Type.NVARCHAR: "VARCHAR",
|
||||
|
@ -1801,10 +1808,15 @@ class Generator(metaclass=_Generator):
|
|||
return f"{self.seg('FROM')} {self.sql(expression, 'this')}"
|
||||
|
||||
def group_sql(self, expression: exp.Group) -> str:
|
||||
group_by = self.op_expressions("GROUP BY", expression)
|
||||
group_by_all = expression.args.get("all")
|
||||
if group_by_all is True:
|
||||
modifier = " ALL"
|
||||
elif group_by_all is False:
|
||||
modifier = " DISTINCT"
|
||||
else:
|
||||
modifier = ""
|
||||
|
||||
if expression.args.get("all"):
|
||||
return f"{group_by} ALL"
|
||||
group_by = self.op_expressions(f"GROUP BY{modifier}", expression)
|
||||
|
||||
grouping_sets = self.expressions(expression, key="grouping_sets", indent=False)
|
||||
grouping_sets = (
|
||||
|
@ -2109,6 +2121,14 @@ class Generator(metaclass=_Generator):
|
|||
|
||||
return f"{this}{sort_order}{nulls_sort_change}{with_fill}"
|
||||
|
||||
def matchrecognizemeasure_sql(self, expression: exp.MatchRecognizeMeasure) -> str:
|
||||
window_frame = self.sql(expression, "window_frame")
|
||||
window_frame = f"{window_frame} " if window_frame else ""
|
||||
|
||||
this = self.sql(expression, "this")
|
||||
|
||||
return f"{window_frame}{this}"
|
||||
|
||||
def matchrecognize_sql(self, expression: exp.MatchRecognize) -> str:
|
||||
partition = self.partition_by_sql(expression)
|
||||
order = self.sql(expression, "order")
|
||||
|
@ -2297,6 +2317,19 @@ class Generator(metaclass=_Generator):
|
|||
return f"{self.seg('QUALIFY')}{self.sep()}{this}"
|
||||
|
||||
def set_operations(self, expression: exp.Union) -> str:
|
||||
if not self.OUTER_UNION_MODIFIERS:
|
||||
limit = expression.args.get("limit")
|
||||
order = expression.args.get("order")
|
||||
|
||||
if limit or order:
|
||||
select = exp.subquery(expression, "_l_0", copy=False).select("*", copy=False)
|
||||
|
||||
if limit:
|
||||
select = select.limit(limit.pop(), copy=False)
|
||||
if order:
|
||||
select = select.order_by(order.pop(), copy=False)
|
||||
return self.sql(select)
|
||||
|
||||
sqls: t.List[str] = []
|
||||
stack: t.List[t.Union[str, exp.Expression]] = [expression]
|
||||
|
||||
|
@ -2412,12 +2445,15 @@ class Generator(metaclass=_Generator):
|
|||
high = self.sql(expression, "high")
|
||||
return f"{this} BETWEEN {low} AND {high}"
|
||||
|
||||
def bracket_sql(self, expression: exp.Bracket) -> str:
|
||||
expressions = apply_index_offset(
|
||||
def bracket_offset_expressions(self, expression: exp.Bracket) -> t.List[exp.Expression]:
|
||||
return apply_index_offset(
|
||||
expression.this,
|
||||
expression.expressions,
|
||||
self.dialect.INDEX_OFFSET - expression.args.get("offset", 0),
|
||||
)
|
||||
|
||||
def bracket_sql(self, expression: exp.Bracket) -> str:
|
||||
expressions = self.bracket_offset_expressions(expression)
|
||||
expressions_sql = ", ".join(self.sql(e) for e in expressions)
|
||||
return f"{self.sql(expression, 'this')}[{expressions_sql}]"
|
||||
|
||||
|
@ -2486,7 +2522,7 @@ class Generator(metaclass=_Generator):
|
|||
args = args[1:] # Skip the delimiter
|
||||
|
||||
if self.dialect.STRICT_STRING_CONCAT and expression.args.get("safe"):
|
||||
args = [exp.cast(e, "text") for e in args]
|
||||
args = [exp.cast(e, exp.DataType.Type.TEXT) for e in args]
|
||||
|
||||
if not self.dialect.CONCAT_COALESCE and expression.args.get("coalesce"):
|
||||
args = [exp.func("coalesce", e, exp.Literal.string("")) for e in args]
|
||||
|
@ -2670,7 +2706,7 @@ class Generator(metaclass=_Generator):
|
|||
is_global = " GLOBAL" if expression.args.get("is_global") else ""
|
||||
|
||||
if query:
|
||||
in_sql = self.wrap(self.sql(query))
|
||||
in_sql = self.sql(query)
|
||||
elif unnest:
|
||||
in_sql = self.in_unnest_op(unnest)
|
||||
elif field:
|
||||
|
@ -2859,9 +2895,10 @@ class Generator(metaclass=_Generator):
|
|||
def comment_sql(self, expression: exp.Comment) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
kind = expression.args["kind"]
|
||||
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
|
||||
exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
|
||||
expression_sql = self.sql(expression, "expression")
|
||||
return f"COMMENT{exists_sql}ON {kind} {this} IS {expression_sql}"
|
||||
return f"COMMENT{exists_sql}ON{materialized} {kind} {this} IS {expression_sql}"
|
||||
|
||||
def mergetreettlaction_sql(self, expression: exp.MergeTreeTTLAction) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -3011,7 +3048,9 @@ class Generator(metaclass=_Generator):
|
|||
|
||||
def dpipe_sql(self, expression: exp.DPipe) -> str:
|
||||
if self.dialect.STRICT_STRING_CONCAT and expression.args.get("safe"):
|
||||
return self.func("CONCAT", *(exp.cast(e, "text") for e in expression.flatten()))
|
||||
return self.func(
|
||||
"CONCAT", *(exp.cast(e, exp.DataType.Type.TEXT) for e in expression.flatten())
|
||||
)
|
||||
return self.binary(expression, "||")
|
||||
|
||||
def div_sql(self, expression: exp.Div) -> str:
|
||||
|
@ -3210,11 +3249,8 @@ class Generator(metaclass=_Generator):
|
|||
num_sqls = len(expressions)
|
||||
|
||||
# These are calculated once in case we have the leading_comma / pretty option set, correspondingly
|
||||
if self.pretty:
|
||||
if self.leading_comma:
|
||||
pad = " " * len(sep)
|
||||
else:
|
||||
stripped_sep = sep.strip()
|
||||
if self.pretty and not self.leading_comma:
|
||||
stripped_sep = sep.strip()
|
||||
|
||||
result_sqls = []
|
||||
for i, e in enumerate(expressions):
|
||||
|
@ -3226,7 +3262,7 @@ class Generator(metaclass=_Generator):
|
|||
|
||||
if self.pretty:
|
||||
if self.leading_comma:
|
||||
result_sqls.append(f"{sep if i > 0 else pad}{prefix}{sql}{comments}")
|
||||
result_sqls.append(f"{sep if i > 0 else ''}{prefix}{sql}{comments}")
|
||||
else:
|
||||
result_sqls.append(
|
||||
f"{prefix}{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}"
|
||||
|
@ -3314,17 +3350,17 @@ class Generator(metaclass=_Generator):
|
|||
if expression.args.get("format"):
|
||||
self.unsupported("Format argument unsupported for TO_CHAR/TO_VARCHAR function")
|
||||
|
||||
return self.sql(exp.cast(expression.this, "text"))
|
||||
return self.sql(exp.cast(expression.this, exp.DataType.Type.TEXT))
|
||||
|
||||
def tonumber_sql(self, expression: exp.ToNumber) -> str:
|
||||
if not self.SUPPORTS_TO_NUMBER:
|
||||
self.unsupported("Unsupported TO_NUMBER function")
|
||||
return self.sql(exp.cast(expression.this, "double"))
|
||||
return self.sql(exp.cast(expression.this, exp.DataType.Type.DOUBLE))
|
||||
|
||||
fmt = expression.args.get("format")
|
||||
if not fmt:
|
||||
self.unsupported("Conversion format is required for TO_NUMBER")
|
||||
return self.sql(exp.cast(expression.this, "double"))
|
||||
return self.sql(exp.cast(expression.this, exp.DataType.Type.DOUBLE))
|
||||
|
||||
return self.func("TO_NUMBER", expression.this, fmt)
|
||||
|
||||
|
@ -3495,14 +3531,14 @@ class Generator(metaclass=_Generator):
|
|||
if isinstance(this, exp.TsOrDsToTime) or this.is_type(exp.DataType.Type.TIME):
|
||||
return self.sql(this)
|
||||
|
||||
return self.sql(exp.cast(this, "time"))
|
||||
return self.sql(exp.cast(this, exp.DataType.Type.TIME))
|
||||
|
||||
def tsordstotimestamp_sql(self, expression: exp.TsOrDsToTimestamp) -> str:
|
||||
this = expression.this
|
||||
if isinstance(this, exp.TsOrDsToTimestamp) or this.is_type(exp.DataType.Type.TIMESTAMP):
|
||||
return self.sql(this)
|
||||
|
||||
return self.sql(exp.cast(this, "timestamp"))
|
||||
return self.sql(exp.cast(this, exp.DataType.Type.TIMESTAMP))
|
||||
|
||||
def tsordstodate_sql(self, expression: exp.TsOrDsToDate) -> str:
|
||||
this = expression.this
|
||||
|
@ -3510,20 +3546,23 @@ class Generator(metaclass=_Generator):
|
|||
|
||||
if time_format and time_format not in (self.dialect.TIME_FORMAT, self.dialect.DATE_FORMAT):
|
||||
return self.sql(
|
||||
exp.cast(exp.StrToTime(this=this, format=expression.args["format"]), "date")
|
||||
exp.cast(
|
||||
exp.StrToTime(this=this, format=expression.args["format"]),
|
||||
exp.DataType.Type.DATE,
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(this, exp.TsOrDsToDate) or this.is_type(exp.DataType.Type.DATE):
|
||||
return self.sql(this)
|
||||
|
||||
return self.sql(exp.cast(this, "date"))
|
||||
return self.sql(exp.cast(this, exp.DataType.Type.DATE))
|
||||
|
||||
def unixdate_sql(self, expression: exp.UnixDate) -> str:
|
||||
return self.sql(
|
||||
exp.func(
|
||||
"DATEDIFF",
|
||||
expression.this,
|
||||
exp.cast(exp.Literal.string("1970-01-01"), "date"),
|
||||
exp.cast(exp.Literal.string("1970-01-01"), exp.DataType.Type.DATE),
|
||||
"day",
|
||||
)
|
||||
)
|
||||
|
|
|
@ -212,6 +212,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.Month,
|
||||
exp.Week,
|
||||
exp.Year,
|
||||
exp.Quarter,
|
||||
},
|
||||
exp.DataType.Type.VARCHAR: {
|
||||
exp.ArrayConcat,
|
||||
|
@ -504,7 +505,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
last_datatype = expr_type
|
||||
break
|
||||
|
||||
last_datatype = self._maybe_coerce(last_datatype or expr_type, expr_type)
|
||||
if not expr_type.is_type(exp.DataType.Type.NULL, exp.DataType.Type.UNKNOWN):
|
||||
last_datatype = self._maybe_coerce(last_datatype or expr_type, expr_type)
|
||||
|
||||
self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)
|
||||
|
||||
|
|
|
@ -66,7 +66,7 @@ def qualify(
|
|||
"""
|
||||
schema = ensure_schema(schema, dialect=dialect)
|
||||
expression = normalize_identifiers(expression, dialect=dialect)
|
||||
expression = qualify_tables(expression, db=db, catalog=catalog, schema=schema)
|
||||
expression = qualify_tables(expression, db=db, catalog=catalog, schema=schema, dialect=dialect)
|
||||
|
||||
if isolate_tables:
|
||||
expression = isolate_table_selects(expression, schema=schema)
|
||||
|
|
|
@ -7,6 +7,7 @@ from sqlglot import alias, exp
|
|||
from sqlglot.dialects.dialect import Dialect, DialectType
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import seq_get, SingleValuedMapping
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
|
||||
from sqlglot.optimizer.simplify import simplify_parens
|
||||
from sqlglot.schema import Schema, ensure_schema
|
||||
|
@ -652,8 +653,19 @@ class Resolver:
|
|||
|
||||
if isinstance(source, exp.Table):
|
||||
columns = self.schema.column_names(source, only_visible)
|
||||
elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
|
||||
columns = source.expression.alias_column_names
|
||||
elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)):
|
||||
columns = source.expression.named_selects
|
||||
|
||||
# in bigquery, unnest structs are automatically scoped as tables, so you can
|
||||
# directly select a struct field in a query.
|
||||
# this handles the case where the unnest is statically defined.
|
||||
if self.schema.dialect == "bigquery":
|
||||
expression = source.expression
|
||||
annotate_types(expression)
|
||||
|
||||
if expression.is_type(exp.DataType.Type.STRUCT):
|
||||
for k in expression.type.expressions: # type: ignore
|
||||
columns.append(k.name)
|
||||
else:
|
||||
columns = source.expression.named_selects
|
||||
|
||||
|
|
|
@ -55,7 +55,7 @@ def qualify_tables(
|
|||
if not table.args.get("catalog") and table.args.get("db"):
|
||||
table.set("catalog", catalog)
|
||||
|
||||
if not isinstance(expression, exp.Query):
|
||||
if (db or catalog) and not isinstance(expression, exp.Query):
|
||||
for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)):
|
||||
if isinstance(node, exp.Table):
|
||||
_qualify(node)
|
||||
|
@ -78,10 +78,10 @@ def qualify_tables(
|
|||
if pivots and not pivots[0].alias:
|
||||
pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())))
|
||||
|
||||
table_aliases = {}
|
||||
|
||||
for name, source in scope.sources.items():
|
||||
if isinstance(source, exp.Table):
|
||||
_qualify(source)
|
||||
|
||||
pivots = pivots = source.args.get("pivots")
|
||||
if not source.alias:
|
||||
# Don't add the pivot's alias to the pivoted table, use the table's name instead
|
||||
|
@ -91,6 +91,12 @@ def qualify_tables(
|
|||
# Mutates the source by attaching an alias to it
|
||||
alias(source, name or source.name or next_alias_name(), copy=False, table=True)
|
||||
|
||||
table_aliases[".".join(p.name for p in source.parts)] = exp.to_identifier(
|
||||
source.alias
|
||||
)
|
||||
|
||||
_qualify(source)
|
||||
|
||||
if pivots and not pivots[0].alias:
|
||||
pivots[0].set(
|
||||
"alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))
|
||||
|
@ -127,4 +133,13 @@ def qualify_tables(
|
|||
# Mutates the table by attaching an alias to it
|
||||
alias(node, node.name, copy=False, table=True)
|
||||
|
||||
for column in scope.columns:
|
||||
if column.db:
|
||||
table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1]))
|
||||
|
||||
if table_alias:
|
||||
for p in exp.COLUMN_PARTS[1:]:
|
||||
column.set(p, None)
|
||||
column.set("table", table_alias)
|
||||
|
||||
return expression
|
||||
|
|
|
@ -600,7 +600,7 @@ def _traverse_ctes(scope):
|
|||
sources = {}
|
||||
|
||||
for cte in scope.ctes:
|
||||
recursive_scope = None
|
||||
cte_name = cte.alias
|
||||
|
||||
# if the scope is a recursive cte, it must be in the form of base_case UNION recursive.
|
||||
# thus the recursive scope is the first section of the union.
|
||||
|
@ -609,7 +609,7 @@ def _traverse_ctes(scope):
|
|||
union = cte.this
|
||||
|
||||
if isinstance(union, exp.Union):
|
||||
recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE)
|
||||
sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE)
|
||||
|
||||
child_scope = None
|
||||
|
||||
|
@ -623,15 +623,9 @@ def _traverse_ctes(scope):
|
|||
):
|
||||
yield child_scope
|
||||
|
||||
alias = cte.alias
|
||||
sources[alias] = child_scope
|
||||
|
||||
if recursive_scope:
|
||||
child_scope.add_source(alias, recursive_scope)
|
||||
child_scope.cte_sources[alias] = recursive_scope
|
||||
|
||||
# append the final child_scope yielded
|
||||
if child_scope:
|
||||
sources[cte_name] = child_scope
|
||||
scope.cte_scopes.append(child_scope)
|
||||
|
||||
scope.sources.update(sources)
|
||||
|
|
|
@ -41,8 +41,6 @@ def unnest(select, parent_select, next_alias_name):
|
|||
return
|
||||
|
||||
predicate = select.find_ancestor(exp.Condition)
|
||||
alias = next_alias_name()
|
||||
|
||||
if (
|
||||
not predicate
|
||||
or parent_select is not predicate.parent_select
|
||||
|
@ -50,6 +48,10 @@ def unnest(select, parent_select, next_alias_name):
|
|||
):
|
||||
return
|
||||
|
||||
if isinstance(select, exp.Union):
|
||||
select = exp.select(*select.selects).from_(select.subquery(next_alias_name()))
|
||||
|
||||
alias = next_alias_name()
|
||||
clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
|
||||
|
||||
# This subquery returns a scalar and can just be converted to a cross join
|
||||
|
|
|
@ -344,6 +344,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.FINAL,
|
||||
TokenType.FORMAT,
|
||||
TokenType.FULL,
|
||||
TokenType.IDENTIFIER,
|
||||
TokenType.IS,
|
||||
TokenType.ISNULL,
|
||||
TokenType.INTERVAL,
|
||||
|
@ -852,6 +853,9 @@ class Parser(metaclass=_Parser):
|
|||
exp.DefaultColumnConstraint, this=self._parse_bitwise()
|
||||
),
|
||||
"ENCODE": lambda self: self.expression(exp.EncodeColumnConstraint, this=self._parse_var()),
|
||||
"EPHEMERAL": lambda self: self.expression(
|
||||
exp.EphemeralColumnConstraint, this=self._parse_bitwise()
|
||||
),
|
||||
"EXCLUDE": lambda self: self.expression(
|
||||
exp.ExcludeColumnConstraint, this=self._parse_index_params()
|
||||
),
|
||||
|
@ -1384,6 +1388,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
self._match(TokenType.ON)
|
||||
|
||||
materialized = self._match_text_seq("MATERIALIZED")
|
||||
kind = self._match_set(self.CREATABLES) and self._prev
|
||||
if not kind:
|
||||
return self._parse_as_command(start)
|
||||
|
@ -1400,7 +1405,12 @@ class Parser(metaclass=_Parser):
|
|||
self._match(TokenType.IS)
|
||||
|
||||
return self.expression(
|
||||
exp.Comment, this=this, kind=kind.text, expression=self._parse_string(), exists=exists
|
||||
exp.Comment,
|
||||
this=this,
|
||||
kind=kind.text,
|
||||
expression=self._parse_string(),
|
||||
exists=exists,
|
||||
materialized=materialized,
|
||||
)
|
||||
|
||||
def _parse_to_table(
|
||||
|
@ -2188,7 +2198,10 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
def _parse_describe(self) -> exp.Describe:
|
||||
kind = self._match_set(self.CREATABLES) and self._prev.text
|
||||
style = self._match_texts(("EXTENDED", "FORMATTED")) and self._prev.text.upper()
|
||||
style = self._match_texts(("EXTENDED", "FORMATTED", "HISTORY")) and self._prev.text.upper()
|
||||
if not self._match_set(self.ID_VAR_TOKENS, advance=False):
|
||||
style = None
|
||||
self._retreat(self._index - 1)
|
||||
this = self._parse_table(schema=True)
|
||||
properties = self._parse_properties()
|
||||
expressions = properties.expressions if properties else None
|
||||
|
@ -2731,6 +2744,13 @@ class Parser(metaclass=_Parser):
|
|||
exp.From, comments=self._prev_comments, this=self._parse_table(joins=joins)
|
||||
)
|
||||
|
||||
def _parse_match_recognize_measure(self) -> exp.MatchRecognizeMeasure:
|
||||
return self.expression(
|
||||
exp.MatchRecognizeMeasure,
|
||||
window_frame=self._match_texts(("FINAL", "RUNNING")) and self._prev.text.upper(),
|
||||
this=self._parse_expression(),
|
||||
)
|
||||
|
||||
def _parse_match_recognize(self) -> t.Optional[exp.MatchRecognize]:
|
||||
if not self._match(TokenType.MATCH_RECOGNIZE):
|
||||
return None
|
||||
|
@ -2739,7 +2759,12 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
partition = self._parse_partition_by()
|
||||
order = self._parse_order()
|
||||
measures = self._parse_expressions() if self._match_text_seq("MEASURES") else None
|
||||
|
||||
measures = (
|
||||
self._parse_csv(self._parse_match_recognize_measure)
|
||||
if self._match_text_seq("MEASURES")
|
||||
else None
|
||||
)
|
||||
|
||||
if self._match_text_seq("ONE", "ROW", "PER", "MATCH"):
|
||||
rows = exp.var("ONE ROW PER MATCH")
|
||||
|
@ -3444,10 +3469,12 @@ class Parser(metaclass=_Parser):
|
|||
if not skip_group_by_token and not self._match(TokenType.GROUP_BY):
|
||||
return None
|
||||
|
||||
elements = defaultdict(list)
|
||||
elements: t.Dict[str, t.Any] = defaultdict(list)
|
||||
|
||||
if self._match(TokenType.ALL):
|
||||
return self.expression(exp.Group, all=True)
|
||||
elements["all"] = True
|
||||
elif self._match(TokenType.DISTINCT):
|
||||
elements["all"] = False
|
||||
|
||||
while True:
|
||||
expressions = self._parse_csv(self._parse_conjunction)
|
||||
|
@ -3808,7 +3835,7 @@ class Parser(metaclass=_Parser):
|
|||
expressions = self._parse_csv(lambda: self._parse_select_or_expression(alias=alias))
|
||||
|
||||
if len(expressions) == 1 and isinstance(expressions[0], exp.Query):
|
||||
this = self.expression(exp.In, this=this, query=expressions[0])
|
||||
this = self.expression(exp.In, this=this, query=expressions[0].subquery(copy=False))
|
||||
else:
|
||||
this = self.expression(exp.In, this=this, expressions=expressions)
|
||||
|
||||
|
@ -4504,12 +4531,15 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
constraints: t.List[exp.Expression] = []
|
||||
|
||||
if (not kind and self._match(TokenType.ALIAS)) or self._match_text_seq("ALIAS"):
|
||||
if (not kind and self._match(TokenType.ALIAS)) or self._match_texts(
|
||||
("ALIAS", "MATERIALIZED")
|
||||
):
|
||||
persisted = self._prev.text.upper() == "MATERIALIZED"
|
||||
constraints.append(
|
||||
self.expression(
|
||||
exp.ComputedColumnConstraint,
|
||||
this=self._parse_conjunction(),
|
||||
persisted=self._match_text_seq("PERSISTED"),
|
||||
persisted=persisted or self._match_text_seq("PERSISTED"),
|
||||
not_null=self._match_pair(TokenType.NOT, TokenType.NULL),
|
||||
)
|
||||
)
|
||||
|
|
|
@ -140,6 +140,26 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr
|
|||
return expression
|
||||
|
||||
|
||||
def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
|
||||
"""Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
|
||||
from sqlglot.optimizer.scope import find_all_in_scope
|
||||
|
||||
if isinstance(expression, exp.Select):
|
||||
unnest_aliases = {
|
||||
unnest.alias
|
||||
for unnest in find_all_in_scope(expression, exp.Unnest)
|
||||
if isinstance(unnest.parent, (exp.From, exp.Join))
|
||||
}
|
||||
if unnest_aliases:
|
||||
for column in expression.find_all(exp.Column):
|
||||
if column.table in unnest_aliases:
|
||||
column.set("table", None)
|
||||
elif column.db in unnest_aliases:
|
||||
column.set("db", None)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
|
||||
"""Convert cross join unnest into lateral view explode."""
|
||||
if isinstance(expression, exp.Select):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue