Merging upstream version 12.2.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
fffa0d5761
commit
62b2b24d3b
100 changed files with 35022 additions and 30936 deletions
|
@ -70,6 +70,7 @@ from sqlglot.dialects.presto import Presto
|
|||
from sqlglot.dialects.redshift import Redshift
|
||||
from sqlglot.dialects.snowflake import Snowflake
|
||||
from sqlglot.dialects.spark import Spark
|
||||
from sqlglot.dialects.spark2 import Spark2
|
||||
from sqlglot.dialects.sqlite import SQLite
|
||||
from sqlglot.dialects.starrocks import StarRocks
|
||||
from sqlglot.dialects.tableau import Tableau
|
||||
|
|
|
@ -39,18 +39,26 @@ def _date_add_sql(
|
|||
|
||||
def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str:
|
||||
if not isinstance(expression.unnest().parent, exp.From):
|
||||
expression = t.cast(exp.Values, transforms.remove_precision_parameterized_types(expression))
|
||||
return self.values_sql(expression)
|
||||
rows = [tuple_exp.expressions for tuple_exp in expression.find_all(exp.Tuple)]
|
||||
structs = []
|
||||
for row in rows:
|
||||
aliases = [
|
||||
exp.alias_(value, column_name)
|
||||
for value, column_name in zip(row, expression.args["alias"].args["columns"])
|
||||
]
|
||||
structs.append(exp.Struct(expressions=aliases))
|
||||
unnest_exp = exp.Unnest(expressions=[exp.Array(expressions=structs)])
|
||||
return self.unnest_sql(unnest_exp)
|
||||
|
||||
alias = expression.args.get("alias")
|
||||
|
||||
structs = [
|
||||
exp.Struct(
|
||||
expressions=[
|
||||
exp.alias_(value, column_name)
|
||||
for value, column_name in zip(
|
||||
t.expressions,
|
||||
alias.columns
|
||||
if alias and alias.columns
|
||||
else (f"_c{i}" for i in range(len(t.expressions))),
|
||||
)
|
||||
]
|
||||
)
|
||||
for t in expression.find_all(exp.Tuple)
|
||||
]
|
||||
|
||||
return self.unnest_sql(exp.Unnest(expressions=[exp.Array(expressions=structs)]))
|
||||
|
||||
|
||||
def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsProperty) -> str:
|
||||
|
@ -128,6 +136,7 @@ class BigQuery(Dialect):
|
|||
IDENTIFIERS = ["`"]
|
||||
STRING_ESCAPES = ["\\"]
|
||||
HEX_STRINGS = [("0x", ""), ("0X", "")]
|
||||
BYTE_STRINGS = [("b'", "'"), ("B'", "'")]
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
|
@ -139,6 +148,7 @@ class BigQuery(Dialect):
|
|||
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
||||
"FLOAT64": TokenType.DOUBLE,
|
||||
"INT64": TokenType.BIGINT,
|
||||
"BYTES": TokenType.BINARY,
|
||||
"NOT DETERMINISTIC": TokenType.VOLATILE,
|
||||
"UNKNOWN": TokenType.NULL,
|
||||
}
|
||||
|
@ -153,7 +163,7 @@ class BigQuery(Dialect):
|
|||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"DATE_TRUNC": lambda args: exp.DateTrunc(
|
||||
unit=exp.Literal.string(seq_get(args, 1).name), # type: ignore
|
||||
unit=exp.Literal.string(str(seq_get(args, 1))),
|
||||
this=seq_get(args, 0),
|
||||
),
|
||||
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
|
||||
|
@ -206,6 +216,12 @@ class BigQuery(Dialect):
|
|||
"NOT DETERMINISTIC": lambda self: self.expression(
|
||||
exp.StabilityProperty, this=exp.Literal.string("VOLATILE")
|
||||
),
|
||||
"OPTIONS": lambda self: self._parse_with_property(),
|
||||
}
|
||||
|
||||
CONSTRAINT_PARSERS = {
|
||||
**parser.Parser.CONSTRAINT_PARSERS, # type: ignore
|
||||
"OPTIONS": lambda self: exp.Properties(expressions=self._parse_with_property()),
|
||||
}
|
||||
|
||||
class Generator(generator.Generator):
|
||||
|
@ -217,11 +233,11 @@ class BigQuery(Dialect):
|
|||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore
|
||||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||
exp.AtTimeZone: lambda self, e: self.func(
|
||||
"TIMESTAMP", self.func("DATETIME", e.this, e.args.get("zone"))
|
||||
),
|
||||
exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
|
||||
exp.DateAdd: _date_add_sql("DATE", "ADD"),
|
||||
exp.DateSub: _date_add_sql("DATE", "SUB"),
|
||||
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
|
||||
|
@ -234,7 +250,9 @@ class BigQuery(Dialect):
|
|||
exp.IntDiv: rename_func("DIV"),
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.Select: transforms.preprocess([_unqualify_unnest]),
|
||||
exp.Select: transforms.preprocess(
|
||||
[_unqualify_unnest, transforms.eliminate_distinct_on]
|
||||
),
|
||||
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
|
||||
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
|
||||
exp.TimeSub: _date_add_sql("TIME", "SUB"),
|
||||
|
@ -259,6 +277,7 @@ class BigQuery(Dialect):
|
|||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC",
|
||||
exp.DataType.Type.BIGINT: "INT64",
|
||||
exp.DataType.Type.BINARY: "BYTES",
|
||||
exp.DataType.Type.BOOLEAN: "BOOL",
|
||||
exp.DataType.Type.CHAR: "STRING",
|
||||
exp.DataType.Type.DECIMAL: "NUMERIC",
|
||||
|
@ -272,6 +291,7 @@ class BigQuery(Dialect):
|
|||
exp.DataType.Type.TIMESTAMP: "DATETIME",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
|
||||
exp.DataType.Type.TINYINT: "INT64",
|
||||
exp.DataType.Type.VARBINARY: "BYTES",
|
||||
exp.DataType.Type.VARCHAR: "STRING",
|
||||
exp.DataType.Type.VARIANT: "ANY TYPE",
|
||||
}
|
||||
|
@ -310,3 +330,6 @@ class BigQuery(Dialect):
|
|||
if not expression.args.get("distinct", False):
|
||||
self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery")
|
||||
return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
|
||||
|
||||
def with_properties(self, properties: exp.Properties) -> str:
|
||||
return self.properties(properties, prefix=self.seg("OPTIONS"))
|
||||
|
|
|
@ -22,6 +22,8 @@ class ClickHouse(Dialect):
|
|||
class Tokenizer(tokens.Tokenizer):
|
||||
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
|
||||
IDENTIFIERS = ['"', "`"]
|
||||
BIT_STRINGS = [("0b", "")]
|
||||
HEX_STRINGS = [("0x", ""), ("0X", "")]
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
|
@ -31,10 +33,18 @@ class ClickHouse(Dialect):
|
|||
"FINAL": TokenType.FINAL,
|
||||
"FLOAT32": TokenType.FLOAT,
|
||||
"FLOAT64": TokenType.DOUBLE,
|
||||
"INT16": TokenType.SMALLINT,
|
||||
"INT32": TokenType.INT,
|
||||
"INT64": TokenType.BIGINT,
|
||||
"INT8": TokenType.TINYINT,
|
||||
"UINT8": TokenType.UTINYINT,
|
||||
"INT16": TokenType.SMALLINT,
|
||||
"UINT16": TokenType.USMALLINT,
|
||||
"INT32": TokenType.INT,
|
||||
"UINT32": TokenType.UINT,
|
||||
"INT64": TokenType.BIGINT,
|
||||
"UINT64": TokenType.UBIGINT,
|
||||
"INT128": TokenType.INT128,
|
||||
"UINT128": TokenType.UINT128,
|
||||
"INT256": TokenType.INT256,
|
||||
"UINT256": TokenType.UINT256,
|
||||
"TUPLE": TokenType.STRUCT,
|
||||
}
|
||||
|
||||
|
@ -121,9 +131,17 @@ class ClickHouse(Dialect):
|
|||
exp.DataType.Type.ARRAY: "Array",
|
||||
exp.DataType.Type.STRUCT: "Tuple",
|
||||
exp.DataType.Type.TINYINT: "Int8",
|
||||
exp.DataType.Type.UTINYINT: "UInt8",
|
||||
exp.DataType.Type.SMALLINT: "Int16",
|
||||
exp.DataType.Type.USMALLINT: "UInt16",
|
||||
exp.DataType.Type.INT: "Int32",
|
||||
exp.DataType.Type.UINT: "UInt32",
|
||||
exp.DataType.Type.BIGINT: "Int64",
|
||||
exp.DataType.Type.UBIGINT: "UInt64",
|
||||
exp.DataType.Type.INT128: "Int128",
|
||||
exp.DataType.Type.UINT128: "UInt128",
|
||||
exp.DataType.Type.INT256: "Int256",
|
||||
exp.DataType.Type.UINT256: "UInt256",
|
||||
exp.DataType.Type.FLOAT: "Float32",
|
||||
exp.DataType.Type.DOUBLE: "Float64",
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot import exp, transforms
|
||||
from sqlglot.dialects.dialect import parse_date_delta
|
||||
from sqlglot.dialects.spark import Spark
|
||||
from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql
|
||||
|
@ -29,13 +29,20 @@ class Databricks(Spark):
|
|||
exp.DateAdd: generate_date_delta_with_unit_sql,
|
||||
exp.DateDiff: generate_date_delta_with_unit_sql,
|
||||
exp.JSONExtract: lambda self, e: self.binary(e, ":"),
|
||||
exp.Select: transforms.preprocess(
|
||||
[
|
||||
transforms.eliminate_distinct_on,
|
||||
transforms.unnest_to_explode,
|
||||
]
|
||||
),
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
}
|
||||
TRANSFORMS.pop(exp.Select) # Remove the ELIMINATE_QUALIFY transformation
|
||||
|
||||
PARAMETER_TOKEN = "$"
|
||||
|
||||
class Tokenizer(Spark.Tokenizer):
|
||||
HEX_STRINGS = []
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
**Spark.Tokenizer.SINGLE_TOKENS,
|
||||
"$": TokenType.PARAMETER,
|
||||
|
|
|
@ -28,6 +28,7 @@ class Dialects(str, Enum):
|
|||
REDSHIFT = "redshift"
|
||||
SNOWFLAKE = "snowflake"
|
||||
SPARK = "spark"
|
||||
SPARK2 = "spark2"
|
||||
SQLITE = "sqlite"
|
||||
STARROCKS = "starrocks"
|
||||
TABLEAU = "tableau"
|
||||
|
@ -69,30 +70,17 @@ class _Dialect(type):
|
|||
klass.tokenizer_class._IDENTIFIERS.items()
|
||||
)[0]
|
||||
|
||||
if (
|
||||
klass.tokenizer_class._BIT_STRINGS
|
||||
and exp.BitString not in klass.generator_class.TRANSFORMS
|
||||
):
|
||||
bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0]
|
||||
klass.generator_class.TRANSFORMS[
|
||||
exp.BitString
|
||||
] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}"
|
||||
if (
|
||||
klass.tokenizer_class._HEX_STRINGS
|
||||
and exp.HexString not in klass.generator_class.TRANSFORMS
|
||||
):
|
||||
hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0]
|
||||
klass.generator_class.TRANSFORMS[
|
||||
exp.HexString
|
||||
] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
|
||||
if (
|
||||
klass.tokenizer_class._BYTE_STRINGS
|
||||
and exp.ByteString not in klass.generator_class.TRANSFORMS
|
||||
):
|
||||
be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0]
|
||||
klass.generator_class.TRANSFORMS[
|
||||
exp.ByteString
|
||||
] = lambda self, e: f"{be_start}{self.sql(e, 'this')}{be_end}"
|
||||
klass.bit_start, klass.bit_end = seq_get(
|
||||
list(klass.tokenizer_class._BIT_STRINGS.items()), 0
|
||||
) or (None, None)
|
||||
|
||||
klass.hex_start, klass.hex_end = seq_get(
|
||||
list(klass.tokenizer_class._HEX_STRINGS.items()), 0
|
||||
) or (None, None)
|
||||
|
||||
klass.byte_start, klass.byte_end = seq_get(
|
||||
list(klass.tokenizer_class._BYTE_STRINGS.items()), 0
|
||||
) or (None, None)
|
||||
|
||||
return klass
|
||||
|
||||
|
@ -198,6 +186,12 @@ class Dialect(metaclass=_Dialect):
|
|||
**{
|
||||
"quote_start": self.quote_start,
|
||||
"quote_end": self.quote_end,
|
||||
"bit_start": self.bit_start,
|
||||
"bit_end": self.bit_end,
|
||||
"hex_start": self.hex_start,
|
||||
"hex_end": self.hex_end,
|
||||
"byte_start": self.byte_start,
|
||||
"byte_end": self.byte_end,
|
||||
"identifier_start": self.identifier_start,
|
||||
"identifier_end": self.identifier_end,
|
||||
"string_escape": self.tokenizer_class.STRING_ESCAPES[0],
|
||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
create_with_partitions_sql,
|
||||
|
@ -145,6 +145,7 @@ class Drill(Dialect):
|
|||
exp.StrPosition: str_position_sql,
|
||||
exp.StrToDate: _str_to_date,
|
||||
exp.Pow: rename_func("POW"),
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
|
@ -23,52 +25,61 @@ from sqlglot.helper import seq_get
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _ts_or_ds_add(self, expression):
|
||||
def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = self.sql(expression, "unit").strip("'") or "DAY"
|
||||
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
|
||||
|
||||
|
||||
def _date_add(self, expression):
|
||||
def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = self.sql(expression, "unit").strip("'") or "DAY"
|
||||
return f"{this} + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
|
||||
|
||||
|
||||
def _array_sort_sql(self, expression):
|
||||
def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
|
||||
if expression.expression:
|
||||
self.unsupported("DUCKDB ARRAY_SORT does not support a comparator")
|
||||
return f"ARRAY_SORT({self.sql(expression, 'this')})"
|
||||
|
||||
|
||||
def _sort_array_sql(self, expression):
|
||||
def _sort_array_sql(self: generator.Generator, expression: exp.SortArray) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
if expression.args.get("asc") == exp.false():
|
||||
return f"ARRAY_REVERSE_SORT({this})"
|
||||
return f"ARRAY_SORT({this})"
|
||||
|
||||
|
||||
def _sort_array_reverse(args):
|
||||
def _sort_array_reverse(args: t.Sequence) -> exp.Expression:
|
||||
return exp.SortArray(this=seq_get(args, 0), asc=exp.false())
|
||||
|
||||
|
||||
def _struct_sql(self, expression):
|
||||
def _parse_date_diff(args: t.Sequence) -> exp.Expression:
|
||||
return exp.DateDiff(
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
unit=seq_get(args, 0),
|
||||
)
|
||||
|
||||
|
||||
def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str:
|
||||
args = [
|
||||
f"'{e.name or e.this.name}': {self.sql(e, 'expression')}" for e in expression.expressions
|
||||
]
|
||||
return f"{{{', '.join(args)}}}"
|
||||
|
||||
|
||||
def _datatype_sql(self, expression):
|
||||
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
|
||||
if expression.this == exp.DataType.Type.ARRAY:
|
||||
return f"{self.expressions(expression, flat=True)}[]"
|
||||
return self.datatype_sql(expression)
|
||||
|
||||
|
||||
def _regexp_extract_sql(self, expression):
|
||||
def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract) -> str:
|
||||
bad_args = list(filter(expression.args.get, ("position", "occurrence")))
|
||||
if bad_args:
|
||||
self.unsupported(f"REGEXP_EXTRACT does not support arg(s) {bad_args}")
|
||||
|
||||
return self.func(
|
||||
"REGEXP_EXTRACT",
|
||||
expression.args.get("this"),
|
||||
|
@ -108,6 +119,8 @@ class DuckDB(Dialect):
|
|||
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
|
||||
"ARRAY_SORT": exp.SortArray.from_arg_list,
|
||||
"ARRAY_REVERSE_SORT": _sort_array_reverse,
|
||||
"DATEDIFF": _parse_date_diff,
|
||||
"DATE_DIFF": _parse_date_diff,
|
||||
"EPOCH": exp.TimeToUnix.from_arg_list,
|
||||
"EPOCH_MS": lambda args: exp.UnixToTime(
|
||||
this=exp.Div(
|
||||
|
@ -115,18 +128,18 @@ class DuckDB(Dialect):
|
|||
expression=exp.Literal.number(1000),
|
||||
)
|
||||
),
|
||||
"LIST_SORT": exp.SortArray.from_arg_list,
|
||||
"LIST_REVERSE_SORT": _sort_array_reverse,
|
||||
"LIST_SORT": exp.SortArray.from_arg_list,
|
||||
"LIST_VALUE": exp.Array.from_arg_list,
|
||||
"REGEXP_MATCHES": exp.RegexpLike.from_arg_list,
|
||||
"STRFTIME": format_time_lambda(exp.TimeToStr, "duckdb"),
|
||||
"STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"),
|
||||
"STR_SPLIT": exp.Split.from_arg_list,
|
||||
"STRING_SPLIT": exp.Split.from_arg_list,
|
||||
"STRING_TO_ARRAY": exp.Split.from_arg_list,
|
||||
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
|
||||
"STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
|
||||
"STRING_TO_ARRAY": exp.Split.from_arg_list,
|
||||
"STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"),
|
||||
"STRUCT_PACK": exp.Struct.from_arg_list,
|
||||
"STR_SPLIT": exp.Split.from_arg_list,
|
||||
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
|
||||
"TO_TIMESTAMP": exp.UnixToTime.from_arg_list,
|
||||
"UNNEST": exp.Explode.from_arg_list,
|
||||
}
|
||||
|
@ -142,10 +155,11 @@ class DuckDB(Dialect):
|
|||
class Generator(generator.Generator):
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0])
|
||||
if isinstance(seq_get(e.expressions, 0), exp.Select)
|
||||
|
@ -154,13 +168,16 @@ class DuckDB(Dialect):
|
|||
exp.ArraySort: _array_sort_sql,
|
||||
exp.ArraySum: rename_func("LIST_SUM"),
|
||||
exp.CommentColumnConstraint: no_comment_column_constraint_sql,
|
||||
exp.CurrentDate: lambda self, e: "CURRENT_DATE",
|
||||
exp.CurrentTime: lambda self, e: "CURRENT_TIME",
|
||||
exp.CurrentTimestamp: lambda self, e: "CURRENT_TIMESTAMP",
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.DateAdd: _date_add,
|
||||
exp.DateAdd: _date_add_sql,
|
||||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF", e.args.get("unit") or exp.Literal.string("day"), e.expression, e.this
|
||||
"DATE_DIFF", f"'{e.args.get('unit', 'day')}'", e.expression, e.this
|
||||
),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
|
||||
|
@ -192,7 +209,7 @@ class DuckDB(Dialect):
|
|||
exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeToUnix: rename_func("EPOCH"),
|
||||
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
|
||||
exp.TsOrDsAdd: _ts_or_ds_add,
|
||||
exp.TsOrDsAdd: _ts_or_ds_add_sql,
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("duckdb"),
|
||||
exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
|
||||
exp.UnixToTime: rename_func("TO_TIMESTAMP"),
|
||||
|
@ -201,7 +218,7 @@ class DuckDB(Dialect):
|
|||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.BINARY: "BLOB",
|
||||
exp.DataType.Type.CHAR: "TEXT",
|
||||
exp.DataType.Type.FLOAT: "REAL",
|
||||
|
@ -212,17 +229,14 @@ class DuckDB(Dialect):
|
|||
exp.DataType.Type.VARCHAR: "TEXT",
|
||||
}
|
||||
|
||||
STAR_MAPPING = {
|
||||
**generator.Generator.STAR_MAPPING,
|
||||
"except": "EXCLUDE",
|
||||
}
|
||||
STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
|
||||
def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str:
|
||||
return super().tablesample_sql(expression, seed_prefix="REPEATABLE")
|
||||
def tablesample_sql(
|
||||
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
|
||||
) -> str:
|
||||
return super().tablesample_sql(expression, seed_prefix="REPEATABLE", sep=sep)
|
||||
|
|
|
@ -81,7 +81,20 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
|
|||
return f"{diff_sql}{multiplier_sql}"
|
||||
|
||||
|
||||
def _array_sort(self: generator.Generator, expression: exp.ArraySort) -> str:
|
||||
def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str:
|
||||
this = expression.this
|
||||
|
||||
if not this.type:
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
annotate_types(this)
|
||||
|
||||
if this.type.is_type(exp.DataType.Type.JSON):
|
||||
return self.sql(this)
|
||||
return self.func("TO_JSON", this, expression.args.get("options"))
|
||||
|
||||
|
||||
def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
|
||||
if expression.expression:
|
||||
self.unsupported("Hive SORT_ARRAY does not support a comparator")
|
||||
return f"SORT_ARRAY({self.sql(expression, 'this')})"
|
||||
|
@ -91,11 +104,11 @@ def _property_sql(self: generator.Generator, expression: exp.Property) -> str:
|
|||
return f"'{expression.name}'={self.sql(expression, 'value')}"
|
||||
|
||||
|
||||
def _str_to_unix(self: generator.Generator, expression: exp.StrToUnix) -> str:
|
||||
def _str_to_unix_sql(self: generator.Generator, expression: exp.StrToUnix) -> str:
|
||||
return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression))
|
||||
|
||||
|
||||
def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
|
||||
def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format not in (Hive.time_format, Hive.date_format):
|
||||
|
@ -103,7 +116,7 @@ def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
|
|||
return f"CAST({this} AS DATE)"
|
||||
|
||||
|
||||
def _str_to_time(self: generator.Generator, expression: exp.StrToTime) -> str:
|
||||
def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format not in (Hive.time_format, Hive.date_format):
|
||||
|
@ -214,6 +227,7 @@ class Hive(Dialect):
|
|||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
|
||||
"BASE64": exp.ToBase64.from_arg_list,
|
||||
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
|
||||
"DATE_ADD": lambda args: exp.TsOrDsAdd(
|
||||
this=seq_get(args, 0),
|
||||
|
@ -251,6 +265,7 @@ class Hive(Dialect):
|
|||
"SPLIT": exp.RegexpSplit.from_arg_list,
|
||||
"TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"),
|
||||
"TO_JSON": exp.JSONFormat.from_arg_list,
|
||||
"UNBASE64": exp.FromBase64.from_arg_list,
|
||||
"UNIX_TIMESTAMP": format_time_lambda(exp.StrToUnix, "hive", True),
|
||||
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
|
||||
}
|
||||
|
@ -280,16 +295,20 @@ class Hive(Dialect):
|
|||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
**transforms.ELIMINATE_QUALIFY, # type: ignore
|
||||
exp.Group: transforms.preprocess([transforms.unalias_group]),
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_qualify, transforms.unnest_to_explode]
|
||||
[
|
||||
transforms.eliminate_qualify,
|
||||
transforms.eliminate_distinct_on,
|
||||
transforms.unnest_to_explode,
|
||||
]
|
||||
),
|
||||
exp.Property: _property_sql,
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.ArrayConcat: rename_func("CONCAT"),
|
||||
exp.ArrayJoin: lambda self, e: self.func("CONCAT_WS", e.expression, e.this),
|
||||
exp.ArraySize: rename_func("SIZE"),
|
||||
exp.ArraySort: _array_sort,
|
||||
exp.ArraySort: _array_sort_sql,
|
||||
exp.With: no_recursive_cte_sql,
|
||||
exp.DateAdd: _add_date_sql,
|
||||
exp.DateDiff: _date_diff_sql,
|
||||
|
@ -298,12 +317,13 @@ class Hive(Dialect):
|
|||
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
|
||||
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
|
||||
exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
|
||||
exp.FromBase64: rename_func("UNBASE64"),
|
||||
exp.If: if_sql,
|
||||
exp.Index: _index_sql,
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
|
||||
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
|
||||
exp.JSONFormat: rename_func("TO_JSON"),
|
||||
exp.JSONFormat: _json_format_sql,
|
||||
exp.Map: var_map_sql,
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
|
@ -318,9 +338,9 @@ class Hive(Dialect):
|
|||
exp.SetAgg: rename_func("COLLECT_SET"),
|
||||
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
|
||||
exp.StrPosition: strposition_to_locate_sql,
|
||||
exp.StrToDate: _str_to_date,
|
||||
exp.StrToTime: _str_to_time,
|
||||
exp.StrToUnix: _str_to_unix,
|
||||
exp.StrToDate: _str_to_date_sql,
|
||||
exp.StrToTime: _str_to_time_sql,
|
||||
exp.StrToUnix: _str_to_unix_sql,
|
||||
exp.StructExtract: struct_extract_sql,
|
||||
exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}",
|
||||
exp.TimeStrToDate: rename_func("TO_DATE"),
|
||||
|
@ -328,6 +348,7 @@ class Hive(Dialect):
|
|||
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.TimeToStr: _time_to_str,
|
||||
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.ToBase64: rename_func("BASE64"),
|
||||
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)",
|
||||
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.TsOrDsToDate: _to_date_sql,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
arrow_json_extract_scalar_sql,
|
||||
|
@ -403,6 +403,7 @@ class MySQL(Dialect):
|
|||
exp.Min: min_or_least,
|
||||
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
|
||||
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.StrPosition: strposition_to_locate_sql,
|
||||
exp.StrToDate: _str_to_date_sql,
|
||||
exp.StrToTime: _str_to_date_sql,
|
||||
|
|
|
@ -34,6 +34,8 @@ def _parse_xml_table(self) -> exp.XMLTable:
|
|||
|
||||
|
||||
class Oracle(Dialect):
|
||||
alias_post_tablesample = True
|
||||
|
||||
# https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212
|
||||
# https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes
|
||||
time_mapping = {
|
||||
|
@ -121,21 +123,23 @@ class Oracle(Dialect):
|
|||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
exp.DateStrToDate: lambda self, e: self.func(
|
||||
"TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD")
|
||||
),
|
||||
exp.Group: transforms.preprocess([transforms.unalias_group]),
|
||||
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.IfNull: rename_func("NVL"),
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
|
||||
exp.Substring: rename_func("SUBSTR"),
|
||||
exp.Table: lambda self, e: self.table_sql(e, sep=" "),
|
||||
exp.TableSample: lambda self, e: self.tablesample_sql(e, sep=" "),
|
||||
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.Trim: trim_sql,
|
||||
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
|
||||
exp.IfNull: rename_func("NVL"),
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
|
@ -164,14 +168,19 @@ class Oracle(Dialect):
|
|||
return f"XMLTABLE({self.sep('')}{self.indent(this + passing + by_ref + columns)}{self.seg(')', sep='')}"
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
VAR_SINGLE_TOKENS = {"@"}
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"(+)": TokenType.JOIN_MARKER,
|
||||
"BINARY_DOUBLE": TokenType.DOUBLE,
|
||||
"BINARY_FLOAT": TokenType.FLOAT,
|
||||
"COLUMNS": TokenType.COLUMN,
|
||||
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
|
||||
"MINUS": TokenType.EXCEPT,
|
||||
"NVARCHAR2": TokenType.NVARCHAR,
|
||||
"RETURNING": TokenType.RETURNING,
|
||||
"SAMPLE": TokenType.TABLE_SAMPLE,
|
||||
"START": TokenType.BEGIN,
|
||||
"TOP": TokenType.TOP,
|
||||
"VARCHAR2": TokenType.VARCHAR,
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
arrow_json_extract_scalar_sql,
|
||||
|
@ -20,7 +22,6 @@ from sqlglot.dialects.dialect import (
|
|||
from sqlglot.helper import seq_get
|
||||
from sqlglot.parser import binary_range_parser
|
||||
from sqlglot.tokens import TokenType
|
||||
from sqlglot.transforms import preprocess, remove_target_from_merge
|
||||
|
||||
DATE_DIFF_FACTOR = {
|
||||
"MICROSECOND": " * 1000000",
|
||||
|
@ -274,8 +275,7 @@ class Postgres(Dialect):
|
|||
TokenType.HASH: exp.BitwiseXor,
|
||||
}
|
||||
|
||||
FACTOR = {
|
||||
**parser.Parser.FACTOR,
|
||||
EXPONENT = {
|
||||
TokenType.CARET: exp.Pow,
|
||||
}
|
||||
|
||||
|
@ -286,6 +286,12 @@ class Postgres(Dialect):
|
|||
TokenType.LT_AT: binary_range_parser(exp.ArrayContained),
|
||||
}
|
||||
|
||||
def _parse_factor(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_tokens(self._parse_exponent, self.FACTOR)
|
||||
|
||||
def _parse_exponent(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_tokens(self._parse_unary, self.EXPONENT)
|
||||
|
||||
def _parse_date_part(self) -> exp.Expression:
|
||||
part = self._parse_type()
|
||||
self._match(TokenType.COMMA)
|
||||
|
@ -316,7 +322,7 @@ class Postgres(Dialect):
|
|||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
|
||||
exp.ColumnDef: preprocess(
|
||||
exp.ColumnDef: transforms.preprocess(
|
||||
[
|
||||
_auto_increment_to_serial,
|
||||
_serial_to_generated,
|
||||
|
@ -341,7 +347,7 @@ class Postgres(Dialect):
|
|||
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
|
||||
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
|
||||
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
|
||||
exp.Merge: preprocess([remove_target_from_merge]),
|
||||
exp.Merge: transforms.preprocess([transforms.remove_target_from_merge]),
|
||||
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
|
||||
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
|
||||
exp.StrPosition: str_position_sql,
|
||||
|
|
|
@ -130,7 +130,7 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s
|
|||
def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str:
|
||||
start = expression.args["start"]
|
||||
end = expression.args["end"]
|
||||
step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series
|
||||
step = expression.args.get("step")
|
||||
|
||||
target_type = None
|
||||
|
||||
|
@ -147,7 +147,11 @@ def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) ->
|
|||
else:
|
||||
start = exp.Cast(this=start, to=to)
|
||||
|
||||
return self.func("SEQUENCE", start, end, step)
|
||||
sql = self.func("SEQUENCE", start, end, step)
|
||||
if isinstance(expression.parent, exp.Table):
|
||||
sql = f"UNNEST({sql})"
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
def _ensure_utf8(charset: exp.Literal) -> None:
|
||||
|
@ -204,6 +208,7 @@ class Presto(Dialect):
|
|||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
|
||||
"APPROX_PERCENTILE": _approx_percentile,
|
||||
"CARDINALITY": exp.ArraySize.from_arg_list,
|
||||
"CONTAINS": exp.ArrayContains.from_arg_list,
|
||||
"DATE_ADD": lambda args: exp.DateAdd(
|
||||
|
@ -219,23 +224,23 @@ class Presto(Dialect):
|
|||
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
|
||||
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
|
||||
"DATE_TRUNC": date_trunc_to_time,
|
||||
"FROM_HEX": exp.Unhex.from_arg_list,
|
||||
"FROM_UNIXTIME": _from_unixtime,
|
||||
"FROM_UTF8": lambda args: exp.Decode(
|
||||
this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8")
|
||||
),
|
||||
"NOW": exp.CurrentTimestamp.from_arg_list,
|
||||
"SEQUENCE": exp.GenerateSeries.from_arg_list,
|
||||
"STRPOS": lambda args: exp.StrPosition(
|
||||
this=seq_get(args, 0),
|
||||
substr=seq_get(args, 1),
|
||||
instance=seq_get(args, 2),
|
||||
),
|
||||
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
|
||||
"APPROX_PERCENTILE": _approx_percentile,
|
||||
"FROM_HEX": exp.Unhex.from_arg_list,
|
||||
"TO_HEX": exp.Hex.from_arg_list,
|
||||
"TO_UTF8": lambda args: exp.Encode(
|
||||
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
|
||||
),
|
||||
"FROM_UTF8": lambda args: exp.Decode(
|
||||
this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8")
|
||||
),
|
||||
}
|
||||
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
|
||||
FUNCTION_PARSERS.pop("TRIM")
|
||||
|
@ -264,7 +269,6 @@ class Presto(Dialect):
|
|||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
exp.ApproxDistinct: _approx_distinct_sql,
|
||||
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
|
||||
exp.ArrayConcat: rename_func("CONCAT"),
|
||||
|
@ -290,6 +294,7 @@ class Presto(Dialect):
|
|||
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
|
||||
exp.Encode: _encode_sql,
|
||||
exp.GenerateSeries: _sequence_sql,
|
||||
exp.Group: transforms.preprocess([transforms.unalias_group]),
|
||||
exp.Hex: rename_func("TO_HEX"),
|
||||
exp.If: if_sql,
|
||||
exp.ILike: no_ilike_sql,
|
||||
|
@ -303,7 +308,11 @@ class Presto(Dialect):
|
|||
exp.SafeDivide: no_safe_divide_sql,
|
||||
exp.Schema: _schema_sql,
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_qualify, transforms.explode_to_unnest]
|
||||
[
|
||||
transforms.eliminate_qualify,
|
||||
transforms.eliminate_distinct_on,
|
||||
transforms.explode_to_unnest,
|
||||
]
|
||||
),
|
||||
exp.SortArray: _no_sort_array,
|
||||
exp.StrPosition: rename_func("STRPOS"),
|
||||
|
@ -327,6 +336,9 @@ class Presto(Dialect):
|
|||
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
||||
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
exp.WithinGroup: transforms.preprocess(
|
||||
[transforms.remove_within_group_for_percentiles]
|
||||
),
|
||||
}
|
||||
|
||||
def interval_sql(self, expression: exp.Interval) -> str:
|
||||
|
|
|
@ -52,6 +52,8 @@ class Redshift(Postgres):
|
|||
return this
|
||||
|
||||
class Tokenizer(Postgres.Tokenizer):
|
||||
BIT_STRINGS = []
|
||||
HEX_STRINGS = []
|
||||
STRING_ESCAPES = ["\\"]
|
||||
|
||||
KEYWORDS = {
|
||||
|
@ -90,7 +92,6 @@ class Redshift(Postgres):
|
|||
|
||||
TRANSFORMS = {
|
||||
**Postgres.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
|
||||
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
|
||||
exp.DateAdd: lambda self, e: self.func(
|
||||
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
|
||||
|
@ -102,6 +103,7 @@ class Redshift(Postgres):
|
|||
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
|
||||
exp.JSONExtract: _json_sql,
|
||||
exp.JSONExtractScalar: _json_sql,
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
|
||||
}
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
date_trunc_to_time,
|
||||
|
@ -252,6 +252,7 @@ class Snowflake(Dialect):
|
|||
class Tokenizer(tokens.Tokenizer):
|
||||
QUOTES = ["'", "$$"]
|
||||
STRING_ESCAPES = ["\\", "'"]
|
||||
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
|
@ -305,6 +306,7 @@ class Snowflake(Dialect):
|
|||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
|
||||
exp.StrPosition: lambda self, e: self.func(
|
||||
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
|
||||
|
|
|
@ -2,222 +2,54 @@ from __future__ import annotations
|
|||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, parser
|
||||
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
|
||||
from sqlglot.dialects.hive import Hive
|
||||
from sqlglot import exp
|
||||
from sqlglot.dialects.spark2 import Spark2
|
||||
from sqlglot.helper import seq_get
|
||||
|
||||
|
||||
def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
|
||||
kind = e.args["kind"]
|
||||
properties = e.args.get("properties")
|
||||
def _parse_datediff(args: t.Sequence) -> exp.Expression:
|
||||
"""
|
||||
Although Spark docs don't mention the "unit" argument, Spark3 added support for
|
||||
it at some point. Databricks also supports this variation (see below).
|
||||
|
||||
if kind.upper() == "TABLE" and any(
|
||||
isinstance(prop, exp.TemporaryProperty)
|
||||
for prop in (properties.expressions if properties else [])
|
||||
):
|
||||
return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
|
||||
return create_with_partitions_sql(self, e)
|
||||
For example, in spark-sql (v3.3.1):
|
||||
- SELECT DATEDIFF('2020-01-01', '2020-01-05') results in -4
|
||||
- SELECT DATEDIFF(day, '2020-01-01', '2020-01-05') results in 4
|
||||
|
||||
See also:
|
||||
- https://docs.databricks.com/sql/language-manual/functions/datediff3.html
|
||||
- https://docs.databricks.com/sql/language-manual/functions/datediff.html
|
||||
"""
|
||||
unit = None
|
||||
this = seq_get(args, 0)
|
||||
expression = seq_get(args, 1)
|
||||
|
||||
if len(args) == 3:
|
||||
unit = this
|
||||
this = args[2]
|
||||
|
||||
return exp.DateDiff(
|
||||
this=exp.TsOrDsToDate(this=this), expression=exp.TsOrDsToDate(this=expression), unit=unit
|
||||
)
|
||||
|
||||
|
||||
def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
|
||||
keys = self.sql(expression.args["keys"])
|
||||
values = self.sql(expression.args["values"])
|
||||
return f"MAP_FROM_ARRAYS({keys}, {values})"
|
||||
|
||||
|
||||
def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format == Hive.date_format:
|
||||
return f"TO_DATE({this})"
|
||||
return f"TO_DATE({this}, {time_format})"
|
||||
|
||||
|
||||
def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
|
||||
scale = expression.args.get("scale")
|
||||
timestamp = self.sql(expression, "this")
|
||||
if scale is None:
|
||||
return f"FROM_UNIXTIME({timestamp})"
|
||||
if scale == exp.UnixToTime.SECONDS:
|
||||
return f"TIMESTAMP_SECONDS({timestamp})"
|
||||
if scale == exp.UnixToTime.MILLIS:
|
||||
return f"TIMESTAMP_MILLIS({timestamp})"
|
||||
if scale == exp.UnixToTime.MICROS:
|
||||
return f"TIMESTAMP_MICROS({timestamp})"
|
||||
|
||||
raise ValueError("Improper scale for timestamp")
|
||||
|
||||
|
||||
class Spark(Hive):
|
||||
class Parser(Hive.Parser):
|
||||
class Spark(Spark2):
|
||||
class Parser(Spark2.Parser):
|
||||
FUNCTIONS = {
|
||||
**Hive.Parser.FUNCTIONS, # type: ignore
|
||||
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
|
||||
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
|
||||
"LEFT": lambda args: exp.Substring(
|
||||
this=seq_get(args, 0),
|
||||
start=exp.Literal.number(1),
|
||||
length=seq_get(args, 1),
|
||||
),
|
||||
"SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
|
||||
this=seq_get(args, 0),
|
||||
expression=seq_get(args, 1),
|
||||
),
|
||||
"SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
|
||||
this=seq_get(args, 0),
|
||||
expression=seq_get(args, 1),
|
||||
),
|
||||
"RIGHT": lambda args: exp.Substring(
|
||||
this=seq_get(args, 0),
|
||||
start=exp.Sub(
|
||||
this=exp.Length(this=seq_get(args, 0)),
|
||||
expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
|
||||
),
|
||||
length=seq_get(args, 1),
|
||||
),
|
||||
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
||||
"BOOLEAN": lambda args: exp.Cast(
|
||||
this=seq_get(args, 0), to=exp.DataType.build("boolean")
|
||||
),
|
||||
"IIF": exp.If.from_arg_list,
|
||||
"INT": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("int")),
|
||||
"AGGREGATE": exp.Reduce.from_arg_list,
|
||||
"DAYOFWEEK": lambda args: exp.DayOfWeek(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"DAYOFMONTH": lambda args: exp.DayOfMonth(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"DAYOFYEAR": lambda args: exp.DayOfYear(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
|
||||
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
|
||||
this=seq_get(args, 1),
|
||||
unit=exp.var(seq_get(args, 0)),
|
||||
),
|
||||
"STRING": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("string")),
|
||||
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
|
||||
"TIMESTAMP": lambda args: exp.Cast(
|
||||
this=seq_get(args, 0), to=exp.DataType.build("timestamp")
|
||||
),
|
||||
**Spark2.Parser.FUNCTIONS, # type: ignore
|
||||
"DATEDIFF": _parse_datediff,
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS, # type: ignore
|
||||
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
|
||||
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
|
||||
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
|
||||
"MERGE": lambda self: self._parse_join_hint("MERGE"),
|
||||
"SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"),
|
||||
"MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"),
|
||||
"SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"),
|
||||
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
|
||||
}
|
||||
class Generator(Spark2.Generator):
|
||||
TRANSFORMS = Spark2.Generator.TRANSFORMS.copy()
|
||||
TRANSFORMS.pop(exp.DateDiff)
|
||||
|
||||
def _parse_add_column(self) -> t.Optional[exp.Expression]:
|
||||
return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
|
||||
def datediff_sql(self, expression: exp.DateDiff) -> str:
|
||||
unit = self.sql(expression, "unit")
|
||||
end = self.sql(expression, "this")
|
||||
start = self.sql(expression, "expression")
|
||||
|
||||
def _parse_drop_column(self) -> t.Optional[exp.Expression]:
|
||||
return self._match_text_seq("DROP", "COLUMNS") and self.expression(
|
||||
exp.Drop,
|
||||
this=self._parse_schema(),
|
||||
kind="COLUMNS",
|
||||
)
|
||||
if unit:
|
||||
return self.func("DATEDIFF", unit, start, end)
|
||||
|
||||
def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
|
||||
# Spark doesn't add a suffix to the pivot columns when there's a single aggregation
|
||||
if len(pivot_columns) == 1:
|
||||
return [""]
|
||||
|
||||
names = []
|
||||
for agg in pivot_columns:
|
||||
if isinstance(agg, exp.Alias):
|
||||
names.append(agg.alias)
|
||||
else:
|
||||
"""
|
||||
This case corresponds to aggregations without aliases being used as suffixes
|
||||
(e.g. col_avg(foo)). We need to unquote identifiers because they're going to
|
||||
be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
|
||||
Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
|
||||
|
||||
Moreover, function names are lowercased in order to mimic Spark's naming scheme.
|
||||
"""
|
||||
agg_all_unquoted = agg.transform(
|
||||
lambda node: exp.Identifier(this=node.name, quoted=False)
|
||||
if isinstance(node, exp.Identifier)
|
||||
else node
|
||||
)
|
||||
names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
|
||||
|
||||
return names
|
||||
|
||||
class Generator(Hive.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**Hive.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.TINYINT: "BYTE",
|
||||
exp.DataType.Type.SMALLINT: "SHORT",
|
||||
exp.DataType.Type.BIGINT: "LONG",
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**Hive.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.CollateProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**Hive.Generator.TRANSFORMS, # type: ignore
|
||||
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
|
||||
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
|
||||
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
|
||||
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
|
||||
exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
|
||||
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
|
||||
exp.StrToDate: _str_to_date,
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.Create: _create_sql,
|
||||
exp.Map: _map_sql,
|
||||
exp.Reduce: rename_func("AGGREGATE"),
|
||||
exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
|
||||
exp.TimestampTrunc: lambda self, e: self.func(
|
||||
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
|
||||
),
|
||||
exp.Trim: trim_sql,
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
exp.DateFromParts: rename_func("MAKE_DATE"),
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
|
||||
exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
|
||||
}
|
||||
TRANSFORMS.pop(exp.ArraySort)
|
||||
TRANSFORMS.pop(exp.ILike)
|
||||
|
||||
WRAP_DERIVED_VALUES = False
|
||||
CREATE_FUNCTION_RETURN_AS = False
|
||||
|
||||
def cast_sql(self, expression: exp.Cast) -> str:
|
||||
if isinstance(expression.this, exp.Cast) and expression.this.is_type(
|
||||
exp.DataType.Type.JSON
|
||||
):
|
||||
schema = f"'{self.sql(expression, 'to')}'"
|
||||
return self.func("FROM_JSON", expression.this.this, schema)
|
||||
if expression.to.is_type(exp.DataType.Type.JSON):
|
||||
return self.func("TO_JSON", expression.this)
|
||||
|
||||
return super(Spark.Generator, self).cast_sql(expression)
|
||||
|
||||
class Tokenizer(Hive.Tokenizer):
|
||||
HEX_STRINGS = [("X'", "'")]
|
||||
return self.func("DATEDIFF", end, start)
|
||||
|
|
238
sqlglot/dialects/spark2.py
Normal file
238
sqlglot/dialects/spark2.py
Normal file
|
@ -0,0 +1,238 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, parser, transforms
|
||||
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
|
||||
from sqlglot.dialects.hive import Hive
|
||||
from sqlglot.helper import seq_get
|
||||
|
||||
|
||||
def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
|
||||
kind = e.args["kind"]
|
||||
properties = e.args.get("properties")
|
||||
|
||||
if kind.upper() == "TABLE" and any(
|
||||
isinstance(prop, exp.TemporaryProperty)
|
||||
for prop in (properties.expressions if properties else [])
|
||||
):
|
||||
return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
|
||||
return create_with_partitions_sql(self, e)
|
||||
|
||||
|
||||
def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
|
||||
keys = self.sql(expression.args["keys"])
|
||||
values = self.sql(expression.args["values"])
|
||||
return f"MAP_FROM_ARRAYS({keys}, {values})"
|
||||
|
||||
|
||||
def _parse_as_cast(to_type: str) -> t.Callable[[t.Sequence], exp.Expression]:
|
||||
return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type))
|
||||
|
||||
|
||||
def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format == Hive.date_format:
|
||||
return f"TO_DATE({this})"
|
||||
return f"TO_DATE({this}, {time_format})"
|
||||
|
||||
|
||||
def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
|
||||
scale = expression.args.get("scale")
|
||||
timestamp = self.sql(expression, "this")
|
||||
if scale is None:
|
||||
return f"CAST(FROM_UNIXTIME({timestamp}) AS TIMESTAMP)"
|
||||
if scale == exp.UnixToTime.SECONDS:
|
||||
return f"TIMESTAMP_SECONDS({timestamp})"
|
||||
if scale == exp.UnixToTime.MILLIS:
|
||||
return f"TIMESTAMP_MILLIS({timestamp})"
|
||||
if scale == exp.UnixToTime.MICROS:
|
||||
return f"TIMESTAMP_MICROS({timestamp})"
|
||||
|
||||
raise ValueError("Improper scale for timestamp")
|
||||
|
||||
|
||||
class Spark2(Hive):
|
||||
class Parser(Hive.Parser):
|
||||
FUNCTIONS = {
|
||||
**Hive.Parser.FUNCTIONS, # type: ignore
|
||||
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
|
||||
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
|
||||
"LEFT": lambda args: exp.Substring(
|
||||
this=seq_get(args, 0),
|
||||
start=exp.Literal.number(1),
|
||||
length=seq_get(args, 1),
|
||||
),
|
||||
"SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
|
||||
this=seq_get(args, 0),
|
||||
expression=seq_get(args, 1),
|
||||
),
|
||||
"SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
|
||||
this=seq_get(args, 0),
|
||||
expression=seq_get(args, 1),
|
||||
),
|
||||
"RIGHT": lambda args: exp.Substring(
|
||||
this=seq_get(args, 0),
|
||||
start=exp.Sub(
|
||||
this=exp.Length(this=seq_get(args, 0)),
|
||||
expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
|
||||
),
|
||||
length=seq_get(args, 1),
|
||||
),
|
||||
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
||||
"IIF": exp.If.from_arg_list,
|
||||
"AGGREGATE": exp.Reduce.from_arg_list,
|
||||
"DAYOFWEEK": lambda args: exp.DayOfWeek(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"DAYOFMONTH": lambda args: exp.DayOfMonth(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"DAYOFYEAR": lambda args: exp.DayOfYear(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
|
||||
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
|
||||
this=seq_get(args, 1),
|
||||
unit=exp.var(seq_get(args, 0)),
|
||||
),
|
||||
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
|
||||
"BOOLEAN": _parse_as_cast("boolean"),
|
||||
"DOUBLE": _parse_as_cast("double"),
|
||||
"FLOAT": _parse_as_cast("float"),
|
||||
"INT": _parse_as_cast("int"),
|
||||
"STRING": _parse_as_cast("string"),
|
||||
"TIMESTAMP": _parse_as_cast("timestamp"),
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS, # type: ignore
|
||||
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
|
||||
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
|
||||
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
|
||||
"MERGE": lambda self: self._parse_join_hint("MERGE"),
|
||||
"SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"),
|
||||
"MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"),
|
||||
"SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"),
|
||||
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
|
||||
}
|
||||
|
||||
def _parse_add_column(self) -> t.Optional[exp.Expression]:
|
||||
return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
|
||||
|
||||
def _parse_drop_column(self) -> t.Optional[exp.Expression]:
|
||||
return self._match_text_seq("DROP", "COLUMNS") and self.expression(
|
||||
exp.Drop,
|
||||
this=self._parse_schema(),
|
||||
kind="COLUMNS",
|
||||
)
|
||||
|
||||
def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
|
||||
# Spark doesn't add a suffix to the pivot columns when there's a single aggregation
|
||||
if len(pivot_columns) == 1:
|
||||
return [""]
|
||||
|
||||
names = []
|
||||
for agg in pivot_columns:
|
||||
if isinstance(agg, exp.Alias):
|
||||
names.append(agg.alias)
|
||||
else:
|
||||
"""
|
||||
This case corresponds to aggregations without aliases being used as suffixes
|
||||
(e.g. col_avg(foo)). We need to unquote identifiers because they're going to
|
||||
be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
|
||||
Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
|
||||
|
||||
Moreover, function names are lowercased in order to mimic Spark's naming scheme.
|
||||
"""
|
||||
agg_all_unquoted = agg.transform(
|
||||
lambda node: exp.Identifier(this=node.name, quoted=False)
|
||||
if isinstance(node, exp.Identifier)
|
||||
else node
|
||||
)
|
||||
names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
|
||||
|
||||
return names
|
||||
|
||||
class Generator(Hive.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**Hive.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.TINYINT: "BYTE",
|
||||
exp.DataType.Type.SMALLINT: "SHORT",
|
||||
exp.DataType.Type.BIGINT: "LONG",
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**Hive.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.CollateProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**Hive.Generator.TRANSFORMS, # type: ignore
|
||||
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
|
||||
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
|
||||
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
|
||||
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
|
||||
exp.Create: _create_sql,
|
||||
exp.DateFromParts: rename_func("MAKE_DATE"),
|
||||
exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
|
||||
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.Map: _map_sql,
|
||||
exp.Pivot: transforms.preprocess([transforms.unqualify_pivot_columns]),
|
||||
exp.Reduce: rename_func("AGGREGATE"),
|
||||
exp.StrToDate: _str_to_date,
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimestampTrunc: lambda self, e: self.func(
|
||||
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
|
||||
),
|
||||
exp.Trim: trim_sql,
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
|
||||
exp.WithinGroup: transforms.preprocess(
|
||||
[transforms.remove_within_group_for_percentiles]
|
||||
),
|
||||
}
|
||||
TRANSFORMS.pop(exp.ArrayJoin)
|
||||
TRANSFORMS.pop(exp.ArraySort)
|
||||
TRANSFORMS.pop(exp.ILike)
|
||||
|
||||
WRAP_DERIVED_VALUES = False
|
||||
CREATE_FUNCTION_RETURN_AS = False
|
||||
|
||||
def cast_sql(self, expression: exp.Cast) -> str:
|
||||
if isinstance(expression.this, exp.Cast) and expression.this.is_type(
|
||||
exp.DataType.Type.JSON
|
||||
):
|
||||
schema = f"'{self.sql(expression, 'to')}'"
|
||||
return self.func("FROM_JSON", expression.this.this, schema)
|
||||
if expression.to.is_type(exp.DataType.Type.JSON):
|
||||
return self.func("TO_JSON", expression.this)
|
||||
|
||||
return super(Hive.Generator, self).cast_sql(expression)
|
||||
|
||||
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
|
||||
return super().columndef_sql(
|
||||
expression,
|
||||
sep=": "
|
||||
if isinstance(expression.parent, exp.DataType)
|
||||
and expression.parent.is_type(exp.DataType.Type.STRUCT)
|
||||
else sep,
|
||||
)
|
||||
|
||||
class Tokenizer(Hive.Tokenizer):
|
||||
HEX_STRINGS = [("X'", "'")]
|
|
@ -22,6 +22,40 @@ def _date_add_sql(self, expression):
|
|||
return self.func("DATE", expression.this, modifier)
|
||||
|
||||
|
||||
def _transform_create(expression: exp.Expression) -> exp.Expression:
|
||||
"""Move primary key to a column and enforce auto_increment on primary keys."""
|
||||
schema = expression.this
|
||||
|
||||
if isinstance(expression, exp.Create) and isinstance(schema, exp.Schema):
|
||||
defs = {}
|
||||
primary_key = None
|
||||
|
||||
for e in schema.expressions:
|
||||
if isinstance(e, exp.ColumnDef):
|
||||
defs[e.name] = e
|
||||
elif isinstance(e, exp.PrimaryKey):
|
||||
primary_key = e
|
||||
|
||||
if primary_key and len(primary_key.expressions) == 1:
|
||||
column = defs[primary_key.expressions[0].name]
|
||||
column.append(
|
||||
"constraints", exp.ColumnConstraint(kind=exp.PrimaryKeyColumnConstraint())
|
||||
)
|
||||
schema.expressions.remove(primary_key)
|
||||
else:
|
||||
for column in defs.values():
|
||||
auto_increment = None
|
||||
for constraint in column.constraints.copy():
|
||||
if isinstance(constraint.kind, exp.PrimaryKeyColumnConstraint):
|
||||
break
|
||||
if isinstance(constraint.kind, exp.AutoIncrementColumnConstraint):
|
||||
auto_increment = constraint
|
||||
if auto_increment:
|
||||
column.constraints.remove(auto_increment)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
class SQLite(Dialect):
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
IDENTIFIERS = ['"', ("[", "]"), "`"]
|
||||
|
@ -65,8 +99,8 @@ class SQLite(Dialect):
|
|||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.ELIMINATE_QUALIFY, # type: ignore
|
||||
exp.CountIf: count_if_to_sum,
|
||||
exp.Create: transforms.preprocess([_transform_create]),
|
||||
exp.CurrentDate: lambda *_: "CURRENT_DATE",
|
||||
exp.CurrentTime: lambda *_: "CURRENT_TIME",
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
|
@ -80,14 +114,17 @@ class SQLite(Dialect):
|
|||
exp.Levenshtein: rename_func("EDITDIST3"),
|
||||
exp.LogicalOr: rename_func("MAX"),
|
||||
exp.LogicalAnd: rename_func("MIN"),
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_distinct_on, transforms.eliminate_qualify]
|
||||
),
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
exp.TimeStrToTime: lambda self, e: self.sql(e, "this"),
|
||||
exp.TryCast: no_trycast_sql,
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
k: exp.Properties.Location.UNSUPPORTED
|
||||
for k, v in generator.Generator.PROPERTIES_LOCATION.items()
|
||||
}
|
||||
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
|
|
|
@ -34,6 +34,7 @@ class StarRocks(MySQL):
|
|||
exp.JSONExtractScalar: arrow_json_extract_sql,
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.DateDiff: rename_func("DATEDIFF"),
|
||||
exp.RegexpLike: rename_func("REGEXP"),
|
||||
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimestampTrunc: lambda self, e: self.func(
|
||||
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, parser
|
||||
from sqlglot import exp, generator, parser, transforms
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
|
||||
|
||||
|
@ -29,6 +29,7 @@ class Tableau(Dialect):
|
|||
exp.If: _if_sql,
|
||||
exp.Coalesce: _coalesce_sql,
|
||||
exp.Count: _count_sql,
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
format_time_lambda,
|
||||
|
@ -148,6 +148,7 @@ class Teradata(Dialect):
|
|||
**generator.Generator.TRANSFORMS,
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.TimeToStr: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
import re
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
max_or_greatest,
|
||||
|
@ -259,8 +259,8 @@ class TSQL(Dialect):
|
|||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
IDENTIFIERS = ['"', ("[", "]")]
|
||||
|
||||
QUOTES = ["'", '"']
|
||||
HEX_STRINGS = [("0x", ""), ("0X", "")]
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
|
@ -463,17 +463,18 @@ class TSQL(Dialect):
|
|||
exp.DateDiff: generate_date_delta_with_unit_sql,
|
||||
exp.CurrentDate: rename_func("GETDATE"),
|
||||
exp.CurrentTimestamp: rename_func("GETDATE"),
|
||||
exp.If: rename_func("IIF"),
|
||||
exp.NumberToStr: _format_sql,
|
||||
exp.TimeToStr: _format_sql,
|
||||
exp.GroupConcat: _string_agg_sql,
|
||||
exp.If: rename_func("IIF"),
|
||||
exp.Max: max_or_greatest,
|
||||
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
|
||||
exp.Min: min_or_least,
|
||||
exp.NumberToStr: _format_sql,
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
|
||||
exp.SHA2: lambda self, e: self.func(
|
||||
"HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
|
||||
),
|
||||
exp.TimeToStr: _format_sql,
|
||||
}
|
||||
|
||||
TRANSFORMS.pop(exp.ReturnsProperty)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue