1
0
Fork 0

Merging upstream version 12.2.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:53:39 +01:00
parent fffa0d5761
commit 62b2b24d3b
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
100 changed files with 35022 additions and 30936 deletions

View file

@ -70,6 +70,7 @@ from sqlglot.dialects.presto import Presto
from sqlglot.dialects.redshift import Redshift
from sqlglot.dialects.snowflake import Snowflake
from sqlglot.dialects.spark import Spark
from sqlglot.dialects.spark2 import Spark2
from sqlglot.dialects.sqlite import SQLite
from sqlglot.dialects.starrocks import StarRocks
from sqlglot.dialects.tableau import Tableau

View file

@ -39,18 +39,26 @@ def _date_add_sql(
def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str:
if not isinstance(expression.unnest().parent, exp.From):
expression = t.cast(exp.Values, transforms.remove_precision_parameterized_types(expression))
return self.values_sql(expression)
rows = [tuple_exp.expressions for tuple_exp in expression.find_all(exp.Tuple)]
structs = []
for row in rows:
aliases = [
exp.alias_(value, column_name)
for value, column_name in zip(row, expression.args["alias"].args["columns"])
]
structs.append(exp.Struct(expressions=aliases))
unnest_exp = exp.Unnest(expressions=[exp.Array(expressions=structs)])
return self.unnest_sql(unnest_exp)
alias = expression.args.get("alias")
structs = [
exp.Struct(
expressions=[
exp.alias_(value, column_name)
for value, column_name in zip(
t.expressions,
alias.columns
if alias and alias.columns
else (f"_c{i}" for i in range(len(t.expressions))),
)
]
)
for t in expression.find_all(exp.Tuple)
]
return self.unnest_sql(exp.Unnest(expressions=[exp.Array(expressions=structs)]))
def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsProperty) -> str:
@ -128,6 +136,7 @@ class BigQuery(Dialect):
IDENTIFIERS = ["`"]
STRING_ESCAPES = ["\\"]
HEX_STRINGS = [("0x", ""), ("0X", "")]
BYTE_STRINGS = [("b'", "'"), ("B'", "'")]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@ -139,6 +148,7 @@ class BigQuery(Dialect):
"GEOGRAPHY": TokenType.GEOGRAPHY,
"FLOAT64": TokenType.DOUBLE,
"INT64": TokenType.BIGINT,
"BYTES": TokenType.BINARY,
"NOT DETERMINISTIC": TokenType.VOLATILE,
"UNKNOWN": TokenType.NULL,
}
@ -153,7 +163,7 @@ class BigQuery(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"DATE_TRUNC": lambda args: exp.DateTrunc(
unit=exp.Literal.string(seq_get(args, 1).name), # type: ignore
unit=exp.Literal.string(str(seq_get(args, 1))),
this=seq_get(args, 0),
),
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
@ -206,6 +216,12 @@ class BigQuery(Dialect):
"NOT DETERMINISTIC": lambda self: self.expression(
exp.StabilityProperty, this=exp.Literal.string("VOLATILE")
),
"OPTIONS": lambda self: self._parse_with_property(),
}
CONSTRAINT_PARSERS = {
**parser.Parser.CONSTRAINT_PARSERS, # type: ignore
"OPTIONS": lambda self: exp.Properties(expressions=self._parse_with_property()),
}
class Generator(generator.Generator):
@ -217,11 +233,11 @@ class BigQuery(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.AtTimeZone: lambda self, e: self.func(
"TIMESTAMP", self.func("DATETIME", e.this, e.args.get("zone"))
),
exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
exp.DateSub: _date_add_sql("DATE", "SUB"),
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
@ -234,7 +250,9 @@ class BigQuery(Dialect):
exp.IntDiv: rename_func("DIV"),
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Select: transforms.preprocess([_unqualify_unnest]),
exp.Select: transforms.preprocess(
[_unqualify_unnest, transforms.eliminate_distinct_on]
),
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
exp.TimeSub: _date_add_sql("TIME", "SUB"),
@ -259,6 +277,7 @@ class BigQuery(Dialect):
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC",
exp.DataType.Type.BIGINT: "INT64",
exp.DataType.Type.BINARY: "BYTES",
exp.DataType.Type.BOOLEAN: "BOOL",
exp.DataType.Type.CHAR: "STRING",
exp.DataType.Type.DECIMAL: "NUMERIC",
@ -272,6 +291,7 @@ class BigQuery(Dialect):
exp.DataType.Type.TIMESTAMP: "DATETIME",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.VARBINARY: "BYTES",
exp.DataType.Type.VARCHAR: "STRING",
exp.DataType.Type.VARIANT: "ANY TYPE",
}
@ -310,3 +330,6 @@ class BigQuery(Dialect):
if not expression.args.get("distinct", False):
self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery")
return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
def with_properties(self, properties: exp.Properties) -> str:
return self.properties(properties, prefix=self.seg("OPTIONS"))

View file

@ -22,6 +22,8 @@ class ClickHouse(Dialect):
class Tokenizer(tokens.Tokenizer):
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
IDENTIFIERS = ['"', "`"]
BIT_STRINGS = [("0b", "")]
HEX_STRINGS = [("0x", ""), ("0X", "")]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@ -31,10 +33,18 @@ class ClickHouse(Dialect):
"FINAL": TokenType.FINAL,
"FLOAT32": TokenType.FLOAT,
"FLOAT64": TokenType.DOUBLE,
"INT16": TokenType.SMALLINT,
"INT32": TokenType.INT,
"INT64": TokenType.BIGINT,
"INT8": TokenType.TINYINT,
"UINT8": TokenType.UTINYINT,
"INT16": TokenType.SMALLINT,
"UINT16": TokenType.USMALLINT,
"INT32": TokenType.INT,
"UINT32": TokenType.UINT,
"INT64": TokenType.BIGINT,
"UINT64": TokenType.UBIGINT,
"INT128": TokenType.INT128,
"UINT128": TokenType.UINT128,
"INT256": TokenType.INT256,
"UINT256": TokenType.UINT256,
"TUPLE": TokenType.STRUCT,
}
@ -121,9 +131,17 @@ class ClickHouse(Dialect):
exp.DataType.Type.ARRAY: "Array",
exp.DataType.Type.STRUCT: "Tuple",
exp.DataType.Type.TINYINT: "Int8",
exp.DataType.Type.UTINYINT: "UInt8",
exp.DataType.Type.SMALLINT: "Int16",
exp.DataType.Type.USMALLINT: "UInt16",
exp.DataType.Type.INT: "Int32",
exp.DataType.Type.UINT: "UInt32",
exp.DataType.Type.BIGINT: "Int64",
exp.DataType.Type.UBIGINT: "UInt64",
exp.DataType.Type.INT128: "Int128",
exp.DataType.Type.UINT128: "UInt128",
exp.DataType.Type.INT256: "Int256",
exp.DataType.Type.UINT256: "UInt256",
exp.DataType.Type.FLOAT: "Float32",
exp.DataType.Type.DOUBLE: "Float64",
}

View file

@ -1,6 +1,6 @@
from __future__ import annotations
from sqlglot import exp
from sqlglot import exp, transforms
from sqlglot.dialects.dialect import parse_date_delta
from sqlglot.dialects.spark import Spark
from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql
@ -29,13 +29,20 @@ class Databricks(Spark):
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.JSONExtract: lambda self, e: self.binary(e, ":"),
exp.Select: transforms.preprocess(
[
transforms.eliminate_distinct_on,
transforms.unnest_to_explode,
]
),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}
TRANSFORMS.pop(exp.Select) # Remove the ELIMINATE_QUALIFY transformation
PARAMETER_TOKEN = "$"
class Tokenizer(Spark.Tokenizer):
HEX_STRINGS = []
SINGLE_TOKENS = {
**Spark.Tokenizer.SINGLE_TOKENS,
"$": TokenType.PARAMETER,

View file

@ -28,6 +28,7 @@ class Dialects(str, Enum):
REDSHIFT = "redshift"
SNOWFLAKE = "snowflake"
SPARK = "spark"
SPARK2 = "spark2"
SQLITE = "sqlite"
STARROCKS = "starrocks"
TABLEAU = "tableau"
@ -69,30 +70,17 @@ class _Dialect(type):
klass.tokenizer_class._IDENTIFIERS.items()
)[0]
if (
klass.tokenizer_class._BIT_STRINGS
and exp.BitString not in klass.generator_class.TRANSFORMS
):
bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.BitString
] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}"
if (
klass.tokenizer_class._HEX_STRINGS
and exp.HexString not in klass.generator_class.TRANSFORMS
):
hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.HexString
] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
if (
klass.tokenizer_class._BYTE_STRINGS
and exp.ByteString not in klass.generator_class.TRANSFORMS
):
be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.ByteString
] = lambda self, e: f"{be_start}{self.sql(e, 'this')}{be_end}"
klass.bit_start, klass.bit_end = seq_get(
list(klass.tokenizer_class._BIT_STRINGS.items()), 0
) or (None, None)
klass.hex_start, klass.hex_end = seq_get(
list(klass.tokenizer_class._HEX_STRINGS.items()), 0
) or (None, None)
klass.byte_start, klass.byte_end = seq_get(
list(klass.tokenizer_class._BYTE_STRINGS.items()), 0
) or (None, None)
return klass
@ -198,6 +186,12 @@ class Dialect(metaclass=_Dialect):
**{
"quote_start": self.quote_start,
"quote_end": self.quote_end,
"bit_start": self.bit_start,
"bit_end": self.bit_end,
"hex_start": self.hex_start,
"hex_end": self.hex_end,
"byte_start": self.byte_start,
"byte_end": self.byte_end,
"identifier_start": self.identifier_start,
"identifier_end": self.identifier_end,
"string_escape": self.tokenizer_class.STRING_ESCAPES[0],

View file

@ -2,7 +2,7 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
create_with_partitions_sql,
@ -145,6 +145,7 @@ class Drill(Dialect):
exp.StrPosition: str_position_sql,
exp.StrToDate: _str_to_date,
exp.Pow: rename_func("POW"),
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: timestrtotime_sql,

View file

@ -1,5 +1,7 @@
from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
@ -23,52 +25,61 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
def _ts_or_ds_add(self, expression):
def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
def _date_add(self, expression):
def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
return f"{this} + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
def _array_sort_sql(self, expression):
def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
if expression.expression:
self.unsupported("DUCKDB ARRAY_SORT does not support a comparator")
return f"ARRAY_SORT({self.sql(expression, 'this')})"
def _sort_array_sql(self, expression):
def _sort_array_sql(self: generator.Generator, expression: exp.SortArray) -> str:
this = self.sql(expression, "this")
if expression.args.get("asc") == exp.false():
return f"ARRAY_REVERSE_SORT({this})"
return f"ARRAY_SORT({this})"
def _sort_array_reverse(args):
def _sort_array_reverse(args: t.Sequence) -> exp.Expression:
return exp.SortArray(this=seq_get(args, 0), asc=exp.false())
def _struct_sql(self, expression):
def _parse_date_diff(args: t.Sequence) -> exp.Expression:
return exp.DateDiff(
this=seq_get(args, 2),
expression=seq_get(args, 1),
unit=seq_get(args, 0),
)
def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str:
args = [
f"'{e.name or e.this.name}': {self.sql(e, 'expression')}" for e in expression.expressions
]
return f"{{{', '.join(args)}}}"
def _datatype_sql(self, expression):
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
if expression.this == exp.DataType.Type.ARRAY:
return f"{self.expressions(expression, flat=True)}[]"
return self.datatype_sql(expression)
def _regexp_extract_sql(self, expression):
def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract) -> str:
bad_args = list(filter(expression.args.get, ("position", "occurrence")))
if bad_args:
self.unsupported(f"REGEXP_EXTRACT does not support arg(s) {bad_args}")
return self.func(
"REGEXP_EXTRACT",
expression.args.get("this"),
@ -108,6 +119,8 @@ class DuckDB(Dialect):
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
"ARRAY_SORT": exp.SortArray.from_arg_list,
"ARRAY_REVERSE_SORT": _sort_array_reverse,
"DATEDIFF": _parse_date_diff,
"DATE_DIFF": _parse_date_diff,
"EPOCH": exp.TimeToUnix.from_arg_list,
"EPOCH_MS": lambda args: exp.UnixToTime(
this=exp.Div(
@ -115,18 +128,18 @@ class DuckDB(Dialect):
expression=exp.Literal.number(1000),
)
),
"LIST_SORT": exp.SortArray.from_arg_list,
"LIST_REVERSE_SORT": _sort_array_reverse,
"LIST_SORT": exp.SortArray.from_arg_list,
"LIST_VALUE": exp.Array.from_arg_list,
"REGEXP_MATCHES": exp.RegexpLike.from_arg_list,
"STRFTIME": format_time_lambda(exp.TimeToStr, "duckdb"),
"STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"),
"STR_SPLIT": exp.Split.from_arg_list,
"STRING_SPLIT": exp.Split.from_arg_list,
"STRING_TO_ARRAY": exp.Split.from_arg_list,
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"STRING_TO_ARRAY": exp.Split.from_arg_list,
"STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"),
"STRUCT_PACK": exp.Struct.from_arg_list,
"STR_SPLIT": exp.Split.from_arg_list,
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"TO_TIMESTAMP": exp.UnixToTime.from_arg_list,
"UNNEST": exp.Explode.from_arg_list,
}
@ -142,10 +155,11 @@ class DuckDB(Dialect):
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
LIMIT_FETCH = "LIMIT"
STRUCT_DELIMITER = ("(", ")")
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0])
if isinstance(seq_get(e.expressions, 0), exp.Select)
@ -154,13 +168,16 @@ class DuckDB(Dialect):
exp.ArraySort: _array_sort_sql,
exp.ArraySum: rename_func("LIST_SUM"),
exp.CommentColumnConstraint: no_comment_column_constraint_sql,
exp.CurrentDate: lambda self, e: "CURRENT_DATE",
exp.CurrentTime: lambda self, e: "CURRENT_TIME",
exp.CurrentTimestamp: lambda self, e: "CURRENT_TIMESTAMP",
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.DataType: _datatype_sql,
exp.DateAdd: _date_add,
exp.DateAdd: _date_add_sql,
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", e.args.get("unit") or exp.Literal.string("day"), e.expression, e.this
"DATE_DIFF", f"'{e.args.get('unit', 'day')}'", e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
@ -192,7 +209,7 @@ class DuckDB(Dialect):
exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("EPOCH"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add,
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsToDate: ts_or_ds_to_date_sql("duckdb"),
exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: rename_func("TO_TIMESTAMP"),
@ -201,7 +218,7 @@ class DuckDB(Dialect):
}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BINARY: "BLOB",
exp.DataType.Type.CHAR: "TEXT",
exp.DataType.Type.FLOAT: "REAL",
@ -212,17 +229,14 @@ class DuckDB(Dialect):
exp.DataType.Type.VARCHAR: "TEXT",
}
STAR_MAPPING = {
**generator.Generator.STAR_MAPPING,
"except": "EXCLUDE",
}
STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
LIMIT_FETCH = "LIMIT"
def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str:
return super().tablesample_sql(expression, seed_prefix="REPEATABLE")
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
) -> str:
return super().tablesample_sql(expression, seed_prefix="REPEATABLE", sep=sep)

View file

@ -81,7 +81,20 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
return f"{diff_sql}{multiplier_sql}"
def _array_sort(self: generator.Generator, expression: exp.ArraySort) -> str:
def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str:
this = expression.this
if not this.type:
from sqlglot.optimizer.annotate_types import annotate_types
annotate_types(this)
if this.type.is_type(exp.DataType.Type.JSON):
return self.sql(this)
return self.func("TO_JSON", this, expression.args.get("options"))
def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
if expression.expression:
self.unsupported("Hive SORT_ARRAY does not support a comparator")
return f"SORT_ARRAY({self.sql(expression, 'this')})"
@ -91,11 +104,11 @@ def _property_sql(self: generator.Generator, expression: exp.Property) -> str:
return f"'{expression.name}'={self.sql(expression, 'value')}"
def _str_to_unix(self: generator.Generator, expression: exp.StrToUnix) -> str:
def _str_to_unix_sql(self: generator.Generator, expression: exp.StrToUnix) -> str:
return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression))
def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format not in (Hive.time_format, Hive.date_format):
@ -103,7 +116,7 @@ def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
return f"CAST({this} AS DATE)"
def _str_to_time(self: generator.Generator, expression: exp.StrToTime) -> str:
def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format not in (Hive.time_format, Hive.date_format):
@ -214,6 +227,7 @@ class Hive(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"BASE64": exp.ToBase64.from_arg_list,
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
"DATE_ADD": lambda args: exp.TsOrDsAdd(
this=seq_get(args, 0),
@ -251,6 +265,7 @@ class Hive(Dialect):
"SPLIT": exp.RegexpSplit.from_arg_list,
"TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"),
"TO_JSON": exp.JSONFormat.from_arg_list,
"UNBASE64": exp.FromBase64.from_arg_list,
"UNIX_TIMESTAMP": format_time_lambda(exp.StrToUnix, "hive", True),
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
}
@ -280,16 +295,20 @@ class Hive(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
**transforms.ELIMINATE_QUALIFY, # type: ignore
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Select: transforms.preprocess(
[transforms.eliminate_qualify, transforms.unnest_to_explode]
[
transforms.eliminate_qualify,
transforms.eliminate_distinct_on,
transforms.unnest_to_explode,
]
),
exp.Property: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayJoin: lambda self, e: self.func("CONCAT_WS", e.expression, e.this),
exp.ArraySize: rename_func("SIZE"),
exp.ArraySort: _array_sort,
exp.ArraySort: _array_sort_sql,
exp.With: no_recursive_cte_sql,
exp.DateAdd: _add_date_sql,
exp.DateDiff: _date_diff_sql,
@ -298,12 +317,13 @@ class Hive(Dialect):
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
exp.FromBase64: rename_func("UNBASE64"),
exp.If: if_sql,
exp.Index: _index_sql,
exp.ILike: no_ilike_sql,
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
exp.JSONFormat: rename_func("TO_JSON"),
exp.JSONFormat: _json_format_sql,
exp.Map: var_map_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
@ -318,9 +338,9 @@ class Hive(Dialect):
exp.SetAgg: rename_func("COLLECT_SET"),
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date,
exp.StrToTime: _str_to_time,
exp.StrToUnix: _str_to_unix,
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: _str_to_unix_sql,
exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}",
exp.TimeStrToDate: rename_func("TO_DATE"),
@ -328,6 +348,7 @@ class Hive(Dialect):
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeToStr: _time_to_str,
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.ToBase64: rename_func("BASE64"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.TsOrDsToDate: _to_date_sql,

View file

@ -1,6 +1,6 @@
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
@ -403,6 +403,7 @@ class MySQL(Dialect):
exp.Min: min_or_least,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_date_sql,

View file

@ -34,6 +34,8 @@ def _parse_xml_table(self) -> exp.XMLTable:
class Oracle(Dialect):
alias_post_tablesample = True
# https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212
# https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes
time_mapping = {
@ -121,21 +123,23 @@ class Oracle(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
exp.DateStrToDate: lambda self, e: self.func(
"TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD")
),
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.ILike: no_ilike_sql,
exp.IfNull: rename_func("NVL"),
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
exp.Substring: rename_func("SUBSTR"),
exp.Table: lambda self, e: self.table_sql(e, sep=" "),
exp.TableSample: lambda self, e: self.tablesample_sql(e, sep=" "),
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: trim_sql,
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
exp.IfNull: rename_func("NVL"),
}
PROPERTIES_LOCATION = {
@ -164,14 +168,19 @@ class Oracle(Dialect):
return f"XMLTABLE({self.sep('')}{self.indent(this + passing + by_ref + columns)}{self.seg(')', sep='')}"
class Tokenizer(tokens.Tokenizer):
VAR_SINGLE_TOKENS = {"@"}
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"(+)": TokenType.JOIN_MARKER,
"BINARY_DOUBLE": TokenType.DOUBLE,
"BINARY_FLOAT": TokenType.FLOAT,
"COLUMNS": TokenType.COLUMN,
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"MINUS": TokenType.EXCEPT,
"NVARCHAR2": TokenType.NVARCHAR,
"RETURNING": TokenType.RETURNING,
"SAMPLE": TokenType.TABLE_SAMPLE,
"START": TokenType.BEGIN,
"TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR,

View file

@ -1,6 +1,8 @@
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
@ -20,7 +22,6 @@ from sqlglot.dialects.dialect import (
from sqlglot.helper import seq_get
from sqlglot.parser import binary_range_parser
from sqlglot.tokens import TokenType
from sqlglot.transforms import preprocess, remove_target_from_merge
DATE_DIFF_FACTOR = {
"MICROSECOND": " * 1000000",
@ -274,8 +275,7 @@ class Postgres(Dialect):
TokenType.HASH: exp.BitwiseXor,
}
FACTOR = {
**parser.Parser.FACTOR,
EXPONENT = {
TokenType.CARET: exp.Pow,
}
@ -286,6 +286,12 @@ class Postgres(Dialect):
TokenType.LT_AT: binary_range_parser(exp.ArrayContained),
}
def _parse_factor(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_exponent, self.FACTOR)
def _parse_exponent(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_unary, self.EXPONENT)
def _parse_date_part(self) -> exp.Expression:
part = self._parse_type()
self._match(TokenType.COMMA)
@ -316,7 +322,7 @@ class Postgres(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
exp.ColumnDef: preprocess(
exp.ColumnDef: transforms.preprocess(
[
_auto_increment_to_serial,
_serial_to_generated,
@ -341,7 +347,7 @@ class Postgres(Dialect):
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
exp.Merge: preprocess([remove_target_from_merge]),
exp.Merge: transforms.preprocess([transforms.remove_target_from_merge]),
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
exp.StrPosition: str_position_sql,

View file

@ -130,7 +130,7 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s
def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str:
start = expression.args["start"]
end = expression.args["end"]
step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series
step = expression.args.get("step")
target_type = None
@ -147,7 +147,11 @@ def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) ->
else:
start = exp.Cast(this=start, to=to)
return self.func("SEQUENCE", start, end, step)
sql = self.func("SEQUENCE", start, end, step)
if isinstance(expression.parent, exp.Table):
sql = f"UNNEST({sql})"
return sql
def _ensure_utf8(charset: exp.Literal) -> None:
@ -204,6 +208,7 @@ class Presto(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
"APPROX_PERCENTILE": _approx_percentile,
"CARDINALITY": exp.ArraySize.from_arg_list,
"CONTAINS": exp.ArrayContains.from_arg_list,
"DATE_ADD": lambda args: exp.DateAdd(
@ -219,23 +224,23 @@ class Presto(Dialect):
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
"DATE_TRUNC": date_trunc_to_time,
"FROM_HEX": exp.Unhex.from_arg_list,
"FROM_UNIXTIME": _from_unixtime,
"FROM_UTF8": lambda args: exp.Decode(
this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8")
),
"NOW": exp.CurrentTimestamp.from_arg_list,
"SEQUENCE": exp.GenerateSeries.from_arg_list,
"STRPOS": lambda args: exp.StrPosition(
this=seq_get(args, 0),
substr=seq_get(args, 1),
instance=seq_get(args, 2),
),
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
"APPROX_PERCENTILE": _approx_percentile,
"FROM_HEX": exp.Unhex.from_arg_list,
"TO_HEX": exp.Hex.from_arg_list,
"TO_UTF8": lambda args: exp.Encode(
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
),
"FROM_UTF8": lambda args: exp.Decode(
this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8")
),
}
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
FUNCTION_PARSERS.pop("TRIM")
@ -264,7 +269,6 @@ class Presto(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"),
@ -290,6 +294,7 @@ class Presto(Dialect):
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
exp.Encode: _encode_sql,
exp.GenerateSeries: _sequence_sql,
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql,
exp.ILike: no_ilike_sql,
@ -303,7 +308,11 @@ class Presto(Dialect):
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.Select: transforms.preprocess(
[transforms.eliminate_qualify, transforms.explode_to_unnest]
[
transforms.eliminate_qualify,
transforms.eliminate_distinct_on,
transforms.explode_to_unnest,
]
),
exp.SortArray: _no_sort_array,
exp.StrPosition: rename_func("STRPOS"),
@ -327,6 +336,9 @@ class Presto(Dialect):
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
exp.VariancePop: rename_func("VAR_POP"),
exp.WithinGroup: transforms.preprocess(
[transforms.remove_within_group_for_percentiles]
),
}
def interval_sql(self, expression: exp.Interval) -> str:

View file

@ -52,6 +52,8 @@ class Redshift(Postgres):
return this
class Tokenizer(Postgres.Tokenizer):
BIT_STRINGS = []
HEX_STRINGS = []
STRING_ESCAPES = ["\\"]
KEYWORDS = {
@ -90,7 +92,6 @@ class Redshift(Postgres):
TRANSFORMS = {
**Postgres.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
exp.DateAdd: lambda self, e: self.func(
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
@ -102,6 +103,7 @@ class Redshift(Postgres):
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
exp.JSONExtract: _json_sql,
exp.JSONExtractScalar: _json_sql,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
}

View file

@ -2,7 +2,7 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
date_trunc_to_time,
@ -252,6 +252,7 @@ class Snowflake(Dialect):
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]
STRING_ESCAPES = ["\\", "'"]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@ -305,6 +306,7 @@ class Snowflake(Dialect):
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
exp.StrPosition: lambda self, e: self.func(
"POSITION", e.args.get("substr"), e.this, e.args.get("position")

View file

@ -2,222 +2,54 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, parser
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
from sqlglot.dialects.hive import Hive
from sqlglot import exp
from sqlglot.dialects.spark2 import Spark2
from sqlglot.helper import seq_get
def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
kind = e.args["kind"]
properties = e.args.get("properties")
def _parse_datediff(args: t.Sequence) -> exp.Expression:
"""
Although Spark docs don't mention the "unit" argument, Spark3 added support for
it at some point. Databricks also supports this variation (see below).
if kind.upper() == "TABLE" and any(
isinstance(prop, exp.TemporaryProperty)
for prop in (properties.expressions if properties else [])
):
return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
return create_with_partitions_sql(self, e)
For example, in spark-sql (v3.3.1):
- SELECT DATEDIFF('2020-01-01', '2020-01-05') results in -4
- SELECT DATEDIFF(day, '2020-01-01', '2020-01-05') results in 4
See also:
- https://docs.databricks.com/sql/language-manual/functions/datediff3.html
- https://docs.databricks.com/sql/language-manual/functions/datediff.html
"""
unit = None
this = seq_get(args, 0)
expression = seq_get(args, 1)
if len(args) == 3:
unit = this
this = args[2]
return exp.DateDiff(
this=exp.TsOrDsToDate(this=this), expression=exp.TsOrDsToDate(this=expression), unit=unit
)
def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
keys = self.sql(expression.args["keys"])
values = self.sql(expression.args["values"])
return f"MAP_FROM_ARRAYS({keys}, {values})"
def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Hive.date_format:
return f"TO_DATE({this})"
return f"TO_DATE({this}, {time_format})"
def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale is None:
return f"FROM_UNIXTIME({timestamp})"
if scale == exp.UnixToTime.SECONDS:
return f"TIMESTAMP_SECONDS({timestamp})"
if scale == exp.UnixToTime.MILLIS:
return f"TIMESTAMP_MILLIS({timestamp})"
if scale == exp.UnixToTime.MICROS:
return f"TIMESTAMP_MICROS({timestamp})"
raise ValueError("Improper scale for timestamp")
class Spark(Hive):
class Parser(Hive.Parser):
class Spark(Spark2):
class Parser(Spark2.Parser):
FUNCTIONS = {
**Hive.Parser.FUNCTIONS, # type: ignore
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
"LEFT": lambda args: exp.Substring(
this=seq_get(args, 0),
start=exp.Literal.number(1),
length=seq_get(args, 1),
),
"SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
this=seq_get(args, 0),
expression=seq_get(args, 1),
),
"SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
this=seq_get(args, 0),
expression=seq_get(args, 1),
),
"RIGHT": lambda args: exp.Substring(
this=seq_get(args, 0),
start=exp.Sub(
this=exp.Length(this=seq_get(args, 0)),
expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
),
length=seq_get(args, 1),
),
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"BOOLEAN": lambda args: exp.Cast(
this=seq_get(args, 0), to=exp.DataType.build("boolean")
),
"IIF": exp.If.from_arg_list,
"INT": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("int")),
"AGGREGATE": exp.Reduce.from_arg_list,
"DAYOFWEEK": lambda args: exp.DayOfWeek(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"DAYOFMONTH": lambda args: exp.DayOfMonth(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"DAYOFYEAR": lambda args: exp.DayOfYear(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1),
unit=exp.var(seq_get(args, 0)),
),
"STRING": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("string")),
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
"TIMESTAMP": lambda args: exp.Cast(
this=seq_get(args, 0), to=exp.DataType.build("timestamp")
),
**Spark2.Parser.FUNCTIONS, # type: ignore
"DATEDIFF": _parse_datediff,
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS, # type: ignore
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
"MERGE": lambda self: self._parse_join_hint("MERGE"),
"SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"),
"MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"),
"SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"),
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
}
class Generator(Spark2.Generator):
TRANSFORMS = Spark2.Generator.TRANSFORMS.copy()
TRANSFORMS.pop(exp.DateDiff)
def _parse_add_column(self) -> t.Optional[exp.Expression]:
return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
def datediff_sql(self, expression: exp.DateDiff) -> str:
unit = self.sql(expression, "unit")
end = self.sql(expression, "this")
start = self.sql(expression, "expression")
def _parse_drop_column(self) -> t.Optional[exp.Expression]:
return self._match_text_seq("DROP", "COLUMNS") and self.expression(
exp.Drop,
this=self._parse_schema(),
kind="COLUMNS",
)
if unit:
return self.func("DATEDIFF", unit, start, end)
def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
# Spark doesn't add a suffix to the pivot columns when there's a single aggregation
if len(pivot_columns) == 1:
return [""]
names = []
for agg in pivot_columns:
if isinstance(agg, exp.Alias):
names.append(agg.alias)
else:
"""
This case corresponds to aggregations without aliases being used as suffixes
(e.g. col_avg(foo)). We need to unquote identifiers because they're going to
be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
Moreover, function names are lowercased in order to mimic Spark's naming scheme.
"""
agg_all_unquoted = agg.transform(
lambda node: exp.Identifier(this=node.name, quoted=False)
if isinstance(node, exp.Identifier)
else node
)
names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
return names
class Generator(Hive.Generator):
TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "BYTE",
exp.DataType.Type.SMALLINT: "SHORT",
exp.DataType.Type.BIGINT: "LONG",
}
PROPERTIES_LOCATION = {
**Hive.Generator.PROPERTIES_LOCATION, # type: ignore
exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
exp.CollateProperty: exp.Properties.Location.UNSUPPORTED,
}
TRANSFORMS = {
**Hive.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: _unix_to_time_sql,
exp.Create: _create_sql,
exp.Map: _map_sql,
exp.Reduce: rename_func("AGGREGATE"),
exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
exp.TimestampTrunc: lambda self, e: self.func(
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
),
exp.Trim: trim_sql,
exp.VariancePop: rename_func("VAR_POP"),
exp.DateFromParts: rename_func("MAKE_DATE"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
}
TRANSFORMS.pop(exp.ArraySort)
TRANSFORMS.pop(exp.ILike)
WRAP_DERIVED_VALUES = False
CREATE_FUNCTION_RETURN_AS = False
def cast_sql(self, expression: exp.Cast) -> str:
if isinstance(expression.this, exp.Cast) and expression.this.is_type(
exp.DataType.Type.JSON
):
schema = f"'{self.sql(expression, 'to')}'"
return self.func("FROM_JSON", expression.this.this, schema)
if expression.to.is_type(exp.DataType.Type.JSON):
return self.func("TO_JSON", expression.this)
return super(Spark.Generator, self).cast_sql(expression)
class Tokenizer(Hive.Tokenizer):
HEX_STRINGS = [("X'", "'")]
return self.func("DATEDIFF", end, start)

238
sqlglot/dialects/spark2.py Normal file
View file

@ -0,0 +1,238 @@
from __future__ import annotations
import typing as t
from sqlglot import exp, parser, transforms
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
from sqlglot.dialects.hive import Hive
from sqlglot.helper import seq_get
def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
kind = e.args["kind"]
properties = e.args.get("properties")
if kind.upper() == "TABLE" and any(
isinstance(prop, exp.TemporaryProperty)
for prop in (properties.expressions if properties else [])
):
return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
return create_with_partitions_sql(self, e)
def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
keys = self.sql(expression.args["keys"])
values = self.sql(expression.args["values"])
return f"MAP_FROM_ARRAYS({keys}, {values})"
def _parse_as_cast(to_type: str) -> t.Callable[[t.Sequence], exp.Expression]:
return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type))
def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Hive.date_format:
return f"TO_DATE({this})"
return f"TO_DATE({this}, {time_format})"
def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale is None:
return f"CAST(FROM_UNIXTIME({timestamp}) AS TIMESTAMP)"
if scale == exp.UnixToTime.SECONDS:
return f"TIMESTAMP_SECONDS({timestamp})"
if scale == exp.UnixToTime.MILLIS:
return f"TIMESTAMP_MILLIS({timestamp})"
if scale == exp.UnixToTime.MICROS:
return f"TIMESTAMP_MICROS({timestamp})"
raise ValueError("Improper scale for timestamp")
class Spark2(Hive):
class Parser(Hive.Parser):
FUNCTIONS = {
**Hive.Parser.FUNCTIONS, # type: ignore
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
"LEFT": lambda args: exp.Substring(
this=seq_get(args, 0),
start=exp.Literal.number(1),
length=seq_get(args, 1),
),
"SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
this=seq_get(args, 0),
expression=seq_get(args, 1),
),
"SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
this=seq_get(args, 0),
expression=seq_get(args, 1),
),
"RIGHT": lambda args: exp.Substring(
this=seq_get(args, 0),
start=exp.Sub(
this=exp.Length(this=seq_get(args, 0)),
expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
),
length=seq_get(args, 1),
),
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"IIF": exp.If.from_arg_list,
"AGGREGATE": exp.Reduce.from_arg_list,
"DAYOFWEEK": lambda args: exp.DayOfWeek(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"DAYOFMONTH": lambda args: exp.DayOfMonth(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"DAYOFYEAR": lambda args: exp.DayOfYear(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1),
unit=exp.var(seq_get(args, 0)),
),
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
"BOOLEAN": _parse_as_cast("boolean"),
"DOUBLE": _parse_as_cast("double"),
"FLOAT": _parse_as_cast("float"),
"INT": _parse_as_cast("int"),
"STRING": _parse_as_cast("string"),
"TIMESTAMP": _parse_as_cast("timestamp"),
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS, # type: ignore
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
"MERGE": lambda self: self._parse_join_hint("MERGE"),
"SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"),
"MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"),
"SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"),
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
}
def _parse_add_column(self) -> t.Optional[exp.Expression]:
return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
def _parse_drop_column(self) -> t.Optional[exp.Expression]:
return self._match_text_seq("DROP", "COLUMNS") and self.expression(
exp.Drop,
this=self._parse_schema(),
kind="COLUMNS",
)
def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
# Spark doesn't add a suffix to the pivot columns when there's a single aggregation
if len(pivot_columns) == 1:
return [""]
names = []
for agg in pivot_columns:
if isinstance(agg, exp.Alias):
names.append(agg.alias)
else:
"""
This case corresponds to aggregations without aliases being used as suffixes
(e.g. col_avg(foo)). We need to unquote identifiers because they're going to
be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
Moreover, function names are lowercased in order to mimic Spark's naming scheme.
"""
agg_all_unquoted = agg.transform(
lambda node: exp.Identifier(this=node.name, quoted=False)
if isinstance(node, exp.Identifier)
else node
)
names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
return names
class Generator(Hive.Generator):
TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "BYTE",
exp.DataType.Type.SMALLINT: "SHORT",
exp.DataType.Type.BIGINT: "LONG",
}
PROPERTIES_LOCATION = {
**Hive.Generator.PROPERTIES_LOCATION, # type: ignore
exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
exp.CollateProperty: exp.Properties.Location.UNSUPPORTED,
}
TRANSFORMS = {
**Hive.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
exp.Create: _create_sql,
exp.DateFromParts: rename_func("MAKE_DATE"),
exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Map: _map_sql,
exp.Pivot: transforms.preprocess([transforms.unqualify_pivot_columns]),
exp.Reduce: rename_func("AGGREGATE"),
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimestampTrunc: lambda self, e: self.func(
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
),
exp.Trim: trim_sql,
exp.UnixToTime: _unix_to_time_sql,
exp.VariancePop: rename_func("VAR_POP"),
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
exp.WithinGroup: transforms.preprocess(
[transforms.remove_within_group_for_percentiles]
),
}
TRANSFORMS.pop(exp.ArrayJoin)
TRANSFORMS.pop(exp.ArraySort)
TRANSFORMS.pop(exp.ILike)
WRAP_DERIVED_VALUES = False
CREATE_FUNCTION_RETURN_AS = False
def cast_sql(self, expression: exp.Cast) -> str:
if isinstance(expression.this, exp.Cast) and expression.this.is_type(
exp.DataType.Type.JSON
):
schema = f"'{self.sql(expression, 'to')}'"
return self.func("FROM_JSON", expression.this.this, schema)
if expression.to.is_type(exp.DataType.Type.JSON):
return self.func("TO_JSON", expression.this)
return super(Hive.Generator, self).cast_sql(expression)
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
return super().columndef_sql(
expression,
sep=": "
if isinstance(expression.parent, exp.DataType)
and expression.parent.is_type(exp.DataType.Type.STRUCT)
else sep,
)
class Tokenizer(Hive.Tokenizer):
HEX_STRINGS = [("X'", "'")]

View file

@ -22,6 +22,40 @@ def _date_add_sql(self, expression):
return self.func("DATE", expression.this, modifier)
def _transform_create(expression: exp.Expression) -> exp.Expression:
"""Move primary key to a column and enforce auto_increment on primary keys."""
schema = expression.this
if isinstance(expression, exp.Create) and isinstance(schema, exp.Schema):
defs = {}
primary_key = None
for e in schema.expressions:
if isinstance(e, exp.ColumnDef):
defs[e.name] = e
elif isinstance(e, exp.PrimaryKey):
primary_key = e
if primary_key and len(primary_key.expressions) == 1:
column = defs[primary_key.expressions[0].name]
column.append(
"constraints", exp.ColumnConstraint(kind=exp.PrimaryKeyColumnConstraint())
)
schema.expressions.remove(primary_key)
else:
for column in defs.values():
auto_increment = None
for constraint in column.constraints.copy():
if isinstance(constraint.kind, exp.PrimaryKeyColumnConstraint):
break
if isinstance(constraint.kind, exp.AutoIncrementColumnConstraint):
auto_increment = constraint
if auto_increment:
column.constraints.remove(auto_increment)
return expression
class SQLite(Dialect):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"]
@ -65,8 +99,8 @@ class SQLite(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_QUALIFY, # type: ignore
exp.CountIf: count_if_to_sum,
exp.Create: transforms.preprocess([_transform_create]),
exp.CurrentDate: lambda *_: "CURRENT_DATE",
exp.CurrentTime: lambda *_: "CURRENT_TIME",
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
@ -80,14 +114,17 @@ class SQLite(Dialect):
exp.Levenshtein: rename_func("EDITDIST3"),
exp.LogicalOr: rename_func("MAX"),
exp.LogicalAnd: rename_func("MIN"),
exp.Select: transforms.preprocess(
[transforms.eliminate_distinct_on, transforms.eliminate_qualify]
),
exp.TableSample: no_tablesample_sql,
exp.TimeStrToTime: lambda self, e: self.sql(e, "this"),
exp.TryCast: no_trycast_sql,
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
k: exp.Properties.Location.UNSUPPORTED
for k, v in generator.Generator.PROPERTIES_LOCATION.items()
}
LIMIT_FETCH = "LIMIT"

View file

@ -34,6 +34,7 @@ class StarRocks(MySQL):
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.DateDiff: rename_func("DATEDIFF"),
exp.RegexpLike: rename_func("REGEXP"),
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimestampTrunc: lambda self, e: self.func(
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this

View file

@ -1,6 +1,6 @@
from __future__ import annotations
from sqlglot import exp, generator, parser
from sqlglot import exp, generator, parser, transforms
from sqlglot.dialects.dialect import Dialect
@ -29,6 +29,7 @@ class Tableau(Dialect):
exp.If: _if_sql,
exp.Coalesce: _coalesce_sql,
exp.Count: _count_sql,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
}
PROPERTIES_LOCATION = {

View file

@ -2,7 +2,7 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
format_time_lambda,
@ -148,6 +148,7 @@ class Teradata(Dialect):
**generator.Generator.TRANSFORMS,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.TimeToStr: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import re
import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
max_or_greatest,
@ -259,8 +259,8 @@ class TSQL(Dialect):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]")]
QUOTES = ["'", '"']
HEX_STRINGS = [("0x", ""), ("0X", "")]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@ -463,17 +463,18 @@ class TSQL(Dialect):
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),
exp.CurrentTimestamp: rename_func("GETDATE"),
exp.If: rename_func("IIF"),
exp.NumberToStr: _format_sql,
exp.TimeToStr: _format_sql,
exp.GroupConcat: _string_agg_sql,
exp.If: rename_func("IIF"),
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
exp.Min: min_or_least,
exp.NumberToStr: _format_sql,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
exp.SHA2: lambda self, e: self.func(
"HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
),
exp.TimeToStr: _format_sql,
}
TRANSFORMS.pop(exp.ReturnsProperty)