1
0
Fork 0

Merging upstream version 12.2.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:53:39 +01:00
parent fffa0d5761
commit 62b2b24d3b
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
100 changed files with 35022 additions and 30936 deletions

View file

@ -50,7 +50,7 @@ if t.TYPE_CHECKING:
T = t.TypeVar("T", bound=Expression)
__version__ = "11.7.1"
__version__ = "12.2.0"
pretty = False
"""Whether to format generated SQL by default."""
@ -181,7 +181,7 @@ def transpile(
Returns:
The list of transpiled SQL statements.
"""
write = write or read if identity else write
write = (read if write is None else write) if identity else write
return [
Dialect.get_or_raise(write)().generate(expression, **opts)
for expression in parse(sql, read, error_level=error_level)

View file

@ -747,11 +747,11 @@ def ascii(col: ColumnOrLiteral) -> Column:
def base64(col: ColumnOrLiteral) -> Column:
return Column.invoke_anonymous_function(col, "BASE64")
return Column.invoke_expression_over_column(col, expression.ToBase64)
def unbase64(col: ColumnOrLiteral) -> Column:
return Column.invoke_anonymous_function(col, "UNBASE64")
return Column.invoke_expression_over_column(col, expression.FromBase64)
def ltrim(col: ColumnOrName) -> Column:

View file

@ -70,6 +70,7 @@ from sqlglot.dialects.presto import Presto
from sqlglot.dialects.redshift import Redshift
from sqlglot.dialects.snowflake import Snowflake
from sqlglot.dialects.spark import Spark
from sqlglot.dialects.spark2 import Spark2
from sqlglot.dialects.sqlite import SQLite
from sqlglot.dialects.starrocks import StarRocks
from sqlglot.dialects.tableau import Tableau

View file

@ -39,18 +39,26 @@ def _date_add_sql(
def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str:
if not isinstance(expression.unnest().parent, exp.From):
expression = t.cast(exp.Values, transforms.remove_precision_parameterized_types(expression))
return self.values_sql(expression)
rows = [tuple_exp.expressions for tuple_exp in expression.find_all(exp.Tuple)]
structs = []
for row in rows:
aliases = [
exp.alias_(value, column_name)
for value, column_name in zip(row, expression.args["alias"].args["columns"])
]
structs.append(exp.Struct(expressions=aliases))
unnest_exp = exp.Unnest(expressions=[exp.Array(expressions=structs)])
return self.unnest_sql(unnest_exp)
alias = expression.args.get("alias")
structs = [
exp.Struct(
expressions=[
exp.alias_(value, column_name)
for value, column_name in zip(
t.expressions,
alias.columns
if alias and alias.columns
else (f"_c{i}" for i in range(len(t.expressions))),
)
]
)
for t in expression.find_all(exp.Tuple)
]
return self.unnest_sql(exp.Unnest(expressions=[exp.Array(expressions=structs)]))
def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsProperty) -> str:
@ -128,6 +136,7 @@ class BigQuery(Dialect):
IDENTIFIERS = ["`"]
STRING_ESCAPES = ["\\"]
HEX_STRINGS = [("0x", ""), ("0X", "")]
BYTE_STRINGS = [("b'", "'"), ("B'", "'")]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@ -139,6 +148,7 @@ class BigQuery(Dialect):
"GEOGRAPHY": TokenType.GEOGRAPHY,
"FLOAT64": TokenType.DOUBLE,
"INT64": TokenType.BIGINT,
"BYTES": TokenType.BINARY,
"NOT DETERMINISTIC": TokenType.VOLATILE,
"UNKNOWN": TokenType.NULL,
}
@ -153,7 +163,7 @@ class BigQuery(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"DATE_TRUNC": lambda args: exp.DateTrunc(
unit=exp.Literal.string(seq_get(args, 1).name), # type: ignore
unit=exp.Literal.string(str(seq_get(args, 1))),
this=seq_get(args, 0),
),
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
@ -206,6 +216,12 @@ class BigQuery(Dialect):
"NOT DETERMINISTIC": lambda self: self.expression(
exp.StabilityProperty, this=exp.Literal.string("VOLATILE")
),
"OPTIONS": lambda self: self._parse_with_property(),
}
CONSTRAINT_PARSERS = {
**parser.Parser.CONSTRAINT_PARSERS, # type: ignore
"OPTIONS": lambda self: exp.Properties(expressions=self._parse_with_property()),
}
class Generator(generator.Generator):
@ -217,11 +233,11 @@ class BigQuery(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.AtTimeZone: lambda self, e: self.func(
"TIMESTAMP", self.func("DATETIME", e.this, e.args.get("zone"))
),
exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
exp.DateSub: _date_add_sql("DATE", "SUB"),
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
@ -234,7 +250,9 @@ class BigQuery(Dialect):
exp.IntDiv: rename_func("DIV"),
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Select: transforms.preprocess([_unqualify_unnest]),
exp.Select: transforms.preprocess(
[_unqualify_unnest, transforms.eliminate_distinct_on]
),
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
exp.TimeSub: _date_add_sql("TIME", "SUB"),
@ -259,6 +277,7 @@ class BigQuery(Dialect):
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC",
exp.DataType.Type.BIGINT: "INT64",
exp.DataType.Type.BINARY: "BYTES",
exp.DataType.Type.BOOLEAN: "BOOL",
exp.DataType.Type.CHAR: "STRING",
exp.DataType.Type.DECIMAL: "NUMERIC",
@ -272,6 +291,7 @@ class BigQuery(Dialect):
exp.DataType.Type.TIMESTAMP: "DATETIME",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.VARBINARY: "BYTES",
exp.DataType.Type.VARCHAR: "STRING",
exp.DataType.Type.VARIANT: "ANY TYPE",
}
@ -310,3 +330,6 @@ class BigQuery(Dialect):
if not expression.args.get("distinct", False):
self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery")
return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
def with_properties(self, properties: exp.Properties) -> str:
return self.properties(properties, prefix=self.seg("OPTIONS"))

View file

@ -22,6 +22,8 @@ class ClickHouse(Dialect):
class Tokenizer(tokens.Tokenizer):
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
IDENTIFIERS = ['"', "`"]
BIT_STRINGS = [("0b", "")]
HEX_STRINGS = [("0x", ""), ("0X", "")]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@ -31,10 +33,18 @@ class ClickHouse(Dialect):
"FINAL": TokenType.FINAL,
"FLOAT32": TokenType.FLOAT,
"FLOAT64": TokenType.DOUBLE,
"INT16": TokenType.SMALLINT,
"INT32": TokenType.INT,
"INT64": TokenType.BIGINT,
"INT8": TokenType.TINYINT,
"UINT8": TokenType.UTINYINT,
"INT16": TokenType.SMALLINT,
"UINT16": TokenType.USMALLINT,
"INT32": TokenType.INT,
"UINT32": TokenType.UINT,
"INT64": TokenType.BIGINT,
"UINT64": TokenType.UBIGINT,
"INT128": TokenType.INT128,
"UINT128": TokenType.UINT128,
"INT256": TokenType.INT256,
"UINT256": TokenType.UINT256,
"TUPLE": TokenType.STRUCT,
}
@ -121,9 +131,17 @@ class ClickHouse(Dialect):
exp.DataType.Type.ARRAY: "Array",
exp.DataType.Type.STRUCT: "Tuple",
exp.DataType.Type.TINYINT: "Int8",
exp.DataType.Type.UTINYINT: "UInt8",
exp.DataType.Type.SMALLINT: "Int16",
exp.DataType.Type.USMALLINT: "UInt16",
exp.DataType.Type.INT: "Int32",
exp.DataType.Type.UINT: "UInt32",
exp.DataType.Type.BIGINT: "Int64",
exp.DataType.Type.UBIGINT: "UInt64",
exp.DataType.Type.INT128: "Int128",
exp.DataType.Type.UINT128: "UInt128",
exp.DataType.Type.INT256: "Int256",
exp.DataType.Type.UINT256: "UInt256",
exp.DataType.Type.FLOAT: "Float32",
exp.DataType.Type.DOUBLE: "Float64",
}

View file

@ -1,6 +1,6 @@
from __future__ import annotations
from sqlglot import exp
from sqlglot import exp, transforms
from sqlglot.dialects.dialect import parse_date_delta
from sqlglot.dialects.spark import Spark
from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql
@ -29,13 +29,20 @@ class Databricks(Spark):
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.JSONExtract: lambda self, e: self.binary(e, ":"),
exp.Select: transforms.preprocess(
[
transforms.eliminate_distinct_on,
transforms.unnest_to_explode,
]
),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}
TRANSFORMS.pop(exp.Select) # Remove the ELIMINATE_QUALIFY transformation
PARAMETER_TOKEN = "$"
class Tokenizer(Spark.Tokenizer):
HEX_STRINGS = []
SINGLE_TOKENS = {
**Spark.Tokenizer.SINGLE_TOKENS,
"$": TokenType.PARAMETER,

View file

@ -28,6 +28,7 @@ class Dialects(str, Enum):
REDSHIFT = "redshift"
SNOWFLAKE = "snowflake"
SPARK = "spark"
SPARK2 = "spark2"
SQLITE = "sqlite"
STARROCKS = "starrocks"
TABLEAU = "tableau"
@ -69,30 +70,17 @@ class _Dialect(type):
klass.tokenizer_class._IDENTIFIERS.items()
)[0]
if (
klass.tokenizer_class._BIT_STRINGS
and exp.BitString not in klass.generator_class.TRANSFORMS
):
bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.BitString
] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}"
if (
klass.tokenizer_class._HEX_STRINGS
and exp.HexString not in klass.generator_class.TRANSFORMS
):
hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.HexString
] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
if (
klass.tokenizer_class._BYTE_STRINGS
and exp.ByteString not in klass.generator_class.TRANSFORMS
):
be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.ByteString
] = lambda self, e: f"{be_start}{self.sql(e, 'this')}{be_end}"
klass.bit_start, klass.bit_end = seq_get(
list(klass.tokenizer_class._BIT_STRINGS.items()), 0
) or (None, None)
klass.hex_start, klass.hex_end = seq_get(
list(klass.tokenizer_class._HEX_STRINGS.items()), 0
) or (None, None)
klass.byte_start, klass.byte_end = seq_get(
list(klass.tokenizer_class._BYTE_STRINGS.items()), 0
) or (None, None)
return klass
@ -198,6 +186,12 @@ class Dialect(metaclass=_Dialect):
**{
"quote_start": self.quote_start,
"quote_end": self.quote_end,
"bit_start": self.bit_start,
"bit_end": self.bit_end,
"hex_start": self.hex_start,
"hex_end": self.hex_end,
"byte_start": self.byte_start,
"byte_end": self.byte_end,
"identifier_start": self.identifier_start,
"identifier_end": self.identifier_end,
"string_escape": self.tokenizer_class.STRING_ESCAPES[0],

View file

@ -2,7 +2,7 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
create_with_partitions_sql,
@ -145,6 +145,7 @@ class Drill(Dialect):
exp.StrPosition: str_position_sql,
exp.StrToDate: _str_to_date,
exp.Pow: rename_func("POW"),
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: timestrtotime_sql,

View file

@ -1,5 +1,7 @@
from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
@ -23,52 +25,61 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
def _ts_or_ds_add(self, expression):
def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
def _date_add(self, expression):
def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
return f"{this} + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
def _array_sort_sql(self, expression):
def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
if expression.expression:
self.unsupported("DUCKDB ARRAY_SORT does not support a comparator")
return f"ARRAY_SORT({self.sql(expression, 'this')})"
def _sort_array_sql(self, expression):
def _sort_array_sql(self: generator.Generator, expression: exp.SortArray) -> str:
this = self.sql(expression, "this")
if expression.args.get("asc") == exp.false():
return f"ARRAY_REVERSE_SORT({this})"
return f"ARRAY_SORT({this})"
def _sort_array_reverse(args):
def _sort_array_reverse(args: t.Sequence) -> exp.Expression:
return exp.SortArray(this=seq_get(args, 0), asc=exp.false())
def _struct_sql(self, expression):
def _parse_date_diff(args: t.Sequence) -> exp.Expression:
return exp.DateDiff(
this=seq_get(args, 2),
expression=seq_get(args, 1),
unit=seq_get(args, 0),
)
def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str:
args = [
f"'{e.name or e.this.name}': {self.sql(e, 'expression')}" for e in expression.expressions
]
return f"{{{', '.join(args)}}}"
def _datatype_sql(self, expression):
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
if expression.this == exp.DataType.Type.ARRAY:
return f"{self.expressions(expression, flat=True)}[]"
return self.datatype_sql(expression)
def _regexp_extract_sql(self, expression):
def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract) -> str:
bad_args = list(filter(expression.args.get, ("position", "occurrence")))
if bad_args:
self.unsupported(f"REGEXP_EXTRACT does not support arg(s) {bad_args}")
return self.func(
"REGEXP_EXTRACT",
expression.args.get("this"),
@ -108,6 +119,8 @@ class DuckDB(Dialect):
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
"ARRAY_SORT": exp.SortArray.from_arg_list,
"ARRAY_REVERSE_SORT": _sort_array_reverse,
"DATEDIFF": _parse_date_diff,
"DATE_DIFF": _parse_date_diff,
"EPOCH": exp.TimeToUnix.from_arg_list,
"EPOCH_MS": lambda args: exp.UnixToTime(
this=exp.Div(
@ -115,18 +128,18 @@ class DuckDB(Dialect):
expression=exp.Literal.number(1000),
)
),
"LIST_SORT": exp.SortArray.from_arg_list,
"LIST_REVERSE_SORT": _sort_array_reverse,
"LIST_SORT": exp.SortArray.from_arg_list,
"LIST_VALUE": exp.Array.from_arg_list,
"REGEXP_MATCHES": exp.RegexpLike.from_arg_list,
"STRFTIME": format_time_lambda(exp.TimeToStr, "duckdb"),
"STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"),
"STR_SPLIT": exp.Split.from_arg_list,
"STRING_SPLIT": exp.Split.from_arg_list,
"STRING_TO_ARRAY": exp.Split.from_arg_list,
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"STRING_TO_ARRAY": exp.Split.from_arg_list,
"STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"),
"STRUCT_PACK": exp.Struct.from_arg_list,
"STR_SPLIT": exp.Split.from_arg_list,
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"TO_TIMESTAMP": exp.UnixToTime.from_arg_list,
"UNNEST": exp.Explode.from_arg_list,
}
@ -142,10 +155,11 @@ class DuckDB(Dialect):
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
LIMIT_FETCH = "LIMIT"
STRUCT_DELIMITER = ("(", ")")
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0])
if isinstance(seq_get(e.expressions, 0), exp.Select)
@ -154,13 +168,16 @@ class DuckDB(Dialect):
exp.ArraySort: _array_sort_sql,
exp.ArraySum: rename_func("LIST_SUM"),
exp.CommentColumnConstraint: no_comment_column_constraint_sql,
exp.CurrentDate: lambda self, e: "CURRENT_DATE",
exp.CurrentTime: lambda self, e: "CURRENT_TIME",
exp.CurrentTimestamp: lambda self, e: "CURRENT_TIMESTAMP",
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.DataType: _datatype_sql,
exp.DateAdd: _date_add,
exp.DateAdd: _date_add_sql,
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", e.args.get("unit") or exp.Literal.string("day"), e.expression, e.this
"DATE_DIFF", f"'{e.args.get('unit', 'day')}'", e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
@ -192,7 +209,7 @@ class DuckDB(Dialect):
exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("EPOCH"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add,
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsToDate: ts_or_ds_to_date_sql("duckdb"),
exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: rename_func("TO_TIMESTAMP"),
@ -201,7 +218,7 @@ class DuckDB(Dialect):
}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BINARY: "BLOB",
exp.DataType.Type.CHAR: "TEXT",
exp.DataType.Type.FLOAT: "REAL",
@ -212,17 +229,14 @@ class DuckDB(Dialect):
exp.DataType.Type.VARCHAR: "TEXT",
}
STAR_MAPPING = {
**generator.Generator.STAR_MAPPING,
"except": "EXCLUDE",
}
STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
LIMIT_FETCH = "LIMIT"
def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str:
return super().tablesample_sql(expression, seed_prefix="REPEATABLE")
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
) -> str:
return super().tablesample_sql(expression, seed_prefix="REPEATABLE", sep=sep)

View file

@ -81,7 +81,20 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
return f"{diff_sql}{multiplier_sql}"
def _array_sort(self: generator.Generator, expression: exp.ArraySort) -> str:
def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str:
this = expression.this
if not this.type:
from sqlglot.optimizer.annotate_types import annotate_types
annotate_types(this)
if this.type.is_type(exp.DataType.Type.JSON):
return self.sql(this)
return self.func("TO_JSON", this, expression.args.get("options"))
def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
if expression.expression:
self.unsupported("Hive SORT_ARRAY does not support a comparator")
return f"SORT_ARRAY({self.sql(expression, 'this')})"
@ -91,11 +104,11 @@ def _property_sql(self: generator.Generator, expression: exp.Property) -> str:
return f"'{expression.name}'={self.sql(expression, 'value')}"
def _str_to_unix(self: generator.Generator, expression: exp.StrToUnix) -> str:
def _str_to_unix_sql(self: generator.Generator, expression: exp.StrToUnix) -> str:
return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression))
def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format not in (Hive.time_format, Hive.date_format):
@ -103,7 +116,7 @@ def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
return f"CAST({this} AS DATE)"
def _str_to_time(self: generator.Generator, expression: exp.StrToTime) -> str:
def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format not in (Hive.time_format, Hive.date_format):
@ -214,6 +227,7 @@ class Hive(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"BASE64": exp.ToBase64.from_arg_list,
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
"DATE_ADD": lambda args: exp.TsOrDsAdd(
this=seq_get(args, 0),
@ -251,6 +265,7 @@ class Hive(Dialect):
"SPLIT": exp.RegexpSplit.from_arg_list,
"TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"),
"TO_JSON": exp.JSONFormat.from_arg_list,
"UNBASE64": exp.FromBase64.from_arg_list,
"UNIX_TIMESTAMP": format_time_lambda(exp.StrToUnix, "hive", True),
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
}
@ -280,16 +295,20 @@ class Hive(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
**transforms.ELIMINATE_QUALIFY, # type: ignore
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Select: transforms.preprocess(
[transforms.eliminate_qualify, transforms.unnest_to_explode]
[
transforms.eliminate_qualify,
transforms.eliminate_distinct_on,
transforms.unnest_to_explode,
]
),
exp.Property: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayJoin: lambda self, e: self.func("CONCAT_WS", e.expression, e.this),
exp.ArraySize: rename_func("SIZE"),
exp.ArraySort: _array_sort,
exp.ArraySort: _array_sort_sql,
exp.With: no_recursive_cte_sql,
exp.DateAdd: _add_date_sql,
exp.DateDiff: _date_diff_sql,
@ -298,12 +317,13 @@ class Hive(Dialect):
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
exp.FromBase64: rename_func("UNBASE64"),
exp.If: if_sql,
exp.Index: _index_sql,
exp.ILike: no_ilike_sql,
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
exp.JSONFormat: rename_func("TO_JSON"),
exp.JSONFormat: _json_format_sql,
exp.Map: var_map_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
@ -318,9 +338,9 @@ class Hive(Dialect):
exp.SetAgg: rename_func("COLLECT_SET"),
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date,
exp.StrToTime: _str_to_time,
exp.StrToUnix: _str_to_unix,
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: _str_to_unix_sql,
exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}",
exp.TimeStrToDate: rename_func("TO_DATE"),
@ -328,6 +348,7 @@ class Hive(Dialect):
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeToStr: _time_to_str,
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.ToBase64: rename_func("BASE64"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.TsOrDsToDate: _to_date_sql,

View file

@ -1,6 +1,6 @@
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
@ -403,6 +403,7 @@ class MySQL(Dialect):
exp.Min: min_or_least,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_date_sql,

View file

@ -34,6 +34,8 @@ def _parse_xml_table(self) -> exp.XMLTable:
class Oracle(Dialect):
alias_post_tablesample = True
# https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212
# https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes
time_mapping = {
@ -121,21 +123,23 @@ class Oracle(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
exp.DateStrToDate: lambda self, e: self.func(
"TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD")
),
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.ILike: no_ilike_sql,
exp.IfNull: rename_func("NVL"),
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
exp.Substring: rename_func("SUBSTR"),
exp.Table: lambda self, e: self.table_sql(e, sep=" "),
exp.TableSample: lambda self, e: self.tablesample_sql(e, sep=" "),
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: trim_sql,
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
exp.IfNull: rename_func("NVL"),
}
PROPERTIES_LOCATION = {
@ -164,14 +168,19 @@ class Oracle(Dialect):
return f"XMLTABLE({self.sep('')}{self.indent(this + passing + by_ref + columns)}{self.seg(')', sep='')}"
class Tokenizer(tokens.Tokenizer):
VAR_SINGLE_TOKENS = {"@"}
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"(+)": TokenType.JOIN_MARKER,
"BINARY_DOUBLE": TokenType.DOUBLE,
"BINARY_FLOAT": TokenType.FLOAT,
"COLUMNS": TokenType.COLUMN,
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"MINUS": TokenType.EXCEPT,
"NVARCHAR2": TokenType.NVARCHAR,
"RETURNING": TokenType.RETURNING,
"SAMPLE": TokenType.TABLE_SAMPLE,
"START": TokenType.BEGIN,
"TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR,

View file

@ -1,6 +1,8 @@
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
@ -20,7 +22,6 @@ from sqlglot.dialects.dialect import (
from sqlglot.helper import seq_get
from sqlglot.parser import binary_range_parser
from sqlglot.tokens import TokenType
from sqlglot.transforms import preprocess, remove_target_from_merge
DATE_DIFF_FACTOR = {
"MICROSECOND": " * 1000000",
@ -274,8 +275,7 @@ class Postgres(Dialect):
TokenType.HASH: exp.BitwiseXor,
}
FACTOR = {
**parser.Parser.FACTOR,
EXPONENT = {
TokenType.CARET: exp.Pow,
}
@ -286,6 +286,12 @@ class Postgres(Dialect):
TokenType.LT_AT: binary_range_parser(exp.ArrayContained),
}
def _parse_factor(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_exponent, self.FACTOR)
def _parse_exponent(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_unary, self.EXPONENT)
def _parse_date_part(self) -> exp.Expression:
part = self._parse_type()
self._match(TokenType.COMMA)
@ -316,7 +322,7 @@ class Postgres(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
exp.ColumnDef: preprocess(
exp.ColumnDef: transforms.preprocess(
[
_auto_increment_to_serial,
_serial_to_generated,
@ -341,7 +347,7 @@ class Postgres(Dialect):
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
exp.Merge: preprocess([remove_target_from_merge]),
exp.Merge: transforms.preprocess([transforms.remove_target_from_merge]),
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
exp.StrPosition: str_position_sql,

View file

@ -130,7 +130,7 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s
def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str:
start = expression.args["start"]
end = expression.args["end"]
step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series
step = expression.args.get("step")
target_type = None
@ -147,7 +147,11 @@ def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) ->
else:
start = exp.Cast(this=start, to=to)
return self.func("SEQUENCE", start, end, step)
sql = self.func("SEQUENCE", start, end, step)
if isinstance(expression.parent, exp.Table):
sql = f"UNNEST({sql})"
return sql
def _ensure_utf8(charset: exp.Literal) -> None:
@ -204,6 +208,7 @@ class Presto(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
"APPROX_PERCENTILE": _approx_percentile,
"CARDINALITY": exp.ArraySize.from_arg_list,
"CONTAINS": exp.ArrayContains.from_arg_list,
"DATE_ADD": lambda args: exp.DateAdd(
@ -219,23 +224,23 @@ class Presto(Dialect):
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
"DATE_TRUNC": date_trunc_to_time,
"FROM_HEX": exp.Unhex.from_arg_list,
"FROM_UNIXTIME": _from_unixtime,
"FROM_UTF8": lambda args: exp.Decode(
this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8")
),
"NOW": exp.CurrentTimestamp.from_arg_list,
"SEQUENCE": exp.GenerateSeries.from_arg_list,
"STRPOS": lambda args: exp.StrPosition(
this=seq_get(args, 0),
substr=seq_get(args, 1),
instance=seq_get(args, 2),
),
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
"APPROX_PERCENTILE": _approx_percentile,
"FROM_HEX": exp.Unhex.from_arg_list,
"TO_HEX": exp.Hex.from_arg_list,
"TO_UTF8": lambda args: exp.Encode(
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
),
"FROM_UTF8": lambda args: exp.Decode(
this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8")
),
}
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
FUNCTION_PARSERS.pop("TRIM")
@ -264,7 +269,6 @@ class Presto(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"),
@ -290,6 +294,7 @@ class Presto(Dialect):
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
exp.Encode: _encode_sql,
exp.GenerateSeries: _sequence_sql,
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql,
exp.ILike: no_ilike_sql,
@ -303,7 +308,11 @@ class Presto(Dialect):
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.Select: transforms.preprocess(
[transforms.eliminate_qualify, transforms.explode_to_unnest]
[
transforms.eliminate_qualify,
transforms.eliminate_distinct_on,
transforms.explode_to_unnest,
]
),
exp.SortArray: _no_sort_array,
exp.StrPosition: rename_func("STRPOS"),
@ -327,6 +336,9 @@ class Presto(Dialect):
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
exp.VariancePop: rename_func("VAR_POP"),
exp.WithinGroup: transforms.preprocess(
[transforms.remove_within_group_for_percentiles]
),
}
def interval_sql(self, expression: exp.Interval) -> str:

View file

@ -52,6 +52,8 @@ class Redshift(Postgres):
return this
class Tokenizer(Postgres.Tokenizer):
BIT_STRINGS = []
HEX_STRINGS = []
STRING_ESCAPES = ["\\"]
KEYWORDS = {
@ -90,7 +92,6 @@ class Redshift(Postgres):
TRANSFORMS = {
**Postgres.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
exp.DateAdd: lambda self, e: self.func(
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
@ -102,6 +103,7 @@ class Redshift(Postgres):
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
exp.JSONExtract: _json_sql,
exp.JSONExtractScalar: _json_sql,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
}

View file

@ -2,7 +2,7 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
date_trunc_to_time,
@ -252,6 +252,7 @@ class Snowflake(Dialect):
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]
STRING_ESCAPES = ["\\", "'"]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@ -305,6 +306,7 @@ class Snowflake(Dialect):
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
exp.StrPosition: lambda self, e: self.func(
"POSITION", e.args.get("substr"), e.this, e.args.get("position")

View file

@ -2,222 +2,54 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, parser
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
from sqlglot.dialects.hive import Hive
from sqlglot import exp
from sqlglot.dialects.spark2 import Spark2
from sqlglot.helper import seq_get
def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
kind = e.args["kind"]
properties = e.args.get("properties")
def _parse_datediff(args: t.Sequence) -> exp.Expression:
"""
Although Spark docs don't mention the "unit" argument, Spark3 added support for
it at some point. Databricks also supports this variation (see below).
if kind.upper() == "TABLE" and any(
isinstance(prop, exp.TemporaryProperty)
for prop in (properties.expressions if properties else [])
):
return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
return create_with_partitions_sql(self, e)
For example, in spark-sql (v3.3.1):
- SELECT DATEDIFF('2020-01-01', '2020-01-05') results in -4
- SELECT DATEDIFF(day, '2020-01-01', '2020-01-05') results in 4
See also:
- https://docs.databricks.com/sql/language-manual/functions/datediff3.html
- https://docs.databricks.com/sql/language-manual/functions/datediff.html
"""
unit = None
this = seq_get(args, 0)
expression = seq_get(args, 1)
if len(args) == 3:
unit = this
this = args[2]
return exp.DateDiff(
this=exp.TsOrDsToDate(this=this), expression=exp.TsOrDsToDate(this=expression), unit=unit
)
def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
keys = self.sql(expression.args["keys"])
values = self.sql(expression.args["values"])
return f"MAP_FROM_ARRAYS({keys}, {values})"
def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Hive.date_format:
return f"TO_DATE({this})"
return f"TO_DATE({this}, {time_format})"
def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale is None:
return f"FROM_UNIXTIME({timestamp})"
if scale == exp.UnixToTime.SECONDS:
return f"TIMESTAMP_SECONDS({timestamp})"
if scale == exp.UnixToTime.MILLIS:
return f"TIMESTAMP_MILLIS({timestamp})"
if scale == exp.UnixToTime.MICROS:
return f"TIMESTAMP_MICROS({timestamp})"
raise ValueError("Improper scale for timestamp")
class Spark(Hive):
class Parser(Hive.Parser):
class Spark(Spark2):
class Parser(Spark2.Parser):
FUNCTIONS = {
**Hive.Parser.FUNCTIONS, # type: ignore
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
"LEFT": lambda args: exp.Substring(
this=seq_get(args, 0),
start=exp.Literal.number(1),
length=seq_get(args, 1),
),
"SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
this=seq_get(args, 0),
expression=seq_get(args, 1),
),
"SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
this=seq_get(args, 0),
expression=seq_get(args, 1),
),
"RIGHT": lambda args: exp.Substring(
this=seq_get(args, 0),
start=exp.Sub(
this=exp.Length(this=seq_get(args, 0)),
expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
),
length=seq_get(args, 1),
),
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"BOOLEAN": lambda args: exp.Cast(
this=seq_get(args, 0), to=exp.DataType.build("boolean")
),
"IIF": exp.If.from_arg_list,
"INT": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("int")),
"AGGREGATE": exp.Reduce.from_arg_list,
"DAYOFWEEK": lambda args: exp.DayOfWeek(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"DAYOFMONTH": lambda args: exp.DayOfMonth(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"DAYOFYEAR": lambda args: exp.DayOfYear(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1),
unit=exp.var(seq_get(args, 0)),
),
"STRING": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("string")),
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
"TIMESTAMP": lambda args: exp.Cast(
this=seq_get(args, 0), to=exp.DataType.build("timestamp")
),
**Spark2.Parser.FUNCTIONS, # type: ignore
"DATEDIFF": _parse_datediff,
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS, # type: ignore
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
"MERGE": lambda self: self._parse_join_hint("MERGE"),
"SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"),
"MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"),
"SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"),
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
}
class Generator(Spark2.Generator):
TRANSFORMS = Spark2.Generator.TRANSFORMS.copy()
TRANSFORMS.pop(exp.DateDiff)
def _parse_add_column(self) -> t.Optional[exp.Expression]:
return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
def datediff_sql(self, expression: exp.DateDiff) -> str:
unit = self.sql(expression, "unit")
end = self.sql(expression, "this")
start = self.sql(expression, "expression")
def _parse_drop_column(self) -> t.Optional[exp.Expression]:
return self._match_text_seq("DROP", "COLUMNS") and self.expression(
exp.Drop,
this=self._parse_schema(),
kind="COLUMNS",
)
if unit:
return self.func("DATEDIFF", unit, start, end)
def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
# Spark doesn't add a suffix to the pivot columns when there's a single aggregation
if len(pivot_columns) == 1:
return [""]
names = []
for agg in pivot_columns:
if isinstance(agg, exp.Alias):
names.append(agg.alias)
else:
"""
This case corresponds to aggregations without aliases being used as suffixes
(e.g. col_avg(foo)). We need to unquote identifiers because they're going to
be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
Moreover, function names are lowercased in order to mimic Spark's naming scheme.
"""
agg_all_unquoted = agg.transform(
lambda node: exp.Identifier(this=node.name, quoted=False)
if isinstance(node, exp.Identifier)
else node
)
names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
return names
class Generator(Hive.Generator):
TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "BYTE",
exp.DataType.Type.SMALLINT: "SHORT",
exp.DataType.Type.BIGINT: "LONG",
}
PROPERTIES_LOCATION = {
**Hive.Generator.PROPERTIES_LOCATION, # type: ignore
exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
exp.CollateProperty: exp.Properties.Location.UNSUPPORTED,
}
TRANSFORMS = {
**Hive.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: _unix_to_time_sql,
exp.Create: _create_sql,
exp.Map: _map_sql,
exp.Reduce: rename_func("AGGREGATE"),
exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
exp.TimestampTrunc: lambda self, e: self.func(
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
),
exp.Trim: trim_sql,
exp.VariancePop: rename_func("VAR_POP"),
exp.DateFromParts: rename_func("MAKE_DATE"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
}
TRANSFORMS.pop(exp.ArraySort)
TRANSFORMS.pop(exp.ILike)
WRAP_DERIVED_VALUES = False
CREATE_FUNCTION_RETURN_AS = False
def cast_sql(self, expression: exp.Cast) -> str:
if isinstance(expression.this, exp.Cast) and expression.this.is_type(
exp.DataType.Type.JSON
):
schema = f"'{self.sql(expression, 'to')}'"
return self.func("FROM_JSON", expression.this.this, schema)
if expression.to.is_type(exp.DataType.Type.JSON):
return self.func("TO_JSON", expression.this)
return super(Spark.Generator, self).cast_sql(expression)
class Tokenizer(Hive.Tokenizer):
HEX_STRINGS = [("X'", "'")]
return self.func("DATEDIFF", end, start)

238
sqlglot/dialects/spark2.py Normal file
View file

@ -0,0 +1,238 @@
from __future__ import annotations
import typing as t
from sqlglot import exp, parser, transforms
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
from sqlglot.dialects.hive import Hive
from sqlglot.helper import seq_get
def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
kind = e.args["kind"]
properties = e.args.get("properties")
if kind.upper() == "TABLE" and any(
isinstance(prop, exp.TemporaryProperty)
for prop in (properties.expressions if properties else [])
):
return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
return create_with_partitions_sql(self, e)
def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
keys = self.sql(expression.args["keys"])
values = self.sql(expression.args["values"])
return f"MAP_FROM_ARRAYS({keys}, {values})"
def _parse_as_cast(to_type: str) -> t.Callable[[t.Sequence], exp.Expression]:
return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type))
def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Hive.date_format:
return f"TO_DATE({this})"
return f"TO_DATE({this}, {time_format})"
def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale is None:
return f"CAST(FROM_UNIXTIME({timestamp}) AS TIMESTAMP)"
if scale == exp.UnixToTime.SECONDS:
return f"TIMESTAMP_SECONDS({timestamp})"
if scale == exp.UnixToTime.MILLIS:
return f"TIMESTAMP_MILLIS({timestamp})"
if scale == exp.UnixToTime.MICROS:
return f"TIMESTAMP_MICROS({timestamp})"
raise ValueError("Improper scale for timestamp")
class Spark2(Hive):
class Parser(Hive.Parser):
FUNCTIONS = {
**Hive.Parser.FUNCTIONS, # type: ignore
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
"LEFT": lambda args: exp.Substring(
this=seq_get(args, 0),
start=exp.Literal.number(1),
length=seq_get(args, 1),
),
"SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
this=seq_get(args, 0),
expression=seq_get(args, 1),
),
"SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
this=seq_get(args, 0),
expression=seq_get(args, 1),
),
"RIGHT": lambda args: exp.Substring(
this=seq_get(args, 0),
start=exp.Sub(
this=exp.Length(this=seq_get(args, 0)),
expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
),
length=seq_get(args, 1),
),
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"IIF": exp.If.from_arg_list,
"AGGREGATE": exp.Reduce.from_arg_list,
"DAYOFWEEK": lambda args: exp.DayOfWeek(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"DAYOFMONTH": lambda args: exp.DayOfMonth(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"DAYOFYEAR": lambda args: exp.DayOfYear(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1),
unit=exp.var(seq_get(args, 0)),
),
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
"BOOLEAN": _parse_as_cast("boolean"),
"DOUBLE": _parse_as_cast("double"),
"FLOAT": _parse_as_cast("float"),
"INT": _parse_as_cast("int"),
"STRING": _parse_as_cast("string"),
"TIMESTAMP": _parse_as_cast("timestamp"),
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS, # type: ignore
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
"MERGE": lambda self: self._parse_join_hint("MERGE"),
"SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"),
"MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"),
"SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"),
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
}
def _parse_add_column(self) -> t.Optional[exp.Expression]:
return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
def _parse_drop_column(self) -> t.Optional[exp.Expression]:
return self._match_text_seq("DROP", "COLUMNS") and self.expression(
exp.Drop,
this=self._parse_schema(),
kind="COLUMNS",
)
def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
# Spark doesn't add a suffix to the pivot columns when there's a single aggregation
if len(pivot_columns) == 1:
return [""]
names = []
for agg in pivot_columns:
if isinstance(agg, exp.Alias):
names.append(agg.alias)
else:
"""
This case corresponds to aggregations without aliases being used as suffixes
(e.g. col_avg(foo)). We need to unquote identifiers because they're going to
be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
Moreover, function names are lowercased in order to mimic Spark's naming scheme.
"""
agg_all_unquoted = agg.transform(
lambda node: exp.Identifier(this=node.name, quoted=False)
if isinstance(node, exp.Identifier)
else node
)
names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
return names
class Generator(Hive.Generator):
TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "BYTE",
exp.DataType.Type.SMALLINT: "SHORT",
exp.DataType.Type.BIGINT: "LONG",
}
PROPERTIES_LOCATION = {
**Hive.Generator.PROPERTIES_LOCATION, # type: ignore
exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
exp.CollateProperty: exp.Properties.Location.UNSUPPORTED,
}
TRANSFORMS = {
**Hive.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
exp.Create: _create_sql,
exp.DateFromParts: rename_func("MAKE_DATE"),
exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Map: _map_sql,
exp.Pivot: transforms.preprocess([transforms.unqualify_pivot_columns]),
exp.Reduce: rename_func("AGGREGATE"),
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimestampTrunc: lambda self, e: self.func(
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
),
exp.Trim: trim_sql,
exp.UnixToTime: _unix_to_time_sql,
exp.VariancePop: rename_func("VAR_POP"),
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
exp.WithinGroup: transforms.preprocess(
[transforms.remove_within_group_for_percentiles]
),
}
TRANSFORMS.pop(exp.ArrayJoin)
TRANSFORMS.pop(exp.ArraySort)
TRANSFORMS.pop(exp.ILike)
WRAP_DERIVED_VALUES = False
CREATE_FUNCTION_RETURN_AS = False
def cast_sql(self, expression: exp.Cast) -> str:
if isinstance(expression.this, exp.Cast) and expression.this.is_type(
exp.DataType.Type.JSON
):
schema = f"'{self.sql(expression, 'to')}'"
return self.func("FROM_JSON", expression.this.this, schema)
if expression.to.is_type(exp.DataType.Type.JSON):
return self.func("TO_JSON", expression.this)
return super(Hive.Generator, self).cast_sql(expression)
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
return super().columndef_sql(
expression,
sep=": "
if isinstance(expression.parent, exp.DataType)
and expression.parent.is_type(exp.DataType.Type.STRUCT)
else sep,
)
class Tokenizer(Hive.Tokenizer):
HEX_STRINGS = [("X'", "'")]

View file

@ -22,6 +22,40 @@ def _date_add_sql(self, expression):
return self.func("DATE", expression.this, modifier)
def _transform_create(expression: exp.Expression) -> exp.Expression:
"""Move primary key to a column and enforce auto_increment on primary keys."""
schema = expression.this
if isinstance(expression, exp.Create) and isinstance(schema, exp.Schema):
defs = {}
primary_key = None
for e in schema.expressions:
if isinstance(e, exp.ColumnDef):
defs[e.name] = e
elif isinstance(e, exp.PrimaryKey):
primary_key = e
if primary_key and len(primary_key.expressions) == 1:
column = defs[primary_key.expressions[0].name]
column.append(
"constraints", exp.ColumnConstraint(kind=exp.PrimaryKeyColumnConstraint())
)
schema.expressions.remove(primary_key)
else:
for column in defs.values():
auto_increment = None
for constraint in column.constraints.copy():
if isinstance(constraint.kind, exp.PrimaryKeyColumnConstraint):
break
if isinstance(constraint.kind, exp.AutoIncrementColumnConstraint):
auto_increment = constraint
if auto_increment:
column.constraints.remove(auto_increment)
return expression
class SQLite(Dialect):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"]
@ -65,8 +99,8 @@ class SQLite(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_QUALIFY, # type: ignore
exp.CountIf: count_if_to_sum,
exp.Create: transforms.preprocess([_transform_create]),
exp.CurrentDate: lambda *_: "CURRENT_DATE",
exp.CurrentTime: lambda *_: "CURRENT_TIME",
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
@ -80,14 +114,17 @@ class SQLite(Dialect):
exp.Levenshtein: rename_func("EDITDIST3"),
exp.LogicalOr: rename_func("MAX"),
exp.LogicalAnd: rename_func("MIN"),
exp.Select: transforms.preprocess(
[transforms.eliminate_distinct_on, transforms.eliminate_qualify]
),
exp.TableSample: no_tablesample_sql,
exp.TimeStrToTime: lambda self, e: self.sql(e, "this"),
exp.TryCast: no_trycast_sql,
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
k: exp.Properties.Location.UNSUPPORTED
for k, v in generator.Generator.PROPERTIES_LOCATION.items()
}
LIMIT_FETCH = "LIMIT"

View file

@ -34,6 +34,7 @@ class StarRocks(MySQL):
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.DateDiff: rename_func("DATEDIFF"),
exp.RegexpLike: rename_func("REGEXP"),
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimestampTrunc: lambda self, e: self.func(
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this

View file

@ -1,6 +1,6 @@
from __future__ import annotations
from sqlglot import exp, generator, parser
from sqlglot import exp, generator, parser, transforms
from sqlglot.dialects.dialect import Dialect
@ -29,6 +29,7 @@ class Tableau(Dialect):
exp.If: _if_sql,
exp.Coalesce: _coalesce_sql,
exp.Count: _count_sql,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
}
PROPERTIES_LOCATION = {

View file

@ -2,7 +2,7 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
format_time_lambda,
@ -148,6 +148,7 @@ class Teradata(Dialect):
**generator.Generator.TRANSFORMS,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.TimeToStr: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import re
import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
max_or_greatest,
@ -259,8 +259,8 @@ class TSQL(Dialect):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]")]
QUOTES = ["'", '"']
HEX_STRINGS = [("0x", ""), ("0X", "")]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@ -463,17 +463,18 @@ class TSQL(Dialect):
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),
exp.CurrentTimestamp: rename_func("GETDATE"),
exp.If: rename_func("IIF"),
exp.NumberToStr: _format_sql,
exp.TimeToStr: _format_sql,
exp.GroupConcat: _string_agg_sql,
exp.If: rename_func("IIF"),
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
exp.Min: min_or_least,
exp.NumberToStr: _format_sql,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
exp.SHA2: lambda self, e: self.func(
"HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
),
exp.TimeToStr: _format_sql,
}
TRANSFORMS.pop(exp.ReturnsProperty)

View file

@ -64,6 +64,13 @@ class Expression(metaclass=_Expression):
and representing expressions as strings.
arg_types: determines what arguments (child nodes) are supported by an expression. It
maps arg keys to booleans that indicate whether the corresponding args are optional.
parent: a reference to the parent expression (or None, in case of root expressions).
arg_key: the arg key an expression is associated with, i.e. the name its parent expression
uses to refer to it.
comments: a list of comments that are associated with a given expression. This is used in
order to preserve comments when transpiling SQL code.
_type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the
optimizer, in order to enable some transformations that require type information.
Example:
>>> class Foo(Expression):
@ -74,13 +81,6 @@ class Expression(metaclass=_Expression):
Args:
args: a mapping used for retrieving the arguments of an expression, given their arg keys.
parent: a reference to the parent expression (or None, in case of root expressions).
arg_key: the arg key an expression is associated with, i.e. the name its parent expression
uses to refer to it.
comments: a list of comments that are associated with a given expression. This is used in
order to preserve comments when transpiling SQL code.
_type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the
optimizer, in order to enable some transformations that require type information.
"""
key = "expression"
@ -258,6 +258,12 @@ class Expression(metaclass=_Expression):
new.parent = self.parent
return new
def add_comments(self, comments: t.Optional[t.List[str]]) -> None:
if self.comments is None:
self.comments = []
if comments:
self.comments.extend(comments)
def append(self, arg_key, value):
"""
Appends value to arg_key if it's a list or sets it as a new list.
@ -650,7 +656,7 @@ ExpOrStr = t.Union[str, Expression]
class Condition(Expression):
def and_(self, *expressions, dialect=None, **opts):
def and_(self, *expressions, dialect=None, copy=True, **opts):
"""
AND this condition with one or multiple expressions.
@ -662,14 +668,15 @@ class Condition(Expression):
*expressions (str | Expression): the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
dialect (str): the dialect used to parse the input expression.
copy (bool): whether or not to copy the involved expressions (only applies to Expressions).
opts (kwargs): other options to use to parse the input expressions.
Returns:
And: the new condition.
"""
return and_(self, *expressions, dialect=dialect, **opts)
return and_(self, *expressions, dialect=dialect, copy=copy, **opts)
def or_(self, *expressions, dialect=None, **opts):
def or_(self, *expressions, dialect=None, copy=True, **opts):
"""
OR this condition with one or multiple expressions.
@ -681,14 +688,15 @@ class Condition(Expression):
*expressions (str | Expression): the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
dialect (str): the dialect used to parse the input expression.
copy (bool): whether or not to copy the involved expressions (only applies to Expressions).
opts (kwargs): other options to use to parse the input expressions.
Returns:
Or: the new condition.
"""
return or_(self, *expressions, dialect=dialect, **opts)
return or_(self, *expressions, dialect=dialect, copy=copy, **opts)
def not_(self):
def not_(self, copy=True):
"""
Wrap this condition with NOT.
@ -696,14 +704,17 @@ class Condition(Expression):
>>> condition("x=1").not_().sql()
'NOT x = 1'
Args:
copy (bool): whether or not to copy this object.
Returns:
Not: the new condition.
"""
return not_(self)
return not_(self, copy=copy)
def _binop(self, klass: t.Type[E], other: ExpOrStr, reverse=False) -> E:
this = self
other = convert(other)
this = self.copy()
other = convert(other, copy=True)
if not isinstance(this, klass) and not isinstance(other, klass):
this = _wrap(this, Binary)
other = _wrap(other, Binary)
@ -711,20 +722,25 @@ class Condition(Expression):
return klass(this=other, expression=this)
return klass(this=this, expression=other)
def __getitem__(self, other: ExpOrStr | slice | t.Tuple[ExpOrStr]):
if isinstance(other, slice):
return Between(
this=self,
low=convert(other.start),
high=convert(other.stop),
)
return Bracket(this=self, expressions=[convert(e) for e in ensure_list(other)])
def __getitem__(self, other: ExpOrStr | t.Tuple[ExpOrStr]):
return Bracket(
this=self.copy(), expressions=[convert(e, copy=True) for e in ensure_list(other)]
)
def isin(self, *expressions: ExpOrStr, query: t.Optional[ExpOrStr] = None, **opts) -> In:
def isin(
self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy=True, **opts
) -> In:
return In(
this=self,
expressions=[convert(e) for e in expressions],
query=maybe_parse(query, **opts) if query else None,
this=_maybe_copy(self, copy),
expressions=[convert(e, copy=copy) for e in expressions],
query=maybe_parse(query, copy=copy, **opts) if query else None,
)
def between(self, low: t.Any, high: t.Any, copy=True, **opts) -> Between:
return Between(
this=_maybe_copy(self, copy),
low=convert(low, copy=copy, **opts),
high=convert(high, copy=copy, **opts),
)
def like(self, other: ExpOrStr) -> Like:
@ -809,10 +825,10 @@ class Condition(Expression):
return self._binop(Or, other, reverse=True)
def __neg__(self) -> Neg:
return Neg(this=_wrap(self, Binary))
return Neg(this=_wrap(self.copy(), Binary))
def __invert__(self) -> Not:
return not_(self)
return not_(self.copy())
class Predicate(Condition):
@ -830,11 +846,7 @@ class DerivedTable(Expression):
@property
def selects(self):
alias = self.args.get("alias")
if alias:
return alias.columns
return []
return self.this.selects if isinstance(self.this, Subqueryable) else []
@property
def named_selects(self):
@ -904,7 +916,10 @@ class Unionable(Expression):
class UDTF(DerivedTable, Unionable):
pass
@property
def selects(self):
alias = self.args.get("alias")
return alias.columns if alias else []
class Cache(Expression):
@ -1073,6 +1088,10 @@ class ColumnDef(Expression):
"position": False,
}
@property
def constraints(self) -> t.List[ColumnConstraint]:
return self.args.get("constraints") or []
class AlterColumn(Expression):
arg_types = {
@ -1100,6 +1119,10 @@ class Comment(Expression):
class ColumnConstraint(Expression):
arg_types = {"this": False, "kind": True}
@property
def kind(self) -> ColumnConstraintKind:
return self.args["kind"]
class ColumnConstraintKind(Expression):
pass
@ -1937,6 +1960,15 @@ class Reference(Expression):
class Tuple(Expression):
arg_types = {"expressions": False}
def isin(
self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy=True, **opts
) -> In:
return In(
this=_maybe_copy(self, copy),
expressions=[convert(e, copy=copy) for e in expressions],
query=maybe_parse(query, copy=copy, **opts) if query else None,
)
class Subqueryable(Unionable):
def subquery(self, alias=None, copy=True) -> Subquery:
@ -2236,6 +2268,8 @@ class Select(Subqueryable):
"expressions": False,
"hint": False,
"distinct": False,
"struct": False, # https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#return_query_results_as_a_value_table
"value": False,
"into": False,
"from": False,
**QUERY_MODIFIERS,
@ -2611,7 +2645,7 @@ class Select(Subqueryable):
join.set("kind", kind.text)
if on:
on = and_(*ensure_collection(on), dialect=dialect, **opts)
on = and_(*ensure_collection(on), dialect=dialect, copy=copy, **opts)
join.set("on", on)
if using:
@ -2723,7 +2757,7 @@ class Select(Subqueryable):
**opts,
)
def distinct(self, distinct=True, copy=True) -> Select:
def distinct(self, *ons: ExpOrStr, distinct: bool = True, copy: bool = True) -> Select:
"""
Set the OFFSET expression.
@ -2732,14 +2766,16 @@ class Select(Subqueryable):
'SELECT DISTINCT x FROM tbl'
Args:
distinct (bool): whether the Select should be distinct
copy (bool): if `False`, modify this expression instance in-place.
ons: the expressions to distinct on
distinct: whether the Select should be distinct
copy: if `False`, modify this expression instance in-place.
Returns:
Select: the modified expression.
"""
instance = _maybe_copy(self, copy)
instance.set("distinct", Distinct() if distinct else None)
on = Tuple(expressions=[maybe_parse(on, copy=copy) for on in ons]) if ons else None
instance.set("distinct", Distinct(on=on) if distinct else None)
return instance
def ctas(self, table, properties=None, dialect=None, copy=True, **opts) -> Create:
@ -2969,6 +3005,10 @@ class DataType(Expression):
USMALLINT = auto()
BIGINT = auto()
UBIGINT = auto()
INT128 = auto()
UINT128 = auto()
INT256 = auto()
UINT256 = auto()
FLOAT = auto()
DOUBLE = auto()
DECIMAL = auto()
@ -3022,6 +3062,8 @@ class DataType(Expression):
Type.TINYINT,
Type.SMALLINT,
Type.BIGINT,
Type.INT128,
Type.INT256,
}
FLOAT_TYPES = {
@ -3069,10 +3111,6 @@ class PseudoType(Expression):
pass
class StructKwarg(Expression):
arg_types = {"this": True, "expression": True}
# WHERE x <OP> EXISTS|ALL|ANY|SOME(SELECT ...)
class SubqueryPredicate(Predicate):
pass
@ -3538,14 +3576,20 @@ class Case(Func):
arg_types = {"this": False, "ifs": True, "default": False}
def when(self, condition: ExpOrStr, then: ExpOrStr, copy: bool = True, **opts) -> Case:
this = self.copy() if copy else self
this.append("ifs", If(this=maybe_parse(condition, **opts), true=maybe_parse(then, **opts)))
return this
instance = _maybe_copy(self, copy)
instance.append(
"ifs",
If(
this=maybe_parse(condition, copy=copy, **opts),
true=maybe_parse(then, copy=copy, **opts),
),
)
return instance
def else_(self, condition: ExpOrStr, copy: bool = True, **opts) -> Case:
this = self.copy() if copy else self
this.set("default", maybe_parse(condition, **opts))
return this
instance = _maybe_copy(self, copy)
instance.set("default", maybe_parse(condition, copy=copy, **opts))
return instance
class Cast(Func):
@ -3760,6 +3804,14 @@ class Floor(Func):
arg_types = {"this": True, "decimals": False}
class FromBase64(Func):
pass
class ToBase64(Func):
pass
class Greatest(Func):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@ -3930,11 +3982,11 @@ class Pow(Binary, Func):
class PercentileCont(AggFunc):
pass
arg_types = {"this": True, "expression": False}
class PercentileDisc(AggFunc):
pass
arg_types = {"this": True, "expression": False}
class Quantile(AggFunc):
@ -4405,14 +4457,16 @@ def _apply_conjunction_builder(
if append and existing is not None:
expressions = [existing.this if into else existing] + list(expressions)
node = and_(*expressions, dialect=dialect, **opts)
node = and_(*expressions, dialect=dialect, copy=copy, **opts)
inst.set(arg, into(this=node) if into else node)
return inst
def _combine(expressions, operator, dialect=None, **opts):
expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions]
def _combine(expressions, operator, dialect=None, copy=True, **opts):
expressions = [
condition(expression, dialect=dialect, copy=copy, **opts) for expression in expressions
]
this = expressions[0]
if expressions[1:]:
this = _wrap(this, Connector)
@ -4626,7 +4680,7 @@ def delete(
return delete_expr
def condition(expression, dialect=None, **opts) -> Condition:
def condition(expression, dialect=None, copy=True, **opts) -> Condition:
"""
Initialize a logical condition expression.
@ -4645,6 +4699,7 @@ def condition(expression, dialect=None, **opts) -> Condition:
If an Expression instance is passed, this is used as-is.
dialect (str): the dialect used to parse the input expression (in the case that the
input expression is a SQL string).
copy (bool): Whether or not to copy `expression` (only applies to expressions).
**opts: other options to use to parse the input expressions (again, in the case
that the input expression is a SQL string).
@ -4655,11 +4710,12 @@ def condition(expression, dialect=None, **opts) -> Condition:
expression,
into=Condition,
dialect=dialect,
copy=copy,
**opts,
)
def and_(*expressions, dialect=None, **opts) -> And:
def and_(*expressions, dialect=None, copy=True, **opts) -> And:
"""
Combine multiple conditions with an AND logical operator.
@ -4671,15 +4727,16 @@ def and_(*expressions, dialect=None, **opts) -> And:
*expressions (str | Expression): the SQL code strings to parse.
If an Expression instance is passed, this is used as-is.
dialect (str): the dialect used to parse the input expression.
copy (bool): whether or not to copy `expressions` (only applies to Expressions).
**opts: other options to use to parse the input expressions.
Returns:
And: the new condition
"""
return _combine(expressions, And, dialect, **opts)
return _combine(expressions, And, dialect, copy=copy, **opts)
def or_(*expressions, dialect=None, **opts) -> Or:
def or_(*expressions, dialect=None, copy=True, **opts) -> Or:
"""
Combine multiple conditions with an OR logical operator.
@ -4691,15 +4748,16 @@ def or_(*expressions, dialect=None, **opts) -> Or:
*expressions (str | Expression): the SQL code strings to parse.
If an Expression instance is passed, this is used as-is.
dialect (str): the dialect used to parse the input expression.
copy (bool): whether or not to copy `expressions` (only applies to Expressions).
**opts: other options to use to parse the input expressions.
Returns:
Or: the new condition
"""
return _combine(expressions, Or, dialect, **opts)
return _combine(expressions, Or, dialect, copy=copy, **opts)
def not_(expression, dialect=None, **opts) -> Not:
def not_(expression, dialect=None, copy=True, **opts) -> Not:
"""
Wrap a condition with a NOT operator.
@ -4719,13 +4777,14 @@ def not_(expression, dialect=None, **opts) -> Not:
this = condition(
expression,
dialect=dialect,
copy=copy,
**opts,
)
return Not(this=_wrap(this, Connector))
def paren(expression) -> Paren:
return Paren(this=expression)
def paren(expression, copy=True) -> Paren:
return Paren(this=_maybe_copy(expression, copy))
SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
@ -4998,29 +5057,20 @@ def values(
alias: optional alias
columns: Optional list of ordered column names or ordered dictionary of column names to types.
If either are provided then an alias is also required.
If a dictionary is provided then the first column of the values will be casted to the expected type
in order to help with type inference.
Returns:
Values: the Values expression object
"""
if columns and not alias:
raise ValueError("Alias is required when providing columns")
table_alias = (
TableAlias(this=to_identifier(alias), columns=[to_identifier(x) for x in columns])
if columns
else TableAlias(this=to_identifier(alias) if alias else None)
)
expressions = [convert(tup) for tup in values]
if columns and isinstance(columns, dict):
types = list(columns.values())
expressions[0].set(
"expressions",
[cast(x, types[i]) for i, x in enumerate(expressions[0].expressions)],
)
return Values(
expressions=expressions,
alias=table_alias,
expressions=[convert(tup) for tup in values],
alias=(
TableAlias(this=to_identifier(alias), columns=[to_identifier(x) for x in columns])
if columns
else (TableAlias(this=to_identifier(alias)) if alias else None)
),
)
@ -5068,19 +5118,20 @@ def rename_table(old_name: str | Table, new_name: str | Table) -> AlterTable:
)
def convert(value) -> Expression:
def convert(value: t.Any, copy: bool = False) -> Expression:
"""Convert a python value into an expression object.
Raises an error if a conversion is not possible.
Args:
value (Any): a python object
value: A python object.
copy: Whether or not to copy `value` (only applies to Expressions and collections).
Returns:
Expression: the equivalent expression object
Expression: the equivalent expression object.
"""
if isinstance(value, Expression):
return value
return _maybe_copy(value, copy)
if isinstance(value, str):
return Literal.string(value)
if isinstance(value, bool):
@ -5098,13 +5149,13 @@ def convert(value) -> Expression:
date_literal = Literal.string(value.strftime("%Y-%m-%d"))
return DateStrToDate(this=date_literal)
if isinstance(value, tuple):
return Tuple(expressions=[convert(v) for v in value])
return Tuple(expressions=[convert(v, copy=copy) for v in value])
if isinstance(value, list):
return Array(expressions=[convert(v) for v in value])
return Array(expressions=[convert(v, copy=copy) for v in value])
if isinstance(value, dict):
return Map(
keys=[convert(k) for k in value],
values=[convert(v) for v in value.values()],
keys=[convert(k, copy=copy) for k in value],
values=[convert(v, copy=copy) for v in value.values()],
)
raise ValueError(f"Cannot convert {value}")

View file

@ -25,6 +25,12 @@ class Generator:
quote_end (str): specifies which ending character to use to delimit quotes. Default: '.
identifier_start (str): specifies which starting character to use to delimit identifiers. Default: ".
identifier_end (str): specifies which ending character to use to delimit identifiers. Default: ".
bit_start (str): specifies which starting character to use to delimit bit literals. Default: None.
bit_end (str): specifies which ending character to use to delimit bit literals. Default: None.
hex_start (str): specifies which starting character to use to delimit hex literals. Default: None.
hex_end (str): specifies which ending character to use to delimit hex literals. Default: None.
byte_start (str): specifies which starting character to use to delimit byte literals. Default: None.
byte_end (str): specifies which ending character to use to delimit byte literals. Default: None.
identify (bool | str): 'always': always quote, 'safe': quote identifiers if they don't contain an upcase, True defaults to always.
normalize (bool): if set to True all identifiers will lower cased
string_escape (str): specifies a string escape character. Default: '.
@ -227,6 +233,12 @@ class Generator:
"quote_end",
"identifier_start",
"identifier_end",
"bit_start",
"bit_end",
"hex_start",
"hex_end",
"byte_start",
"byte_end",
"identify",
"normalize",
"string_escape",
@ -258,6 +270,12 @@ class Generator:
quote_end=None,
identifier_start=None,
identifier_end=None,
bit_start=None,
bit_end=None,
hex_start=None,
hex_end=None,
byte_start=None,
byte_end=None,
identify=False,
normalize=False,
string_escape=None,
@ -284,6 +302,12 @@ class Generator:
self.quote_end = quote_end or "'"
self.identifier_start = identifier_start or '"'
self.identifier_end = identifier_end or '"'
self.bit_start = bit_start
self.bit_end = bit_end
self.hex_start = hex_start
self.hex_end = hex_end
self.byte_start = byte_start
self.byte_end = byte_end
self.identify = identify
self.normalize = normalize
self.string_escape = string_escape or "'"
@ -361,7 +385,7 @@ class Generator:
expression: t.Optional[exp.Expression] = None,
comments: t.Optional[t.List[str]] = None,
) -> str:
comments = (comments or (expression and expression.comments)) if self._comments else None # type: ignore
comments = ((expression and expression.comments) if comments is None else comments) if self._comments else None # type: ignore
if not comments or isinstance(expression, exp.Binary):
return sql
@ -510,12 +534,12 @@ class Generator:
position = self.sql(expression, "position")
return f"{position}{this}"
def columndef_sql(self, expression: exp.ColumnDef) -> str:
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
column = self.sql(expression, "this")
kind = self.sql(expression, "kind")
constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
exists = "IF NOT EXISTS " if expression.args.get("exists") else ""
kind = f" {kind}" if kind else ""
kind = f"{sep}{kind}" if kind else ""
constraints = f" {constraints}" if constraints else ""
position = self.sql(expression, "position")
position = f" {position}" if position else ""
@ -524,7 +548,7 @@ class Generator:
def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str:
this = self.sql(expression, "this")
kind_sql = self.sql(expression, "kind")
kind_sql = self.sql(expression, "kind").strip()
return f"CONSTRAINT {this} {kind_sql}" if this else kind_sql
def autoincrementcolumnconstraint_sql(self, _) -> str:
@ -716,13 +740,22 @@ class Generator:
return f"{alias}{columns}"
def bitstring_sql(self, expression: exp.BitString) -> str:
return self.sql(expression, "this")
this = self.sql(expression, "this")
if self.bit_start:
return f"{self.bit_start}{this}{self.bit_end}"
return f"{int(this, 2)}"
def hexstring_sql(self, expression: exp.HexString) -> str:
return self.sql(expression, "this")
this = self.sql(expression, "this")
if self.hex_start:
return f"{self.hex_start}{this}{self.hex_end}"
return f"{int(this, 16)}"
def bytestring_sql(self, expression: exp.ByteString) -> str:
return self.sql(expression, "this")
this = self.sql(expression, "this")
if self.byte_start:
return f"{self.byte_start}{this}{self.byte_end}"
return this
def datatype_sql(self, expression: exp.DataType) -> str:
type_value = expression.this
@ -1115,10 +1148,12 @@ class Generator:
return f"{table}{system_time}{alias}{hints}{laterals}{joins}{pivots}"
def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str:
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
) -> str:
if self.alias_post_tablesample and expression.this.alias:
this = self.sql(expression.this, "this")
alias = f" AS {self.sql(expression.this, 'alias')}"
alias = f"{sep}{self.sql(expression.this, 'alias')}"
else:
this = self.sql(expression, "this")
alias = ""
@ -1447,16 +1482,16 @@ class Generator:
)
def select_sql(self, expression: exp.Select) -> str:
kind = expression.args.get("kind")
kind = f" AS {kind}" if kind else ""
hint = self.sql(expression, "hint")
distinct = self.sql(expression, "distinct")
distinct = f" {distinct}" if distinct else ""
kind = expression.args.get("kind")
kind = f" AS {kind}" if kind else ""
expressions = self.expressions(expression)
expressions = f"{self.sep()}{expressions}" if expressions else expressions
sql = self.query_modifiers(
expression,
f"SELECT{kind}{hint}{distinct}{expressions}",
f"SELECT{hint}{distinct}{kind}{expressions}",
self.sql(expression, "into", comment=False),
self.sql(expression, "from", comment=False),
)
@ -1475,9 +1510,6 @@ class Generator:
replace = f"{self.seg(self.STAR_MAPPING['replace'])} ({replace})" if replace else ""
return f"*{except_}{replace}"
def structkwarg_sql(self, expression: exp.StructKwarg) -> str:
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
def parameter_sql(self, expression: exp.Parameter) -> str:
this = self.sql(expression, "this")
this = f"{{{this}}}" if expression.args.get("wrapped") else f"{this}"
@ -1806,7 +1838,7 @@ class Generator:
return self.binary(expression, op)
sqls = tuple(
self.maybe_comment(self.sql(e), e, e.parent.comments) if i != 1 else self.sql(e)
self.maybe_comment(self.sql(e), e, e.parent.comments or []) if i != 1 else self.sql(e)
for i, e in enumerate(expression.flatten(unnest=False))
)

View file

@ -153,7 +153,7 @@ def join_condition(join):
#
# should pull y.b as the join key and x.a as the source key
if normalized(on):
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False)
for condition in on.flatten():
if isinstance(condition, exp.EQ):

View file

@ -29,6 +29,6 @@ def expand_laterals(expression: exp.Expression) -> exp.Expression:
for column in projection.find_all(exp.Column):
if not column.table and column.name in alias_to_expression:
column.replace(alias_to_expression[column.name].copy())
if isinstance(projection, exp.Alias):
alias_to_expression[projection.alias] = projection.this
if isinstance(projection, exp.Alias):
alias_to_expression[projection.alias] = projection.this
return expression

View file

@ -152,12 +152,14 @@ def _distribute(a, b, from_func, to_func, cache):
lambda c: to_func(
uniq_sort(flatten(from_func(c, b.left)), cache),
uniq_sort(flatten(from_func(c, b.right)), cache),
copy=False,
),
)
else:
a = to_func(
uniq_sort(flatten(from_func(a, b.left)), cache),
uniq_sort(flatten(from_func(a, b.right)), cache),
copy=False,
)
return a

View file

@ -10,7 +10,6 @@ from sqlglot.optimizer.canonicalize import canonicalize
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.expand_laterals import expand_laterals
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.lower_identities import lower_identities
@ -30,7 +29,6 @@ RULES = (
qualify_tables,
isolate_table_selects,
qualify_columns,
expand_laterals,
pushdown_projections,
validate_qualify_columns,
normalize,

View file

@ -3,11 +3,12 @@ import typing as t
from sqlglot import alias, exp
from sqlglot.errors import OptimizeError
from sqlglot.optimizer.expand_laterals import expand_laterals as _expand_laterals
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
def qualify_columns(expression, schema):
def qualify_columns(expression, schema, expand_laterals=True):
"""
Rewrite sqlglot AST to have fully qualified columns.
@ -26,6 +27,9 @@ def qualify_columns(expression, schema):
"""
schema = ensure_schema(schema)
if not schema.mapping and expand_laterals:
expression = _expand_laterals(expression)
for scope in traverse_scope(expression):
resolver = Resolver(scope, schema)
_pop_table_column_aliases(scope.ctes)
@ -39,6 +43,9 @@ def qualify_columns(expression, schema):
_expand_group_by(scope, resolver)
_expand_order_by(scope)
if schema.mapping and expand_laterals:
expression = _expand_laterals(expression)
return expression
@ -124,7 +131,7 @@ def _expand_using(scope, resolver):
tables[join_table] = None
join.args.pop("using")
join.set("on", exp.and_(*conditions))
join.set("on", exp.and_(*conditions, copy=False))
if column_tables:
for column in scope.columns:
@ -240,7 +247,9 @@ def _qualify_columns(scope, resolver):
# column_table can be a '' because bigquery unnest has no table alias
if column_table:
column.set("table", column_table)
elif column_table not in scope.sources:
elif column_table not in scope.sources and (
not scope.parent or column_table not in scope.parent.sources
):
# structs are used like tables (e.g. "struct"."field"), so they need to be qualified
# separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
@ -376,10 +385,13 @@ def _qualify_outputs(scope):
if not selection.output_name:
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
elif not isinstance(selection, exp.Alias) and not selection.is_star:
alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}")
alias_.set("this", selection)
selection = alias_
selection = alias(
selection,
alias=selection.output_name or f"_col_{i}",
quoted=True
if isinstance(selection, exp.Column) and selection.this.quoted
else None,
)
if aliased_column:
selection.set("alias", exp.to_identifier(aliased_column))

View file

@ -7,21 +7,29 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
def qualify_tables(expression, db=None, catalog=None, schema=None):
"""
Rewrite sqlglot AST to have fully qualified tables.
Rewrite sqlglot AST to have fully qualified tables. Additionally, this
replaces "join constructs" (*) by equivalent SELECT * subqueries.
Example:
Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
>>> qualify_tables(expression, db="db").sql()
'SELECT 1 FROM db.tbl AS tbl'
>>>
>>> expression = sqlglot.parse_one("SELECT * FROM (tbl1 JOIN tbl2 ON id1 = id2)")
>>> qualify_tables(expression).sql()
'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0'
Args:
expression (sqlglot.Expression): expression to qualify
db (str): Database name
catalog (str): Catalog name
schema: A schema to populate
Returns:
sqlglot.Expression: qualified expression
(*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html
"""
sequence = itertools.count()
@ -29,6 +37,12 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
for scope in traverse_scope(expression):
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
# Expand join construct
if isinstance(derived_table, exp.Subquery):
unnested = derived_table.unnest()
if isinstance(unnested, exp.Table):
derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
if not derived_table.args.get("alias"):
alias_ = f"_q_{next(sequence)}"
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))

View file

@ -510,6 +510,9 @@ def _traverse_scope(scope):
yield from _traverse_union(scope)
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
elif isinstance(scope.expression, exp.Table):
# This case corresponds to a "join construct", i.e. (tbl1 JOIN tbl2 ON ..)
yield from _traverse_tables(scope)
elif isinstance(scope.expression, exp.UDTF):
pass
else:
@ -587,6 +590,9 @@ def _traverse_tables(scope):
for join in scope.expression.args.get("joins") or []:
expressions.append(join.this)
if isinstance(scope.expression, exp.Table):
expressions.append(scope.expression)
expressions.extend(scope.expression.args.get("laterals") or [])
for expression in expressions:

View file

@ -60,6 +60,7 @@ def rewrite_between(expression: exp.Expression) -> exp.Expression:
return exp.and_(
exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
copy=False,
)
return expression
@ -76,9 +77,17 @@ def simplify_not(expression):
if isinstance(expression.this, exp.Paren):
condition = expression.this.unnest()
if isinstance(condition, exp.And):
return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
return exp.or_(
exp.not_(condition.left, copy=False),
exp.not_(condition.right, copy=False),
copy=False,
)
if isinstance(condition, exp.Or):
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
return exp.and_(
exp.not_(condition.left, copy=False),
exp.not_(condition.right, copy=False),
copy=False,
)
if is_null(condition):
return exp.null()
if always_true(expression.this):
@ -254,12 +263,12 @@ def uniq_sort(expression, cache=None, root=True):
# A AND C AND B -> A AND B AND C
for i, (sql, e) in enumerate(arr[1:]):
if sql < arr[i][0]:
expression = result_func(*(e for _, e in sorted(arr)))
expression = result_func(*(e for _, e in sorted(arr)), copy=False)
break
else:
# we didn't have to sort but maybe we need to dedup
if len(deduped) < len(flattened):
expression = result_func(*deduped.values())
expression = result_func(*deduped.values(), copy=False)
return expression

View file

@ -126,9 +126,17 @@ class Parser(metaclass=_Parser):
TokenType.BIT,
TokenType.BOOLEAN,
TokenType.TINYINT,
TokenType.UTINYINT,
TokenType.SMALLINT,
TokenType.USMALLINT,
TokenType.INT,
TokenType.UINT,
TokenType.BIGINT,
TokenType.UBIGINT,
TokenType.INT128,
TokenType.UINT128,
TokenType.INT256,
TokenType.UINT256,
TokenType.FLOAT,
TokenType.DOUBLE,
TokenType.CHAR,
@ -961,14 +969,15 @@ class Parser(metaclass=_Parser):
The target expression.
"""
instance = exp_class(**kwargs)
if self._prev_comments:
instance.comments = self._prev_comments
self._prev_comments = None
if comments:
instance.comments = comments
instance.add_comments(comments) if comments else self._add_comments(instance)
self.validate_expression(instance)
return instance
def _add_comments(self, expression: t.Optional[exp.Expression]) -> None:
if expression and self._prev_comments:
expression.add_comments(self._prev_comments)
self._prev_comments = None
def validate_expression(
self, expression: exp.Expression, args: t.Optional[t.List] = None
) -> None:
@ -1567,7 +1576,7 @@ class Parser(metaclass=_Parser):
value = self.expression(
exp.Schema,
this="TABLE",
expressions=self._parse_csv(self._parse_struct_kwargs),
expressions=self._parse_csv(self._parse_struct_types),
)
if not self._match(TokenType.GT):
self.raise_error("Expecting >")
@ -1802,14 +1811,15 @@ class Parser(metaclass=_Parser):
elif self._match(TokenType.SELECT):
comments = self._prev_comments
hint = self._parse_hint()
all_ = self._match(TokenType.ALL)
distinct = self._match(TokenType.DISTINCT)
kind = (
self._match(TokenType.ALIAS)
and self._match_texts(("STRUCT", "VALUE"))
and self._prev.text
)
hint = self._parse_hint()
all_ = self._match(TokenType.ALL)
distinct = self._match(TokenType.DISTINCT)
if distinct:
distinct = self.expression(
@ -2284,7 +2294,7 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.UNNEST):
return None
expressions = self._parse_wrapped_csv(self._parse_column)
expressions = self._parse_wrapped_csv(self._parse_type)
ordinality = self._match_pair(TokenType.WITH, TokenType.ORDINALITY)
alias = self._parse_table_alias()
@ -2333,7 +2343,9 @@ class Parser(metaclass=_Parser):
size = None
seed = None
kind = "TABLESAMPLE" if self._prev.token_type == TokenType.TABLE_SAMPLE else "USING SAMPLE"
kind = (
self._prev.text if self._prev.token_type == TokenType.TABLE_SAMPLE else "USING SAMPLE"
)
method = self._parse_var(tokens=(TokenType.ROW,))
self._match(TokenType.L_PAREN)
@ -2684,7 +2696,7 @@ class Parser(metaclass=_Parser):
else:
this = self.expression(exp.In, this=this, expressions=expressions)
self._match_r_paren()
self._match_r_paren(this)
else:
this = self.expression(exp.In, this=this, field=self._parse_field())
@ -2798,7 +2810,7 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.L_PAREN):
if is_struct:
expressions = self._parse_csv(self._parse_struct_kwargs)
expressions = self._parse_csv(self._parse_struct_types)
elif nested:
expressions = self._parse_csv(self._parse_types)
else:
@ -2833,7 +2845,7 @@ class Parser(metaclass=_Parser):
values: t.Optional[t.List[t.Optional[exp.Expression]]] = None
if nested and self._match(TokenType.LT):
if is_struct:
expressions = self._parse_csv(self._parse_struct_kwargs)
expressions = self._parse_csv(self._parse_struct_types)
else:
expressions = self._parse_csv(self._parse_types)
@ -2891,16 +2903,10 @@ class Parser(metaclass=_Parser):
prefix=prefix,
)
def _parse_struct_kwargs(self) -> t.Optional[exp.Expression]:
index = self._index
this = self._parse_id_var()
def _parse_struct_types(self) -> t.Optional[exp.Expression]:
this = self._parse_type() or self._parse_id_var()
self._match(TokenType.COLON)
data_type = self._parse_types()
if not data_type:
self._retreat(index)
return self._parse_types()
return self.expression(exp.StructKwarg, this=this, expression=data_type)
return self._parse_column_def(this)
def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match(TokenType.AT_TIME_ZONE):
@ -2932,7 +2938,11 @@ class Parser(metaclass=_Parser):
else exp.Literal.string(value)
)
else:
field = self._parse_star() or self._parse_function() or self._parse_id_var()
field = (
self._parse_star()
or self._parse_function(anonymous=True)
or self._parse_id_var()
)
if isinstance(field, exp.Func):
# bigquery allows function calls like x.y.count(...)
@ -2995,11 +3005,9 @@ class Parser(metaclass=_Parser):
else:
this = self.expression(exp.Paren, this=self._parse_set_operations(this))
self._match_r_paren()
comments.extend(self._prev_comments)
if this and comments:
this.comments = comments
if this:
this.add_comments(comments)
self._match_r_paren(expression=this)
return this
@ -3017,7 +3025,7 @@ class Parser(metaclass=_Parser):
)
def _parse_function(
self, functions: t.Optional[t.Dict[str, t.Callable]] = None
self, functions: t.Optional[t.Dict[str, t.Callable]] = None, anonymous: bool = False
) -> t.Optional[exp.Expression]:
if not self._curr:
return None
@ -3043,7 +3051,7 @@ class Parser(metaclass=_Parser):
parser = self.FUNCTION_PARSERS.get(upper)
if parser:
if parser and not anonymous:
this = parser(self)
else:
subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type)
@ -3059,7 +3067,7 @@ class Parser(metaclass=_Parser):
function = functions.get(upper)
args = self._parse_csv(self._parse_lambda)
if function:
if function and not anonymous:
# Clickhouse supports function calls like foo(x, y)(z), so for these we need to also parse the
# second parameter list (i.e. "(z)") and the corresponding function will receive both arg lists.
if count_params(function) == 2:
@ -3148,12 +3156,7 @@ class Parser(metaclass=_Parser):
if isinstance(left, exp.Column):
left.replace(exp.Var(this=left.text("this")))
if self._match(TokenType.IGNORE_NULLS):
this = self.expression(exp.IgnoreNulls, this=this)
else:
self._match(TokenType.RESPECT_NULLS)
return self._parse_limit(self._parse_order(this))
return self._parse_limit(self._parse_order(self._parse_respect_or_ignore_nulls(this)))
def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
index = self._index
@ -3177,6 +3180,9 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Schema, this=this, expressions=args)
def _parse_column_def(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
# column defs are not really columns, they're identifiers
if isinstance(this, exp.Column):
this = this.this
kind = self._parse_types()
if self._match_text_seq("FOR", "ORDINALITY"):
@ -3420,7 +3426,7 @@ class Parser(metaclass=_Parser):
elif not self._match(TokenType.R_BRACE) and bracket_kind == TokenType.L_BRACE:
self.raise_error("Expected }")
this.comments = self._prev_comments
self._add_comments(this)
return self._parse_bracket(this)
def _parse_slice(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
@ -3584,7 +3590,9 @@ class Parser(metaclass=_Parser):
exp.and_(
exp.Is(this=expression.copy(), expression=exp.Null()),
exp.Is(this=search.copy(), expression=exp.Null()),
copy=False,
),
copy=False,
)
ifs.append(exp.If(this=cond, true=result))
@ -3717,15 +3725,15 @@ class Parser(metaclass=_Parser):
if self._match_set(self.TRIM_TYPES):
position = self._prev.text.upper()
expression = self._parse_term()
expression = self._parse_bitwise()
if self._match_set((TokenType.FROM, TokenType.COMMA)):
this = self._parse_term()
this = self._parse_bitwise()
else:
this = expression
expression = None
if self._match(TokenType.COLLATE):
collation = self._parse_term()
collation = self._parse_bitwise()
return self.expression(
exp.Trim,
@ -3741,6 +3749,15 @@ class Parser(metaclass=_Parser):
def _parse_named_window(self) -> t.Optional[exp.Expression]:
return self._parse_window(self._parse_id_var(), alias=True)
def _parse_respect_or_ignore_nulls(
self, this: t.Optional[exp.Expression]
) -> t.Optional[exp.Expression]:
if self._match(TokenType.IGNORE_NULLS):
return self.expression(exp.IgnoreNulls, this=this)
if self._match(TokenType.RESPECT_NULLS):
return self.expression(exp.RespectNulls, this=this)
return this
def _parse_window(
self, this: t.Optional[exp.Expression], alias: bool = False
) -> t.Optional[exp.Expression]:
@ -3768,10 +3785,7 @@ class Parser(metaclass=_Parser):
# (https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/img_text/first_value.html)
# and Snowflake chose to do the same for familiarity
# https://docs.snowflake.com/en/sql-reference/functions/first_value.html#usage-notes
if self._match(TokenType.IGNORE_NULLS):
this = self.expression(exp.IgnoreNulls, this=this)
elif self._match(TokenType.RESPECT_NULLS):
this = self.expression(exp.RespectNulls, this=this)
this = self._parse_respect_or_ignore_nulls(this)
# bigquery select from window x AS (partition by ...)
if alias:
@ -3975,9 +3989,7 @@ class Parser(metaclass=_Parser):
items = [parse_result] if parse_result is not None else []
while self._match(sep):
if parse_result and self._prev_comments:
parse_result.comments = self._prev_comments
self._add_comments(parse_result)
parse_result = parse_method()
if parse_result is not None:
items.append(parse_result)
@ -4345,13 +4357,14 @@ class Parser(metaclass=_Parser):
self._retreat(index)
return None
def _match(self, token_type, advance=True):
def _match(self, token_type, advance=True, expression=None):
if not self._curr:
return None
if self._curr.token_type == token_type:
if advance:
self._advance()
self._add_comments(expression)
return True
return None
@ -4379,16 +4392,12 @@ class Parser(metaclass=_Parser):
return None
def _match_l_paren(self, expression=None):
if not self._match(TokenType.L_PAREN):
if not self._match(TokenType.L_PAREN, expression=expression):
self.raise_error("Expecting (")
if expression and self._prev_comments:
expression.comments = self._prev_comments
def _match_r_paren(self, expression=None):
if not self._match(TokenType.R_PAREN):
if not self._match(TokenType.R_PAREN, expression=expression):
self.raise_error("Expecting )")
if expression and self._prev_comments:
expression.comments = self._prev_comments
def _match_texts(self, texts, advance=True):
if self._curr and self._curr.text.upper() in texts:

View file

@ -84,6 +84,10 @@ class TokenType(AutoName):
UINT = auto()
BIGINT = auto()
UBIGINT = auto()
INT128 = auto()
UINT128 = auto()
INT256 = auto()
UINT256 = auto()
FLOAT = auto()
DOUBLE = auto()
DECIMAL = auto()
@ -774,8 +778,6 @@ class Tokenizer(metaclass=_Tokenizer):
"_end",
"_peek",
"_prev_token_line",
"_prev_token_comments",
"_prev_token_type",
)
def __init__(self) -> None:
@ -795,8 +797,6 @@ class Tokenizer(metaclass=_Tokenizer):
self._end = False
self._peek = ""
self._prev_token_line = -1
self._prev_token_comments: t.List[str] = []
self._prev_token_type: t.Optional[TokenType] = None
def tokenize(self, sql: str) -> t.List[Token]:
"""Returns a list of tokens corresponding to the SQL string `sql`."""
@ -846,7 +846,7 @@ class Tokenizer(metaclass=_Tokenizer):
return self.sql[start:end]
return ""
def _advance(self, i: int = 1) -> None:
def _advance(self, i: int = 1, alnum: bool = False) -> None:
if self.WHITE_SPACE.get(self._char) is TokenType.BREAK:
self._col = 1
self._line += 1
@ -858,14 +858,30 @@ class Tokenizer(metaclass=_Tokenizer):
self._char = self.sql[self._current - 1]
self._peek = "" if self._end else self.sql[self._current]
if alnum and self._char.isalnum():
_col = self._col
_current = self._current
_end = self._end
_peek = self._peek
while _peek.isalnum():
_col += 1
_current += 1
_end = _current >= self.size
_peek = "" if _end else self.sql[_current]
self._col = _col
self._current = _current
self._end = _end
self._peek = _peek
self._char = self.sql[_current - 1]
@property
def _text(self) -> str:
return self.sql[self._start : self._current]
def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
self._prev_token_line = self._line
self._prev_token_comments = self._comments
self._prev_token_type = token_type
self.tokens.append(
Token(
token_type,
@ -966,13 +982,13 @@ class Tokenizer(metaclass=_Tokenizer):
comment_end_size = len(comment_end)
while not self._end and self._chars(comment_end_size) != comment_end:
self._advance()
self._advance(alnum=True)
self._comments.append(self._text[comment_start_size : -comment_end_size + 1])
self._advance(comment_end_size - 1)
else:
while not self._end and not self.WHITE_SPACE.get(self._peek) is TokenType.BREAK:
self._advance()
self._advance(alnum=True)
self._comments.append(self._text[comment_start_size:])
# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding.
@ -988,9 +1004,9 @@ class Tokenizer(metaclass=_Tokenizer):
if self._char == "0":
peek = self._peek.upper()
if peek == "B":
return self._scan_bits()
return self._scan_bits() if self._BIT_STRINGS else self._add(TokenType.NUMBER)
elif peek == "X":
return self._scan_hex()
return self._scan_hex() if self._HEX_STRINGS else self._add(TokenType.NUMBER)
decimal = False
scientific = 0
@ -1033,7 +1049,9 @@ class Tokenizer(metaclass=_Tokenizer):
self._advance()
value = self._extract_value()
try:
self._add(TokenType.BIT_STRING, f"{int(value, 2)}")
# If `value` can't be converted to a binary, fallback to tokenizing it as an identifier
int(value, 2)
self._add(TokenType.BIT_STRING, value[2:]) # Drop the 0b
except ValueError:
self._add(TokenType.IDENTIFIER)
@ -1041,7 +1059,9 @@ class Tokenizer(metaclass=_Tokenizer):
self._advance()
value = self._extract_value()
try:
self._add(TokenType.HEX_STRING, f"{int(value, 16)}")
# If `value` can't be converted to a hex, fallback to tokenizing it as an identifier
int(value, 16)
self._add(TokenType.HEX_STRING, value[2:]) # Drop the 0x
except ValueError:
self._add(TokenType.IDENTIFIER)
@ -1049,7 +1069,7 @@ class Tokenizer(metaclass=_Tokenizer):
while True:
char = self._peek.strip()
if char and char not in self.SINGLE_TOKENS:
self._advance()
self._advance(alnum=True)
else:
break
@ -1066,7 +1086,7 @@ class Tokenizer(metaclass=_Tokenizer):
self._add(TokenType.NATIONAL if quote[0].upper() == "N" else TokenType.STRING, text)
return True
# X'1234, b'0110', E'\\\\\' etc.
# X'1234', b'0110', E'\\\\\' etc.
def _scan_formatted_string(self, string_start: str) -> bool:
if string_start in self._HEX_STRINGS:
delimiters = self._HEX_STRINGS
@ -1087,60 +1107,43 @@ class Tokenizer(metaclass=_Tokenizer):
string_end = delimiters[string_start]
text = self._extract_string(string_end)
if base is None:
self._add(token_type, text)
else:
if base:
try:
self._add(token_type, f"{int(text, base)}")
int(text, base)
except:
raise RuntimeError(
f"Numeric string contains invalid characters from {self._line}:{self._start}"
)
self._add(token_type, text)
return True
def _scan_identifier(self, identifier_end: str) -> None:
text = ""
identifier_end_is_escape = identifier_end in self._IDENTIFIER_ESCAPES
while True:
if self._end:
raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}")
self._advance()
if self._char == identifier_end:
if identifier_end_is_escape and self._peek == identifier_end:
text += identifier_end
self._advance()
continue
break
text += self._char
self._advance()
text = self._extract_string(identifier_end, self._IDENTIFIER_ESCAPES)
self._add(TokenType.IDENTIFIER, text)
def _scan_var(self) -> None:
while True:
char = self._peek.strip()
if char and (char in self.VAR_SINGLE_TOKENS or char not in self.SINGLE_TOKENS):
self._advance()
self._advance(alnum=True)
else:
break
self._add(
TokenType.VAR
if self._prev_token_type == TokenType.PARAMETER
if self.tokens and self.tokens[-1].token_type == TokenType.PARAMETER
else self.KEYWORDS.get(self._text.upper(), TokenType.VAR)
)
def _extract_string(self, delimiter: str) -> str:
def _extract_string(self, delimiter: str, escapes=None) -> str:
text = ""
delim_size = len(delimiter)
escapes = self._STRING_ESCAPES if escapes is None else escapes
while True:
if self._char in self._STRING_ESCAPES and (
self._peek == delimiter or self._peek in self._STRING_ESCAPES
):
if self._char in escapes and (self._peek == delimiter or self._peek in escapes):
if self._peek == delimiter:
text += self._peek
else:
@ -1158,7 +1161,9 @@ class Tokenizer(metaclass=_Tokenizer):
if self._end:
raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}")
text += self._char
self._advance()
current = self._current - 1
self._advance(alnum=True)
text += self.sql[current : self._current - 1]
return text

View file

@ -121,20 +121,9 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
other expressions. This transforms removes the precision from parameterized types in expressions.
"""
return expression.transform(
lambda node: exp.DataType(
**{
**node.args,
"expressions": [
node_expression
for node_expression in node.expressions
if isinstance(node_expression, exp.DataType)
],
}
)
if isinstance(node, exp.DataType)
else node,
)
for node in expression.find_all(exp.DataType):
node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)])
return expression
def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
@ -240,12 +229,36 @@ def remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
return expression
def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
if (
isinstance(expression, exp.WithinGroup)
and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc))
and isinstance(expression.expression, exp.Order)
):
quantile = expression.this.this
input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
return expression
def unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Pivot):
expression.args["field"].transform(
lambda node: exp.column(node.output_name) if isinstance(node, exp.Column) else node,
copy=False,
)
return expression
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
) -> t.Callable[[Generator, exp.Expression], str]:
"""
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using an appropriate `Generator.TRANSFORMS` function.
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
Args:
transforms: sequence of transform functions. These will be called in order.
@ -255,17 +268,28 @@ def preprocess(
"""
def _to_sql(self, expression: exp.Expression) -> str:
expression_type = type(expression)
expression = transforms[0](expression.copy())
for t in transforms[1:]:
expression = t(expression)
return getattr(self, expression.key + "_sql")(expression)
_sql_handler = getattr(self, expression.key + "_sql", None)
if _sql_handler:
return _sql_handler(expression)
transforms_handler = self.TRANSFORMS.get(type(expression))
if transforms_handler:
# Ensures we don't enter an infinite loop. This can happen when the original expression
# has the same type as the final expression and there's no _sql method available for it,
# because then it'd re-enter _to_sql.
if expression_type is type(expression):
raise ValueError(
f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
)
return transforms_handler(self, expression)
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
return _to_sql
UNALIAS_GROUP = {exp.Group: preprocess([unalias_group])}
ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on])}
ELIMINATE_QUALIFY = {exp.Select: preprocess([eliminate_qualify])}
REMOVE_PRECISION_PARAMETERIZED_TYPES = {
exp.Cast: preprocess([remove_precision_parameterized_types])
}