Merging upstream version 22.2.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
b13ba670fd
commit
2c28c49d7e
148 changed files with 68457 additions and 63176 deletions
|
@ -88,13 +88,11 @@ def parse(
|
|||
|
||||
|
||||
@t.overload
|
||||
def parse_one(sql: str, *, into: t.Type[E], **opts) -> E:
|
||||
...
|
||||
def parse_one(sql: str, *, into: t.Type[E], **opts) -> E: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def parse_one(sql: str, **opts) -> Expression:
|
||||
...
|
||||
def parse_one(sql: str, **opts) -> Expression: ...
|
||||
|
||||
|
||||
def parse_one(
|
||||
|
|
|
@ -140,12 +140,10 @@ class DataFrame:
|
|||
return cte, name
|
||||
|
||||
@t.overload
|
||||
def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]:
|
||||
...
|
||||
def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: ...
|
||||
|
||||
@t.overload
|
||||
def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]:
|
||||
...
|
||||
def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: ...
|
||||
|
||||
def _ensure_list_of_columns(self, cols):
|
||||
return Column.ensure_cols(ensure_list(cols))
|
||||
|
|
|
@ -210,7 +210,7 @@ def sec(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def signum(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_anonymous_function(col, "SIGNUM")
|
||||
return Column.invoke_expression_over_column(col, expression.Sign)
|
||||
|
||||
|
||||
def sin(col: ColumnOrName) -> Column:
|
||||
|
@ -592,7 +592,7 @@ def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Column:
|
||||
return Column.invoke_anonymous_function(start, "ADD_MONTHS", months)
|
||||
return Column.invoke_expression_over_column(start, expression.AddMonths, expression=months)
|
||||
|
||||
|
||||
def months_between(
|
||||
|
|
|
@ -42,7 +42,10 @@ def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Va
|
|||
alias = expression.args.get("alias")
|
||||
for tup in expression.find_all(exp.Tuple):
|
||||
field_aliases = alias.columns if alias else (f"_c{i}" for i in range(len(tup.expressions)))
|
||||
expressions = [exp.alias_(fld, name) for fld, name in zip(tup.expressions, field_aliases)]
|
||||
expressions = [
|
||||
exp.PropertyEQ(this=exp.to_identifier(name), expression=fld)
|
||||
for name, fld in zip(field_aliases, tup.expressions)
|
||||
]
|
||||
structs.append(exp.Struct(expressions=expressions))
|
||||
|
||||
return self.unnest_sql(exp.Unnest(expressions=[exp.array(*structs, copy=False)]))
|
||||
|
@ -111,6 +114,8 @@ def _alias_ordered_group(expression: exp.Expression) -> exp.Expression:
|
|||
}
|
||||
|
||||
for grouped in group.expressions:
|
||||
if grouped.is_int:
|
||||
continue
|
||||
alias = aliases.get(grouped)
|
||||
if alias:
|
||||
grouped.replace(exp.column(alias))
|
||||
|
@ -226,8 +231,11 @@ class BigQuery(Dialect):
|
|||
# bigquery udfs are case sensitive
|
||||
NORMALIZE_FUNCTIONS = False
|
||||
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_elements_date_time
|
||||
TIME_MAPPING = {
|
||||
"%D": "%m/%d/%y",
|
||||
"%E*S": "%S.%f",
|
||||
"%E6S": "%S.%f",
|
||||
}
|
||||
|
||||
ESCAPE_SEQUENCES = {
|
||||
|
@ -266,14 +274,20 @@ class BigQuery(Dialect):
|
|||
while isinstance(parent, exp.Dot):
|
||||
parent = parent.parent
|
||||
|
||||
# In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least).
|
||||
# The following check is essentially a heuristic to detect tables based on whether or
|
||||
# not they're qualified. It also avoids normalizing UDFs, because they're case-sensitive.
|
||||
if (
|
||||
not isinstance(parent, exp.UserDefinedFunction)
|
||||
and not (isinstance(parent, exp.Table) and parent.db)
|
||||
and not expression.meta.get("is_table")
|
||||
):
|
||||
# In BigQuery, CTEs are case-insensitive, but UDF and table names are case-sensitive
|
||||
# by default. The following check uses a heuristic to detect tables based on whether
|
||||
# they are qualified. This should generally be correct, because tables in BigQuery
|
||||
# must be qualified with at least a dataset, unless @@dataset_id is set.
|
||||
case_sensitive = (
|
||||
isinstance(parent, exp.UserDefinedFunction)
|
||||
or (
|
||||
isinstance(parent, exp.Table)
|
||||
and parent.db
|
||||
and (parent.meta.get("quoted_table") or not parent.meta.get("maybe_column"))
|
||||
)
|
||||
or expression.meta.get("is_table")
|
||||
)
|
||||
if not case_sensitive:
|
||||
expression.set("this", expression.this.lower())
|
||||
|
||||
return expression
|
||||
|
@ -302,6 +316,7 @@ class BigQuery(Dialect):
|
|||
"BYTES": TokenType.BINARY,
|
||||
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
|
||||
"DECLARE": TokenType.COMMAND,
|
||||
"ELSEIF": TokenType.COMMAND,
|
||||
"EXCEPTION": TokenType.COMMAND,
|
||||
"FLOAT64": TokenType.DOUBLE,
|
||||
"FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
|
||||
|
@ -315,8 +330,8 @@ class BigQuery(Dialect):
|
|||
|
||||
class Parser(parser.Parser):
|
||||
PREFIXED_PIVOT_COLUMNS = True
|
||||
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
SUPPORTS_IMPLICIT_UNNEST = True
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
|
@ -410,6 +425,7 @@ class BigQuery(Dialect):
|
|||
|
||||
STATEMENT_PARSERS = {
|
||||
**parser.Parser.STATEMENT_PARSERS,
|
||||
TokenType.ELSE: lambda self: self._parse_as_command(self._prev),
|
||||
TokenType.END: lambda self: self._parse_as_command(self._prev),
|
||||
TokenType.FOR: lambda self: self._parse_for_in(),
|
||||
}
|
||||
|
@ -433,8 +449,11 @@ class BigQuery(Dialect):
|
|||
if isinstance(this, exp.Identifier):
|
||||
table_name = this.name
|
||||
while self._match(TokenType.DASH, advance=False) and self._next:
|
||||
self._advance(2)
|
||||
table_name += f"-{self._prev.text}"
|
||||
text = ""
|
||||
while self._curr and self._curr.token_type != TokenType.DOT:
|
||||
self._advance()
|
||||
text += self._prev.text
|
||||
table_name += text
|
||||
|
||||
this = exp.Identifier(this=table_name, quoted=this.args.get("quoted"))
|
||||
elif isinstance(this, exp.Literal):
|
||||
|
@ -448,12 +467,28 @@ class BigQuery(Dialect):
|
|||
return this
|
||||
|
||||
def _parse_table_parts(
|
||||
self, schema: bool = False, is_db_reference: bool = False
|
||||
self, schema: bool = False, is_db_reference: bool = False, wildcard: bool = False
|
||||
) -> exp.Table:
|
||||
table = super()._parse_table_parts(schema=schema, is_db_reference=is_db_reference)
|
||||
table = super()._parse_table_parts(
|
||||
schema=schema, is_db_reference=is_db_reference, wildcard=True
|
||||
)
|
||||
|
||||
# proj-1.db.tbl -- `1.` is tokenized as a float so we need to unravel it here
|
||||
if not table.catalog:
|
||||
if table.db:
|
||||
parts = table.db.split(".")
|
||||
if len(parts) == 2 and not table.args["db"].quoted:
|
||||
table.set("catalog", exp.Identifier(this=parts[0]))
|
||||
table.set("db", exp.Identifier(this=parts[1]))
|
||||
else:
|
||||
parts = table.name.split(".")
|
||||
if len(parts) == 2 and not table.this.quoted:
|
||||
table.set("db", exp.Identifier(this=parts[0]))
|
||||
table.set("this", exp.Identifier(this=parts[1]))
|
||||
|
||||
if isinstance(table.this, exp.Identifier) and "." in table.name:
|
||||
catalog, db, this, *rest = (
|
||||
t.cast(t.Optional[exp.Expression], exp.to_identifier(x))
|
||||
t.cast(t.Optional[exp.Expression], exp.to_identifier(x, quoted=True))
|
||||
for x in split_num_words(table.name, ".", 3)
|
||||
)
|
||||
|
||||
|
@ -461,16 +496,15 @@ class BigQuery(Dialect):
|
|||
this = exp.Dot.build(t.cast(t.List[exp.Expression], [this, *rest]))
|
||||
|
||||
table = exp.Table(this=this, db=db, catalog=catalog)
|
||||
table.meta["quoted_table"] = True
|
||||
|
||||
return table
|
||||
|
||||
@t.overload
|
||||
def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject:
|
||||
...
|
||||
def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ...
|
||||
|
||||
@t.overload
|
||||
def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg:
|
||||
...
|
||||
def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ...
|
||||
|
||||
def _parse_json_object(self, agg=False):
|
||||
json_object = super()._parse_json_object()
|
||||
|
@ -532,6 +566,7 @@ class BigQuery(Dialect):
|
|||
IGNORE_NULLS_IN_FUNC = True
|
||||
JSON_PATH_SINGLE_QUOTE_ESCAPE = True
|
||||
CAN_IMPLEMENT_ARRAY_ANY = True
|
||||
NAMED_PLACEHOLDER_TOKEN = "@"
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -762,22 +797,25 @@ class BigQuery(Dialect):
|
|||
"within",
|
||||
}
|
||||
|
||||
def table_parts(self, expression: exp.Table) -> str:
|
||||
# Depending on the context, `x.y` may not resolve to the same data source as `x`.`y`, so
|
||||
# we need to make sure the correct quoting is used in each case.
|
||||
#
|
||||
# For example, if there is a CTE x that clashes with a schema name, then the former will
|
||||
# return the table y in that schema, whereas the latter will return the CTE's y column:
|
||||
#
|
||||
# - WITH x AS (SELECT [1, 2] AS y) SELECT * FROM x, `x.y` -> cross join
|
||||
# - WITH x AS (SELECT [1, 2] AS y) SELECT * FROM x, `x`.`y` -> implicit unnest
|
||||
if expression.meta.get("quoted_table"):
|
||||
table_parts = ".".join(p.name for p in expression.parts)
|
||||
return self.sql(exp.Identifier(this=table_parts, quoted=True))
|
||||
|
||||
return super().table_parts(expression)
|
||||
|
||||
def timetostr_sql(self, expression: exp.TimeToStr) -> str:
|
||||
this = expression.this if isinstance(expression.this, exp.TsOrDsToDate) else expression
|
||||
return self.func("FORMAT_DATE", self.format_time(expression), this.this)
|
||||
|
||||
def struct_sql(self, expression: exp.Struct) -> str:
|
||||
args = []
|
||||
for expr in expression.expressions:
|
||||
if isinstance(expr, self.KEY_VALUE_DEFINITIONS):
|
||||
arg = f"{self.sql(expr, 'expression')} AS {expr.this.name}"
|
||||
else:
|
||||
arg = self.sql(expr)
|
||||
|
||||
args.append(arg)
|
||||
|
||||
return self.func("STRUCT", *args)
|
||||
|
||||
def eq_sql(self, expression: exp.EQ) -> str:
|
||||
# Operands of = cannot be NULL in BigQuery
|
||||
if isinstance(expression.left, exp.Null) or isinstance(expression.right, exp.Null):
|
||||
|
@ -803,7 +841,7 @@ class BigQuery(Dialect):
|
|||
|
||||
def array_sql(self, expression: exp.Array) -> str:
|
||||
first_arg = seq_get(expression.expressions, 0)
|
||||
if isinstance(first_arg, exp.Subqueryable):
|
||||
if isinstance(first_arg, exp.Query):
|
||||
return f"ARRAY{self.wrap(self.sql(first_arg))}"
|
||||
|
||||
return inline_array_sql(self, expression)
|
||||
|
|
|
@ -68,7 +68,6 @@ class ClickHouse(Dialect):
|
|||
"DATE32": TokenType.DATE32,
|
||||
"DATETIME64": TokenType.DATETIME64,
|
||||
"DICTIONARY": TokenType.DICTIONARY,
|
||||
"ENUM": TokenType.ENUM,
|
||||
"ENUM8": TokenType.ENUM8,
|
||||
"ENUM16": TokenType.ENUM16,
|
||||
"FINAL": TokenType.FINAL,
|
||||
|
@ -93,6 +92,7 @@ class ClickHouse(Dialect):
|
|||
"AGGREGATEFUNCTION": TokenType.AGGREGATEFUNCTION,
|
||||
"SIMPLEAGGREGATEFUNCTION": TokenType.SIMPLEAGGREGATEFUNCTION,
|
||||
"SYSTEM": TokenType.COMMAND,
|
||||
"PREWHERE": TokenType.PREWHERE,
|
||||
}
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
|
@ -129,6 +129,7 @@ class ClickHouse(Dialect):
|
|||
"MAP": parser.build_var_map,
|
||||
"MATCH": exp.RegexpLike.from_arg_list,
|
||||
"RANDCANONICAL": exp.Rand.from_arg_list,
|
||||
"TUPLE": exp.Struct.from_arg_list,
|
||||
"UNIQ": exp.ApproxDistinct.from_arg_list,
|
||||
"XOR": lambda args: exp.Xor(expressions=args),
|
||||
}
|
||||
|
@ -390,7 +391,7 @@ class ClickHouse(Dialect):
|
|||
|
||||
return self.expression(
|
||||
exp.CTE,
|
||||
this=self._parse_field(),
|
||||
this=self._parse_conjunction(),
|
||||
alias=self._parse_table_alias(),
|
||||
scalar=True,
|
||||
)
|
||||
|
@ -732,3 +733,7 @@ class ClickHouse(Dialect):
|
|||
return f"{this_name}{self.sep()}{this_properties}{self.sep()}{this_schema}"
|
||||
|
||||
return super().createable_sql(expression, locations)
|
||||
|
||||
def prewhere_sql(self, expression: exp.PreWhere) -> str:
|
||||
this = self.indent(self.sql(expression, "this"))
|
||||
return f"{self.seg('PREWHERE')}{self.sep()}{this}"
|
||||
|
|
|
@ -69,7 +69,7 @@ class Databricks(Spark):
|
|||
|
||||
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
|
||||
constraint = expression.find(exp.GeneratedAsIdentityColumnConstraint)
|
||||
kind = expression.args.get("kind")
|
||||
kind = expression.kind
|
||||
if (
|
||||
constraint
|
||||
and isinstance(kind, exp.DataType)
|
||||
|
|
|
@ -443,7 +443,7 @@ class Dialect(metaclass=_Dialect):
|
|||
identify: If set to `False`, the quotes will only be added if the identifier is deemed
|
||||
"unsafe", with respect to its characters and this dialect's normalization strategy.
|
||||
"""
|
||||
if isinstance(expression, exp.Identifier):
|
||||
if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
|
||||
name = expression.this
|
||||
expression.set(
|
||||
"quoted",
|
||||
|
|
|
@ -21,6 +21,7 @@ class Doris(MySQL):
|
|||
**MySQL.Parser.FUNCTIONS,
|
||||
"COLLECT_SET": exp.ArrayUniqueAgg.from_arg_list,
|
||||
"DATE_TRUNC": build_timestamp_trunc,
|
||||
"MONTHS_ADD": exp.AddMonths.from_arg_list,
|
||||
"REGEXP": exp.RegexpLike.from_arg_list,
|
||||
"TO_DATE": exp.TsOrDsToDate.from_arg_list,
|
||||
}
|
||||
|
@ -41,6 +42,7 @@ class Doris(MySQL):
|
|||
|
||||
TRANSFORMS = {
|
||||
**MySQL.Generator.TRANSFORMS,
|
||||
exp.AddMonths: rename_func("MONTHS_ADD"),
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.ArgMax: rename_func("MAX_BY"),
|
||||
exp.ArgMin: rename_func("MIN_BY"),
|
||||
|
@ -58,7 +60,6 @@ class Doris(MySQL):
|
|||
exp.StrToUnix: lambda self, e: self.func("UNIX_TIMESTAMP", e.this, self.format_time(e)),
|
||||
exp.Split: rename_func("SPLIT_BY_STRING"),
|
||||
exp.TimeStrToDate: rename_func("TO_DATE"),
|
||||
exp.ToChar: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)),
|
||||
exp.TsOrDsAdd: lambda self, e: self.func("DATE_ADD", e.this, e.expression),
|
||||
exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this),
|
||||
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
|
|
|
@ -156,6 +156,3 @@ class Drill(Dialect):
|
|||
exp.TsOrDiToDi: lambda self,
|
||||
e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
|
||||
}
|
||||
|
||||
def normalize_func(self, name: str) -> str:
|
||||
return name if exp.SAFE_IDENTIFIER_RE.match(name) else f"`{name}`"
|
||||
|
|
|
@ -79,6 +79,21 @@ def _build_date_diff(args: t.List) -> exp.Expression:
|
|||
return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
|
||||
|
||||
|
||||
def _build_generate_series(end_exclusive: bool = False) -> t.Callable[[t.List], exp.GenerateSeries]:
|
||||
def _builder(args: t.List) -> exp.GenerateSeries:
|
||||
# Check https://duckdb.org/docs/sql/functions/nested.html#range-functions
|
||||
if len(args) == 1:
|
||||
# DuckDB uses 0 as a default for the series' start when it's omitted
|
||||
args.insert(0, exp.Literal.number("0"))
|
||||
|
||||
gen_series = exp.GenerateSeries.from_arg_list(args)
|
||||
gen_series.set("is_end_exclusive", end_exclusive)
|
||||
|
||||
return gen_series
|
||||
|
||||
return _builder
|
||||
|
||||
|
||||
def _build_make_timestamp(args: t.List) -> exp.Expression:
|
||||
if len(args) == 1:
|
||||
return exp.UnixToTime(this=seq_get(args, 0), scale=exp.UnixToTime.MICROS)
|
||||
|
@ -95,13 +110,13 @@ def _build_make_timestamp(args: t.List) -> exp.Expression:
|
|||
|
||||
def _struct_sql(self: DuckDB.Generator, expression: exp.Struct) -> str:
|
||||
args: t.List[str] = []
|
||||
for expr in expression.expressions:
|
||||
if isinstance(expr, exp.Alias):
|
||||
key = expr.alias
|
||||
value = expr.this
|
||||
else:
|
||||
key = expr.name or expr.this.name
|
||||
for i, expr in enumerate(expression.expressions):
|
||||
if isinstance(expr, exp.PropertyEQ):
|
||||
key = expr.name
|
||||
value = expr.expression
|
||||
else:
|
||||
key = f"_{i}"
|
||||
value = expr
|
||||
|
||||
args.append(f"{self.sql(exp.Literal.string(key))}: {self.sql(value)}")
|
||||
|
||||
|
@ -148,13 +163,6 @@ def _rename_unless_within_group(
|
|||
)
|
||||
|
||||
|
||||
def _build_struct_pack(args: t.List) -> exp.Struct:
|
||||
args_with_columns_as_identifiers = [
|
||||
exp.PropertyEQ(this=arg.this.this, expression=arg.expression) for arg in args
|
||||
]
|
||||
return exp.Struct.from_arg_list(args_with_columns_as_identifiers)
|
||||
|
||||
|
||||
class DuckDB(Dialect):
|
||||
NULL_ORDERING = "nulls_are_last"
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
@ -189,6 +197,7 @@ class DuckDB(Dialect):
|
|||
"CHARACTER VARYING": TokenType.TEXT,
|
||||
"EXCLUDE": TokenType.EXCEPT,
|
||||
"LOGICAL": TokenType.BOOLEAN,
|
||||
"ONLY": TokenType.ONLY,
|
||||
"PIVOT_WIDER": TokenType.PIVOT,
|
||||
"SIGNED": TokenType.INT,
|
||||
"STRING": TokenType.VARCHAR,
|
||||
|
@ -213,6 +222,8 @@ class DuckDB(Dialect):
|
|||
TokenType.TILDA: exp.RegexpLike,
|
||||
}
|
||||
|
||||
FUNCTIONS_WITH_ALIASED_ARGS = {*parser.Parser.FUNCTIONS_WITH_ALIASED_ARGS, "STRUCT_PACK"}
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"ARRAY_HAS": exp.ArrayContains.from_arg_list,
|
||||
|
@ -261,12 +272,14 @@ class DuckDB(Dialect):
|
|||
"STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
|
||||
"STRING_TO_ARRAY": exp.Split.from_arg_list,
|
||||
"STRPTIME": build_formatted_time(exp.StrToTime, "duckdb"),
|
||||
"STRUCT_PACK": _build_struct_pack,
|
||||
"STRUCT_PACK": exp.Struct.from_arg_list,
|
||||
"STR_SPLIT": exp.Split.from_arg_list,
|
||||
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
|
||||
"TO_TIMESTAMP": exp.UnixToTime.from_arg_list,
|
||||
"UNNEST": exp.Explode.from_arg_list,
|
||||
"XOR": binary_from_function(exp.BitwiseXor),
|
||||
"GENERATE_SERIES": _build_generate_series(),
|
||||
"RANGE": _build_generate_series(end_exclusive=True),
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
|
||||
|
@ -313,6 +326,8 @@ class DuckDB(Dialect):
|
|||
return pivot_column_names(aggregations, dialect="duckdb")
|
||||
|
||||
class Generator(generator.Generator):
|
||||
PARAMETER_TOKEN = "$"
|
||||
NAMED_PLACEHOLDER_TOKEN = "$"
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
QUERY_HINTS = False
|
||||
|
@ -535,5 +550,22 @@ class DuckDB(Dialect):
|
|||
return self.sql(expression, "this")
|
||||
return super().columndef_sql(expression, sep)
|
||||
|
||||
def placeholder_sql(self, expression: exp.Placeholder) -> str:
|
||||
return f"${expression.name}" if expression.name else "?"
|
||||
def join_sql(self, expression: exp.Join) -> str:
|
||||
if (
|
||||
expression.side == "LEFT"
|
||||
and not expression.args.get("on")
|
||||
and isinstance(expression.this, exp.Unnest)
|
||||
):
|
||||
# Some dialects support `LEFT JOIN UNNEST(...)` without an explicit ON clause
|
||||
# DuckDB doesn't, but we can just add a dummy ON clause that is always true
|
||||
return super().join_sql(expression.on(exp.true()))
|
||||
|
||||
return super().join_sql(expression)
|
||||
|
||||
def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
|
||||
# GENERATE_SERIES(a, b) -> [a, b], RANGE(a, b) -> [a, b)
|
||||
if expression.args.get("is_end_exclusive"):
|
||||
expression.set("is_end_exclusive", None)
|
||||
return rename_func("RANGE")(self, expression)
|
||||
|
||||
return super().generateseries_sql(expression)
|
||||
|
|
|
@ -140,6 +140,15 @@ def _str_to_unix_sql(self: Hive.Generator, expression: exp.StrToUnix) -> str:
|
|||
return self.func("UNIX_TIMESTAMP", expression.this, time_format("hive")(self, expression))
|
||||
|
||||
|
||||
def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
|
||||
timestamp = self.sql(expression, "this")
|
||||
scale = expression.args.get("scale")
|
||||
if scale in (None, exp.UnixToTime.SECONDS):
|
||||
return rename_func("FROM_UNIXTIME")(self, expression)
|
||||
|
||||
return f"FROM_UNIXTIME({timestamp} / POW(10, {scale}))"
|
||||
|
||||
|
||||
def _str_to_date_sql(self: Hive.Generator, expression: exp.StrToDate) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
|
@ -536,7 +545,7 @@ class Hive(Dialect):
|
|||
exp.UnixToStr: lambda self, e: self.func(
|
||||
"FROM_UNIXTIME", e.this, time_format("hive")(self, e)
|
||||
),
|
||||
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
|
||||
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
|
||||
|
@ -609,9 +618,8 @@ class Hive(Dialect):
|
|||
return self.properties(properties, prefix=self.seg("TBLPROPERTIES"))
|
||||
|
||||
def datatype_sql(self, expression: exp.DataType) -> str:
|
||||
if (
|
||||
expression.this in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR)
|
||||
and not expression.expressions
|
||||
if expression.this in self.PARAMETERIZABLE_TEXT_TYPES and (
|
||||
not expression.expressions or expression.expressions[0].name == "MAX"
|
||||
):
|
||||
expression = exp.DataType.build("text")
|
||||
elif expression.is_type(exp.DataType.Type.TEXT) and expression.expressions:
|
||||
|
@ -631,3 +639,15 @@ class Hive(Dialect):
|
|||
def version_sql(self, expression: exp.Version) -> str:
|
||||
sql = super().version_sql(expression)
|
||||
return sql.replace("FOR ", "", 1)
|
||||
|
||||
def struct_sql(self, expression: exp.Struct) -> str:
|
||||
values = []
|
||||
|
||||
for i, e in enumerate(expression.expressions):
|
||||
if isinstance(e, exp.PropertyEQ):
|
||||
self.unsupported("Hive does not support named structs.")
|
||||
values.append(e.expression)
|
||||
else:
|
||||
values.append(e)
|
||||
|
||||
return self.func("STRUCT", *values)
|
||||
|
|
|
@ -185,7 +185,6 @@ class MySQL(Dialect):
|
|||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"CHARSET": TokenType.CHARACTER_SET,
|
||||
"ENUM": TokenType.ENUM,
|
||||
"FORCE": TokenType.FORCE,
|
||||
"IGNORE": TokenType.IGNORE,
|
||||
"LOCK TABLES": TokenType.COMMAND,
|
||||
|
@ -391,6 +390,11 @@ class MySQL(Dialect):
|
|||
"WARNINGS": _show_parser("WARNINGS"),
|
||||
}
|
||||
|
||||
PROPERTY_PARSERS = {
|
||||
**parser.Parser.PROPERTY_PARSERS,
|
||||
"LOCK": lambda self: self._parse_property_assignment(exp.LockProperty),
|
||||
}
|
||||
|
||||
SET_PARSERS = {
|
||||
**parser.Parser.SET_PARSERS,
|
||||
"PERSIST": lambda self: self._parse_set_item_assignment("PERSIST"),
|
||||
|
@ -416,16 +420,11 @@ class MySQL(Dialect):
|
|||
"SPATIAL",
|
||||
}
|
||||
|
||||
PROFILE_TYPES = {
|
||||
"ALL",
|
||||
"BLOCK IO",
|
||||
"CONTEXT SWITCHES",
|
||||
"CPU",
|
||||
"IPC",
|
||||
"MEMORY",
|
||||
"PAGE FAULTS",
|
||||
"SOURCE",
|
||||
"SWAPS",
|
||||
PROFILE_TYPES: parser.OPTIONS_TYPE = {
|
||||
**dict.fromkeys(("ALL", "CPU", "IPC", "MEMORY", "SOURCE", "SWAPS"), tuple()),
|
||||
"BLOCK": ("IO",),
|
||||
"CONTEXT": ("SWITCHES",),
|
||||
"PAGE": ("FAULTS",),
|
||||
}
|
||||
|
||||
TYPE_TOKENS = {
|
||||
|
|
|
@ -66,6 +66,26 @@ class Oracle(Dialect):
|
|||
"FF6": "%f", # only 6 digits are supported in python formats
|
||||
}
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
VAR_SINGLE_TOKENS = {"@", "$", "#"}
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"(+)": TokenType.JOIN_MARKER,
|
||||
"BINARY_DOUBLE": TokenType.DOUBLE,
|
||||
"BINARY_FLOAT": TokenType.FLOAT,
|
||||
"COLUMNS": TokenType.COLUMN,
|
||||
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
|
||||
"MINUS": TokenType.EXCEPT,
|
||||
"NVARCHAR2": TokenType.NVARCHAR,
|
||||
"ORDER SIBLINGS BY": TokenType.ORDER_SIBLINGS_BY,
|
||||
"SAMPLE": TokenType.TABLE_SAMPLE,
|
||||
"START": TokenType.BEGIN,
|
||||
"SYSDATE": TokenType.CURRENT_TIMESTAMP,
|
||||
"TOP": TokenType.TOP,
|
||||
"VARCHAR2": TokenType.VARCHAR,
|
||||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
|
||||
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP}
|
||||
|
@ -93,6 +113,21 @@ class Oracle(Dialect):
|
|||
"XMLTABLE": lambda self: self._parse_xml_table(),
|
||||
}
|
||||
|
||||
NO_PAREN_FUNCTION_PARSERS = {
|
||||
**parser.Parser.NO_PAREN_FUNCTION_PARSERS,
|
||||
"CONNECT_BY_ROOT": lambda self: self.expression(
|
||||
exp.ConnectByRoot, this=self._parse_column()
|
||||
),
|
||||
}
|
||||
|
||||
PROPERTY_PARSERS = {
|
||||
**parser.Parser.PROPERTY_PARSERS,
|
||||
"GLOBAL": lambda self: self._match_text_seq("TEMPORARY")
|
||||
and self.expression(exp.TemporaryProperty, this="GLOBAL"),
|
||||
"PRIVATE": lambda self: self._match_text_seq("TEMPORARY")
|
||||
and self.expression(exp.TemporaryProperty, this="PRIVATE"),
|
||||
}
|
||||
|
||||
QUERY_MODIFIER_PARSERS = {
|
||||
**parser.Parser.QUERY_MODIFIER_PARSERS,
|
||||
TokenType.ORDER_SIBLINGS_BY: lambda self: ("order", self._parse_order()),
|
||||
|
@ -190,6 +225,7 @@ class Oracle(Dialect):
|
|||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.ConnectByRoot: lambda self, e: f"CONNECT_BY_ROOT {self.sql(e, 'this')}",
|
||||
exp.DateStrToDate: lambda self, e: self.func(
|
||||
"TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD")
|
||||
),
|
||||
|
@ -207,6 +243,7 @@ class Oracle(Dialect):
|
|||
exp.Substring: rename_func("SUBSTR"),
|
||||
exp.Table: lambda self, e: self.table_sql(e, sep=" "),
|
||||
exp.TableSample: lambda self, e: self.tablesample_sql(e, sep=" "),
|
||||
exp.TemporaryProperty: lambda _, e: f"{e.name or 'GLOBAL'} TEMPORARY",
|
||||
exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)),
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.Trim: trim_sql,
|
||||
|
@ -242,23 +279,3 @@ class Oracle(Dialect):
|
|||
if len(expression.args.get("actions", [])) > 1:
|
||||
return f"ADD ({actions})"
|
||||
return f"ADD {actions}"
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
VAR_SINGLE_TOKENS = {"@", "$", "#"}
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"(+)": TokenType.JOIN_MARKER,
|
||||
"BINARY_DOUBLE": TokenType.DOUBLE,
|
||||
"BINARY_FLOAT": TokenType.FLOAT,
|
||||
"COLUMNS": TokenType.COLUMN,
|
||||
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
|
||||
"MINUS": TokenType.EXCEPT,
|
||||
"NVARCHAR2": TokenType.NVARCHAR,
|
||||
"ORDER SIBLINGS BY": TokenType.ORDER_SIBLINGS_BY,
|
||||
"SAMPLE": TokenType.TABLE_SAMPLE,
|
||||
"START": TokenType.BEGIN,
|
||||
"SYSDATE": TokenType.CURRENT_TIMESTAMP,
|
||||
"TOP": TokenType.TOP,
|
||||
"VARCHAR2": TokenType.VARCHAR,
|
||||
}
|
||||
|
|
|
@ -138,7 +138,9 @@ def _auto_increment_to_serial(expression: exp.Expression) -> exp.Expression:
|
|||
|
||||
|
||||
def _serial_to_generated(expression: exp.Expression) -> exp.Expression:
|
||||
kind = expression.args.get("kind")
|
||||
if not isinstance(expression, exp.ColumnDef):
|
||||
return expression
|
||||
kind = expression.kind
|
||||
if not kind:
|
||||
return expression
|
||||
|
||||
|
@ -279,6 +281,7 @@ class Postgres(Dialect):
|
|||
"TEMP": TokenType.TEMPORARY,
|
||||
"CSTRING": TokenType.PSEUDO_TYPE,
|
||||
"OID": TokenType.OBJECT_IDENTIFIER,
|
||||
"ONLY": TokenType.ONLY,
|
||||
"OPERATOR": TokenType.OPERATOR,
|
||||
"REGCLASS": TokenType.OBJECT_IDENTIFIER,
|
||||
"REGCOLLATION": TokenType.OBJECT_IDENTIFIER,
|
||||
|
@ -451,6 +454,7 @@ class Postgres(Dialect):
|
|||
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
|
||||
exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"),
|
||||
exp.JSONBContains: lambda self, e: self.binary(e, "?"),
|
||||
exp.ParseJSON: lambda self, e: self.sql(exp.cast(e.this, exp.DataType.Type.JSON)),
|
||||
exp.JSONPathKey: json_path_key_only_name,
|
||||
exp.JSONPathRoot: lambda *_: "",
|
||||
exp.JSONPathSubscript: lambda self, e: self.json_path_part(e.this),
|
||||
|
@ -506,6 +510,26 @@ class Postgres(Dialect):
|
|||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def unnest_sql(self, expression: exp.Unnest) -> str:
|
||||
if len(expression.expressions) == 1:
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
this = annotate_types(expression.expressions[0])
|
||||
if this.is_type("array<json>"):
|
||||
while isinstance(this, exp.Cast):
|
||||
this = this.this
|
||||
|
||||
arg = self.sql(exp.cast(this, exp.DataType.Type.JSON))
|
||||
alias = self.sql(expression, "alias")
|
||||
alias = f" AS {alias}" if alias else ""
|
||||
|
||||
if expression.args.get("offset"):
|
||||
self.unsupported("Unsupported JSON_ARRAY_ELEMENTS with offset")
|
||||
|
||||
return f"JSON_ARRAY_ELEMENTS({arg}){alias}"
|
||||
|
||||
return super().unnest_sql(expression)
|
||||
|
||||
def bracket_sql(self, expression: exp.Bracket) -> str:
|
||||
"""Forms like ARRAY[1, 2, 3][3] aren't allowed; we need to wrap the ARRAY."""
|
||||
if isinstance(expression.this, exp.Array):
|
||||
|
|
|
@ -453,11 +453,32 @@ class Presto(Dialect):
|
|||
return super().bracket_sql(expression)
|
||||
|
||||
def struct_sql(self, expression: exp.Struct) -> str:
|
||||
if any(isinstance(arg, self.KEY_VALUE_DEFINITIONS) for arg in expression.expressions):
|
||||
self.unsupported("Struct with key-value definitions is unsupported.")
|
||||
return self.function_fallback_sql(expression)
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
return rename_func("ROW")(self, expression)
|
||||
expression = annotate_types(expression)
|
||||
values: t.List[str] = []
|
||||
schema: t.List[str] = []
|
||||
unknown_type = False
|
||||
|
||||
for e in expression.expressions:
|
||||
if isinstance(e, exp.PropertyEQ):
|
||||
if e.type and e.type.is_type(exp.DataType.Type.UNKNOWN):
|
||||
unknown_type = True
|
||||
else:
|
||||
schema.append(f"{self.sql(e, 'this')} {self.sql(e.type)}")
|
||||
values.append(self.sql(e, "expression"))
|
||||
else:
|
||||
values.append(self.sql(e))
|
||||
|
||||
size = len(expression.expressions)
|
||||
|
||||
if not size or len(schema) != size:
|
||||
if unknown_type:
|
||||
self.unsupported(
|
||||
"Cannot convert untyped key-value definitions (try annotate_types)."
|
||||
)
|
||||
return self.func("ROW", *values)
|
||||
return f"CAST(ROW({', '.join(values)}) AS ROW({', '.join(schema)}))"
|
||||
|
||||
def interval_sql(self, expression: exp.Interval) -> str:
|
||||
unit = self.sql(expression, "unit")
|
||||
|
|
|
@ -70,6 +70,8 @@ class Redshift(Postgres):
|
|||
"SYSDATE": lambda self: self.expression(exp.CurrentTimestamp, transaction=True),
|
||||
}
|
||||
|
||||
SUPPORTS_IMPLICIT_UNNEST = True
|
||||
|
||||
def _parse_table(
|
||||
self,
|
||||
schema: bool = False,
|
||||
|
@ -124,27 +126,6 @@ class Redshift(Postgres):
|
|||
self._retreat(index)
|
||||
return None
|
||||
|
||||
def _parse_query_modifiers(
|
||||
self, this: t.Optional[exp.Expression]
|
||||
) -> t.Optional[exp.Expression]:
|
||||
this = super()._parse_query_modifiers(this)
|
||||
|
||||
if this:
|
||||
refs = set()
|
||||
|
||||
for i, join in enumerate(this.args.get("joins", [])):
|
||||
refs.add(
|
||||
(
|
||||
this.args["from"] if i == 0 else this.args["joins"][i - 1]
|
||||
).this.alias.lower()
|
||||
)
|
||||
|
||||
table = join.this
|
||||
if isinstance(table, exp.Table) and not join.args.get("on"):
|
||||
if table.parts[0].name.lower() in refs:
|
||||
table.replace(table.to_column())
|
||||
return this
|
||||
|
||||
class Tokenizer(Postgres.Tokenizer):
|
||||
BIT_STRINGS = []
|
||||
HEX_STRINGS = []
|
||||
|
@ -225,6 +206,18 @@ class Redshift(Postgres):
|
|||
|
||||
RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"}
|
||||
|
||||
def unnest_sql(self, expression: exp.Unnest) -> str:
|
||||
args = expression.expressions
|
||||
num_args = len(args)
|
||||
|
||||
if num_args > 1:
|
||||
self.unsupported(f"Unsupported number of arguments in UNNEST: {num_args}")
|
||||
return ""
|
||||
|
||||
arg = self.sql(seq_get(args, 0))
|
||||
alias = self.expressions(expression.args.get("alias"), key="columns")
|
||||
return f"{arg} AS {alias}" if alias else arg
|
||||
|
||||
def with_properties(self, properties: exp.Properties) -> str:
|
||||
"""Redshift doesn't have `WITH` as part of their with_properties so we remove it"""
|
||||
return self.properties(properties, prefix=" ", suffix="")
|
||||
|
|
|
@ -21,7 +21,7 @@ from sqlglot.dialects.dialect import (
|
|||
var_map_sql,
|
||||
)
|
||||
from sqlglot.expressions import Literal
|
||||
from sqlglot.helper import is_int, seq_get
|
||||
from sqlglot.helper import flatten, is_int, seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
@ -66,7 +66,7 @@ def _build_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
|
|||
|
||||
return exp.Struct(
|
||||
expressions=[
|
||||
t.cast(exp.Condition, k).eq(v) for k, v in zip(expression.keys, expression.values)
|
||||
exp.PropertyEQ(this=k, expression=v) for k, v in zip(expression.keys, expression.values)
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -409,8 +409,16 @@ class Snowflake(Dialect):
|
|||
"TERSE OBJECTS": _show_parser("OBJECTS"),
|
||||
"TABLES": _show_parser("TABLES"),
|
||||
"TERSE TABLES": _show_parser("TABLES"),
|
||||
"VIEWS": _show_parser("VIEWS"),
|
||||
"TERSE VIEWS": _show_parser("VIEWS"),
|
||||
"PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
|
||||
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
|
||||
"IMPORTED KEYS": _show_parser("IMPORTED KEYS"),
|
||||
"TERSE IMPORTED KEYS": _show_parser("IMPORTED KEYS"),
|
||||
"UNIQUE KEYS": _show_parser("UNIQUE KEYS"),
|
||||
"TERSE UNIQUE KEYS": _show_parser("UNIQUE KEYS"),
|
||||
"SEQUENCES": _show_parser("SEQUENCES"),
|
||||
"TERSE SEQUENCES": _show_parser("SEQUENCES"),
|
||||
"COLUMNS": _show_parser("COLUMNS"),
|
||||
"USERS": _show_parser("USERS"),
|
||||
"TERSE USERS": _show_parser("USERS"),
|
||||
|
@ -424,11 +432,13 @@ class Snowflake(Dialect):
|
|||
|
||||
FLATTEN_COLUMNS = ["SEQ", "KEY", "PATH", "INDEX", "VALUE", "THIS"]
|
||||
|
||||
SCHEMA_KINDS = {"OBJECTS", "TABLES", "VIEWS", "SEQUENCES", "UNIQUE KEYS", "IMPORTED KEYS"}
|
||||
|
||||
def _parse_colon_get_path(
|
||||
self: parser.Parser, this: t.Optional[exp.Expression]
|
||||
) -> t.Optional[exp.Expression]:
|
||||
while True:
|
||||
path = self._parse_bitwise()
|
||||
path = self._parse_bitwise() or self._parse_var(any_token=True)
|
||||
|
||||
# The cast :: operator has a lower precedence than the extraction operator :, so
|
||||
# we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH
|
||||
|
@ -535,7 +545,7 @@ class Snowflake(Dialect):
|
|||
return table
|
||||
|
||||
def _parse_table_parts(
|
||||
self, schema: bool = False, is_db_reference: bool = False
|
||||
self, schema: bool = False, is_db_reference: bool = False, wildcard: bool = False
|
||||
) -> exp.Table:
|
||||
# https://docs.snowflake.com/en/user-guide/querying-stage
|
||||
if self._match(TokenType.STRING, advance=False):
|
||||
|
@ -603,7 +613,7 @@ class Snowflake(Dialect):
|
|||
if self._curr:
|
||||
scope = self._parse_table_parts()
|
||||
elif self._curr:
|
||||
scope_kind = "SCHEMA" if this in ("OBJECTS", "TABLES") else "TABLE"
|
||||
scope_kind = "SCHEMA" if this in self.SCHEMA_KINDS else "TABLE"
|
||||
scope = self._parse_table_parts()
|
||||
|
||||
return self.expression(
|
||||
|
@ -758,10 +768,6 @@ class Snowflake(Dialect):
|
|||
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
|
||||
),
|
||||
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
|
||||
exp.Struct: lambda self, e: self.func(
|
||||
"OBJECT_CONSTRUCT",
|
||||
*(arg for expression in e.expressions for arg in expression.flatten()),
|
||||
),
|
||||
exp.Stuff: rename_func("INSERT"),
|
||||
exp.TimestampDiff: lambda self, e: self.func(
|
||||
"TIMESTAMPDIFF", e.unit, e.expression, e.this
|
||||
|
@ -937,3 +943,19 @@ class Snowflake(Dialect):
|
|||
|
||||
def cluster_sql(self, expression: exp.Cluster) -> str:
|
||||
return f"CLUSTER BY ({self.expressions(expression, flat=True)})"
|
||||
|
||||
def struct_sql(self, expression: exp.Struct) -> str:
|
||||
keys = []
|
||||
values = []
|
||||
|
||||
for i, e in enumerate(expression.expressions):
|
||||
if isinstance(e, exp.PropertyEQ):
|
||||
keys.append(
|
||||
exp.Literal.string(e.name) if isinstance(e.this, exp.Identifier) else e.this
|
||||
)
|
||||
values.append(e.expression)
|
||||
else:
|
||||
keys.append(exp.Literal.string(f"_{i}"))
|
||||
values.append(e)
|
||||
|
||||
return self.func("OBJECT_CONSTRUCT", *flatten(zip(keys, values)))
|
||||
|
|
|
@ -263,14 +263,9 @@ class Spark2(Hive):
|
|||
CREATE_FUNCTION_RETURN_AS = False
|
||||
|
||||
def struct_sql(self, expression: exp.Struct) -> str:
|
||||
args = []
|
||||
for arg in expression.expressions:
|
||||
if isinstance(arg, self.KEY_VALUE_DEFINITIONS):
|
||||
args.append(exp.alias_(arg.expression, arg.this.name))
|
||||
else:
|
||||
args.append(arg)
|
||||
from sqlglot.generator import Generator
|
||||
|
||||
return self.func("STRUCT", *args)
|
||||
return Generator.struct_sql(self, expression)
|
||||
|
||||
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
|
||||
if is_parse_json(expression.this):
|
||||
|
|
|
@ -92,6 +92,7 @@ class SQLite(Dialect):
|
|||
NVL2_SUPPORTED = False
|
||||
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
|
||||
SUPPORTS_CREATE_TABLE_LIKE = False
|
||||
SUPPORTS_TABLE_ALIAS_COLUMNS = False
|
||||
|
||||
SUPPORTED_JSON_PATH_PARTS = {
|
||||
exp.JSONPathKey,
|
||||
|
@ -173,6 +174,21 @@ class SQLite(Dialect):
|
|||
|
||||
return super().cast_sql(expression)
|
||||
|
||||
def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
|
||||
parent = expression.parent
|
||||
alias = parent and parent.args.get("alias")
|
||||
|
||||
if isinstance(alias, exp.TableAlias) and alias.columns:
|
||||
column_alias = alias.columns[0]
|
||||
alias.set("columns", None)
|
||||
sql = self.sql(
|
||||
exp.select(exp.alias_("value", column_alias)).from_(expression).subquery()
|
||||
)
|
||||
else:
|
||||
sql = super().generateseries_sql(expression)
|
||||
|
||||
return sql
|
||||
|
||||
def datediff_sql(self, expression: exp.DateDiff) -> str:
|
||||
unit = expression.args.get("unit")
|
||||
unit = unit.name.upper() if unit else "DAY"
|
||||
|
|
|
@ -18,7 +18,6 @@ from sqlglot.dialects.dialect import (
|
|||
timestrtotime_sql,
|
||||
trim_sql,
|
||||
)
|
||||
from sqlglot.expressions import DataType
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.time import format_time
|
||||
from sqlglot.tokens import TokenType
|
||||
|
@ -63,6 +62,44 @@ DEFAULT_START_DATE = datetime.date(1900, 1, 1)
|
|||
|
||||
BIT_TYPES = {exp.EQ, exp.NEQ, exp.Is, exp.In, exp.Select, exp.Alias}
|
||||
|
||||
# Unsupported options:
|
||||
# - OPTIMIZE FOR ( @variable_name { UNKNOWN | = <literal_constant> } [ , ...n ] )
|
||||
# - TABLE HINT
|
||||
OPTIONS: parser.OPTIONS_TYPE = {
|
||||
**dict.fromkeys(
|
||||
(
|
||||
"DISABLE_OPTIMIZED_PLAN_FORCING",
|
||||
"FAST",
|
||||
"IGNORE_NONCLUSTERED_COLUMNSTORE_INDEX",
|
||||
"LABEL",
|
||||
"MAXDOP",
|
||||
"MAXRECURSION",
|
||||
"MAX_GRANT_PERCENT",
|
||||
"MIN_GRANT_PERCENT",
|
||||
"NO_PERFORMANCE_SPOOL",
|
||||
"QUERYTRACEON",
|
||||
"RECOMPILE",
|
||||
),
|
||||
tuple(),
|
||||
),
|
||||
"CONCAT": ("UNION",),
|
||||
"DISABLE": ("EXTERNALPUSHDOWN", "SCALEOUTEXECUTION"),
|
||||
"EXPAND": ("VIEWS",),
|
||||
"FORCE": ("EXTERNALPUSHDOWN", "ORDER", "SCALEOUTEXECUTION"),
|
||||
"HASH": ("GROUP", "JOIN", "UNION"),
|
||||
"KEEP": ("PLAN",),
|
||||
"KEEPFIXED": ("PLAN",),
|
||||
"LOOP": ("JOIN",),
|
||||
"MERGE": ("JOIN", "UNION"),
|
||||
"OPTIMIZE": (("FOR", "UNKNOWN"),),
|
||||
"ORDER": ("GROUP",),
|
||||
"PARAMETERIZATION": ("FORCED", "SIMPLE"),
|
||||
"ROBUST": ("PLAN",),
|
||||
"USE": ("PLAN",),
|
||||
}
|
||||
|
||||
OPTIONS_THAT_REQUIRE_EQUAL = ("MAX_GRANT_PERCENT", "MIN_GRANT_PERCENT", "LABEL")
|
||||
|
||||
|
||||
def _build_formatted_time(
|
||||
exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None
|
||||
|
@ -221,19 +258,17 @@ def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression:
|
|||
# We keep track of the unaliased column projection indexes instead of the expressions
|
||||
# themselves, because the latter are going to be replaced by new nodes when the aliases
|
||||
# are added and hence we won't be able to reach these newly added Alias parents
|
||||
subqueryable = expression.this
|
||||
query = expression.this
|
||||
unaliased_column_indexes = (
|
||||
i
|
||||
for i, c in enumerate(subqueryable.selects)
|
||||
if isinstance(c, exp.Column) and not c.alias
|
||||
i for i, c in enumerate(query.selects) if isinstance(c, exp.Column) and not c.alias
|
||||
)
|
||||
|
||||
qualify_outputs(subqueryable)
|
||||
qualify_outputs(query)
|
||||
|
||||
# Preserve the quoting information of columns for newly added Alias nodes
|
||||
subqueryable_selects = subqueryable.selects
|
||||
query_selects = query.selects
|
||||
for select_index in unaliased_column_indexes:
|
||||
alias = subqueryable_selects[select_index]
|
||||
alias = query_selects[select_index]
|
||||
column = alias.this
|
||||
if isinstance(column.this, exp.Identifier):
|
||||
alias.args["alias"].set("quoted", column.this.quoted)
|
||||
|
@ -420,7 +455,6 @@ class TSQL(Dialect):
|
|||
"IMAGE": TokenType.IMAGE,
|
||||
"MONEY": TokenType.MONEY,
|
||||
"NTEXT": TokenType.TEXT,
|
||||
"NVARCHAR(MAX)": TokenType.TEXT,
|
||||
"PRINT": TokenType.COMMAND,
|
||||
"PROC": TokenType.PROCEDURE,
|
||||
"REAL": TokenType.FLOAT,
|
||||
|
@ -431,15 +465,24 @@ class TSQL(Dialect):
|
|||
"TOP": TokenType.TOP,
|
||||
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
|
||||
"UPDATE STATISTICS": TokenType.COMMAND,
|
||||
"VARCHAR(MAX)": TokenType.TEXT,
|
||||
"XML": TokenType.XML,
|
||||
"OUTPUT": TokenType.RETURNING,
|
||||
"SYSTEM_USER": TokenType.CURRENT_USER,
|
||||
"FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
|
||||
"OPTION": TokenType.OPTION,
|
||||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
SET_REQUIRES_ASSIGNMENT_DELIMITER = False
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
|
||||
STRING_ALIASES = True
|
||||
NO_PAREN_IF_COMMANDS = False
|
||||
|
||||
QUERY_MODIFIER_PARSERS = {
|
||||
**parser.Parser.QUERY_MODIFIER_PARSERS,
|
||||
TokenType.OPTION: lambda self: ("options", self._parse_options()),
|
||||
}
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
|
@ -472,19 +515,7 @@ class TSQL(Dialect):
|
|||
"TIMEFROMPARTS": _build_timefromparts,
|
||||
}
|
||||
|
||||
JOIN_HINTS = {
|
||||
"LOOP",
|
||||
"HASH",
|
||||
"MERGE",
|
||||
"REMOTE",
|
||||
}
|
||||
|
||||
VAR_LENGTH_DATATYPES = {
|
||||
DataType.Type.NVARCHAR,
|
||||
DataType.Type.VARCHAR,
|
||||
DataType.Type.CHAR,
|
||||
DataType.Type.NCHAR,
|
||||
}
|
||||
JOIN_HINTS = {"LOOP", "HASH", "MERGE", "REMOTE"}
|
||||
|
||||
RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - {
|
||||
TokenType.TABLE,
|
||||
|
@ -496,11 +527,21 @@ class TSQL(Dialect):
|
|||
TokenType.END: lambda self: self._parse_command(),
|
||||
}
|
||||
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
def _parse_options(self) -> t.Optional[t.List[exp.Expression]]:
|
||||
if not self._match(TokenType.OPTION):
|
||||
return None
|
||||
|
||||
ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
|
||||
STRING_ALIASES = True
|
||||
NO_PAREN_IF_COMMANDS = False
|
||||
def _parse_option() -> t.Optional[exp.Expression]:
|
||||
option = self._parse_var_from_options(OPTIONS)
|
||||
if not option:
|
||||
return None
|
||||
|
||||
self._match(TokenType.EQ)
|
||||
return self.expression(
|
||||
exp.QueryOption, this=option, expression=self._parse_primary_or_var()
|
||||
)
|
||||
|
||||
return self._parse_wrapped_csv(_parse_option)
|
||||
|
||||
def _parse_projections(self) -> t.List[exp.Expression]:
|
||||
"""
|
||||
|
@ -576,48 +617,13 @@ class TSQL(Dialect):
|
|||
def _parse_convert(
|
||||
self, strict: bool, safe: t.Optional[bool] = None
|
||||
) -> t.Optional[exp.Expression]:
|
||||
to = self._parse_types()
|
||||
this = self._parse_types()
|
||||
self._match(TokenType.COMMA)
|
||||
this = self._parse_conjunction()
|
||||
|
||||
if not to or not this:
|
||||
return None
|
||||
|
||||
# Retrieve length of datatype and override to default if not specified
|
||||
if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
|
||||
to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False)
|
||||
|
||||
# Check whether a conversion with format is applicable
|
||||
if self._match(TokenType.COMMA):
|
||||
format_val = self._parse_number()
|
||||
format_val_name = format_val.name if format_val else ""
|
||||
|
||||
if format_val_name not in TSQL.CONVERT_FORMAT_MAPPING:
|
||||
raise ValueError(
|
||||
f"CONVERT function at T-SQL does not support format style {format_val_name}"
|
||||
)
|
||||
|
||||
format_norm = exp.Literal.string(TSQL.CONVERT_FORMAT_MAPPING[format_val_name])
|
||||
|
||||
# Check whether the convert entails a string to date format
|
||||
if to.this == DataType.Type.DATE:
|
||||
return self.expression(exp.StrToDate, this=this, format=format_norm)
|
||||
# Check whether the convert entails a string to datetime format
|
||||
elif to.this == DataType.Type.DATETIME:
|
||||
return self.expression(exp.StrToTime, this=this, format=format_norm)
|
||||
# Check whether the convert entails a date to string format
|
||||
elif to.this in self.VAR_LENGTH_DATATYPES:
|
||||
return self.expression(
|
||||
exp.Cast if strict else exp.TryCast,
|
||||
to=to,
|
||||
this=self.expression(exp.TimeToStr, this=this, format=format_norm),
|
||||
safe=safe,
|
||||
)
|
||||
elif to.this == DataType.Type.TEXT:
|
||||
return self.expression(exp.TimeToStr, this=this, format=format_norm)
|
||||
|
||||
# Entails a simple cast without any format requirement
|
||||
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to, safe=safe)
|
||||
args = [this, *self._parse_csv(self._parse_conjunction)]
|
||||
convert = exp.Convert.from_arg_list(args)
|
||||
convert.set("safe", safe)
|
||||
convert.set("strict", strict)
|
||||
return convert
|
||||
|
||||
def _parse_user_defined_function(
|
||||
self, kind: t.Optional[TokenType] = None
|
||||
|
@ -683,6 +689,26 @@ class TSQL(Dialect):
|
|||
|
||||
return self.expression(exp.UniqueColumnConstraint, this=this)
|
||||
|
||||
def _parse_partition(self) -> t.Optional[exp.Partition]:
|
||||
if not self._match_text_seq("WITH", "(", "PARTITIONS"):
|
||||
return None
|
||||
|
||||
def parse_range():
|
||||
low = self._parse_bitwise()
|
||||
high = self._parse_bitwise() if self._match_text_seq("TO") else None
|
||||
|
||||
return (
|
||||
self.expression(exp.PartitionRange, this=low, expression=high) if high else low
|
||||
)
|
||||
|
||||
partition = self.expression(
|
||||
exp.Partition, expressions=self._parse_wrapped_csv(parse_range)
|
||||
)
|
||||
|
||||
self._match_r_paren()
|
||||
|
||||
return partition
|
||||
|
||||
class Generator(generator.Generator):
|
||||
LIMIT_IS_TOP = True
|
||||
QUERY_HINTS = False
|
||||
|
@ -728,6 +754,9 @@ class TSQL(Dialect):
|
|||
exp.DataType.Type.VARIANT: "SQL_VARIANT",
|
||||
}
|
||||
|
||||
TYPE_MAPPING.pop(exp.DataType.Type.NCHAR)
|
||||
TYPE_MAPPING.pop(exp.DataType.Type.NVARCHAR)
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.AnyValue: any_value_to_max_sql,
|
||||
|
@ -779,6 +808,20 @@ class TSQL(Dialect):
|
|||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def convert_sql(self, expression: exp.Convert) -> str:
|
||||
name = "TRY_CONVERT" if expression.args.get("safe") else "CONVERT"
|
||||
return self.func(
|
||||
name, expression.this, expression.expression, expression.args.get("style")
|
||||
)
|
||||
|
||||
def queryoption_sql(self, expression: exp.QueryOption) -> str:
|
||||
option = self.sql(expression, "this")
|
||||
value = self.sql(expression, "expression")
|
||||
if value:
|
||||
optional_equal_sign = "= " if option in OPTIONS_THAT_REQUIRE_EQUAL else ""
|
||||
return f"{option} {optional_equal_sign}{value}"
|
||||
return option
|
||||
|
||||
def lateral_op(self, expression: exp.Lateral) -> str:
|
||||
cross_apply = expression.args.get("cross_apply")
|
||||
if cross_apply is True:
|
||||
|
@ -876,11 +919,10 @@ class TSQL(Dialect):
|
|||
if ctas_with:
|
||||
ctas_with = ctas_with.pop()
|
||||
|
||||
subquery = ctas_expression
|
||||
if isinstance(subquery, exp.Subqueryable):
|
||||
subquery = subquery.subquery()
|
||||
if isinstance(ctas_expression, exp.UNWRAPPED_QUERIES):
|
||||
ctas_expression = ctas_expression.subquery()
|
||||
|
||||
select_into = exp.select("*").from_(exp.alias_(subquery, "temp", table=True))
|
||||
select_into = exp.select("*").from_(exp.alias_(ctas_expression, "temp", table=True))
|
||||
select_into.set("into", exp.Into(this=table))
|
||||
select_into.set("with", ctas_with)
|
||||
|
||||
|
@ -993,3 +1035,6 @@ class TSQL(Dialect):
|
|||
this_sql = self.sql(this)
|
||||
expression_sql = self.sql(expression, "expression")
|
||||
return self.func(name, this_sql, expression_sql if expression_sql else None)
|
||||
|
||||
def partition_sql(self, expression: exp.Partition) -> str:
|
||||
return f"WITH (PARTITIONS({self.expressions(expression, flat=True)}))"
|
||||
|
|
|
@ -119,13 +119,18 @@ def diff(
|
|||
return ChangeDistiller(**kwargs).diff(source_copy, target_copy, matchings=matchings_copy)
|
||||
|
||||
|
||||
LEAF_EXPRESSION_TYPES = (
|
||||
# The expression types for which Update edits are allowed.
|
||||
UPDATABLE_EXPRESSION_TYPES = (
|
||||
exp.Boolean,
|
||||
exp.DataType,
|
||||
exp.Identifier,
|
||||
exp.Literal,
|
||||
exp.Table,
|
||||
exp.Column,
|
||||
exp.Lambda,
|
||||
)
|
||||
|
||||
IGNORED_LEAF_EXPRESSION_TYPES = (exp.Identifier,)
|
||||
|
||||
|
||||
class ChangeDistiller:
|
||||
"""
|
||||
|
@ -152,8 +157,16 @@ class ChangeDistiller:
|
|||
|
||||
self._source = source
|
||||
self._target = target
|
||||
self._source_index = {id(n): n for n, *_ in self._source.bfs()}
|
||||
self._target_index = {id(n): n for n, *_ in self._target.bfs()}
|
||||
self._source_index = {
|
||||
id(n): n
|
||||
for n, *_ in self._source.bfs()
|
||||
if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES)
|
||||
}
|
||||
self._target_index = {
|
||||
id(n): n
|
||||
for n, *_ in self._target.bfs()
|
||||
if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES)
|
||||
}
|
||||
self._unmatched_source_nodes = set(self._source_index) - set(pre_matched_nodes)
|
||||
self._unmatched_target_nodes = set(self._target_index) - set(pre_matched_nodes.values())
|
||||
self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {}
|
||||
|
@ -170,7 +183,10 @@ class ChangeDistiller:
|
|||
for kept_source_node_id, kept_target_node_id in matching_set:
|
||||
source_node = self._source_index[kept_source_node_id]
|
||||
target_node = self._target_index[kept_target_node_id]
|
||||
if not isinstance(source_node, LEAF_EXPRESSION_TYPES) or source_node == target_node:
|
||||
if (
|
||||
not isinstance(source_node, UPDATABLE_EXPRESSION_TYPES)
|
||||
or source_node == target_node
|
||||
):
|
||||
edit_script.extend(
|
||||
self._generate_move_edits(source_node, target_node, matching_set)
|
||||
)
|
||||
|
@ -307,17 +323,16 @@ def _get_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]:
|
|||
has_child_exprs = False
|
||||
|
||||
for _, node in expression.iter_expressions():
|
||||
has_child_exprs = True
|
||||
yield from _get_leaves(node)
|
||||
if not isinstance(node, IGNORED_LEAF_EXPRESSION_TYPES):
|
||||
has_child_exprs = True
|
||||
yield from _get_leaves(node)
|
||||
|
||||
if not has_child_exprs:
|
||||
yield expression
|
||||
|
||||
|
||||
def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool:
|
||||
if type(source) is type(target) and (
|
||||
not isinstance(source, exp.Identifier) or type(source.parent) is type(target.parent)
|
||||
):
|
||||
if type(source) is type(target):
|
||||
if isinstance(source, exp.Join):
|
||||
return source.args.get("side") == target.args.get("side")
|
||||
|
||||
|
@ -343,7 +358,11 @@ def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]:
|
|||
if expression:
|
||||
for a in expression.args.values():
|
||||
args.extend(ensure_list(a))
|
||||
return [a for a in args if isinstance(a, exp.Expression)]
|
||||
return [
|
||||
a
|
||||
for a in args
|
||||
if isinstance(a, exp.Expression) and not isinstance(a, IGNORED_LEAF_EXPRESSION_TYPES)
|
||||
]
|
||||
|
||||
|
||||
def _lcs(
|
||||
|
|
|
@ -78,7 +78,7 @@ class Context:
|
|||
def sort(self, key) -> None:
|
||||
def sort_key(row: t.Tuple) -> t.Tuple:
|
||||
self.set_row(row)
|
||||
return self.eval_tuple(key)
|
||||
return tuple((t is None, t) for t in self.eval_tuple(key))
|
||||
|
||||
self.table.rows.sort(key=sort_key)
|
||||
|
||||
|
|
|
@ -142,7 +142,6 @@ class PythonExecutor:
|
|||
context = self.context({alias: table})
|
||||
yield context
|
||||
types = []
|
||||
|
||||
for row in reader:
|
||||
if not types:
|
||||
for v in row:
|
||||
|
@ -150,7 +149,11 @@ class PythonExecutor:
|
|||
types.append(type(ast.literal_eval(v)))
|
||||
except (ValueError, SyntaxError):
|
||||
types.append(str)
|
||||
context.set_row(tuple(t(v) for t, v in zip(types, row)))
|
||||
|
||||
# We can't cast empty values ('') to non-string types, so we convert them to None instead
|
||||
context.set_row(
|
||||
tuple(None if (t is not str and v == "") else t(v) for t, v in zip(types, row))
|
||||
)
|
||||
yield context.table.reader
|
||||
|
||||
def join(self, step, context):
|
||||
|
|
|
@ -548,12 +548,10 @@ class Expression(metaclass=_Expression):
|
|||
return new_node
|
||||
|
||||
@t.overload
|
||||
def replace(self, expression: E) -> E:
|
||||
...
|
||||
def replace(self, expression: E) -> E: ...
|
||||
|
||||
@t.overload
|
||||
def replace(self, expression: None) -> None:
|
||||
...
|
||||
def replace(self, expression: None) -> None: ...
|
||||
|
||||
def replace(self, expression):
|
||||
"""
|
||||
|
@ -913,14 +911,142 @@ class Predicate(Condition):
|
|||
class DerivedTable(Expression):
|
||||
@property
|
||||
def selects(self) -> t.List[Expression]:
|
||||
return self.this.selects if isinstance(self.this, Subqueryable) else []
|
||||
return self.this.selects if isinstance(self.this, Query) else []
|
||||
|
||||
@property
|
||||
def named_selects(self) -> t.List[str]:
|
||||
return [select.output_name for select in self.selects]
|
||||
|
||||
|
||||
class Unionable(Expression):
|
||||
class Query(Expression):
|
||||
def subquery(self, alias: t.Optional[ExpOrStr] = None, copy: bool = True) -> Subquery:
|
||||
"""
|
||||
Returns a `Subquery` that wraps around this query.
|
||||
|
||||
Example:
|
||||
>>> subquery = Select().select("x").from_("tbl").subquery()
|
||||
>>> Select().select("x").from_(subquery).sql()
|
||||
'SELECT x FROM (SELECT x FROM tbl)'
|
||||
|
||||
Args:
|
||||
alias: an optional alias for the subquery.
|
||||
copy: if `False`, modify this expression instance in-place.
|
||||
"""
|
||||
instance = maybe_copy(self, copy)
|
||||
if not isinstance(alias, Expression):
|
||||
alias = TableAlias(this=to_identifier(alias)) if alias else None
|
||||
|
||||
return Subquery(this=instance, alias=alias)
|
||||
|
||||
def limit(
|
||||
self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts
|
||||
) -> Select:
|
||||
"""
|
||||
Adds a LIMIT clause to this query.
|
||||
|
||||
Example:
|
||||
>>> select("1").union(select("1")).limit(1).sql()
|
||||
'SELECT * FROM (SELECT 1 UNION SELECT 1) AS _l_0 LIMIT 1'
|
||||
|
||||
Args:
|
||||
expression: the SQL code string to parse.
|
||||
This can also be an integer.
|
||||
If a `Limit` instance is passed, it will be used as-is.
|
||||
If another `Expression` instance is passed, it will be wrapped in a `Limit`.
|
||||
dialect: the dialect used to parse the input expression.
|
||||
copy: if `False`, modify this expression instance in-place.
|
||||
opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
A limited Select expression.
|
||||
"""
|
||||
return (
|
||||
select("*")
|
||||
.from_(self.subquery(alias="_l_0", copy=copy))
|
||||
.limit(expression, dialect=dialect, copy=False, **opts)
|
||||
)
|
||||
|
||||
@property
|
||||
def ctes(self) -> t.List[CTE]:
|
||||
"""Returns a list of all the CTEs attached to this query."""
|
||||
with_ = self.args.get("with")
|
||||
return with_.expressions if with_ else []
|
||||
|
||||
@property
|
||||
def selects(self) -> t.List[Expression]:
|
||||
"""Returns the query's projections."""
|
||||
raise NotImplementedError("Query objects must implement `selects`")
|
||||
|
||||
@property
|
||||
def named_selects(self) -> t.List[str]:
|
||||
"""Returns the output names of the query's projections."""
|
||||
raise NotImplementedError("Query objects must implement `named_selects`")
|
||||
|
||||
def select(
|
||||
self,
|
||||
*expressions: t.Optional[ExpOrStr],
|
||||
append: bool = True,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
) -> Query:
|
||||
"""
|
||||
Append to or set the SELECT expressions.
|
||||
|
||||
Example:
|
||||
>>> Select().select("x", "y").sql()
|
||||
'SELECT x, y'
|
||||
|
||||
Args:
|
||||
*expressions: the SQL code strings to parse.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
append: if `True`, add to any existing expressions.
|
||||
Otherwise, this resets the expressions.
|
||||
dialect: the dialect used to parse the input expressions.
|
||||
copy: if `False`, modify this expression instance in-place.
|
||||
opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
The modified Query expression.
|
||||
"""
|
||||
raise NotImplementedError("Query objects must implement `select`")
|
||||
|
||||
def with_(
|
||||
self,
|
||||
alias: ExpOrStr,
|
||||
as_: ExpOrStr,
|
||||
recursive: t.Optional[bool] = None,
|
||||
append: bool = True,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
) -> Query:
|
||||
"""
|
||||
Append to or set the common table expressions.
|
||||
|
||||
Example:
|
||||
>>> Select().with_("tbl2", as_="SELECT * FROM tbl").select("x").from_("tbl2").sql()
|
||||
'WITH tbl2 AS (SELECT * FROM tbl) SELECT x FROM tbl2'
|
||||
|
||||
Args:
|
||||
alias: the SQL code string to parse as the table name.
|
||||
If an `Expression` instance is passed, this is used as-is.
|
||||
as_: the SQL code string to parse as the table expression.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
recursive: set the RECURSIVE part of the expression. Defaults to `False`.
|
||||
append: if `True`, add to any existing expressions.
|
||||
Otherwise, this resets the expressions.
|
||||
dialect: the dialect used to parse the input expression.
|
||||
copy: if `False`, modify this expression instance in-place.
|
||||
opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
The modified expression.
|
||||
"""
|
||||
return _apply_cte_builder(
|
||||
self, alias, as_, recursive=recursive, append=append, dialect=dialect, copy=copy, **opts
|
||||
)
|
||||
|
||||
def union(
|
||||
self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
|
||||
) -> Union:
|
||||
|
@ -946,7 +1072,7 @@ class Unionable(Expression):
|
|||
|
||||
def intersect(
|
||||
self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
|
||||
) -> Unionable:
|
||||
) -> Intersect:
|
||||
"""
|
||||
Builds an INTERSECT expression.
|
||||
|
||||
|
@ -969,7 +1095,7 @@ class Unionable(Expression):
|
|||
|
||||
def except_(
|
||||
self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
|
||||
) -> Unionable:
|
||||
) -> Except:
|
||||
"""
|
||||
Builds an EXCEPT expression.
|
||||
|
||||
|
@ -991,7 +1117,7 @@ class Unionable(Expression):
|
|||
return except_(left=self, right=expression, distinct=distinct, dialect=dialect, **opts)
|
||||
|
||||
|
||||
class UDTF(DerivedTable, Unionable):
|
||||
class UDTF(DerivedTable):
|
||||
@property
|
||||
def selects(self) -> t.List[Expression]:
|
||||
alias = self.args.get("alias")
|
||||
|
@ -1017,23 +1143,23 @@ class Refresh(Expression):
|
|||
|
||||
class DDL(Expression):
|
||||
@property
|
||||
def ctes(self):
|
||||
def ctes(self) -> t.List[CTE]:
|
||||
"""Returns a list of all the CTEs attached to this statement."""
|
||||
with_ = self.args.get("with")
|
||||
if not with_:
|
||||
return []
|
||||
return with_.expressions
|
||||
|
||||
@property
|
||||
def named_selects(self) -> t.List[str]:
|
||||
if isinstance(self.expression, Subqueryable):
|
||||
return self.expression.named_selects
|
||||
return []
|
||||
return with_.expressions if with_ else []
|
||||
|
||||
@property
|
||||
def selects(self) -> t.List[Expression]:
|
||||
if isinstance(self.expression, Subqueryable):
|
||||
return self.expression.selects
|
||||
return []
|
||||
"""If this statement contains a query (e.g. a CTAS), this returns the query's projections."""
|
||||
return self.expression.selects if isinstance(self.expression, Query) else []
|
||||
|
||||
@property
|
||||
def named_selects(self) -> t.List[str]:
|
||||
"""
|
||||
If this statement contains a query (e.g. a CTAS), this returns the output
|
||||
names of the query's projections.
|
||||
"""
|
||||
return self.expression.named_selects if isinstance(self.expression, Query) else []
|
||||
|
||||
|
||||
class DML(Expression):
|
||||
|
@ -1096,6 +1222,19 @@ class Create(DDL):
|
|||
return kind and kind.upper()
|
||||
|
||||
|
||||
class TruncateTable(Expression):
|
||||
arg_types = {
|
||||
"expressions": True,
|
||||
"is_database": False,
|
||||
"exists": False,
|
||||
"only": False,
|
||||
"cluster": False,
|
||||
"identity": False,
|
||||
"option": False,
|
||||
"partition": False,
|
||||
}
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/sql/create-clone
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_clone_statement
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_copy
|
||||
|
@ -1271,6 +1410,10 @@ class ColumnDef(Expression):
|
|||
def constraints(self) -> t.List[ColumnConstraint]:
|
||||
return self.args.get("constraints") or []
|
||||
|
||||
@property
|
||||
def kind(self) -> t.Optional[DataType]:
|
||||
return self.args.get("kind")
|
||||
|
||||
|
||||
class AlterColumn(Expression):
|
||||
arg_types = {
|
||||
|
@ -1367,7 +1510,7 @@ class CharacterSetColumnConstraint(ColumnConstraintKind):
|
|||
|
||||
|
||||
class CheckColumnConstraint(ColumnConstraintKind):
|
||||
pass
|
||||
arg_types = {"this": True, "enforced": False}
|
||||
|
||||
|
||||
class ClusteredColumnConstraint(ColumnConstraintKind):
|
||||
|
@ -1776,6 +1919,10 @@ class Partition(Expression):
|
|||
arg_types = {"expressions": True}
|
||||
|
||||
|
||||
class PartitionRange(Expression):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
class Fetch(Expression):
|
||||
arg_types = {
|
||||
"direction": False,
|
||||
|
@ -2173,6 +2320,10 @@ class LocationProperty(Property):
|
|||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class LockProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class LockingProperty(Property):
|
||||
arg_types = {
|
||||
"this": False,
|
||||
|
@ -2310,7 +2461,7 @@ class StabilityProperty(Property):
|
|||
|
||||
|
||||
class TemporaryProperty(Property):
|
||||
arg_types = {}
|
||||
arg_types = {"this": False}
|
||||
|
||||
|
||||
class TransformModelProperty(Property):
|
||||
|
@ -2356,6 +2507,7 @@ class Properties(Expression):
|
|||
"FORMAT": FileFormatProperty,
|
||||
"LANGUAGE": LanguageProperty,
|
||||
"LOCATION": LocationProperty,
|
||||
"LOCK": LockProperty,
|
||||
"PARTITIONED_BY": PartitionedByProperty,
|
||||
"RETURNS": ReturnsProperty,
|
||||
"ROW_FORMAT": RowFormatProperty,
|
||||
|
@ -2445,102 +2597,13 @@ class Tuple(Expression):
|
|||
)
|
||||
|
||||
|
||||
class Subqueryable(Unionable):
|
||||
def subquery(self, alias: t.Optional[ExpOrStr] = None, copy: bool = True) -> Subquery:
|
||||
"""
|
||||
Convert this expression to an aliased expression that can be used as a Subquery.
|
||||
|
||||
Example:
|
||||
>>> subquery = Select().select("x").from_("tbl").subquery()
|
||||
>>> Select().select("x").from_(subquery).sql()
|
||||
'SELECT x FROM (SELECT x FROM tbl)'
|
||||
|
||||
Args:
|
||||
alias (str | Identifier): an optional alias for the subquery
|
||||
copy (bool): if `False`, modify this expression instance in-place.
|
||||
|
||||
Returns:
|
||||
Alias: the subquery
|
||||
"""
|
||||
instance = maybe_copy(self, copy)
|
||||
if not isinstance(alias, Expression):
|
||||
alias = TableAlias(this=to_identifier(alias)) if alias else None
|
||||
|
||||
return Subquery(this=instance, alias=alias)
|
||||
|
||||
def limit(
|
||||
self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts
|
||||
) -> Select:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def ctes(self):
|
||||
with_ = self.args.get("with")
|
||||
if not with_:
|
||||
return []
|
||||
return with_.expressions
|
||||
|
||||
@property
|
||||
def selects(self) -> t.List[Expression]:
|
||||
raise NotImplementedError("Subqueryable objects must implement `selects`")
|
||||
|
||||
@property
|
||||
def named_selects(self) -> t.List[str]:
|
||||
raise NotImplementedError("Subqueryable objects must implement `named_selects`")
|
||||
|
||||
def select(
|
||||
self,
|
||||
*expressions: t.Optional[ExpOrStr],
|
||||
append: bool = True,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
) -> Subqueryable:
|
||||
raise NotImplementedError("Subqueryable objects must implement `select`")
|
||||
|
||||
def with_(
|
||||
self,
|
||||
alias: ExpOrStr,
|
||||
as_: ExpOrStr,
|
||||
recursive: t.Optional[bool] = None,
|
||||
append: bool = True,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
) -> Subqueryable:
|
||||
"""
|
||||
Append to or set the common table expressions.
|
||||
|
||||
Example:
|
||||
>>> Select().with_("tbl2", as_="SELECT * FROM tbl").select("x").from_("tbl2").sql()
|
||||
'WITH tbl2 AS (SELECT * FROM tbl) SELECT x FROM tbl2'
|
||||
|
||||
Args:
|
||||
alias: the SQL code string to parse as the table name.
|
||||
If an `Expression` instance is passed, this is used as-is.
|
||||
as_: the SQL code string to parse as the table expression.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
recursive: set the RECURSIVE part of the expression. Defaults to `False`.
|
||||
append: if `True`, add to any existing expressions.
|
||||
Otherwise, this resets the expressions.
|
||||
dialect: the dialect used to parse the input expression.
|
||||
copy: if `False`, modify this expression instance in-place.
|
||||
opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
The modified expression.
|
||||
"""
|
||||
return _apply_cte_builder(
|
||||
self, alias, as_, recursive=recursive, append=append, dialect=dialect, copy=copy, **opts
|
||||
)
|
||||
|
||||
|
||||
QUERY_MODIFIERS = {
|
||||
"match": False,
|
||||
"laterals": False,
|
||||
"joins": False,
|
||||
"connect": False,
|
||||
"pivots": False,
|
||||
"prewhere": False,
|
||||
"where": False,
|
||||
"group": False,
|
||||
"having": False,
|
||||
|
@ -2556,9 +2619,16 @@ QUERY_MODIFIERS = {
|
|||
"sample": False,
|
||||
"settings": False,
|
||||
"format": False,
|
||||
"options": False,
|
||||
}
|
||||
|
||||
|
||||
# https://learn.microsoft.com/en-us/sql/t-sql/queries/option-clause-transact-sql?view=sql-server-ver16
|
||||
# https://learn.microsoft.com/en-us/sql/t-sql/queries/hints-transact-sql-query?view=sql-server-ver16
|
||||
class QueryOption(Expression):
|
||||
arg_types = {"this": True, "expression": False}
|
||||
|
||||
|
||||
# https://learn.microsoft.com/en-us/sql/t-sql/queries/hints-transact-sql-table?view=sql-server-ver16
|
||||
class WithTableHint(Expression):
|
||||
arg_types = {"expressions": True}
|
||||
|
@ -2590,6 +2660,7 @@ class Table(Expression):
|
|||
"pattern": False,
|
||||
"ordinality": False,
|
||||
"when": False,
|
||||
"only": False,
|
||||
}
|
||||
|
||||
@property
|
||||
|
@ -2638,7 +2709,7 @@ class Table(Expression):
|
|||
return col
|
||||
|
||||
|
||||
class Union(Subqueryable):
|
||||
class Union(Query):
|
||||
arg_types = {
|
||||
"with": False,
|
||||
"this": True,
|
||||
|
@ -2648,34 +2719,6 @@ class Union(Subqueryable):
|
|||
**QUERY_MODIFIERS,
|
||||
}
|
||||
|
||||
def limit(
|
||||
self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts
|
||||
) -> Select:
|
||||
"""
|
||||
Set the LIMIT expression.
|
||||
|
||||
Example:
|
||||
>>> select("1").union(select("1")).limit(1).sql()
|
||||
'SELECT * FROM (SELECT 1 UNION SELECT 1) AS _l_0 LIMIT 1'
|
||||
|
||||
Args:
|
||||
expression: the SQL code string to parse.
|
||||
This can also be an integer.
|
||||
If a `Limit` instance is passed, this is used as-is.
|
||||
If another `Expression` instance is passed, it will be wrapped in a `Limit`.
|
||||
dialect: the dialect used to parse the input expression.
|
||||
copy: if `False`, modify this expression instance in-place.
|
||||
opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
The limited subqueryable.
|
||||
"""
|
||||
return (
|
||||
select("*")
|
||||
.from_(self.subquery(alias="_l_0", copy=copy))
|
||||
.limit(expression, dialect=dialect, copy=False, **opts)
|
||||
)
|
||||
|
||||
def select(
|
||||
self,
|
||||
*expressions: t.Optional[ExpOrStr],
|
||||
|
@ -2684,26 +2727,7 @@ class Union(Subqueryable):
|
|||
copy: bool = True,
|
||||
**opts,
|
||||
) -> Union:
|
||||
"""Append to or set the SELECT of the union recursively.
|
||||
|
||||
Example:
|
||||
>>> from sqlglot import parse_one
|
||||
>>> parse_one("select a from x union select a from y union select a from z").select("b").sql()
|
||||
'SELECT a, b FROM x UNION SELECT a, b FROM y UNION SELECT a, b FROM z'
|
||||
|
||||
Args:
|
||||
*expressions: the SQL code strings to parse.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
append: if `True`, add to any existing expressions.
|
||||
Otherwise, this resets the expressions.
|
||||
dialect: the dialect used to parse the input expressions.
|
||||
copy: if `False`, modify this expression instance in-place.
|
||||
opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
Union: the modified expression.
|
||||
"""
|
||||
this = self.copy() if copy else self
|
||||
this = maybe_copy(self, copy)
|
||||
this.this.unnest().select(*expressions, append=append, dialect=dialect, copy=False, **opts)
|
||||
this.expression.unnest().select(
|
||||
*expressions, append=append, dialect=dialect, copy=False, **opts
|
||||
|
@ -2800,7 +2824,7 @@ class Lock(Expression):
|
|||
arg_types = {"update": True, "expressions": False, "wait": False}
|
||||
|
||||
|
||||
class Select(Subqueryable):
|
||||
class Select(Query):
|
||||
arg_types = {
|
||||
"with": False,
|
||||
"kind": False,
|
||||
|
@ -3011,25 +3035,6 @@ class Select(Subqueryable):
|
|||
def limit(
|
||||
self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts
|
||||
) -> Select:
|
||||
"""
|
||||
Set the LIMIT expression.
|
||||
|
||||
Example:
|
||||
>>> Select().from_("tbl").select("x").limit(10).sql()
|
||||
'SELECT x FROM tbl LIMIT 10'
|
||||
|
||||
Args:
|
||||
expression: the SQL code string to parse.
|
||||
This can also be an integer.
|
||||
If a `Limit` instance is passed, this is used as-is.
|
||||
If another `Expression` instance is passed, it will be wrapped in a `Limit`.
|
||||
dialect: the dialect used to parse the input expression.
|
||||
copy: if `False`, modify this expression instance in-place.
|
||||
opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
Select: the modified expression.
|
||||
"""
|
||||
return _apply_builder(
|
||||
expression=expression,
|
||||
instance=self,
|
||||
|
@ -3084,31 +3089,13 @@ class Select(Subqueryable):
|
|||
copy: bool = True,
|
||||
**opts,
|
||||
) -> Select:
|
||||
"""
|
||||
Append to or set the SELECT expressions.
|
||||
|
||||
Example:
|
||||
>>> Select().select("x", "y").sql()
|
||||
'SELECT x, y'
|
||||
|
||||
Args:
|
||||
*expressions: the SQL code strings to parse.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
append: if `True`, add to any existing expressions.
|
||||
Otherwise, this resets the expressions.
|
||||
dialect: the dialect used to parse the input expressions.
|
||||
copy: if `False`, modify this expression instance in-place.
|
||||
opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
The modified Select expression.
|
||||
"""
|
||||
return _apply_list_builder(
|
||||
*expressions,
|
||||
instance=self,
|
||||
arg="expressions",
|
||||
append=append,
|
||||
dialect=dialect,
|
||||
into=Expression,
|
||||
copy=copy,
|
||||
**opts,
|
||||
)
|
||||
|
@ -3416,12 +3403,8 @@ class Select(Subqueryable):
|
|||
The new Create expression.
|
||||
"""
|
||||
instance = maybe_copy(self, copy)
|
||||
table_expression = maybe_parse(
|
||||
table,
|
||||
into=Table,
|
||||
dialect=dialect,
|
||||
**opts,
|
||||
)
|
||||
table_expression = maybe_parse(table, into=Table, dialect=dialect, **opts)
|
||||
|
||||
properties_expression = None
|
||||
if properties:
|
||||
properties_expression = Properties.from_dict(properties)
|
||||
|
@ -3493,7 +3476,10 @@ class Select(Subqueryable):
|
|||
return self.expressions
|
||||
|
||||
|
||||
class Subquery(DerivedTable, Unionable):
|
||||
UNWRAPPED_QUERIES = (Select, Union)
|
||||
|
||||
|
||||
class Subquery(DerivedTable, Query):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"alias": False,
|
||||
|
@ -3502,9 +3488,7 @@ class Subquery(DerivedTable, Unionable):
|
|||
}
|
||||
|
||||
def unnest(self):
|
||||
"""
|
||||
Returns the first non subquery.
|
||||
"""
|
||||
"""Returns the first non subquery."""
|
||||
expression = self
|
||||
while isinstance(expression, Subquery):
|
||||
expression = expression.this
|
||||
|
@ -3516,6 +3500,18 @@ class Subquery(DerivedTable, Unionable):
|
|||
expression = t.cast(Subquery, expression.parent)
|
||||
return expression
|
||||
|
||||
def select(
|
||||
self,
|
||||
*expressions: t.Optional[ExpOrStr],
|
||||
append: bool = True,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
) -> Subquery:
|
||||
this = maybe_copy(self, copy)
|
||||
this.unnest().select(*expressions, append=append, dialect=dialect, copy=False, **opts)
|
||||
return this
|
||||
|
||||
@property
|
||||
def is_wrapper(self) -> bool:
|
||||
"""
|
||||
|
@ -3603,6 +3599,10 @@ class WindowSpec(Expression):
|
|||
}
|
||||
|
||||
|
||||
class PreWhere(Expression):
|
||||
pass
|
||||
|
||||
|
||||
class Where(Expression):
|
||||
pass
|
||||
|
||||
|
@ -3646,6 +3646,10 @@ class Boolean(Condition):
|
|||
class DataTypeParam(Expression):
|
||||
arg_types = {"this": True, "expression": False}
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.this.name
|
||||
|
||||
|
||||
class DataType(Expression):
|
||||
arg_types = {
|
||||
|
@ -3926,11 +3930,17 @@ class Rollback(Expression):
|
|||
|
||||
|
||||
class AlterTable(Expression):
|
||||
arg_types = {"this": True, "actions": True, "exists": False, "only": False}
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"actions": True,
|
||||
"exists": False,
|
||||
"only": False,
|
||||
"options": False,
|
||||
}
|
||||
|
||||
|
||||
class AddConstraint(Expression):
|
||||
arg_types = {"this": False, "expression": False, "enforced": False}
|
||||
arg_types = {"expressions": True}
|
||||
|
||||
|
||||
class DropPartition(Expression):
|
||||
|
@ -3995,6 +4005,10 @@ class Overlaps(Binary):
|
|||
|
||||
|
||||
class Dot(Binary):
|
||||
@property
|
||||
def is_star(self) -> bool:
|
||||
return self.expression.is_star
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.expression.name
|
||||
|
@ -4390,6 +4404,10 @@ class Anonymous(Func):
|
|||
arg_types = {"this": True, "expressions": False}
|
||||
is_var_len_args = True
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.this if isinstance(self.this, str) else self.this.name
|
||||
|
||||
|
||||
class AnonymousAggFunc(AggFunc):
|
||||
arg_types = {"this": True, "expressions": False}
|
||||
|
@ -4433,8 +4451,13 @@ class ToChar(Func):
|
|||
arg_types = {"this": True, "format": False, "nlsparam": False}
|
||||
|
||||
|
||||
# https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16#syntax
|
||||
class Convert(Func):
|
||||
arg_types = {"this": True, "expression": True, "style": False}
|
||||
|
||||
|
||||
class GenerateSeries(Func):
|
||||
arg_types = {"start": True, "end": True, "step": False}
|
||||
arg_types = {"start": True, "end": True, "step": False, "is_end_exclusive": False}
|
||||
|
||||
|
||||
class ArrayAgg(AggFunc):
|
||||
|
@ -4624,6 +4647,11 @@ class ConcatWs(Concat):
|
|||
_sql_names = ["CONCAT_WS"]
|
||||
|
||||
|
||||
# https://docs.oracle.com/cd/B13789_01/server.101/b10759/operators004.htm#i1035022
|
||||
class ConnectByRoot(Func):
|
||||
pass
|
||||
|
||||
|
||||
class Count(AggFunc):
|
||||
arg_types = {"this": False, "expressions": False}
|
||||
is_var_len_args = True
|
||||
|
@ -5197,6 +5225,10 @@ class Month(Func):
|
|||
pass
|
||||
|
||||
|
||||
class AddMonths(Func):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
class Nvl2(Func):
|
||||
arg_types = {"this": True, "true": True, "false": False}
|
||||
|
||||
|
@ -5313,6 +5345,10 @@ class SHA2(Func):
|
|||
arg_types = {"this": True, "length": False}
|
||||
|
||||
|
||||
class Sign(Func):
|
||||
_sql_names = ["SIGN", "SIGNUM"]
|
||||
|
||||
|
||||
class SortArray(Func):
|
||||
arg_types = {"this": True, "asc": False}
|
||||
|
||||
|
@ -5554,7 +5590,13 @@ class Use(Expression):
|
|||
|
||||
|
||||
class Merge(Expression):
|
||||
arg_types = {"this": True, "using": True, "on": True, "expressions": True, "with": False}
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"using": True,
|
||||
"on": True,
|
||||
"expressions": True,
|
||||
"with": False,
|
||||
}
|
||||
|
||||
|
||||
class When(Func):
|
||||
|
@ -5587,8 +5629,7 @@ def maybe_parse(
|
|||
prefix: t.Optional[str] = None,
|
||||
copy: bool = False,
|
||||
**opts,
|
||||
) -> E:
|
||||
...
|
||||
) -> E: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
|
@ -5600,8 +5641,7 @@ def maybe_parse(
|
|||
prefix: t.Optional[str] = None,
|
||||
copy: bool = False,
|
||||
**opts,
|
||||
) -> E:
|
||||
...
|
||||
) -> E: ...
|
||||
|
||||
|
||||
def maybe_parse(
|
||||
|
@ -5653,13 +5693,11 @@ def maybe_parse(
|
|||
|
||||
|
||||
@t.overload
|
||||
def maybe_copy(instance: None, copy: bool = True) -> None:
|
||||
...
|
||||
def maybe_copy(instance: None, copy: bool = True) -> None: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def maybe_copy(instance: E, copy: bool = True) -> E:
|
||||
...
|
||||
def maybe_copy(instance: E, copy: bool = True) -> E: ...
|
||||
|
||||
|
||||
def maybe_copy(instance, copy=True):
|
||||
|
@ -6282,15 +6320,13 @@ SAFE_IDENTIFIER_RE: t.Pattern[str] = re.compile(r"^[_a-zA-Z][\w]*$")
|
|||
|
||||
|
||||
@t.overload
|
||||
def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None:
|
||||
...
|
||||
def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def to_identifier(
|
||||
name: str | Identifier, quoted: t.Optional[bool] = None, copy: bool = True
|
||||
) -> Identifier:
|
||||
...
|
||||
) -> Identifier: ...
|
||||
|
||||
|
||||
def to_identifier(name, quoted=None, copy=True):
|
||||
|
@ -6362,13 +6398,11 @@ def to_interval(interval: str | Literal) -> Interval:
|
|||
|
||||
|
||||
@t.overload
|
||||
def to_table(sql_path: str | Table, **kwargs) -> Table:
|
||||
...
|
||||
def to_table(sql_path: str | Table, **kwargs) -> Table: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def to_table(sql_path: None, **kwargs) -> None:
|
||||
...
|
||||
def to_table(sql_path: None, **kwargs) -> None: ...
|
||||
|
||||
|
||||
def to_table(
|
||||
|
@ -6929,7 +6963,7 @@ def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression:
|
|||
if isinstance(node, Placeholder):
|
||||
if node.name:
|
||||
new_name = kwargs.get(node.name)
|
||||
if new_name:
|
||||
if new_name is not None:
|
||||
return convert(new_name)
|
||||
else:
|
||||
try:
|
||||
|
@ -6943,7 +6977,7 @@ def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression:
|
|||
|
||||
def expand(
|
||||
expression: Expression,
|
||||
sources: t.Dict[str, Subqueryable],
|
||||
sources: t.Dict[str, Query],
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
) -> Expression:
|
||||
|
@ -6959,7 +6993,7 @@ def expand(
|
|||
|
||||
Args:
|
||||
expression: The expression to expand.
|
||||
sources: A dictionary of name to Subqueryables.
|
||||
sources: A dictionary of name to Queries.
|
||||
dialect: The dialect of the sources dict.
|
||||
copy: Whether to copy the expression during transformation. Defaults to True.
|
||||
|
||||
|
|
|
@ -73,17 +73,16 @@ class Generator(metaclass=_Generator):
|
|||
TRANSFORMS: t.Dict[t.Type[exp.Expression], t.Callable[..., str]] = {
|
||||
**JSON_PATH_PART_TRANSFORMS,
|
||||
exp.AutoRefreshProperty: lambda self, e: f"AUTO REFRESH {self.sql(e, 'this')}",
|
||||
exp.CaseSpecificColumnConstraint: lambda self,
|
||||
exp.CaseSpecificColumnConstraint: lambda _,
|
||||
e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
|
||||
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
|
||||
exp.CharacterSetProperty: lambda self,
|
||||
e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}",
|
||||
exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})",
|
||||
exp.ClusteredColumnConstraint: lambda self,
|
||||
e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})",
|
||||
exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
|
||||
exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
|
||||
exp.CopyGrantsProperty: lambda self, e: "COPY GRANTS",
|
||||
exp.CopyGrantsProperty: lambda *_: "COPY GRANTS",
|
||||
exp.DateAdd: lambda self, e: self.func(
|
||||
"DATE_ADD", e.this, e.expression, exp.Literal.string(e.text("unit"))
|
||||
),
|
||||
|
@ -91,8 +90,8 @@ class Generator(metaclass=_Generator):
|
|||
exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}",
|
||||
exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}",
|
||||
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
|
||||
exp.ExternalProperty: lambda self, e: "EXTERNAL",
|
||||
exp.HeapProperty: lambda self, e: "HEAP",
|
||||
exp.ExternalProperty: lambda *_: "EXTERNAL",
|
||||
exp.HeapProperty: lambda *_: "HEAP",
|
||||
exp.InheritsProperty: lambda self, e: f"INHERITS ({self.expressions(e, flat=True)})",
|
||||
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
|
||||
exp.InputModelProperty: lambda self, e: f"INPUT{self.sql(e, 'this')}",
|
||||
|
@ -105,13 +104,13 @@ class Generator(metaclass=_Generator):
|
|||
),
|
||||
exp.LanguageProperty: lambda self, e: self.naked_property(e),
|
||||
exp.LocationProperty: lambda self, e: self.naked_property(e),
|
||||
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
|
||||
exp.MaterializedProperty: lambda self, e: "MATERIALIZED",
|
||||
exp.LogProperty: lambda _, e: f"{'NO ' if e.args.get('no') else ''}LOG",
|
||||
exp.MaterializedProperty: lambda *_: "MATERIALIZED",
|
||||
exp.NonClusteredColumnConstraint: lambda self,
|
||||
e: f"NONCLUSTERED ({self.expressions(e, 'this', indent=False)})",
|
||||
exp.NoPrimaryIndexProperty: lambda self, e: "NO PRIMARY INDEX",
|
||||
exp.NotForReplicationColumnConstraint: lambda self, e: "NOT FOR REPLICATION",
|
||||
exp.OnCommitProperty: lambda self,
|
||||
exp.NoPrimaryIndexProperty: lambda *_: "NO PRIMARY INDEX",
|
||||
exp.NotForReplicationColumnConstraint: lambda *_: "NOT FOR REPLICATION",
|
||||
exp.OnCommitProperty: lambda _,
|
||||
e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS",
|
||||
exp.OnProperty: lambda self, e: f"ON {self.sql(e, 'this')}",
|
||||
exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
|
||||
|
@ -122,21 +121,21 @@ class Generator(metaclass=_Generator):
|
|||
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
|
||||
exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}",
|
||||
exp.SetConfigProperty: lambda self, e: self.sql(e, "this"),
|
||||
exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
|
||||
exp.SetProperty: lambda _, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
|
||||
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
|
||||
exp.SqlReadWriteProperty: lambda self, e: e.name,
|
||||
exp.SqlSecurityProperty: lambda self,
|
||||
exp.SqlReadWriteProperty: lambda _, e: e.name,
|
||||
exp.SqlSecurityProperty: lambda _,
|
||||
e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
|
||||
exp.StabilityProperty: lambda self, e: e.name,
|
||||
exp.TemporaryProperty: lambda self, e: "TEMPORARY",
|
||||
exp.StabilityProperty: lambda _, e: e.name,
|
||||
exp.TemporaryProperty: lambda *_: "TEMPORARY",
|
||||
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
|
||||
exp.Timestamp: lambda self, e: self.func("TIMESTAMP", e.this, e.expression),
|
||||
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
|
||||
exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions),
|
||||
exp.TransientProperty: lambda self, e: "TRANSIENT",
|
||||
exp.UppercaseColumnConstraint: lambda self, e: "UPPERCASE",
|
||||
exp.TransientProperty: lambda *_: "TRANSIENT",
|
||||
exp.UppercaseColumnConstraint: lambda *_: "UPPERCASE",
|
||||
exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]),
|
||||
exp.VolatileProperty: lambda self, e: "VOLATILE",
|
||||
exp.VolatileProperty: lambda *_: "VOLATILE",
|
||||
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
|
||||
}
|
||||
|
||||
|
@ -356,6 +355,7 @@ class Generator(metaclass=_Generator):
|
|||
STRUCT_DELIMITER = ("<", ">")
|
||||
|
||||
PARAMETER_TOKEN = "@"
|
||||
NAMED_PLACEHOLDER_TOKEN = ":"
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE,
|
||||
|
@ -388,6 +388,7 @@ class Generator(metaclass=_Generator):
|
|||
exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.LikeProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.LocationProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.LockProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.LockingProperty: exp.Properties.Location.POST_ALIAS,
|
||||
exp.LogProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.MaterializedProperty: exp.Properties.Location.POST_CREATE,
|
||||
|
@ -459,11 +460,16 @@ class Generator(metaclass=_Generator):
|
|||
exp.Paren,
|
||||
)
|
||||
|
||||
PARAMETERIZABLE_TEXT_TYPES = {
|
||||
exp.DataType.Type.NVARCHAR,
|
||||
exp.DataType.Type.VARCHAR,
|
||||
exp.DataType.Type.CHAR,
|
||||
exp.DataType.Type.NCHAR,
|
||||
}
|
||||
|
||||
# Expressions that need to have all CTEs under them bubbled up to them
|
||||
EXPRESSIONS_WITHOUT_NESTED_CTES: t.Set[t.Type[exp.Expression]] = set()
|
||||
|
||||
KEY_VALUE_DEFINITIONS = (exp.EQ, exp.PropertyEQ, exp.Slice)
|
||||
|
||||
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
|
||||
|
||||
__slots__ = (
|
||||
|
@ -630,7 +636,7 @@ class Generator(metaclass=_Generator):
|
|||
this_sql = self.indent(
|
||||
(
|
||||
self.sql(expression)
|
||||
if isinstance(expression, (exp.Select, exp.Union))
|
||||
if isinstance(expression, exp.UNWRAPPED_QUERIES)
|
||||
else self.sql(expression, "this")
|
||||
),
|
||||
level=1,
|
||||
|
@ -1535,8 +1541,8 @@ class Generator(metaclass=_Generator):
|
|||
expr = self.sql(expression, "expression")
|
||||
return f"{this} ({kind} => {expr})"
|
||||
|
||||
def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str:
|
||||
table = ".".join(
|
||||
def table_parts(self, expression: exp.Table) -> str:
|
||||
return ".".join(
|
||||
self.sql(part)
|
||||
for part in (
|
||||
expression.args.get("catalog"),
|
||||
|
@ -1546,6 +1552,9 @@ class Generator(metaclass=_Generator):
|
|||
if part is not None
|
||||
)
|
||||
|
||||
def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str:
|
||||
table = self.table_parts(expression)
|
||||
only = "ONLY " if expression.args.get("only") else ""
|
||||
version = self.sql(expression, "version")
|
||||
version = f" {version}" if version else ""
|
||||
alias = self.sql(expression, "alias")
|
||||
|
@ -1572,7 +1581,7 @@ class Generator(metaclass=_Generator):
|
|||
if when:
|
||||
table = f"{table} {when}"
|
||||
|
||||
return f"{table}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}{ordinality}"
|
||||
return f"{only}{table}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}{ordinality}"
|
||||
|
||||
def tablesample_sql(
|
||||
self,
|
||||
|
@ -1681,7 +1690,7 @@ class Generator(metaclass=_Generator):
|
|||
alias_node = expression.args.get("alias")
|
||||
column_names = alias_node and alias_node.columns
|
||||
|
||||
selects: t.List[exp.Subqueryable] = []
|
||||
selects: t.List[exp.Query] = []
|
||||
|
||||
for i, tup in enumerate(expression.expressions):
|
||||
row = tup.expressions
|
||||
|
@ -1697,10 +1706,8 @@ class Generator(metaclass=_Generator):
|
|||
# This may result in poor performance for large-cardinality `VALUES` tables, due to
|
||||
# the deep nesting of the resulting exp.Unions. If this is a problem, either increase
|
||||
# `sys.setrecursionlimit` to avoid RecursionErrors, or don't set `pretty`.
|
||||
subqueryable = reduce(lambda x, y: exp.union(x, y, distinct=False, copy=False), selects)
|
||||
return self.subquery_sql(
|
||||
subqueryable.subquery(alias_node and alias_node.this, copy=False)
|
||||
)
|
||||
query = reduce(lambda x, y: exp.union(x, y, distinct=False, copy=False), selects)
|
||||
return self.subquery_sql(query.subquery(alias_node and alias_node.this, copy=False))
|
||||
|
||||
alias = f" AS {self.sql(alias_node, 'this')}" if alias_node else ""
|
||||
unions = " UNION ALL ".join(self.sql(select) for select in selects)
|
||||
|
@ -1854,7 +1861,7 @@ class Generator(metaclass=_Generator):
|
|||
]
|
||||
|
||||
args_sql = ", ".join(self.sql(e) for e in args)
|
||||
args_sql = f"({args_sql})" if any(top and not e.is_number for e in args) else args_sql
|
||||
args_sql = f"({args_sql})" if top and any(not e.is_number for e in args) else args_sql
|
||||
expressions = self.expressions(expression, flat=True)
|
||||
expressions = f" BY {expressions}" if expressions else ""
|
||||
|
||||
|
@ -2070,12 +2077,17 @@ class Generator(metaclass=_Generator):
|
|||
else []
|
||||
)
|
||||
|
||||
options = self.expressions(expression, key="options")
|
||||
if options:
|
||||
options = f" OPTION{self.wrap(options)}"
|
||||
|
||||
return csv(
|
||||
*sqls,
|
||||
*[self.sql(join) for join in expression.args.get("joins") or []],
|
||||
self.sql(expression, "connect"),
|
||||
self.sql(expression, "match"),
|
||||
*[self.sql(lateral) for lateral in expression.args.get("laterals") or []],
|
||||
self.sql(expression, "prewhere"),
|
||||
self.sql(expression, "where"),
|
||||
self.sql(expression, "group"),
|
||||
self.sql(expression, "having"),
|
||||
|
@ -2083,9 +2095,13 @@ class Generator(metaclass=_Generator):
|
|||
self.sql(expression, "order"),
|
||||
*offset_limit_modifiers,
|
||||
*self.after_limit_modifiers(expression),
|
||||
options,
|
||||
sep="",
|
||||
)
|
||||
|
||||
def queryoption_sql(self, expression: exp.QueryOption) -> str:
|
||||
return ""
|
||||
|
||||
def offset_limit_modifiers(
|
||||
self, expression: exp.Expression, fetch: bool, limit: t.Optional[exp.Fetch | exp.Limit]
|
||||
) -> t.List[str]:
|
||||
|
@ -2140,9 +2156,9 @@ class Generator(metaclass=_Generator):
|
|||
self.sql(
|
||||
exp.Struct(
|
||||
expressions=[
|
||||
exp.column(e.output_name).eq(
|
||||
e.this if isinstance(e, exp.Alias) else e
|
||||
)
|
||||
exp.PropertyEQ(this=e.args.get("alias"), expression=e.this)
|
||||
if isinstance(e, exp.Alias)
|
||||
else e
|
||||
for e in expression.expressions
|
||||
]
|
||||
)
|
||||
|
@ -2204,7 +2220,7 @@ class Generator(metaclass=_Generator):
|
|||
return f"@@{kind}{this}"
|
||||
|
||||
def placeholder_sql(self, expression: exp.Placeholder) -> str:
|
||||
return f":{expression.name}" if expression.name else "?"
|
||||
return f"{self.NAMED_PLACEHOLDER_TOKEN}{expression.name}" if expression.name else "?"
|
||||
|
||||
def subquery_sql(self, expression: exp.Subquery, sep: str = " AS ") -> str:
|
||||
alias = self.sql(expression, "alias")
|
||||
|
@ -2261,6 +2277,9 @@ class Generator(metaclass=_Generator):
|
|||
|
||||
return f"UNNEST({args}){suffix}"
|
||||
|
||||
def prewhere_sql(self, expression: exp.PreWhere) -> str:
|
||||
return ""
|
||||
|
||||
def where_sql(self, expression: exp.Where) -> str:
|
||||
this = self.indent(self.sql(expression, "this"))
|
||||
return f"{self.seg('WHERE')}{self.sep()}{this}"
|
||||
|
@ -2326,7 +2345,7 @@ class Generator(metaclass=_Generator):
|
|||
|
||||
def any_sql(self, expression: exp.Any) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
if isinstance(expression.this, exp.Subqueryable):
|
||||
if isinstance(expression.this, exp.UNWRAPPED_QUERIES):
|
||||
this = self.wrap(this)
|
||||
return f"ANY {this}"
|
||||
|
||||
|
@ -2568,7 +2587,7 @@ class Generator(metaclass=_Generator):
|
|||
is_global = " GLOBAL" if expression.args.get("is_global") else ""
|
||||
|
||||
if query:
|
||||
in_sql = self.wrap(query)
|
||||
in_sql = self.wrap(self.sql(query))
|
||||
elif unnest:
|
||||
in_sql = self.in_unnest_op(unnest)
|
||||
elif field:
|
||||
|
@ -2610,7 +2629,7 @@ class Generator(metaclass=_Generator):
|
|||
return f"REFERENCES {this}{expressions}{options}"
|
||||
|
||||
def anonymous_sql(self, expression: exp.Anonymous) -> str:
|
||||
return self.func(expression.name, *expression.expressions)
|
||||
return self.func(self.sql(expression, "this"), *expression.expressions)
|
||||
|
||||
def paren_sql(self, expression: exp.Paren) -> str:
|
||||
if isinstance(expression.unnest(), exp.Select):
|
||||
|
@ -2822,7 +2841,9 @@ class Generator(metaclass=_Generator):
|
|||
|
||||
exists = " IF EXISTS" if expression.args.get("exists") else ""
|
||||
only = " ONLY" if expression.args.get("only") else ""
|
||||
return f"ALTER TABLE{exists}{only} {self.sql(expression, 'this')} {actions}"
|
||||
options = self.expressions(expression, key="options")
|
||||
options = f", {options}" if options else ""
|
||||
return f"ALTER TABLE{exists}{only} {self.sql(expression, 'this')} {actions}{options}"
|
||||
|
||||
def add_column_sql(self, expression: exp.AlterTable) -> str:
|
||||
if self.ALTER_TABLE_INCLUDE_COLUMN_KEYWORD:
|
||||
|
@ -2839,15 +2860,7 @@ class Generator(metaclass=_Generator):
|
|||
return f"DROP{exists}{expressions}"
|
||||
|
||||
def addconstraint_sql(self, expression: exp.AddConstraint) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
expression_ = self.sql(expression, "expression")
|
||||
add_constraint = f"ADD CONSTRAINT {this}" if this else "ADD"
|
||||
|
||||
enforced = expression.args.get("enforced")
|
||||
if enforced is not None:
|
||||
return f"{add_constraint} CHECK ({expression_}){' ENFORCED' if enforced else ''}"
|
||||
|
||||
return f"{add_constraint} {expression_}"
|
||||
return f"ADD {self.expressions(expression)}"
|
||||
|
||||
def distinct_sql(self, expression: exp.Distinct) -> str:
|
||||
this = self.expressions(expression, flat=True)
|
||||
|
@ -3296,6 +3309,10 @@ class Generator(metaclass=_Generator):
|
|||
self.unsupported("Unsupported index constraint option.")
|
||||
return ""
|
||||
|
||||
def checkcolumnconstraint_sql(self, expression: exp.CheckColumnConstraint) -> str:
|
||||
enforced = " ENFORCED" if expression.args.get("enforced") else ""
|
||||
return f"CHECK ({self.sql(expression, 'this')}){enforced}"
|
||||
|
||||
def indexcolumnconstraint_sql(self, expression: exp.IndexColumnConstraint) -> str:
|
||||
kind = self.sql(expression, "kind")
|
||||
kind = f"{kind} INDEX" if kind else "INDEX"
|
||||
|
@ -3452,9 +3469,87 @@ class Generator(metaclass=_Generator):
|
|||
|
||||
return expression
|
||||
|
||||
def _ensure_string_if_null(self, values: t.List[exp.Expression]) -> t.List[exp.Expression]:
|
||||
return [
|
||||
exp.func("COALESCE", exp.cast(value, "text"), exp.Literal.string(""))
|
||||
for value in values
|
||||
if value
|
||||
]
|
||||
def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
|
||||
expression.set("is_end_exclusive", None)
|
||||
return self.function_fallback_sql(expression)
|
||||
|
||||
def struct_sql(self, expression: exp.Struct) -> str:
|
||||
expression.set(
|
||||
"expressions",
|
||||
[
|
||||
exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
|
||||
for e in expression.expressions
|
||||
],
|
||||
)
|
||||
|
||||
return self.function_fallback_sql(expression)
|
||||
|
||||
def partitionrange_sql(self, expression: exp.PartitionRange) -> str:
|
||||
low = self.sql(expression, "this")
|
||||
high = self.sql(expression, "expression")
|
||||
|
||||
return f"{low} TO {high}"
|
||||
|
||||
def truncatetable_sql(self, expression: exp.TruncateTable) -> str:
|
||||
target = "DATABASE" if expression.args.get("is_database") else "TABLE"
|
||||
tables = f" {self.expressions(expression)}"
|
||||
|
||||
exists = " IF EXISTS" if expression.args.get("exists") else ""
|
||||
|
||||
on_cluster = self.sql(expression, "cluster")
|
||||
on_cluster = f" {on_cluster}" if on_cluster else ""
|
||||
|
||||
identity = self.sql(expression, "identity")
|
||||
identity = f" {identity} IDENTITY" if identity else ""
|
||||
|
||||
option = self.sql(expression, "option")
|
||||
option = f" {option}" if option else ""
|
||||
|
||||
partition = self.sql(expression, "partition")
|
||||
partition = f" {partition}" if partition else ""
|
||||
|
||||
return f"TRUNCATE {target}{exists}{tables}{on_cluster}{identity}{option}{partition}"
|
||||
|
||||
# This transpiles T-SQL's CONVERT function
|
||||
# https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16
|
||||
def convert_sql(self, expression: exp.Convert) -> str:
|
||||
to = expression.this
|
||||
value = expression.expression
|
||||
style = expression.args.get("style")
|
||||
safe = expression.args.get("safe")
|
||||
strict = expression.args.get("strict")
|
||||
|
||||
if not to or not value:
|
||||
return ""
|
||||
|
||||
# Retrieve length of datatype and override to default if not specified
|
||||
if not seq_get(to.expressions, 0) and to.this in self.PARAMETERIZABLE_TEXT_TYPES:
|
||||
to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False)
|
||||
|
||||
transformed: t.Optional[exp.Expression] = None
|
||||
cast = exp.Cast if strict else exp.TryCast
|
||||
|
||||
# Check whether a conversion with format (T-SQL calls this 'style') is applicable
|
||||
if isinstance(style, exp.Literal) and style.is_int:
|
||||
from sqlglot.dialects.tsql import TSQL
|
||||
|
||||
style_value = style.name
|
||||
converted_style = TSQL.CONVERT_FORMAT_MAPPING.get(style_value)
|
||||
if not converted_style:
|
||||
self.unsupported(f"Unsupported T-SQL 'style' value: {style_value}")
|
||||
|
||||
fmt = exp.Literal.string(converted_style)
|
||||
|
||||
if to.this == exp.DataType.Type.DATE:
|
||||
transformed = exp.StrToDate(this=value, format=fmt)
|
||||
elif to.this == exp.DataType.Type.DATETIME:
|
||||
transformed = exp.StrToTime(this=value, format=fmt)
|
||||
elif to.this in self.PARAMETERIZABLE_TEXT_TYPES:
|
||||
transformed = cast(this=exp.TimeToStr(this=value, format=fmt), to=to, safe=safe)
|
||||
elif to.this == exp.DataType.Type.TEXT:
|
||||
transformed = exp.TimeToStr(this=value, format=fmt)
|
||||
|
||||
if not transformed:
|
||||
transformed = cast(this=value, to=to, safe=safe)
|
||||
|
||||
return self.sql(transformed)
|
||||
|
|
|
@ -53,13 +53,11 @@ def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
|
|||
|
||||
|
||||
@t.overload
|
||||
def ensure_list(value: t.Collection[T]) -> t.List[T]:
|
||||
...
|
||||
def ensure_list(value: t.Collection[T]) -> t.List[T]: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def ensure_list(value: T) -> t.List[T]:
|
||||
...
|
||||
def ensure_list(value: T) -> t.List[T]: ...
|
||||
|
||||
|
||||
def ensure_list(value):
|
||||
|
@ -81,13 +79,11 @@ def ensure_list(value):
|
|||
|
||||
|
||||
@t.overload
|
||||
def ensure_collection(value: t.Collection[T]) -> t.Collection[T]:
|
||||
...
|
||||
def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def ensure_collection(value: T) -> t.Collection[T]:
|
||||
...
|
||||
def ensure_collection(value: T) -> t.Collection[T]: ...
|
||||
|
||||
|
||||
def ensure_collection(value):
|
||||
|
|
|
@ -1,16 +1,19 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import typing as t
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from sqlglot import Schema, exp, maybe_parse
|
||||
from sqlglot.errors import SqlglotError
|
||||
from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, qualify
|
||||
from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, normalize_identifiers, qualify
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Node:
|
||||
|
@ -18,7 +21,8 @@ class Node:
|
|||
expression: exp.Expression
|
||||
source: exp.Expression
|
||||
downstream: t.List[Node] = field(default_factory=list)
|
||||
alias: str = ""
|
||||
source_name: str = ""
|
||||
reference_node_name: str = ""
|
||||
|
||||
def walk(self) -> t.Iterator[Node]:
|
||||
yield self
|
||||
|
@ -67,7 +71,7 @@ def lineage(
|
|||
column: str | exp.Column,
|
||||
sql: str | exp.Expression,
|
||||
schema: t.Optional[t.Dict | Schema] = None,
|
||||
sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
|
||||
sources: t.Optional[t.Dict[str, str | exp.Query]] = None,
|
||||
dialect: DialectType = None,
|
||||
**kwargs,
|
||||
) -> Node:
|
||||
|
@ -86,14 +90,12 @@ def lineage(
|
|||
"""
|
||||
|
||||
expression = maybe_parse(sql, dialect=dialect)
|
||||
column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name
|
||||
|
||||
if sources:
|
||||
expression = exp.expand(
|
||||
expression,
|
||||
{
|
||||
k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
|
||||
for k, v in sources.items()
|
||||
},
|
||||
{k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()},
|
||||
dialect=dialect,
|
||||
)
|
||||
|
||||
|
@ -109,122 +111,141 @@ def lineage(
|
|||
if not scope:
|
||||
raise SqlglotError("Cannot build lineage, sql must be SELECT")
|
||||
|
||||
def to_node(
|
||||
column: str | int,
|
||||
scope: Scope,
|
||||
scope_name: t.Optional[str] = None,
|
||||
upstream: t.Optional[Node] = None,
|
||||
alias: t.Optional[str] = None,
|
||||
) -> Node:
|
||||
aliases = {
|
||||
dt.alias: dt.comments[0].split()[1]
|
||||
for dt in scope.derived_tables
|
||||
if dt.comments and dt.comments[0].startswith("source: ")
|
||||
}
|
||||
if not any(select.alias_or_name == column for select in scope.expression.selects):
|
||||
raise SqlglotError(f"Cannot find column '{column}' in query.")
|
||||
|
||||
# Find the specific select clause that is the source of the column we want.
|
||||
# This can either be a specific, named select or a generic `*` clause.
|
||||
select = (
|
||||
scope.expression.selects[column]
|
||||
return to_node(column, scope, dialect)
|
||||
|
||||
|
||||
def to_node(
|
||||
column: str | int,
|
||||
scope: Scope,
|
||||
dialect: DialectType,
|
||||
scope_name: t.Optional[str] = None,
|
||||
upstream: t.Optional[Node] = None,
|
||||
source_name: t.Optional[str] = None,
|
||||
reference_node_name: t.Optional[str] = None,
|
||||
) -> Node:
|
||||
source_names = {
|
||||
dt.alias: dt.comments[0].split()[1]
|
||||
for dt in scope.derived_tables
|
||||
if dt.comments and dt.comments[0].startswith("source: ")
|
||||
}
|
||||
|
||||
# Find the specific select clause that is the source of the column we want.
|
||||
# This can either be a specific, named select or a generic `*` clause.
|
||||
select = (
|
||||
scope.expression.selects[column]
|
||||
if isinstance(column, int)
|
||||
else next(
|
||||
(select for select in scope.expression.selects if select.alias_or_name == column),
|
||||
exp.Star() if scope.expression.is_star else scope.expression,
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(scope.expression, exp.Union):
|
||||
upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
|
||||
|
||||
index = (
|
||||
column
|
||||
if isinstance(column, int)
|
||||
else next(
|
||||
(select for select in scope.expression.selects if select.alias_or_name == column),
|
||||
exp.Star() if scope.expression.is_star else scope.expression,
|
||||
(
|
||||
i
|
||||
for i, select in enumerate(scope.expression.selects)
|
||||
if select.alias_or_name == column or select.is_star
|
||||
),
|
||||
-1, # mypy will not allow a None here, but a negative index should never be returned
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(scope.expression, exp.Union):
|
||||
upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
|
||||
if index == -1:
|
||||
raise ValueError(f"Could not find {column} in {scope.expression}")
|
||||
|
||||
index = (
|
||||
column
|
||||
if isinstance(column, int)
|
||||
else next(
|
||||
(
|
||||
i
|
||||
for i, select in enumerate(scope.expression.selects)
|
||||
if select.alias_or_name == column or select.is_star
|
||||
),
|
||||
-1, # mypy will not allow a None here, but a negative index should never be returned
|
||||
)
|
||||
for s in scope.union_scopes:
|
||||
to_node(
|
||||
index,
|
||||
scope=s,
|
||||
dialect=dialect,
|
||||
upstream=upstream,
|
||||
source_name=source_name,
|
||||
reference_node_name=reference_node_name,
|
||||
)
|
||||
|
||||
if index == -1:
|
||||
raise ValueError(f"Could not find {column} in {scope.expression}")
|
||||
return upstream
|
||||
|
||||
for s in scope.union_scopes:
|
||||
to_node(index, scope=s, upstream=upstream, alias=alias)
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
# For better ergonomics in our node labels, replace the full select with
|
||||
# a version that has only the column we care about.
|
||||
# "x", SELECT x, y FROM foo
|
||||
# => "x", SELECT x FROM foo
|
||||
source = t.cast(exp.Expression, scope.expression.select(select, append=False))
|
||||
else:
|
||||
source = scope.expression
|
||||
|
||||
return upstream
|
||||
# Create the node for this step in the lineage chain, and attach it to the previous one.
|
||||
node = Node(
|
||||
name=f"{scope_name}.{column}" if scope_name else str(column),
|
||||
source=source,
|
||||
expression=select,
|
||||
source_name=source_name or "",
|
||||
reference_node_name=reference_node_name or "",
|
||||
)
|
||||
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
# For better ergonomics in our node labels, replace the full select with
|
||||
# a version that has only the column we care about.
|
||||
# "x", SELECT x, y FROM foo
|
||||
# => "x", SELECT x FROM foo
|
||||
source = t.cast(exp.Expression, scope.expression.select(select, append=False))
|
||||
else:
|
||||
source = scope.expression
|
||||
if upstream:
|
||||
upstream.downstream.append(node)
|
||||
|
||||
# Create the node for this step in the lineage chain, and attach it to the previous one.
|
||||
node = Node(
|
||||
name=f"{scope_name}.{column}" if scope_name else str(column),
|
||||
source=source,
|
||||
expression=select,
|
||||
alias=alias or "",
|
||||
)
|
||||
subquery_scopes = {
|
||||
id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes
|
||||
}
|
||||
|
||||
if upstream:
|
||||
upstream.downstream.append(node)
|
||||
for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES):
|
||||
subquery_scope = subquery_scopes.get(id(subquery))
|
||||
if not subquery_scope:
|
||||
logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}")
|
||||
continue
|
||||
|
||||
subquery_scopes = {
|
||||
id(subquery_scope.expression): subquery_scope
|
||||
for subquery_scope in scope.subquery_scopes
|
||||
}
|
||||
|
||||
for subquery in find_all_in_scope(select, exp.Subqueryable):
|
||||
subquery_scope = subquery_scopes[id(subquery)]
|
||||
|
||||
for name in subquery.named_selects:
|
||||
to_node(name, scope=subquery_scope, upstream=node)
|
||||
|
||||
# if the select is a star add all scope sources as downstreams
|
||||
if select.is_star:
|
||||
for source in scope.sources.values():
|
||||
if isinstance(source, Scope):
|
||||
source = source.expression
|
||||
node.downstream.append(Node(name=select.sql(), source=source, expression=source))
|
||||
|
||||
# Find all columns that went into creating this one to list their lineage nodes.
|
||||
source_columns = set(find_all_in_scope(select, exp.Column))
|
||||
|
||||
# If the source is a UDTF find columns used in the UTDF to generate the table
|
||||
if isinstance(source, exp.UDTF):
|
||||
source_columns |= set(source.find_all(exp.Column))
|
||||
|
||||
for c in source_columns:
|
||||
table = c.table
|
||||
source = scope.sources.get(table)
|
||||
for name in subquery.named_selects:
|
||||
to_node(name, scope=subquery_scope, dialect=dialect, upstream=node)
|
||||
|
||||
# if the select is a star add all scope sources as downstreams
|
||||
if select.is_star:
|
||||
for source in scope.sources.values():
|
||||
if isinstance(source, Scope):
|
||||
# The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
|
||||
to_node(
|
||||
c.name,
|
||||
scope=source,
|
||||
scope_name=table,
|
||||
upstream=node,
|
||||
alias=aliases.get(table) or alias,
|
||||
)
|
||||
else:
|
||||
# The source is not a scope - we've reached the end of the line. At this point, if a source is not found
|
||||
# it means this column's lineage is unknown. This can happen if the definition of a source used in a query
|
||||
# is not passed into the `sources` map.
|
||||
source = source or exp.Placeholder()
|
||||
node.downstream.append(Node(name=c.sql(), source=source, expression=source))
|
||||
source = source.expression
|
||||
node.downstream.append(Node(name=select.sql(), source=source, expression=source))
|
||||
|
||||
return node
|
||||
# Find all columns that went into creating this one to list their lineage nodes.
|
||||
source_columns = set(find_all_in_scope(select, exp.Column))
|
||||
|
||||
return to_node(column if isinstance(column, str) else column.name, scope)
|
||||
# If the source is a UDTF find columns used in the UTDF to generate the table
|
||||
if isinstance(source, exp.UDTF):
|
||||
source_columns |= set(source.find_all(exp.Column))
|
||||
|
||||
for c in source_columns:
|
||||
table = c.table
|
||||
source = scope.sources.get(table)
|
||||
|
||||
if isinstance(source, Scope):
|
||||
selected_node, _ = scope.selected_sources.get(table, (None, None))
|
||||
# The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
|
||||
to_node(
|
||||
c.name,
|
||||
scope=source,
|
||||
dialect=dialect,
|
||||
scope_name=table,
|
||||
upstream=node,
|
||||
source_name=source_names.get(table) or source_name,
|
||||
reference_node_name=selected_node.name if selected_node else None,
|
||||
)
|
||||
else:
|
||||
# The source is not a scope - we've reached the end of the line. At this point, if a source is not found
|
||||
# it means this column's lineage is unknown. This can happen if the definition of a source used in a query
|
||||
# is not passed into the `sources` map.
|
||||
source = source or exp.Placeholder()
|
||||
node.downstream.append(Node(name=c.sql(), source=source, expression=source))
|
||||
|
||||
return node
|
||||
|
||||
|
||||
class GraphHTML:
|
||||
|
|
|
@ -191,6 +191,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.DateToDi,
|
||||
exp.Floor,
|
||||
exp.Levenshtein,
|
||||
exp.Sign,
|
||||
exp.StrPosition,
|
||||
exp.TsOrDiToDi,
|
||||
},
|
||||
|
@ -262,6 +263,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.DateTrunc: lambda self, e: self._annotate_timeunit(e),
|
||||
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
|
||||
exp.Div: lambda self, e: self._annotate_div(e),
|
||||
exp.Dot: lambda self, e: self._annotate_dot(e),
|
||||
exp.Explode: lambda self, e: self._annotate_explode(e),
|
||||
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
|
||||
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
|
||||
|
@ -273,15 +275,17 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
|
||||
exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
|
||||
exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"),
|
||||
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
|
||||
exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True),
|
||||
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
|
||||
exp.Timestamp: lambda self, e: self._annotate_with_type(
|
||||
e,
|
||||
exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP,
|
||||
),
|
||||
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
|
||||
exp.Unnest: lambda self, e: self._annotate_unnest(e),
|
||||
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
|
||||
exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True),
|
||||
}
|
||||
|
||||
NESTED_TYPES = {
|
||||
|
@ -380,8 +384,11 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
source = scope.sources.get(col.table)
|
||||
if isinstance(source, exp.Table):
|
||||
self._set_type(col, self.schema.get_column_type(source, col))
|
||||
elif source and col.table in selects and col.name in selects[col.table]:
|
||||
self._set_type(col, selects[col.table][col.name].type)
|
||||
elif source:
|
||||
if col.table in selects and col.name in selects[col.table]:
|
||||
self._set_type(col, selects[col.table][col.name].type)
|
||||
elif isinstance(source.expression, exp.Unnest):
|
||||
self._set_type(col, source.expression.type)
|
||||
|
||||
# Then (possibly) annotate the remaining expressions in the scope
|
||||
self._maybe_annotate(scope.expression)
|
||||
|
@ -514,7 +521,14 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
last_datatype = None
|
||||
for expr in expressions:
|
||||
last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
|
||||
expr_type = expr.type
|
||||
|
||||
# Stop at the first nested data type found - we don't want to _maybe_coerce nested types
|
||||
if expr_type.args.get("nested"):
|
||||
last_datatype = expr_type
|
||||
break
|
||||
|
||||
last_datatype = self._maybe_coerce(last_datatype or expr_type, expr_type)
|
||||
|
||||
self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)
|
||||
|
||||
|
@ -594,7 +608,26 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
return expression
|
||||
|
||||
def _annotate_dot(self, expression: exp.Dot) -> exp.Dot:
|
||||
self._annotate_args(expression)
|
||||
self._set_type(expression, None)
|
||||
this_type = expression.this.type
|
||||
|
||||
if this_type and this_type.is_type(exp.DataType.Type.STRUCT):
|
||||
for e in this_type.expressions:
|
||||
if e.name == expression.expression.name:
|
||||
self._set_type(expression, e.kind)
|
||||
break
|
||||
|
||||
return expression
|
||||
|
||||
def _annotate_explode(self, expression: exp.Explode) -> exp.Explode:
|
||||
self._annotate_args(expression)
|
||||
self._set_type(expression, seq_get(expression.this.type.expressions, 0))
|
||||
return expression
|
||||
|
||||
def _annotate_unnest(self, expression: exp.Unnest) -> exp.Unnest:
|
||||
self._annotate_args(expression)
|
||||
child = seq_get(expression.expressions, 0)
|
||||
self._set_type(expression, child and seq_get(child.type.expressions, 0))
|
||||
return expression
|
||||
|
|
|
@ -10,13 +10,11 @@ if t.TYPE_CHECKING:
|
|||
|
||||
|
||||
@t.overload
|
||||
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
|
||||
...
|
||||
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier:
|
||||
...
|
||||
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: ...
|
||||
|
||||
|
||||
def normalize_identifiers(expression, dialect=None):
|
||||
|
|
|
@ -120,6 +120,8 @@ def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) ->
|
|||
For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
|
||||
"""
|
||||
for derived_table in derived_tables:
|
||||
if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
|
||||
continue
|
||||
table_alias = derived_table.args.get("alias")
|
||||
if table_alias:
|
||||
table_alias.args.pop("columns", None)
|
||||
|
@ -214,7 +216,13 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
|
|||
table = resolver.get_table(column.name) if resolve_table and not column.table else None
|
||||
alias_expr, i = alias_to_expression.get(column.name, (None, 1))
|
||||
double_agg = (
|
||||
(alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
|
||||
(
|
||||
alias_expr.find(exp.AggFunc)
|
||||
and (
|
||||
column.find_ancestor(exp.AggFunc)
|
||||
and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
|
||||
)
|
||||
)
|
||||
if alias_expr
|
||||
else False
|
||||
)
|
||||
|
@ -404,7 +412,7 @@ def _expand_stars(
|
|||
tables = list(scope.selected_sources)
|
||||
_add_except_columns(expression, tables, except_columns)
|
||||
_add_replace_columns(expression, tables, replace_columns)
|
||||
elif expression.is_star:
|
||||
elif expression.is_star and not isinstance(expression, exp.Dot):
|
||||
tables = [expression.table]
|
||||
_add_except_columns(expression.this, tables, except_columns)
|
||||
_add_replace_columns(expression.this, tables, replace_columns)
|
||||
|
@ -437,7 +445,7 @@ def _expand_stars(
|
|||
|
||||
if pivot_columns:
|
||||
new_selections.extend(
|
||||
exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
|
||||
alias(exp.column(name, table=pivot.alias), name, copy=False)
|
||||
for name in pivot_columns
|
||||
if name not in columns_to_exclude
|
||||
)
|
||||
|
@ -466,7 +474,7 @@ def _expand_stars(
|
|||
)
|
||||
|
||||
# Ensures we don't overwrite the initial selections with an empty list
|
||||
if new_selections:
|
||||
if new_selections and isinstance(scope.expression, exp.Select):
|
||||
scope.expression.set("expressions", new_selections)
|
||||
|
||||
|
||||
|
@ -528,7 +536,8 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
|
|||
|
||||
new_selections.append(selection)
|
||||
|
||||
scope.expression.set("expressions", new_selections)
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
scope.expression.set("expressions", new_selections)
|
||||
|
||||
|
||||
def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
|
||||
|
@ -615,7 +624,7 @@ class Resolver:
|
|||
|
||||
node, _ = self.scope.selected_sources.get(table_name)
|
||||
|
||||
if isinstance(node, exp.Subqueryable):
|
||||
if isinstance(node, exp.Query):
|
||||
while node and node.alias != table_name:
|
||||
node = node.parent
|
||||
|
||||
|
|
|
@ -55,8 +55,8 @@ def qualify_tables(
|
|||
if not table.args.get("catalog") and table.args.get("db"):
|
||||
table.set("catalog", catalog)
|
||||
|
||||
if not isinstance(expression, exp.Subqueryable):
|
||||
for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Unionable)):
|
||||
if not isinstance(expression, exp.Query):
|
||||
for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Query)):
|
||||
if isinstance(node, exp.Table):
|
||||
_qualify(node)
|
||||
|
||||
|
|
|
@ -138,7 +138,7 @@ class Scope:
|
|||
and _is_derived_table(node)
|
||||
):
|
||||
self._derived_tables.append(node)
|
||||
elif isinstance(node, exp.Subqueryable):
|
||||
elif isinstance(node, exp.UNWRAPPED_QUERIES):
|
||||
self._subqueries.append(node)
|
||||
|
||||
self._collected = True
|
||||
|
@ -225,7 +225,7 @@ class Scope:
|
|||
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
|
||||
|
||||
Returns:
|
||||
list[exp.Subqueryable]: subqueries
|
||||
list[exp.Select | exp.Union]: subqueries
|
||||
"""
|
||||
self._ensure_collected()
|
||||
return self._subqueries
|
||||
|
@ -486,8 +486,8 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
|
|||
Returns:
|
||||
list[Scope]: scope instances
|
||||
"""
|
||||
if isinstance(expression, exp.Unionable) or (
|
||||
isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Unionable)
|
||||
if isinstance(expression, exp.Query) or (
|
||||
isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query)
|
||||
):
|
||||
return list(_traverse_scope(Scope(expression)))
|
||||
|
||||
|
@ -615,7 +615,7 @@ def _is_derived_table(expression: exp.Subquery) -> bool:
|
|||
as it doesn't introduce a new scope. If an alias is present, it shadows all names
|
||||
under the Subquery, so that's one exception to this rule.
|
||||
"""
|
||||
return bool(expression.alias or isinstance(expression.this, exp.Subqueryable))
|
||||
return bool(expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES))
|
||||
|
||||
|
||||
def _traverse_tables(scope):
|
||||
|
@ -786,7 +786,7 @@ def walk_in_scope(expression, bfs=True, prune=None):
|
|||
and _is_derived_table(node)
|
||||
)
|
||||
or isinstance(node, exp.UDTF)
|
||||
or isinstance(node, exp.Subqueryable)
|
||||
or isinstance(node, exp.UNWRAPPED_QUERIES)
|
||||
):
|
||||
crossed_scope_boundary = True
|
||||
|
||||
|
|
|
@ -1185,7 +1185,7 @@ def gen(expression: t.Any) -> str:
|
|||
GEN_MAP = {
|
||||
exp.Add: lambda e: _binary(e, "+"),
|
||||
exp.And: lambda e: _binary(e, "AND"),
|
||||
exp.Anonymous: lambda e: f"{e.this.upper()} {','.join(gen(e) for e in e.expressions)}",
|
||||
exp.Anonymous: lambda e: _anonymous(e),
|
||||
exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
|
||||
exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
|
||||
exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
|
||||
|
@ -1219,6 +1219,20 @@ GEN_MAP = {
|
|||
}
|
||||
|
||||
|
||||
def _anonymous(e: exp.Anonymous) -> str:
|
||||
this = e.this
|
||||
if isinstance(this, str):
|
||||
name = this.upper()
|
||||
elif isinstance(this, exp.Identifier):
|
||||
name = f'"{this.name}"' if this.quoted else this.name.upper()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
|
||||
)
|
||||
|
||||
return f"{name} {','.join(gen(e) for e in e.expressions)}"
|
||||
|
||||
|
||||
def _binary(e: exp.Binary, op: str) -> str:
|
||||
return f"{gen(e.left)} {op} {gen(e.right)}"
|
||||
|
||||
|
|
|
@ -94,8 +94,20 @@ def unnest(select, parent_select, next_alias_name):
|
|||
else:
|
||||
_replace(predicate, join_key_not_null)
|
||||
|
||||
group = select.args.get("group")
|
||||
|
||||
if group:
|
||||
if {value.this} != set(group.expressions):
|
||||
select = (
|
||||
exp.select(exp.column(value.alias, "_q"))
|
||||
.from_(select.subquery("_q", copy=False), copy=False)
|
||||
.group_by(exp.column(value.alias, "_q"), copy=False)
|
||||
)
|
||||
else:
|
||||
select = select.group_by(value.this, copy=False)
|
||||
|
||||
parent_select.join(
|
||||
select.group_by(value.this, copy=False),
|
||||
select,
|
||||
on=column.eq(join_key),
|
||||
join_type="LEFT",
|
||||
join_alias=alias,
|
||||
|
|
|
@ -17,6 +17,8 @@ if t.TYPE_CHECKING:
|
|||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
OPTIONS_TYPE = t.Dict[str, t.Sequence[t.Union[t.Sequence[str], str]]]
|
||||
|
||||
|
||||
def build_var_map(args: t.List) -> exp.StarMap | exp.VarMap:
|
||||
if len(args) == 1 and args[0].is_star:
|
||||
|
@ -367,6 +369,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.TEMPORARY,
|
||||
TokenType.TOP,
|
||||
TokenType.TRUE,
|
||||
TokenType.TRUNCATE,
|
||||
TokenType.UNIQUE,
|
||||
TokenType.UNPIVOT,
|
||||
TokenType.UPDATE,
|
||||
|
@ -435,6 +438,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.TABLE,
|
||||
TokenType.TIMESTAMP,
|
||||
TokenType.TIMESTAMPTZ,
|
||||
TokenType.TRUNCATE,
|
||||
TokenType.WINDOW,
|
||||
TokenType.XOR,
|
||||
*TYPE_TOKENS,
|
||||
|
@ -578,7 +582,7 @@ class Parser(metaclass=_Parser):
|
|||
exp.Column: lambda self: self._parse_column(),
|
||||
exp.Condition: lambda self: self._parse_conjunction(),
|
||||
exp.DataType: lambda self: self._parse_types(allow_identifiers=False),
|
||||
exp.Expression: lambda self: self._parse_statement(),
|
||||
exp.Expression: lambda self: self._parse_expression(),
|
||||
exp.From: lambda self: self._parse_from(),
|
||||
exp.Group: lambda self: self._parse_group(),
|
||||
exp.Having: lambda self: self._parse_having(),
|
||||
|
@ -625,10 +629,10 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.SET: lambda self: self._parse_set(),
|
||||
TokenType.UNCACHE: lambda self: self._parse_uncache(),
|
||||
TokenType.UPDATE: lambda self: self._parse_update(),
|
||||
TokenType.TRUNCATE: lambda self: self._parse_truncate_table(),
|
||||
TokenType.USE: lambda self: self.expression(
|
||||
exp.Use,
|
||||
kind=self._match_texts(("ROLE", "WAREHOUSE", "DATABASE", "SCHEMA"))
|
||||
and exp.var(self._prev.text),
|
||||
kind=self._parse_var_from_options(self.USABLES, raise_unmatched=False),
|
||||
this=self._parse_table(schema=False),
|
||||
),
|
||||
}
|
||||
|
@ -642,36 +646,44 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.DPIPE_SLASH: lambda self: self.expression(exp.Cbrt, this=self._parse_unary()),
|
||||
}
|
||||
|
||||
PRIMARY_PARSERS = {
|
||||
TokenType.STRING: lambda self, token: self.expression(
|
||||
exp.Literal, this=token.text, is_string=True
|
||||
STRING_PARSERS = {
|
||||
TokenType.HEREDOC_STRING: lambda self, token: self.expression(
|
||||
exp.RawString, this=token.text
|
||||
),
|
||||
TokenType.NUMBER: lambda self, token: self.expression(
|
||||
exp.Literal, this=token.text, is_string=False
|
||||
),
|
||||
TokenType.STAR: lambda self, _: self.expression(
|
||||
exp.Star, **{"except": self._parse_except(), "replace": self._parse_replace()}
|
||||
),
|
||||
TokenType.NULL: lambda self, _: self.expression(exp.Null),
|
||||
TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True),
|
||||
TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False),
|
||||
TokenType.BIT_STRING: lambda self, token: self.expression(exp.BitString, this=token.text),
|
||||
TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text),
|
||||
TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text),
|
||||
TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token),
|
||||
TokenType.NATIONAL_STRING: lambda self, token: self.expression(
|
||||
exp.National, this=token.text
|
||||
),
|
||||
TokenType.RAW_STRING: lambda self, token: self.expression(exp.RawString, this=token.text),
|
||||
TokenType.HEREDOC_STRING: lambda self, token: self.expression(
|
||||
exp.RawString, this=token.text
|
||||
TokenType.STRING: lambda self, token: self.expression(
|
||||
exp.Literal, this=token.text, is_string=True
|
||||
),
|
||||
TokenType.UNICODE_STRING: lambda self, token: self.expression(
|
||||
exp.UnicodeString,
|
||||
this=token.text,
|
||||
escape=self._match_text_seq("UESCAPE") and self._parse_string(),
|
||||
),
|
||||
}
|
||||
|
||||
NUMERIC_PARSERS = {
|
||||
TokenType.BIT_STRING: lambda self, token: self.expression(exp.BitString, this=token.text),
|
||||
TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text),
|
||||
TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text),
|
||||
TokenType.NUMBER: lambda self, token: self.expression(
|
||||
exp.Literal, this=token.text, is_string=False
|
||||
),
|
||||
}
|
||||
|
||||
PRIMARY_PARSERS = {
|
||||
**STRING_PARSERS,
|
||||
**NUMERIC_PARSERS,
|
||||
TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token),
|
||||
TokenType.NULL: lambda self, _: self.expression(exp.Null),
|
||||
TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True),
|
||||
TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False),
|
||||
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
|
||||
TokenType.STAR: lambda self, _: self.expression(
|
||||
exp.Star, **{"except": self._parse_except(), "replace": self._parse_replace()}
|
||||
),
|
||||
}
|
||||
|
||||
PLACEHOLDER_PARSERS = {
|
||||
|
@ -799,7 +811,9 @@ class Parser(metaclass=_Parser):
|
|||
exp.CharacterSetColumnConstraint, this=self._parse_var_or_string()
|
||||
),
|
||||
"CHECK": lambda self: self.expression(
|
||||
exp.CheckColumnConstraint, this=self._parse_wrapped(self._parse_conjunction)
|
||||
exp.CheckColumnConstraint,
|
||||
this=self._parse_wrapped(self._parse_conjunction),
|
||||
enforced=self._match_text_seq("ENFORCED"),
|
||||
),
|
||||
"COLLATE": lambda self: self.expression(
|
||||
exp.CollateColumnConstraint, this=self._parse_var()
|
||||
|
@ -873,6 +887,8 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
FUNCTIONS_WITH_ALIASED_ARGS = {"STRUCT"}
|
||||
|
||||
KEY_VALUE_DEFINITIONS = (exp.Alias, exp.EQ, exp.PropertyEQ, exp.Slice)
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
|
||||
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
|
||||
|
@ -895,6 +911,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
QUERY_MODIFIER_PARSERS = {
|
||||
TokenType.MATCH_RECOGNIZE: lambda self: ("match", self._parse_match_recognize()),
|
||||
TokenType.PREWHERE: lambda self: ("prewhere", self._parse_prewhere()),
|
||||
TokenType.WHERE: lambda self: ("where", self._parse_where()),
|
||||
TokenType.GROUP_BY: lambda self: ("group", self._parse_group()),
|
||||
TokenType.HAVING: lambda self: ("having", self._parse_having()),
|
||||
|
@ -934,22 +951,23 @@ class Parser(metaclass=_Parser):
|
|||
exp.DataType.Type.JSON: lambda self, this, _: self.expression(exp.ParseJSON, this=this),
|
||||
}
|
||||
|
||||
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
|
||||
|
||||
DDL_SELECT_TOKENS = {TokenType.SELECT, TokenType.WITH, TokenType.L_PAREN}
|
||||
|
||||
PRE_VOLATILE_TOKENS = {TokenType.CREATE, TokenType.REPLACE, TokenType.UNIQUE}
|
||||
|
||||
TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
|
||||
TRANSACTION_CHARACTERISTICS = {
|
||||
"ISOLATION LEVEL REPEATABLE READ",
|
||||
"ISOLATION LEVEL READ COMMITTED",
|
||||
"ISOLATION LEVEL READ UNCOMMITTED",
|
||||
"ISOLATION LEVEL SERIALIZABLE",
|
||||
"READ WRITE",
|
||||
"READ ONLY",
|
||||
TRANSACTION_CHARACTERISTICS: OPTIONS_TYPE = {
|
||||
"ISOLATION": (
|
||||
("LEVEL", "REPEATABLE", "READ"),
|
||||
("LEVEL", "READ", "COMMITTED"),
|
||||
("LEVEL", "READ", "UNCOMITTED"),
|
||||
("LEVEL", "SERIALIZABLE"),
|
||||
),
|
||||
"READ": ("WRITE", "ONLY"),
|
||||
}
|
||||
|
||||
USABLES: OPTIONS_TYPE = dict.fromkeys(("ROLE", "WAREHOUSE", "DATABASE", "SCHEMA"), tuple())
|
||||
|
||||
INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"}
|
||||
|
||||
CLONE_KEYWORDS = {"CLONE", "COPY"}
|
||||
|
@ -1012,6 +1030,9 @@ class Parser(metaclass=_Parser):
|
|||
# If this is True and '(' is not found, the keyword will be treated as an identifier
|
||||
VALUES_FOLLOWED_BY_PAREN = True
|
||||
|
||||
# Whether implicit unnesting is supported, e.g. SELECT 1 FROM y.z AS z, z.a (Redshift)
|
||||
SUPPORTS_IMPLICIT_UNNEST = False
|
||||
|
||||
__slots__ = (
|
||||
"error_level",
|
||||
"error_message_context",
|
||||
|
@ -2450,10 +2471,37 @@ class Parser(metaclass=_Parser):
|
|||
alias=self._parse_table_alias() if parse_alias else None,
|
||||
)
|
||||
|
||||
def _implicit_unnests_to_explicit(self, this: E) -> E:
|
||||
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers as _norm
|
||||
|
||||
refs = {_norm(this.args["from"].this.copy(), dialect=self.dialect).alias_or_name}
|
||||
for i, join in enumerate(this.args.get("joins") or []):
|
||||
table = join.this
|
||||
normalized_table = table.copy()
|
||||
normalized_table.meta["maybe_column"] = True
|
||||
normalized_table = _norm(normalized_table, dialect=self.dialect)
|
||||
|
||||
if isinstance(table, exp.Table) and not join.args.get("on"):
|
||||
if normalized_table.parts[0].name in refs:
|
||||
table_as_column = table.to_column()
|
||||
unnest = exp.Unnest(expressions=[table_as_column])
|
||||
|
||||
# Table.to_column creates a parent Alias node that we want to convert to
|
||||
# a TableAlias and attach to the Unnest, so it matches the parser's output
|
||||
if isinstance(table.args.get("alias"), exp.TableAlias):
|
||||
table_as_column.replace(table_as_column.this)
|
||||
exp.alias_(unnest, None, table=[table.args["alias"].this], copy=False)
|
||||
|
||||
table.replace(unnest)
|
||||
|
||||
refs.add(normalized_table.alias_or_name)
|
||||
|
||||
return this
|
||||
|
||||
def _parse_query_modifiers(
|
||||
self, this: t.Optional[exp.Expression]
|
||||
) -> t.Optional[exp.Expression]:
|
||||
if isinstance(this, self.MODIFIABLES):
|
||||
if isinstance(this, (exp.Query, exp.Table)):
|
||||
for join in iter(self._parse_join, None):
|
||||
this.append("joins", join)
|
||||
for lateral in iter(self._parse_lateral, None):
|
||||
|
@ -2478,6 +2526,10 @@ class Parser(metaclass=_Parser):
|
|||
offset.set("expressions", limit_by_expressions)
|
||||
continue
|
||||
break
|
||||
|
||||
if self.SUPPORTS_IMPLICIT_UNNEST and this and "from" in this.args:
|
||||
this = self._implicit_unnests_to_explicit(this)
|
||||
|
||||
return this
|
||||
|
||||
def _parse_hint(self) -> t.Optional[exp.Hint]:
|
||||
|
@ -2803,7 +2855,9 @@ class Parser(metaclass=_Parser):
|
|||
or self._parse_placeholder()
|
||||
)
|
||||
|
||||
def _parse_table_parts(self, schema: bool = False, is_db_reference: bool = False) -> exp.Table:
|
||||
def _parse_table_parts(
|
||||
self, schema: bool = False, is_db_reference: bool = False, wildcard: bool = False
|
||||
) -> exp.Table:
|
||||
catalog = None
|
||||
db = None
|
||||
table: t.Optional[exp.Expression | str] = self._parse_table_part(schema=schema)
|
||||
|
@ -2817,8 +2871,20 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
catalog = db
|
||||
db = table
|
||||
# "" used for tsql FROM a..b case
|
||||
table = self._parse_table_part(schema=schema) or ""
|
||||
|
||||
if (
|
||||
wildcard
|
||||
and self._is_connected()
|
||||
and (isinstance(table, exp.Identifier) or not table)
|
||||
and self._match(TokenType.STAR)
|
||||
):
|
||||
if isinstance(table, exp.Identifier):
|
||||
table.args["this"] += "*"
|
||||
else:
|
||||
table = exp.Identifier(this="*")
|
||||
|
||||
if is_db_reference:
|
||||
catalog = db
|
||||
db = table
|
||||
|
@ -2861,6 +2927,9 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
bracket = parse_bracket and self._parse_bracket(None)
|
||||
bracket = self.expression(exp.Table, this=bracket) if bracket else None
|
||||
|
||||
only = self._match(TokenType.ONLY)
|
||||
|
||||
this = t.cast(
|
||||
exp.Expression,
|
||||
bracket
|
||||
|
@ -2869,6 +2938,12 @@ class Parser(metaclass=_Parser):
|
|||
),
|
||||
)
|
||||
|
||||
if only:
|
||||
this.set("only", only)
|
||||
|
||||
# Postgres supports a wildcard (table) suffix operator, which is a no-op in this context
|
||||
self._match_text_seq("*")
|
||||
|
||||
if schema:
|
||||
return self._parse_schema(this=this)
|
||||
|
||||
|
@ -3161,6 +3236,14 @@ class Parser(metaclass=_Parser):
|
|||
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
|
||||
return [agg.alias for agg in aggregations]
|
||||
|
||||
def _parse_prewhere(self, skip_where_token: bool = False) -> t.Optional[exp.PreWhere]:
|
||||
if not skip_where_token and not self._match(TokenType.PREWHERE):
|
||||
return None
|
||||
|
||||
return self.expression(
|
||||
exp.PreWhere, comments=self._prev_comments, this=self._parse_conjunction()
|
||||
)
|
||||
|
||||
def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Where]:
|
||||
if not skip_where_token and not self._match(TokenType.WHERE):
|
||||
return None
|
||||
|
@ -3291,8 +3374,12 @@ class Parser(metaclass=_Parser):
|
|||
return None
|
||||
return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered))
|
||||
|
||||
def _parse_ordered(self, parse_method: t.Optional[t.Callable] = None) -> exp.Ordered:
|
||||
def _parse_ordered(
|
||||
self, parse_method: t.Optional[t.Callable] = None
|
||||
) -> t.Optional[exp.Ordered]:
|
||||
this = parse_method() if parse_method else self._parse_conjunction()
|
||||
if not this:
|
||||
return None
|
||||
|
||||
asc = self._match(TokenType.ASC)
|
||||
desc = self._match(TokenType.DESC) or (asc and False)
|
||||
|
@ -3510,7 +3597,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
if self._match_text_seq("DISTINCT", "FROM"):
|
||||
klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
|
||||
return self.expression(klass, this=this, expression=self._parse_conjunction())
|
||||
return self.expression(klass, this=this, expression=self._parse_bitwise())
|
||||
|
||||
expression = self._parse_null() or self._parse_boolean()
|
||||
if not expression:
|
||||
|
@ -3528,7 +3615,7 @@ class Parser(metaclass=_Parser):
|
|||
matched_l_paren = self._prev.token_type == TokenType.L_PAREN
|
||||
expressions = self._parse_csv(lambda: self._parse_select_or_expression(alias=alias))
|
||||
|
||||
if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable):
|
||||
if len(expressions) == 1 and isinstance(expressions[0], exp.Query):
|
||||
this = self.expression(exp.In, this=this, query=expressions[0])
|
||||
else:
|
||||
this = self.expression(exp.In, this=this, expressions=expressions)
|
||||
|
@ -3959,7 +4046,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
this = self._parse_query_modifiers(seq_get(expressions, 0))
|
||||
|
||||
if isinstance(this, exp.Subqueryable):
|
||||
if isinstance(this, exp.UNWRAPPED_QUERIES):
|
||||
this = self._parse_set_operations(
|
||||
self._parse_subquery(this=this, parse_alias=False)
|
||||
)
|
||||
|
@ -4064,6 +4151,9 @@ class Parser(metaclass=_Parser):
|
|||
alias = upper in self.FUNCTIONS_WITH_ALIASED_ARGS
|
||||
args = self._parse_csv(lambda: self._parse_lambda(alias=alias))
|
||||
|
||||
if alias:
|
||||
args = self._kv_to_prop_eq(args)
|
||||
|
||||
if function and not anonymous:
|
||||
if "dialect" in function.__code__.co_varnames:
|
||||
func = function(args, dialect=self.dialect)
|
||||
|
@ -4076,6 +4166,8 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
this = func
|
||||
else:
|
||||
if token_type == TokenType.IDENTIFIER:
|
||||
this = exp.Identifier(this=this, quoted=True)
|
||||
this = self.expression(exp.Anonymous, this=this, expressions=args)
|
||||
|
||||
if isinstance(this, exp.Expression):
|
||||
|
@ -4084,6 +4176,26 @@ class Parser(metaclass=_Parser):
|
|||
self._match_r_paren(this)
|
||||
return self._parse_window(this)
|
||||
|
||||
def _kv_to_prop_eq(self, expressions: t.List[exp.Expression]) -> t.List[exp.Expression]:
|
||||
transformed = []
|
||||
|
||||
for e in expressions:
|
||||
if isinstance(e, self.KEY_VALUE_DEFINITIONS):
|
||||
if isinstance(e, exp.Alias):
|
||||
e = self.expression(exp.PropertyEQ, this=e.args.get("alias"), expression=e.this)
|
||||
|
||||
if not isinstance(e, exp.PropertyEQ):
|
||||
e = self.expression(
|
||||
exp.PropertyEQ, this=exp.to_identifier(e.name), expression=e.expression
|
||||
)
|
||||
|
||||
if isinstance(e.this, exp.Column):
|
||||
e.this.replace(e.this.this)
|
||||
|
||||
transformed.append(e)
|
||||
|
||||
return transformed
|
||||
|
||||
def _parse_function_parameter(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_column_def(self._parse_id_var())
|
||||
|
||||
|
@ -4496,7 +4608,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
# https://duckdb.org/docs/sql/data_types/struct.html#creating-structs
|
||||
if bracket_kind == TokenType.L_BRACE:
|
||||
this = self.expression(exp.Struct, expressions=expressions)
|
||||
this = self.expression(exp.Struct, expressions=self._kv_to_prop_eq(expressions))
|
||||
elif not this or this.name.upper() == "ARRAY":
|
||||
this = self.expression(exp.Array, expressions=expressions)
|
||||
else:
|
||||
|
@ -4747,12 +4859,10 @@ class Parser(metaclass=_Parser):
|
|||
return None
|
||||
|
||||
@t.overload
|
||||
def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject:
|
||||
...
|
||||
def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ...
|
||||
|
||||
@t.overload
|
||||
def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg:
|
||||
...
|
||||
def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ...
|
||||
|
||||
def _parse_json_object(self, agg=False):
|
||||
star = self._parse_star()
|
||||
|
@ -5140,16 +5250,16 @@ class Parser(metaclass=_Parser):
|
|||
return None
|
||||
|
||||
def _parse_string(self) -> t.Optional[exp.Expression]:
|
||||
if self._match_set((TokenType.STRING, TokenType.RAW_STRING)):
|
||||
return self.PRIMARY_PARSERS[self._prev.token_type](self, self._prev)
|
||||
if self._match_set(self.STRING_PARSERS):
|
||||
return self.STRING_PARSERS[self._prev.token_type](self, self._prev)
|
||||
return self._parse_placeholder()
|
||||
|
||||
def _parse_string_as_identifier(self) -> t.Optional[exp.Identifier]:
|
||||
return exp.to_identifier(self._match(TokenType.STRING) and self._prev.text, quoted=True)
|
||||
|
||||
def _parse_number(self) -> t.Optional[exp.Expression]:
|
||||
if self._match(TokenType.NUMBER):
|
||||
return self.PRIMARY_PARSERS[TokenType.NUMBER](self, self._prev)
|
||||
if self._match_set(self.NUMERIC_PARSERS):
|
||||
return self.NUMERIC_PARSERS[self._prev.token_type](self, self._prev)
|
||||
return self._parse_placeholder()
|
||||
|
||||
def _parse_identifier(self) -> t.Optional[exp.Expression]:
|
||||
|
@ -5182,6 +5292,9 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_var_or_string(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_var() or self._parse_string()
|
||||
|
||||
def _parse_primary_or_var(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_primary() or self._parse_var(any_token=True)
|
||||
|
||||
def _parse_null(self) -> t.Optional[exp.Expression]:
|
||||
if self._match_set(self.NULL_TOKENS):
|
||||
return self.PRIMARY_PARSERS[TokenType.NULL](self, self._prev)
|
||||
|
@ -5200,16 +5313,12 @@ class Parser(metaclass=_Parser):
|
|||
return self._parse_placeholder()
|
||||
|
||||
def _parse_parameter(self) -> exp.Parameter:
|
||||
def _parse_parameter_part() -> t.Optional[exp.Expression]:
|
||||
return (
|
||||
self._parse_identifier() or self._parse_primary() or self._parse_var(any_token=True)
|
||||
)
|
||||
|
||||
self._match(TokenType.L_BRACE)
|
||||
this = _parse_parameter_part()
|
||||
expression = self._match(TokenType.COLON) and _parse_parameter_part()
|
||||
this = self._parse_identifier() or self._parse_primary_or_var()
|
||||
expression = self._match(TokenType.COLON) and (
|
||||
self._parse_identifier() or self._parse_primary_or_var()
|
||||
)
|
||||
self._match(TokenType.R_BRACE)
|
||||
|
||||
return self.expression(exp.Parameter, this=this, expression=expression)
|
||||
|
||||
def _parse_placeholder(self) -> t.Optional[exp.Expression]:
|
||||
|
@ -5376,35 +5485,15 @@ class Parser(metaclass=_Parser):
|
|||
exp.DropPartition, expressions=self._parse_csv(self._parse_partition), exists=exists
|
||||
)
|
||||
|
||||
def _parse_add_constraint(self) -> exp.AddConstraint:
|
||||
this = None
|
||||
kind = self._prev.token_type
|
||||
|
||||
if kind == TokenType.CONSTRAINT:
|
||||
this = self._parse_id_var()
|
||||
|
||||
if self._match_text_seq("CHECK"):
|
||||
expression = self._parse_wrapped(self._parse_conjunction)
|
||||
enforced = self._match_text_seq("ENFORCED") or False
|
||||
|
||||
return self.expression(
|
||||
exp.AddConstraint, this=this, expression=expression, enforced=enforced
|
||||
)
|
||||
|
||||
if kind == TokenType.FOREIGN_KEY or self._match(TokenType.FOREIGN_KEY):
|
||||
expression = self._parse_foreign_key()
|
||||
elif kind == TokenType.PRIMARY_KEY or self._match(TokenType.PRIMARY_KEY):
|
||||
expression = self._parse_primary_key()
|
||||
else:
|
||||
expression = None
|
||||
|
||||
return self.expression(exp.AddConstraint, this=this, expression=expression)
|
||||
|
||||
def _parse_alter_table_add(self) -> t.List[exp.Expression]:
|
||||
index = self._index - 1
|
||||
|
||||
if self._match_set(self.ADD_CONSTRAINT_TOKENS):
|
||||
return self._parse_csv(self._parse_add_constraint)
|
||||
if self._match_set(self.ADD_CONSTRAINT_TOKENS, advance=False):
|
||||
return self._parse_csv(
|
||||
lambda: self.expression(
|
||||
exp.AddConstraint, expressions=self._parse_csv(self._parse_constraint)
|
||||
)
|
||||
)
|
||||
|
||||
self._retreat(index)
|
||||
if not self.ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN and self._match_text_seq("ADD"):
|
||||
|
@ -5472,6 +5561,7 @@ class Parser(metaclass=_Parser):
|
|||
parser = self.ALTER_PARSERS.get(self._prev.text.upper()) if self._prev else None
|
||||
if parser:
|
||||
actions = ensure_list(parser(self))
|
||||
options = self._parse_csv(self._parse_property)
|
||||
|
||||
if not self._curr and actions:
|
||||
return self.expression(
|
||||
|
@ -5480,6 +5570,7 @@ class Parser(metaclass=_Parser):
|
|||
exists=exists,
|
||||
actions=actions,
|
||||
only=only,
|
||||
options=options,
|
||||
)
|
||||
|
||||
return self._parse_as_command(start)
|
||||
|
@ -5610,11 +5701,34 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return set_
|
||||
|
||||
def _parse_var_from_options(self, options: t.Collection[str]) -> t.Optional[exp.Var]:
|
||||
for option in options:
|
||||
if self._match_text_seq(*option.split(" ")):
|
||||
return exp.var(option)
|
||||
return None
|
||||
def _parse_var_from_options(
|
||||
self, options: OPTIONS_TYPE, raise_unmatched: bool = True
|
||||
) -> t.Optional[exp.Var]:
|
||||
start = self._curr
|
||||
if not start:
|
||||
return None
|
||||
|
||||
option = start.text.upper()
|
||||
continuations = options.get(option)
|
||||
|
||||
index = self._index
|
||||
self._advance()
|
||||
for keywords in continuations or []:
|
||||
if isinstance(keywords, str):
|
||||
keywords = (keywords,)
|
||||
|
||||
if self._match_text_seq(*keywords):
|
||||
option = f"{option} {' '.join(keywords)}"
|
||||
break
|
||||
else:
|
||||
if continuations or continuations is None:
|
||||
if raise_unmatched:
|
||||
self.raise_error(f"Unknown option {option}")
|
||||
|
||||
self._retreat(index)
|
||||
return None
|
||||
|
||||
return exp.var(option)
|
||||
|
||||
def _parse_as_command(self, start: Token) -> exp.Command:
|
||||
while self._curr:
|
||||
|
@ -5806,14 +5920,12 @@ class Parser(metaclass=_Parser):
|
|||
return True
|
||||
|
||||
@t.overload
|
||||
def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression:
|
||||
...
|
||||
def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression: ...
|
||||
|
||||
@t.overload
|
||||
def _replace_columns_with_dots(
|
||||
self, this: t.Optional[exp.Expression]
|
||||
) -> t.Optional[exp.Expression]:
|
||||
...
|
||||
) -> t.Optional[exp.Expression]: ...
|
||||
|
||||
def _replace_columns_with_dots(self, this):
|
||||
if isinstance(this, exp.Dot):
|
||||
|
@ -5849,3 +5961,53 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
column.replace(dot_or_id)
|
||||
return node
|
||||
|
||||
def _parse_truncate_table(self) -> t.Optional[exp.TruncateTable] | exp.Expression:
|
||||
start = self._prev
|
||||
|
||||
# Not to be confused with TRUNCATE(number, decimals) function call
|
||||
if self._match(TokenType.L_PAREN):
|
||||
self._retreat(self._index - 2)
|
||||
return self._parse_function()
|
||||
|
||||
# Clickhouse supports TRUNCATE DATABASE as well
|
||||
is_database = self._match(TokenType.DATABASE)
|
||||
|
||||
self._match(TokenType.TABLE)
|
||||
|
||||
exists = self._parse_exists(not_=False)
|
||||
|
||||
expressions = self._parse_csv(
|
||||
lambda: self._parse_table(schema=True, is_db_reference=is_database)
|
||||
)
|
||||
|
||||
cluster = self._parse_on_property() if self._match(TokenType.ON) else None
|
||||
|
||||
if self._match_text_seq("RESTART", "IDENTITY"):
|
||||
identity = "RESTART"
|
||||
elif self._match_text_seq("CONTINUE", "IDENTITY"):
|
||||
identity = "CONTINUE"
|
||||
else:
|
||||
identity = None
|
||||
|
||||
if self._match_text_seq("CASCADE") or self._match_text_seq("RESTRICT"):
|
||||
option = self._prev.text
|
||||
else:
|
||||
option = None
|
||||
|
||||
partition = self._parse_partition()
|
||||
|
||||
# Fallback case
|
||||
if self._curr:
|
||||
return self._parse_as_command(start)
|
||||
|
||||
return self.expression(
|
||||
exp.TruncateTable,
|
||||
expressions=expressions,
|
||||
is_database=is_database,
|
||||
exists=exists,
|
||||
cluster=cluster,
|
||||
identity=identity,
|
||||
option=option,
|
||||
partition=partition,
|
||||
)
|
||||
|
|
|
@ -302,6 +302,7 @@ class TokenType(AutoName):
|
|||
OBJECT_IDENTIFIER = auto()
|
||||
OFFSET = auto()
|
||||
ON = auto()
|
||||
ONLY = auto()
|
||||
OPERATOR = auto()
|
||||
ORDER_BY = auto()
|
||||
ORDER_SIBLINGS_BY = auto()
|
||||
|
@ -317,6 +318,7 @@ class TokenType(AutoName):
|
|||
PIVOT = auto()
|
||||
PLACEHOLDER = auto()
|
||||
PRAGMA = auto()
|
||||
PREWHERE = auto()
|
||||
PRIMARY_KEY = auto()
|
||||
PROCEDURE = auto()
|
||||
PROPERTIES = auto()
|
||||
|
@ -353,6 +355,7 @@ class TokenType(AutoName):
|
|||
TOP = auto()
|
||||
THEN = auto()
|
||||
TRUE = auto()
|
||||
TRUNCATE = auto()
|
||||
UNCACHE = auto()
|
||||
UNION = auto()
|
||||
UNNEST = auto()
|
||||
|
@ -370,6 +373,7 @@ class TokenType(AutoName):
|
|||
UNIQUE = auto()
|
||||
VERSION_SNAPSHOT = auto()
|
||||
TIMESTAMP_SNAPSHOT = auto()
|
||||
OPTION = auto()
|
||||
|
||||
|
||||
_ALL_TOKEN_TYPES = list(TokenType)
|
||||
|
@ -657,6 +661,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"DROP": TokenType.DROP,
|
||||
"ELSE": TokenType.ELSE,
|
||||
"END": TokenType.END,
|
||||
"ENUM": TokenType.ENUM,
|
||||
"ESCAPE": TokenType.ESCAPE,
|
||||
"EXCEPT": TokenType.EXCEPT,
|
||||
"EXECUTE": TokenType.EXECUTE,
|
||||
|
@ -752,6 +757,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"TEMPORARY": TokenType.TEMPORARY,
|
||||
"THEN": TokenType.THEN,
|
||||
"TRUE": TokenType.TRUE,
|
||||
"TRUNCATE": TokenType.TRUNCATE,
|
||||
"UNION": TokenType.UNION,
|
||||
"UNKNOWN": TokenType.UNKNOWN,
|
||||
"UNNEST": TokenType.UNNEST,
|
||||
|
@ -860,7 +866,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"GRANT": TokenType.COMMAND,
|
||||
"OPTIMIZE": TokenType.COMMAND,
|
||||
"PREPARE": TokenType.COMMAND,
|
||||
"TRUNCATE": TokenType.COMMAND,
|
||||
"VACUUM": TokenType.COMMAND,
|
||||
"USER-DEFINED": TokenType.USERDEFINED,
|
||||
"FOR VERSION": TokenType.VERSION_SNAPSHOT,
|
||||
|
@ -1036,12 +1041,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
def _text(self) -> str:
|
||||
return self.sql[self._start : self._current]
|
||||
|
||||
def peek(self, i: int = 0) -> str:
|
||||
i = self._current + i
|
||||
if i < self.size:
|
||||
return self.sql[i]
|
||||
return ""
|
||||
|
||||
def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
|
||||
self._prev_token_line = self._line
|
||||
|
||||
|
@ -1182,12 +1181,8 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
if self._peek.isdigit():
|
||||
self._advance()
|
||||
elif self._peek == "." and not decimal:
|
||||
after = self.peek(1)
|
||||
if after.isdigit() or not after.isalpha():
|
||||
decimal = True
|
||||
self._advance()
|
||||
else:
|
||||
return self._add(TokenType.VAR)
|
||||
decimal = True
|
||||
self._advance()
|
||||
elif self._peek in ("-", "+") and scientific == 1:
|
||||
scientific += 1
|
||||
self._advance()
|
||||
|
|
|
@ -547,7 +547,7 @@ def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Exp
|
|||
prop
|
||||
and prop.this
|
||||
and isinstance(prop.this, exp.Schema)
|
||||
and all(isinstance(e, exp.ColumnDef) and e.args.get("kind") for e in prop.this.expressions)
|
||||
and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
|
||||
):
|
||||
prop_this = exp.Tuple(
|
||||
expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
|
||||
|
@ -560,6 +560,22 @@ def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Exp
|
|||
return expression
|
||||
|
||||
|
||||
def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Convert struct arguments to aliases: STRUCT(1 AS y) .
|
||||
"""
|
||||
if isinstance(expression, exp.Struct):
|
||||
expression.set(
|
||||
"expressions",
|
||||
[
|
||||
exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
|
||||
for e in expression.expressions
|
||||
],
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def preprocess(
|
||||
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
|
||||
) -> t.Callable[[Generator, exp.Expression], str]:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue