1
0
Fork 0

Merging upstream version 21.1.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:27:51 +01:00
parent 4e41aa0bbb
commit bf03050a25
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
91 changed files with 49165 additions and 47854 deletions

View file

@ -148,7 +148,7 @@ def atanh(col: ColumnOrName) -> Column:
def cbrt(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "CBRT")
return Column.invoke_expression_over_column(col, expression.Cbrt)
def ceil(col: ColumnOrName) -> Column:

View file

@ -70,12 +70,10 @@ class SparkSession:
column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
data_expressions = [
exp.Tuple(
expressions=list(
map(
lambda x: F.lit(x).expression,
row if not isinstance(row, dict) else row.values(),
)
exp.tuple_(
*map(
lambda x: F.lit(x).expression,
row if not isinstance(row, dict) else row.values(),
)
)
for row in data

View file

@ -39,24 +39,31 @@ def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Va
alias = expression.args.get("alias")
structs = [
exp.Struct(
return self.unnest_sql(
exp.Unnest(
expressions=[
exp.alias_(value, column_name)
for value, column_name in zip(
t.expressions,
(
alias.columns
if alias and alias.columns
else (f"_c{i}" for i in range(len(t.expressions)))
exp.array(
*(
exp.Struct(
expressions=[
exp.alias_(value, column_name)
for value, column_name in zip(
t.expressions,
(
alias.columns
if alias and alias.columns
else (f"_c{i}" for i in range(len(t.expressions)))
),
)
]
)
for t in expression.find_all(exp.Tuple)
),
copy=False,
)
]
)
for t in expression.find_all(exp.Tuple)
]
return self.unnest_sql(exp.Unnest(expressions=[exp.Array(expressions=structs)]))
)
def _returnsproperty_sql(self: BigQuery.Generator, expression: exp.ReturnsProperty) -> str:
@ -161,12 +168,18 @@ def _pushdown_cte_column_names(expression: exp.Expression) -> exp.Expression:
return expression
def _parse_timestamp(args: t.List) -> exp.StrToTime:
def _parse_parse_timestamp(args: t.List) -> exp.StrToTime:
this = format_time_lambda(exp.StrToTime, "bigquery")([seq_get(args, 1), seq_get(args, 0)])
this.set("zone", seq_get(args, 2))
return this
def _parse_timestamp(args: t.List) -> exp.Timestamp:
timestamp = exp.Timestamp.from_arg_list(args)
timestamp.set("with_tz", True)
return timestamp
def _parse_date(args: t.List) -> exp.Date | exp.DateFromParts:
expr_type = exp.DateFromParts if len(args) == 3 else exp.Date
return expr_type.from_arg_list(args)
@ -318,6 +331,7 @@ class BigQuery(Dialect):
"TIMESTAMP": TokenType.TIMESTAMPTZ,
}
KEYWORDS.pop("DIV")
KEYWORDS.pop("VALUES")
class Parser(parser.Parser):
PREFIXED_PIVOT_COLUMNS = True
@ -348,7 +362,7 @@ class BigQuery(Dialect):
"PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")(
[seq_get(args, 1), seq_get(args, 0)]
),
"PARSE_TIMESTAMP": _parse_timestamp,
"PARSE_TIMESTAMP": _parse_parse_timestamp,
"REGEXP_CONTAINS": exp.RegexpLike.from_arg_list,
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
this=seq_get(args, 0),
@ -367,6 +381,7 @@ class BigQuery(Dialect):
"TIME": _parse_time,
"TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd),
"TIME_SUB": parse_date_delta_with_interval(exp.TimeSub),
"TIMESTAMP": _parse_timestamp,
"TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd),
"TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub),
"TIMESTAMP_MICROS": lambda args: exp.UnixToTime(
@ -395,11 +410,6 @@ class BigQuery(Dialect):
TokenType.TABLE,
}
ID_VAR_TOKENS = {
*parser.Parser.ID_VAR_TOKENS,
TokenType.VALUES,
}
PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS,
"NOT DETERMINISTIC": lambda self: self.expression(

View file

@ -93,6 +93,7 @@ class ClickHouse(Dialect):
"IPV6": TokenType.IPV6,
"AGGREGATEFUNCTION": TokenType.AGGREGATEFUNCTION,
"SIMPLEAGGREGATEFUNCTION": TokenType.SIMPLEAGGREGATEFUNCTION,
"SYSTEM": TokenType.COMMAND,
}
SINGLE_TOKENS = {

View file

@ -654,28 +654,6 @@ def time_format(
return _time_format
def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
"""
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
columns are removed from the create statement.
"""
has_schema = isinstance(expression.this, exp.Schema)
is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
if has_schema and is_partitionable:
prop = expression.find(exp.PartitionedByProperty)
if prop and prop.this and not isinstance(prop.this, exp.Schema):
schema = expression.this
columns = {v.name.upper() for v in prop.this.expressions}
partitions = [col for col in schema.expressions if col.name.upper() in columns]
schema.set("expressions", [e for e in schema.expressions if e not in partitions])
prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
expression.set("this", schema)
return self.create_sql(expression)
def parse_date_delta(
exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
) -> t.Callable[[t.List], E]:
@ -742,7 +720,10 @@ def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
if not expression.expression:
return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP))
from sqlglot.optimizer.annotate_types import annotate_types
target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
return self.sql(exp.cast(expression.this, to=target_type))
if expression.text("expression").lower() in TIMEZONES:
return self.sql(
exp.AtTimeZone(
@ -750,7 +731,7 @@ def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
zone=expression.expression,
)
)
return self.function_fallback_sql(expression)
return self.func("TIMESTAMP", expression.this, expression.expression)
def locate_to_strposition(args: t.List) -> exp.Expression:

View file

@ -5,7 +5,6 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
create_with_partitions_sql,
datestrtodate_sql,
format_time_lambda,
no_trycast_sql,
@ -13,6 +12,7 @@ from sqlglot.dialects.dialect import (
str_position_sql,
timestrtotime_sql,
)
from sqlglot.transforms import preprocess, move_schema_columns_to_partitioned_by
def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.DateSub], str]:
@ -125,7 +125,7 @@ class Drill(Dialect):
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
exp.ArraySize: rename_func("REPEATED_COUNT"),
exp.Create: create_with_partitions_sql,
exp.Create: preprocess([move_schema_columns_to_partitioned_by]),
exp.DateAdd: _date_add_sql("ADD"),
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _date_add_sql("SUB"),

View file

@ -9,7 +9,6 @@ from sqlglot.dialects.dialect import (
NormalizationStrategy,
approx_count_distinct_sql,
arg_max_or_min_no_count,
create_with_partitions_sql,
datestrtodate_sql,
format_time_lambda,
if_sql,
@ -32,6 +31,12 @@ from sqlglot.dialects.dialect import (
timestrtotime_sql,
var_map_sql,
)
from sqlglot.transforms import (
remove_unique_constraints,
ctas_with_tmp_tables_to_create_tmp_view,
preprocess,
move_schema_columns_to_partitioned_by,
)
from sqlglot.helper import seq_get
from sqlglot.parser import parse_var_map
from sqlglot.tokens import TokenType
@ -55,30 +60,6 @@ TIME_DIFF_FACTOR = {
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
def _create_sql(self, expression: exp.Create) -> str:
# remove UNIQUE column constraints
for constraint in expression.find_all(exp.UniqueColumnConstraint):
if constraint.parent:
constraint.parent.pop()
properties = expression.args.get("properties")
temporary = any(
isinstance(prop, exp.TemporaryProperty)
for prop in (properties.expressions if properties else [])
)
# CTAS with temp tables map to CREATE TEMPORARY VIEW
kind = expression.args["kind"]
if kind.upper() == "TABLE" and temporary:
if expression.expression:
return f"CREATE TEMPORARY VIEW {self.sql(expression, 'this')} AS {self.sql(expression, 'expression')}"
else:
# CREATE TEMPORARY TABLE may require storage provider
expression = self.temporary_storage_provider(expression)
return create_with_partitions_sql(self, expression)
def _add_date_sql(self: Hive.Generator, expression: DATE_ADD_OR_SUB) -> str:
if isinstance(expression, exp.TsOrDsAdd) and not expression.unit:
return self.func("DATE_ADD", expression.this, expression.expression)
@ -285,6 +266,7 @@ class Hive(Dialect):
class Parser(parser.Parser):
LOG_DEFAULTS_TO_LN = True
STRICT_CAST = False
VALUES_FOLLOWED_BY_PAREN = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@ -518,7 +500,13 @@ class Hive(Dialect):
"" if e.args.get("allow_null") else "NOT NULL"
),
exp.VarMap: var_map_sql,
exp.Create: _create_sql,
exp.Create: preprocess(
[
remove_unique_constraints,
ctas_with_tmp_tables_to_create_tmp_view,
move_schema_columns_to_partitioned_by,
]
),
exp.Quantile: rename_func("PERCENTILE"),
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
exp.RegexpExtract: regexp_extract_sql,
@ -581,10 +569,6 @@ class Hive(Dialect):
return super()._jsonpathkey_sql(expression)
def temporary_storage_provider(self, expression: exp.Create) -> exp.Create:
# Hive has no temporary storage provider (there are hive settings though)
return expression
def parameter_sql(self, expression: exp.Parameter) -> str:
this = self.sql(expression, "this")
expression_sql = self.sql(expression, "expression")

View file

@ -445,6 +445,7 @@ class MySQL(Dialect):
LOG_DEFAULTS_TO_LN = True
STRING_ALIASES = True
VALUES_FOLLOWED_BY_PAREN = False
def _parse_primary_key_part(self) -> t.Optional[exp.Expression]:
this = self._parse_id_var()

View file

@ -88,6 +88,7 @@ class Oracle(Dialect):
class Parser(parser.Parser):
ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP}
VALUES_FOLLOWED_BY_PAREN = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,

View file

@ -244,6 +244,8 @@ class Postgres(Dialect):
"@@": TokenType.DAT,
"@>": TokenType.AT_GT,
"<@": TokenType.LT_AT,
"|/": TokenType.PIPE_SLASH,
"||/": TokenType.DPIPE_SLASH,
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
"BIGSERIAL": TokenType.BIGSERIAL,

View file

@ -225,6 +225,8 @@ class Presto(Dialect):
}
class Parser(parser.Parser):
VALUES_FOLLOWED_BY_PAREN = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARBITRARY": exp.AnyValue.from_arg_list,

View file

@ -136,11 +136,11 @@ class Redshift(Postgres):
refs.add(
(
this.args["from"] if i == 0 else this.args["joins"][i - 1]
).alias_or_name.lower()
).this.alias.lower()
)
table = join.this
if isinstance(table, exp.Table):
table = join.this
if isinstance(table, exp.Table) and not join.args.get("on"):
if table.parts[0].name.lower() in refs:
table.replace(table.to_column())
return this
@ -158,6 +158,7 @@ class Redshift(Postgres):
"UNLOAD": TokenType.COMMAND,
"VARBYTE": TokenType.VARBINARY,
}
KEYWORDS.pop("VALUES")
# Redshift allows # to appear as a table identifier prefix
SINGLE_TOKENS = Postgres.Tokenizer.SINGLE_TOKENS.copy()

View file

@ -477,6 +477,8 @@ class Snowflake(Dialect):
"PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
"COLUMNS": _show_parser("COLUMNS"),
"USERS": _show_parser("USERS"),
"TERSE USERS": _show_parser("USERS"),
}
STAGED_FILE_SINGLE_TOKENS = {

View file

@ -5,8 +5,14 @@ import typing as t
from sqlglot import exp
from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.hive import _parse_ignore_nulls
from sqlglot.dialects.spark2 import Spark2
from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider
from sqlglot.helper import seq_get
from sqlglot.transforms import (
ctas_with_tmp_tables_to_create_tmp_view,
remove_unique_constraints,
preprocess,
move_partitioned_by_to_schema_columns,
)
def _parse_datediff(args: t.List) -> exp.Expression:
@ -35,6 +41,15 @@ def _parse_datediff(args: t.List) -> exp.Expression:
)
def _normalize_partition(e: exp.Expression) -> exp.Expression:
"""Normalize the expressions in PARTITION BY (<expression>, <expression>, ...)"""
if isinstance(e, str):
return exp.to_identifier(e)
if isinstance(e, exp.Literal):
return exp.to_identifier(e.name)
return e
class Spark(Spark2):
class Tokenizer(Spark2.Tokenizer):
RAW_STRINGS = [
@ -72,6 +87,17 @@ class Spark(Spark2):
TRANSFORMS = {
**Spark2.Generator.TRANSFORMS,
exp.Create: preprocess(
[
remove_unique_constraints,
lambda e: ctas_with_tmp_tables_to_create_tmp_view(
e, temporary_storage_provider
),
move_partitioned_by_to_schema_columns,
]
),
exp.PartitionedByProperty: lambda self,
e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}",
exp.StartsWith: rename_func("STARTSWITH"),
exp.TimestampAdd: lambda self, e: self.func(
"DATEADD", e.args.get("unit") or "DAY", e.expression, e.this

View file

@ -13,6 +13,12 @@ from sqlglot.dialects.dialect import (
)
from sqlglot.dialects.hive import Hive
from sqlglot.helper import seq_get
from sqlglot.transforms import (
preprocess,
remove_unique_constraints,
ctas_with_tmp_tables_to_create_tmp_view,
move_schema_columns_to_partitioned_by,
)
def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str:
@ -95,6 +101,13 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
return expression
def temporary_storage_provider(expression: exp.Expression) -> exp.Expression:
# spark2, spark, Databricks require a storage provider for temporary tables
provider = exp.FileFormatProperty(this=exp.Literal.string("parquet"))
expression.args["properties"].append("expressions", provider)
return expression
class Spark2(Hive):
class Parser(Hive.Parser):
TRIM_PATTERN_FIRST = True
@ -121,7 +134,6 @@ class Spark2(Hive):
),
zone=seq_get(args, 1),
),
"IIF": exp.If.from_arg_list,
"INT": _parse_as_cast("int"),
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
"RLIKE": exp.RegexpLike.from_arg_list,
@ -193,6 +205,15 @@ class Spark2(Hive):
e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
exp.Create: preprocess(
[
remove_unique_constraints,
lambda e: ctas_with_tmp_tables_to_create_tmp_view(
e, temporary_storage_provider
),
move_schema_columns_to_partitioned_by,
]
),
exp.DateFromParts: rename_func("MAKE_DATE"),
exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
exp.DayOfMonth: rename_func("DAYOFMONTH"),
@ -251,12 +272,6 @@ class Spark2(Hive):
return self.func("STRUCT", *args)
def temporary_storage_provider(self, expression: exp.Create) -> exp.Create:
# spark2, spark, Databricks require a storage provider for temporary tables
provider = exp.FileFormatProperty(this=exp.Literal.string("parquet"))
expression.args["properties"].append("expressions", provider)
return expression
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
if is_parse_json(expression.this):
schema = f"'{self.sql(expression, 'to')}'"

View file

@ -132,6 +132,7 @@ class SQLite(Dialect):
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: _date_add_sql,
exp.DateStrToDate: lambda self, e: self.sql(e, "this"),
exp.If: rename_func("IIF"),
exp.ILike: no_ilike_sql,
exp.JSONExtract: _json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_sql,

View file

@ -1,10 +1,14 @@
from __future__ import annotations
from sqlglot import exp, generator, parser, transforms
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import Dialect, rename_func
class Tableau(Dialect):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = [("[", "]")]
QUOTES = ["'", '"']
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False

View file

@ -74,6 +74,7 @@ class Teradata(Dialect):
class Parser(parser.Parser):
TABLESAMPLE_CSV = True
VALUES_FOLLOWED_BY_PAREN = False
CHARSET_TRANSLATORS = {
"GRAPHIC_TO_KANJISJIS",

View file

@ -457,7 +457,6 @@ class TSQL(Dialect):
"FORMAT": _parse_format,
"GETDATE": exp.CurrentTimestamp.from_arg_list,
"HASHBYTES": _parse_hashbytes,
"IIF": exp.If.from_arg_list,
"ISNULL": exp.Coalesce.from_arg_list,
"JSON_QUERY": parser.parse_extract_json_with_path(exp.JSONExtract),
"JSON_VALUE": parser.parse_extract_json_with_path(exp.JSONExtractScalar),

View file

@ -1090,6 +1090,11 @@ class Create(DDL):
"clone": False,
}
@property
def kind(self) -> t.Optional[str]:
kind = self.args.get("kind")
return kind and kind.upper()
# https://docs.snowflake.com/en/sql-reference/sql/create-clone
# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_clone_statement
@ -4626,6 +4631,11 @@ class CountIf(AggFunc):
_sql_names = ["COUNT_IF", "COUNTIF"]
# cube root
class Cbrt(Func):
pass
class CurrentDate(Func):
arg_types = {"this": False}
@ -4728,7 +4738,7 @@ class Extract(Func):
class Timestamp(Func):
arg_types = {"this": False, "expression": False}
arg_types = {"this": False, "expression": False, "with_tz": False}
class TimestampAdd(Func, TimeUnit):
@ -4833,7 +4843,7 @@ class Posexplode(Explode):
pass
class PosexplodeOuter(Posexplode):
class PosexplodeOuter(Posexplode, ExplodeOuter):
pass
@ -4868,6 +4878,7 @@ class Xor(Connector, Func):
class If(Func):
arg_types = {"this": True, "true": True, "false": False}
_sql_names = ["IF", "IIF"]
class Nullif(Func):
@ -6883,6 +6894,7 @@ def replace_tables(
table = to_table(
new_name,
**{k: v for k, v in node.args.items() if k not in TABLE_PARTS},
dialect=dialect,
)
table.add_comments([original])
return table
@ -7072,6 +7084,60 @@ def cast_unless(
return cast(expr, to, **opts)
def array(
*expressions: ExpOrStr, copy: bool = True, dialect: DialectType = None, **kwargs
) -> Array:
"""
Returns an array.
Examples:
>>> array(1, 'x').sql()
'ARRAY(1, x)'
Args:
expressions: the expressions to add to the array.
copy: whether or not to copy the argument expressions.
dialect: the source dialect.
kwargs: the kwargs used to instantiate the function of interest.
Returns:
An array expression.
"""
return Array(
expressions=[
maybe_parse(expression, copy=copy, dialect=dialect, **kwargs)
for expression in expressions
]
)
def tuple_(
*expressions: ExpOrStr, copy: bool = True, dialect: DialectType = None, **kwargs
) -> Tuple:
"""
Returns an tuple.
Examples:
>>> tuple_(1, 'x').sql()
'(1, x)'
Args:
expressions: the expressions to add to the tuple.
copy: whether or not to copy the argument expressions.
dialect: the source dialect.
kwargs: the kwargs used to instantiate the function of interest.
Returns:
A tuple expression.
"""
return Tuple(
expressions=[
maybe_parse(expression, copy=copy, dialect=dialect, **kwargs)
for expression in expressions
]
)
def true() -> Boolean:
"""
Returns a true Boolean expression.

View file

@ -124,6 +124,7 @@ class Generator(metaclass=_Generator):
exp.StabilityProperty: lambda self, e: e.name,
exp.TemporaryProperty: lambda self, e: "TEMPORARY",
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
exp.Timestamp: lambda self, e: self.func("TIMESTAMP", e.this, e.expression),
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions),
exp.TransientProperty: lambda self, e: "TRANSIENT",
@ -3360,7 +3361,7 @@ class Generator(metaclass=_Generator):
return self.sql(arg)
cond_for_null = arg.is_(exp.null())
return self.sql(exp.func("IF", cond_for_null, exp.null(), exp.Array(expressions=[arg])))
return self.sql(exp.func("IF", cond_for_null, exp.null(), exp.array(arg, copy=False)))
def tsordstotime_sql(self, expression: exp.TsOrDsToTime) -> str:
this = expression.this

View file

@ -6,7 +6,7 @@ import logging
import re
import sys
import typing as t
from collections.abc import Collection
from collections.abc import Collection, Set
from contextlib import contextmanager
from copy import copy
from enum import Enum
@ -496,3 +496,31 @@ DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
return expression is not None and expression.name.lower() in DATE_UNITS
K = t.TypeVar("K")
V = t.TypeVar("V")
class SingleValuedMapping(t.Mapping[K, V]):
"""
Mapping where all keys return the same value.
This rigamarole is meant to avoid copying keys, which was originally intended
as an optimization while qualifying columns for tables with lots of columns.
"""
def __init__(self, keys: t.Collection[K], value: V):
self._keys = keys if isinstance(keys, Set) else set(keys)
self._value = value
def __getitem__(self, key: K) -> V:
if key in self._keys:
return self._value
raise KeyError(key)
def __len__(self) -> int:
return len(self._keys)
def __iter__(self) -> t.Iterator[K]:
return iter(self._keys)

View file

@ -153,7 +153,7 @@ def lineage(
raise ValueError(f"Could not find {column} in {scope.expression}")
for s in scope.union_scopes:
to_node(index, scope=s, upstream=upstream)
to_node(index, scope=s, upstream=upstream, alias=alias)
return upstream
@ -209,7 +209,11 @@ def lineage(
if isinstance(source, Scope):
# The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
to_node(
c.name, scope=source, scope_name=table, upstream=node, alias=aliases.get(table)
c.name,
scope=source,
scope_name=table,
upstream=node,
alias=aliases.get(table) or alias,
)
else:
# The source is not a scope - we've reached the end of the line. At this point, if a source is not found

View file

@ -204,7 +204,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.TimeAdd,
exp.TimeStrToTime,
exp.TimeSub,
exp.Timestamp,
exp.TimestampAdd,
exp.TimestampSub,
exp.UnixToTime,
@ -276,6 +275,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
exp.Timestamp: lambda self, e: self._annotate_with_type(
e,
exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP,
),
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True),

View file

@ -38,7 +38,12 @@ def replace_date_funcs(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Date) and not node.expressions and not node.args.get("zone"):
return exp.cast(node.this, to=exp.DataType.Type.DATE)
if isinstance(node, exp.Timestamp) and not node.expression:
return exp.cast(node.this, to=exp.DataType.Type.TIMESTAMP)
if not node.type:
from sqlglot.optimizer.annotate_types import annotate_types
node = annotate_types(node)
return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP)
return node
@ -76,9 +81,8 @@ def coerce_type(node: exp.Expression) -> exp.Expression:
def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
if (
isinstance(expression, exp.Cast)
and expression.to.type
and expression.this.type
and expression.to.type.this == expression.this.type.this
and expression.to.this == expression.this.type.this
):
return expression.this
return expression

View file

@ -6,7 +6,7 @@ import typing as t
from sqlglot import alias, exp
from sqlglot.dialects.dialect import Dialect, DialectType
from sqlglot.errors import OptimizeError
from sqlglot.helper import seq_get
from sqlglot.helper import seq_get, SingleValuedMapping
from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
from sqlglot.optimizer.simplify import simplify_parens
from sqlglot.schema import Schema, ensure_schema
@ -586,8 +586,8 @@ class Resolver:
def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
self.scope = scope
self.schema = schema
self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
self._all_columns: t.Optional[t.Set[str]] = None
self._infer_schema = infer_schema
@ -640,7 +640,7 @@ class Resolver:
}
return self._all_columns
def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
"""Resolve the source columns for a given source `name`."""
if name not in self.scope.sources:
raise OptimizeError(f"Unknown table: {name}")
@ -662,10 +662,15 @@ class Resolver:
else:
column_aliases = []
# If the source's columns are aliased, their aliases shadow the corresponding column names
return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
if column_aliases:
# If the source's columns are aliased, their aliases shadow the corresponding column names.
# This can be expensive if there are lots of columns, so only do this if column_aliases exist.
return [
alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
]
return columns
def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]:
def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
if self._source_columns is None:
self._source_columns = {
source_name: self.get_source_columns(source_name)
@ -676,8 +681,8 @@ class Resolver:
return self._source_columns
def _get_unambiguous_columns(
self, source_columns: t.Dict[str, t.List[str]]
) -> t.Dict[str, str]:
self, source_columns: t.Dict[str, t.Sequence[str]]
) -> t.Mapping[str, str]:
"""
Find all the unambiguous columns in sources.
@ -693,12 +698,17 @@ class Resolver:
source_columns_pairs = list(source_columns.items())
first_table, first_columns = source_columns_pairs[0]
unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
if len(source_columns_pairs) == 1:
# Performance optimization - avoid copying first_columns if there is only one table.
return SingleValuedMapping(first_columns, first_table)
unambiguous_columns = {col: first_table for col in first_columns}
all_columns = set(unambiguous_columns)
for table, columns in source_columns_pairs[1:]:
unique = self._find_unique_columns(columns)
ambiguous = set(all_columns).intersection(unique)
unique = set(columns)
ambiguous = all_columns.intersection(unique)
all_columns.update(columns)
for column in ambiguous:
@ -707,19 +717,3 @@ class Resolver:
unambiguous_columns[column] = table
return unambiguous_columns
@staticmethod
def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]:
"""
Find the unique columns in a list of columns.
Example:
>>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
['a', 'c']
This is necessary because duplicate column names are ambiguous.
"""
counts: t.Dict[str, int] = {}
for column in columns:
counts[column] = counts.get(column, 0) + 1
return {column for column, count in counts.items() if count == 1}

View file

@ -29,8 +29,8 @@ def parse_var_map(args: t.List) -> exp.StarMap | exp.VarMap:
values.append(args[i + 1])
return exp.VarMap(
keys=exp.Array(expressions=keys),
values=exp.Array(expressions=values),
keys=exp.array(*keys, copy=False),
values=exp.array(*values, copy=False),
)
@ -638,6 +638,8 @@ class Parser(metaclass=_Parser):
TokenType.NOT: lambda self: self.expression(exp.Not, this=self._parse_equality()),
TokenType.TILDA: lambda self: self.expression(exp.BitwiseNot, this=self._parse_unary()),
TokenType.DASH: lambda self: self.expression(exp.Neg, this=self._parse_unary()),
TokenType.PIPE_SLASH: lambda self: self.expression(exp.Sqrt, this=self._parse_unary()),
TokenType.DPIPE_SLASH: lambda self: self.expression(exp.Cbrt, this=self._parse_unary()),
}
PRIMARY_PARSERS = {
@ -1000,9 +1002,13 @@ class Parser(metaclass=_Parser):
MODIFIERS_ATTACHED_TO_UNION = True
UNION_MODIFIERS = {"order", "limit", "offset"}
# parses no parenthesis if statements as commands
# Parses no parenthesis if statements as commands
NO_PAREN_IF_COMMANDS = True
# Whether or not a VALUES keyword needs to be followed by '(' to form a VALUES clause.
# If this is True and '(' is not found, the keyword will be treated as an identifier
VALUES_FOLLOWED_BY_PAREN = True
__slots__ = (
"error_level",
"error_message_context",
@ -2058,7 +2064,7 @@ class Parser(metaclass=_Parser):
partition=self._parse_partition(),
where=self._match_pair(TokenType.REPLACE, TokenType.WHERE)
and self._parse_conjunction(),
expression=self._parse_ddl_select(),
expression=self._parse_derived_table_values() or self._parse_ddl_select(),
conflict=self._parse_on_conflict(),
returning=returning or self._parse_returning(),
overwrite=overwrite,
@ -2267,8 +2273,7 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
return self.expression(exp.Tuple, expressions=expressions)
# In presto we can have VALUES 1, 2 which results in 1 column & 2 rows.
# https://prestodb.io/docs/current/sql/values.html
# In some dialects we can have VALUES 1, 2 which results in 1 column & 2 rows.
return self.expression(exp.Tuple, expressions=[self._parse_expression()])
def _parse_projections(self) -> t.List[exp.Expression]:
@ -2367,12 +2372,8 @@ class Parser(metaclass=_Parser):
# We return early here so that the UNION isn't attached to the subquery by the
# following call to _parse_set_operations, but instead becomes the parent node
return self._parse_subquery(this, parse_alias=parse_subquery_alias)
elif self._match(TokenType.VALUES):
this = self.expression(
exp.Values,
expressions=self._parse_csv(self._parse_value),
alias=self._parse_table_alias(),
)
elif self._match(TokenType.VALUES, advance=False):
this = self._parse_derived_table_values()
elif from_:
this = exp.select("*").from_(from_.this, copy=False)
else:
@ -2969,7 +2970,7 @@ class Parser(metaclass=_Parser):
def _parse_derived_table_values(self) -> t.Optional[exp.Values]:
is_derived = self._match_pair(TokenType.L_PAREN, TokenType.VALUES)
if not is_derived and not self._match(TokenType.VALUES):
if not is_derived and not self._match_text_seq("VALUES"):
return None
expressions = self._parse_csv(self._parse_value)
@ -3655,8 +3656,15 @@ class Parser(metaclass=_Parser):
def _parse_type(self, parse_interval: bool = True) -> t.Optional[exp.Expression]:
interval = parse_interval and self._parse_interval()
if interval:
# Convert INTERVAL 'val_1' unit_1 ... 'val_n' unit_n into a sum of intervals
while self._match_set((TokenType.STRING, TokenType.NUMBER), advance=False):
# Convert INTERVAL 'val_1' unit_1 [+] ... [+] 'val_n' unit_n into a sum of intervals
while True:
index = self._index
self._match(TokenType.PLUS)
if not self._match_set((TokenType.STRING, TokenType.NUMBER), advance=False):
self._retreat(index)
break
interval = self.expression( # type: ignore
exp.Add, this=interval, expression=self._parse_interval(match_interval=False)
)
@ -3872,9 +3880,15 @@ class Parser(metaclass=_Parser):
def _parse_column_reference(self) -> t.Optional[exp.Expression]:
this = self._parse_field()
if isinstance(this, exp.Identifier):
this = self.expression(exp.Column, this=this)
return this
if (
not this
and self._match(TokenType.VALUES, advance=False)
and self.VALUES_FOLLOWED_BY_PAREN
and (not self._next or self._next.token_type != TokenType.L_PAREN)
):
this = self._parse_id_var()
return self.expression(exp.Column, this=this) if isinstance(this, exp.Identifier) else this
def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
this = self._parse_bracket(this)
@ -5511,7 +5525,7 @@ class Parser(metaclass=_Parser):
then = self.expression(
exp.Insert,
this=self._parse_value(),
expression=self._match(TokenType.VALUES) and self._parse_value(),
expression=self._match_text_seq("VALUES") and self._parse_value(),
)
elif self._match(TokenType.UPDATE):
expressions = self._parse_star()

View file

@ -49,7 +49,7 @@ class Schema(abc.ABC):
only_visible: bool = False,
dialect: DialectType = None,
normalize: t.Optional[bool] = None,
) -> t.List[str]:
) -> t.Sequence[str]:
"""
Get the column names for a table.
@ -60,7 +60,7 @@ class Schema(abc.ABC):
normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The list of column names.
The sequence of column names.
"""
@abc.abstractmethod

View file

@ -57,6 +57,8 @@ class TokenType(AutoName):
AMP = auto()
DPIPE = auto()
PIPE = auto()
PIPE_SLASH = auto()
DPIPE_SLASH = auto()
CARET = auto()
TILDA = auto()
ARROW = auto()

View file

@ -213,6 +213,19 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
is_posexplode = isinstance(explode, exp.Posexplode)
explode_arg = explode.this
if isinstance(explode, exp.ExplodeOuter):
bracket = explode_arg[0]
bracket.set("safe", True)
bracket.set("offset", True)
explode_arg = exp.func(
"IF",
exp.func(
"ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
).eq(0),
exp.array(bracket, copy=False),
explode_arg,
)
# This ensures that we won't use [POS]EXPLODE's argument as a new selection
if isinstance(explode_arg, exp.Column):
taken_select_names.add(explode_arg.output_name)
@ -466,6 +479,87 @@ def unqualify_columns(expression: exp.Expression) -> exp.Expression:
return expression
def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
assert isinstance(expression, exp.Create)
for constraint in expression.find_all(exp.UniqueColumnConstraint):
if constraint.parent:
constraint.parent.pop()
return expression
def ctas_with_tmp_tables_to_create_tmp_view(
expression: exp.Expression,
tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
) -> exp.Expression:
assert isinstance(expression, exp.Create)
properties = expression.args.get("properties")
temporary = any(
isinstance(prop, exp.TemporaryProperty)
for prop in (properties.expressions if properties else [])
)
# CTAS with temp tables map to CREATE TEMPORARY VIEW
if expression.kind == "TABLE" and temporary:
if expression.expression:
return exp.Create(
kind="TEMPORARY VIEW",
this=expression.this,
expression=expression.expression,
)
return tmp_storage_provider(expression)
return expression
def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
"""
In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
PARTITIONED BY value is an array of column names, they are transformed into a schema.
The corresponding columns are removed from the create statement.
"""
assert isinstance(expression, exp.Create)
has_schema = isinstance(expression.this, exp.Schema)
is_partitionable = expression.kind in {"TABLE", "VIEW"}
if has_schema and is_partitionable:
prop = expression.find(exp.PartitionedByProperty)
if prop and prop.this and not isinstance(prop.this, exp.Schema):
schema = expression.this
columns = {v.name.upper() for v in prop.this.expressions}
partitions = [col for col in schema.expressions if col.name.upper() in columns]
schema.set("expressions", [e for e in schema.expressions if e not in partitions])
prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
expression.set("this", schema)
return expression
def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
"""
Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
Currently, SQLGlot uses the DATASOURCE format for Spark 3.
"""
assert isinstance(expression, exp.Create)
prop = expression.find(exp.PartitionedByProperty)
if (
prop
and prop.this
and isinstance(prop.this, exp.Schema)
and all(isinstance(e, exp.ColumnDef) and e.args.get("kind") for e in prop.this.expressions)
):
prop_this = exp.Tuple(
expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
)
schema = expression.this
for e in prop.this.expressions:
schema.append("expressions", e)
prop.set("this", prop_this)
return expression
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
) -> t.Callable[[Generator, exp.Expression], str]: