Merging upstream version 21.1.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
4e41aa0bbb
commit
bf03050a25
91 changed files with 49165 additions and 47854 deletions
|
@ -148,7 +148,7 @@ def atanh(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def cbrt(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_anonymous_function(col, "CBRT")
|
||||
return Column.invoke_expression_over_column(col, expression.Cbrt)
|
||||
|
||||
|
||||
def ceil(col: ColumnOrName) -> Column:
|
||||
|
|
|
@ -70,12 +70,10 @@ class SparkSession:
|
|||
column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
|
||||
|
||||
data_expressions = [
|
||||
exp.Tuple(
|
||||
expressions=list(
|
||||
map(
|
||||
lambda x: F.lit(x).expression,
|
||||
row if not isinstance(row, dict) else row.values(),
|
||||
)
|
||||
exp.tuple_(
|
||||
*map(
|
||||
lambda x: F.lit(x).expression,
|
||||
row if not isinstance(row, dict) else row.values(),
|
||||
)
|
||||
)
|
||||
for row in data
|
||||
|
|
|
@ -39,24 +39,31 @@ def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Va
|
|||
|
||||
alias = expression.args.get("alias")
|
||||
|
||||
structs = [
|
||||
exp.Struct(
|
||||
return self.unnest_sql(
|
||||
exp.Unnest(
|
||||
expressions=[
|
||||
exp.alias_(value, column_name)
|
||||
for value, column_name in zip(
|
||||
t.expressions,
|
||||
(
|
||||
alias.columns
|
||||
if alias and alias.columns
|
||||
else (f"_c{i}" for i in range(len(t.expressions)))
|
||||
exp.array(
|
||||
*(
|
||||
exp.Struct(
|
||||
expressions=[
|
||||
exp.alias_(value, column_name)
|
||||
for value, column_name in zip(
|
||||
t.expressions,
|
||||
(
|
||||
alias.columns
|
||||
if alias and alias.columns
|
||||
else (f"_c{i}" for i in range(len(t.expressions)))
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
for t in expression.find_all(exp.Tuple)
|
||||
),
|
||||
copy=False,
|
||||
)
|
||||
]
|
||||
)
|
||||
for t in expression.find_all(exp.Tuple)
|
||||
]
|
||||
|
||||
return self.unnest_sql(exp.Unnest(expressions=[exp.Array(expressions=structs)]))
|
||||
)
|
||||
|
||||
|
||||
def _returnsproperty_sql(self: BigQuery.Generator, expression: exp.ReturnsProperty) -> str:
|
||||
|
@ -161,12 +168,18 @@ def _pushdown_cte_column_names(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def _parse_timestamp(args: t.List) -> exp.StrToTime:
|
||||
def _parse_parse_timestamp(args: t.List) -> exp.StrToTime:
|
||||
this = format_time_lambda(exp.StrToTime, "bigquery")([seq_get(args, 1), seq_get(args, 0)])
|
||||
this.set("zone", seq_get(args, 2))
|
||||
return this
|
||||
|
||||
|
||||
def _parse_timestamp(args: t.List) -> exp.Timestamp:
|
||||
timestamp = exp.Timestamp.from_arg_list(args)
|
||||
timestamp.set("with_tz", True)
|
||||
return timestamp
|
||||
|
||||
|
||||
def _parse_date(args: t.List) -> exp.Date | exp.DateFromParts:
|
||||
expr_type = exp.DateFromParts if len(args) == 3 else exp.Date
|
||||
return expr_type.from_arg_list(args)
|
||||
|
@ -318,6 +331,7 @@ class BigQuery(Dialect):
|
|||
"TIMESTAMP": TokenType.TIMESTAMPTZ,
|
||||
}
|
||||
KEYWORDS.pop("DIV")
|
||||
KEYWORDS.pop("VALUES")
|
||||
|
||||
class Parser(parser.Parser):
|
||||
PREFIXED_PIVOT_COLUMNS = True
|
||||
|
@ -348,7 +362,7 @@ class BigQuery(Dialect):
|
|||
"PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")(
|
||||
[seq_get(args, 1), seq_get(args, 0)]
|
||||
),
|
||||
"PARSE_TIMESTAMP": _parse_timestamp,
|
||||
"PARSE_TIMESTAMP": _parse_parse_timestamp,
|
||||
"REGEXP_CONTAINS": exp.RegexpLike.from_arg_list,
|
||||
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
|
||||
this=seq_get(args, 0),
|
||||
|
@ -367,6 +381,7 @@ class BigQuery(Dialect):
|
|||
"TIME": _parse_time,
|
||||
"TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd),
|
||||
"TIME_SUB": parse_date_delta_with_interval(exp.TimeSub),
|
||||
"TIMESTAMP": _parse_timestamp,
|
||||
"TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd),
|
||||
"TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub),
|
||||
"TIMESTAMP_MICROS": lambda args: exp.UnixToTime(
|
||||
|
@ -395,11 +410,6 @@ class BigQuery(Dialect):
|
|||
TokenType.TABLE,
|
||||
}
|
||||
|
||||
ID_VAR_TOKENS = {
|
||||
*parser.Parser.ID_VAR_TOKENS,
|
||||
TokenType.VALUES,
|
||||
}
|
||||
|
||||
PROPERTY_PARSERS = {
|
||||
**parser.Parser.PROPERTY_PARSERS,
|
||||
"NOT DETERMINISTIC": lambda self: self.expression(
|
||||
|
|
|
@ -93,6 +93,7 @@ class ClickHouse(Dialect):
|
|||
"IPV6": TokenType.IPV6,
|
||||
"AGGREGATEFUNCTION": TokenType.AGGREGATEFUNCTION,
|
||||
"SIMPLEAGGREGATEFUNCTION": TokenType.SIMPLEAGGREGATEFUNCTION,
|
||||
"SYSTEM": TokenType.COMMAND,
|
||||
}
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
|
|
|
@ -654,28 +654,6 @@ def time_format(
|
|||
return _time_format
|
||||
|
||||
|
||||
def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
|
||||
"""
|
||||
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
|
||||
PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
|
||||
columns are removed from the create statement.
|
||||
"""
|
||||
has_schema = isinstance(expression.this, exp.Schema)
|
||||
is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
|
||||
|
||||
if has_schema and is_partitionable:
|
||||
prop = expression.find(exp.PartitionedByProperty)
|
||||
if prop and prop.this and not isinstance(prop.this, exp.Schema):
|
||||
schema = expression.this
|
||||
columns = {v.name.upper() for v in prop.this.expressions}
|
||||
partitions = [col for col in schema.expressions if col.name.upper() in columns]
|
||||
schema.set("expressions", [e for e in schema.expressions if e not in partitions])
|
||||
prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
|
||||
expression.set("this", schema)
|
||||
|
||||
return self.create_sql(expression)
|
||||
|
||||
|
||||
def parse_date_delta(
|
||||
exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
|
||||
) -> t.Callable[[t.List], E]:
|
||||
|
@ -742,7 +720,10 @@ def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
|
|||
|
||||
def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
|
||||
if not expression.expression:
|
||||
return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP))
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
|
||||
return self.sql(exp.cast(expression.this, to=target_type))
|
||||
if expression.text("expression").lower() in TIMEZONES:
|
||||
return self.sql(
|
||||
exp.AtTimeZone(
|
||||
|
@ -750,7 +731,7 @@ def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
|
|||
zone=expression.expression,
|
||||
)
|
||||
)
|
||||
return self.function_fallback_sql(expression)
|
||||
return self.func("TIMESTAMP", expression.this, expression.expression)
|
||||
|
||||
|
||||
def locate_to_strposition(args: t.List) -> exp.Expression:
|
||||
|
|
|
@ -5,7 +5,6 @@ import typing as t
|
|||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
create_with_partitions_sql,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
no_trycast_sql,
|
||||
|
@ -13,6 +12,7 @@ from sqlglot.dialects.dialect import (
|
|||
str_position_sql,
|
||||
timestrtotime_sql,
|
||||
)
|
||||
from sqlglot.transforms import preprocess, move_schema_columns_to_partitioned_by
|
||||
|
||||
|
||||
def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.DateSub], str]:
|
||||
|
@ -125,7 +125,7 @@ class Drill(Dialect):
|
|||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
|
||||
exp.ArraySize: rename_func("REPEATED_COUNT"),
|
||||
exp.Create: create_with_partitions_sql,
|
||||
exp.Create: preprocess([move_schema_columns_to_partitioned_by]),
|
||||
exp.DateAdd: _date_add_sql("ADD"),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateSub: _date_add_sql("SUB"),
|
||||
|
|
|
@ -9,7 +9,6 @@ from sqlglot.dialects.dialect import (
|
|||
NormalizationStrategy,
|
||||
approx_count_distinct_sql,
|
||||
arg_max_or_min_no_count,
|
||||
create_with_partitions_sql,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
if_sql,
|
||||
|
@ -32,6 +31,12 @@ from sqlglot.dialects.dialect import (
|
|||
timestrtotime_sql,
|
||||
var_map_sql,
|
||||
)
|
||||
from sqlglot.transforms import (
|
||||
remove_unique_constraints,
|
||||
ctas_with_tmp_tables_to_create_tmp_view,
|
||||
preprocess,
|
||||
move_schema_columns_to_partitioned_by,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.parser import parse_var_map
|
||||
from sqlglot.tokens import TokenType
|
||||
|
@ -55,30 +60,6 @@ TIME_DIFF_FACTOR = {
|
|||
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
|
||||
|
||||
|
||||
def _create_sql(self, expression: exp.Create) -> str:
|
||||
# remove UNIQUE column constraints
|
||||
for constraint in expression.find_all(exp.UniqueColumnConstraint):
|
||||
if constraint.parent:
|
||||
constraint.parent.pop()
|
||||
|
||||
properties = expression.args.get("properties")
|
||||
temporary = any(
|
||||
isinstance(prop, exp.TemporaryProperty)
|
||||
for prop in (properties.expressions if properties else [])
|
||||
)
|
||||
|
||||
# CTAS with temp tables map to CREATE TEMPORARY VIEW
|
||||
kind = expression.args["kind"]
|
||||
if kind.upper() == "TABLE" and temporary:
|
||||
if expression.expression:
|
||||
return f"CREATE TEMPORARY VIEW {self.sql(expression, 'this')} AS {self.sql(expression, 'expression')}"
|
||||
else:
|
||||
# CREATE TEMPORARY TABLE may require storage provider
|
||||
expression = self.temporary_storage_provider(expression)
|
||||
|
||||
return create_with_partitions_sql(self, expression)
|
||||
|
||||
|
||||
def _add_date_sql(self: Hive.Generator, expression: DATE_ADD_OR_SUB) -> str:
|
||||
if isinstance(expression, exp.TsOrDsAdd) and not expression.unit:
|
||||
return self.func("DATE_ADD", expression.this, expression.expression)
|
||||
|
@ -285,6 +266,7 @@ class Hive(Dialect):
|
|||
class Parser(parser.Parser):
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
STRICT_CAST = False
|
||||
VALUES_FOLLOWED_BY_PAREN = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
|
@ -518,7 +500,13 @@ class Hive(Dialect):
|
|||
"" if e.args.get("allow_null") else "NOT NULL"
|
||||
),
|
||||
exp.VarMap: var_map_sql,
|
||||
exp.Create: _create_sql,
|
||||
exp.Create: preprocess(
|
||||
[
|
||||
remove_unique_constraints,
|
||||
ctas_with_tmp_tables_to_create_tmp_view,
|
||||
move_schema_columns_to_partitioned_by,
|
||||
]
|
||||
),
|
||||
exp.Quantile: rename_func("PERCENTILE"),
|
||||
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
|
||||
exp.RegexpExtract: regexp_extract_sql,
|
||||
|
@ -581,10 +569,6 @@ class Hive(Dialect):
|
|||
|
||||
return super()._jsonpathkey_sql(expression)
|
||||
|
||||
def temporary_storage_provider(self, expression: exp.Create) -> exp.Create:
|
||||
# Hive has no temporary storage provider (there are hive settings though)
|
||||
return expression
|
||||
|
||||
def parameter_sql(self, expression: exp.Parameter) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
expression_sql = self.sql(expression, "expression")
|
||||
|
|
|
@ -445,6 +445,7 @@ class MySQL(Dialect):
|
|||
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
STRING_ALIASES = True
|
||||
VALUES_FOLLOWED_BY_PAREN = False
|
||||
|
||||
def _parse_primary_key_part(self) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_id_var()
|
||||
|
|
|
@ -88,6 +88,7 @@ class Oracle(Dialect):
|
|||
class Parser(parser.Parser):
|
||||
ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
|
||||
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP}
|
||||
VALUES_FOLLOWED_BY_PAREN = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
|
|
|
@ -244,6 +244,8 @@ class Postgres(Dialect):
|
|||
"@@": TokenType.DAT,
|
||||
"@>": TokenType.AT_GT,
|
||||
"<@": TokenType.LT_AT,
|
||||
"|/": TokenType.PIPE_SLASH,
|
||||
"||/": TokenType.DPIPE_SLASH,
|
||||
"BEGIN": TokenType.COMMAND,
|
||||
"BEGIN TRANSACTION": TokenType.BEGIN,
|
||||
"BIGSERIAL": TokenType.BIGSERIAL,
|
||||
|
|
|
@ -225,6 +225,8 @@ class Presto(Dialect):
|
|||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
VALUES_FOLLOWED_BY_PAREN = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"ARBITRARY": exp.AnyValue.from_arg_list,
|
||||
|
|
|
@ -136,11 +136,11 @@ class Redshift(Postgres):
|
|||
refs.add(
|
||||
(
|
||||
this.args["from"] if i == 0 else this.args["joins"][i - 1]
|
||||
).alias_or_name.lower()
|
||||
).this.alias.lower()
|
||||
)
|
||||
table = join.this
|
||||
|
||||
if isinstance(table, exp.Table):
|
||||
table = join.this
|
||||
if isinstance(table, exp.Table) and not join.args.get("on"):
|
||||
if table.parts[0].name.lower() in refs:
|
||||
table.replace(table.to_column())
|
||||
return this
|
||||
|
@ -158,6 +158,7 @@ class Redshift(Postgres):
|
|||
"UNLOAD": TokenType.COMMAND,
|
||||
"VARBYTE": TokenType.VARBINARY,
|
||||
}
|
||||
KEYWORDS.pop("VALUES")
|
||||
|
||||
# Redshift allows # to appear as a table identifier prefix
|
||||
SINGLE_TOKENS = Postgres.Tokenizer.SINGLE_TOKENS.copy()
|
||||
|
|
|
@ -477,6 +477,8 @@ class Snowflake(Dialect):
|
|||
"PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
|
||||
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
|
||||
"COLUMNS": _show_parser("COLUMNS"),
|
||||
"USERS": _show_parser("USERS"),
|
||||
"TERSE USERS": _show_parser("USERS"),
|
||||
}
|
||||
|
||||
STAGED_FILE_SINGLE_TOKENS = {
|
||||
|
|
|
@ -5,8 +5,14 @@ import typing as t
|
|||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import rename_func
|
||||
from sqlglot.dialects.hive import _parse_ignore_nulls
|
||||
from sqlglot.dialects.spark2 import Spark2
|
||||
from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.transforms import (
|
||||
ctas_with_tmp_tables_to_create_tmp_view,
|
||||
remove_unique_constraints,
|
||||
preprocess,
|
||||
move_partitioned_by_to_schema_columns,
|
||||
)
|
||||
|
||||
|
||||
def _parse_datediff(args: t.List) -> exp.Expression:
|
||||
|
@ -35,6 +41,15 @@ def _parse_datediff(args: t.List) -> exp.Expression:
|
|||
)
|
||||
|
||||
|
||||
def _normalize_partition(e: exp.Expression) -> exp.Expression:
|
||||
"""Normalize the expressions in PARTITION BY (<expression>, <expression>, ...)"""
|
||||
if isinstance(e, str):
|
||||
return exp.to_identifier(e)
|
||||
if isinstance(e, exp.Literal):
|
||||
return exp.to_identifier(e.name)
|
||||
return e
|
||||
|
||||
|
||||
class Spark(Spark2):
|
||||
class Tokenizer(Spark2.Tokenizer):
|
||||
RAW_STRINGS = [
|
||||
|
@ -72,6 +87,17 @@ class Spark(Spark2):
|
|||
|
||||
TRANSFORMS = {
|
||||
**Spark2.Generator.TRANSFORMS,
|
||||
exp.Create: preprocess(
|
||||
[
|
||||
remove_unique_constraints,
|
||||
lambda e: ctas_with_tmp_tables_to_create_tmp_view(
|
||||
e, temporary_storage_provider
|
||||
),
|
||||
move_partitioned_by_to_schema_columns,
|
||||
]
|
||||
),
|
||||
exp.PartitionedByProperty: lambda self,
|
||||
e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}",
|
||||
exp.StartsWith: rename_func("STARTSWITH"),
|
||||
exp.TimestampAdd: lambda self, e: self.func(
|
||||
"DATEADD", e.args.get("unit") or "DAY", e.expression, e.this
|
||||
|
|
|
@ -13,6 +13,12 @@ from sqlglot.dialects.dialect import (
|
|||
)
|
||||
from sqlglot.dialects.hive import Hive
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.transforms import (
|
||||
preprocess,
|
||||
remove_unique_constraints,
|
||||
ctas_with_tmp_tables_to_create_tmp_view,
|
||||
move_schema_columns_to_partitioned_by,
|
||||
)
|
||||
|
||||
|
||||
def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str:
|
||||
|
@ -95,6 +101,13 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def temporary_storage_provider(expression: exp.Expression) -> exp.Expression:
|
||||
# spark2, spark, Databricks require a storage provider for temporary tables
|
||||
provider = exp.FileFormatProperty(this=exp.Literal.string("parquet"))
|
||||
expression.args["properties"].append("expressions", provider)
|
||||
return expression
|
||||
|
||||
|
||||
class Spark2(Hive):
|
||||
class Parser(Hive.Parser):
|
||||
TRIM_PATTERN_FIRST = True
|
||||
|
@ -121,7 +134,6 @@ class Spark2(Hive):
|
|||
),
|
||||
zone=seq_get(args, 1),
|
||||
),
|
||||
"IIF": exp.If.from_arg_list,
|
||||
"INT": _parse_as_cast("int"),
|
||||
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
|
||||
"RLIKE": exp.RegexpLike.from_arg_list,
|
||||
|
@ -193,6 +205,15 @@ class Spark2(Hive):
|
|||
e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
|
||||
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
|
||||
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
|
||||
exp.Create: preprocess(
|
||||
[
|
||||
remove_unique_constraints,
|
||||
lambda e: ctas_with_tmp_tables_to_create_tmp_view(
|
||||
e, temporary_storage_provider
|
||||
),
|
||||
move_schema_columns_to_partitioned_by,
|
||||
]
|
||||
),
|
||||
exp.DateFromParts: rename_func("MAKE_DATE"),
|
||||
exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
|
@ -251,12 +272,6 @@ class Spark2(Hive):
|
|||
|
||||
return self.func("STRUCT", *args)
|
||||
|
||||
def temporary_storage_provider(self, expression: exp.Create) -> exp.Create:
|
||||
# spark2, spark, Databricks require a storage provider for temporary tables
|
||||
provider = exp.FileFormatProperty(this=exp.Literal.string("parquet"))
|
||||
expression.args["properties"].append("expressions", provider)
|
||||
return expression
|
||||
|
||||
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
|
||||
if is_parse_json(expression.this):
|
||||
schema = f"'{self.sql(expression, 'to')}'"
|
||||
|
|
|
@ -132,6 +132,7 @@ class SQLite(Dialect):
|
|||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.DateAdd: _date_add_sql,
|
||||
exp.DateStrToDate: lambda self, e: self.sql(e, "this"),
|
||||
exp.If: rename_func("IIF"),
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.JSONExtract: _json_extract_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_sql,
|
||||
|
|
|
@ -1,10 +1,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, parser, transforms
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import Dialect, rename_func
|
||||
|
||||
|
||||
class Tableau(Dialect):
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
IDENTIFIERS = [("[", "]")]
|
||||
QUOTES = ["'", '"']
|
||||
|
||||
class Generator(generator.Generator):
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
|
|
|
@ -74,6 +74,7 @@ class Teradata(Dialect):
|
|||
|
||||
class Parser(parser.Parser):
|
||||
TABLESAMPLE_CSV = True
|
||||
VALUES_FOLLOWED_BY_PAREN = False
|
||||
|
||||
CHARSET_TRANSLATORS = {
|
||||
"GRAPHIC_TO_KANJISJIS",
|
||||
|
|
|
@ -457,7 +457,6 @@ class TSQL(Dialect):
|
|||
"FORMAT": _parse_format,
|
||||
"GETDATE": exp.CurrentTimestamp.from_arg_list,
|
||||
"HASHBYTES": _parse_hashbytes,
|
||||
"IIF": exp.If.from_arg_list,
|
||||
"ISNULL": exp.Coalesce.from_arg_list,
|
||||
"JSON_QUERY": parser.parse_extract_json_with_path(exp.JSONExtract),
|
||||
"JSON_VALUE": parser.parse_extract_json_with_path(exp.JSONExtractScalar),
|
||||
|
|
|
@ -1090,6 +1090,11 @@ class Create(DDL):
|
|||
"clone": False,
|
||||
}
|
||||
|
||||
@property
|
||||
def kind(self) -> t.Optional[str]:
|
||||
kind = self.args.get("kind")
|
||||
return kind and kind.upper()
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/sql/create-clone
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_clone_statement
|
||||
|
@ -4626,6 +4631,11 @@ class CountIf(AggFunc):
|
|||
_sql_names = ["COUNT_IF", "COUNTIF"]
|
||||
|
||||
|
||||
# cube root
|
||||
class Cbrt(Func):
|
||||
pass
|
||||
|
||||
|
||||
class CurrentDate(Func):
|
||||
arg_types = {"this": False}
|
||||
|
||||
|
@ -4728,7 +4738,7 @@ class Extract(Func):
|
|||
|
||||
|
||||
class Timestamp(Func):
|
||||
arg_types = {"this": False, "expression": False}
|
||||
arg_types = {"this": False, "expression": False, "with_tz": False}
|
||||
|
||||
|
||||
class TimestampAdd(Func, TimeUnit):
|
||||
|
@ -4833,7 +4843,7 @@ class Posexplode(Explode):
|
|||
pass
|
||||
|
||||
|
||||
class PosexplodeOuter(Posexplode):
|
||||
class PosexplodeOuter(Posexplode, ExplodeOuter):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -4868,6 +4878,7 @@ class Xor(Connector, Func):
|
|||
|
||||
class If(Func):
|
||||
arg_types = {"this": True, "true": True, "false": False}
|
||||
_sql_names = ["IF", "IIF"]
|
||||
|
||||
|
||||
class Nullif(Func):
|
||||
|
@ -6883,6 +6894,7 @@ def replace_tables(
|
|||
table = to_table(
|
||||
new_name,
|
||||
**{k: v for k, v in node.args.items() if k not in TABLE_PARTS},
|
||||
dialect=dialect,
|
||||
)
|
||||
table.add_comments([original])
|
||||
return table
|
||||
|
@ -7072,6 +7084,60 @@ def cast_unless(
|
|||
return cast(expr, to, **opts)
|
||||
|
||||
|
||||
def array(
|
||||
*expressions: ExpOrStr, copy: bool = True, dialect: DialectType = None, **kwargs
|
||||
) -> Array:
|
||||
"""
|
||||
Returns an array.
|
||||
|
||||
Examples:
|
||||
>>> array(1, 'x').sql()
|
||||
'ARRAY(1, x)'
|
||||
|
||||
Args:
|
||||
expressions: the expressions to add to the array.
|
||||
copy: whether or not to copy the argument expressions.
|
||||
dialect: the source dialect.
|
||||
kwargs: the kwargs used to instantiate the function of interest.
|
||||
|
||||
Returns:
|
||||
An array expression.
|
||||
"""
|
||||
return Array(
|
||||
expressions=[
|
||||
maybe_parse(expression, copy=copy, dialect=dialect, **kwargs)
|
||||
for expression in expressions
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def tuple_(
|
||||
*expressions: ExpOrStr, copy: bool = True, dialect: DialectType = None, **kwargs
|
||||
) -> Tuple:
|
||||
"""
|
||||
Returns an tuple.
|
||||
|
||||
Examples:
|
||||
>>> tuple_(1, 'x').sql()
|
||||
'(1, x)'
|
||||
|
||||
Args:
|
||||
expressions: the expressions to add to the tuple.
|
||||
copy: whether or not to copy the argument expressions.
|
||||
dialect: the source dialect.
|
||||
kwargs: the kwargs used to instantiate the function of interest.
|
||||
|
||||
Returns:
|
||||
A tuple expression.
|
||||
"""
|
||||
return Tuple(
|
||||
expressions=[
|
||||
maybe_parse(expression, copy=copy, dialect=dialect, **kwargs)
|
||||
for expression in expressions
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def true() -> Boolean:
|
||||
"""
|
||||
Returns a true Boolean expression.
|
||||
|
|
|
@ -124,6 +124,7 @@ class Generator(metaclass=_Generator):
|
|||
exp.StabilityProperty: lambda self, e: e.name,
|
||||
exp.TemporaryProperty: lambda self, e: "TEMPORARY",
|
||||
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
|
||||
exp.Timestamp: lambda self, e: self.func("TIMESTAMP", e.this, e.expression),
|
||||
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
|
||||
exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions),
|
||||
exp.TransientProperty: lambda self, e: "TRANSIENT",
|
||||
|
@ -3360,7 +3361,7 @@ class Generator(metaclass=_Generator):
|
|||
return self.sql(arg)
|
||||
|
||||
cond_for_null = arg.is_(exp.null())
|
||||
return self.sql(exp.func("IF", cond_for_null, exp.null(), exp.Array(expressions=[arg])))
|
||||
return self.sql(exp.func("IF", cond_for_null, exp.null(), exp.array(arg, copy=False)))
|
||||
|
||||
def tsordstotime_sql(self, expression: exp.TsOrDsToTime) -> str:
|
||||
this = expression.this
|
||||
|
|
|
@ -6,7 +6,7 @@ import logging
|
|||
import re
|
||||
import sys
|
||||
import typing as t
|
||||
from collections.abc import Collection
|
||||
from collections.abc import Collection, Set
|
||||
from contextlib import contextmanager
|
||||
from copy import copy
|
||||
from enum import Enum
|
||||
|
@ -496,3 +496,31 @@ DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
|
|||
|
||||
def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
|
||||
return expression is not None and expression.name.lower() in DATE_UNITS
|
||||
|
||||
|
||||
K = t.TypeVar("K")
|
||||
V = t.TypeVar("V")
|
||||
|
||||
|
||||
class SingleValuedMapping(t.Mapping[K, V]):
|
||||
"""
|
||||
Mapping where all keys return the same value.
|
||||
|
||||
This rigamarole is meant to avoid copying keys, which was originally intended
|
||||
as an optimization while qualifying columns for tables with lots of columns.
|
||||
"""
|
||||
|
||||
def __init__(self, keys: t.Collection[K], value: V):
|
||||
self._keys = keys if isinstance(keys, Set) else set(keys)
|
||||
self._value = value
|
||||
|
||||
def __getitem__(self, key: K) -> V:
|
||||
if key in self._keys:
|
||||
return self._value
|
||||
raise KeyError(key)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._keys)
|
||||
|
||||
def __iter__(self) -> t.Iterator[K]:
|
||||
return iter(self._keys)
|
||||
|
|
|
@ -153,7 +153,7 @@ def lineage(
|
|||
raise ValueError(f"Could not find {column} in {scope.expression}")
|
||||
|
||||
for s in scope.union_scopes:
|
||||
to_node(index, scope=s, upstream=upstream)
|
||||
to_node(index, scope=s, upstream=upstream, alias=alias)
|
||||
|
||||
return upstream
|
||||
|
||||
|
@ -209,7 +209,11 @@ def lineage(
|
|||
if isinstance(source, Scope):
|
||||
# The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
|
||||
to_node(
|
||||
c.name, scope=source, scope_name=table, upstream=node, alias=aliases.get(table)
|
||||
c.name,
|
||||
scope=source,
|
||||
scope_name=table,
|
||||
upstream=node,
|
||||
alias=aliases.get(table) or alias,
|
||||
)
|
||||
else:
|
||||
# The source is not a scope - we've reached the end of the line. At this point, if a source is not found
|
||||
|
|
|
@ -204,7 +204,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.TimeAdd,
|
||||
exp.TimeStrToTime,
|
||||
exp.TimeSub,
|
||||
exp.Timestamp,
|
||||
exp.TimestampAdd,
|
||||
exp.TimestampSub,
|
||||
exp.UnixToTime,
|
||||
|
@ -276,6 +275,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
|
||||
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
|
||||
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
|
||||
exp.Timestamp: lambda self, e: self._annotate_with_type(
|
||||
e,
|
||||
exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP,
|
||||
),
|
||||
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
|
||||
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
|
||||
exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True),
|
||||
|
|
|
@ -38,7 +38,12 @@ def replace_date_funcs(node: exp.Expression) -> exp.Expression:
|
|||
if isinstance(node, exp.Date) and not node.expressions and not node.args.get("zone"):
|
||||
return exp.cast(node.this, to=exp.DataType.Type.DATE)
|
||||
if isinstance(node, exp.Timestamp) and not node.expression:
|
||||
return exp.cast(node.this, to=exp.DataType.Type.TIMESTAMP)
|
||||
if not node.type:
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
node = annotate_types(node)
|
||||
return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP)
|
||||
|
||||
return node
|
||||
|
||||
|
||||
|
@ -76,9 +81,8 @@ def coerce_type(node: exp.Expression) -> exp.Expression:
|
|||
def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
|
||||
if (
|
||||
isinstance(expression, exp.Cast)
|
||||
and expression.to.type
|
||||
and expression.this.type
|
||||
and expression.to.type.this == expression.this.type.this
|
||||
and expression.to.this == expression.this.type.this
|
||||
):
|
||||
return expression.this
|
||||
return expression
|
||||
|
|
|
@ -6,7 +6,7 @@ import typing as t
|
|||
from sqlglot import alias, exp
|
||||
from sqlglot.dialects.dialect import Dialect, DialectType
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.helper import seq_get, SingleValuedMapping
|
||||
from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
|
||||
from sqlglot.optimizer.simplify import simplify_parens
|
||||
from sqlglot.schema import Schema, ensure_schema
|
||||
|
@ -586,8 +586,8 @@ class Resolver:
|
|||
def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
|
||||
self.scope = scope
|
||||
self.schema = schema
|
||||
self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
|
||||
self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
|
||||
self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
|
||||
self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
|
||||
self._all_columns: t.Optional[t.Set[str]] = None
|
||||
self._infer_schema = infer_schema
|
||||
|
||||
|
@ -640,7 +640,7 @@ class Resolver:
|
|||
}
|
||||
return self._all_columns
|
||||
|
||||
def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
|
||||
def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
|
||||
"""Resolve the source columns for a given source `name`."""
|
||||
if name not in self.scope.sources:
|
||||
raise OptimizeError(f"Unknown table: {name}")
|
||||
|
@ -662,10 +662,15 @@ class Resolver:
|
|||
else:
|
||||
column_aliases = []
|
||||
|
||||
# If the source's columns are aliased, their aliases shadow the corresponding column names
|
||||
return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
|
||||
if column_aliases:
|
||||
# If the source's columns are aliased, their aliases shadow the corresponding column names.
|
||||
# This can be expensive if there are lots of columns, so only do this if column_aliases exist.
|
||||
return [
|
||||
alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
|
||||
]
|
||||
return columns
|
||||
|
||||
def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]:
|
||||
def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
|
||||
if self._source_columns is None:
|
||||
self._source_columns = {
|
||||
source_name: self.get_source_columns(source_name)
|
||||
|
@ -676,8 +681,8 @@ class Resolver:
|
|||
return self._source_columns
|
||||
|
||||
def _get_unambiguous_columns(
|
||||
self, source_columns: t.Dict[str, t.List[str]]
|
||||
) -> t.Dict[str, str]:
|
||||
self, source_columns: t.Dict[str, t.Sequence[str]]
|
||||
) -> t.Mapping[str, str]:
|
||||
"""
|
||||
Find all the unambiguous columns in sources.
|
||||
|
||||
|
@ -693,12 +698,17 @@ class Resolver:
|
|||
source_columns_pairs = list(source_columns.items())
|
||||
|
||||
first_table, first_columns = source_columns_pairs[0]
|
||||
unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
|
||||
|
||||
if len(source_columns_pairs) == 1:
|
||||
# Performance optimization - avoid copying first_columns if there is only one table.
|
||||
return SingleValuedMapping(first_columns, first_table)
|
||||
|
||||
unambiguous_columns = {col: first_table for col in first_columns}
|
||||
all_columns = set(unambiguous_columns)
|
||||
|
||||
for table, columns in source_columns_pairs[1:]:
|
||||
unique = self._find_unique_columns(columns)
|
||||
ambiguous = set(all_columns).intersection(unique)
|
||||
unique = set(columns)
|
||||
ambiguous = all_columns.intersection(unique)
|
||||
all_columns.update(columns)
|
||||
|
||||
for column in ambiguous:
|
||||
|
@ -707,19 +717,3 @@ class Resolver:
|
|||
unambiguous_columns[column] = table
|
||||
|
||||
return unambiguous_columns
|
||||
|
||||
@staticmethod
|
||||
def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]:
|
||||
"""
|
||||
Find the unique columns in a list of columns.
|
||||
|
||||
Example:
|
||||
>>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
|
||||
['a', 'c']
|
||||
|
||||
This is necessary because duplicate column names are ambiguous.
|
||||
"""
|
||||
counts: t.Dict[str, int] = {}
|
||||
for column in columns:
|
||||
counts[column] = counts.get(column, 0) + 1
|
||||
return {column for column, count in counts.items() if count == 1}
|
||||
|
|
|
@ -29,8 +29,8 @@ def parse_var_map(args: t.List) -> exp.StarMap | exp.VarMap:
|
|||
values.append(args[i + 1])
|
||||
|
||||
return exp.VarMap(
|
||||
keys=exp.Array(expressions=keys),
|
||||
values=exp.Array(expressions=values),
|
||||
keys=exp.array(*keys, copy=False),
|
||||
values=exp.array(*values, copy=False),
|
||||
)
|
||||
|
||||
|
||||
|
@ -638,6 +638,8 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.NOT: lambda self: self.expression(exp.Not, this=self._parse_equality()),
|
||||
TokenType.TILDA: lambda self: self.expression(exp.BitwiseNot, this=self._parse_unary()),
|
||||
TokenType.DASH: lambda self: self.expression(exp.Neg, this=self._parse_unary()),
|
||||
TokenType.PIPE_SLASH: lambda self: self.expression(exp.Sqrt, this=self._parse_unary()),
|
||||
TokenType.DPIPE_SLASH: lambda self: self.expression(exp.Cbrt, this=self._parse_unary()),
|
||||
}
|
||||
|
||||
PRIMARY_PARSERS = {
|
||||
|
@ -1000,9 +1002,13 @@ class Parser(metaclass=_Parser):
|
|||
MODIFIERS_ATTACHED_TO_UNION = True
|
||||
UNION_MODIFIERS = {"order", "limit", "offset"}
|
||||
|
||||
# parses no parenthesis if statements as commands
|
||||
# Parses no parenthesis if statements as commands
|
||||
NO_PAREN_IF_COMMANDS = True
|
||||
|
||||
# Whether or not a VALUES keyword needs to be followed by '(' to form a VALUES clause.
|
||||
# If this is True and '(' is not found, the keyword will be treated as an identifier
|
||||
VALUES_FOLLOWED_BY_PAREN = True
|
||||
|
||||
__slots__ = (
|
||||
"error_level",
|
||||
"error_message_context",
|
||||
|
@ -2058,7 +2064,7 @@ class Parser(metaclass=_Parser):
|
|||
partition=self._parse_partition(),
|
||||
where=self._match_pair(TokenType.REPLACE, TokenType.WHERE)
|
||||
and self._parse_conjunction(),
|
||||
expression=self._parse_ddl_select(),
|
||||
expression=self._parse_derived_table_values() or self._parse_ddl_select(),
|
||||
conflict=self._parse_on_conflict(),
|
||||
returning=returning or self._parse_returning(),
|
||||
overwrite=overwrite,
|
||||
|
@ -2267,8 +2273,7 @@ class Parser(metaclass=_Parser):
|
|||
self._match_r_paren()
|
||||
return self.expression(exp.Tuple, expressions=expressions)
|
||||
|
||||
# In presto we can have VALUES 1, 2 which results in 1 column & 2 rows.
|
||||
# https://prestodb.io/docs/current/sql/values.html
|
||||
# In some dialects we can have VALUES 1, 2 which results in 1 column & 2 rows.
|
||||
return self.expression(exp.Tuple, expressions=[self._parse_expression()])
|
||||
|
||||
def _parse_projections(self) -> t.List[exp.Expression]:
|
||||
|
@ -2367,12 +2372,8 @@ class Parser(metaclass=_Parser):
|
|||
# We return early here so that the UNION isn't attached to the subquery by the
|
||||
# following call to _parse_set_operations, but instead becomes the parent node
|
||||
return self._parse_subquery(this, parse_alias=parse_subquery_alias)
|
||||
elif self._match(TokenType.VALUES):
|
||||
this = self.expression(
|
||||
exp.Values,
|
||||
expressions=self._parse_csv(self._parse_value),
|
||||
alias=self._parse_table_alias(),
|
||||
)
|
||||
elif self._match(TokenType.VALUES, advance=False):
|
||||
this = self._parse_derived_table_values()
|
||||
elif from_:
|
||||
this = exp.select("*").from_(from_.this, copy=False)
|
||||
else:
|
||||
|
@ -2969,7 +2970,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
def _parse_derived_table_values(self) -> t.Optional[exp.Values]:
|
||||
is_derived = self._match_pair(TokenType.L_PAREN, TokenType.VALUES)
|
||||
if not is_derived and not self._match(TokenType.VALUES):
|
||||
if not is_derived and not self._match_text_seq("VALUES"):
|
||||
return None
|
||||
|
||||
expressions = self._parse_csv(self._parse_value)
|
||||
|
@ -3655,8 +3656,15 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_type(self, parse_interval: bool = True) -> t.Optional[exp.Expression]:
|
||||
interval = parse_interval and self._parse_interval()
|
||||
if interval:
|
||||
# Convert INTERVAL 'val_1' unit_1 ... 'val_n' unit_n into a sum of intervals
|
||||
while self._match_set((TokenType.STRING, TokenType.NUMBER), advance=False):
|
||||
# Convert INTERVAL 'val_1' unit_1 [+] ... [+] 'val_n' unit_n into a sum of intervals
|
||||
while True:
|
||||
index = self._index
|
||||
self._match(TokenType.PLUS)
|
||||
|
||||
if not self._match_set((TokenType.STRING, TokenType.NUMBER), advance=False):
|
||||
self._retreat(index)
|
||||
break
|
||||
|
||||
interval = self.expression( # type: ignore
|
||||
exp.Add, this=interval, expression=self._parse_interval(match_interval=False)
|
||||
)
|
||||
|
@ -3872,9 +3880,15 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
def _parse_column_reference(self) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_field()
|
||||
if isinstance(this, exp.Identifier):
|
||||
this = self.expression(exp.Column, this=this)
|
||||
return this
|
||||
if (
|
||||
not this
|
||||
and self._match(TokenType.VALUES, advance=False)
|
||||
and self.VALUES_FOLLOWED_BY_PAREN
|
||||
and (not self._next or self._next.token_type != TokenType.L_PAREN)
|
||||
):
|
||||
this = self._parse_id_var()
|
||||
|
||||
return self.expression(exp.Column, this=this) if isinstance(this, exp.Identifier) else this
|
||||
|
||||
def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_bracket(this)
|
||||
|
@ -5511,7 +5525,7 @@ class Parser(metaclass=_Parser):
|
|||
then = self.expression(
|
||||
exp.Insert,
|
||||
this=self._parse_value(),
|
||||
expression=self._match(TokenType.VALUES) and self._parse_value(),
|
||||
expression=self._match_text_seq("VALUES") and self._parse_value(),
|
||||
)
|
||||
elif self._match(TokenType.UPDATE):
|
||||
expressions = self._parse_star()
|
||||
|
|
|
@ -49,7 +49,7 @@ class Schema(abc.ABC):
|
|||
only_visible: bool = False,
|
||||
dialect: DialectType = None,
|
||||
normalize: t.Optional[bool] = None,
|
||||
) -> t.List[str]:
|
||||
) -> t.Sequence[str]:
|
||||
"""
|
||||
Get the column names for a table.
|
||||
|
||||
|
@ -60,7 +60,7 @@ class Schema(abc.ABC):
|
|||
normalize: whether to normalize identifiers according to the dialect of interest.
|
||||
|
||||
Returns:
|
||||
The list of column names.
|
||||
The sequence of column names.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
|
|
|
@ -57,6 +57,8 @@ class TokenType(AutoName):
|
|||
AMP = auto()
|
||||
DPIPE = auto()
|
||||
PIPE = auto()
|
||||
PIPE_SLASH = auto()
|
||||
DPIPE_SLASH = auto()
|
||||
CARET = auto()
|
||||
TILDA = auto()
|
||||
ARROW = auto()
|
||||
|
|
|
@ -213,6 +213,19 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
|
|||
is_posexplode = isinstance(explode, exp.Posexplode)
|
||||
explode_arg = explode.this
|
||||
|
||||
if isinstance(explode, exp.ExplodeOuter):
|
||||
bracket = explode_arg[0]
|
||||
bracket.set("safe", True)
|
||||
bracket.set("offset", True)
|
||||
explode_arg = exp.func(
|
||||
"IF",
|
||||
exp.func(
|
||||
"ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
|
||||
).eq(0),
|
||||
exp.array(bracket, copy=False),
|
||||
explode_arg,
|
||||
)
|
||||
|
||||
# This ensures that we won't use [POS]EXPLODE's argument as a new selection
|
||||
if isinstance(explode_arg, exp.Column):
|
||||
taken_select_names.add(explode_arg.output_name)
|
||||
|
@ -466,6 +479,87 @@ def unqualify_columns(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
|
||||
assert isinstance(expression, exp.Create)
|
||||
for constraint in expression.find_all(exp.UniqueColumnConstraint):
|
||||
if constraint.parent:
|
||||
constraint.parent.pop()
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def ctas_with_tmp_tables_to_create_tmp_view(
|
||||
expression: exp.Expression,
|
||||
tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
|
||||
) -> exp.Expression:
|
||||
assert isinstance(expression, exp.Create)
|
||||
properties = expression.args.get("properties")
|
||||
temporary = any(
|
||||
isinstance(prop, exp.TemporaryProperty)
|
||||
for prop in (properties.expressions if properties else [])
|
||||
)
|
||||
|
||||
# CTAS with temp tables map to CREATE TEMPORARY VIEW
|
||||
if expression.kind == "TABLE" and temporary:
|
||||
if expression.expression:
|
||||
return exp.Create(
|
||||
kind="TEMPORARY VIEW",
|
||||
this=expression.this,
|
||||
expression=expression.expression,
|
||||
)
|
||||
return tmp_storage_provider(expression)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
|
||||
PARTITIONED BY value is an array of column names, they are transformed into a schema.
|
||||
The corresponding columns are removed from the create statement.
|
||||
"""
|
||||
assert isinstance(expression, exp.Create)
|
||||
has_schema = isinstance(expression.this, exp.Schema)
|
||||
is_partitionable = expression.kind in {"TABLE", "VIEW"}
|
||||
|
||||
if has_schema and is_partitionable:
|
||||
prop = expression.find(exp.PartitionedByProperty)
|
||||
if prop and prop.this and not isinstance(prop.this, exp.Schema):
|
||||
schema = expression.this
|
||||
columns = {v.name.upper() for v in prop.this.expressions}
|
||||
partitions = [col for col in schema.expressions if col.name.upper() in columns]
|
||||
schema.set("expressions", [e for e in schema.expressions if e not in partitions])
|
||||
prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
|
||||
expression.set("this", schema)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
|
||||
|
||||
Currently, SQLGlot uses the DATASOURCE format for Spark 3.
|
||||
"""
|
||||
assert isinstance(expression, exp.Create)
|
||||
prop = expression.find(exp.PartitionedByProperty)
|
||||
if (
|
||||
prop
|
||||
and prop.this
|
||||
and isinstance(prop.this, exp.Schema)
|
||||
and all(isinstance(e, exp.ColumnDef) and e.args.get("kind") for e in prop.this.expressions)
|
||||
):
|
||||
prop_this = exp.Tuple(
|
||||
expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
|
||||
)
|
||||
schema = expression.this
|
||||
for e in prop.this.expressions:
|
||||
schema.append("expressions", e)
|
||||
prop.set("this", prop_this)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def preprocess(
|
||||
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
|
||||
) -> t.Callable[[Generator, exp.Expression], str]:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue