Merging upstream version 12.2.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
fffa0d5761
commit
62b2b24d3b
100 changed files with 35022 additions and 30936 deletions
|
@ -50,7 +50,7 @@ if t.TYPE_CHECKING:
|
|||
T = t.TypeVar("T", bound=Expression)
|
||||
|
||||
|
||||
__version__ = "11.7.1"
|
||||
__version__ = "12.2.0"
|
||||
|
||||
pretty = False
|
||||
"""Whether to format generated SQL by default."""
|
||||
|
@ -181,7 +181,7 @@ def transpile(
|
|||
Returns:
|
||||
The list of transpiled SQL statements.
|
||||
"""
|
||||
write = write or read if identity else write
|
||||
write = (read if write is None else write) if identity else write
|
||||
return [
|
||||
Dialect.get_or_raise(write)().generate(expression, **opts)
|
||||
for expression in parse(sql, read, error_level=error_level)
|
||||
|
|
|
@ -747,11 +747,11 @@ def ascii(col: ColumnOrLiteral) -> Column:
|
|||
|
||||
|
||||
def base64(col: ColumnOrLiteral) -> Column:
|
||||
return Column.invoke_anonymous_function(col, "BASE64")
|
||||
return Column.invoke_expression_over_column(col, expression.ToBase64)
|
||||
|
||||
|
||||
def unbase64(col: ColumnOrLiteral) -> Column:
|
||||
return Column.invoke_anonymous_function(col, "UNBASE64")
|
||||
return Column.invoke_expression_over_column(col, expression.FromBase64)
|
||||
|
||||
|
||||
def ltrim(col: ColumnOrName) -> Column:
|
||||
|
|
|
@ -70,6 +70,7 @@ from sqlglot.dialects.presto import Presto
|
|||
from sqlglot.dialects.redshift import Redshift
|
||||
from sqlglot.dialects.snowflake import Snowflake
|
||||
from sqlglot.dialects.spark import Spark
|
||||
from sqlglot.dialects.spark2 import Spark2
|
||||
from sqlglot.dialects.sqlite import SQLite
|
||||
from sqlglot.dialects.starrocks import StarRocks
|
||||
from sqlglot.dialects.tableau import Tableau
|
||||
|
|
|
@ -39,18 +39,26 @@ def _date_add_sql(
|
|||
|
||||
def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str:
|
||||
if not isinstance(expression.unnest().parent, exp.From):
|
||||
expression = t.cast(exp.Values, transforms.remove_precision_parameterized_types(expression))
|
||||
return self.values_sql(expression)
|
||||
rows = [tuple_exp.expressions for tuple_exp in expression.find_all(exp.Tuple)]
|
||||
structs = []
|
||||
for row in rows:
|
||||
aliases = [
|
||||
exp.alias_(value, column_name)
|
||||
for value, column_name in zip(row, expression.args["alias"].args["columns"])
|
||||
]
|
||||
structs.append(exp.Struct(expressions=aliases))
|
||||
unnest_exp = exp.Unnest(expressions=[exp.Array(expressions=structs)])
|
||||
return self.unnest_sql(unnest_exp)
|
||||
|
||||
alias = expression.args.get("alias")
|
||||
|
||||
structs = [
|
||||
exp.Struct(
|
||||
expressions=[
|
||||
exp.alias_(value, column_name)
|
||||
for value, column_name in zip(
|
||||
t.expressions,
|
||||
alias.columns
|
||||
if alias and alias.columns
|
||||
else (f"_c{i}" for i in range(len(t.expressions))),
|
||||
)
|
||||
]
|
||||
)
|
||||
for t in expression.find_all(exp.Tuple)
|
||||
]
|
||||
|
||||
return self.unnest_sql(exp.Unnest(expressions=[exp.Array(expressions=structs)]))
|
||||
|
||||
|
||||
def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsProperty) -> str:
|
||||
|
@ -128,6 +136,7 @@ class BigQuery(Dialect):
|
|||
IDENTIFIERS = ["`"]
|
||||
STRING_ESCAPES = ["\\"]
|
||||
HEX_STRINGS = [("0x", ""), ("0X", "")]
|
||||
BYTE_STRINGS = [("b'", "'"), ("B'", "'")]
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
|
@ -139,6 +148,7 @@ class BigQuery(Dialect):
|
|||
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
||||
"FLOAT64": TokenType.DOUBLE,
|
||||
"INT64": TokenType.BIGINT,
|
||||
"BYTES": TokenType.BINARY,
|
||||
"NOT DETERMINISTIC": TokenType.VOLATILE,
|
||||
"UNKNOWN": TokenType.NULL,
|
||||
}
|
||||
|
@ -153,7 +163,7 @@ class BigQuery(Dialect):
|
|||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"DATE_TRUNC": lambda args: exp.DateTrunc(
|
||||
unit=exp.Literal.string(seq_get(args, 1).name), # type: ignore
|
||||
unit=exp.Literal.string(str(seq_get(args, 1))),
|
||||
this=seq_get(args, 0),
|
||||
),
|
||||
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
|
||||
|
@ -206,6 +216,12 @@ class BigQuery(Dialect):
|
|||
"NOT DETERMINISTIC": lambda self: self.expression(
|
||||
exp.StabilityProperty, this=exp.Literal.string("VOLATILE")
|
||||
),
|
||||
"OPTIONS": lambda self: self._parse_with_property(),
|
||||
}
|
||||
|
||||
CONSTRAINT_PARSERS = {
|
||||
**parser.Parser.CONSTRAINT_PARSERS, # type: ignore
|
||||
"OPTIONS": lambda self: exp.Properties(expressions=self._parse_with_property()),
|
||||
}
|
||||
|
||||
class Generator(generator.Generator):
|
||||
|
@ -217,11 +233,11 @@ class BigQuery(Dialect):
|
|||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore
|
||||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||
exp.AtTimeZone: lambda self, e: self.func(
|
||||
"TIMESTAMP", self.func("DATETIME", e.this, e.args.get("zone"))
|
||||
),
|
||||
exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
|
||||
exp.DateAdd: _date_add_sql("DATE", "ADD"),
|
||||
exp.DateSub: _date_add_sql("DATE", "SUB"),
|
||||
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
|
||||
|
@ -234,7 +250,9 @@ class BigQuery(Dialect):
|
|||
exp.IntDiv: rename_func("DIV"),
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.Select: transforms.preprocess([_unqualify_unnest]),
|
||||
exp.Select: transforms.preprocess(
|
||||
[_unqualify_unnest, transforms.eliminate_distinct_on]
|
||||
),
|
||||
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
|
||||
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
|
||||
exp.TimeSub: _date_add_sql("TIME", "SUB"),
|
||||
|
@ -259,6 +277,7 @@ class BigQuery(Dialect):
|
|||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC",
|
||||
exp.DataType.Type.BIGINT: "INT64",
|
||||
exp.DataType.Type.BINARY: "BYTES",
|
||||
exp.DataType.Type.BOOLEAN: "BOOL",
|
||||
exp.DataType.Type.CHAR: "STRING",
|
||||
exp.DataType.Type.DECIMAL: "NUMERIC",
|
||||
|
@ -272,6 +291,7 @@ class BigQuery(Dialect):
|
|||
exp.DataType.Type.TIMESTAMP: "DATETIME",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
|
||||
exp.DataType.Type.TINYINT: "INT64",
|
||||
exp.DataType.Type.VARBINARY: "BYTES",
|
||||
exp.DataType.Type.VARCHAR: "STRING",
|
||||
exp.DataType.Type.VARIANT: "ANY TYPE",
|
||||
}
|
||||
|
@ -310,3 +330,6 @@ class BigQuery(Dialect):
|
|||
if not expression.args.get("distinct", False):
|
||||
self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery")
|
||||
return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
|
||||
|
||||
def with_properties(self, properties: exp.Properties) -> str:
|
||||
return self.properties(properties, prefix=self.seg("OPTIONS"))
|
||||
|
|
|
@ -22,6 +22,8 @@ class ClickHouse(Dialect):
|
|||
class Tokenizer(tokens.Tokenizer):
|
||||
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
|
||||
IDENTIFIERS = ['"', "`"]
|
||||
BIT_STRINGS = [("0b", "")]
|
||||
HEX_STRINGS = [("0x", ""), ("0X", "")]
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
|
@ -31,10 +33,18 @@ class ClickHouse(Dialect):
|
|||
"FINAL": TokenType.FINAL,
|
||||
"FLOAT32": TokenType.FLOAT,
|
||||
"FLOAT64": TokenType.DOUBLE,
|
||||
"INT16": TokenType.SMALLINT,
|
||||
"INT32": TokenType.INT,
|
||||
"INT64": TokenType.BIGINT,
|
||||
"INT8": TokenType.TINYINT,
|
||||
"UINT8": TokenType.UTINYINT,
|
||||
"INT16": TokenType.SMALLINT,
|
||||
"UINT16": TokenType.USMALLINT,
|
||||
"INT32": TokenType.INT,
|
||||
"UINT32": TokenType.UINT,
|
||||
"INT64": TokenType.BIGINT,
|
||||
"UINT64": TokenType.UBIGINT,
|
||||
"INT128": TokenType.INT128,
|
||||
"UINT128": TokenType.UINT128,
|
||||
"INT256": TokenType.INT256,
|
||||
"UINT256": TokenType.UINT256,
|
||||
"TUPLE": TokenType.STRUCT,
|
||||
}
|
||||
|
||||
|
@ -121,9 +131,17 @@ class ClickHouse(Dialect):
|
|||
exp.DataType.Type.ARRAY: "Array",
|
||||
exp.DataType.Type.STRUCT: "Tuple",
|
||||
exp.DataType.Type.TINYINT: "Int8",
|
||||
exp.DataType.Type.UTINYINT: "UInt8",
|
||||
exp.DataType.Type.SMALLINT: "Int16",
|
||||
exp.DataType.Type.USMALLINT: "UInt16",
|
||||
exp.DataType.Type.INT: "Int32",
|
||||
exp.DataType.Type.UINT: "UInt32",
|
||||
exp.DataType.Type.BIGINT: "Int64",
|
||||
exp.DataType.Type.UBIGINT: "UInt64",
|
||||
exp.DataType.Type.INT128: "Int128",
|
||||
exp.DataType.Type.UINT128: "UInt128",
|
||||
exp.DataType.Type.INT256: "Int256",
|
||||
exp.DataType.Type.UINT256: "UInt256",
|
||||
exp.DataType.Type.FLOAT: "Float32",
|
||||
exp.DataType.Type.DOUBLE: "Float64",
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot import exp, transforms
|
||||
from sqlglot.dialects.dialect import parse_date_delta
|
||||
from sqlglot.dialects.spark import Spark
|
||||
from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql
|
||||
|
@ -29,13 +29,20 @@ class Databricks(Spark):
|
|||
exp.DateAdd: generate_date_delta_with_unit_sql,
|
||||
exp.DateDiff: generate_date_delta_with_unit_sql,
|
||||
exp.JSONExtract: lambda self, e: self.binary(e, ":"),
|
||||
exp.Select: transforms.preprocess(
|
||||
[
|
||||
transforms.eliminate_distinct_on,
|
||||
transforms.unnest_to_explode,
|
||||
]
|
||||
),
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
}
|
||||
TRANSFORMS.pop(exp.Select) # Remove the ELIMINATE_QUALIFY transformation
|
||||
|
||||
PARAMETER_TOKEN = "$"
|
||||
|
||||
class Tokenizer(Spark.Tokenizer):
|
||||
HEX_STRINGS = []
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
**Spark.Tokenizer.SINGLE_TOKENS,
|
||||
"$": TokenType.PARAMETER,
|
||||
|
|
|
@ -28,6 +28,7 @@ class Dialects(str, Enum):
|
|||
REDSHIFT = "redshift"
|
||||
SNOWFLAKE = "snowflake"
|
||||
SPARK = "spark"
|
||||
SPARK2 = "spark2"
|
||||
SQLITE = "sqlite"
|
||||
STARROCKS = "starrocks"
|
||||
TABLEAU = "tableau"
|
||||
|
@ -69,30 +70,17 @@ class _Dialect(type):
|
|||
klass.tokenizer_class._IDENTIFIERS.items()
|
||||
)[0]
|
||||
|
||||
if (
|
||||
klass.tokenizer_class._BIT_STRINGS
|
||||
and exp.BitString not in klass.generator_class.TRANSFORMS
|
||||
):
|
||||
bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0]
|
||||
klass.generator_class.TRANSFORMS[
|
||||
exp.BitString
|
||||
] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}"
|
||||
if (
|
||||
klass.tokenizer_class._HEX_STRINGS
|
||||
and exp.HexString not in klass.generator_class.TRANSFORMS
|
||||
):
|
||||
hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0]
|
||||
klass.generator_class.TRANSFORMS[
|
||||
exp.HexString
|
||||
] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
|
||||
if (
|
||||
klass.tokenizer_class._BYTE_STRINGS
|
||||
and exp.ByteString not in klass.generator_class.TRANSFORMS
|
||||
):
|
||||
be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0]
|
||||
klass.generator_class.TRANSFORMS[
|
||||
exp.ByteString
|
||||
] = lambda self, e: f"{be_start}{self.sql(e, 'this')}{be_end}"
|
||||
klass.bit_start, klass.bit_end = seq_get(
|
||||
list(klass.tokenizer_class._BIT_STRINGS.items()), 0
|
||||
) or (None, None)
|
||||
|
||||
klass.hex_start, klass.hex_end = seq_get(
|
||||
list(klass.tokenizer_class._HEX_STRINGS.items()), 0
|
||||
) or (None, None)
|
||||
|
||||
klass.byte_start, klass.byte_end = seq_get(
|
||||
list(klass.tokenizer_class._BYTE_STRINGS.items()), 0
|
||||
) or (None, None)
|
||||
|
||||
return klass
|
||||
|
||||
|
@ -198,6 +186,12 @@ class Dialect(metaclass=_Dialect):
|
|||
**{
|
||||
"quote_start": self.quote_start,
|
||||
"quote_end": self.quote_end,
|
||||
"bit_start": self.bit_start,
|
||||
"bit_end": self.bit_end,
|
||||
"hex_start": self.hex_start,
|
||||
"hex_end": self.hex_end,
|
||||
"byte_start": self.byte_start,
|
||||
"byte_end": self.byte_end,
|
||||
"identifier_start": self.identifier_start,
|
||||
"identifier_end": self.identifier_end,
|
||||
"string_escape": self.tokenizer_class.STRING_ESCAPES[0],
|
||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
create_with_partitions_sql,
|
||||
|
@ -145,6 +145,7 @@ class Drill(Dialect):
|
|||
exp.StrPosition: str_position_sql,
|
||||
exp.StrToDate: _str_to_date,
|
||||
exp.Pow: rename_func("POW"),
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
|
@ -23,52 +25,61 @@ from sqlglot.helper import seq_get
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _ts_or_ds_add(self, expression):
|
||||
def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = self.sql(expression, "unit").strip("'") or "DAY"
|
||||
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
|
||||
|
||||
|
||||
def _date_add(self, expression):
|
||||
def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = self.sql(expression, "unit").strip("'") or "DAY"
|
||||
return f"{this} + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
|
||||
|
||||
|
||||
def _array_sort_sql(self, expression):
|
||||
def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
|
||||
if expression.expression:
|
||||
self.unsupported("DUCKDB ARRAY_SORT does not support a comparator")
|
||||
return f"ARRAY_SORT({self.sql(expression, 'this')})"
|
||||
|
||||
|
||||
def _sort_array_sql(self, expression):
|
||||
def _sort_array_sql(self: generator.Generator, expression: exp.SortArray) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
if expression.args.get("asc") == exp.false():
|
||||
return f"ARRAY_REVERSE_SORT({this})"
|
||||
return f"ARRAY_SORT({this})"
|
||||
|
||||
|
||||
def _sort_array_reverse(args):
|
||||
def _sort_array_reverse(args: t.Sequence) -> exp.Expression:
|
||||
return exp.SortArray(this=seq_get(args, 0), asc=exp.false())
|
||||
|
||||
|
||||
def _struct_sql(self, expression):
|
||||
def _parse_date_diff(args: t.Sequence) -> exp.Expression:
|
||||
return exp.DateDiff(
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
unit=seq_get(args, 0),
|
||||
)
|
||||
|
||||
|
||||
def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str:
|
||||
args = [
|
||||
f"'{e.name or e.this.name}': {self.sql(e, 'expression')}" for e in expression.expressions
|
||||
]
|
||||
return f"{{{', '.join(args)}}}"
|
||||
|
||||
|
||||
def _datatype_sql(self, expression):
|
||||
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
|
||||
if expression.this == exp.DataType.Type.ARRAY:
|
||||
return f"{self.expressions(expression, flat=True)}[]"
|
||||
return self.datatype_sql(expression)
|
||||
|
||||
|
||||
def _regexp_extract_sql(self, expression):
|
||||
def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract) -> str:
|
||||
bad_args = list(filter(expression.args.get, ("position", "occurrence")))
|
||||
if bad_args:
|
||||
self.unsupported(f"REGEXP_EXTRACT does not support arg(s) {bad_args}")
|
||||
|
||||
return self.func(
|
||||
"REGEXP_EXTRACT",
|
||||
expression.args.get("this"),
|
||||
|
@ -108,6 +119,8 @@ class DuckDB(Dialect):
|
|||
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
|
||||
"ARRAY_SORT": exp.SortArray.from_arg_list,
|
||||
"ARRAY_REVERSE_SORT": _sort_array_reverse,
|
||||
"DATEDIFF": _parse_date_diff,
|
||||
"DATE_DIFF": _parse_date_diff,
|
||||
"EPOCH": exp.TimeToUnix.from_arg_list,
|
||||
"EPOCH_MS": lambda args: exp.UnixToTime(
|
||||
this=exp.Div(
|
||||
|
@ -115,18 +128,18 @@ class DuckDB(Dialect):
|
|||
expression=exp.Literal.number(1000),
|
||||
)
|
||||
),
|
||||
"LIST_SORT": exp.SortArray.from_arg_list,
|
||||
"LIST_REVERSE_SORT": _sort_array_reverse,
|
||||
"LIST_SORT": exp.SortArray.from_arg_list,
|
||||
"LIST_VALUE": exp.Array.from_arg_list,
|
||||
"REGEXP_MATCHES": exp.RegexpLike.from_arg_list,
|
||||
"STRFTIME": format_time_lambda(exp.TimeToStr, "duckdb"),
|
||||
"STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"),
|
||||
"STR_SPLIT": exp.Split.from_arg_list,
|
||||
"STRING_SPLIT": exp.Split.from_arg_list,
|
||||
"STRING_TO_ARRAY": exp.Split.from_arg_list,
|
||||
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
|
||||
"STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
|
||||
"STRING_TO_ARRAY": exp.Split.from_arg_list,
|
||||
"STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"),
|
||||
"STRUCT_PACK": exp.Struct.from_arg_list,
|
||||
"STR_SPLIT": exp.Split.from_arg_list,
|
||||
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
|
||||
"TO_TIMESTAMP": exp.UnixToTime.from_arg_list,
|
||||
"UNNEST": exp.Explode.from_arg_list,
|
||||
}
|
||||
|
@ -142,10 +155,11 @@ class DuckDB(Dialect):
|
|||
class Generator(generator.Generator):
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0])
|
||||
if isinstance(seq_get(e.expressions, 0), exp.Select)
|
||||
|
@ -154,13 +168,16 @@ class DuckDB(Dialect):
|
|||
exp.ArraySort: _array_sort_sql,
|
||||
exp.ArraySum: rename_func("LIST_SUM"),
|
||||
exp.CommentColumnConstraint: no_comment_column_constraint_sql,
|
||||
exp.CurrentDate: lambda self, e: "CURRENT_DATE",
|
||||
exp.CurrentTime: lambda self, e: "CURRENT_TIME",
|
||||
exp.CurrentTimestamp: lambda self, e: "CURRENT_TIMESTAMP",
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.DateAdd: _date_add,
|
||||
exp.DateAdd: _date_add_sql,
|
||||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF", e.args.get("unit") or exp.Literal.string("day"), e.expression, e.this
|
||||
"DATE_DIFF", f"'{e.args.get('unit', 'day')}'", e.expression, e.this
|
||||
),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
|
||||
|
@ -192,7 +209,7 @@ class DuckDB(Dialect):
|
|||
exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeToUnix: rename_func("EPOCH"),
|
||||
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
|
||||
exp.TsOrDsAdd: _ts_or_ds_add,
|
||||
exp.TsOrDsAdd: _ts_or_ds_add_sql,
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("duckdb"),
|
||||
exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
|
||||
exp.UnixToTime: rename_func("TO_TIMESTAMP"),
|
||||
|
@ -201,7 +218,7 @@ class DuckDB(Dialect):
|
|||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.BINARY: "BLOB",
|
||||
exp.DataType.Type.CHAR: "TEXT",
|
||||
exp.DataType.Type.FLOAT: "REAL",
|
||||
|
@ -212,17 +229,14 @@ class DuckDB(Dialect):
|
|||
exp.DataType.Type.VARCHAR: "TEXT",
|
||||
}
|
||||
|
||||
STAR_MAPPING = {
|
||||
**generator.Generator.STAR_MAPPING,
|
||||
"except": "EXCLUDE",
|
||||
}
|
||||
STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
|
||||
def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str:
|
||||
return super().tablesample_sql(expression, seed_prefix="REPEATABLE")
|
||||
def tablesample_sql(
|
||||
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
|
||||
) -> str:
|
||||
return super().tablesample_sql(expression, seed_prefix="REPEATABLE", sep=sep)
|
||||
|
|
|
@ -81,7 +81,20 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
|
|||
return f"{diff_sql}{multiplier_sql}"
|
||||
|
||||
|
||||
def _array_sort(self: generator.Generator, expression: exp.ArraySort) -> str:
|
||||
def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str:
|
||||
this = expression.this
|
||||
|
||||
if not this.type:
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
annotate_types(this)
|
||||
|
||||
if this.type.is_type(exp.DataType.Type.JSON):
|
||||
return self.sql(this)
|
||||
return self.func("TO_JSON", this, expression.args.get("options"))
|
||||
|
||||
|
||||
def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
|
||||
if expression.expression:
|
||||
self.unsupported("Hive SORT_ARRAY does not support a comparator")
|
||||
return f"SORT_ARRAY({self.sql(expression, 'this')})"
|
||||
|
@ -91,11 +104,11 @@ def _property_sql(self: generator.Generator, expression: exp.Property) -> str:
|
|||
return f"'{expression.name}'={self.sql(expression, 'value')}"
|
||||
|
||||
|
||||
def _str_to_unix(self: generator.Generator, expression: exp.StrToUnix) -> str:
|
||||
def _str_to_unix_sql(self: generator.Generator, expression: exp.StrToUnix) -> str:
|
||||
return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression))
|
||||
|
||||
|
||||
def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
|
||||
def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format not in (Hive.time_format, Hive.date_format):
|
||||
|
@ -103,7 +116,7 @@ def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
|
|||
return f"CAST({this} AS DATE)"
|
||||
|
||||
|
||||
def _str_to_time(self: generator.Generator, expression: exp.StrToTime) -> str:
|
||||
def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format not in (Hive.time_format, Hive.date_format):
|
||||
|
@ -214,6 +227,7 @@ class Hive(Dialect):
|
|||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
|
||||
"BASE64": exp.ToBase64.from_arg_list,
|
||||
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
|
||||
"DATE_ADD": lambda args: exp.TsOrDsAdd(
|
||||
this=seq_get(args, 0),
|
||||
|
@ -251,6 +265,7 @@ class Hive(Dialect):
|
|||
"SPLIT": exp.RegexpSplit.from_arg_list,
|
||||
"TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"),
|
||||
"TO_JSON": exp.JSONFormat.from_arg_list,
|
||||
"UNBASE64": exp.FromBase64.from_arg_list,
|
||||
"UNIX_TIMESTAMP": format_time_lambda(exp.StrToUnix, "hive", True),
|
||||
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
|
||||
}
|
||||
|
@ -280,16 +295,20 @@ class Hive(Dialect):
|
|||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
**transforms.ELIMINATE_QUALIFY, # type: ignore
|
||||
exp.Group: transforms.preprocess([transforms.unalias_group]),
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_qualify, transforms.unnest_to_explode]
|
||||
[
|
||||
transforms.eliminate_qualify,
|
||||
transforms.eliminate_distinct_on,
|
||||
transforms.unnest_to_explode,
|
||||
]
|
||||
),
|
||||
exp.Property: _property_sql,
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.ArrayConcat: rename_func("CONCAT"),
|
||||
exp.ArrayJoin: lambda self, e: self.func("CONCAT_WS", e.expression, e.this),
|
||||
exp.ArraySize: rename_func("SIZE"),
|
||||
exp.ArraySort: _array_sort,
|
||||
exp.ArraySort: _array_sort_sql,
|
||||
exp.With: no_recursive_cte_sql,
|
||||
exp.DateAdd: _add_date_sql,
|
||||
exp.DateDiff: _date_diff_sql,
|
||||
|
@ -298,12 +317,13 @@ class Hive(Dialect):
|
|||
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
|
||||
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
|
||||
exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
|
||||
exp.FromBase64: rename_func("UNBASE64"),
|
||||
exp.If: if_sql,
|
||||
exp.Index: _index_sql,
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
|
||||
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
|
||||
exp.JSONFormat: rename_func("TO_JSON"),
|
||||
exp.JSONFormat: _json_format_sql,
|
||||
exp.Map: var_map_sql,
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
|
@ -318,9 +338,9 @@ class Hive(Dialect):
|
|||
exp.SetAgg: rename_func("COLLECT_SET"),
|
||||
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
|
||||
exp.StrPosition: strposition_to_locate_sql,
|
||||
exp.StrToDate: _str_to_date,
|
||||
exp.StrToTime: _str_to_time,
|
||||
exp.StrToUnix: _str_to_unix,
|
||||
exp.StrToDate: _str_to_date_sql,
|
||||
exp.StrToTime: _str_to_time_sql,
|
||||
exp.StrToUnix: _str_to_unix_sql,
|
||||
exp.StructExtract: struct_extract_sql,
|
||||
exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}",
|
||||
exp.TimeStrToDate: rename_func("TO_DATE"),
|
||||
|
@ -328,6 +348,7 @@ class Hive(Dialect):
|
|||
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.TimeToStr: _time_to_str,
|
||||
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.ToBase64: rename_func("BASE64"),
|
||||
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)",
|
||||
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.TsOrDsToDate: _to_date_sql,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
arrow_json_extract_scalar_sql,
|
||||
|
@ -403,6 +403,7 @@ class MySQL(Dialect):
|
|||
exp.Min: min_or_least,
|
||||
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
|
||||
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.StrPosition: strposition_to_locate_sql,
|
||||
exp.StrToDate: _str_to_date_sql,
|
||||
exp.StrToTime: _str_to_date_sql,
|
||||
|
|
|
@ -34,6 +34,8 @@ def _parse_xml_table(self) -> exp.XMLTable:
|
|||
|
||||
|
||||
class Oracle(Dialect):
|
||||
alias_post_tablesample = True
|
||||
|
||||
# https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212
|
||||
# https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes
|
||||
time_mapping = {
|
||||
|
@ -121,21 +123,23 @@ class Oracle(Dialect):
|
|||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
exp.DateStrToDate: lambda self, e: self.func(
|
||||
"TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD")
|
||||
),
|
||||
exp.Group: transforms.preprocess([transforms.unalias_group]),
|
||||
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.IfNull: rename_func("NVL"),
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
|
||||
exp.Substring: rename_func("SUBSTR"),
|
||||
exp.Table: lambda self, e: self.table_sql(e, sep=" "),
|
||||
exp.TableSample: lambda self, e: self.tablesample_sql(e, sep=" "),
|
||||
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.Trim: trim_sql,
|
||||
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
|
||||
exp.IfNull: rename_func("NVL"),
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
|
@ -164,14 +168,19 @@ class Oracle(Dialect):
|
|||
return f"XMLTABLE({self.sep('')}{self.indent(this + passing + by_ref + columns)}{self.seg(')', sep='')}"
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
VAR_SINGLE_TOKENS = {"@"}
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"(+)": TokenType.JOIN_MARKER,
|
||||
"BINARY_DOUBLE": TokenType.DOUBLE,
|
||||
"BINARY_FLOAT": TokenType.FLOAT,
|
||||
"COLUMNS": TokenType.COLUMN,
|
||||
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
|
||||
"MINUS": TokenType.EXCEPT,
|
||||
"NVARCHAR2": TokenType.NVARCHAR,
|
||||
"RETURNING": TokenType.RETURNING,
|
||||
"SAMPLE": TokenType.TABLE_SAMPLE,
|
||||
"START": TokenType.BEGIN,
|
||||
"TOP": TokenType.TOP,
|
||||
"VARCHAR2": TokenType.VARCHAR,
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
arrow_json_extract_scalar_sql,
|
||||
|
@ -20,7 +22,6 @@ from sqlglot.dialects.dialect import (
|
|||
from sqlglot.helper import seq_get
|
||||
from sqlglot.parser import binary_range_parser
|
||||
from sqlglot.tokens import TokenType
|
||||
from sqlglot.transforms import preprocess, remove_target_from_merge
|
||||
|
||||
DATE_DIFF_FACTOR = {
|
||||
"MICROSECOND": " * 1000000",
|
||||
|
@ -274,8 +275,7 @@ class Postgres(Dialect):
|
|||
TokenType.HASH: exp.BitwiseXor,
|
||||
}
|
||||
|
||||
FACTOR = {
|
||||
**parser.Parser.FACTOR,
|
||||
EXPONENT = {
|
||||
TokenType.CARET: exp.Pow,
|
||||
}
|
||||
|
||||
|
@ -286,6 +286,12 @@ class Postgres(Dialect):
|
|||
TokenType.LT_AT: binary_range_parser(exp.ArrayContained),
|
||||
}
|
||||
|
||||
def _parse_factor(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_tokens(self._parse_exponent, self.FACTOR)
|
||||
|
||||
def _parse_exponent(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_tokens(self._parse_unary, self.EXPONENT)
|
||||
|
||||
def _parse_date_part(self) -> exp.Expression:
|
||||
part = self._parse_type()
|
||||
self._match(TokenType.COMMA)
|
||||
|
@ -316,7 +322,7 @@ class Postgres(Dialect):
|
|||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
|
||||
exp.ColumnDef: preprocess(
|
||||
exp.ColumnDef: transforms.preprocess(
|
||||
[
|
||||
_auto_increment_to_serial,
|
||||
_serial_to_generated,
|
||||
|
@ -341,7 +347,7 @@ class Postgres(Dialect):
|
|||
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
|
||||
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
|
||||
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
|
||||
exp.Merge: preprocess([remove_target_from_merge]),
|
||||
exp.Merge: transforms.preprocess([transforms.remove_target_from_merge]),
|
||||
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
|
||||
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
|
||||
exp.StrPosition: str_position_sql,
|
||||
|
|
|
@ -130,7 +130,7 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s
|
|||
def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str:
|
||||
start = expression.args["start"]
|
||||
end = expression.args["end"]
|
||||
step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series
|
||||
step = expression.args.get("step")
|
||||
|
||||
target_type = None
|
||||
|
||||
|
@ -147,7 +147,11 @@ def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) ->
|
|||
else:
|
||||
start = exp.Cast(this=start, to=to)
|
||||
|
||||
return self.func("SEQUENCE", start, end, step)
|
||||
sql = self.func("SEQUENCE", start, end, step)
|
||||
if isinstance(expression.parent, exp.Table):
|
||||
sql = f"UNNEST({sql})"
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
def _ensure_utf8(charset: exp.Literal) -> None:
|
||||
|
@ -204,6 +208,7 @@ class Presto(Dialect):
|
|||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
|
||||
"APPROX_PERCENTILE": _approx_percentile,
|
||||
"CARDINALITY": exp.ArraySize.from_arg_list,
|
||||
"CONTAINS": exp.ArrayContains.from_arg_list,
|
||||
"DATE_ADD": lambda args: exp.DateAdd(
|
||||
|
@ -219,23 +224,23 @@ class Presto(Dialect):
|
|||
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
|
||||
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
|
||||
"DATE_TRUNC": date_trunc_to_time,
|
||||
"FROM_HEX": exp.Unhex.from_arg_list,
|
||||
"FROM_UNIXTIME": _from_unixtime,
|
||||
"FROM_UTF8": lambda args: exp.Decode(
|
||||
this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8")
|
||||
),
|
||||
"NOW": exp.CurrentTimestamp.from_arg_list,
|
||||
"SEQUENCE": exp.GenerateSeries.from_arg_list,
|
||||
"STRPOS": lambda args: exp.StrPosition(
|
||||
this=seq_get(args, 0),
|
||||
substr=seq_get(args, 1),
|
||||
instance=seq_get(args, 2),
|
||||
),
|
||||
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
|
||||
"APPROX_PERCENTILE": _approx_percentile,
|
||||
"FROM_HEX": exp.Unhex.from_arg_list,
|
||||
"TO_HEX": exp.Hex.from_arg_list,
|
||||
"TO_UTF8": lambda args: exp.Encode(
|
||||
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
|
||||
),
|
||||
"FROM_UTF8": lambda args: exp.Decode(
|
||||
this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8")
|
||||
),
|
||||
}
|
||||
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
|
||||
FUNCTION_PARSERS.pop("TRIM")
|
||||
|
@ -264,7 +269,6 @@ class Presto(Dialect):
|
|||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
exp.ApproxDistinct: _approx_distinct_sql,
|
||||
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
|
||||
exp.ArrayConcat: rename_func("CONCAT"),
|
||||
|
@ -290,6 +294,7 @@ class Presto(Dialect):
|
|||
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
|
||||
exp.Encode: _encode_sql,
|
||||
exp.GenerateSeries: _sequence_sql,
|
||||
exp.Group: transforms.preprocess([transforms.unalias_group]),
|
||||
exp.Hex: rename_func("TO_HEX"),
|
||||
exp.If: if_sql,
|
||||
exp.ILike: no_ilike_sql,
|
||||
|
@ -303,7 +308,11 @@ class Presto(Dialect):
|
|||
exp.SafeDivide: no_safe_divide_sql,
|
||||
exp.Schema: _schema_sql,
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_qualify, transforms.explode_to_unnest]
|
||||
[
|
||||
transforms.eliminate_qualify,
|
||||
transforms.eliminate_distinct_on,
|
||||
transforms.explode_to_unnest,
|
||||
]
|
||||
),
|
||||
exp.SortArray: _no_sort_array,
|
||||
exp.StrPosition: rename_func("STRPOS"),
|
||||
|
@ -327,6 +336,9 @@ class Presto(Dialect):
|
|||
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
||||
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
exp.WithinGroup: transforms.preprocess(
|
||||
[transforms.remove_within_group_for_percentiles]
|
||||
),
|
||||
}
|
||||
|
||||
def interval_sql(self, expression: exp.Interval) -> str:
|
||||
|
|
|
@ -52,6 +52,8 @@ class Redshift(Postgres):
|
|||
return this
|
||||
|
||||
class Tokenizer(Postgres.Tokenizer):
|
||||
BIT_STRINGS = []
|
||||
HEX_STRINGS = []
|
||||
STRING_ESCAPES = ["\\"]
|
||||
|
||||
KEYWORDS = {
|
||||
|
@ -90,7 +92,6 @@ class Redshift(Postgres):
|
|||
|
||||
TRANSFORMS = {
|
||||
**Postgres.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
|
||||
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
|
||||
exp.DateAdd: lambda self, e: self.func(
|
||||
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
|
||||
|
@ -102,6 +103,7 @@ class Redshift(Postgres):
|
|||
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
|
||||
exp.JSONExtract: _json_sql,
|
||||
exp.JSONExtractScalar: _json_sql,
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
|
||||
}
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
date_trunc_to_time,
|
||||
|
@ -252,6 +252,7 @@ class Snowflake(Dialect):
|
|||
class Tokenizer(tokens.Tokenizer):
|
||||
QUOTES = ["'", "$$"]
|
||||
STRING_ESCAPES = ["\\", "'"]
|
||||
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
|
@ -305,6 +306,7 @@ class Snowflake(Dialect):
|
|||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
|
||||
exp.StrPosition: lambda self, e: self.func(
|
||||
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
|
||||
|
|
|
@ -2,222 +2,54 @@ from __future__ import annotations
|
|||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, parser
|
||||
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
|
||||
from sqlglot.dialects.hive import Hive
|
||||
from sqlglot import exp
|
||||
from sqlglot.dialects.spark2 import Spark2
|
||||
from sqlglot.helper import seq_get
|
||||
|
||||
|
||||
def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
|
||||
kind = e.args["kind"]
|
||||
properties = e.args.get("properties")
|
||||
def _parse_datediff(args: t.Sequence) -> exp.Expression:
|
||||
"""
|
||||
Although Spark docs don't mention the "unit" argument, Spark3 added support for
|
||||
it at some point. Databricks also supports this variation (see below).
|
||||
|
||||
if kind.upper() == "TABLE" and any(
|
||||
isinstance(prop, exp.TemporaryProperty)
|
||||
for prop in (properties.expressions if properties else [])
|
||||
):
|
||||
return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
|
||||
return create_with_partitions_sql(self, e)
|
||||
For example, in spark-sql (v3.3.1):
|
||||
- SELECT DATEDIFF('2020-01-01', '2020-01-05') results in -4
|
||||
- SELECT DATEDIFF(day, '2020-01-01', '2020-01-05') results in 4
|
||||
|
||||
See also:
|
||||
- https://docs.databricks.com/sql/language-manual/functions/datediff3.html
|
||||
- https://docs.databricks.com/sql/language-manual/functions/datediff.html
|
||||
"""
|
||||
unit = None
|
||||
this = seq_get(args, 0)
|
||||
expression = seq_get(args, 1)
|
||||
|
||||
if len(args) == 3:
|
||||
unit = this
|
||||
this = args[2]
|
||||
|
||||
return exp.DateDiff(
|
||||
this=exp.TsOrDsToDate(this=this), expression=exp.TsOrDsToDate(this=expression), unit=unit
|
||||
)
|
||||
|
||||
|
||||
def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
|
||||
keys = self.sql(expression.args["keys"])
|
||||
values = self.sql(expression.args["values"])
|
||||
return f"MAP_FROM_ARRAYS({keys}, {values})"
|
||||
|
||||
|
||||
def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format == Hive.date_format:
|
||||
return f"TO_DATE({this})"
|
||||
return f"TO_DATE({this}, {time_format})"
|
||||
|
||||
|
||||
def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
|
||||
scale = expression.args.get("scale")
|
||||
timestamp = self.sql(expression, "this")
|
||||
if scale is None:
|
||||
return f"FROM_UNIXTIME({timestamp})"
|
||||
if scale == exp.UnixToTime.SECONDS:
|
||||
return f"TIMESTAMP_SECONDS({timestamp})"
|
||||
if scale == exp.UnixToTime.MILLIS:
|
||||
return f"TIMESTAMP_MILLIS({timestamp})"
|
||||
if scale == exp.UnixToTime.MICROS:
|
||||
return f"TIMESTAMP_MICROS({timestamp})"
|
||||
|
||||
raise ValueError("Improper scale for timestamp")
|
||||
|
||||
|
||||
class Spark(Hive):
|
||||
class Parser(Hive.Parser):
|
||||
class Spark(Spark2):
|
||||
class Parser(Spark2.Parser):
|
||||
FUNCTIONS = {
|
||||
**Hive.Parser.FUNCTIONS, # type: ignore
|
||||
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
|
||||
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
|
||||
"LEFT": lambda args: exp.Substring(
|
||||
this=seq_get(args, 0),
|
||||
start=exp.Literal.number(1),
|
||||
length=seq_get(args, 1),
|
||||
),
|
||||
"SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
|
||||
this=seq_get(args, 0),
|
||||
expression=seq_get(args, 1),
|
||||
),
|
||||
"SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
|
||||
this=seq_get(args, 0),
|
||||
expression=seq_get(args, 1),
|
||||
),
|
||||
"RIGHT": lambda args: exp.Substring(
|
||||
this=seq_get(args, 0),
|
||||
start=exp.Sub(
|
||||
this=exp.Length(this=seq_get(args, 0)),
|
||||
expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
|
||||
),
|
||||
length=seq_get(args, 1),
|
||||
),
|
||||
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
||||
"BOOLEAN": lambda args: exp.Cast(
|
||||
this=seq_get(args, 0), to=exp.DataType.build("boolean")
|
||||
),
|
||||
"IIF": exp.If.from_arg_list,
|
||||
"INT": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("int")),
|
||||
"AGGREGATE": exp.Reduce.from_arg_list,
|
||||
"DAYOFWEEK": lambda args: exp.DayOfWeek(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"DAYOFMONTH": lambda args: exp.DayOfMonth(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"DAYOFYEAR": lambda args: exp.DayOfYear(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
|
||||
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
|
||||
this=seq_get(args, 1),
|
||||
unit=exp.var(seq_get(args, 0)),
|
||||
),
|
||||
"STRING": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("string")),
|
||||
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
|
||||
"TIMESTAMP": lambda args: exp.Cast(
|
||||
this=seq_get(args, 0), to=exp.DataType.build("timestamp")
|
||||
),
|
||||
**Spark2.Parser.FUNCTIONS, # type: ignore
|
||||
"DATEDIFF": _parse_datediff,
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS, # type: ignore
|
||||
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
|
||||
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
|
||||
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
|
||||
"MERGE": lambda self: self._parse_join_hint("MERGE"),
|
||||
"SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"),
|
||||
"MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"),
|
||||
"SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"),
|
||||
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
|
||||
}
|
||||
class Generator(Spark2.Generator):
|
||||
TRANSFORMS = Spark2.Generator.TRANSFORMS.copy()
|
||||
TRANSFORMS.pop(exp.DateDiff)
|
||||
|
||||
def _parse_add_column(self) -> t.Optional[exp.Expression]:
|
||||
return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
|
||||
def datediff_sql(self, expression: exp.DateDiff) -> str:
|
||||
unit = self.sql(expression, "unit")
|
||||
end = self.sql(expression, "this")
|
||||
start = self.sql(expression, "expression")
|
||||
|
||||
def _parse_drop_column(self) -> t.Optional[exp.Expression]:
|
||||
return self._match_text_seq("DROP", "COLUMNS") and self.expression(
|
||||
exp.Drop,
|
||||
this=self._parse_schema(),
|
||||
kind="COLUMNS",
|
||||
)
|
||||
if unit:
|
||||
return self.func("DATEDIFF", unit, start, end)
|
||||
|
||||
def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
|
||||
# Spark doesn't add a suffix to the pivot columns when there's a single aggregation
|
||||
if len(pivot_columns) == 1:
|
||||
return [""]
|
||||
|
||||
names = []
|
||||
for agg in pivot_columns:
|
||||
if isinstance(agg, exp.Alias):
|
||||
names.append(agg.alias)
|
||||
else:
|
||||
"""
|
||||
This case corresponds to aggregations without aliases being used as suffixes
|
||||
(e.g. col_avg(foo)). We need to unquote identifiers because they're going to
|
||||
be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
|
||||
Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
|
||||
|
||||
Moreover, function names are lowercased in order to mimic Spark's naming scheme.
|
||||
"""
|
||||
agg_all_unquoted = agg.transform(
|
||||
lambda node: exp.Identifier(this=node.name, quoted=False)
|
||||
if isinstance(node, exp.Identifier)
|
||||
else node
|
||||
)
|
||||
names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
|
||||
|
||||
return names
|
||||
|
||||
class Generator(Hive.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**Hive.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.TINYINT: "BYTE",
|
||||
exp.DataType.Type.SMALLINT: "SHORT",
|
||||
exp.DataType.Type.BIGINT: "LONG",
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**Hive.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.CollateProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**Hive.Generator.TRANSFORMS, # type: ignore
|
||||
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
|
||||
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
|
||||
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
|
||||
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
|
||||
exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
|
||||
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
|
||||
exp.StrToDate: _str_to_date,
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.Create: _create_sql,
|
||||
exp.Map: _map_sql,
|
||||
exp.Reduce: rename_func("AGGREGATE"),
|
||||
exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
|
||||
exp.TimestampTrunc: lambda self, e: self.func(
|
||||
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
|
||||
),
|
||||
exp.Trim: trim_sql,
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
exp.DateFromParts: rename_func("MAKE_DATE"),
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
|
||||
exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
|
||||
}
|
||||
TRANSFORMS.pop(exp.ArraySort)
|
||||
TRANSFORMS.pop(exp.ILike)
|
||||
|
||||
WRAP_DERIVED_VALUES = False
|
||||
CREATE_FUNCTION_RETURN_AS = False
|
||||
|
||||
def cast_sql(self, expression: exp.Cast) -> str:
|
||||
if isinstance(expression.this, exp.Cast) and expression.this.is_type(
|
||||
exp.DataType.Type.JSON
|
||||
):
|
||||
schema = f"'{self.sql(expression, 'to')}'"
|
||||
return self.func("FROM_JSON", expression.this.this, schema)
|
||||
if expression.to.is_type(exp.DataType.Type.JSON):
|
||||
return self.func("TO_JSON", expression.this)
|
||||
|
||||
return super(Spark.Generator, self).cast_sql(expression)
|
||||
|
||||
class Tokenizer(Hive.Tokenizer):
|
||||
HEX_STRINGS = [("X'", "'")]
|
||||
return self.func("DATEDIFF", end, start)
|
||||
|
|
238
sqlglot/dialects/spark2.py
Normal file
238
sqlglot/dialects/spark2.py
Normal file
|
@ -0,0 +1,238 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, parser, transforms
|
||||
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
|
||||
from sqlglot.dialects.hive import Hive
|
||||
from sqlglot.helper import seq_get
|
||||
|
||||
|
||||
def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
|
||||
kind = e.args["kind"]
|
||||
properties = e.args.get("properties")
|
||||
|
||||
if kind.upper() == "TABLE" and any(
|
||||
isinstance(prop, exp.TemporaryProperty)
|
||||
for prop in (properties.expressions if properties else [])
|
||||
):
|
||||
return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
|
||||
return create_with_partitions_sql(self, e)
|
||||
|
||||
|
||||
def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
|
||||
keys = self.sql(expression.args["keys"])
|
||||
values = self.sql(expression.args["values"])
|
||||
return f"MAP_FROM_ARRAYS({keys}, {values})"
|
||||
|
||||
|
||||
def _parse_as_cast(to_type: str) -> t.Callable[[t.Sequence], exp.Expression]:
|
||||
return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type))
|
||||
|
||||
|
||||
def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format == Hive.date_format:
|
||||
return f"TO_DATE({this})"
|
||||
return f"TO_DATE({this}, {time_format})"
|
||||
|
||||
|
||||
def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
|
||||
scale = expression.args.get("scale")
|
||||
timestamp = self.sql(expression, "this")
|
||||
if scale is None:
|
||||
return f"CAST(FROM_UNIXTIME({timestamp}) AS TIMESTAMP)"
|
||||
if scale == exp.UnixToTime.SECONDS:
|
||||
return f"TIMESTAMP_SECONDS({timestamp})"
|
||||
if scale == exp.UnixToTime.MILLIS:
|
||||
return f"TIMESTAMP_MILLIS({timestamp})"
|
||||
if scale == exp.UnixToTime.MICROS:
|
||||
return f"TIMESTAMP_MICROS({timestamp})"
|
||||
|
||||
raise ValueError("Improper scale for timestamp")
|
||||
|
||||
|
||||
class Spark2(Hive):
|
||||
class Parser(Hive.Parser):
|
||||
FUNCTIONS = {
|
||||
**Hive.Parser.FUNCTIONS, # type: ignore
|
||||
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
|
||||
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
|
||||
"LEFT": lambda args: exp.Substring(
|
||||
this=seq_get(args, 0),
|
||||
start=exp.Literal.number(1),
|
||||
length=seq_get(args, 1),
|
||||
),
|
||||
"SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
|
||||
this=seq_get(args, 0),
|
||||
expression=seq_get(args, 1),
|
||||
),
|
||||
"SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
|
||||
this=seq_get(args, 0),
|
||||
expression=seq_get(args, 1),
|
||||
),
|
||||
"RIGHT": lambda args: exp.Substring(
|
||||
this=seq_get(args, 0),
|
||||
start=exp.Sub(
|
||||
this=exp.Length(this=seq_get(args, 0)),
|
||||
expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
|
||||
),
|
||||
length=seq_get(args, 1),
|
||||
),
|
||||
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
||||
"IIF": exp.If.from_arg_list,
|
||||
"AGGREGATE": exp.Reduce.from_arg_list,
|
||||
"DAYOFWEEK": lambda args: exp.DayOfWeek(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"DAYOFMONTH": lambda args: exp.DayOfMonth(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"DAYOFYEAR": lambda args: exp.DayOfYear(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
|
||||
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
|
||||
this=seq_get(args, 1),
|
||||
unit=exp.var(seq_get(args, 0)),
|
||||
),
|
||||
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
|
||||
"BOOLEAN": _parse_as_cast("boolean"),
|
||||
"DOUBLE": _parse_as_cast("double"),
|
||||
"FLOAT": _parse_as_cast("float"),
|
||||
"INT": _parse_as_cast("int"),
|
||||
"STRING": _parse_as_cast("string"),
|
||||
"TIMESTAMP": _parse_as_cast("timestamp"),
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS, # type: ignore
|
||||
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
|
||||
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
|
||||
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
|
||||
"MERGE": lambda self: self._parse_join_hint("MERGE"),
|
||||
"SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"),
|
||||
"MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"),
|
||||
"SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"),
|
||||
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
|
||||
}
|
||||
|
||||
def _parse_add_column(self) -> t.Optional[exp.Expression]:
|
||||
return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
|
||||
|
||||
def _parse_drop_column(self) -> t.Optional[exp.Expression]:
|
||||
return self._match_text_seq("DROP", "COLUMNS") and self.expression(
|
||||
exp.Drop,
|
||||
this=self._parse_schema(),
|
||||
kind="COLUMNS",
|
||||
)
|
||||
|
||||
def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
|
||||
# Spark doesn't add a suffix to the pivot columns when there's a single aggregation
|
||||
if len(pivot_columns) == 1:
|
||||
return [""]
|
||||
|
||||
names = []
|
||||
for agg in pivot_columns:
|
||||
if isinstance(agg, exp.Alias):
|
||||
names.append(agg.alias)
|
||||
else:
|
||||
"""
|
||||
This case corresponds to aggregations without aliases being used as suffixes
|
||||
(e.g. col_avg(foo)). We need to unquote identifiers because they're going to
|
||||
be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
|
||||
Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
|
||||
|
||||
Moreover, function names are lowercased in order to mimic Spark's naming scheme.
|
||||
"""
|
||||
agg_all_unquoted = agg.transform(
|
||||
lambda node: exp.Identifier(this=node.name, quoted=False)
|
||||
if isinstance(node, exp.Identifier)
|
||||
else node
|
||||
)
|
||||
names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
|
||||
|
||||
return names
|
||||
|
||||
class Generator(Hive.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**Hive.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.TINYINT: "BYTE",
|
||||
exp.DataType.Type.SMALLINT: "SHORT",
|
||||
exp.DataType.Type.BIGINT: "LONG",
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**Hive.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.CollateProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**Hive.Generator.TRANSFORMS, # type: ignore
|
||||
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
|
||||
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
|
||||
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
|
||||
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
|
||||
exp.Create: _create_sql,
|
||||
exp.DateFromParts: rename_func("MAKE_DATE"),
|
||||
exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
|
||||
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.Map: _map_sql,
|
||||
exp.Pivot: transforms.preprocess([transforms.unqualify_pivot_columns]),
|
||||
exp.Reduce: rename_func("AGGREGATE"),
|
||||
exp.StrToDate: _str_to_date,
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimestampTrunc: lambda self, e: self.func(
|
||||
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
|
||||
),
|
||||
exp.Trim: trim_sql,
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
|
||||
exp.WithinGroup: transforms.preprocess(
|
||||
[transforms.remove_within_group_for_percentiles]
|
||||
),
|
||||
}
|
||||
TRANSFORMS.pop(exp.ArrayJoin)
|
||||
TRANSFORMS.pop(exp.ArraySort)
|
||||
TRANSFORMS.pop(exp.ILike)
|
||||
|
||||
WRAP_DERIVED_VALUES = False
|
||||
CREATE_FUNCTION_RETURN_AS = False
|
||||
|
||||
def cast_sql(self, expression: exp.Cast) -> str:
|
||||
if isinstance(expression.this, exp.Cast) and expression.this.is_type(
|
||||
exp.DataType.Type.JSON
|
||||
):
|
||||
schema = f"'{self.sql(expression, 'to')}'"
|
||||
return self.func("FROM_JSON", expression.this.this, schema)
|
||||
if expression.to.is_type(exp.DataType.Type.JSON):
|
||||
return self.func("TO_JSON", expression.this)
|
||||
|
||||
return super(Hive.Generator, self).cast_sql(expression)
|
||||
|
||||
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
|
||||
return super().columndef_sql(
|
||||
expression,
|
||||
sep=": "
|
||||
if isinstance(expression.parent, exp.DataType)
|
||||
and expression.parent.is_type(exp.DataType.Type.STRUCT)
|
||||
else sep,
|
||||
)
|
||||
|
||||
class Tokenizer(Hive.Tokenizer):
|
||||
HEX_STRINGS = [("X'", "'")]
|
|
@ -22,6 +22,40 @@ def _date_add_sql(self, expression):
|
|||
return self.func("DATE", expression.this, modifier)
|
||||
|
||||
|
||||
def _transform_create(expression: exp.Expression) -> exp.Expression:
|
||||
"""Move primary key to a column and enforce auto_increment on primary keys."""
|
||||
schema = expression.this
|
||||
|
||||
if isinstance(expression, exp.Create) and isinstance(schema, exp.Schema):
|
||||
defs = {}
|
||||
primary_key = None
|
||||
|
||||
for e in schema.expressions:
|
||||
if isinstance(e, exp.ColumnDef):
|
||||
defs[e.name] = e
|
||||
elif isinstance(e, exp.PrimaryKey):
|
||||
primary_key = e
|
||||
|
||||
if primary_key and len(primary_key.expressions) == 1:
|
||||
column = defs[primary_key.expressions[0].name]
|
||||
column.append(
|
||||
"constraints", exp.ColumnConstraint(kind=exp.PrimaryKeyColumnConstraint())
|
||||
)
|
||||
schema.expressions.remove(primary_key)
|
||||
else:
|
||||
for column in defs.values():
|
||||
auto_increment = None
|
||||
for constraint in column.constraints.copy():
|
||||
if isinstance(constraint.kind, exp.PrimaryKeyColumnConstraint):
|
||||
break
|
||||
if isinstance(constraint.kind, exp.AutoIncrementColumnConstraint):
|
||||
auto_increment = constraint
|
||||
if auto_increment:
|
||||
column.constraints.remove(auto_increment)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
class SQLite(Dialect):
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
IDENTIFIERS = ['"', ("[", "]"), "`"]
|
||||
|
@ -65,8 +99,8 @@ class SQLite(Dialect):
|
|||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.ELIMINATE_QUALIFY, # type: ignore
|
||||
exp.CountIf: count_if_to_sum,
|
||||
exp.Create: transforms.preprocess([_transform_create]),
|
||||
exp.CurrentDate: lambda *_: "CURRENT_DATE",
|
||||
exp.CurrentTime: lambda *_: "CURRENT_TIME",
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
|
@ -80,14 +114,17 @@ class SQLite(Dialect):
|
|||
exp.Levenshtein: rename_func("EDITDIST3"),
|
||||
exp.LogicalOr: rename_func("MAX"),
|
||||
exp.LogicalAnd: rename_func("MIN"),
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_distinct_on, transforms.eliminate_qualify]
|
||||
),
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
exp.TimeStrToTime: lambda self, e: self.sql(e, "this"),
|
||||
exp.TryCast: no_trycast_sql,
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
k: exp.Properties.Location.UNSUPPORTED
|
||||
for k, v in generator.Generator.PROPERTIES_LOCATION.items()
|
||||
}
|
||||
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
|
|
|
@ -34,6 +34,7 @@ class StarRocks(MySQL):
|
|||
exp.JSONExtractScalar: arrow_json_extract_sql,
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.DateDiff: rename_func("DATEDIFF"),
|
||||
exp.RegexpLike: rename_func("REGEXP"),
|
||||
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimestampTrunc: lambda self, e: self.func(
|
||||
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, parser
|
||||
from sqlglot import exp, generator, parser, transforms
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
|
||||
|
||||
|
@ -29,6 +29,7 @@ class Tableau(Dialect):
|
|||
exp.If: _if_sql,
|
||||
exp.Coalesce: _coalesce_sql,
|
||||
exp.Count: _count_sql,
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
format_time_lambda,
|
||||
|
@ -148,6 +148,7 @@ class Teradata(Dialect):
|
|||
**generator.Generator.TRANSFORMS,
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.TimeToStr: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
import re
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
max_or_greatest,
|
||||
|
@ -259,8 +259,8 @@ class TSQL(Dialect):
|
|||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
IDENTIFIERS = ['"', ("[", "]")]
|
||||
|
||||
QUOTES = ["'", '"']
|
||||
HEX_STRINGS = [("0x", ""), ("0X", "")]
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
|
@ -463,17 +463,18 @@ class TSQL(Dialect):
|
|||
exp.DateDiff: generate_date_delta_with_unit_sql,
|
||||
exp.CurrentDate: rename_func("GETDATE"),
|
||||
exp.CurrentTimestamp: rename_func("GETDATE"),
|
||||
exp.If: rename_func("IIF"),
|
||||
exp.NumberToStr: _format_sql,
|
||||
exp.TimeToStr: _format_sql,
|
||||
exp.GroupConcat: _string_agg_sql,
|
||||
exp.If: rename_func("IIF"),
|
||||
exp.Max: max_or_greatest,
|
||||
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
|
||||
exp.Min: min_or_least,
|
||||
exp.NumberToStr: _format_sql,
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
|
||||
exp.SHA2: lambda self, e: self.func(
|
||||
"HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
|
||||
),
|
||||
exp.TimeToStr: _format_sql,
|
||||
}
|
||||
|
||||
TRANSFORMS.pop(exp.ReturnsProperty)
|
||||
|
|
|
@ -64,6 +64,13 @@ class Expression(metaclass=_Expression):
|
|||
and representing expressions as strings.
|
||||
arg_types: determines what arguments (child nodes) are supported by an expression. It
|
||||
maps arg keys to booleans that indicate whether the corresponding args are optional.
|
||||
parent: a reference to the parent expression (or None, in case of root expressions).
|
||||
arg_key: the arg key an expression is associated with, i.e. the name its parent expression
|
||||
uses to refer to it.
|
||||
comments: a list of comments that are associated with a given expression. This is used in
|
||||
order to preserve comments when transpiling SQL code.
|
||||
_type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the
|
||||
optimizer, in order to enable some transformations that require type information.
|
||||
|
||||
Example:
|
||||
>>> class Foo(Expression):
|
||||
|
@ -74,13 +81,6 @@ class Expression(metaclass=_Expression):
|
|||
|
||||
Args:
|
||||
args: a mapping used for retrieving the arguments of an expression, given their arg keys.
|
||||
parent: a reference to the parent expression (or None, in case of root expressions).
|
||||
arg_key: the arg key an expression is associated with, i.e. the name its parent expression
|
||||
uses to refer to it.
|
||||
comments: a list of comments that are associated with a given expression. This is used in
|
||||
order to preserve comments when transpiling SQL code.
|
||||
_type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the
|
||||
optimizer, in order to enable some transformations that require type information.
|
||||
"""
|
||||
|
||||
key = "expression"
|
||||
|
@ -258,6 +258,12 @@ class Expression(metaclass=_Expression):
|
|||
new.parent = self.parent
|
||||
return new
|
||||
|
||||
def add_comments(self, comments: t.Optional[t.List[str]]) -> None:
|
||||
if self.comments is None:
|
||||
self.comments = []
|
||||
if comments:
|
||||
self.comments.extend(comments)
|
||||
|
||||
def append(self, arg_key, value):
|
||||
"""
|
||||
Appends value to arg_key if it's a list or sets it as a new list.
|
||||
|
@ -650,7 +656,7 @@ ExpOrStr = t.Union[str, Expression]
|
|||
|
||||
|
||||
class Condition(Expression):
|
||||
def and_(self, *expressions, dialect=None, **opts):
|
||||
def and_(self, *expressions, dialect=None, copy=True, **opts):
|
||||
"""
|
||||
AND this condition with one or multiple expressions.
|
||||
|
||||
|
@ -662,14 +668,15 @@ class Condition(Expression):
|
|||
*expressions (str | Expression): the SQL code strings to parse.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
dialect (str): the dialect used to parse the input expression.
|
||||
copy (bool): whether or not to copy the involved expressions (only applies to Expressions).
|
||||
opts (kwargs): other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
And: the new condition.
|
||||
"""
|
||||
return and_(self, *expressions, dialect=dialect, **opts)
|
||||
return and_(self, *expressions, dialect=dialect, copy=copy, **opts)
|
||||
|
||||
def or_(self, *expressions, dialect=None, **opts):
|
||||
def or_(self, *expressions, dialect=None, copy=True, **opts):
|
||||
"""
|
||||
OR this condition with one or multiple expressions.
|
||||
|
||||
|
@ -681,14 +688,15 @@ class Condition(Expression):
|
|||
*expressions (str | Expression): the SQL code strings to parse.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
dialect (str): the dialect used to parse the input expression.
|
||||
copy (bool): whether or not to copy the involved expressions (only applies to Expressions).
|
||||
opts (kwargs): other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
Or: the new condition.
|
||||
"""
|
||||
return or_(self, *expressions, dialect=dialect, **opts)
|
||||
return or_(self, *expressions, dialect=dialect, copy=copy, **opts)
|
||||
|
||||
def not_(self):
|
||||
def not_(self, copy=True):
|
||||
"""
|
||||
Wrap this condition with NOT.
|
||||
|
||||
|
@ -696,14 +704,17 @@ class Condition(Expression):
|
|||
>>> condition("x=1").not_().sql()
|
||||
'NOT x = 1'
|
||||
|
||||
Args:
|
||||
copy (bool): whether or not to copy this object.
|
||||
|
||||
Returns:
|
||||
Not: the new condition.
|
||||
"""
|
||||
return not_(self)
|
||||
return not_(self, copy=copy)
|
||||
|
||||
def _binop(self, klass: t.Type[E], other: ExpOrStr, reverse=False) -> E:
|
||||
this = self
|
||||
other = convert(other)
|
||||
this = self.copy()
|
||||
other = convert(other, copy=True)
|
||||
if not isinstance(this, klass) and not isinstance(other, klass):
|
||||
this = _wrap(this, Binary)
|
||||
other = _wrap(other, Binary)
|
||||
|
@ -711,20 +722,25 @@ class Condition(Expression):
|
|||
return klass(this=other, expression=this)
|
||||
return klass(this=this, expression=other)
|
||||
|
||||
def __getitem__(self, other: ExpOrStr | slice | t.Tuple[ExpOrStr]):
|
||||
if isinstance(other, slice):
|
||||
return Between(
|
||||
this=self,
|
||||
low=convert(other.start),
|
||||
high=convert(other.stop),
|
||||
)
|
||||
return Bracket(this=self, expressions=[convert(e) for e in ensure_list(other)])
|
||||
def __getitem__(self, other: ExpOrStr | t.Tuple[ExpOrStr]):
|
||||
return Bracket(
|
||||
this=self.copy(), expressions=[convert(e, copy=True) for e in ensure_list(other)]
|
||||
)
|
||||
|
||||
def isin(self, *expressions: ExpOrStr, query: t.Optional[ExpOrStr] = None, **opts) -> In:
|
||||
def isin(
|
||||
self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy=True, **opts
|
||||
) -> In:
|
||||
return In(
|
||||
this=self,
|
||||
expressions=[convert(e) for e in expressions],
|
||||
query=maybe_parse(query, **opts) if query else None,
|
||||
this=_maybe_copy(self, copy),
|
||||
expressions=[convert(e, copy=copy) for e in expressions],
|
||||
query=maybe_parse(query, copy=copy, **opts) if query else None,
|
||||
)
|
||||
|
||||
def between(self, low: t.Any, high: t.Any, copy=True, **opts) -> Between:
|
||||
return Between(
|
||||
this=_maybe_copy(self, copy),
|
||||
low=convert(low, copy=copy, **opts),
|
||||
high=convert(high, copy=copy, **opts),
|
||||
)
|
||||
|
||||
def like(self, other: ExpOrStr) -> Like:
|
||||
|
@ -809,10 +825,10 @@ class Condition(Expression):
|
|||
return self._binop(Or, other, reverse=True)
|
||||
|
||||
def __neg__(self) -> Neg:
|
||||
return Neg(this=_wrap(self, Binary))
|
||||
return Neg(this=_wrap(self.copy(), Binary))
|
||||
|
||||
def __invert__(self) -> Not:
|
||||
return not_(self)
|
||||
return not_(self.copy())
|
||||
|
||||
|
||||
class Predicate(Condition):
|
||||
|
@ -830,11 +846,7 @@ class DerivedTable(Expression):
|
|||
|
||||
@property
|
||||
def selects(self):
|
||||
alias = self.args.get("alias")
|
||||
|
||||
if alias:
|
||||
return alias.columns
|
||||
return []
|
||||
return self.this.selects if isinstance(self.this, Subqueryable) else []
|
||||
|
||||
@property
|
||||
def named_selects(self):
|
||||
|
@ -904,7 +916,10 @@ class Unionable(Expression):
|
|||
|
||||
|
||||
class UDTF(DerivedTable, Unionable):
|
||||
pass
|
||||
@property
|
||||
def selects(self):
|
||||
alias = self.args.get("alias")
|
||||
return alias.columns if alias else []
|
||||
|
||||
|
||||
class Cache(Expression):
|
||||
|
@ -1073,6 +1088,10 @@ class ColumnDef(Expression):
|
|||
"position": False,
|
||||
}
|
||||
|
||||
@property
|
||||
def constraints(self) -> t.List[ColumnConstraint]:
|
||||
return self.args.get("constraints") or []
|
||||
|
||||
|
||||
class AlterColumn(Expression):
|
||||
arg_types = {
|
||||
|
@ -1100,6 +1119,10 @@ class Comment(Expression):
|
|||
class ColumnConstraint(Expression):
|
||||
arg_types = {"this": False, "kind": True}
|
||||
|
||||
@property
|
||||
def kind(self) -> ColumnConstraintKind:
|
||||
return self.args["kind"]
|
||||
|
||||
|
||||
class ColumnConstraintKind(Expression):
|
||||
pass
|
||||
|
@ -1937,6 +1960,15 @@ class Reference(Expression):
|
|||
class Tuple(Expression):
|
||||
arg_types = {"expressions": False}
|
||||
|
||||
def isin(
|
||||
self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy=True, **opts
|
||||
) -> In:
|
||||
return In(
|
||||
this=_maybe_copy(self, copy),
|
||||
expressions=[convert(e, copy=copy) for e in expressions],
|
||||
query=maybe_parse(query, copy=copy, **opts) if query else None,
|
||||
)
|
||||
|
||||
|
||||
class Subqueryable(Unionable):
|
||||
def subquery(self, alias=None, copy=True) -> Subquery:
|
||||
|
@ -2236,6 +2268,8 @@ class Select(Subqueryable):
|
|||
"expressions": False,
|
||||
"hint": False,
|
||||
"distinct": False,
|
||||
"struct": False, # https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#return_query_results_as_a_value_table
|
||||
"value": False,
|
||||
"into": False,
|
||||
"from": False,
|
||||
**QUERY_MODIFIERS,
|
||||
|
@ -2611,7 +2645,7 @@ class Select(Subqueryable):
|
|||
join.set("kind", kind.text)
|
||||
|
||||
if on:
|
||||
on = and_(*ensure_collection(on), dialect=dialect, **opts)
|
||||
on = and_(*ensure_collection(on), dialect=dialect, copy=copy, **opts)
|
||||
join.set("on", on)
|
||||
|
||||
if using:
|
||||
|
@ -2723,7 +2757,7 @@ class Select(Subqueryable):
|
|||
**opts,
|
||||
)
|
||||
|
||||
def distinct(self, distinct=True, copy=True) -> Select:
|
||||
def distinct(self, *ons: ExpOrStr, distinct: bool = True, copy: bool = True) -> Select:
|
||||
"""
|
||||
Set the OFFSET expression.
|
||||
|
||||
|
@ -2732,14 +2766,16 @@ class Select(Subqueryable):
|
|||
'SELECT DISTINCT x FROM tbl'
|
||||
|
||||
Args:
|
||||
distinct (bool): whether the Select should be distinct
|
||||
copy (bool): if `False`, modify this expression instance in-place.
|
||||
ons: the expressions to distinct on
|
||||
distinct: whether the Select should be distinct
|
||||
copy: if `False`, modify this expression instance in-place.
|
||||
|
||||
Returns:
|
||||
Select: the modified expression.
|
||||
"""
|
||||
instance = _maybe_copy(self, copy)
|
||||
instance.set("distinct", Distinct() if distinct else None)
|
||||
on = Tuple(expressions=[maybe_parse(on, copy=copy) for on in ons]) if ons else None
|
||||
instance.set("distinct", Distinct(on=on) if distinct else None)
|
||||
return instance
|
||||
|
||||
def ctas(self, table, properties=None, dialect=None, copy=True, **opts) -> Create:
|
||||
|
@ -2969,6 +3005,10 @@ class DataType(Expression):
|
|||
USMALLINT = auto()
|
||||
BIGINT = auto()
|
||||
UBIGINT = auto()
|
||||
INT128 = auto()
|
||||
UINT128 = auto()
|
||||
INT256 = auto()
|
||||
UINT256 = auto()
|
||||
FLOAT = auto()
|
||||
DOUBLE = auto()
|
||||
DECIMAL = auto()
|
||||
|
@ -3022,6 +3062,8 @@ class DataType(Expression):
|
|||
Type.TINYINT,
|
||||
Type.SMALLINT,
|
||||
Type.BIGINT,
|
||||
Type.INT128,
|
||||
Type.INT256,
|
||||
}
|
||||
|
||||
FLOAT_TYPES = {
|
||||
|
@ -3069,10 +3111,6 @@ class PseudoType(Expression):
|
|||
pass
|
||||
|
||||
|
||||
class StructKwarg(Expression):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
# WHERE x <OP> EXISTS|ALL|ANY|SOME(SELECT ...)
|
||||
class SubqueryPredicate(Predicate):
|
||||
pass
|
||||
|
@ -3538,14 +3576,20 @@ class Case(Func):
|
|||
arg_types = {"this": False, "ifs": True, "default": False}
|
||||
|
||||
def when(self, condition: ExpOrStr, then: ExpOrStr, copy: bool = True, **opts) -> Case:
|
||||
this = self.copy() if copy else self
|
||||
this.append("ifs", If(this=maybe_parse(condition, **opts), true=maybe_parse(then, **opts)))
|
||||
return this
|
||||
instance = _maybe_copy(self, copy)
|
||||
instance.append(
|
||||
"ifs",
|
||||
If(
|
||||
this=maybe_parse(condition, copy=copy, **opts),
|
||||
true=maybe_parse(then, copy=copy, **opts),
|
||||
),
|
||||
)
|
||||
return instance
|
||||
|
||||
def else_(self, condition: ExpOrStr, copy: bool = True, **opts) -> Case:
|
||||
this = self.copy() if copy else self
|
||||
this.set("default", maybe_parse(condition, **opts))
|
||||
return this
|
||||
instance = _maybe_copy(self, copy)
|
||||
instance.set("default", maybe_parse(condition, copy=copy, **opts))
|
||||
return instance
|
||||
|
||||
|
||||
class Cast(Func):
|
||||
|
@ -3760,6 +3804,14 @@ class Floor(Func):
|
|||
arg_types = {"this": True, "decimals": False}
|
||||
|
||||
|
||||
class FromBase64(Func):
|
||||
pass
|
||||
|
||||
|
||||
class ToBase64(Func):
|
||||
pass
|
||||
|
||||
|
||||
class Greatest(Func):
|
||||
arg_types = {"this": True, "expressions": False}
|
||||
is_var_len_args = True
|
||||
|
@ -3930,11 +3982,11 @@ class Pow(Binary, Func):
|
|||
|
||||
|
||||
class PercentileCont(AggFunc):
|
||||
pass
|
||||
arg_types = {"this": True, "expression": False}
|
||||
|
||||
|
||||
class PercentileDisc(AggFunc):
|
||||
pass
|
||||
arg_types = {"this": True, "expression": False}
|
||||
|
||||
|
||||
class Quantile(AggFunc):
|
||||
|
@ -4405,14 +4457,16 @@ def _apply_conjunction_builder(
|
|||
if append and existing is not None:
|
||||
expressions = [existing.this if into else existing] + list(expressions)
|
||||
|
||||
node = and_(*expressions, dialect=dialect, **opts)
|
||||
node = and_(*expressions, dialect=dialect, copy=copy, **opts)
|
||||
|
||||
inst.set(arg, into(this=node) if into else node)
|
||||
return inst
|
||||
|
||||
|
||||
def _combine(expressions, operator, dialect=None, **opts):
|
||||
expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions]
|
||||
def _combine(expressions, operator, dialect=None, copy=True, **opts):
|
||||
expressions = [
|
||||
condition(expression, dialect=dialect, copy=copy, **opts) for expression in expressions
|
||||
]
|
||||
this = expressions[0]
|
||||
if expressions[1:]:
|
||||
this = _wrap(this, Connector)
|
||||
|
@ -4626,7 +4680,7 @@ def delete(
|
|||
return delete_expr
|
||||
|
||||
|
||||
def condition(expression, dialect=None, **opts) -> Condition:
|
||||
def condition(expression, dialect=None, copy=True, **opts) -> Condition:
|
||||
"""
|
||||
Initialize a logical condition expression.
|
||||
|
||||
|
@ -4645,6 +4699,7 @@ def condition(expression, dialect=None, **opts) -> Condition:
|
|||
If an Expression instance is passed, this is used as-is.
|
||||
dialect (str): the dialect used to parse the input expression (in the case that the
|
||||
input expression is a SQL string).
|
||||
copy (bool): Whether or not to copy `expression` (only applies to expressions).
|
||||
**opts: other options to use to parse the input expressions (again, in the case
|
||||
that the input expression is a SQL string).
|
||||
|
||||
|
@ -4655,11 +4710,12 @@ def condition(expression, dialect=None, **opts) -> Condition:
|
|||
expression,
|
||||
into=Condition,
|
||||
dialect=dialect,
|
||||
copy=copy,
|
||||
**opts,
|
||||
)
|
||||
|
||||
|
||||
def and_(*expressions, dialect=None, **opts) -> And:
|
||||
def and_(*expressions, dialect=None, copy=True, **opts) -> And:
|
||||
"""
|
||||
Combine multiple conditions with an AND logical operator.
|
||||
|
||||
|
@ -4671,15 +4727,16 @@ def and_(*expressions, dialect=None, **opts) -> And:
|
|||
*expressions (str | Expression): the SQL code strings to parse.
|
||||
If an Expression instance is passed, this is used as-is.
|
||||
dialect (str): the dialect used to parse the input expression.
|
||||
copy (bool): whether or not to copy `expressions` (only applies to Expressions).
|
||||
**opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
And: the new condition
|
||||
"""
|
||||
return _combine(expressions, And, dialect, **opts)
|
||||
return _combine(expressions, And, dialect, copy=copy, **opts)
|
||||
|
||||
|
||||
def or_(*expressions, dialect=None, **opts) -> Or:
|
||||
def or_(*expressions, dialect=None, copy=True, **opts) -> Or:
|
||||
"""
|
||||
Combine multiple conditions with an OR logical operator.
|
||||
|
||||
|
@ -4691,15 +4748,16 @@ def or_(*expressions, dialect=None, **opts) -> Or:
|
|||
*expressions (str | Expression): the SQL code strings to parse.
|
||||
If an Expression instance is passed, this is used as-is.
|
||||
dialect (str): the dialect used to parse the input expression.
|
||||
copy (bool): whether or not to copy `expressions` (only applies to Expressions).
|
||||
**opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
Or: the new condition
|
||||
"""
|
||||
return _combine(expressions, Or, dialect, **opts)
|
||||
return _combine(expressions, Or, dialect, copy=copy, **opts)
|
||||
|
||||
|
||||
def not_(expression, dialect=None, **opts) -> Not:
|
||||
def not_(expression, dialect=None, copy=True, **opts) -> Not:
|
||||
"""
|
||||
Wrap a condition with a NOT operator.
|
||||
|
||||
|
@ -4719,13 +4777,14 @@ def not_(expression, dialect=None, **opts) -> Not:
|
|||
this = condition(
|
||||
expression,
|
||||
dialect=dialect,
|
||||
copy=copy,
|
||||
**opts,
|
||||
)
|
||||
return Not(this=_wrap(this, Connector))
|
||||
|
||||
|
||||
def paren(expression) -> Paren:
|
||||
return Paren(this=expression)
|
||||
def paren(expression, copy=True) -> Paren:
|
||||
return Paren(this=_maybe_copy(expression, copy))
|
||||
|
||||
|
||||
SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
|
||||
|
@ -4998,29 +5057,20 @@ def values(
|
|||
alias: optional alias
|
||||
columns: Optional list of ordered column names or ordered dictionary of column names to types.
|
||||
If either are provided then an alias is also required.
|
||||
If a dictionary is provided then the first column of the values will be casted to the expected type
|
||||
in order to help with type inference.
|
||||
|
||||
Returns:
|
||||
Values: the Values expression object
|
||||
"""
|
||||
if columns and not alias:
|
||||
raise ValueError("Alias is required when providing columns")
|
||||
table_alias = (
|
||||
TableAlias(this=to_identifier(alias), columns=[to_identifier(x) for x in columns])
|
||||
if columns
|
||||
else TableAlias(this=to_identifier(alias) if alias else None)
|
||||
)
|
||||
expressions = [convert(tup) for tup in values]
|
||||
if columns and isinstance(columns, dict):
|
||||
types = list(columns.values())
|
||||
expressions[0].set(
|
||||
"expressions",
|
||||
[cast(x, types[i]) for i, x in enumerate(expressions[0].expressions)],
|
||||
)
|
||||
|
||||
return Values(
|
||||
expressions=expressions,
|
||||
alias=table_alias,
|
||||
expressions=[convert(tup) for tup in values],
|
||||
alias=(
|
||||
TableAlias(this=to_identifier(alias), columns=[to_identifier(x) for x in columns])
|
||||
if columns
|
||||
else (TableAlias(this=to_identifier(alias)) if alias else None)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
@ -5068,19 +5118,20 @@ def rename_table(old_name: str | Table, new_name: str | Table) -> AlterTable:
|
|||
)
|
||||
|
||||
|
||||
def convert(value) -> Expression:
|
||||
def convert(value: t.Any, copy: bool = False) -> Expression:
|
||||
"""Convert a python value into an expression object.
|
||||
|
||||
Raises an error if a conversion is not possible.
|
||||
|
||||
Args:
|
||||
value (Any): a python object
|
||||
value: A python object.
|
||||
copy: Whether or not to copy `value` (only applies to Expressions and collections).
|
||||
|
||||
Returns:
|
||||
Expression: the equivalent expression object
|
||||
Expression: the equivalent expression object.
|
||||
"""
|
||||
if isinstance(value, Expression):
|
||||
return value
|
||||
return _maybe_copy(value, copy)
|
||||
if isinstance(value, str):
|
||||
return Literal.string(value)
|
||||
if isinstance(value, bool):
|
||||
|
@ -5098,13 +5149,13 @@ def convert(value) -> Expression:
|
|||
date_literal = Literal.string(value.strftime("%Y-%m-%d"))
|
||||
return DateStrToDate(this=date_literal)
|
||||
if isinstance(value, tuple):
|
||||
return Tuple(expressions=[convert(v) for v in value])
|
||||
return Tuple(expressions=[convert(v, copy=copy) for v in value])
|
||||
if isinstance(value, list):
|
||||
return Array(expressions=[convert(v) for v in value])
|
||||
return Array(expressions=[convert(v, copy=copy) for v in value])
|
||||
if isinstance(value, dict):
|
||||
return Map(
|
||||
keys=[convert(k) for k in value],
|
||||
values=[convert(v) for v in value.values()],
|
||||
keys=[convert(k, copy=copy) for k in value],
|
||||
values=[convert(v, copy=copy) for v in value.values()],
|
||||
)
|
||||
raise ValueError(f"Cannot convert {value}")
|
||||
|
||||
|
|
|
@ -25,6 +25,12 @@ class Generator:
|
|||
quote_end (str): specifies which ending character to use to delimit quotes. Default: '.
|
||||
identifier_start (str): specifies which starting character to use to delimit identifiers. Default: ".
|
||||
identifier_end (str): specifies which ending character to use to delimit identifiers. Default: ".
|
||||
bit_start (str): specifies which starting character to use to delimit bit literals. Default: None.
|
||||
bit_end (str): specifies which ending character to use to delimit bit literals. Default: None.
|
||||
hex_start (str): specifies which starting character to use to delimit hex literals. Default: None.
|
||||
hex_end (str): specifies which ending character to use to delimit hex literals. Default: None.
|
||||
byte_start (str): specifies which starting character to use to delimit byte literals. Default: None.
|
||||
byte_end (str): specifies which ending character to use to delimit byte literals. Default: None.
|
||||
identify (bool | str): 'always': always quote, 'safe': quote identifiers if they don't contain an upcase, True defaults to always.
|
||||
normalize (bool): if set to True all identifiers will lower cased
|
||||
string_escape (str): specifies a string escape character. Default: '.
|
||||
|
@ -227,6 +233,12 @@ class Generator:
|
|||
"quote_end",
|
||||
"identifier_start",
|
||||
"identifier_end",
|
||||
"bit_start",
|
||||
"bit_end",
|
||||
"hex_start",
|
||||
"hex_end",
|
||||
"byte_start",
|
||||
"byte_end",
|
||||
"identify",
|
||||
"normalize",
|
||||
"string_escape",
|
||||
|
@ -258,6 +270,12 @@ class Generator:
|
|||
quote_end=None,
|
||||
identifier_start=None,
|
||||
identifier_end=None,
|
||||
bit_start=None,
|
||||
bit_end=None,
|
||||
hex_start=None,
|
||||
hex_end=None,
|
||||
byte_start=None,
|
||||
byte_end=None,
|
||||
identify=False,
|
||||
normalize=False,
|
||||
string_escape=None,
|
||||
|
@ -284,6 +302,12 @@ class Generator:
|
|||
self.quote_end = quote_end or "'"
|
||||
self.identifier_start = identifier_start or '"'
|
||||
self.identifier_end = identifier_end or '"'
|
||||
self.bit_start = bit_start
|
||||
self.bit_end = bit_end
|
||||
self.hex_start = hex_start
|
||||
self.hex_end = hex_end
|
||||
self.byte_start = byte_start
|
||||
self.byte_end = byte_end
|
||||
self.identify = identify
|
||||
self.normalize = normalize
|
||||
self.string_escape = string_escape or "'"
|
||||
|
@ -361,7 +385,7 @@ class Generator:
|
|||
expression: t.Optional[exp.Expression] = None,
|
||||
comments: t.Optional[t.List[str]] = None,
|
||||
) -> str:
|
||||
comments = (comments or (expression and expression.comments)) if self._comments else None # type: ignore
|
||||
comments = ((expression and expression.comments) if comments is None else comments) if self._comments else None # type: ignore
|
||||
|
||||
if not comments or isinstance(expression, exp.Binary):
|
||||
return sql
|
||||
|
@ -510,12 +534,12 @@ class Generator:
|
|||
position = self.sql(expression, "position")
|
||||
return f"{position}{this}"
|
||||
|
||||
def columndef_sql(self, expression: exp.ColumnDef) -> str:
|
||||
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
|
||||
column = self.sql(expression, "this")
|
||||
kind = self.sql(expression, "kind")
|
||||
constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
|
||||
exists = "IF NOT EXISTS " if expression.args.get("exists") else ""
|
||||
kind = f" {kind}" if kind else ""
|
||||
kind = f"{sep}{kind}" if kind else ""
|
||||
constraints = f" {constraints}" if constraints else ""
|
||||
position = self.sql(expression, "position")
|
||||
position = f" {position}" if position else ""
|
||||
|
@ -524,7 +548,7 @@ class Generator:
|
|||
|
||||
def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
kind_sql = self.sql(expression, "kind")
|
||||
kind_sql = self.sql(expression, "kind").strip()
|
||||
return f"CONSTRAINT {this} {kind_sql}" if this else kind_sql
|
||||
|
||||
def autoincrementcolumnconstraint_sql(self, _) -> str:
|
||||
|
@ -716,13 +740,22 @@ class Generator:
|
|||
return f"{alias}{columns}"
|
||||
|
||||
def bitstring_sql(self, expression: exp.BitString) -> str:
|
||||
return self.sql(expression, "this")
|
||||
this = self.sql(expression, "this")
|
||||
if self.bit_start:
|
||||
return f"{self.bit_start}{this}{self.bit_end}"
|
||||
return f"{int(this, 2)}"
|
||||
|
||||
def hexstring_sql(self, expression: exp.HexString) -> str:
|
||||
return self.sql(expression, "this")
|
||||
this = self.sql(expression, "this")
|
||||
if self.hex_start:
|
||||
return f"{self.hex_start}{this}{self.hex_end}"
|
||||
return f"{int(this, 16)}"
|
||||
|
||||
def bytestring_sql(self, expression: exp.ByteString) -> str:
|
||||
return self.sql(expression, "this")
|
||||
this = self.sql(expression, "this")
|
||||
if self.byte_start:
|
||||
return f"{self.byte_start}{this}{self.byte_end}"
|
||||
return this
|
||||
|
||||
def datatype_sql(self, expression: exp.DataType) -> str:
|
||||
type_value = expression.this
|
||||
|
@ -1115,10 +1148,12 @@ class Generator:
|
|||
|
||||
return f"{table}{system_time}{alias}{hints}{laterals}{joins}{pivots}"
|
||||
|
||||
def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str:
|
||||
def tablesample_sql(
|
||||
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
|
||||
) -> str:
|
||||
if self.alias_post_tablesample and expression.this.alias:
|
||||
this = self.sql(expression.this, "this")
|
||||
alias = f" AS {self.sql(expression.this, 'alias')}"
|
||||
alias = f"{sep}{self.sql(expression.this, 'alias')}"
|
||||
else:
|
||||
this = self.sql(expression, "this")
|
||||
alias = ""
|
||||
|
@ -1447,16 +1482,16 @@ class Generator:
|
|||
)
|
||||
|
||||
def select_sql(self, expression: exp.Select) -> str:
|
||||
kind = expression.args.get("kind")
|
||||
kind = f" AS {kind}" if kind else ""
|
||||
hint = self.sql(expression, "hint")
|
||||
distinct = self.sql(expression, "distinct")
|
||||
distinct = f" {distinct}" if distinct else ""
|
||||
kind = expression.args.get("kind")
|
||||
kind = f" AS {kind}" if kind else ""
|
||||
expressions = self.expressions(expression)
|
||||
expressions = f"{self.sep()}{expressions}" if expressions else expressions
|
||||
sql = self.query_modifiers(
|
||||
expression,
|
||||
f"SELECT{kind}{hint}{distinct}{expressions}",
|
||||
f"SELECT{hint}{distinct}{kind}{expressions}",
|
||||
self.sql(expression, "into", comment=False),
|
||||
self.sql(expression, "from", comment=False),
|
||||
)
|
||||
|
@ -1475,9 +1510,6 @@ class Generator:
|
|||
replace = f"{self.seg(self.STAR_MAPPING['replace'])} ({replace})" if replace else ""
|
||||
return f"*{except_}{replace}"
|
||||
|
||||
def structkwarg_sql(self, expression: exp.StructKwarg) -> str:
|
||||
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
|
||||
|
||||
def parameter_sql(self, expression: exp.Parameter) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
this = f"{{{this}}}" if expression.args.get("wrapped") else f"{this}"
|
||||
|
@ -1806,7 +1838,7 @@ class Generator:
|
|||
return self.binary(expression, op)
|
||||
|
||||
sqls = tuple(
|
||||
self.maybe_comment(self.sql(e), e, e.parent.comments) if i != 1 else self.sql(e)
|
||||
self.maybe_comment(self.sql(e), e, e.parent.comments or []) if i != 1 else self.sql(e)
|
||||
for i, e in enumerate(expression.flatten(unnest=False))
|
||||
)
|
||||
|
||||
|
|
|
@ -153,7 +153,7 @@ def join_condition(join):
|
|||
#
|
||||
# should pull y.b as the join key and x.a as the source key
|
||||
if normalized(on):
|
||||
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
|
||||
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False)
|
||||
|
||||
for condition in on.flatten():
|
||||
if isinstance(condition, exp.EQ):
|
||||
|
|
|
@ -29,6 +29,6 @@ def expand_laterals(expression: exp.Expression) -> exp.Expression:
|
|||
for column in projection.find_all(exp.Column):
|
||||
if not column.table and column.name in alias_to_expression:
|
||||
column.replace(alias_to_expression[column.name].copy())
|
||||
if isinstance(projection, exp.Alias):
|
||||
alias_to_expression[projection.alias] = projection.this
|
||||
if isinstance(projection, exp.Alias):
|
||||
alias_to_expression[projection.alias] = projection.this
|
||||
return expression
|
||||
|
|
|
@ -152,12 +152,14 @@ def _distribute(a, b, from_func, to_func, cache):
|
|||
lambda c: to_func(
|
||||
uniq_sort(flatten(from_func(c, b.left)), cache),
|
||||
uniq_sort(flatten(from_func(c, b.right)), cache),
|
||||
copy=False,
|
||||
),
|
||||
)
|
||||
else:
|
||||
a = to_func(
|
||||
uniq_sort(flatten(from_func(a, b.left)), cache),
|
||||
uniq_sort(flatten(from_func(a, b.right)), cache),
|
||||
copy=False,
|
||||
)
|
||||
|
||||
return a
|
||||
|
|
|
@ -10,7 +10,6 @@ from sqlglot.optimizer.canonicalize import canonicalize
|
|||
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
|
||||
from sqlglot.optimizer.eliminate_joins import eliminate_joins
|
||||
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
|
||||
from sqlglot.optimizer.expand_laterals import expand_laterals
|
||||
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
|
||||
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
|
||||
from sqlglot.optimizer.lower_identities import lower_identities
|
||||
|
@ -30,7 +29,6 @@ RULES = (
|
|||
qualify_tables,
|
||||
isolate_table_selects,
|
||||
qualify_columns,
|
||||
expand_laterals,
|
||||
pushdown_projections,
|
||||
validate_qualify_columns,
|
||||
normalize,
|
||||
|
|
|
@ -3,11 +3,12 @@ import typing as t
|
|||
|
||||
from sqlglot import alias, exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.optimizer.expand_laterals import expand_laterals as _expand_laterals
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import ensure_schema
|
||||
|
||||
|
||||
def qualify_columns(expression, schema):
|
||||
def qualify_columns(expression, schema, expand_laterals=True):
|
||||
"""
|
||||
Rewrite sqlglot AST to have fully qualified columns.
|
||||
|
||||
|
@ -26,6 +27,9 @@ def qualify_columns(expression, schema):
|
|||
"""
|
||||
schema = ensure_schema(schema)
|
||||
|
||||
if not schema.mapping and expand_laterals:
|
||||
expression = _expand_laterals(expression)
|
||||
|
||||
for scope in traverse_scope(expression):
|
||||
resolver = Resolver(scope, schema)
|
||||
_pop_table_column_aliases(scope.ctes)
|
||||
|
@ -39,6 +43,9 @@ def qualify_columns(expression, schema):
|
|||
_expand_group_by(scope, resolver)
|
||||
_expand_order_by(scope)
|
||||
|
||||
if schema.mapping and expand_laterals:
|
||||
expression = _expand_laterals(expression)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -124,7 +131,7 @@ def _expand_using(scope, resolver):
|
|||
tables[join_table] = None
|
||||
|
||||
join.args.pop("using")
|
||||
join.set("on", exp.and_(*conditions))
|
||||
join.set("on", exp.and_(*conditions, copy=False))
|
||||
|
||||
if column_tables:
|
||||
for column in scope.columns:
|
||||
|
@ -240,7 +247,9 @@ def _qualify_columns(scope, resolver):
|
|||
# column_table can be a '' because bigquery unnest has no table alias
|
||||
if column_table:
|
||||
column.set("table", column_table)
|
||||
elif column_table not in scope.sources:
|
||||
elif column_table not in scope.sources and (
|
||||
not scope.parent or column_table not in scope.parent.sources
|
||||
):
|
||||
# structs are used like tables (e.g. "struct"."field"), so they need to be qualified
|
||||
# separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
|
||||
|
||||
|
@ -376,10 +385,13 @@ def _qualify_outputs(scope):
|
|||
if not selection.output_name:
|
||||
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
|
||||
elif not isinstance(selection, exp.Alias) and not selection.is_star:
|
||||
alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}")
|
||||
alias_.set("this", selection)
|
||||
selection = alias_
|
||||
|
||||
selection = alias(
|
||||
selection,
|
||||
alias=selection.output_name or f"_col_{i}",
|
||||
quoted=True
|
||||
if isinstance(selection, exp.Column) and selection.this.quoted
|
||||
else None,
|
||||
)
|
||||
if aliased_column:
|
||||
selection.set("alias", exp.to_identifier(aliased_column))
|
||||
|
||||
|
|
|
@ -7,21 +7,29 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
|
|||
|
||||
def qualify_tables(expression, db=None, catalog=None, schema=None):
|
||||
"""
|
||||
Rewrite sqlglot AST to have fully qualified tables.
|
||||
Rewrite sqlglot AST to have fully qualified tables. Additionally, this
|
||||
replaces "join constructs" (*) by equivalent SELECT * subqueries.
|
||||
|
||||
Example:
|
||||
Examples:
|
||||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
|
||||
>>> qualify_tables(expression, db="db").sql()
|
||||
'SELECT 1 FROM db.tbl AS tbl'
|
||||
>>>
|
||||
>>> expression = sqlglot.parse_one("SELECT * FROM (tbl1 JOIN tbl2 ON id1 = id2)")
|
||||
>>> qualify_tables(expression).sql()
|
||||
'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to qualify
|
||||
db (str): Database name
|
||||
catalog (str): Catalog name
|
||||
schema: A schema to populate
|
||||
|
||||
Returns:
|
||||
sqlglot.Expression: qualified expression
|
||||
|
||||
(*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html
|
||||
"""
|
||||
sequence = itertools.count()
|
||||
|
||||
|
@ -29,6 +37,12 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
|
|||
|
||||
for scope in traverse_scope(expression):
|
||||
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
|
||||
# Expand join construct
|
||||
if isinstance(derived_table, exp.Subquery):
|
||||
unnested = derived_table.unnest()
|
||||
if isinstance(unnested, exp.Table):
|
||||
derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
|
||||
|
||||
if not derived_table.args.get("alias"):
|
||||
alias_ = f"_q_{next(sequence)}"
|
||||
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
|
||||
|
|
|
@ -510,6 +510,9 @@ def _traverse_scope(scope):
|
|||
yield from _traverse_union(scope)
|
||||
elif isinstance(scope.expression, exp.Subquery):
|
||||
yield from _traverse_subqueries(scope)
|
||||
elif isinstance(scope.expression, exp.Table):
|
||||
# This case corresponds to a "join construct", i.e. (tbl1 JOIN tbl2 ON ..)
|
||||
yield from _traverse_tables(scope)
|
||||
elif isinstance(scope.expression, exp.UDTF):
|
||||
pass
|
||||
else:
|
||||
|
@ -587,6 +590,9 @@ def _traverse_tables(scope):
|
|||
for join in scope.expression.args.get("joins") or []:
|
||||
expressions.append(join.this)
|
||||
|
||||
if isinstance(scope.expression, exp.Table):
|
||||
expressions.append(scope.expression)
|
||||
|
||||
expressions.extend(scope.expression.args.get("laterals") or [])
|
||||
|
||||
for expression in expressions:
|
||||
|
|
|
@ -60,6 +60,7 @@ def rewrite_between(expression: exp.Expression) -> exp.Expression:
|
|||
return exp.and_(
|
||||
exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
|
||||
exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
|
||||
copy=False,
|
||||
)
|
||||
return expression
|
||||
|
||||
|
@ -76,9 +77,17 @@ def simplify_not(expression):
|
|||
if isinstance(expression.this, exp.Paren):
|
||||
condition = expression.this.unnest()
|
||||
if isinstance(condition, exp.And):
|
||||
return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
|
||||
return exp.or_(
|
||||
exp.not_(condition.left, copy=False),
|
||||
exp.not_(condition.right, copy=False),
|
||||
copy=False,
|
||||
)
|
||||
if isinstance(condition, exp.Or):
|
||||
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
|
||||
return exp.and_(
|
||||
exp.not_(condition.left, copy=False),
|
||||
exp.not_(condition.right, copy=False),
|
||||
copy=False,
|
||||
)
|
||||
if is_null(condition):
|
||||
return exp.null()
|
||||
if always_true(expression.this):
|
||||
|
@ -254,12 +263,12 @@ def uniq_sort(expression, cache=None, root=True):
|
|||
# A AND C AND B -> A AND B AND C
|
||||
for i, (sql, e) in enumerate(arr[1:]):
|
||||
if sql < arr[i][0]:
|
||||
expression = result_func(*(e for _, e in sorted(arr)))
|
||||
expression = result_func(*(e for _, e in sorted(arr)), copy=False)
|
||||
break
|
||||
else:
|
||||
# we didn't have to sort but maybe we need to dedup
|
||||
if len(deduped) < len(flattened):
|
||||
expression = result_func(*deduped.values())
|
||||
expression = result_func(*deduped.values(), copy=False)
|
||||
|
||||
return expression
|
||||
|
||||
|
|
|
@ -126,9 +126,17 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.BIT,
|
||||
TokenType.BOOLEAN,
|
||||
TokenType.TINYINT,
|
||||
TokenType.UTINYINT,
|
||||
TokenType.SMALLINT,
|
||||
TokenType.USMALLINT,
|
||||
TokenType.INT,
|
||||
TokenType.UINT,
|
||||
TokenType.BIGINT,
|
||||
TokenType.UBIGINT,
|
||||
TokenType.INT128,
|
||||
TokenType.UINT128,
|
||||
TokenType.INT256,
|
||||
TokenType.UINT256,
|
||||
TokenType.FLOAT,
|
||||
TokenType.DOUBLE,
|
||||
TokenType.CHAR,
|
||||
|
@ -961,14 +969,15 @@ class Parser(metaclass=_Parser):
|
|||
The target expression.
|
||||
"""
|
||||
instance = exp_class(**kwargs)
|
||||
if self._prev_comments:
|
||||
instance.comments = self._prev_comments
|
||||
self._prev_comments = None
|
||||
if comments:
|
||||
instance.comments = comments
|
||||
instance.add_comments(comments) if comments else self._add_comments(instance)
|
||||
self.validate_expression(instance)
|
||||
return instance
|
||||
|
||||
def _add_comments(self, expression: t.Optional[exp.Expression]) -> None:
|
||||
if expression and self._prev_comments:
|
||||
expression.add_comments(self._prev_comments)
|
||||
self._prev_comments = None
|
||||
|
||||
def validate_expression(
|
||||
self, expression: exp.Expression, args: t.Optional[t.List] = None
|
||||
) -> None:
|
||||
|
@ -1567,7 +1576,7 @@ class Parser(metaclass=_Parser):
|
|||
value = self.expression(
|
||||
exp.Schema,
|
||||
this="TABLE",
|
||||
expressions=self._parse_csv(self._parse_struct_kwargs),
|
||||
expressions=self._parse_csv(self._parse_struct_types),
|
||||
)
|
||||
if not self._match(TokenType.GT):
|
||||
self.raise_error("Expecting >")
|
||||
|
@ -1802,14 +1811,15 @@ class Parser(metaclass=_Parser):
|
|||
elif self._match(TokenType.SELECT):
|
||||
comments = self._prev_comments
|
||||
|
||||
hint = self._parse_hint()
|
||||
all_ = self._match(TokenType.ALL)
|
||||
distinct = self._match(TokenType.DISTINCT)
|
||||
|
||||
kind = (
|
||||
self._match(TokenType.ALIAS)
|
||||
and self._match_texts(("STRUCT", "VALUE"))
|
||||
and self._prev.text
|
||||
)
|
||||
hint = self._parse_hint()
|
||||
all_ = self._match(TokenType.ALL)
|
||||
distinct = self._match(TokenType.DISTINCT)
|
||||
|
||||
if distinct:
|
||||
distinct = self.expression(
|
||||
|
@ -2284,7 +2294,7 @@ class Parser(metaclass=_Parser):
|
|||
if not self._match(TokenType.UNNEST):
|
||||
return None
|
||||
|
||||
expressions = self._parse_wrapped_csv(self._parse_column)
|
||||
expressions = self._parse_wrapped_csv(self._parse_type)
|
||||
ordinality = self._match_pair(TokenType.WITH, TokenType.ORDINALITY)
|
||||
alias = self._parse_table_alias()
|
||||
|
||||
|
@ -2333,7 +2343,9 @@ class Parser(metaclass=_Parser):
|
|||
size = None
|
||||
seed = None
|
||||
|
||||
kind = "TABLESAMPLE" if self._prev.token_type == TokenType.TABLE_SAMPLE else "USING SAMPLE"
|
||||
kind = (
|
||||
self._prev.text if self._prev.token_type == TokenType.TABLE_SAMPLE else "USING SAMPLE"
|
||||
)
|
||||
method = self._parse_var(tokens=(TokenType.ROW,))
|
||||
|
||||
self._match(TokenType.L_PAREN)
|
||||
|
@ -2684,7 +2696,7 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
this = self.expression(exp.In, this=this, expressions=expressions)
|
||||
|
||||
self._match_r_paren()
|
||||
self._match_r_paren(this)
|
||||
else:
|
||||
this = self.expression(exp.In, this=this, field=self._parse_field())
|
||||
|
||||
|
@ -2798,7 +2810,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
if self._match(TokenType.L_PAREN):
|
||||
if is_struct:
|
||||
expressions = self._parse_csv(self._parse_struct_kwargs)
|
||||
expressions = self._parse_csv(self._parse_struct_types)
|
||||
elif nested:
|
||||
expressions = self._parse_csv(self._parse_types)
|
||||
else:
|
||||
|
@ -2833,7 +2845,7 @@ class Parser(metaclass=_Parser):
|
|||
values: t.Optional[t.List[t.Optional[exp.Expression]]] = None
|
||||
if nested and self._match(TokenType.LT):
|
||||
if is_struct:
|
||||
expressions = self._parse_csv(self._parse_struct_kwargs)
|
||||
expressions = self._parse_csv(self._parse_struct_types)
|
||||
else:
|
||||
expressions = self._parse_csv(self._parse_types)
|
||||
|
||||
|
@ -2891,16 +2903,10 @@ class Parser(metaclass=_Parser):
|
|||
prefix=prefix,
|
||||
)
|
||||
|
||||
def _parse_struct_kwargs(self) -> t.Optional[exp.Expression]:
|
||||
index = self._index
|
||||
this = self._parse_id_var()
|
||||
def _parse_struct_types(self) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_type() or self._parse_id_var()
|
||||
self._match(TokenType.COLON)
|
||||
data_type = self._parse_types()
|
||||
|
||||
if not data_type:
|
||||
self._retreat(index)
|
||||
return self._parse_types()
|
||||
return self.expression(exp.StructKwarg, this=this, expression=data_type)
|
||||
return self._parse_column_def(this)
|
||||
|
||||
def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
if not self._match(TokenType.AT_TIME_ZONE):
|
||||
|
@ -2932,7 +2938,11 @@ class Parser(metaclass=_Parser):
|
|||
else exp.Literal.string(value)
|
||||
)
|
||||
else:
|
||||
field = self._parse_star() or self._parse_function() or self._parse_id_var()
|
||||
field = (
|
||||
self._parse_star()
|
||||
or self._parse_function(anonymous=True)
|
||||
or self._parse_id_var()
|
||||
)
|
||||
|
||||
if isinstance(field, exp.Func):
|
||||
# bigquery allows function calls like x.y.count(...)
|
||||
|
@ -2995,11 +3005,9 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
this = self.expression(exp.Paren, this=self._parse_set_operations(this))
|
||||
|
||||
self._match_r_paren()
|
||||
comments.extend(self._prev_comments)
|
||||
|
||||
if this and comments:
|
||||
this.comments = comments
|
||||
if this:
|
||||
this.add_comments(comments)
|
||||
self._match_r_paren(expression=this)
|
||||
|
||||
return this
|
||||
|
||||
|
@ -3017,7 +3025,7 @@ class Parser(metaclass=_Parser):
|
|||
)
|
||||
|
||||
def _parse_function(
|
||||
self, functions: t.Optional[t.Dict[str, t.Callable]] = None
|
||||
self, functions: t.Optional[t.Dict[str, t.Callable]] = None, anonymous: bool = False
|
||||
) -> t.Optional[exp.Expression]:
|
||||
if not self._curr:
|
||||
return None
|
||||
|
@ -3043,7 +3051,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
parser = self.FUNCTION_PARSERS.get(upper)
|
||||
|
||||
if parser:
|
||||
if parser and not anonymous:
|
||||
this = parser(self)
|
||||
else:
|
||||
subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type)
|
||||
|
@ -3059,7 +3067,7 @@ class Parser(metaclass=_Parser):
|
|||
function = functions.get(upper)
|
||||
args = self._parse_csv(self._parse_lambda)
|
||||
|
||||
if function:
|
||||
if function and not anonymous:
|
||||
# Clickhouse supports function calls like foo(x, y)(z), so for these we need to also parse the
|
||||
# second parameter list (i.e. "(z)") and the corresponding function will receive both arg lists.
|
||||
if count_params(function) == 2:
|
||||
|
@ -3148,12 +3156,7 @@ class Parser(metaclass=_Parser):
|
|||
if isinstance(left, exp.Column):
|
||||
left.replace(exp.Var(this=left.text("this")))
|
||||
|
||||
if self._match(TokenType.IGNORE_NULLS):
|
||||
this = self.expression(exp.IgnoreNulls, this=this)
|
||||
else:
|
||||
self._match(TokenType.RESPECT_NULLS)
|
||||
|
||||
return self._parse_limit(self._parse_order(this))
|
||||
return self._parse_limit(self._parse_order(self._parse_respect_or_ignore_nulls(this)))
|
||||
|
||||
def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
|
||||
index = self._index
|
||||
|
@ -3177,6 +3180,9 @@ class Parser(metaclass=_Parser):
|
|||
return self.expression(exp.Schema, this=this, expressions=args)
|
||||
|
||||
def _parse_column_def(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
# column defs are not really columns, they're identifiers
|
||||
if isinstance(this, exp.Column):
|
||||
this = this.this
|
||||
kind = self._parse_types()
|
||||
|
||||
if self._match_text_seq("FOR", "ORDINALITY"):
|
||||
|
@ -3420,7 +3426,7 @@ class Parser(metaclass=_Parser):
|
|||
elif not self._match(TokenType.R_BRACE) and bracket_kind == TokenType.L_BRACE:
|
||||
self.raise_error("Expected }")
|
||||
|
||||
this.comments = self._prev_comments
|
||||
self._add_comments(this)
|
||||
return self._parse_bracket(this)
|
||||
|
||||
def _parse_slice(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
|
@ -3584,7 +3590,9 @@ class Parser(metaclass=_Parser):
|
|||
exp.and_(
|
||||
exp.Is(this=expression.copy(), expression=exp.Null()),
|
||||
exp.Is(this=search.copy(), expression=exp.Null()),
|
||||
copy=False,
|
||||
),
|
||||
copy=False,
|
||||
)
|
||||
ifs.append(exp.If(this=cond, true=result))
|
||||
|
||||
|
@ -3717,15 +3725,15 @@ class Parser(metaclass=_Parser):
|
|||
if self._match_set(self.TRIM_TYPES):
|
||||
position = self._prev.text.upper()
|
||||
|
||||
expression = self._parse_term()
|
||||
expression = self._parse_bitwise()
|
||||
if self._match_set((TokenType.FROM, TokenType.COMMA)):
|
||||
this = self._parse_term()
|
||||
this = self._parse_bitwise()
|
||||
else:
|
||||
this = expression
|
||||
expression = None
|
||||
|
||||
if self._match(TokenType.COLLATE):
|
||||
collation = self._parse_term()
|
||||
collation = self._parse_bitwise()
|
||||
|
||||
return self.expression(
|
||||
exp.Trim,
|
||||
|
@ -3741,6 +3749,15 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_named_window(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_window(self._parse_id_var(), alias=True)
|
||||
|
||||
def _parse_respect_or_ignore_nulls(
|
||||
self, this: t.Optional[exp.Expression]
|
||||
) -> t.Optional[exp.Expression]:
|
||||
if self._match(TokenType.IGNORE_NULLS):
|
||||
return self.expression(exp.IgnoreNulls, this=this)
|
||||
if self._match(TokenType.RESPECT_NULLS):
|
||||
return self.expression(exp.RespectNulls, this=this)
|
||||
return this
|
||||
|
||||
def _parse_window(
|
||||
self, this: t.Optional[exp.Expression], alias: bool = False
|
||||
) -> t.Optional[exp.Expression]:
|
||||
|
@ -3768,10 +3785,7 @@ class Parser(metaclass=_Parser):
|
|||
# (https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/img_text/first_value.html)
|
||||
# and Snowflake chose to do the same for familiarity
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/first_value.html#usage-notes
|
||||
if self._match(TokenType.IGNORE_NULLS):
|
||||
this = self.expression(exp.IgnoreNulls, this=this)
|
||||
elif self._match(TokenType.RESPECT_NULLS):
|
||||
this = self.expression(exp.RespectNulls, this=this)
|
||||
this = self._parse_respect_or_ignore_nulls(this)
|
||||
|
||||
# bigquery select from window x AS (partition by ...)
|
||||
if alias:
|
||||
|
@ -3975,9 +3989,7 @@ class Parser(metaclass=_Parser):
|
|||
items = [parse_result] if parse_result is not None else []
|
||||
|
||||
while self._match(sep):
|
||||
if parse_result and self._prev_comments:
|
||||
parse_result.comments = self._prev_comments
|
||||
|
||||
self._add_comments(parse_result)
|
||||
parse_result = parse_method()
|
||||
if parse_result is not None:
|
||||
items.append(parse_result)
|
||||
|
@ -4345,13 +4357,14 @@ class Parser(metaclass=_Parser):
|
|||
self._retreat(index)
|
||||
return None
|
||||
|
||||
def _match(self, token_type, advance=True):
|
||||
def _match(self, token_type, advance=True, expression=None):
|
||||
if not self._curr:
|
||||
return None
|
||||
|
||||
if self._curr.token_type == token_type:
|
||||
if advance:
|
||||
self._advance()
|
||||
self._add_comments(expression)
|
||||
return True
|
||||
|
||||
return None
|
||||
|
@ -4379,16 +4392,12 @@ class Parser(metaclass=_Parser):
|
|||
return None
|
||||
|
||||
def _match_l_paren(self, expression=None):
|
||||
if not self._match(TokenType.L_PAREN):
|
||||
if not self._match(TokenType.L_PAREN, expression=expression):
|
||||
self.raise_error("Expecting (")
|
||||
if expression and self._prev_comments:
|
||||
expression.comments = self._prev_comments
|
||||
|
||||
def _match_r_paren(self, expression=None):
|
||||
if not self._match(TokenType.R_PAREN):
|
||||
if not self._match(TokenType.R_PAREN, expression=expression):
|
||||
self.raise_error("Expecting )")
|
||||
if expression and self._prev_comments:
|
||||
expression.comments = self._prev_comments
|
||||
|
||||
def _match_texts(self, texts, advance=True):
|
||||
if self._curr and self._curr.text.upper() in texts:
|
||||
|
|
|
@ -84,6 +84,10 @@ class TokenType(AutoName):
|
|||
UINT = auto()
|
||||
BIGINT = auto()
|
||||
UBIGINT = auto()
|
||||
INT128 = auto()
|
||||
UINT128 = auto()
|
||||
INT256 = auto()
|
||||
UINT256 = auto()
|
||||
FLOAT = auto()
|
||||
DOUBLE = auto()
|
||||
DECIMAL = auto()
|
||||
|
@ -774,8 +778,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"_end",
|
||||
"_peek",
|
||||
"_prev_token_line",
|
||||
"_prev_token_comments",
|
||||
"_prev_token_type",
|
||||
)
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
@ -795,8 +797,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._end = False
|
||||
self._peek = ""
|
||||
self._prev_token_line = -1
|
||||
self._prev_token_comments: t.List[str] = []
|
||||
self._prev_token_type: t.Optional[TokenType] = None
|
||||
|
||||
def tokenize(self, sql: str) -> t.List[Token]:
|
||||
"""Returns a list of tokens corresponding to the SQL string `sql`."""
|
||||
|
@ -846,7 +846,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
return self.sql[start:end]
|
||||
return ""
|
||||
|
||||
def _advance(self, i: int = 1) -> None:
|
||||
def _advance(self, i: int = 1, alnum: bool = False) -> None:
|
||||
if self.WHITE_SPACE.get(self._char) is TokenType.BREAK:
|
||||
self._col = 1
|
||||
self._line += 1
|
||||
|
@ -858,14 +858,30 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._char = self.sql[self._current - 1]
|
||||
self._peek = "" if self._end else self.sql[self._current]
|
||||
|
||||
if alnum and self._char.isalnum():
|
||||
_col = self._col
|
||||
_current = self._current
|
||||
_end = self._end
|
||||
_peek = self._peek
|
||||
|
||||
while _peek.isalnum():
|
||||
_col += 1
|
||||
_current += 1
|
||||
_end = _current >= self.size
|
||||
_peek = "" if _end else self.sql[_current]
|
||||
|
||||
self._col = _col
|
||||
self._current = _current
|
||||
self._end = _end
|
||||
self._peek = _peek
|
||||
self._char = self.sql[_current - 1]
|
||||
|
||||
@property
|
||||
def _text(self) -> str:
|
||||
return self.sql[self._start : self._current]
|
||||
|
||||
def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
|
||||
self._prev_token_line = self._line
|
||||
self._prev_token_comments = self._comments
|
||||
self._prev_token_type = token_type
|
||||
self.tokens.append(
|
||||
Token(
|
||||
token_type,
|
||||
|
@ -966,13 +982,13 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
comment_end_size = len(comment_end)
|
||||
while not self._end and self._chars(comment_end_size) != comment_end:
|
||||
self._advance()
|
||||
self._advance(alnum=True)
|
||||
|
||||
self._comments.append(self._text[comment_start_size : -comment_end_size + 1])
|
||||
self._advance(comment_end_size - 1)
|
||||
else:
|
||||
while not self._end and not self.WHITE_SPACE.get(self._peek) is TokenType.BREAK:
|
||||
self._advance()
|
||||
self._advance(alnum=True)
|
||||
self._comments.append(self._text[comment_start_size:])
|
||||
|
||||
# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding.
|
||||
|
@ -988,9 +1004,9 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
if self._char == "0":
|
||||
peek = self._peek.upper()
|
||||
if peek == "B":
|
||||
return self._scan_bits()
|
||||
return self._scan_bits() if self._BIT_STRINGS else self._add(TokenType.NUMBER)
|
||||
elif peek == "X":
|
||||
return self._scan_hex()
|
||||
return self._scan_hex() if self._HEX_STRINGS else self._add(TokenType.NUMBER)
|
||||
|
||||
decimal = False
|
||||
scientific = 0
|
||||
|
@ -1033,7 +1049,9 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._advance()
|
||||
value = self._extract_value()
|
||||
try:
|
||||
self._add(TokenType.BIT_STRING, f"{int(value, 2)}")
|
||||
# If `value` can't be converted to a binary, fallback to tokenizing it as an identifier
|
||||
int(value, 2)
|
||||
self._add(TokenType.BIT_STRING, value[2:]) # Drop the 0b
|
||||
except ValueError:
|
||||
self._add(TokenType.IDENTIFIER)
|
||||
|
||||
|
@ -1041,7 +1059,9 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._advance()
|
||||
value = self._extract_value()
|
||||
try:
|
||||
self._add(TokenType.HEX_STRING, f"{int(value, 16)}")
|
||||
# If `value` can't be converted to a hex, fallback to tokenizing it as an identifier
|
||||
int(value, 16)
|
||||
self._add(TokenType.HEX_STRING, value[2:]) # Drop the 0x
|
||||
except ValueError:
|
||||
self._add(TokenType.IDENTIFIER)
|
||||
|
||||
|
@ -1049,7 +1069,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
while True:
|
||||
char = self._peek.strip()
|
||||
if char and char not in self.SINGLE_TOKENS:
|
||||
self._advance()
|
||||
self._advance(alnum=True)
|
||||
else:
|
||||
break
|
||||
|
||||
|
@ -1066,7 +1086,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._add(TokenType.NATIONAL if quote[0].upper() == "N" else TokenType.STRING, text)
|
||||
return True
|
||||
|
||||
# X'1234, b'0110', E'\\\\\' etc.
|
||||
# X'1234', b'0110', E'\\\\\' etc.
|
||||
def _scan_formatted_string(self, string_start: str) -> bool:
|
||||
if string_start in self._HEX_STRINGS:
|
||||
delimiters = self._HEX_STRINGS
|
||||
|
@ -1087,60 +1107,43 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
string_end = delimiters[string_start]
|
||||
text = self._extract_string(string_end)
|
||||
|
||||
if base is None:
|
||||
self._add(token_type, text)
|
||||
else:
|
||||
if base:
|
||||
try:
|
||||
self._add(token_type, f"{int(text, base)}")
|
||||
int(text, base)
|
||||
except:
|
||||
raise RuntimeError(
|
||||
f"Numeric string contains invalid characters from {self._line}:{self._start}"
|
||||
)
|
||||
|
||||
self._add(token_type, text)
|
||||
return True
|
||||
|
||||
def _scan_identifier(self, identifier_end: str) -> None:
|
||||
text = ""
|
||||
identifier_end_is_escape = identifier_end in self._IDENTIFIER_ESCAPES
|
||||
|
||||
while True:
|
||||
if self._end:
|
||||
raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}")
|
||||
|
||||
self._advance()
|
||||
if self._char == identifier_end:
|
||||
if identifier_end_is_escape and self._peek == identifier_end:
|
||||
text += identifier_end
|
||||
self._advance()
|
||||
continue
|
||||
|
||||
break
|
||||
|
||||
text += self._char
|
||||
|
||||
self._advance()
|
||||
text = self._extract_string(identifier_end, self._IDENTIFIER_ESCAPES)
|
||||
self._add(TokenType.IDENTIFIER, text)
|
||||
|
||||
def _scan_var(self) -> None:
|
||||
while True:
|
||||
char = self._peek.strip()
|
||||
if char and (char in self.VAR_SINGLE_TOKENS or char not in self.SINGLE_TOKENS):
|
||||
self._advance()
|
||||
self._advance(alnum=True)
|
||||
else:
|
||||
break
|
||||
|
||||
self._add(
|
||||
TokenType.VAR
|
||||
if self._prev_token_type == TokenType.PARAMETER
|
||||
if self.tokens and self.tokens[-1].token_type == TokenType.PARAMETER
|
||||
else self.KEYWORDS.get(self._text.upper(), TokenType.VAR)
|
||||
)
|
||||
|
||||
def _extract_string(self, delimiter: str) -> str:
|
||||
def _extract_string(self, delimiter: str, escapes=None) -> str:
|
||||
text = ""
|
||||
delim_size = len(delimiter)
|
||||
escapes = self._STRING_ESCAPES if escapes is None else escapes
|
||||
|
||||
while True:
|
||||
if self._char in self._STRING_ESCAPES and (
|
||||
self._peek == delimiter or self._peek in self._STRING_ESCAPES
|
||||
):
|
||||
if self._char in escapes and (self._peek == delimiter or self._peek in escapes):
|
||||
if self._peek == delimiter:
|
||||
text += self._peek
|
||||
else:
|
||||
|
@ -1158,7 +1161,9 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
if self._end:
|
||||
raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}")
|
||||
text += self._char
|
||||
self._advance()
|
||||
|
||||
current = self._current - 1
|
||||
self._advance(alnum=True)
|
||||
text += self.sql[current : self._current - 1]
|
||||
|
||||
return text
|
||||
|
|
|
@ -121,20 +121,9 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr
|
|||
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
|
||||
other expressions. This transforms removes the precision from parameterized types in expressions.
|
||||
"""
|
||||
return expression.transform(
|
||||
lambda node: exp.DataType(
|
||||
**{
|
||||
**node.args,
|
||||
"expressions": [
|
||||
node_expression
|
||||
for node_expression in node.expressions
|
||||
if isinstance(node_expression, exp.DataType)
|
||||
],
|
||||
}
|
||||
)
|
||||
if isinstance(node, exp.DataType)
|
||||
else node,
|
||||
)
|
||||
for node in expression.find_all(exp.DataType):
|
||||
node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)])
|
||||
return expression
|
||||
|
||||
|
||||
def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
|
||||
|
@ -240,12 +229,36 @@ def remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
|
||||
if (
|
||||
isinstance(expression, exp.WithinGroup)
|
||||
and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc))
|
||||
and isinstance(expression.expression, exp.Order)
|
||||
):
|
||||
quantile = expression.this.this
|
||||
input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
|
||||
return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
|
||||
if isinstance(expression, exp.Pivot):
|
||||
expression.args["field"].transform(
|
||||
lambda node: exp.column(node.output_name) if isinstance(node, exp.Column) else node,
|
||||
copy=False,
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def preprocess(
|
||||
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
|
||||
) -> t.Callable[[Generator, exp.Expression], str]:
|
||||
"""
|
||||
Creates a new transform by chaining a sequence of transformations and converts the resulting
|
||||
expression to SQL, using an appropriate `Generator.TRANSFORMS` function.
|
||||
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
|
||||
or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
|
||||
|
||||
Args:
|
||||
transforms: sequence of transform functions. These will be called in order.
|
||||
|
@ -255,17 +268,28 @@ def preprocess(
|
|||
"""
|
||||
|
||||
def _to_sql(self, expression: exp.Expression) -> str:
|
||||
expression_type = type(expression)
|
||||
|
||||
expression = transforms[0](expression.copy())
|
||||
for t in transforms[1:]:
|
||||
expression = t(expression)
|
||||
return getattr(self, expression.key + "_sql")(expression)
|
||||
|
||||
_sql_handler = getattr(self, expression.key + "_sql", None)
|
||||
if _sql_handler:
|
||||
return _sql_handler(expression)
|
||||
|
||||
transforms_handler = self.TRANSFORMS.get(type(expression))
|
||||
if transforms_handler:
|
||||
# Ensures we don't enter an infinite loop. This can happen when the original expression
|
||||
# has the same type as the final expression and there's no _sql method available for it,
|
||||
# because then it'd re-enter _to_sql.
|
||||
if expression_type is type(expression):
|
||||
raise ValueError(
|
||||
f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
|
||||
)
|
||||
|
||||
return transforms_handler(self, expression)
|
||||
|
||||
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
|
||||
|
||||
return _to_sql
|
||||
|
||||
|
||||
UNALIAS_GROUP = {exp.Group: preprocess([unalias_group])}
|
||||
ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on])}
|
||||
ELIMINATE_QUALIFY = {exp.Select: preprocess([eliminate_qualify])}
|
||||
REMOVE_PRECISION_PARAMETERIZED_TYPES = {
|
||||
exp.Cast: preprocess([remove_precision_parameterized_types])
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue