Merging upstream version 23.7.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
ebba7c6a18
commit
d26905e4af
187 changed files with 86502 additions and 71397 deletions
|
@ -45,7 +45,7 @@ from sqlglot.expressions import (
|
|||
from sqlglot.generator import Generator as Generator
|
||||
from sqlglot.parser import Parser as Parser
|
||||
from sqlglot.schema import MappingSchema as MappingSchema, Schema as Schema
|
||||
from sqlglot.tokens import Tokenizer as Tokenizer, TokenType as TokenType
|
||||
from sqlglot.tokens import Token as Token, Tokenizer as Tokenizer, TokenType as TokenType
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E
|
||||
|
@ -69,6 +69,21 @@ schema = MappingSchema()
|
|||
"""The default schema used by SQLGlot (e.g. in the optimizer)."""
|
||||
|
||||
|
||||
def tokenize(sql: str, read: DialectType = None, dialect: DialectType = None) -> t.List[Token]:
|
||||
"""
|
||||
Tokenizes the given SQL string.
|
||||
|
||||
Args:
|
||||
sql: the SQL code string to tokenize.
|
||||
read: the SQL dialect to apply during tokenizing (eg. "spark", "hive", "presto", "mysql").
|
||||
dialect: the SQL dialect (alias for read).
|
||||
|
||||
Returns:
|
||||
The resulting list of tokens.
|
||||
"""
|
||||
return Dialect.get_or_raise(read or dialect).tokenize(sql)
|
||||
|
||||
|
||||
def parse(
|
||||
sql: str, read: DialectType = None, dialect: DialectType = None, **opts
|
||||
) -> t.List[t.Optional[Expression]]:
|
||||
|
|
|
@ -18,8 +18,6 @@ from sqlglot.dataframe.sql.transforms import replace_id_value
|
|||
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
|
||||
from sqlglot.dataframe.sql.window import Window
|
||||
from sqlglot.helper import ensure_list, object_to_dict, seq_get
|
||||
from sqlglot.optimizer import optimize as optimize_func
|
||||
from sqlglot.optimizer.qualify_columns import quote_identifiers
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dataframe.sql._typing import (
|
||||
|
@ -121,7 +119,9 @@ class DataFrame:
|
|||
self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
|
||||
)
|
||||
replacement_mapping[old_name_id] = new_hashed_id
|
||||
expression = expression.transform(replace_id_value, replacement_mapping)
|
||||
expression = expression.transform(replace_id_value, replacement_mapping).assert_is(
|
||||
exp.Select
|
||||
)
|
||||
return expression
|
||||
|
||||
def _create_cte_from_expression(
|
||||
|
@ -306,11 +306,12 @@ class DataFrame:
|
|||
replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
|
||||
|
||||
for expression_type, select_expression in select_expressions:
|
||||
select_expression = select_expression.transform(replace_id_value, replacement_mapping)
|
||||
select_expression = select_expression.transform(
|
||||
replace_id_value, replacement_mapping
|
||||
).assert_is(exp.Select)
|
||||
if optimize:
|
||||
quote_identifiers(select_expression, dialect=dialect)
|
||||
select_expression = t.cast(
|
||||
exp.Select, optimize_func(select_expression, dialect=dialect)
|
||||
exp.Select, self.spark._optimize(select_expression, dialect=dialect)
|
||||
)
|
||||
|
||||
select_expression = df._replace_cte_names_with_hashes(select_expression)
|
||||
|
|
|
@ -184,7 +184,7 @@ def floor(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def log10(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, expression.Log10)
|
||||
return Column.invoke_expression_over_column(lit(10), expression.Log, expression=col)
|
||||
|
||||
|
||||
def log1p(col: ColumnOrName) -> Column:
|
||||
|
@ -192,7 +192,7 @@ def log1p(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def log2(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, expression.Log2)
|
||||
return Column.invoke_expression_over_column(lit(2), expression.Log, expression=col)
|
||||
|
||||
|
||||
def log(arg1: t.Union[ColumnOrName, float], arg2: t.Optional[ColumnOrName] = None) -> Column:
|
||||
|
@ -356,15 +356,15 @@ def coalesce(*cols: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column:
|
||||
return Column.invoke_anonymous_function(col1, "CORR", col2)
|
||||
return Column.invoke_expression_over_column(col1, expression.Corr, expression=col2)
|
||||
|
||||
|
||||
def covar_pop(col1: ColumnOrName, col2: ColumnOrName) -> Column:
|
||||
return Column.invoke_anonymous_function(col1, "COVAR_POP", col2)
|
||||
return Column.invoke_expression_over_column(col1, expression.CovarPop, expression=col2)
|
||||
|
||||
|
||||
def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column:
|
||||
return Column.invoke_anonymous_function(col1, "COVAR_SAMP", col2)
|
||||
return Column.invoke_expression_over_column(col1, expression.CovarSamp, expression=col2)
|
||||
|
||||
|
||||
def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
|
||||
|
@ -971,10 +971,10 @@ def array_join(
|
|||
) -> Column:
|
||||
if null_replacement is not None:
|
||||
return Column.invoke_expression_over_column(
|
||||
col, expression.ArrayJoin, expression=lit(delimiter), null=lit(null_replacement)
|
||||
col, expression.ArrayToString, expression=lit(delimiter), null=lit(null_replacement)
|
||||
)
|
||||
return Column.invoke_expression_over_column(
|
||||
col, expression.ArrayJoin, expression=lit(delimiter)
|
||||
col, expression.ArrayToString, expression=lit(delimiter)
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -12,6 +12,8 @@ from sqlglot.dataframe.sql.readwriter import DataFrameReader
|
|||
from sqlglot.dataframe.sql.types import StructType
|
||||
from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input
|
||||
from sqlglot.helper import classproperty
|
||||
from sqlglot.optimizer import optimize
|
||||
from sqlglot.optimizer.qualify_columns import quote_identifiers
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput
|
||||
|
@ -104,8 +106,15 @@ class SparkSession:
|
|||
sel_expression = exp.Select(**select_kwargs)
|
||||
return DataFrame(self, sel_expression)
|
||||
|
||||
def _optimize(
|
||||
self, expression: exp.Expression, dialect: t.Optional[Dialect] = None
|
||||
) -> exp.Expression:
|
||||
dialect = dialect or self.dialect
|
||||
quote_identifiers(expression, dialect=dialect)
|
||||
return optimize(expression, dialect=dialect)
|
||||
|
||||
def sql(self, sqlQuery: str) -> DataFrame:
|
||||
expression = sqlglot.parse_one(sqlQuery, read=self.dialect)
|
||||
expression = self._optimize(sqlglot.parse_one(sqlQuery, read=self.dialect))
|
||||
if isinstance(expression, exp.Select):
|
||||
df = DataFrame(self, expression)
|
||||
df = df._convert_leaf_to_cte()
|
||||
|
|
|
@ -61,6 +61,7 @@ dialect implementations in order to understand how their various components can
|
|||
----
|
||||
"""
|
||||
|
||||
from sqlglot.dialects.athena import Athena
|
||||
from sqlglot.dialects.bigquery import BigQuery
|
||||
from sqlglot.dialects.clickhouse import ClickHouse
|
||||
from sqlglot.dialects.databricks import Databricks
|
||||
|
@ -73,6 +74,7 @@ from sqlglot.dialects.mysql import MySQL
|
|||
from sqlglot.dialects.oracle import Oracle
|
||||
from sqlglot.dialects.postgres import Postgres
|
||||
from sqlglot.dialects.presto import Presto
|
||||
from sqlglot.dialects.prql import PRQL
|
||||
from sqlglot.dialects.redshift import Redshift
|
||||
from sqlglot.dialects.snowflake import Snowflake
|
||||
from sqlglot.dialects.spark import Spark
|
||||
|
|
37
sqlglot/dialects/athena.py
Normal file
37
sqlglot/dialects/athena.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.dialects.trino import Trino
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
class Athena(Trino):
|
||||
class Parser(Trino.Parser):
|
||||
STATEMENT_PARSERS = {
|
||||
**Trino.Parser.STATEMENT_PARSERS,
|
||||
TokenType.USING: lambda self: self._parse_as_command(self._prev),
|
||||
}
|
||||
|
||||
class Generator(Trino.Generator):
|
||||
PROPERTIES_LOCATION = {
|
||||
**Trino.Generator.PROPERTIES_LOCATION,
|
||||
exp.LocationProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**Trino.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.TEXT: "STRING",
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**Trino.Generator.TRANSFORMS,
|
||||
exp.FileFormatProperty: lambda self, e: f"'FORMAT'={self.sql(e, 'this')}",
|
||||
}
|
||||
|
||||
def property_sql(self, expression: exp.Property) -> str:
|
||||
return (
|
||||
f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}"
|
||||
)
|
||||
|
||||
def with_properties(self, properties: exp.Properties) -> str:
|
||||
return self.properties(properties, prefix=self.seg("TBLPROPERTIES"))
|
|
@ -24,6 +24,7 @@ from sqlglot.dialects.dialect import (
|
|||
rename_func,
|
||||
timestrtotime_sql,
|
||||
ts_or_ds_add_cast,
|
||||
unit_to_var,
|
||||
)
|
||||
from sqlglot.helper import seq_get, split_num_words
|
||||
from sqlglot.tokens import TokenType
|
||||
|
@ -41,14 +42,22 @@ def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Va
|
|||
structs = []
|
||||
alias = expression.args.get("alias")
|
||||
for tup in expression.find_all(exp.Tuple):
|
||||
field_aliases = alias.columns if alias else (f"_c{i}" for i in range(len(tup.expressions)))
|
||||
field_aliases = (
|
||||
alias.columns
|
||||
if alias and alias.columns
|
||||
else (f"_c{i}" for i in range(len(tup.expressions)))
|
||||
)
|
||||
expressions = [
|
||||
exp.PropertyEQ(this=exp.to_identifier(name), expression=fld)
|
||||
for name, fld in zip(field_aliases, tup.expressions)
|
||||
]
|
||||
structs.append(exp.Struct(expressions=expressions))
|
||||
|
||||
return self.unnest_sql(exp.Unnest(expressions=[exp.array(*structs, copy=False)]))
|
||||
# Due to `UNNEST_COLUMN_ONLY`, it is expected that the table alias be contained in the columns expression
|
||||
alias_name_only = exp.TableAlias(columns=[alias.this]) if alias else None
|
||||
return self.unnest_sql(
|
||||
exp.Unnest(expressions=[exp.array(*structs, copy=False)], alias=alias_name_only)
|
||||
)
|
||||
|
||||
|
||||
def _returnsproperty_sql(self: BigQuery.Generator, expression: exp.ReturnsProperty) -> str:
|
||||
|
@ -190,7 +199,7 @@ def _ts_or_ds_add_sql(self: BigQuery.Generator, expression: exp.TsOrDsAdd) -> st
|
|||
def _ts_or_ds_diff_sql(self: BigQuery.Generator, expression: exp.TsOrDsDiff) -> str:
|
||||
expression.this.replace(exp.cast(expression.this, "TIMESTAMP", copy=True))
|
||||
expression.expression.replace(exp.cast(expression.expression, "TIMESTAMP", copy=True))
|
||||
unit = expression.args.get("unit") or "DAY"
|
||||
unit = unit_to_var(expression)
|
||||
return self.func("DATE_DIFF", expression.this, expression.expression, unit)
|
||||
|
||||
|
||||
|
@ -238,16 +247,6 @@ class BigQuery(Dialect):
|
|||
"%E6S": "%S.%f",
|
||||
}
|
||||
|
||||
ESCAPE_SEQUENCES = {
|
||||
"\\a": "\a",
|
||||
"\\b": "\b",
|
||||
"\\f": "\f",
|
||||
"\\n": "\n",
|
||||
"\\r": "\r",
|
||||
"\\t": "\t",
|
||||
"\\v": "\v",
|
||||
}
|
||||
|
||||
FORMAT_MAPPING = {
|
||||
"DD": "%d",
|
||||
"MM": "%m",
|
||||
|
@ -315,6 +314,7 @@ class BigQuery(Dialect):
|
|||
"BEGIN TRANSACTION": TokenType.BEGIN,
|
||||
"BYTES": TokenType.BINARY,
|
||||
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
|
||||
"DATETIME": TokenType.TIMESTAMP,
|
||||
"DECLARE": TokenType.COMMAND,
|
||||
"ELSEIF": TokenType.COMMAND,
|
||||
"EXCEPTION": TokenType.COMMAND,
|
||||
|
@ -486,14 +486,14 @@ class BigQuery(Dialect):
|
|||
table.set("db", exp.Identifier(this=parts[0]))
|
||||
table.set("this", exp.Identifier(this=parts[1]))
|
||||
|
||||
if isinstance(table.this, exp.Identifier) and "." in table.name:
|
||||
if any("." in p.name for p in table.parts):
|
||||
catalog, db, this, *rest = (
|
||||
t.cast(t.Optional[exp.Expression], exp.to_identifier(x, quoted=True))
|
||||
for x in split_num_words(table.name, ".", 3)
|
||||
exp.to_identifier(p, quoted=True)
|
||||
for p in split_num_words(".".join(p.name for p in table.parts), ".", 3)
|
||||
)
|
||||
|
||||
if rest and this:
|
||||
this = exp.Dot.build(t.cast(t.List[exp.Expression], [this, *rest]))
|
||||
this = exp.Dot.build([this, *rest]) # type: ignore
|
||||
|
||||
table = exp.Table(this=this, db=db, catalog=catalog)
|
||||
table.meta["quoted_table"] = True
|
||||
|
@ -527,7 +527,9 @@ class BigQuery(Dialect):
|
|||
|
||||
return json_object
|
||||
|
||||
def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
def _parse_bracket(
|
||||
self, this: t.Optional[exp.Expression] = None
|
||||
) -> t.Optional[exp.Expression]:
|
||||
bracket = super()._parse_bracket(this)
|
||||
|
||||
if this is bracket:
|
||||
|
@ -566,6 +568,7 @@ class BigQuery(Dialect):
|
|||
IGNORE_NULLS_IN_FUNC = True
|
||||
JSON_PATH_SINGLE_QUOTE_ESCAPE = True
|
||||
CAN_IMPLEMENT_ARRAY_ANY = True
|
||||
SUPPORTS_TO_NUMBER = False
|
||||
NAMED_PLACEHOLDER_TOKEN = "@"
|
||||
|
||||
TRANSFORMS = {
|
||||
|
@ -588,7 +591,7 @@ class BigQuery(Dialect):
|
|||
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
|
||||
exp.DateAdd: date_add_interval_sql("DATE", "ADD"),
|
||||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF", e.this, e.expression, e.unit or "DAY"
|
||||
"DATE_DIFF", e.this, e.expression, unit_to_var(e)
|
||||
),
|
||||
exp.DateFromParts: rename_func("DATE"),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
|
@ -607,6 +610,7 @@ class BigQuery(Dialect):
|
|||
exp.IntDiv: rename_func("DIV"),
|
||||
exp.JSONFormat: rename_func("TO_JSON_STRING"),
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Mod: rename_func("MOD"),
|
||||
exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)),
|
||||
exp.MD5Digest: rename_func("MD5"),
|
||||
exp.Min: min_or_least,
|
||||
|
@ -847,10 +851,10 @@ class BigQuery(Dialect):
|
|||
return inline_array_sql(self, expression)
|
||||
|
||||
def bracket_sql(self, expression: exp.Bracket) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
this = expression.this
|
||||
expressions = expression.expressions
|
||||
|
||||
if len(expressions) == 1:
|
||||
if len(expressions) == 1 and this and this.is_type(exp.DataType.Type.STRUCT):
|
||||
arg = expressions[0]
|
||||
if arg.type is None:
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
@ -858,10 +862,10 @@ class BigQuery(Dialect):
|
|||
arg = annotate_types(arg)
|
||||
|
||||
if arg.type and arg.type.this in exp.DataType.TEXT_TYPES:
|
||||
# BQ doesn't support bracket syntax with string values
|
||||
return f"{this}.{arg.name}"
|
||||
# BQ doesn't support bracket syntax with string values for structs
|
||||
return f"{self.sql(this)}.{arg.name}"
|
||||
|
||||
expressions_sql = ", ".join(self.sql(e) for e in expressions)
|
||||
expressions_sql = self.expressions(expression, flat=True)
|
||||
offset = expression.args.get("offset")
|
||||
|
||||
if offset == 0:
|
||||
|
@ -874,7 +878,7 @@ class BigQuery(Dialect):
|
|||
if expression.args.get("safe"):
|
||||
expressions_sql = f"SAFE_{expressions_sql}"
|
||||
|
||||
return f"{this}[{expressions_sql}]"
|
||||
return f"{self.sql(this)}[{expressions_sql}]"
|
||||
|
||||
def in_unnest_op(self, expression: exp.Unnest) -> str:
|
||||
return self.sql(expression)
|
||||
|
|
|
@ -15,7 +15,6 @@ from sqlglot.dialects.dialect import (
|
|||
rename_func,
|
||||
var_map_sql,
|
||||
)
|
||||
from sqlglot.errors import ParseError
|
||||
from sqlglot.helper import is_int, seq_get
|
||||
from sqlglot.tokens import Token, TokenType
|
||||
|
||||
|
@ -49,8 +48,9 @@ class ClickHouse(Dialect):
|
|||
NULL_ORDERING = "nulls_are_last"
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
SAFE_DIVISION = True
|
||||
LOG_BASE_FIRST: t.Optional[bool] = None
|
||||
|
||||
ESCAPE_SEQUENCES = {
|
||||
UNESCAPED_SEQUENCES = {
|
||||
"\\0": "\0",
|
||||
}
|
||||
|
||||
|
@ -105,6 +105,7 @@ class ClickHouse(Dialect):
|
|||
# * select x from t1 union all select x from t2 limit 1;
|
||||
# * select x from t1 union all (select x from t2 limit 1);
|
||||
MODIFIERS_ATTACHED_TO_UNION = False
|
||||
INTERVAL_SPANS = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
|
@ -260,6 +261,11 @@ class ClickHouse(Dialect):
|
|||
"ArgMax",
|
||||
]
|
||||
|
||||
FUNC_TOKENS = {
|
||||
*parser.Parser.FUNC_TOKENS,
|
||||
TokenType.SET,
|
||||
}
|
||||
|
||||
AGG_FUNC_MAPPING = (
|
||||
lambda functions, suffixes: {
|
||||
f"{f}{sfx}": (f, sfx) for sfx in (suffixes + [""]) for f in functions
|
||||
|
@ -305,6 +311,10 @@ class ClickHouse(Dialect):
|
|||
TokenType.SETTINGS,
|
||||
}
|
||||
|
||||
ALIAS_TOKENS = parser.Parser.ALIAS_TOKENS - {
|
||||
TokenType.FORMAT,
|
||||
}
|
||||
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
QUERY_MODIFIER_PARSERS = {
|
||||
|
@ -316,6 +326,17 @@ class ClickHouse(Dialect):
|
|||
TokenType.FORMAT: lambda self: ("format", self._advance() or self._parse_id_var()),
|
||||
}
|
||||
|
||||
CONSTRAINT_PARSERS = {
|
||||
**parser.Parser.CONSTRAINT_PARSERS,
|
||||
"INDEX": lambda self: self._parse_index_constraint(),
|
||||
"CODEC": lambda self: self._parse_compress(),
|
||||
}
|
||||
|
||||
SCHEMA_UNNAMED_CONSTRAINTS = {
|
||||
*parser.Parser.SCHEMA_UNNAMED_CONSTRAINTS,
|
||||
"INDEX",
|
||||
}
|
||||
|
||||
def _parse_conjunction(self) -> t.Optional[exp.Expression]:
|
||||
this = super()._parse_conjunction()
|
||||
|
||||
|
@ -381,21 +402,20 @@ class ClickHouse(Dialect):
|
|||
|
||||
# https://clickhouse.com/docs/en/sql-reference/statements/select/with/
|
||||
def _parse_cte(self) -> exp.CTE:
|
||||
index = self._index
|
||||
try:
|
||||
# WITH <identifier> AS <subquery expression>
|
||||
return super()._parse_cte()
|
||||
except ParseError:
|
||||
# WITH <expression> AS <identifier>
|
||||
self._retreat(index)
|
||||
# WITH <identifier> AS <subquery expression>
|
||||
cte: t.Optional[exp.CTE] = self._try_parse(super()._parse_cte)
|
||||
|
||||
return self.expression(
|
||||
if not cte:
|
||||
# WITH <expression> AS <identifier>
|
||||
cte = self.expression(
|
||||
exp.CTE,
|
||||
this=self._parse_conjunction(),
|
||||
alias=self._parse_table_alias(),
|
||||
scalar=True,
|
||||
)
|
||||
|
||||
return cte
|
||||
|
||||
def _parse_join_parts(
|
||||
self,
|
||||
) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]:
|
||||
|
@ -508,6 +528,27 @@ class ClickHouse(Dialect):
|
|||
self._retreat(index)
|
||||
return None
|
||||
|
||||
def _parse_index_constraint(
|
||||
self, kind: t.Optional[str] = None
|
||||
) -> exp.IndexColumnConstraint:
|
||||
# INDEX name1 expr TYPE type1(args) GRANULARITY value
|
||||
this = self._parse_id_var()
|
||||
expression = self._parse_conjunction()
|
||||
|
||||
index_type = self._match_text_seq("TYPE") and (
|
||||
self._parse_function() or self._parse_var()
|
||||
)
|
||||
|
||||
granularity = self._match_text_seq("GRANULARITY") and self._parse_term()
|
||||
|
||||
return self.expression(
|
||||
exp.IndexColumnConstraint,
|
||||
this=this,
|
||||
expression=expression,
|
||||
index_type=index_type,
|
||||
granularity=granularity,
|
||||
)
|
||||
|
||||
class Generator(generator.Generator):
|
||||
QUERY_HINTS = False
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
@ -517,6 +558,7 @@ class ClickHouse(Dialect):
|
|||
TABLESAMPLE_KEYWORDS = "SAMPLE"
|
||||
LAST_DAY_SUPPORTS_DATE_PART = False
|
||||
CAN_IMPLEMENT_ARRAY_ANY = True
|
||||
SUPPORTS_TO_NUMBER = False
|
||||
|
||||
STRING_TYPE_MAPPING = {
|
||||
exp.DataType.Type.CHAR: "String",
|
||||
|
@ -585,6 +627,9 @@ class ClickHouse(Dialect):
|
|||
exp.Array: inline_array_sql,
|
||||
exp.CastToStrType: rename_func("CAST"),
|
||||
exp.CountIf: rename_func("countIf"),
|
||||
exp.CompressColumnConstraint: lambda self,
|
||||
e: f"CODEC({self.expressions(e, key='this', flat=True)})",
|
||||
exp.ComputedColumnConstraint: lambda self, e: f"ALIAS {self.sql(e, 'this')}",
|
||||
exp.CurrentDate: lambda self, e: self.func("CURRENT_DATE"),
|
||||
exp.DateAdd: date_delta_sql("DATE_ADD"),
|
||||
exp.DateDiff: date_delta_sql("DATE_DIFF"),
|
||||
|
@ -737,3 +782,15 @@ class ClickHouse(Dialect):
|
|||
def prewhere_sql(self, expression: exp.PreWhere) -> str:
|
||||
this = self.indent(self.sql(expression, "this"))
|
||||
return f"{self.seg('PREWHERE')}{self.sep()}{this}"
|
||||
|
||||
def indexcolumnconstraint_sql(self, expression: exp.IndexColumnConstraint) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
this = f" {this}" if this else ""
|
||||
expr = self.sql(expression, "expression")
|
||||
expr = f" {expr}" if expr else ""
|
||||
index_type = self.sql(expression, "index_type")
|
||||
index_type = f" TYPE {index_type}" if index_type else ""
|
||||
granularity = self.sql(expression, "granularity")
|
||||
granularity = f" GRANULARITY {granularity}" if granularity else ""
|
||||
|
||||
return f"INDEX{this}{expr}{index_type}{granularity}"
|
||||
|
|
|
@ -31,6 +31,7 @@ class Dialects(str, Enum):
|
|||
|
||||
DIALECT = ""
|
||||
|
||||
ATHENA = "athena"
|
||||
BIGQUERY = "bigquery"
|
||||
CLICKHOUSE = "clickhouse"
|
||||
DATABRICKS = "databricks"
|
||||
|
@ -42,6 +43,7 @@ class Dialects(str, Enum):
|
|||
ORACLE = "oracle"
|
||||
POSTGRES = "postgres"
|
||||
PRESTO = "presto"
|
||||
PRQL = "prql"
|
||||
REDSHIFT = "redshift"
|
||||
SNOWFLAKE = "snowflake"
|
||||
SPARK = "spark"
|
||||
|
@ -108,11 +110,18 @@ class _Dialect(type):
|
|||
klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
|
||||
klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
|
||||
|
||||
klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()}
|
||||
base = seq_get(bases, 0)
|
||||
base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),)
|
||||
base_parser = (getattr(base, "parser_class", Parser),)
|
||||
base_generator = (getattr(base, "generator_class", Generator),)
|
||||
|
||||
klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
|
||||
klass.parser_class = getattr(klass, "Parser", Parser)
|
||||
klass.generator_class = getattr(klass, "Generator", Generator)
|
||||
klass.tokenizer_class = klass.__dict__.get(
|
||||
"Tokenizer", type("Tokenizer", base_tokenizer, {})
|
||||
)
|
||||
klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {}))
|
||||
klass.generator_class = klass.__dict__.get(
|
||||
"Generator", type("Generator", base_generator, {})
|
||||
)
|
||||
|
||||
klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
|
||||
klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
|
||||
|
@ -134,9 +143,31 @@ class _Dialect(type):
|
|||
klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
|
||||
klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
|
||||
|
||||
if "\\" in klass.tokenizer_class.STRING_ESCAPES:
|
||||
klass.UNESCAPED_SEQUENCES = {
|
||||
"\\a": "\a",
|
||||
"\\b": "\b",
|
||||
"\\f": "\f",
|
||||
"\\n": "\n",
|
||||
"\\r": "\r",
|
||||
"\\t": "\t",
|
||||
"\\v": "\v",
|
||||
"\\\\": "\\",
|
||||
**klass.UNESCAPED_SEQUENCES,
|
||||
}
|
||||
|
||||
klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()}
|
||||
|
||||
if enum not in ("", "bigquery"):
|
||||
klass.generator_class.SELECT_KINDS = ()
|
||||
|
||||
if enum not in ("", "databricks", "hive", "spark", "spark2"):
|
||||
modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
|
||||
for modifier in ("cluster", "distribute", "sort"):
|
||||
modifier_transforms.pop(modifier, None)
|
||||
|
||||
klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
|
||||
|
||||
if not klass.SUPPORTS_SEMI_ANTI_JOIN:
|
||||
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
|
||||
TokenType.ANTI,
|
||||
|
@ -189,8 +220,11 @@ class Dialect(metaclass=_Dialect):
|
|||
False: Disables function name normalization.
|
||||
"""
|
||||
|
||||
LOG_BASE_FIRST = True
|
||||
"""Whether the base comes first in the `LOG` function."""
|
||||
LOG_BASE_FIRST: t.Optional[bool] = True
|
||||
"""
|
||||
Whether the base comes first in the `LOG` function.
|
||||
Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
|
||||
"""
|
||||
|
||||
NULL_ORDERING = "nulls_are_small"
|
||||
"""
|
||||
|
@ -226,8 +260,8 @@ class Dialect(metaclass=_Dialect):
|
|||
If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
|
||||
"""
|
||||
|
||||
ESCAPE_SEQUENCES: t.Dict[str, str] = {}
|
||||
"""Mapping of an unescaped escape sequence to the corresponding character."""
|
||||
UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
|
||||
"""Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
|
||||
|
||||
PSEUDOCOLUMNS: t.Set[str] = set()
|
||||
"""
|
||||
|
@ -266,7 +300,7 @@ class Dialect(metaclass=_Dialect):
|
|||
INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
|
||||
INVERSE_TIME_TRIE: t.Dict = {}
|
||||
|
||||
INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
|
||||
ESCAPED_SEQUENCES: t.Dict[str, str] = {}
|
||||
|
||||
# Delimiters for string literals and identifiers
|
||||
QUOTE_START = "'"
|
||||
|
@ -587,13 +621,21 @@ def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) ->
|
|||
return ""
|
||||
|
||||
|
||||
def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
|
||||
def str_position_sql(
|
||||
self: Generator, expression: exp.StrPosition, generate_instance: bool = False
|
||||
) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
substr = self.sql(expression, "substr")
|
||||
position = self.sql(expression, "position")
|
||||
instance = expression.args.get("instance") if generate_instance else None
|
||||
position_offset = ""
|
||||
|
||||
if position:
|
||||
return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
|
||||
return f"STRPOS({this}, {substr})"
|
||||
# Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
|
||||
this = self.func("SUBSTR", this, position)
|
||||
position_offset = f" + {position} - 1"
|
||||
|
||||
return self.func("STRPOS", this, substr, instance) + position_offset
|
||||
|
||||
|
||||
def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
|
||||
|
@ -689,9 +731,7 @@ def build_date_delta_with_interval(
|
|||
if expression and expression.is_string:
|
||||
expression = exp.Literal.number(expression.this)
|
||||
|
||||
return expression_class(
|
||||
this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
|
||||
)
|
||||
return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
|
||||
|
||||
return _builder
|
||||
|
||||
|
@ -710,18 +750,14 @@ def date_add_interval_sql(
|
|||
) -> t.Callable[[Generator, exp.Expression], str]:
|
||||
def func(self: Generator, expression: exp.Expression) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = expression.args.get("unit")
|
||||
unit = exp.var(unit.name.upper() if unit else "DAY")
|
||||
interval = exp.Interval(this=expression.expression, unit=unit)
|
||||
interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
|
||||
return f"{data_type}_{kind}({this}, {self.sql(interval)})"
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
|
||||
return self.func(
|
||||
"DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
|
||||
)
|
||||
return self.func("DATE_TRUNC", unit_to_str(expression), expression.this)
|
||||
|
||||
|
||||
def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
|
||||
|
@ -956,7 +992,7 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE
|
|||
|
||||
return self.func(
|
||||
name,
|
||||
exp.var(expression.text("unit").upper() or "DAY"),
|
||||
unit_to_var(expression),
|
||||
expression.expression,
|
||||
expression.this,
|
||||
)
|
||||
|
@ -964,6 +1000,24 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE
|
|||
return _delta_sql
|
||||
|
||||
|
||||
def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
|
||||
unit = expression.args.get("unit")
|
||||
|
||||
if isinstance(unit, exp.Placeholder):
|
||||
return unit
|
||||
if unit:
|
||||
return exp.Literal.string(unit.name)
|
||||
return exp.Literal.string(default) if default else None
|
||||
|
||||
|
||||
def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
|
||||
unit = expression.args.get("unit")
|
||||
|
||||
if isinstance(unit, (exp.Var, exp.Placeholder)):
|
||||
return unit
|
||||
return exp.Var(this=default) if default else None
|
||||
|
||||
|
||||
def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
|
||||
trunc_curr_date = exp.func("date_trunc", "month", expression.this)
|
||||
plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
|
||||
|
@ -998,7 +1052,7 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
|
|||
|
||||
|
||||
def build_json_extract_path(
|
||||
expr_type: t.Type[F], zero_based_indexing: bool = True
|
||||
expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
|
||||
) -> t.Callable[[t.List], F]:
|
||||
def _builder(args: t.List) -> F:
|
||||
segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
|
||||
|
@ -1018,7 +1072,11 @@ def build_json_extract_path(
|
|||
|
||||
# This is done to avoid failing in the expression validator due to the arg count
|
||||
del args[2:]
|
||||
return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments))
|
||||
return expr_type(
|
||||
this=seq_get(args, 0),
|
||||
expression=exp.JSONPath(expressions=segments),
|
||||
only_json_types=arrow_req_json_type,
|
||||
)
|
||||
|
||||
return _builder
|
||||
|
||||
|
@ -1070,3 +1128,12 @@ def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> s
|
|||
unnest = exp.Unnest(expressions=[expression.this])
|
||||
filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
|
||||
return self.sql(exp.Array(expressions=[filtered]))
|
||||
|
||||
|
||||
def to_number_with_nls_param(self, expression: exp.ToNumber) -> str:
|
||||
return self.func(
|
||||
"TO_NUMBER",
|
||||
expression.this,
|
||||
expression.args.get("format"),
|
||||
expression.args.get("nlsparam"),
|
||||
)
|
||||
|
|
|
@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import (
|
|||
build_timestamp_trunc,
|
||||
rename_func,
|
||||
time_format,
|
||||
unit_to_str,
|
||||
)
|
||||
from sqlglot.dialects.mysql import MySQL
|
||||
|
||||
|
@ -27,7 +28,7 @@ class Doris(MySQL):
|
|||
}
|
||||
|
||||
class Generator(MySQL.Generator):
|
||||
CAST_MAPPING = {}
|
||||
LAST_DAY_SUPPORTS_DATE_PART = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**MySQL.Generator.TYPE_MAPPING,
|
||||
|
@ -36,8 +37,7 @@ class Doris(MySQL):
|
|||
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
|
||||
}
|
||||
|
||||
LAST_DAY_SUPPORTS_DATE_PART = False
|
||||
|
||||
CAST_MAPPING = {}
|
||||
TIMESTAMP_FUNC_TYPES = set()
|
||||
|
||||
TRANSFORMS = {
|
||||
|
@ -49,9 +49,7 @@ class Doris(MySQL):
|
|||
exp.ArrayAgg: rename_func("COLLECT_LIST"),
|
||||
exp.ArrayUniqueAgg: rename_func("COLLECT_SET"),
|
||||
exp.CurrentTimestamp: lambda self, _: self.func("NOW"),
|
||||
exp.DateTrunc: lambda self, e: self.func(
|
||||
"DATE_TRUNC", e.this, "'" + e.text("unit") + "'"
|
||||
),
|
||||
exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, unit_to_str(e)),
|
||||
exp.JSONExtractScalar: arrow_json_extract_sql,
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.Map: rename_func("ARRAY_MAP"),
|
||||
|
@ -63,9 +61,7 @@ class Doris(MySQL):
|
|||
exp.TsOrDsAdd: lambda self, e: self.func("DATE_ADD", e.this, e.expression),
|
||||
exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this),
|
||||
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.TimestampTrunc: lambda self, e: self.func(
|
||||
"DATE_TRUNC", e.this, "'" + e.text("unit") + "'"
|
||||
),
|
||||
exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, unit_to_str(e)),
|
||||
exp.UnixToStr: lambda self, e: self.func(
|
||||
"FROM_UNIXTIME", e.this, time_format("doris")(self, e)
|
||||
),
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
|
@ -12,18 +11,10 @@ from sqlglot.dialects.dialect import (
|
|||
str_position_sql,
|
||||
timestrtotime_sql,
|
||||
)
|
||||
from sqlglot.dialects.mysql import date_add_sql
|
||||
from sqlglot.transforms import preprocess, move_schema_columns_to_partitioned_by
|
||||
|
||||
|
||||
def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.DateSub], str]:
|
||||
def func(self: Drill.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = exp.var(expression.text("unit").upper() or "DAY")
|
||||
return self.func(f"DATE_{kind}", this, exp.Interval(this=expression.expression, unit=unit))
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def _str_to_date(self: Drill.Generator, expression: exp.StrToDate) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
|
@ -84,7 +75,6 @@ class Drill(Dialect):
|
|||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"DATE_FORMAT": build_formatted_time(exp.TimeToStr, "drill"),
|
||||
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
|
||||
"TO_CHAR": build_formatted_time(exp.TimeToStr, "drill"),
|
||||
}
|
||||
|
@ -124,9 +114,9 @@ class Drill(Dialect):
|
|||
exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
|
||||
exp.ArraySize: rename_func("REPEATED_COUNT"),
|
||||
exp.Create: preprocess([move_schema_columns_to_partitioned_by]),
|
||||
exp.DateAdd: _date_add_sql("ADD"),
|
||||
exp.DateAdd: date_add_sql("ADD"),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateSub: _date_add_sql("SUB"),
|
||||
exp.DateSub: date_add_sql("SUB"),
|
||||
exp.DateToDi: lambda self,
|
||||
e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.DATEINT_FORMAT}) AS INT)",
|
||||
exp.DiToDate: lambda self,
|
||||
|
|
|
@ -26,6 +26,7 @@ from sqlglot.dialects.dialect import (
|
|||
str_to_time_sql,
|
||||
timestamptrunc_sql,
|
||||
timestrtotime_sql,
|
||||
unit_to_var,
|
||||
)
|
||||
from sqlglot.helper import flatten, seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
@ -33,15 +34,16 @@ from sqlglot.tokens import TokenType
|
|||
|
||||
def _ts_or_ds_add_sql(self: DuckDB.Generator, expression: exp.TsOrDsAdd) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = self.sql(expression, "unit").strip("'") or "DAY"
|
||||
interval = self.sql(exp.Interval(this=expression.expression, unit=unit))
|
||||
interval = self.sql(exp.Interval(this=expression.expression, unit=unit_to_var(expression)))
|
||||
return f"CAST({this} AS {self.sql(expression.return_type)}) + {interval}"
|
||||
|
||||
|
||||
def _date_delta_sql(self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
def _date_delta_sql(
|
||||
self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub | exp.TimeAdd
|
||||
) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = self.sql(expression, "unit").strip("'") or "DAY"
|
||||
op = "+" if isinstance(expression, exp.DateAdd) else "-"
|
||||
unit = unit_to_var(expression)
|
||||
op = "+" if isinstance(expression, (exp.DateAdd, exp.TimeAdd)) else "-"
|
||||
return f"{this} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
|
||||
|
||||
|
||||
|
@ -186,6 +188,11 @@ class DuckDB(Dialect):
|
|||
return super().to_json_path(path)
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
HEREDOC_STRINGS = ["$"]
|
||||
|
||||
HEREDOC_TAG_IS_IDENTIFIER = True
|
||||
HEREDOC_STRING_ALTERNATIVE = TokenType.PARAMETER
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"//": TokenType.DIV,
|
||||
|
@ -199,6 +206,7 @@ class DuckDB(Dialect):
|
|||
"LOGICAL": TokenType.BOOLEAN,
|
||||
"ONLY": TokenType.ONLY,
|
||||
"PIVOT_WIDER": TokenType.PIVOT,
|
||||
"POSITIONAL": TokenType.POSITIONAL,
|
||||
"SIGNED": TokenType.INT,
|
||||
"STRING": TokenType.VARCHAR,
|
||||
"UBIGINT": TokenType.UBIGINT,
|
||||
|
@ -227,8 +235,8 @@ class DuckDB(Dialect):
|
|||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"ARRAY_HAS": exp.ArrayContains.from_arg_list,
|
||||
"ARRAY_SORT": exp.SortArray.from_arg_list,
|
||||
"ARRAY_REVERSE_SORT": _build_sort_array_desc,
|
||||
"ARRAY_SORT": exp.SortArray.from_arg_list,
|
||||
"DATEDIFF": _build_date_diff,
|
||||
"DATE_DIFF": _build_date_diff,
|
||||
"DATE_TRUNC": date_trunc_to_time,
|
||||
|
@ -285,6 +293,11 @@ class DuckDB(Dialect):
|
|||
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
|
||||
FUNCTION_PARSERS.pop("DECODE")
|
||||
|
||||
NO_PAREN_FUNCTION_PARSERS = {
|
||||
**parser.Parser.NO_PAREN_FUNCTION_PARSERS,
|
||||
"MAP": lambda self: self._parse_map(),
|
||||
}
|
||||
|
||||
TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
|
||||
TokenType.SEMI,
|
||||
TokenType.ANTI,
|
||||
|
@ -299,6 +312,13 @@ class DuckDB(Dialect):
|
|||
),
|
||||
}
|
||||
|
||||
def _parse_map(self) -> exp.ToMap | exp.Map:
|
||||
if self._match(TokenType.L_BRACE, advance=False):
|
||||
return self.expression(exp.ToMap, this=self._parse_bracket())
|
||||
|
||||
args = self._parse_wrapped_csv(self._parse_conjunction)
|
||||
return self.expression(exp.Map, keys=seq_get(args, 0), values=seq_get(args, 1))
|
||||
|
||||
def _parse_types(
|
||||
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
|
||||
) -> t.Optional[exp.Expression]:
|
||||
|
@ -345,6 +365,7 @@ class DuckDB(Dialect):
|
|||
SUPPORTS_CREATE_TABLE_LIKE = False
|
||||
MULTI_ARG_DISTINCT = False
|
||||
CAN_IMPLEMENT_ARRAY_ANY = True
|
||||
SUPPORTS_TO_NUMBER = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -425,6 +446,7 @@ class DuckDB(Dialect):
|
|||
"EPOCH", self.func("STRPTIME", e.this, self.format_time(e))
|
||||
),
|
||||
exp.Struct: _struct_sql,
|
||||
exp.TimeAdd: _date_delta_sql,
|
||||
exp.Timestamp: no_timestamp_sql,
|
||||
exp.TimestampDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this
|
||||
|
@ -478,7 +500,7 @@ class DuckDB(Dialect):
|
|||
|
||||
STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"}
|
||||
|
||||
UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Paren)
|
||||
UNWRAPPED_INTERVAL_VALUES = (exp.Literal, exp.Paren)
|
||||
|
||||
# DuckDB doesn't generally support CREATE TABLE .. properties
|
||||
# https://duckdb.org/docs/sql/statements/create_table.html
|
||||
|
@ -569,3 +591,9 @@ class DuckDB(Dialect):
|
|||
return rename_func("RANGE")(self, expression)
|
||||
|
||||
return super().generateseries_sql(expression)
|
||||
|
||||
def bracket_sql(self, expression: exp.Bracket) -> str:
|
||||
if isinstance(expression.this, exp.Array):
|
||||
expression.this.replace(exp.paren(expression.this))
|
||||
|
||||
return super().bracket_sql(expression)
|
||||
|
|
|
@ -319,7 +319,9 @@ class Hive(Dialect):
|
|||
"TO_DATE": build_formatted_time(exp.TsOrDsToDate, "hive"),
|
||||
"TO_JSON": exp.JSONFormat.from_arg_list,
|
||||
"UNBASE64": exp.FromBase64.from_arg_list,
|
||||
"UNIX_TIMESTAMP": build_formatted_time(exp.StrToUnix, "hive", True),
|
||||
"UNIX_TIMESTAMP": lambda args: build_formatted_time(exp.StrToUnix, "hive", True)(
|
||||
args or [exp.CurrentTimestamp()]
|
||||
),
|
||||
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
|
||||
}
|
||||
|
||||
|
@ -431,6 +433,7 @@ class Hive(Dialect):
|
|||
NVL2_SUPPORTED = False
|
||||
LAST_DAY_SUPPORTS_DATE_PART = False
|
||||
JSON_PATH_SINGLE_QUOTE_ESCAPE = True
|
||||
SUPPORTS_TO_NUMBER = False
|
||||
|
||||
EXPRESSIONS_WITHOUT_NESTED_CTES = {
|
||||
exp.Insert,
|
||||
|
@ -472,7 +475,7 @@ class Hive(Dialect):
|
|||
exp.ArgMax: arg_max_or_min_no_count("MAX_BY"),
|
||||
exp.ArgMin: arg_max_or_min_no_count("MIN_BY"),
|
||||
exp.ArrayConcat: rename_func("CONCAT"),
|
||||
exp.ArrayJoin: lambda self, e: self.func("CONCAT_WS", e.expression, e.this),
|
||||
exp.ArrayToString: lambda self, e: self.func("CONCAT_WS", e.expression, e.this),
|
||||
exp.ArraySize: rename_func("SIZE"),
|
||||
exp.ArraySort: _array_sort_sql,
|
||||
exp.With: no_recursive_cte_sql,
|
||||
|
|
|
@ -23,6 +23,7 @@ from sqlglot.dialects.dialect import (
|
|||
build_date_delta_with_interval,
|
||||
rename_func,
|
||||
strposition_to_locate_sql,
|
||||
unit_to_var,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
@ -109,14 +110,14 @@ def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str:
|
|||
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
|
||||
|
||||
|
||||
def _date_add_sql(
|
||||
def date_add_sql(
|
||||
kind: str,
|
||||
) -> t.Callable[[MySQL.Generator, exp.Expression], str]:
|
||||
def func(self: MySQL.Generator, expression: exp.Expression) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = expression.text("unit").upper() or "DAY"
|
||||
return (
|
||||
f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
|
||||
) -> t.Callable[[generator.Generator, exp.Expression], str]:
|
||||
def func(self: generator.Generator, expression: exp.Expression) -> str:
|
||||
return self.func(
|
||||
f"DATE_{kind}",
|
||||
expression.this,
|
||||
exp.Interval(this=expression.expression, unit=unit_to_var(expression)),
|
||||
)
|
||||
|
||||
return func
|
||||
|
@ -291,6 +292,7 @@ class MySQL(Dialect):
|
|||
"DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
|
||||
"FROM_UNIXTIME": build_formatted_time(exp.UnixToTime, "mysql"),
|
||||
"ISNULL": isnull_to_is_null,
|
||||
"LOCATE": locate_to_strposition,
|
||||
"MAKETIME": exp.TimeFromParts.from_arg_list,
|
||||
|
@ -319,11 +321,7 @@ class MySQL(Dialect):
|
|||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
"CHAR": lambda self: self._parse_chr(),
|
||||
"GROUP_CONCAT": lambda self: self.expression(
|
||||
exp.GroupConcat,
|
||||
this=self._parse_lambda(),
|
||||
separator=self._match(TokenType.SEPARATOR) and self._parse_field(),
|
||||
),
|
||||
"GROUP_CONCAT": lambda self: self._parse_group_concat(),
|
||||
# https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values
|
||||
"VALUES": lambda self: self.expression(
|
||||
exp.Anonymous, this="VALUES", expressions=[self._parse_id_var()]
|
||||
|
@ -412,6 +410,11 @@ class MySQL(Dialect):
|
|||
"SPATIAL": lambda self: self._parse_index_constraint(kind="SPATIAL"),
|
||||
}
|
||||
|
||||
ALTER_PARSERS = {
|
||||
**parser.Parser.ALTER_PARSERS,
|
||||
"MODIFY": lambda self: self._parse_alter_table_alter(),
|
||||
}
|
||||
|
||||
SCHEMA_UNNAMED_CONSTRAINTS = {
|
||||
*parser.Parser.SCHEMA_UNNAMED_CONSTRAINTS,
|
||||
"FULLTEXT",
|
||||
|
@ -458,7 +461,7 @@ class MySQL(Dialect):
|
|||
|
||||
this = self._parse_id_var(any_token=False)
|
||||
index_type = self._match(TokenType.USING) and self._advance_any() and self._prev.text
|
||||
schema = self._parse_schema()
|
||||
expressions = self._parse_wrapped_csv(self._parse_ordered)
|
||||
|
||||
options = []
|
||||
while True:
|
||||
|
@ -478,9 +481,6 @@ class MySQL(Dialect):
|
|||
elif self._match_text_seq("ENGINE_ATTRIBUTE"):
|
||||
self._match(TokenType.EQ)
|
||||
opt = exp.IndexConstraintOption(engine_attr=self._parse_string())
|
||||
elif self._match_text_seq("ENGINE_ATTRIBUTE"):
|
||||
self._match(TokenType.EQ)
|
||||
opt = exp.IndexConstraintOption(engine_attr=self._parse_string())
|
||||
elif self._match_text_seq("SECONDARY_ENGINE_ATTRIBUTE"):
|
||||
self._match(TokenType.EQ)
|
||||
opt = exp.IndexConstraintOption(secondary_engine_attr=self._parse_string())
|
||||
|
@ -495,7 +495,7 @@ class MySQL(Dialect):
|
|||
return self.expression(
|
||||
exp.IndexColumnConstraint,
|
||||
this=this,
|
||||
schema=schema,
|
||||
expressions=expressions,
|
||||
kind=kind,
|
||||
index_type=index_type,
|
||||
options=options,
|
||||
|
@ -617,6 +617,39 @@ class MySQL(Dialect):
|
|||
|
||||
return self.expression(exp.Chr, **kwargs)
|
||||
|
||||
def _parse_group_concat(self) -> t.Optional[exp.Expression]:
|
||||
def concat_exprs(
|
||||
node: t.Optional[exp.Expression], exprs: t.List[exp.Expression]
|
||||
) -> exp.Expression:
|
||||
if isinstance(node, exp.Distinct) and len(node.expressions) > 1:
|
||||
concat_exprs = [
|
||||
self.expression(exp.Concat, expressions=node.expressions, safe=True)
|
||||
]
|
||||
node.set("expressions", concat_exprs)
|
||||
return node
|
||||
if len(exprs) == 1:
|
||||
return exprs[0]
|
||||
return self.expression(exp.Concat, expressions=args, safe=True)
|
||||
|
||||
args = self._parse_csv(self._parse_lambda)
|
||||
|
||||
if args:
|
||||
order = args[-1] if isinstance(args[-1], exp.Order) else None
|
||||
|
||||
if order:
|
||||
# Order By is the last (or only) expression in the list and has consumed the 'expr' before it,
|
||||
# remove 'expr' from exp.Order and add it back to args
|
||||
args[-1] = order.this
|
||||
order.set("this", concat_exprs(order.this, args))
|
||||
|
||||
this = order or concat_exprs(args[0], args)
|
||||
else:
|
||||
this = None
|
||||
|
||||
separator = self._parse_field() if self._match(TokenType.SEPARATOR) else None
|
||||
|
||||
return self.expression(exp.GroupConcat, this=this, separator=separator)
|
||||
|
||||
class Generator(generator.Generator):
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
NULL_ORDERING_SUPPORTED = None
|
||||
|
@ -630,6 +663,7 @@ class MySQL(Dialect):
|
|||
JSON_TYPE_REQUIRED_FOR_EXTRACTION = True
|
||||
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
|
||||
JSON_KEY_VALUE_PAIR_SEP = ","
|
||||
SUPPORTS_TO_NUMBER = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -637,9 +671,9 @@ class MySQL(Dialect):
|
|||
exp.DateDiff: _remove_ts_or_ds_to_date(
|
||||
lambda self, e: self.func("DATEDIFF", e.this, e.expression), ("this", "expression")
|
||||
),
|
||||
exp.DateAdd: _remove_ts_or_ds_to_date(_date_add_sql("ADD")),
|
||||
exp.DateAdd: _remove_ts_or_ds_to_date(date_add_sql("ADD")),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateSub: _remove_ts_or_ds_to_date(_date_add_sql("SUB")),
|
||||
exp.DateSub: _remove_ts_or_ds_to_date(date_add_sql("SUB")),
|
||||
exp.DateTrunc: _date_trunc_sql,
|
||||
exp.Day: _remove_ts_or_ds_to_date(),
|
||||
exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")),
|
||||
|
@ -672,7 +706,7 @@ class MySQL(Dialect):
|
|||
exp.TimeFromParts: rename_func("MAKETIME"),
|
||||
exp.TimestampAdd: date_add_interval_sql("DATE", "ADD"),
|
||||
exp.TimestampDiff: lambda self, e: self.func(
|
||||
"TIMESTAMPDIFF", e.text("unit"), e.expression, e.this
|
||||
"TIMESTAMPDIFF", unit_to_var(e), e.expression, e.this
|
||||
),
|
||||
exp.TimestampSub: date_add_interval_sql("DATE", "SUB"),
|
||||
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
|
@ -682,9 +716,10 @@ class MySQL(Dialect):
|
|||
),
|
||||
exp.Trim: _trim_sql,
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.TsOrDsAdd: _date_add_sql("ADD"),
|
||||
exp.TsOrDsAdd: date_add_sql("ADD"),
|
||||
exp.TsOrDsDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression),
|
||||
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
|
||||
exp.UnixToTime: lambda self, e: self.func("FROM_UNIXTIME", e.this, self.format_time(e)),
|
||||
exp.Week: _remove_ts_or_ds_to_date(),
|
||||
exp.WeekOfYear: _remove_ts_or_ds_to_date(rename_func("WEEKOFYEAR")),
|
||||
exp.Year: _remove_ts_or_ds_to_date(),
|
||||
|
@ -751,11 +786,6 @@ class MySQL(Dialect):
|
|||
result = f"{result} UNSIGNED"
|
||||
return result
|
||||
|
||||
def xor_sql(self, expression: exp.Xor) -> str:
|
||||
if expression.expressions:
|
||||
return self.expressions(expression, sep=" XOR ")
|
||||
return super().xor_sql(expression)
|
||||
|
||||
def jsonarraycontains_sql(self, expression: exp.JSONArrayContains) -> str:
|
||||
return f"{self.sql(expression, 'this')} MEMBER OF({self.sql(expression, 'expression')})"
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import (
|
|||
build_formatted_time,
|
||||
no_ilike_sql,
|
||||
rename_func,
|
||||
to_number_with_nls_param,
|
||||
trim_sql,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
|
@ -246,6 +247,7 @@ class Oracle(Dialect):
|
|||
exp.TemporaryProperty: lambda _, e: f"{e.name or 'GLOBAL'} TEMPORARY",
|
||||
exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)),
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.ToNumber: to_number_with_nls_param,
|
||||
exp.Trim: trim_sql,
|
||||
exp.UnixToTime: lambda self,
|
||||
e: f"TO_DATE('1970-01-01', 'YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
|
||||
|
|
|
@ -278,6 +278,7 @@ class Postgres(Dialect):
|
|||
"REVOKE": TokenType.COMMAND,
|
||||
"SERIAL": TokenType.SERIAL,
|
||||
"SMALLSERIAL": TokenType.SMALLSERIAL,
|
||||
"NAME": TokenType.NAME,
|
||||
"TEMP": TokenType.TEMPORARY,
|
||||
"CSTRING": TokenType.PSEUDO_TYPE,
|
||||
"OID": TokenType.OBJECT_IDENTIFIER,
|
||||
|
@ -356,6 +357,16 @@ class Postgres(Dialect):
|
|||
|
||||
JSON_ARROWS_REQUIRE_JSON_TYPE = True
|
||||
|
||||
COLUMN_OPERATORS = {
|
||||
**parser.Parser.COLUMN_OPERATORS,
|
||||
TokenType.ARROW: lambda self, this, path: build_json_extract_path(
|
||||
exp.JSONExtract, arrow_req_json_type=self.JSON_ARROWS_REQUIRE_JSON_TYPE
|
||||
)([this, path]),
|
||||
TokenType.DARROW: lambda self, this, path: build_json_extract_path(
|
||||
exp.JSONExtractScalar, arrow_req_json_type=self.JSON_ARROWS_REQUIRE_JSON_TYPE
|
||||
)([this, path]),
|
||||
}
|
||||
|
||||
def _parse_operator(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
while True:
|
||||
if not self._match(TokenType.L_PAREN):
|
||||
|
@ -484,6 +495,7 @@ class Postgres(Dialect):
|
|||
]
|
||||
),
|
||||
exp.StrPosition: str_position_sql,
|
||||
exp.StrToDate: lambda self, e: self.func("TO_DATE", e.this, self.format_time(e)),
|
||||
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
|
||||
exp.StructExtract: struct_extract_sql,
|
||||
exp.Substring: _substring_sql,
|
||||
|
|
|
@ -22,10 +22,13 @@ from sqlglot.dialects.dialect import (
|
|||
rename_func,
|
||||
right_to_substring_sql,
|
||||
struct_extract_sql,
|
||||
str_position_sql,
|
||||
timestamptrunc_sql,
|
||||
timestrtotime_sql,
|
||||
ts_or_ds_add_cast,
|
||||
unit_to_str,
|
||||
)
|
||||
from sqlglot.dialects.hive import Hive
|
||||
from sqlglot.dialects.mysql import MySQL
|
||||
from sqlglot.helper import apply_index_offset, seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
@ -93,14 +96,14 @@ def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate)
|
|||
|
||||
def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str:
|
||||
expression = ts_or_ds_add_cast(expression)
|
||||
unit = exp.Literal.string(expression.text("unit") or "DAY")
|
||||
unit = unit_to_str(expression)
|
||||
return self.func("DATE_ADD", unit, expression.expression, expression.this)
|
||||
|
||||
|
||||
def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> str:
|
||||
this = exp.cast(expression.this, "TIMESTAMP")
|
||||
expr = exp.cast(expression.expression, "TIMESTAMP")
|
||||
unit = exp.Literal.string(expression.text("unit") or "DAY")
|
||||
unit = unit_to_str(expression)
|
||||
return self.func("DATE_DIFF", unit, expr, this)
|
||||
|
||||
|
||||
|
@ -196,6 +199,7 @@ class Presto(Dialect):
|
|||
SUPPORTS_SEMI_ANTI_JOIN = False
|
||||
TYPED_DIVISION = True
|
||||
TABLESAMPLE_SIZE_IS_PERCENT = True
|
||||
LOG_BASE_FIRST: t.Optional[bool] = None
|
||||
|
||||
# https://github.com/trinodb/trino/issues/17
|
||||
# https://github.com/trinodb/trino/issues/12289
|
||||
|
@ -289,6 +293,7 @@ class Presto(Dialect):
|
|||
SUPPORTS_SINGLE_ARG_CONCAT = False
|
||||
LIKE_PROPERTY_INSIDE_SCHEMA = True
|
||||
MULTI_ARG_DISTINCT = False
|
||||
SUPPORTS_TO_NUMBER = False
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION,
|
||||
|
@ -323,6 +328,7 @@ class Presto(Dialect):
|
|||
exp.ArrayConcat: rename_func("CONCAT"),
|
||||
exp.ArrayContains: rename_func("CONTAINS"),
|
||||
exp.ArraySize: rename_func("CARDINALITY"),
|
||||
exp.ArrayToString: rename_func("ARRAY_JOIN"),
|
||||
exp.ArrayUniqueAgg: rename_func("SET_AGG"),
|
||||
exp.AtTimeZone: rename_func("AT_TIMEZONE"),
|
||||
exp.BitwiseAnd: lambda self, e: self.func("BITWISE_AND", e.this, e.expression),
|
||||
|
@ -339,19 +345,19 @@ class Presto(Dialect):
|
|||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.DateAdd: lambda self, e: self.func(
|
||||
"DATE_ADD",
|
||||
exp.Literal.string(e.text("unit") or "DAY"),
|
||||
unit_to_str(e),
|
||||
_to_int(e.expression),
|
||||
e.this,
|
||||
),
|
||||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.expression, e.this
|
||||
"DATE_DIFF", unit_to_str(e), e.expression, e.this
|
||||
),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateToDi: lambda self,
|
||||
e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
|
||||
exp.DateSub: lambda self, e: self.func(
|
||||
"DATE_ADD",
|
||||
exp.Literal.string(e.text("unit") or "DAY"),
|
||||
unit_to_str(e),
|
||||
_to_int(e.expression * -1),
|
||||
e.this,
|
||||
),
|
||||
|
@ -397,13 +403,10 @@ class Presto(Dialect):
|
|||
]
|
||||
),
|
||||
exp.SortArray: _no_sort_array,
|
||||
exp.StrPosition: rename_func("STRPOS"),
|
||||
exp.StrPosition: lambda self, e: str_position_sql(self, e, generate_instance=True),
|
||||
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
|
||||
exp.StrToMap: rename_func("SPLIT_TO_MAP"),
|
||||
exp.StrToTime: _str_to_time_sql,
|
||||
exp.StrToUnix: lambda self, e: self.func(
|
||||
"TO_UNIXTIME", self.func("DATE_PARSE", e.this, self.format_time(e))
|
||||
),
|
||||
exp.StructExtract: struct_extract_sql,
|
||||
exp.Table: transforms.preprocess([_unnest_sequence]),
|
||||
exp.Timestamp: no_timestamp_sql,
|
||||
|
@ -436,6 +439,22 @@ class Presto(Dialect):
|
|||
exp.Xor: bool_xor_sql,
|
||||
}
|
||||
|
||||
def strtounix_sql(self, expression: exp.StrToUnix) -> str:
|
||||
# Since `TO_UNIXTIME` requires a `TIMESTAMP`, we need to parse the argument into one.
|
||||
# To do this, we first try to `DATE_PARSE` it, but since this can fail when there's a
|
||||
# timezone involved, we wrap it in a `TRY` call and use `PARSE_DATETIME` as a fallback,
|
||||
# which seems to be using the same time mapping as Hive, as per:
|
||||
# https://joda-time.sourceforge.net/apidocs/org/joda/time/format/DateTimeFormat.html
|
||||
value_as_text = exp.cast(expression.this, "text")
|
||||
parse_without_tz = self.func("DATE_PARSE", value_as_text, self.format_time(expression))
|
||||
parse_with_tz = self.func(
|
||||
"PARSE_DATETIME",
|
||||
value_as_text,
|
||||
self.format_time(expression, Hive.INVERSE_TIME_MAPPING, Hive.INVERSE_TIME_TRIE),
|
||||
)
|
||||
coalesced = self.func("COALESCE", self.func("TRY", parse_without_tz), parse_with_tz)
|
||||
return self.func("TO_UNIXTIME", coalesced)
|
||||
|
||||
def bracket_sql(self, expression: exp.Bracket) -> str:
|
||||
if expression.args.get("safe"):
|
||||
return self.func(
|
||||
|
@ -481,8 +500,7 @@ class Presto(Dialect):
|
|||
return f"CAST(ROW({', '.join(values)}) AS ROW({', '.join(schema)}))"
|
||||
|
||||
def interval_sql(self, expression: exp.Interval) -> str:
|
||||
unit = self.sql(expression, "unit")
|
||||
if expression.this and unit.startswith("WEEK"):
|
||||
if expression.this and expression.text("unit").upper().startswith("WEEK"):
|
||||
return f"({expression.this.name} * INTERVAL '7' DAY)"
|
||||
return super().interval_sql(expression)
|
||||
|
||||
|
|
109
sqlglot/dialects/prql.py
Normal file
109
sqlglot/dialects/prql.py
Normal file
|
@ -0,0 +1,109 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, parser, tokens
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
class PRQL(Dialect):
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
IDENTIFIERS = ["`"]
|
||||
QUOTES = ["'", '"']
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
**tokens.Tokenizer.SINGLE_TOKENS,
|
||||
"=": TokenType.ALIAS,
|
||||
"'": TokenType.QUOTE,
|
||||
'"': TokenType.QUOTE,
|
||||
"`": TokenType.IDENTIFIER,
|
||||
"#": TokenType.COMMENT,
|
||||
}
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
TRANSFORM_PARSERS = {
|
||||
"DERIVE": lambda self, query: self._parse_selection(query),
|
||||
"SELECT": lambda self, query: self._parse_selection(query, append=False),
|
||||
"TAKE": lambda self, query: self._parse_take(query),
|
||||
}
|
||||
|
||||
def _parse_statement(self) -> t.Optional[exp.Expression]:
|
||||
expression = self._parse_expression()
|
||||
expression = expression if expression else self._parse_query()
|
||||
return expression
|
||||
|
||||
def _parse_query(self) -> t.Optional[exp.Query]:
|
||||
from_ = self._parse_from()
|
||||
|
||||
if not from_:
|
||||
return None
|
||||
|
||||
query = exp.select("*").from_(from_, copy=False)
|
||||
|
||||
while self._match_texts(self.TRANSFORM_PARSERS):
|
||||
query = self.TRANSFORM_PARSERS[self._prev.text.upper()](self, query)
|
||||
|
||||
return query
|
||||
|
||||
def _parse_selection(self, query: exp.Query, append: bool = True) -> exp.Query:
|
||||
if self._match(TokenType.L_BRACE):
|
||||
selects = self._parse_csv(self._parse_expression)
|
||||
|
||||
if not self._match(TokenType.R_BRACE, expression=query):
|
||||
self.raise_error("Expecting }")
|
||||
else:
|
||||
expression = self._parse_expression()
|
||||
selects = [expression] if expression else []
|
||||
|
||||
projections = {
|
||||
select.alias_or_name: select.this if isinstance(select, exp.Alias) else select
|
||||
for select in query.selects
|
||||
}
|
||||
|
||||
selects = [
|
||||
select.transform(
|
||||
lambda s: (projections[s.name].copy() if s.name in projections else s)
|
||||
if isinstance(s, exp.Column)
|
||||
else s,
|
||||
copy=False,
|
||||
)
|
||||
for select in selects
|
||||
]
|
||||
|
||||
return query.select(*selects, append=append, copy=False)
|
||||
|
||||
def _parse_take(self, query: exp.Query) -> t.Optional[exp.Query]:
|
||||
num = self._parse_number() # TODO: TAKE for ranges a..b
|
||||
return query.limit(num) if num else None
|
||||
|
||||
def _parse_expression(self) -> t.Optional[exp.Expression]:
|
||||
if self._next and self._next.token_type == TokenType.ALIAS:
|
||||
alias = self._parse_id_var(True)
|
||||
self._match(TokenType.ALIAS)
|
||||
return self.expression(exp.Alias, this=self._parse_conjunction(), alias=alias)
|
||||
return self._parse_conjunction()
|
||||
|
||||
def _parse_table(
|
||||
self,
|
||||
schema: bool = False,
|
||||
joins: bool = False,
|
||||
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
|
||||
parse_bracket: bool = False,
|
||||
is_db_reference: bool = False,
|
||||
) -> t.Optional[exp.Expression]:
|
||||
return self._parse_table_parts()
|
||||
|
||||
def _parse_from(
|
||||
self, joins: bool = False, skip_from_token: bool = False
|
||||
) -> t.Optional[exp.From]:
|
||||
if not skip_from_token and not self._match(TokenType.FROM):
|
||||
return None
|
||||
|
||||
return self.expression(
|
||||
exp.From, comments=self._prev_comments, this=self._parse_table(joins=joins)
|
||||
)
|
|
@ -92,23 +92,6 @@ class Redshift(Postgres):
|
|||
|
||||
return self.expression(exp.Pivot, this=table, unpivot=True) if unpivot else table
|
||||
|
||||
def _parse_types(
|
||||
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
|
||||
) -> t.Optional[exp.Expression]:
|
||||
this = super()._parse_types(
|
||||
check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(this, exp.DataType)
|
||||
and this.is_type("varchar")
|
||||
and this.expressions
|
||||
and this.expressions[0].this == exp.column("MAX")
|
||||
):
|
||||
this.set("expressions", [exp.var("MAX")])
|
||||
|
||||
return this
|
||||
|
||||
def _parse_convert(
|
||||
self, strict: bool, safe: t.Optional[bool] = None
|
||||
) -> t.Optional[exp.Expression]:
|
||||
|
@ -153,6 +136,7 @@ class Redshift(Postgres):
|
|||
NVL2_SUPPORTED = True
|
||||
LAST_DAY_SUPPORTS_DATE_PART = False
|
||||
CAN_IMPLEMENT_ARRAY_ANY = False
|
||||
MULTI_ARG_DISTINCT = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**Postgres.Generator.TYPE_MAPPING,
|
||||
|
@ -187,9 +171,13 @@ class Redshift(Postgres):
|
|||
),
|
||||
exp.SortKeyProperty: lambda self,
|
||||
e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
|
||||
exp.StartsWith: lambda self,
|
||||
e: f"{self.sql(e.this)} LIKE {self.sql(e.expression)} || '%'",
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
exp.TsOrDsAdd: date_delta_sql("DATEADD"),
|
||||
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
|
||||
exp.UnixToTime: lambda self,
|
||||
e: f"(TIMESTAMP 'epoch' + {self.sql(e.this)} * INTERVAL '1 SECOND')",
|
||||
}
|
||||
|
||||
# Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots
|
||||
|
|
|
@ -20,8 +20,7 @@ from sqlglot.dialects.dialect import (
|
|||
timestrtotime_sql,
|
||||
var_map_sql,
|
||||
)
|
||||
from sqlglot.expressions import Literal
|
||||
from sqlglot.helper import flatten, is_int, seq_get
|
||||
from sqlglot.helper import flatten, is_float, is_int, seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
@ -29,33 +28,35 @@ if t.TYPE_CHECKING:
|
|||
|
||||
|
||||
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
|
||||
def _build_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]:
|
||||
if len(args) == 2:
|
||||
first_arg, second_arg = args
|
||||
if second_arg.is_string:
|
||||
# case: <string_expr> [ , <format> ]
|
||||
return build_formatted_time(exp.StrToTime, "snowflake")(args)
|
||||
return exp.UnixToTime(this=first_arg, scale=second_arg)
|
||||
def _build_datetime(
|
||||
name: str, kind: exp.DataType.Type, safe: bool = False
|
||||
) -> t.Callable[[t.List], exp.Func]:
|
||||
def _builder(args: t.List) -> exp.Func:
|
||||
value = seq_get(args, 0)
|
||||
|
||||
from sqlglot.optimizer.simplify import simplify_literals
|
||||
if isinstance(value, exp.Literal):
|
||||
int_value = is_int(value.this)
|
||||
|
||||
# The first argument might be an expression like 40 * 365 * 86400, so we try to
|
||||
# reduce it using `simplify_literals` first and then check if it's a Literal.
|
||||
first_arg = seq_get(args, 0)
|
||||
if not isinstance(simplify_literals(first_arg, root=True), Literal):
|
||||
# case: <variant_expr> or other expressions such as columns
|
||||
return exp.TimeStrToTime.from_arg_list(args)
|
||||
# Converts calls like `TO_TIME('01:02:03')` into casts
|
||||
if len(args) == 1 and value.is_string and not int_value:
|
||||
return exp.cast(value, kind)
|
||||
|
||||
if first_arg.is_string:
|
||||
if is_int(first_arg.this):
|
||||
# case: <integer>
|
||||
return exp.UnixToTime.from_arg_list(args)
|
||||
# Handles `TO_TIMESTAMP(str, fmt)` and `TO_TIMESTAMP(num, scale)` as special
|
||||
# cases so we can transpile them, since they're relatively common
|
||||
if kind == exp.DataType.Type.TIMESTAMP:
|
||||
if int_value:
|
||||
return exp.UnixToTime(this=value, scale=seq_get(args, 1))
|
||||
if not is_float(value.this):
|
||||
return build_formatted_time(exp.StrToTime, "snowflake")(args)
|
||||
|
||||
# case: <date_expr>
|
||||
return build_formatted_time(exp.StrToTime, "snowflake", default=True)(args)
|
||||
if len(args) == 2 and kind == exp.DataType.Type.DATE:
|
||||
formatted_exp = build_formatted_time(exp.TsOrDsToDate, "snowflake")(args)
|
||||
formatted_exp.set("safe", safe)
|
||||
return formatted_exp
|
||||
|
||||
# case: <numeric_expr>
|
||||
return exp.UnixToTime.from_arg_list(args)
|
||||
return exp.Anonymous(this=name, expressions=args)
|
||||
|
||||
return _builder
|
||||
|
||||
|
||||
def _build_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
|
||||
|
@ -77,6 +78,17 @@ def _build_datediff(args: t.List) -> exp.DateDiff:
|
|||
)
|
||||
|
||||
|
||||
def _build_date_time_add(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
|
||||
def _builder(args: t.List) -> E:
|
||||
return expr_type(
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
unit=_map_date_part(seq_get(args, 0)),
|
||||
)
|
||||
|
||||
return _builder
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/div0
|
||||
def _build_if_from_div0(args: t.List) -> exp.If:
|
||||
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
|
||||
|
@ -97,14 +109,6 @@ def _build_if_from_nullifzero(args: t.List) -> exp.If:
|
|||
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
|
||||
|
||||
|
||||
def _datatype_sql(self: Snowflake.Generator, expression: exp.DataType) -> str:
|
||||
if expression.is_type("array"):
|
||||
return "ARRAY"
|
||||
elif expression.is_type("map"):
|
||||
return "OBJECT"
|
||||
return self.datatype_sql(expression)
|
||||
|
||||
|
||||
def _regexpilike_sql(self: Snowflake.Generator, expression: exp.RegexpILike) -> str:
|
||||
flag = expression.text("flag")
|
||||
|
||||
|
@ -258,6 +262,25 @@ def _unqualify_unpivot_columns(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def _flatten_structured_types_unless_iceberg(expression: exp.Expression) -> exp.Expression:
|
||||
assert isinstance(expression, exp.Create)
|
||||
|
||||
def _flatten_structured_type(expression: exp.DataType) -> exp.DataType:
|
||||
if expression.this in exp.DataType.NESTED_TYPES:
|
||||
expression.set("expressions", None)
|
||||
return expression
|
||||
|
||||
props = expression.args.get("properties")
|
||||
if isinstance(expression.this, exp.Schema) and not (props and props.find(exp.IcebergProperty)):
|
||||
for schema_expression in expression.this.expressions:
|
||||
if isinstance(schema_expression, exp.ColumnDef):
|
||||
column_type = schema_expression.kind
|
||||
if isinstance(column_type, exp.DataType):
|
||||
column_type.transform(_flatten_structured_type, copy=False)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
class Snowflake(Dialect):
|
||||
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
|
||||
|
@ -312,7 +335,13 @@ class Snowflake(Dialect):
|
|||
class Parser(parser.Parser):
|
||||
IDENTIFY_PIVOT_STRINGS = True
|
||||
|
||||
ID_VAR_TOKENS = {
|
||||
*parser.Parser.ID_VAR_TOKENS,
|
||||
TokenType.MATCH_CONDITION,
|
||||
}
|
||||
|
||||
TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS | {TokenType.WINDOW}
|
||||
TABLE_ALIAS_TOKENS.discard(TokenType.MATCH_CONDITION)
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
|
@ -327,17 +356,13 @@ class Snowflake(Dialect):
|
|||
end=exp.Sub(this=seq_get(args, 1), expression=exp.Literal.number(1)),
|
||||
step=seq_get(args, 2),
|
||||
),
|
||||
"ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list,
|
||||
"BITXOR": binary_from_function(exp.BitwiseXor),
|
||||
"BIT_XOR": binary_from_function(exp.BitwiseXor),
|
||||
"BOOLXOR": binary_from_function(exp.Xor),
|
||||
"CONVERT_TIMEZONE": _build_convert_timezone,
|
||||
"DATE": _build_datetime("DATE", exp.DataType.Type.DATE),
|
||||
"DATE_TRUNC": _date_trunc_to_time,
|
||||
"DATEADD": lambda args: exp.DateAdd(
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
unit=_map_date_part(seq_get(args, 0)),
|
||||
),
|
||||
"DATEADD": _build_date_time_add(exp.DateAdd),
|
||||
"DATEDIFF": _build_datediff,
|
||||
"DIV0": _build_if_from_div0,
|
||||
"FLATTEN": exp.Explode.from_arg_list,
|
||||
|
@ -349,17 +374,34 @@ class Snowflake(Dialect):
|
|||
this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1))
|
||||
),
|
||||
"LISTAGG": exp.GroupConcat.from_arg_list,
|
||||
"MEDIAN": lambda args: exp.PercentileCont(
|
||||
this=seq_get(args, 0), expression=exp.Literal.number(0.5)
|
||||
),
|
||||
"NULLIFZERO": _build_if_from_nullifzero,
|
||||
"OBJECT_CONSTRUCT": _build_object_construct,
|
||||
"REGEXP_REPLACE": _build_regexp_replace,
|
||||
"REGEXP_SUBSTR": exp.RegexpExtract.from_arg_list,
|
||||
"RLIKE": exp.RegexpLike.from_arg_list,
|
||||
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
|
||||
"TIMEADD": _build_date_time_add(exp.TimeAdd),
|
||||
"TIMEDIFF": _build_datediff,
|
||||
"TIMESTAMPADD": _build_date_time_add(exp.DateAdd),
|
||||
"TIMESTAMPDIFF": _build_datediff,
|
||||
"TIMESTAMPFROMPARTS": _build_timestamp_from_parts,
|
||||
"TIMESTAMP_FROM_PARTS": _build_timestamp_from_parts,
|
||||
"TO_TIMESTAMP": _build_to_timestamp,
|
||||
"TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DataType.Type.DATE, safe=True),
|
||||
"TO_DATE": _build_datetime("TO_DATE", exp.DataType.Type.DATE),
|
||||
"TO_NUMBER": lambda args: exp.ToNumber(
|
||||
this=seq_get(args, 0),
|
||||
format=seq_get(args, 1),
|
||||
precision=seq_get(args, 2),
|
||||
scale=seq_get(args, 3),
|
||||
),
|
||||
"TO_TIME": _build_datetime("TO_TIME", exp.DataType.Type.TIME),
|
||||
"TO_TIMESTAMP": _build_datetime("TO_TIMESTAMP", exp.DataType.Type.TIMESTAMP),
|
||||
"TO_TIMESTAMP_LTZ": _build_datetime("TO_TIMESTAMP_LTZ", exp.DataType.Type.TIMESTAMPLTZ),
|
||||
"TO_TIMESTAMP_NTZ": _build_datetime("TO_TIMESTAMP_NTZ", exp.DataType.Type.TIMESTAMP),
|
||||
"TO_TIMESTAMP_TZ": _build_datetime("TO_TIMESTAMP_TZ", exp.DataType.Type.TIMESTAMPTZ),
|
||||
"TO_VARCHAR": exp.ToChar.from_arg_list,
|
||||
"ZEROIFNULL": _build_if_from_zeroifnull,
|
||||
}
|
||||
|
@ -377,7 +419,6 @@ class Snowflake(Dialect):
|
|||
**parser.Parser.RANGE_PARSERS,
|
||||
TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny),
|
||||
TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny),
|
||||
TokenType.COLON: lambda self, this: self._parse_colon_get_path(this),
|
||||
}
|
||||
|
||||
ALTER_PARSERS = {
|
||||
|
@ -434,35 +475,35 @@ class Snowflake(Dialect):
|
|||
|
||||
SCHEMA_KINDS = {"OBJECTS", "TABLES", "VIEWS", "SEQUENCES", "UNIQUE KEYS", "IMPORTED KEYS"}
|
||||
|
||||
def _parse_colon_get_path(
|
||||
self: parser.Parser, this: t.Optional[exp.Expression]
|
||||
) -> t.Optional[exp.Expression]:
|
||||
while True:
|
||||
path = self._parse_bitwise() or self._parse_var(any_token=True)
|
||||
def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
this = super()._parse_column_ops(this)
|
||||
|
||||
casts = []
|
||||
json_path = []
|
||||
|
||||
while self._match(TokenType.COLON):
|
||||
path = super()._parse_column_ops(self._parse_field(any_token=True))
|
||||
|
||||
# The cast :: operator has a lower precedence than the extraction operator :, so
|
||||
# we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH
|
||||
if isinstance(path, exp.Cast):
|
||||
target_type = path.to
|
||||
while isinstance(path, exp.Cast):
|
||||
casts.append(path.to)
|
||||
path = path.this
|
||||
else:
|
||||
target_type = None
|
||||
|
||||
if isinstance(path, exp.Expression):
|
||||
path = exp.Literal.string(path.sql(dialect="snowflake"))
|
||||
if path:
|
||||
json_path.append(path.sql(dialect="snowflake", copy=False))
|
||||
|
||||
# The extraction operator : is left-associative
|
||||
if json_path:
|
||||
this = self.expression(
|
||||
exp.JSONExtract, this=this, expression=self.dialect.to_json_path(path)
|
||||
exp.JSONExtract,
|
||||
this=this,
|
||||
expression=self.dialect.to_json_path(exp.Literal.string(".".join(json_path))),
|
||||
)
|
||||
|
||||
if target_type:
|
||||
this = exp.cast(this, target_type)
|
||||
while casts:
|
||||
this = self.expression(exp.Cast, this=this, to=casts.pop())
|
||||
|
||||
if not self._match(TokenType.COLON):
|
||||
break
|
||||
|
||||
return self._parse_range(this)
|
||||
return this
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
|
||||
# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts
|
||||
|
@ -663,6 +704,7 @@ class Snowflake(Dialect):
|
|||
"EXCLUDE": TokenType.EXCEPT,
|
||||
"ILIKE ANY": TokenType.ILIKE_ANY,
|
||||
"LIKE ANY": TokenType.LIKE_ANY,
|
||||
"MATCH_CONDITION": TokenType.MATCH_CONDITION,
|
||||
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
|
||||
"MINUS": TokenType.EXCEPT,
|
||||
"NCHAR VARYING": TokenType.VARCHAR,
|
||||
|
@ -703,6 +745,7 @@ class Snowflake(Dialect):
|
|||
LIMIT_ONLY_LITERALS = True
|
||||
JSON_KEY_VALUE_PAIR_SEP = ","
|
||||
INSERT_OVERWRITE = " OVERWRITE INTO"
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -711,15 +754,14 @@ class Snowflake(Dialect):
|
|||
exp.Array: inline_array_sql,
|
||||
exp.ArrayConcat: rename_func("ARRAY_CAT"),
|
||||
exp.ArrayContains: lambda self, e: self.func("ARRAY_CONTAINS", e.expression, e.this),
|
||||
exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
|
||||
exp.AtTimeZone: lambda self, e: self.func(
|
||||
"CONVERT_TIMEZONE", e.args.get("zone"), e.this
|
||||
),
|
||||
exp.BitwiseXor: rename_func("BITXOR"),
|
||||
exp.Create: transforms.preprocess([_flatten_structured_types_unless_iceberg]),
|
||||
exp.DateAdd: date_delta_sql("DATEADD"),
|
||||
exp.DateDiff: date_delta_sql("DATEDIFF"),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
|
@ -769,6 +811,7 @@ class Snowflake(Dialect):
|
|||
),
|
||||
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
|
||||
exp.Stuff: rename_func("INSERT"),
|
||||
exp.TimeAdd: date_delta_sql("TIMEADD"),
|
||||
exp.TimestampDiff: lambda self, e: self.func(
|
||||
"TIMESTAMPDIFF", e.unit, e.expression, e.this
|
||||
),
|
||||
|
@ -783,6 +826,9 @@ class Snowflake(Dialect):
|
|||
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
|
||||
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
|
||||
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
|
||||
exp.TsOrDsToDate: lambda self, e: self.func(
|
||||
"TRY_TO_DATE" if e.args.get("safe") else "TO_DATE", e.this, self.format_time(e)
|
||||
),
|
||||
exp.UnixToTime: rename_func("TO_TIMESTAMP"),
|
||||
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
|
||||
|
@ -797,6 +843,8 @@ class Snowflake(Dialect):
|
|||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.NESTED: "OBJECT",
|
||||
exp.DataType.Type.STRUCT: "OBJECT",
|
||||
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
|
||||
}
|
||||
|
||||
|
@ -811,6 +859,37 @@ class Snowflake(Dialect):
|
|||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
UNSUPPORTED_VALUES_EXPRESSIONS = {
|
||||
exp.Struct,
|
||||
}
|
||||
|
||||
def values_sql(self, expression: exp.Values, values_as_table: bool = True) -> str:
|
||||
if expression.find(*self.UNSUPPORTED_VALUES_EXPRESSIONS):
|
||||
values_as_table = False
|
||||
|
||||
return super().values_sql(expression, values_as_table=values_as_table)
|
||||
|
||||
def datatype_sql(self, expression: exp.DataType) -> str:
|
||||
expressions = expression.expressions
|
||||
if (
|
||||
expressions
|
||||
and expression.is_type(*exp.DataType.STRUCT_TYPES)
|
||||
and any(isinstance(field_type, exp.DataType) for field_type in expressions)
|
||||
):
|
||||
# The correct syntax is OBJECT [ (<key> <value_type [NOT NULL] [, ...]) ]
|
||||
return "OBJECT"
|
||||
|
||||
return super().datatype_sql(expression)
|
||||
|
||||
def tonumber_sql(self, expression: exp.ToNumber) -> str:
|
||||
return self.func(
|
||||
"TO_NUMBER",
|
||||
expression.this,
|
||||
expression.args.get("format"),
|
||||
expression.args.get("precision"),
|
||||
expression.args.get("scale"),
|
||||
)
|
||||
|
||||
def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str:
|
||||
milli = expression.args.get("milli")
|
||||
if milli is not None:
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import rename_func
|
||||
from sqlglot.dialects.dialect import rename_func, unit_to_var
|
||||
from sqlglot.dialects.hive import _build_with_ignore_nulls
|
||||
from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider
|
||||
from sqlglot.helper import seq_get
|
||||
|
@ -78,6 +78,8 @@ class Spark(Spark2):
|
|||
return this
|
||||
|
||||
class Generator(Spark2.Generator):
|
||||
SUPPORTS_TO_NUMBER = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**Spark2.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.MONEY: "DECIMAL(15, 4)",
|
||||
|
@ -100,7 +102,7 @@ class Spark(Spark2):
|
|||
e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}",
|
||||
exp.StartsWith: rename_func("STARTSWITH"),
|
||||
exp.TimestampAdd: lambda self, e: self.func(
|
||||
"DATEADD", e.args.get("unit") or "DAY", e.expression, e.this
|
||||
"DATEADD", unit_to_var(e), e.expression, e.this
|
||||
),
|
||||
exp.TryCast: lambda self, e: (
|
||||
self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e)
|
||||
|
@ -117,11 +119,10 @@ class Spark(Spark2):
|
|||
return self.function_fallback_sql(expression)
|
||||
|
||||
def datediff_sql(self, expression: exp.DateDiff) -> str:
|
||||
unit = self.sql(expression, "unit")
|
||||
end = self.sql(expression, "this")
|
||||
start = self.sql(expression, "expression")
|
||||
|
||||
if unit:
|
||||
return self.func("DATEDIFF", unit, start, end)
|
||||
if expression.unit:
|
||||
return self.func("DATEDIFF", unit_to_var(expression), start, end)
|
||||
|
||||
return self.func("DATEDIFF", end, start)
|
||||
|
|
|
@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
|
|||
pivot_column_names,
|
||||
rename_func,
|
||||
trim_sql,
|
||||
unit_to_str,
|
||||
)
|
||||
from sqlglot.dialects.hive import Hive
|
||||
from sqlglot.helper import seq_get
|
||||
|
@ -203,6 +204,7 @@ class Spark2(Hive):
|
|||
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
|
||||
exp.ArraySum: lambda self,
|
||||
e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
exp.ArrayToString: rename_func("ARRAY_JOIN"),
|
||||
exp.AtTimeZone: lambda self, e: self.func(
|
||||
"FROM_UTC_TIMESTAMP", e.this, e.args.get("zone")
|
||||
),
|
||||
|
@ -218,7 +220,7 @@ class Spark2(Hive):
|
|||
]
|
||||
),
|
||||
exp.DateFromParts: rename_func("MAKE_DATE"),
|
||||
exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
|
||||
exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, unit_to_str(e)),
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
|
@ -241,9 +243,7 @@ class Spark2(Hive):
|
|||
),
|
||||
exp.StrToDate: _str_to_date,
|
||||
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
|
||||
exp.TimestampTrunc: lambda self, e: self.func(
|
||||
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
|
||||
),
|
||||
exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", unit_to_str(e), e.this),
|
||||
exp.Trim: trim_sql,
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
|
@ -252,7 +252,6 @@ class Spark2(Hive):
|
|||
[transforms.remove_within_group_for_percentiles]
|
||||
),
|
||||
}
|
||||
TRANSFORMS.pop(exp.ArrayJoin)
|
||||
TRANSFORMS.pop(exp.ArraySort)
|
||||
TRANSFORMS.pop(exp.ILike)
|
||||
TRANSFORMS.pop(exp.Left)
|
||||
|
|
|
@ -33,6 +33,14 @@ def _json_extract_sql(self: SQLite.Generator, expression: exp.JSONExtract) -> st
|
|||
return arrow_json_extract_sql(self, expression)
|
||||
|
||||
|
||||
def _build_strftime(args: t.List) -> exp.Anonymous | exp.TimeToStr:
|
||||
if len(args) == 1:
|
||||
args.append(exp.CurrentTimestamp())
|
||||
if len(args) == 2:
|
||||
return exp.TimeToStr(this=exp.TsOrDsToTimestamp(this=args[1]), format=args[0])
|
||||
return exp.Anonymous(this="STRFTIME", expressions=args)
|
||||
|
||||
|
||||
def _transform_create(expression: exp.Expression) -> exp.Expression:
|
||||
"""Move primary key to a column and enforce auto_increment on primary keys."""
|
||||
schema = expression.this
|
||||
|
@ -82,6 +90,7 @@ class SQLite(Dialect):
|
|||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"EDITDIST3": exp.Levenshtein.from_arg_list,
|
||||
"STRFTIME": _build_strftime,
|
||||
}
|
||||
STRING_ALIASES = True
|
||||
|
||||
|
@ -93,6 +102,7 @@ class SQLite(Dialect):
|
|||
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
|
||||
SUPPORTS_CREATE_TABLE_LIKE = False
|
||||
SUPPORTS_TABLE_ALIAS_COLUMNS = False
|
||||
SUPPORTS_TO_NUMBER = False
|
||||
|
||||
SUPPORTED_JSON_PATH_PARTS = {
|
||||
exp.JSONPathKey,
|
||||
|
@ -151,7 +161,9 @@ class SQLite(Dialect):
|
|||
),
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
exp.TimeStrToTime: lambda self, e: self.sql(e, "this"),
|
||||
exp.TimeToStr: lambda self, e: self.func("STRFTIME", e.args.get("format"), e.this),
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.TsOrDsToTimestamp: lambda self, e: self.sql(e, "this"),
|
||||
}
|
||||
|
||||
# SQLite doesn't generally support CREATE TABLE .. properties
|
||||
|
|
|
@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import (
|
|||
arrow_json_extract_sql,
|
||||
build_timestamp_trunc,
|
||||
rename_func,
|
||||
unit_to_str,
|
||||
)
|
||||
from sqlglot.dialects.mysql import MySQL
|
||||
from sqlglot.helper import seq_get
|
||||
|
@ -39,15 +40,13 @@ class StarRocks(MySQL):
|
|||
**MySQL.Generator.TRANSFORMS,
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.this, e.expression
|
||||
"DATE_DIFF", unit_to_str(e), e.this, e.expression
|
||||
),
|
||||
exp.JSONExtractScalar: arrow_json_extract_sql,
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.RegexpLike: rename_func("REGEXP"),
|
||||
exp.StrToUnix: lambda self, e: self.func("UNIX_TIMESTAMP", e.this, self.format_time(e)),
|
||||
exp.TimestampTrunc: lambda self, e: self.func(
|
||||
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
|
||||
),
|
||||
exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", unit_to_str(e), e.this),
|
||||
exp.TimeStrToDate: rename_func("TO_DATE"),
|
||||
exp.UnixToStr: lambda self, e: self.func("FROM_UNIXTIME", e.this, self.format_time(e)),
|
||||
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
||||
|
|
|
@ -5,6 +5,8 @@ from sqlglot.dialects.dialect import Dialect, rename_func
|
|||
|
||||
|
||||
class Tableau(Dialect):
|
||||
LOG_BASE_FIRST = False
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
IDENTIFIERS = [("[", "]")]
|
||||
QUOTES = ["'", '"']
|
||||
|
|
|
@ -3,7 +3,13 @@ from __future__ import annotations
|
|||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least, rename_func
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
rename_func,
|
||||
to_number_with_nls_param,
|
||||
)
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
|
@ -206,6 +212,7 @@ class Teradata(Dialect):
|
|||
exp.StrToDate: lambda self,
|
||||
e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.ToNumber: to_number_with_nls_param,
|
||||
exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}",
|
||||
}
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ from sqlglot.dialects.presto import Presto
|
|||
|
||||
class Trino(Presto):
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
LOG_BASE_FIRST = True
|
||||
|
||||
class Generator(Presto.Generator):
|
||||
TRANSFORMS = {
|
||||
|
|
|
@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
|
|||
NormalizationStrategy,
|
||||
any_value_to_max_sql,
|
||||
date_delta_sql,
|
||||
datestrtodate_sql,
|
||||
generatedasidentitycolumnconstraint_sql,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
|
@ -724,6 +725,7 @@ class TSQL(Dialect):
|
|||
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
|
||||
SUPPORTS_SELECT_INTO = True
|
||||
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
|
||||
SUPPORTS_TO_NUMBER = False
|
||||
|
||||
EXPRESSIONS_WITHOUT_NESTED_CTES = {
|
||||
exp.Delete,
|
||||
|
@ -760,12 +762,14 @@ class TSQL(Dialect):
|
|||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.AnyValue: any_value_to_max_sql,
|
||||
exp.ArrayToString: rename_func("STRING_AGG"),
|
||||
exp.AutoIncrementColumnConstraint: lambda *_: "IDENTITY",
|
||||
exp.DateAdd: date_delta_sql("DATEADD"),
|
||||
exp.DateDiff: date_delta_sql("DATEDIFF"),
|
||||
exp.CTE: transforms.preprocess([qualify_derived_table_outputs]),
|
||||
exp.CurrentDate: rename_func("GETDATE"),
|
||||
exp.CurrentTimestamp: rename_func("GETDATE"),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.Extract: rename_func("DATEPART"),
|
||||
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
|
||||
exp.GroupConcat: _string_agg_sql,
|
||||
|
@ -808,6 +812,22 @@ class TSQL(Dialect):
|
|||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def select_sql(self, expression: exp.Select) -> str:
|
||||
if expression.args.get("offset"):
|
||||
if not expression.args.get("order"):
|
||||
# ORDER BY is required in order to use OFFSET in a query, so we use
|
||||
# a noop order by, since we don't really care about the order.
|
||||
# See: https://www.microsoftpressstore.com/articles/article.aspx?p=2314819
|
||||
expression.order_by(exp.select(exp.null()).subquery(), copy=False)
|
||||
|
||||
limit = expression.args.get("limit")
|
||||
if isinstance(limit, exp.Limit):
|
||||
# TOP and OFFSET can't be combined, we need use FETCH instead of TOP
|
||||
# we replace here because otherwise TOP would be generated in select_sql
|
||||
limit.replace(exp.Fetch(direction="FIRST", count=limit.expression))
|
||||
|
||||
return super().select_sql(expression)
|
||||
|
||||
def convert_sql(self, expression: exp.Convert) -> str:
|
||||
name = "TRY_CONVERT" if expression.args.get("safe") else "CONVERT"
|
||||
return self.func(
|
||||
|
@ -862,12 +882,12 @@ class TSQL(Dialect):
|
|||
|
||||
return rename_func("DATETIMEFROMPARTS")(self, expression)
|
||||
|
||||
def set_operation(self, expression: exp.Union, op: str) -> str:
|
||||
def set_operations(self, expression: exp.Union) -> str:
|
||||
limit = expression.args.get("limit")
|
||||
if limit:
|
||||
return self.sql(expression.limit(limit.pop(), copy=False))
|
||||
|
||||
return super().set_operation(expression, op)
|
||||
return super().set_operations(expression)
|
||||
|
||||
def setitem_sql(self, expression: exp.SetItem) -> str:
|
||||
this = expression.this
|
||||
|
|
|
@ -103,7 +103,7 @@ def diff(
|
|||
) -> t.Dict[int, exp.Expression]:
|
||||
return {
|
||||
id(old_node): new_node
|
||||
for (old_node, _, _), (new_node, _, _) in zip(original.walk(), copy.walk())
|
||||
for old_node, new_node in zip(original.walk(), copy.walk())
|
||||
if id(old_node) in matching_ids
|
||||
}
|
||||
|
||||
|
@ -158,14 +158,10 @@ class ChangeDistiller:
|
|||
self._source = source
|
||||
self._target = target
|
||||
self._source_index = {
|
||||
id(n): n
|
||||
for n, *_ in self._source.bfs()
|
||||
if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES)
|
||||
id(n): n for n in self._source.bfs() if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES)
|
||||
}
|
||||
self._target_index = {
|
||||
id(n): n
|
||||
for n, *_ in self._target.bfs()
|
||||
if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES)
|
||||
id(n): n for n in self._target.bfs() if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES)
|
||||
}
|
||||
self._unmatched_source_nodes = set(self._source_index) - set(pre_matched_nodes)
|
||||
self._unmatched_target_nodes = set(self._target_index) - set(pre_matched_nodes.values())
|
||||
|
@ -216,10 +212,10 @@ class ChangeDistiller:
|
|||
matching_set = leaves_matching_set.copy()
|
||||
|
||||
ordered_unmatched_source_nodes = {
|
||||
id(n): None for n, *_ in self._source.bfs() if id(n) in self._unmatched_source_nodes
|
||||
id(n): None for n in self._source.bfs() if id(n) in self._unmatched_source_nodes
|
||||
}
|
||||
ordered_unmatched_target_nodes = {
|
||||
id(n): None for n, *_ in self._target.bfs() if id(n) in self._unmatched_target_nodes
|
||||
id(n): None for n in self._target.bfs() if id(n) in self._unmatched_target_nodes
|
||||
}
|
||||
|
||||
for source_node_id in ordered_unmatched_source_nodes:
|
||||
|
@ -322,7 +318,7 @@ class ChangeDistiller:
|
|||
def _get_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]:
|
||||
has_child_exprs = False
|
||||
|
||||
for _, node in expression.iter_expressions():
|
||||
for node in expression.iter_expressions():
|
||||
if not isinstance(node, IGNORED_LEAF_EXPRESSION_TYPES):
|
||||
has_child_exprs = True
|
||||
yield from _get_leaves(node)
|
||||
|
|
|
@ -10,11 +10,13 @@ import logging
|
|||
import time
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import ExecuteError
|
||||
from sqlglot.executor.python import PythonExecutor
|
||||
from sqlglot.executor.table import Table, ensure_tables
|
||||
from sqlglot.helper import dict_depth
|
||||
from sqlglot.optimizer import optimize
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
from sqlglot.planner import Plan
|
||||
from sqlglot.schema import ensure_schema, flatten_schema, nested_get, nested_set
|
||||
|
||||
|
@ -26,15 +28,11 @@ if t.TYPE_CHECKING:
|
|||
from sqlglot.schema import Schema
|
||||
|
||||
|
||||
PYTHON_TYPE_TO_SQLGLOT = {
|
||||
"dict": "MAP",
|
||||
}
|
||||
|
||||
|
||||
def execute(
|
||||
sql: str | Expression,
|
||||
schema: t.Optional[t.Dict | Schema] = None,
|
||||
read: DialectType = None,
|
||||
dialect: DialectType = None,
|
||||
tables: t.Optional[t.Dict] = None,
|
||||
) -> Table:
|
||||
"""
|
||||
|
@ -48,11 +46,13 @@ def execute(
|
|||
2. {db: {table: {col: type}}}
|
||||
3. {catalog: {db: {table: {col: type}}}}
|
||||
read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
|
||||
dialect: the SQL dialect (alias for read).
|
||||
tables: additional tables to register.
|
||||
|
||||
Returns:
|
||||
Simple columnar data structure.
|
||||
"""
|
||||
read = read or dialect
|
||||
tables_ = ensure_tables(tables, dialect=read)
|
||||
|
||||
if not schema:
|
||||
|
@ -64,8 +64,9 @@ def execute(
|
|||
assert table is not None
|
||||
|
||||
for column in table.columns:
|
||||
py_type = type(table[0][column]).__name__
|
||||
nested_set(schema, [*keys, column], PYTHON_TYPE_TO_SQLGLOT.get(py_type) or py_type)
|
||||
value = table[0][column]
|
||||
column_type = annotate_types(exp.convert(value)).type or type(value).__name__
|
||||
nested_set(schema, [*keys, column], column_type)
|
||||
|
||||
schema = ensure_schema(schema, dialect=read)
|
||||
|
||||
|
|
|
@ -106,6 +106,13 @@ def cast(this, to):
|
|||
return this
|
||||
if isinstance(this, str):
|
||||
return datetime.date.fromisoformat(this)
|
||||
if to == exp.DataType.Type.TIME:
|
||||
if isinstance(this, datetime.datetime):
|
||||
return this.time()
|
||||
if isinstance(this, datetime.time):
|
||||
return this
|
||||
if isinstance(this, str):
|
||||
return datetime.time.fromisoformat(this)
|
||||
if to in (exp.DataType.Type.DATETIME, exp.DataType.Type.TIMESTAMP):
|
||||
if isinstance(this, datetime.datetime):
|
||||
return this
|
||||
|
@ -139,7 +146,7 @@ def interval(this, unit):
|
|||
|
||||
|
||||
@null_if_any("this", "expression")
|
||||
def arrayjoin(this, expression, null=None):
|
||||
def arraytostring(this, expression, null=None):
|
||||
return expression.join(x for x in (x if x is not None else null for x in this) if x is not None)
|
||||
|
||||
|
||||
|
@ -173,7 +180,7 @@ ENV = {
|
|||
"ABS": null_if_any(lambda this: abs(this)),
|
||||
"ADD": null_if_any(lambda e, this: e + this),
|
||||
"ARRAYANY": null_if_any(lambda arr, func: any(func(e) for e in arr)),
|
||||
"ARRAYJOIN": arrayjoin,
|
||||
"ARRAYTOSTRING": arraytostring,
|
||||
"BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high),
|
||||
"BITWISEAND": null_if_any(lambda this, e: this & e),
|
||||
"BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e),
|
||||
|
@ -212,6 +219,7 @@ ENV = {
|
|||
"ORDERED": ordered,
|
||||
"POW": pow,
|
||||
"RIGHT": null_if_any(lambda this, e: this[-e:]),
|
||||
"ROUND": null_if_any(lambda this, decimals=None, truncate=None: round(this, ndigits=decimals)),
|
||||
"STRPOSITION": str_position,
|
||||
"SUB": null_if_any(lambda e, this: e - this),
|
||||
"SUBSTRING": substring,
|
||||
|
@ -225,10 +233,12 @@ ENV = {
|
|||
"CURRENTTIME": datetime.datetime.now,
|
||||
"CURRENTDATE": datetime.date.today,
|
||||
"STRFTIME": null_if_any(lambda fmt, arg: datetime.datetime.fromisoformat(arg).strftime(fmt)),
|
||||
"STRTOTIME": null_if_any(lambda arg, format: datetime.datetime.strptime(arg, format)),
|
||||
"TRIM": null_if_any(lambda this, e=None: this.strip(e)),
|
||||
"STRUCT": lambda *args: {
|
||||
args[x]: args[x + 1]
|
||||
for x in range(0, len(args), 2)
|
||||
if (args[x + 1] is not None and args[x] is not None)
|
||||
},
|
||||
"UNIXTOTIME": null_if_any(lambda arg: datetime.datetime.utcfromtimestamp(arg)),
|
||||
}
|
||||
|
|
|
@ -157,7 +157,7 @@ class PythonExecutor:
|
|||
yield context.table.reader
|
||||
|
||||
def join(self, step, context):
|
||||
source = step.name
|
||||
source = step.source_name
|
||||
|
||||
source_table = context.tables[source]
|
||||
source_context = self.context({source: source_table})
|
||||
|
@ -398,7 +398,7 @@ def _lambda_sql(self, e: exp.Lambda) -> str:
|
|||
lambda n: (
|
||||
exp.var(n.name) if isinstance(n, exp.Identifier) and n.name.lower() in names else n
|
||||
)
|
||||
)
|
||||
).assert_is(exp.Lambda)
|
||||
|
||||
return f"lambda {self.expressions(e, flat=True)}: {self.sql(e, 'this')}"
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -46,9 +46,11 @@ class Generator(metaclass=_Generator):
|
|||
'safe': Only quote identifiers that are case insensitive.
|
||||
normalize: Whether to normalize identifiers to lowercase.
|
||||
Default: False.
|
||||
pad: The pad size in a formatted string.
|
||||
pad: The pad size in a formatted string. For example, this affects the indentation of
|
||||
a projection in a query, relative to its nesting level.
|
||||
Default: 2.
|
||||
indent: The indentation size in a formatted string.
|
||||
indent: The indentation size in a formatted string. For example, this affects the
|
||||
indentation of subqueries and filters under a `WHERE` clause.
|
||||
Default: 2.
|
||||
normalize_functions: How to normalize function names. Possible values are:
|
||||
"upper" or True (default): Convert names to uppercase.
|
||||
|
@ -73,6 +75,7 @@ class Generator(metaclass=_Generator):
|
|||
TRANSFORMS: t.Dict[t.Type[exp.Expression], t.Callable[..., str]] = {
|
||||
**JSON_PATH_PART_TRANSFORMS,
|
||||
exp.AutoRefreshProperty: lambda self, e: f"AUTO REFRESH {self.sql(e, 'this')}",
|
||||
exp.BackupProperty: lambda self, e: f"BACKUP {self.sql(e, 'this')}",
|
||||
exp.CaseSpecificColumnConstraint: lambda _,
|
||||
e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
|
||||
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
|
||||
|
@ -83,15 +86,15 @@ class Generator(metaclass=_Generator):
|
|||
exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
|
||||
exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
|
||||
exp.CopyGrantsProperty: lambda *_: "COPY GRANTS",
|
||||
exp.DateAdd: lambda self, e: self.func(
|
||||
"DATE_ADD", e.this, e.expression, exp.Literal.string(e.text("unit"))
|
||||
),
|
||||
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
|
||||
exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}",
|
||||
exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}",
|
||||
exp.ExcludeColumnConstraint: lambda self, e: f"EXCLUDE {self.sql(e, 'this').lstrip()}",
|
||||
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
|
||||
exp.ExternalProperty: lambda *_: "EXTERNAL",
|
||||
exp.GlobalProperty: lambda *_: "GLOBAL",
|
||||
exp.HeapProperty: lambda *_: "HEAP",
|
||||
exp.IcebergProperty: lambda *_: "ICEBERG",
|
||||
exp.InheritsProperty: lambda self, e: f"INHERITS ({self.expressions(e, flat=True)})",
|
||||
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
|
||||
exp.InputModelProperty: lambda self, e: f"INPUT{self.sql(e, 'this')}",
|
||||
|
@ -123,6 +126,7 @@ class Generator(metaclass=_Generator):
|
|||
exp.SetConfigProperty: lambda self, e: self.sql(e, "this"),
|
||||
exp.SetProperty: lambda _, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
|
||||
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
|
||||
exp.SharingProperty: lambda self, e: f"SHARING={self.sql(e, 'this')}",
|
||||
exp.SqlReadWriteProperty: lambda _, e: e.name,
|
||||
exp.SqlSecurityProperty: lambda _,
|
||||
e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
|
||||
|
@ -130,13 +134,17 @@ class Generator(metaclass=_Generator):
|
|||
exp.TemporaryProperty: lambda *_: "TEMPORARY",
|
||||
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
|
||||
exp.Timestamp: lambda self, e: self.func("TIMESTAMP", e.this, e.expression),
|
||||
exp.ToMap: lambda self, e: f"MAP {self.sql(e, 'this')}",
|
||||
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
|
||||
exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions),
|
||||
exp.TransientProperty: lambda *_: "TRANSIENT",
|
||||
exp.UppercaseColumnConstraint: lambda *_: "UPPERCASE",
|
||||
exp.UnloggedProperty: lambda *_: "UNLOGGED",
|
||||
exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]),
|
||||
exp.ViewAttributeProperty: lambda self, e: f"WITH {self.sql(e, 'this')}",
|
||||
exp.VolatileProperty: lambda *_: "VOLATILE",
|
||||
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
|
||||
exp.WithOperator: lambda self, e: f"{self.sql(e, 'this')} WITH {self.sql(e, 'op')}",
|
||||
}
|
||||
|
||||
# Whether null ordering is supported in order by
|
||||
|
@ -321,6 +329,9 @@ class Generator(metaclass=_Generator):
|
|||
# Whether any(f(x) for x in array) can be implemented by this dialect
|
||||
CAN_IMPLEMENT_ARRAY_ANY = False
|
||||
|
||||
# Whether the function TO_NUMBER is supported
|
||||
SUPPORTS_TO_NUMBER = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
exp.DataType.Type.NCHAR: "CHAR",
|
||||
exp.DataType.Type.NVARCHAR: "VARCHAR",
|
||||
|
@ -350,6 +361,18 @@ class Generator(metaclass=_Generator):
|
|||
"YEARS": "YEAR",
|
||||
}
|
||||
|
||||
AFTER_HAVING_MODIFIER_TRANSFORMS = {
|
||||
"cluster": lambda self, e: self.sql(e, "cluster"),
|
||||
"distribute": lambda self, e: self.sql(e, "distribute"),
|
||||
"qualify": lambda self, e: self.sql(e, "qualify"),
|
||||
"sort": lambda self, e: self.sql(e, "sort"),
|
||||
"windows": lambda self, e: (
|
||||
self.seg("WINDOW ") + self.expressions(e, key="windows", flat=True)
|
||||
if e.args.get("windows")
|
||||
else ""
|
||||
),
|
||||
}
|
||||
|
||||
TOKEN_MAPPING: t.Dict[TokenType, str] = {}
|
||||
|
||||
STRUCT_DELIMITER = ("<", ">")
|
||||
|
@ -361,6 +384,7 @@ class Generator(metaclass=_Generator):
|
|||
exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.AutoRefreshProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.BackupProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.BlockCompressionProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.ChecksumProperty: exp.Properties.Location.POST_NAME,
|
||||
|
@ -380,8 +404,10 @@ class Generator(metaclass=_Generator):
|
|||
exp.FallbackProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.FileFormatProperty: exp.Properties.Location.POST_WITH,
|
||||
exp.FreespaceProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.GlobalProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.HeapProperty: exp.Properties.Location.POST_WITH,
|
||||
exp.InheritsProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.IcebergProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.InputModelProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.IsolatedLoadingProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.JournalProperty: exp.Properties.Location.POST_NAME,
|
||||
|
@ -414,6 +440,8 @@ class Generator(metaclass=_Generator):
|
|||
exp.SettingsProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SetProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.SetConfigProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SharingProperty: exp.Properties.Location.POST_EXPRESSION,
|
||||
exp.SequenceProperties: exp.Properties.Location.POST_EXPRESSION,
|
||||
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SqlReadWriteProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
|
||||
|
@ -423,6 +451,8 @@ class Generator(metaclass=_Generator):
|
|||
exp.TransientProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.TransformModelProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.MergeTreeTTL: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.UnloggedProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.ViewAttributeProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION,
|
||||
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
|
||||
|
@ -441,6 +471,7 @@ class Generator(metaclass=_Generator):
|
|||
exp.Insert,
|
||||
exp.Join,
|
||||
exp.Select,
|
||||
exp.Union,
|
||||
exp.Update,
|
||||
exp.Where,
|
||||
exp.With,
|
||||
|
@ -626,7 +657,7 @@ class Generator(metaclass=_Generator):
|
|||
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
|
||||
return (
|
||||
f"{self.sep()}{comments_sql}{sql}"
|
||||
if sql[0].isspace()
|
||||
if not sql or sql[0].isspace()
|
||||
else f"{comments_sql}{self.sep()}{sql}"
|
||||
)
|
||||
|
||||
|
@ -869,7 +900,9 @@ class Generator(metaclass=_Generator):
|
|||
this = f" {this}" if this else ""
|
||||
index_type = expression.args.get("index_type")
|
||||
index_type = f" USING {index_type}" if index_type else ""
|
||||
return f"UNIQUE{this}{index_type}"
|
||||
on_conflict = self.sql(expression, "on_conflict")
|
||||
on_conflict = f" {on_conflict}" if on_conflict else ""
|
||||
return f"UNIQUE{this}{index_type}{on_conflict}"
|
||||
|
||||
def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str:
|
||||
return self.sql(expression, "this")
|
||||
|
@ -961,6 +994,31 @@ class Generator(metaclass=_Generator):
|
|||
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{postexpression_props_sql}{index_sql}{no_schema_binding}{clone}"
|
||||
return self.prepend_ctes(expression, expression_sql)
|
||||
|
||||
def sequenceproperties_sql(self, expression: exp.SequenceProperties) -> str:
|
||||
start = self.sql(expression, "start")
|
||||
start = f"START WITH {start}" if start else ""
|
||||
increment = self.sql(expression, "increment")
|
||||
increment = f" INCREMENT BY {increment}" if increment else ""
|
||||
minvalue = self.sql(expression, "minvalue")
|
||||
minvalue = f" MINVALUE {minvalue}" if minvalue else ""
|
||||
maxvalue = self.sql(expression, "maxvalue")
|
||||
maxvalue = f" MAXVALUE {maxvalue}" if maxvalue else ""
|
||||
owned = self.sql(expression, "owned")
|
||||
owned = f" OWNED BY {owned}" if owned else ""
|
||||
|
||||
cache = expression.args.get("cache")
|
||||
if cache is None:
|
||||
cache_str = ""
|
||||
elif cache is True:
|
||||
cache_str = " CACHE"
|
||||
else:
|
||||
cache_str = f" CACHE {cache}"
|
||||
|
||||
options = self.expressions(expression, key="options", flat=True, sep=" ")
|
||||
options = f" {options}" if options else ""
|
||||
|
||||
return f"{start}{increment}{minvalue}{maxvalue}{cache_str}{options}{owned}".lstrip()
|
||||
|
||||
def clone_sql(self, expression: exp.Clone) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
shallow = "SHALLOW " if expression.args.get("shallow") else ""
|
||||
|
@ -968,8 +1026,9 @@ class Generator(metaclass=_Generator):
|
|||
return f"{shallow}{keyword} {this}"
|
||||
|
||||
def describe_sql(self, expression: exp.Describe) -> str:
|
||||
extended = " EXTENDED" if expression.args.get("extended") else ""
|
||||
return f"DESCRIBE{extended} {self.sql(expression, 'this')}"
|
||||
style = expression.args.get("style")
|
||||
style = f" {style}" if style else ""
|
||||
return f"DESCRIBE{style} {self.sql(expression, 'this')}"
|
||||
|
||||
def heredoc_sql(self, expression: exp.Heredoc) -> str:
|
||||
tag = self.sql(expression, "tag")
|
||||
|
@ -993,7 +1052,14 @@ class Generator(metaclass=_Generator):
|
|||
|
||||
def cte_sql(self, expression: exp.CTE) -> str:
|
||||
alias = self.sql(expression, "alias")
|
||||
return f"{alias} AS {self.wrap(expression)}"
|
||||
|
||||
materialized = expression.args.get("materialized")
|
||||
if materialized is False:
|
||||
materialized = "NOT MATERIALIZED "
|
||||
elif materialized:
|
||||
materialized = "MATERIALIZED "
|
||||
|
||||
return f"{alias} AS {materialized or ''}{self.wrap(expression)}"
|
||||
|
||||
def tablealias_sql(self, expression: exp.TableAlias) -> str:
|
||||
alias = self.sql(expression, "this")
|
||||
|
@ -1044,7 +1110,7 @@ class Generator(metaclass=_Generator):
|
|||
return f"{self.dialect.QUOTE_START}{this}{self.dialect.QUOTE_END}"
|
||||
|
||||
def rawstring_sql(self, expression: exp.RawString) -> str:
|
||||
string = self.escape_str(expression.this.replace("\\", "\\\\"))
|
||||
string = self.escape_str(expression.this.replace("\\", "\\\\"), escape_backslash=False)
|
||||
return f"{self.dialect.QUOTE_START}{string}{self.dialect.QUOTE_END}"
|
||||
|
||||
def datatypeparam_sql(self, expression: exp.DataTypeParam) -> str:
|
||||
|
@ -1114,6 +1180,8 @@ class Generator(metaclass=_Generator):
|
|||
|
||||
def drop_sql(self, expression: exp.Drop) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
expressions = self.expressions(expression, flat=True)
|
||||
expressions = f" ({expressions})" if expressions else ""
|
||||
kind = expression.args["kind"]
|
||||
exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
|
||||
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
|
||||
|
@ -1121,15 +1189,10 @@ class Generator(metaclass=_Generator):
|
|||
cascade = " CASCADE" if expression.args.get("cascade") else ""
|
||||
constraints = " CONSTRAINTS" if expression.args.get("constraints") else ""
|
||||
purge = " PURGE" if expression.args.get("purge") else ""
|
||||
return (
|
||||
f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}{constraints}{purge}"
|
||||
)
|
||||
return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{expressions}{cascade}{constraints}{purge}"
|
||||
|
||||
def except_sql(self, expression: exp.Except) -> str:
|
||||
return self.prepend_ctes(
|
||||
expression,
|
||||
self.set_operation(expression, self.except_op(expression)),
|
||||
)
|
||||
return self.set_operations(expression)
|
||||
|
||||
def except_op(self, expression: exp.Except) -> str:
|
||||
return f"EXCEPT{'' if expression.args.get('distinct') else ' ALL'}"
|
||||
|
@ -1163,17 +1226,9 @@ class Generator(metaclass=_Generator):
|
|||
|
||||
return f" /*+ {self.expressions(expression, sep=self.QUERY_HINT_SEP).strip()} */"
|
||||
|
||||
def index_sql(self, expression: exp.Index) -> str:
|
||||
unique = "UNIQUE " if expression.args.get("unique") else ""
|
||||
primary = "PRIMARY " if expression.args.get("primary") else ""
|
||||
amp = "AMP " if expression.args.get("amp") else ""
|
||||
name = self.sql(expression, "this")
|
||||
name = f"{name} " if name else ""
|
||||
table = self.sql(expression, "table")
|
||||
table = f"{self.INDEX_ON} {table}" if table else ""
|
||||
def indexparameters_sql(self, expression: exp.IndexParameters) -> str:
|
||||
using = self.sql(expression, "using")
|
||||
using = f" USING {using}" if using else ""
|
||||
index = "INDEX " if not table else ""
|
||||
columns = self.expressions(expression, key="columns", flat=True)
|
||||
columns = f"({columns})" if columns else ""
|
||||
partition_by = self.expressions(expression, key="partition_by", flat=True)
|
||||
|
@ -1182,7 +1237,26 @@ class Generator(metaclass=_Generator):
|
|||
include = self.expressions(expression, key="include", flat=True)
|
||||
if include:
|
||||
include = f" INCLUDE ({include})"
|
||||
return f"{unique}{primary}{amp}{index}{name}{table}{using}{columns}{include}{partition_by}{where}"
|
||||
with_storage = self.expressions(expression, key="with_storage", flat=True)
|
||||
with_storage = f" WITH ({with_storage})" if with_storage else ""
|
||||
tablespace = self.sql(expression, "tablespace")
|
||||
tablespace = f" USING INDEX TABLESPACE {tablespace}" if tablespace else ""
|
||||
|
||||
return f"{using}{columns}{include}{with_storage}{tablespace}{partition_by}{where}"
|
||||
|
||||
def index_sql(self, expression: exp.Index) -> str:
|
||||
unique = "UNIQUE " if expression.args.get("unique") else ""
|
||||
primary = "PRIMARY " if expression.args.get("primary") else ""
|
||||
amp = "AMP " if expression.args.get("amp") else ""
|
||||
name = self.sql(expression, "this")
|
||||
name = f"{name} " if name else ""
|
||||
table = self.sql(expression, "table")
|
||||
table = f"{self.INDEX_ON} {table}" if table else ""
|
||||
|
||||
index = "INDEX " if not table else ""
|
||||
|
||||
params = self.sql(expression, "params")
|
||||
return f"{unique}{primary}{amp}{index}{name}{table}{params}"
|
||||
|
||||
def identifier_sql(self, expression: exp.Identifier) -> str:
|
||||
text = expression.name
|
||||
|
@ -1371,15 +1445,9 @@ class Generator(metaclass=_Generator):
|
|||
no = " NO" if no else ""
|
||||
concurrent = expression.args.get("concurrent")
|
||||
concurrent = " CONCURRENT" if concurrent else ""
|
||||
|
||||
for_ = ""
|
||||
if expression.args.get("for_all"):
|
||||
for_ = " FOR ALL"
|
||||
elif expression.args.get("for_insert"):
|
||||
for_ = " FOR INSERT"
|
||||
elif expression.args.get("for_none"):
|
||||
for_ = " FOR NONE"
|
||||
return f"WITH{no}{concurrent} ISOLATED LOADING{for_}"
|
||||
target = self.sql(expression, "target")
|
||||
target = f" {target}" if target else ""
|
||||
return f"WITH{no}{concurrent} ISOLATED LOADING{target}"
|
||||
|
||||
def partitionboundspec_sql(self, expression: exp.PartitionBoundSpec) -> str:
|
||||
if isinstance(expression.this, list):
|
||||
|
@ -1437,6 +1505,7 @@ class Generator(metaclass=_Generator):
|
|||
return f"{sql})"
|
||||
|
||||
def insert_sql(self, expression: exp.Insert) -> str:
|
||||
hint = self.sql(expression, "hint")
|
||||
overwrite = expression.args.get("overwrite")
|
||||
|
||||
if isinstance(expression.this, exp.Directory):
|
||||
|
@ -1447,7 +1516,9 @@ class Generator(metaclass=_Generator):
|
|||
alternative = expression.args.get("alternative")
|
||||
alternative = f" OR {alternative}" if alternative else ""
|
||||
ignore = " IGNORE" if expression.args.get("ignore") else ""
|
||||
|
||||
is_function = expression.args.get("is_function")
|
||||
if is_function:
|
||||
this = f"{this} FUNCTION"
|
||||
this = f"{this} {self.sql(expression, 'this')}"
|
||||
|
||||
exists = " IF EXISTS" if expression.args.get("exists") else ""
|
||||
|
@ -1457,23 +1528,21 @@ class Generator(metaclass=_Generator):
|
|||
where = self.sql(expression, "where")
|
||||
where = f"{self.sep()}REPLACE WHERE {where}" if where else ""
|
||||
expression_sql = f"{self.sep()}{self.sql(expression, 'expression')}"
|
||||
conflict = self.sql(expression, "conflict")
|
||||
on_conflict = self.sql(expression, "conflict")
|
||||
on_conflict = f" {on_conflict}" if on_conflict else ""
|
||||
by_name = " BY NAME" if expression.args.get("by_name") else ""
|
||||
returning = self.sql(expression, "returning")
|
||||
|
||||
if self.RETURNING_END:
|
||||
expression_sql = f"{expression_sql}{conflict}{returning}"
|
||||
expression_sql = f"{expression_sql}{on_conflict}{returning}"
|
||||
else:
|
||||
expression_sql = f"{returning}{expression_sql}{conflict}"
|
||||
expression_sql = f"{returning}{expression_sql}{on_conflict}"
|
||||
|
||||
sql = f"INSERT{alternative}{ignore}{this}{by_name}{exists}{partition_sql}{where}{expression_sql}"
|
||||
sql = f"INSERT{hint}{alternative}{ignore}{this}{by_name}{exists}{partition_sql}{where}{expression_sql}"
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
def intersect_sql(self, expression: exp.Intersect) -> str:
|
||||
return self.prepend_ctes(
|
||||
expression,
|
||||
self.set_operation(expression, self.intersect_op(expression)),
|
||||
)
|
||||
return self.set_operations(expression)
|
||||
|
||||
def intersect_op(self, expression: exp.Intersect) -> str:
|
||||
return f"INTERSECT{'' if expression.args.get('distinct') else ' ALL'}"
|
||||
|
@ -1496,33 +1565,36 @@ class Generator(metaclass=_Generator):
|
|||
|
||||
def onconflict_sql(self, expression: exp.OnConflict) -> str:
|
||||
conflict = "ON DUPLICATE KEY" if expression.args.get("duplicate") else "ON CONFLICT"
|
||||
|
||||
constraint = self.sql(expression, "constraint")
|
||||
if constraint:
|
||||
constraint = f"ON CONSTRAINT {constraint}"
|
||||
key = self.expressions(expression, key="key", flat=True)
|
||||
do = "" if expression.args.get("duplicate") else " DO "
|
||||
nothing = "NOTHING" if expression.args.get("nothing") else ""
|
||||
constraint = f" ON CONSTRAINT {constraint}" if constraint else ""
|
||||
|
||||
conflict_keys = self.expressions(expression, key="conflict_keys", flat=True)
|
||||
conflict_keys = f"({conflict_keys}) " if conflict_keys else " "
|
||||
action = self.sql(expression, "action")
|
||||
|
||||
expressions = self.expressions(expression, flat=True)
|
||||
set_keyword = "SET " if self.DUPLICATE_KEY_UPDATE_WITH_SET else ""
|
||||
if expressions:
|
||||
expressions = f"UPDATE {set_keyword}{expressions}"
|
||||
return f"{self.seg(conflict)} {constraint}{key}{do}{nothing}{expressions}"
|
||||
set_keyword = "SET " if self.DUPLICATE_KEY_UPDATE_WITH_SET else ""
|
||||
expressions = f" {set_keyword}{expressions}"
|
||||
|
||||
return f"{conflict}{constraint}{conflict_keys}{action}{expressions}"
|
||||
|
||||
def returning_sql(self, expression: exp.Returning) -> str:
|
||||
return f"{self.seg('RETURNING')} {self.expressions(expression, flat=True)}"
|
||||
|
||||
def rowformatdelimitedproperty_sql(self, expression: exp.RowFormatDelimitedProperty) -> str:
|
||||
fields = expression.args.get("fields")
|
||||
fields = self.sql(expression, "fields")
|
||||
fields = f" FIELDS TERMINATED BY {fields}" if fields else ""
|
||||
escaped = expression.args.get("escaped")
|
||||
escaped = self.sql(expression, "escaped")
|
||||
escaped = f" ESCAPED BY {escaped}" if escaped else ""
|
||||
items = expression.args.get("collection_items")
|
||||
items = self.sql(expression, "collection_items")
|
||||
items = f" COLLECTION ITEMS TERMINATED BY {items}" if items else ""
|
||||
keys = expression.args.get("map_keys")
|
||||
keys = self.sql(expression, "map_keys")
|
||||
keys = f" MAP KEYS TERMINATED BY {keys}" if keys else ""
|
||||
lines = expression.args.get("lines")
|
||||
lines = self.sql(expression, "lines")
|
||||
lines = f" LINES TERMINATED BY {lines}" if lines else ""
|
||||
null = expression.args.get("null")
|
||||
null = self.sql(expression, "null")
|
||||
null = f" NULL DEFINED AS {null}" if null else ""
|
||||
return f"ROW FORMAT DELIMITED{fields}{escaped}{items}{keys}{lines}{null}"
|
||||
|
||||
|
@ -1563,7 +1635,9 @@ class Generator(metaclass=_Generator):
|
|||
hints = f" {hints}" if hints and self.TABLE_HINTS else ""
|
||||
pivots = self.expressions(expression, key="pivots", sep=" ", flat=True)
|
||||
pivots = f" {pivots}" if pivots else ""
|
||||
joins = self.expressions(expression, key="joins", sep="", skip_first=True)
|
||||
joins = self.indent(
|
||||
self.expressions(expression, key="joins", sep="", flat=True), skip_first=True
|
||||
)
|
||||
laterals = self.expressions(expression, key="laterals", sep="")
|
||||
|
||||
file_format = self.sql(expression, "format")
|
||||
|
@ -1673,9 +1747,11 @@ class Generator(metaclass=_Generator):
|
|||
sql = f"UPDATE {this} SET {set_sql}{expression_sql}{order}{limit}"
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
def values_sql(self, expression: exp.Values) -> str:
|
||||
def values_sql(self, expression: exp.Values, values_as_table: bool = True) -> str:
|
||||
values_as_table = values_as_table and self.VALUES_AS_TABLE
|
||||
|
||||
# The VALUES clause is still valid in an `INSERT INTO ..` statement, for example
|
||||
if self.VALUES_AS_TABLE or not expression.find_ancestor(exp.From, exp.Join):
|
||||
if values_as_table or not expression.find_ancestor(exp.From, exp.Join):
|
||||
args = self.expressions(expression)
|
||||
alias = self.sql(expression, "alias")
|
||||
values = f"VALUES{self.seg('')}{args}"
|
||||
|
@ -1769,8 +1845,9 @@ class Generator(metaclass=_Generator):
|
|||
def connect_sql(self, expression: exp.Connect) -> str:
|
||||
start = self.sql(expression, "start")
|
||||
start = self.seg(f"START WITH {start}") if start else ""
|
||||
nocycle = " NOCYCLE" if expression.args.get("nocycle") else ""
|
||||
connect = self.sql(expression, "connect")
|
||||
connect = self.seg(f"CONNECT BY {connect}")
|
||||
connect = self.seg(f"CONNECT BY{nocycle} {connect}")
|
||||
return start + connect
|
||||
|
||||
def prior_sql(self, expression: exp.Prior) -> str:
|
||||
|
@ -1793,6 +1870,8 @@ class Generator(metaclass=_Generator):
|
|||
)
|
||||
if op
|
||||
)
|
||||
match_cond = self.sql(expression, "match_condition")
|
||||
match_cond = f" MATCH_CONDITION ({match_cond})" if match_cond else ""
|
||||
on_sql = self.sql(expression, "on")
|
||||
using = expression.args.get("using")
|
||||
|
||||
|
@ -1816,7 +1895,7 @@ class Generator(metaclass=_Generator):
|
|||
return f", {this_sql}"
|
||||
|
||||
op_sql = f"{op_sql} JOIN" if op_sql else "JOIN"
|
||||
return f"{self.seg(op_sql)} {this_sql}{on_sql}"
|
||||
return f"{self.seg(op_sql)} {this_sql}{match_cond}{on_sql}"
|
||||
|
||||
def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str:
|
||||
args = self.expressions(expression, flat=True)
|
||||
|
@ -1919,13 +1998,17 @@ class Generator(metaclass=_Generator):
|
|||
text = f"{self.dialect.QUOTE_START}{self.escape_str(text)}{self.dialect.QUOTE_END}"
|
||||
return text
|
||||
|
||||
def escape_str(self, text: str) -> str:
|
||||
text = text.replace(self.dialect.QUOTE_END, self._escaped_quote_end)
|
||||
if self.dialect.INVERSE_ESCAPE_SEQUENCES:
|
||||
text = "".join(self.dialect.INVERSE_ESCAPE_SEQUENCES.get(ch, ch) for ch in text)
|
||||
elif self.pretty:
|
||||
def escape_str(self, text: str, escape_backslash: bool = True) -> str:
|
||||
if self.dialect.ESCAPED_SEQUENCES:
|
||||
to_escaped = self.dialect.ESCAPED_SEQUENCES
|
||||
text = "".join(
|
||||
to_escaped.get(ch, ch) if escape_backslash or ch != "\\" else ch for ch in text
|
||||
)
|
||||
|
||||
if self.pretty:
|
||||
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
|
||||
return text
|
||||
|
||||
return text.replace(self.dialect.QUOTE_END, self._escaped_quote_end)
|
||||
|
||||
def loaddata_sql(self, expression: exp.LoadData) -> str:
|
||||
local = " LOCAL" if expression.args.get("local") else ""
|
||||
|
@ -2016,7 +2099,7 @@ class Generator(metaclass=_Generator):
|
|||
self.unsupported(
|
||||
f"'{nulls_sort_change.strip()}' translation not supported with positional ordering"
|
||||
)
|
||||
else:
|
||||
elif not isinstance(expression.this, exp.Rand):
|
||||
null_sort_order = " DESC" if nulls_sort_change == " NULLS FIRST" else ""
|
||||
this = f"CASE WHEN {this} IS NULL THEN 1 ELSE 0 END{null_sort_order}, {this}"
|
||||
nulls_sort_change = ""
|
||||
|
@ -2059,24 +2142,13 @@ class Generator(metaclass=_Generator):
|
|||
return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}{alias}"
|
||||
|
||||
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
|
||||
limit: t.Optional[exp.Fetch | exp.Limit] = expression.args.get("limit")
|
||||
|
||||
# If the limit is generated as TOP, we need to ensure it's not generated twice
|
||||
with_offset_limit_modifiers = not isinstance(limit, exp.Limit) or not self.LIMIT_IS_TOP
|
||||
limit = expression.args.get("limit")
|
||||
|
||||
if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch):
|
||||
limit = exp.Limit(expression=exp.maybe_copy(limit.args.get("count")))
|
||||
elif self.LIMIT_FETCH == "FETCH" and isinstance(limit, exp.Limit):
|
||||
limit = exp.Fetch(direction="FIRST", count=exp.maybe_copy(limit.expression))
|
||||
|
||||
fetch = isinstance(limit, exp.Fetch)
|
||||
|
||||
offset_limit_modifiers = (
|
||||
self.offset_limit_modifiers(expression, fetch, limit)
|
||||
if with_offset_limit_modifiers
|
||||
else []
|
||||
)
|
||||
|
||||
options = self.expressions(expression, key="options")
|
||||
if options:
|
||||
options = f" OPTION{self.wrap(options)}"
|
||||
|
@ -2091,9 +2163,9 @@ class Generator(metaclass=_Generator):
|
|||
self.sql(expression, "where"),
|
||||
self.sql(expression, "group"),
|
||||
self.sql(expression, "having"),
|
||||
*self.after_having_modifiers(expression),
|
||||
*[gen(self, expression) for gen in self.AFTER_HAVING_MODIFIER_TRANSFORMS.values()],
|
||||
self.sql(expression, "order"),
|
||||
*offset_limit_modifiers,
|
||||
*self.offset_limit_modifiers(expression, isinstance(limit, exp.Fetch), limit),
|
||||
*self.after_limit_modifiers(expression),
|
||||
options,
|
||||
sep="",
|
||||
|
@ -2110,19 +2182,6 @@ class Generator(metaclass=_Generator):
|
|||
self.sql(limit) if fetch else self.sql(expression, "offset"),
|
||||
]
|
||||
|
||||
def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
|
||||
return [
|
||||
self.sql(expression, "qualify"),
|
||||
(
|
||||
self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True)
|
||||
if expression.args.get("windows")
|
||||
else ""
|
||||
),
|
||||
self.sql(expression, "distribute"),
|
||||
self.sql(expression, "sort"),
|
||||
self.sql(expression, "cluster"),
|
||||
]
|
||||
|
||||
def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]:
|
||||
locks = self.expressions(expression, key="locks", sep=" ")
|
||||
locks = f" {locks}" if locks else ""
|
||||
|
@ -2137,12 +2196,13 @@ class Generator(metaclass=_Generator):
|
|||
distinct = self.sql(expression, "distinct")
|
||||
distinct = f" {distinct}" if distinct else ""
|
||||
kind = self.sql(expression, "kind")
|
||||
|
||||
limit = expression.args.get("limit")
|
||||
top = (
|
||||
self.limit_sql(limit, top=True)
|
||||
if isinstance(limit, exp.Limit) and self.LIMIT_IS_TOP
|
||||
else ""
|
||||
)
|
||||
if isinstance(limit, exp.Limit) and self.LIMIT_IS_TOP:
|
||||
top = self.limit_sql(limit, top=True)
|
||||
limit.pop()
|
||||
else:
|
||||
top = ""
|
||||
|
||||
expressions = self.expressions(expression)
|
||||
|
||||
|
@ -2220,7 +2280,7 @@ class Generator(metaclass=_Generator):
|
|||
return f"@@{kind}{this}"
|
||||
|
||||
def placeholder_sql(self, expression: exp.Placeholder) -> str:
|
||||
return f"{self.NAMED_PLACEHOLDER_TOKEN}{expression.name}" if expression.name else "?"
|
||||
return f"{self.NAMED_PLACEHOLDER_TOKEN}{expression.name}" if expression.this else "?"
|
||||
|
||||
def subquery_sql(self, expression: exp.Subquery, sep: str = " AS ") -> str:
|
||||
alias = self.sql(expression, "alias")
|
||||
|
@ -2236,11 +2296,32 @@ class Generator(metaclass=_Generator):
|
|||
this = self.indent(self.sql(expression, "this"))
|
||||
return f"{self.seg('QUALIFY')}{self.sep()}{this}"
|
||||
|
||||
def set_operations(self, expression: exp.Union) -> str:
|
||||
sqls: t.List[str] = []
|
||||
stack: t.List[t.Union[str, exp.Expression]] = [expression]
|
||||
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
|
||||
if isinstance(node, exp.Union):
|
||||
stack.append(node.expression)
|
||||
stack.append(
|
||||
self.maybe_comment(
|
||||
getattr(self, f"{node.key}_op")(node),
|
||||
expression=node.this,
|
||||
comments=node.comments,
|
||||
)
|
||||
)
|
||||
stack.append(node.this)
|
||||
else:
|
||||
sqls.append(self.sql(node))
|
||||
|
||||
this = self.sep().join(sqls)
|
||||
this = self.query_modifiers(expression, this)
|
||||
return self.prepend_ctes(expression, this)
|
||||
|
||||
def union_sql(self, expression: exp.Union) -> str:
|
||||
return self.prepend_ctes(
|
||||
expression,
|
||||
self.set_operation(expression, self.union_op(expression)),
|
||||
)
|
||||
return self.set_operations(expression)
|
||||
|
||||
def union_op(self, expression: exp.Union) -> str:
|
||||
kind = " DISTINCT" if self.EXPLICIT_UNION else ""
|
||||
|
@ -2345,8 +2426,10 @@ class Generator(metaclass=_Generator):
|
|||
|
||||
def any_sql(self, expression: exp.Any) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
if isinstance(expression.this, exp.UNWRAPPED_QUERIES):
|
||||
this = self.wrap(this)
|
||||
if isinstance(expression.this, (*exp.UNWRAPPED_QUERIES, exp.Paren)):
|
||||
if isinstance(expression.this, exp.UNWRAPPED_QUERIES):
|
||||
this = self.wrap(this)
|
||||
return f"ANY{this}"
|
||||
return f"ANY {this}"
|
||||
|
||||
def exists_sql(self, expression: exp.Exists) -> str:
|
||||
|
@ -2632,13 +2715,8 @@ class Generator(metaclass=_Generator):
|
|||
return self.func(self.sql(expression, "this"), *expression.expressions)
|
||||
|
||||
def paren_sql(self, expression: exp.Paren) -> str:
|
||||
if isinstance(expression.unnest(), exp.Select):
|
||||
sql = self.wrap(expression)
|
||||
else:
|
||||
sql = self.seg(self.indent(self.sql(expression, "this")), sep="")
|
||||
sql = f"({sql}{self.seg(')', sep='')}"
|
||||
|
||||
return self.prepend_ctes(expression, sql)
|
||||
sql = self.seg(self.indent(self.sql(expression, "this")), sep="")
|
||||
return f"({sql}{self.seg(')', sep='')}"
|
||||
|
||||
def neg_sql(self, expression: exp.Neg) -> str:
|
||||
# This makes sure we don't convert "- - 5" to "--5", which is a comment
|
||||
|
@ -2686,23 +2764,55 @@ class Generator(metaclass=_Generator):
|
|||
def add_sql(self, expression: exp.Add) -> str:
|
||||
return self.binary(expression, "+")
|
||||
|
||||
def and_sql(self, expression: exp.And) -> str:
|
||||
return self.connector_sql(expression, "AND")
|
||||
def and_sql(
|
||||
self, expression: exp.And, stack: t.Optional[t.List[str | exp.Expression]] = None
|
||||
) -> str:
|
||||
return self.connector_sql(expression, "AND", stack)
|
||||
|
||||
def xor_sql(self, expression: exp.Xor) -> str:
|
||||
return self.connector_sql(expression, "XOR")
|
||||
def or_sql(
|
||||
self, expression: exp.Or, stack: t.Optional[t.List[str | exp.Expression]] = None
|
||||
) -> str:
|
||||
return self.connector_sql(expression, "OR", stack)
|
||||
|
||||
def connector_sql(self, expression: exp.Connector, op: str) -> str:
|
||||
if not self.pretty:
|
||||
return self.binary(expression, op)
|
||||
def xor_sql(
|
||||
self, expression: exp.Xor, stack: t.Optional[t.List[str | exp.Expression]] = None
|
||||
) -> str:
|
||||
return self.connector_sql(expression, "XOR", stack)
|
||||
|
||||
sqls = tuple(
|
||||
self.maybe_comment(self.sql(e), e, e.parent.comments or []) if i != 1 else self.sql(e)
|
||||
for i, e in enumerate(expression.flatten(unnest=False))
|
||||
)
|
||||
def connector_sql(
|
||||
self,
|
||||
expression: exp.Connector,
|
||||
op: str,
|
||||
stack: t.Optional[t.List[str | exp.Expression]] = None,
|
||||
) -> str:
|
||||
if stack is not None:
|
||||
if expression.expressions:
|
||||
stack.append(self.expressions(expression, sep=f" {op} "))
|
||||
else:
|
||||
stack.append(expression.right)
|
||||
if expression.comments:
|
||||
for comment in expression.comments:
|
||||
op += f" /*{self.pad_comment(comment)}*/"
|
||||
stack.extend((op, expression.left))
|
||||
return op
|
||||
|
||||
sep = "\n" if self.text_width(sqls) > self.max_text_width else " "
|
||||
return f"{sep}{op} ".join(sqls)
|
||||
stack = [expression]
|
||||
sqls: t.List[str] = []
|
||||
ops = set()
|
||||
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
if isinstance(node, exp.Connector):
|
||||
ops.add(getattr(self, f"{node.key}_sql")(node, stack))
|
||||
else:
|
||||
sql = self.sql(node)
|
||||
if sqls and sqls[-1] in ops:
|
||||
sqls[-1] += f" {sql}"
|
||||
else:
|
||||
sqls.append(sql)
|
||||
|
||||
sep = "\n" if self.pretty and self.text_width(sqls) > self.max_text_width else " "
|
||||
return sep.join(sqls)
|
||||
|
||||
def bitwiseand_sql(self, expression: exp.BitwiseAnd) -> str:
|
||||
return self.binary(expression, "&")
|
||||
|
@ -2727,7 +2837,9 @@ class Generator(metaclass=_Generator):
|
|||
format_sql = f" FORMAT {format_sql}" if format_sql else ""
|
||||
to_sql = self.sql(expression, "to")
|
||||
to_sql = f" {to_sql}" if to_sql else ""
|
||||
return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS{to_sql}{format_sql})"
|
||||
action = self.sql(expression, "action")
|
||||
action = f" {action}" if action else ""
|
||||
return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS{to_sql}{format_sql}{action})"
|
||||
|
||||
def currentdate_sql(self, expression: exp.CurrentDate) -> str:
|
||||
zone = self.sql(expression, "this")
|
||||
|
@ -2817,7 +2929,7 @@ class Generator(metaclass=_Generator):
|
|||
# Remove db from tables
|
||||
expression = expression.transform(
|
||||
lambda n: exp.table_(n.this) if isinstance(n, exp.Table) else n
|
||||
)
|
||||
).assert_is(exp.RenameTable)
|
||||
this = self.sql(expression, "this")
|
||||
return f"RENAME TO {this}"
|
||||
|
||||
|
@ -2889,30 +3001,6 @@ class Generator(metaclass=_Generator):
|
|||
kind = "MAX" if expression.args.get("max") else "MIN"
|
||||
return f"{this_sql} HAVING {kind} {expression_sql}"
|
||||
|
||||
def _embed_ignore_nulls(self, expression: exp.IgnoreNulls | exp.RespectNulls, text: str) -> str:
|
||||
if self.IGNORE_NULLS_IN_FUNC and not expression.meta.get("inline"):
|
||||
# The first modifier here will be the one closest to the AggFunc's arg
|
||||
mods = sorted(
|
||||
expression.find_all(exp.HavingMax, exp.Order, exp.Limit),
|
||||
key=lambda x: 0
|
||||
if isinstance(x, exp.HavingMax)
|
||||
else (1 if isinstance(x, exp.Order) else 2),
|
||||
)
|
||||
|
||||
if mods:
|
||||
mod = mods[0]
|
||||
this = expression.__class__(this=mod.this.copy())
|
||||
this.meta["inline"] = True
|
||||
mod.this.replace(this)
|
||||
return self.sql(expression.this)
|
||||
|
||||
agg_func = expression.find(exp.AggFunc)
|
||||
|
||||
if agg_func:
|
||||
return self.sql(agg_func)[:-1] + f" {text})"
|
||||
|
||||
return f"{self.sql(expression, 'this')} {text}"
|
||||
|
||||
def intdiv_sql(self, expression: exp.IntDiv) -> str:
|
||||
return self.sql(
|
||||
exp.Cast(
|
||||
|
@ -2933,9 +3021,7 @@ class Generator(metaclass=_Generator):
|
|||
r.replace(exp.Nullif(this=r.copy(), expression=exp.Literal.number(0)))
|
||||
|
||||
if self.dialect.TYPED_DIVISION and not expression.args.get("typed"):
|
||||
if not l.is_type(*exp.DataType.FLOAT_TYPES) and not r.is_type(
|
||||
*exp.DataType.FLOAT_TYPES
|
||||
):
|
||||
if not l.is_type(*exp.DataType.REAL_TYPES) and not r.is_type(*exp.DataType.REAL_TYPES):
|
||||
l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DOUBLE))
|
||||
|
||||
elif not self.dialect.TYPED_DIVISION and expression.args.get("typed"):
|
||||
|
@ -3019,9 +3105,6 @@ class Generator(metaclass=_Generator):
|
|||
def nullsafeneq_sql(self, expression: exp.NullSafeNEQ) -> str:
|
||||
return self.binary(expression, "IS DISTINCT FROM")
|
||||
|
||||
def or_sql(self, expression: exp.Or) -> str:
|
||||
return self.connector_sql(expression, "OR")
|
||||
|
||||
def slice_sql(self, expression: exp.Slice) -> str:
|
||||
return self.binary(expression, ":")
|
||||
|
||||
|
@ -3035,8 +3118,13 @@ class Generator(metaclass=_Generator):
|
|||
this = expression.this
|
||||
expr = expression.expression
|
||||
|
||||
if not self.dialect.LOG_BASE_FIRST:
|
||||
if self.dialect.LOG_BASE_FIRST is False:
|
||||
this, expr = expr, this
|
||||
elif self.dialect.LOG_BASE_FIRST is None and expr:
|
||||
if this.name in ("2", "10"):
|
||||
return self.func(f"LOG{this.name}", expr)
|
||||
|
||||
self.unsupported(f"Unsupported logarithm with base {self.sql(this)}")
|
||||
|
||||
return self.func("LOG", this, expr)
|
||||
|
||||
|
@ -3088,11 +3176,16 @@ class Generator(metaclass=_Generator):
|
|||
def text_width(self, args: t.Iterable) -> int:
|
||||
return sum(len(arg) for arg in args)
|
||||
|
||||
def format_time(self, expression: exp.Expression) -> t.Optional[str]:
|
||||
def format_time(
|
||||
self,
|
||||
expression: exp.Expression,
|
||||
inverse_time_mapping: t.Optional[t.Dict[str, str]] = None,
|
||||
inverse_time_trie: t.Optional[t.Dict] = None,
|
||||
) -> t.Optional[str]:
|
||||
return format_time(
|
||||
self.sql(expression, "format"),
|
||||
self.dialect.INVERSE_TIME_MAPPING,
|
||||
self.dialect.INVERSE_TIME_TRIE,
|
||||
inverse_time_mapping or self.dialect.INVERSE_TIME_MAPPING,
|
||||
inverse_time_trie or self.dialect.INVERSE_TIME_TRIE,
|
||||
)
|
||||
|
||||
def expressions(
|
||||
|
@ -3117,8 +3210,11 @@ class Generator(metaclass=_Generator):
|
|||
num_sqls = len(expressions)
|
||||
|
||||
# These are calculated once in case we have the leading_comma / pretty option set, correspondingly
|
||||
pad = " " * self.pad
|
||||
stripped_sep = sep.strip()
|
||||
if self.pretty:
|
||||
if self.leading_comma:
|
||||
pad = " " * len(sep)
|
||||
else:
|
||||
stripped_sep = sep.strip()
|
||||
|
||||
result_sqls = []
|
||||
for i, e in enumerate(expressions):
|
||||
|
@ -3154,13 +3250,6 @@ class Generator(metaclass=_Generator):
|
|||
self.unsupported(f"Unsupported property {expression.__class__.__name__}")
|
||||
return f"{property_name} {self.sql(expression, 'this')}"
|
||||
|
||||
def set_operation(self, expression: exp.Union, op: str) -> str:
|
||||
this = self.maybe_comment(self.sql(expression, "this"), comments=expression.comments)
|
||||
op = self.seg(op)
|
||||
return self.query_modifiers(
|
||||
expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
|
||||
)
|
||||
|
||||
def tag_sql(self, expression: exp.Tag) -> str:
|
||||
return f"{expression.args.get('prefix')}{self.sql(expression.this)}{expression.args.get('postfix')}"
|
||||
|
||||
|
@ -3227,6 +3316,18 @@ class Generator(metaclass=_Generator):
|
|||
|
||||
return self.sql(exp.cast(expression.this, "text"))
|
||||
|
||||
def tonumber_sql(self, expression: exp.ToNumber) -> str:
|
||||
if not self.SUPPORTS_TO_NUMBER:
|
||||
self.unsupported("Unsupported TO_NUMBER function")
|
||||
return self.sql(exp.cast(expression.this, "double"))
|
||||
|
||||
fmt = expression.args.get("format")
|
||||
if not fmt:
|
||||
self.unsupported("Conversion format is required for TO_NUMBER")
|
||||
return self.sql(exp.cast(expression.this, "double"))
|
||||
|
||||
return self.func("TO_NUMBER", expression.this, fmt)
|
||||
|
||||
def dictproperty_sql(self, expression: exp.DictProperty) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
kind = self.sql(expression, "kind")
|
||||
|
@ -3320,11 +3421,11 @@ class Generator(metaclass=_Generator):
|
|||
this = f" {this}" if this else ""
|
||||
index_type = self.sql(expression, "index_type")
|
||||
index_type = f" USING {index_type}" if index_type else ""
|
||||
schema = self.sql(expression, "schema")
|
||||
schema = f" {schema}" if schema else ""
|
||||
expressions = self.expressions(expression, flat=True)
|
||||
expressions = f" ({expressions})" if expressions else ""
|
||||
options = self.expressions(expression, key="options", sep=" ")
|
||||
options = f" {options}" if options else ""
|
||||
return f"{kind}{this}{index_type}{schema}{options}"
|
||||
return f"{kind}{this}{index_type}{expressions}{options}"
|
||||
|
||||
def nvl2_sql(self, expression: exp.Nvl2) -> str:
|
||||
if self.NVL2_SUPPORTED:
|
||||
|
@ -3396,6 +3497,13 @@ class Generator(metaclass=_Generator):
|
|||
|
||||
return self.sql(exp.cast(this, "time"))
|
||||
|
||||
def tsordstotimestamp_sql(self, expression: exp.TsOrDsToTimestamp) -> str:
|
||||
this = expression.this
|
||||
if isinstance(this, exp.TsOrDsToTimestamp) or this.is_type(exp.DataType.Type.TIMESTAMP):
|
||||
return self.sql(this)
|
||||
|
||||
return self.sql(exp.cast(this, "timestamp"))
|
||||
|
||||
def tsordstodate_sql(self, expression: exp.TsOrDsToDate) -> str:
|
||||
this = expression.this
|
||||
time_format = self.format_time(expression)
|
||||
|
@ -3430,6 +3538,13 @@ class Generator(metaclass=_Generator):
|
|||
|
||||
return self.func("LAST_DAY", expression.this)
|
||||
|
||||
def dateadd_sql(self, expression: exp.DateAdd) -> str:
|
||||
from sqlglot.dialects.dialect import unit_to_str
|
||||
|
||||
return self.func(
|
||||
"DATE_ADD", expression.this, expression.expression, unit_to_str(expression)
|
||||
)
|
||||
|
||||
def arrayany_sql(self, expression: exp.ArrayAny) -> str:
|
||||
if self.CAN_IMPLEMENT_ARRAY_ANY:
|
||||
filtered = exp.ArrayFilter(this=expression.this, expression=expression.expression)
|
||||
|
@ -3445,30 +3560,6 @@ class Generator(metaclass=_Generator):
|
|||
|
||||
return self.function_fallback_sql(expression)
|
||||
|
||||
def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str:
|
||||
this = expression.this
|
||||
if isinstance(this, exp.JSONPathWildcard):
|
||||
this = self.json_path_part(this)
|
||||
return f".{this}" if this else ""
|
||||
|
||||
if exp.SAFE_IDENTIFIER_RE.match(this):
|
||||
return f".{this}"
|
||||
|
||||
this = self.json_path_part(this)
|
||||
return f"[{this}]" if self.JSON_PATH_BRACKETED_KEY_SUPPORTED else f".{this}"
|
||||
|
||||
def _jsonpathsubscript_sql(self, expression: exp.JSONPathSubscript) -> str:
|
||||
this = self.json_path_part(expression.this)
|
||||
return f"[{this}]" if this else ""
|
||||
|
||||
def _simplify_unless_literal(self, expression: E) -> E:
|
||||
if not isinstance(expression, exp.Literal):
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
expression = simplify(expression, dialect=self.dialect)
|
||||
|
||||
return expression
|
||||
|
||||
def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
|
||||
expression.set("is_end_exclusive", None)
|
||||
return self.function_fallback_sql(expression)
|
||||
|
@ -3477,7 +3568,9 @@ class Generator(metaclass=_Generator):
|
|||
expression.set(
|
||||
"expressions",
|
||||
[
|
||||
exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
|
||||
exp.alias_(e.expression, e.name if e.this.is_string else e.this)
|
||||
if isinstance(e, exp.PropertyEQ)
|
||||
else e
|
||||
for e in expression.expressions
|
||||
],
|
||||
)
|
||||
|
@ -3553,3 +3646,51 @@ class Generator(metaclass=_Generator):
|
|||
transformed = cast(this=value, to=to, safe=safe)
|
||||
|
||||
return self.sql(transformed)
|
||||
|
||||
def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str:
|
||||
this = expression.this
|
||||
if isinstance(this, exp.JSONPathWildcard):
|
||||
this = self.json_path_part(this)
|
||||
return f".{this}" if this else ""
|
||||
|
||||
if exp.SAFE_IDENTIFIER_RE.match(this):
|
||||
return f".{this}"
|
||||
|
||||
this = self.json_path_part(this)
|
||||
return f"[{this}]" if self.JSON_PATH_BRACKETED_KEY_SUPPORTED else f".{this}"
|
||||
|
||||
def _jsonpathsubscript_sql(self, expression: exp.JSONPathSubscript) -> str:
|
||||
this = self.json_path_part(expression.this)
|
||||
return f"[{this}]" if this else ""
|
||||
|
||||
def _simplify_unless_literal(self, expression: E) -> E:
|
||||
if not isinstance(expression, exp.Literal):
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
expression = simplify(expression, dialect=self.dialect)
|
||||
|
||||
return expression
|
||||
|
||||
def _embed_ignore_nulls(self, expression: exp.IgnoreNulls | exp.RespectNulls, text: str) -> str:
|
||||
if self.IGNORE_NULLS_IN_FUNC and not expression.meta.get("inline"):
|
||||
# The first modifier here will be the one closest to the AggFunc's arg
|
||||
mods = sorted(
|
||||
expression.find_all(exp.HavingMax, exp.Order, exp.Limit),
|
||||
key=lambda x: 0
|
||||
if isinstance(x, exp.HavingMax)
|
||||
else (1 if isinstance(x, exp.Order) else 2),
|
||||
)
|
||||
|
||||
if mods:
|
||||
mod = mods[0]
|
||||
this = expression.__class__(this=mod.this.copy())
|
||||
this.meta["inline"] = True
|
||||
mod.this.replace(this)
|
||||
return self.sql(expression.this)
|
||||
|
||||
agg_func = expression.find(exp.AggFunc)
|
||||
|
||||
if agg_func:
|
||||
return self.sql(agg_func)[:-1] + f" {text})"
|
||||
|
||||
return f"{self.sql(expression, 'this')} {text}"
|
||||
|
|
|
@ -181,7 +181,7 @@ def apply_index_offset(
|
|||
annotate_types(expression)
|
||||
if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
|
||||
logger.warning("Applying array index offset (%s)", offset)
|
||||
expression = simplify(exp.Add(this=expression, expression=exp.Literal.number(offset)))
|
||||
expression = simplify(expression + offset)
|
||||
return [expression]
|
||||
|
||||
return expressions
|
||||
|
@ -204,13 +204,13 @@ def while_changing(expression: Expression, func: t.Callable[[Expression], E]) ->
|
|||
The transformed expression.
|
||||
"""
|
||||
while True:
|
||||
for n, *_ in reversed(tuple(expression.walk())):
|
||||
for n in reversed(tuple(expression.walk())):
|
||||
n._hash = hash(n)
|
||||
|
||||
start = hash(expression)
|
||||
expression = func(expression)
|
||||
|
||||
for n, *_ in expression.walk():
|
||||
for n in expression.walk():
|
||||
n._hash = None
|
||||
if start == hash(expression):
|
||||
break
|
||||
|
@ -317,8 +317,16 @@ def find_new_name(taken: t.Collection[str], base: str) -> str:
|
|||
|
||||
|
||||
def is_int(text: str) -> bool:
|
||||
return is_type(text, int)
|
||||
|
||||
|
||||
def is_float(text: str) -> bool:
|
||||
return is_type(text, float)
|
||||
|
||||
|
||||
def is_type(text: str, target_type: t.Type) -> bool:
|
||||
try:
|
||||
int(text)
|
||||
target_type(text)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
|
|
@ -28,10 +28,7 @@ class Node:
|
|||
yield self
|
||||
|
||||
for d in self.downstream:
|
||||
if isinstance(d, Node):
|
||||
yield from d.walk()
|
||||
else:
|
||||
yield d
|
||||
yield from d.walk()
|
||||
|
||||
def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML:
|
||||
nodes = {}
|
||||
|
@ -71,8 +68,10 @@ def lineage(
|
|||
column: str | exp.Column,
|
||||
sql: str | exp.Expression,
|
||||
schema: t.Optional[t.Dict | Schema] = None,
|
||||
sources: t.Optional[t.Dict[str, str | exp.Query]] = None,
|
||||
sources: t.Optional[t.Mapping[str, str | exp.Query]] = None,
|
||||
dialect: DialectType = None,
|
||||
scope: t.Optional[Scope] = None,
|
||||
trim_selects: bool = True,
|
||||
**kwargs,
|
||||
) -> Node:
|
||||
"""Build the lineage graph for a column of a SQL query.
|
||||
|
@ -83,6 +82,8 @@ def lineage(
|
|||
schema: The schema of tables.
|
||||
sources: A mapping of queries which will be used to continue building lineage.
|
||||
dialect: The dialect of input SQL.
|
||||
scope: A pre-created scope to use instead.
|
||||
trim_selects: Whether or not to clean up selects by trimming to only relevant columns.
|
||||
**kwargs: Qualification optimizer kwargs.
|
||||
|
||||
Returns:
|
||||
|
@ -99,14 +100,15 @@ def lineage(
|
|||
dialect=dialect,
|
||||
)
|
||||
|
||||
qualified = qualify.qualify(
|
||||
expression,
|
||||
dialect=dialect,
|
||||
schema=schema,
|
||||
**{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore
|
||||
)
|
||||
if not scope:
|
||||
expression = qualify.qualify(
|
||||
expression,
|
||||
dialect=dialect,
|
||||
schema=schema,
|
||||
**{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore
|
||||
)
|
||||
|
||||
scope = build_scope(qualified)
|
||||
scope = build_scope(expression)
|
||||
|
||||
if not scope:
|
||||
raise SqlglotError("Cannot build lineage, sql must be SELECT")
|
||||
|
@ -114,7 +116,7 @@ def lineage(
|
|||
if not any(select.alias_or_name == column for select in scope.expression.selects):
|
||||
raise SqlglotError(f"Cannot find column '{column}' in query.")
|
||||
|
||||
return to_node(column, scope, dialect)
|
||||
return to_node(column, scope, dialect, trim_selects=trim_selects)
|
||||
|
||||
|
||||
def to_node(
|
||||
|
@ -125,6 +127,7 @@ def to_node(
|
|||
upstream: t.Optional[Node] = None,
|
||||
source_name: t.Optional[str] = None,
|
||||
reference_node_name: t.Optional[str] = None,
|
||||
trim_selects: bool = True,
|
||||
) -> Node:
|
||||
source_names = {
|
||||
dt.alias: dt.comments[0].split()[1]
|
||||
|
@ -143,6 +146,17 @@ def to_node(
|
|||
)
|
||||
)
|
||||
|
||||
if isinstance(scope.expression, exp.Subquery):
|
||||
for source in scope.subquery_scopes:
|
||||
return to_node(
|
||||
column,
|
||||
scope=source,
|
||||
dialect=dialect,
|
||||
upstream=upstream,
|
||||
source_name=source_name,
|
||||
reference_node_name=reference_node_name,
|
||||
trim_selects=trim_selects,
|
||||
)
|
||||
if isinstance(scope.expression, exp.Union):
|
||||
upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
|
||||
|
||||
|
@ -170,11 +184,12 @@ def to_node(
|
|||
upstream=upstream,
|
||||
source_name=source_name,
|
||||
reference_node_name=reference_node_name,
|
||||
trim_selects=trim_selects,
|
||||
)
|
||||
|
||||
return upstream
|
||||
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
if trim_selects and isinstance(scope.expression, exp.Select):
|
||||
# For better ergonomics in our node labels, replace the full select with
|
||||
# a version that has only the column we care about.
|
||||
# "x", SELECT x, y FROM foo
|
||||
|
@ -206,7 +221,13 @@ def to_node(
|
|||
continue
|
||||
|
||||
for name in subquery.named_selects:
|
||||
to_node(name, scope=subquery_scope, dialect=dialect, upstream=node)
|
||||
to_node(
|
||||
name,
|
||||
scope=subquery_scope,
|
||||
dialect=dialect,
|
||||
upstream=node,
|
||||
trim_selects=trim_selects,
|
||||
)
|
||||
|
||||
# if the select is a star add all scope sources as downstreams
|
||||
if select.is_star:
|
||||
|
@ -237,6 +258,7 @@ def to_node(
|
|||
upstream=node,
|
||||
source_name=source_names.get(table) or source_name,
|
||||
reference_node_name=selected_node.name if selected_node else None,
|
||||
trim_selects=trim_selects,
|
||||
)
|
||||
else:
|
||||
# The source is not a scope - we've reached the end of the line. At this point, if a source is not found
|
||||
|
|
|
@ -168,8 +168,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.Exp,
|
||||
exp.Ln,
|
||||
exp.Log,
|
||||
exp.Log2,
|
||||
exp.Log10,
|
||||
exp.Pow,
|
||||
exp.Quantile,
|
||||
exp.Round,
|
||||
|
@ -266,26 +264,30 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.Dot: lambda self, e: self._annotate_dot(e),
|
||||
exp.Explode: lambda self, e: self._annotate_explode(e),
|
||||
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
|
||||
exp.GenerateDateArray: lambda self, e: self._annotate_with_type(
|
||||
e, exp.DataType.build("ARRAY<DATE>")
|
||||
),
|
||||
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
|
||||
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
|
||||
exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
|
||||
exp.Literal: lambda self, e: self._annotate_literal(e),
|
||||
exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
|
||||
exp.Map: lambda self, e: self._annotate_map(e),
|
||||
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
|
||||
exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
|
||||
exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"),
|
||||
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
|
||||
exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True),
|
||||
exp.Struct: lambda self, e: self._annotate_struct(e),
|
||||
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
|
||||
exp.Timestamp: lambda self, e: self._annotate_with_type(
|
||||
e,
|
||||
exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP,
|
||||
),
|
||||
exp.ToMap: lambda self, e: self._annotate_to_map(e),
|
||||
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
|
||||
exp.Unnest: lambda self, e: self._annotate_unnest(e),
|
||||
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
|
||||
exp.VarMap: lambda self, e: self._annotate_map(e),
|
||||
}
|
||||
|
||||
NESTED_TYPES = {
|
||||
|
@ -358,6 +360,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
if isinstance(source.expression, exp.Lateral):
|
||||
if isinstance(source.expression.this, exp.Explode):
|
||||
values = [source.expression.this.this]
|
||||
elif isinstance(source.expression, exp.Unnest):
|
||||
values = [source.expression]
|
||||
else:
|
||||
values = source.expression.expressions[0].expressions
|
||||
|
||||
|
@ -408,7 +412,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
)
|
||||
|
||||
def _annotate_args(self, expression: E) -> E:
|
||||
for _, value in expression.iter_expressions():
|
||||
for value in expression.iter_expressions():
|
||||
self._maybe_annotate(value)
|
||||
|
||||
return expression
|
||||
|
@ -425,23 +429,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
|
||||
return exp.DataType.Type.UNKNOWN
|
||||
|
||||
if type1_value in self.NESTED_TYPES:
|
||||
return type1
|
||||
if type2_value in self.NESTED_TYPES:
|
||||
return type2
|
||||
return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value
|
||||
|
||||
return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore
|
||||
|
||||
# Note: the following "no_type_check" decorators were added because mypy was yelling due
|
||||
# to assigning Type values to expression.type (since its getter returns Optional[DataType]).
|
||||
# This is a known mypy issue: https://github.com/python/mypy/issues/3004
|
||||
|
||||
@t.no_type_check
|
||||
def _annotate_binary(self, expression: B) -> B:
|
||||
self._annotate_args(expression)
|
||||
|
||||
left, right = expression.left, expression.right
|
||||
left_type, right_type = left.type.this, right.type.this
|
||||
left_type, right_type = left.type.this, right.type.this # type: ignore
|
||||
|
||||
if isinstance(expression, exp.Connector):
|
||||
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
|
||||
|
@ -462,7 +456,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
return expression
|
||||
|
||||
@t.no_type_check
|
||||
def _annotate_unary(self, expression: E) -> E:
|
||||
self._annotate_args(expression)
|
||||
|
||||
|
@ -473,7 +466,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
return expression
|
||||
|
||||
@t.no_type_check
|
||||
def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
|
||||
if expression.is_string:
|
||||
self._set_type(expression, exp.DataType.Type.VARCHAR)
|
||||
|
@ -484,25 +476,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
return expression
|
||||
|
||||
@t.no_type_check
|
||||
def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
|
||||
self._set_type(expression, target_type)
|
||||
return self._annotate_args(expression)
|
||||
|
||||
@t.no_type_check
|
||||
def _annotate_struct_value(
|
||||
self, expression: exp.Expression
|
||||
) -> t.Optional[exp.DataType] | exp.ColumnDef:
|
||||
alias = expression.args.get("alias")
|
||||
if alias:
|
||||
return exp.ColumnDef(this=alias.copy(), kind=expression.type)
|
||||
|
||||
# Case: key = value or key := value
|
||||
if expression.expression:
|
||||
return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)
|
||||
|
||||
return expression.type
|
||||
|
||||
@t.no_type_check
|
||||
def _annotate_by_args(
|
||||
self,
|
||||
|
@ -510,7 +487,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
*args: str,
|
||||
promote: bool = False,
|
||||
array: bool = False,
|
||||
struct: bool = False,
|
||||
) -> E:
|
||||
self._annotate_args(expression)
|
||||
|
||||
|
@ -546,16 +522,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
),
|
||||
)
|
||||
|
||||
if struct:
|
||||
self._set_type(
|
||||
expression,
|
||||
exp.DataType(
|
||||
this=exp.DataType.Type.STRUCT,
|
||||
expressions=[self._annotate_struct_value(expr) for expr in expressions],
|
||||
nested=True,
|
||||
),
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
def _annotate_timeunit(
|
||||
|
@ -605,6 +571,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
self._set_type(expression, exp.DataType.Type.BIGINT)
|
||||
else:
|
||||
self._set_type(expression, self._maybe_coerce(left_type, right_type))
|
||||
if expression.type and expression.type.this not in exp.DataType.REAL_TYPES:
|
||||
self._set_type(
|
||||
expression, self._maybe_coerce(expression.type, exp.DataType.Type.DOUBLE)
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
|
@ -631,3 +601,68 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
child = seq_get(expression.expressions, 0)
|
||||
self._set_type(expression, child and seq_get(child.type.expressions, 0))
|
||||
return expression
|
||||
|
||||
def _annotate_struct_value(
|
||||
self, expression: exp.Expression
|
||||
) -> t.Optional[exp.DataType] | exp.ColumnDef:
|
||||
alias = expression.args.get("alias")
|
||||
if alias:
|
||||
return exp.ColumnDef(this=alias.copy(), kind=expression.type)
|
||||
|
||||
# Case: key = value or key := value
|
||||
if expression.expression:
|
||||
return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)
|
||||
|
||||
return expression.type
|
||||
|
||||
def _annotate_struct(self, expression: exp.Struct) -> exp.Struct:
|
||||
self._annotate_args(expression)
|
||||
self._set_type(
|
||||
expression,
|
||||
exp.DataType(
|
||||
this=exp.DataType.Type.STRUCT,
|
||||
expressions=[self._annotate_struct_value(expr) for expr in expression.expressions],
|
||||
nested=True,
|
||||
),
|
||||
)
|
||||
return expression
|
||||
|
||||
@t.overload
|
||||
def _annotate_map(self, expression: exp.Map) -> exp.Map: ...
|
||||
|
||||
@t.overload
|
||||
def _annotate_map(self, expression: exp.VarMap) -> exp.VarMap: ...
|
||||
|
||||
def _annotate_map(self, expression):
|
||||
self._annotate_args(expression)
|
||||
|
||||
keys = expression.args.get("keys")
|
||||
values = expression.args.get("values")
|
||||
|
||||
map_type = exp.DataType(this=exp.DataType.Type.MAP)
|
||||
if isinstance(keys, exp.Array) and isinstance(values, exp.Array):
|
||||
key_type = seq_get(keys.type.expressions, 0) or exp.DataType.Type.UNKNOWN
|
||||
value_type = seq_get(values.type.expressions, 0) or exp.DataType.Type.UNKNOWN
|
||||
|
||||
if key_type != exp.DataType.Type.UNKNOWN and value_type != exp.DataType.Type.UNKNOWN:
|
||||
map_type.set("expressions", [key_type, value_type])
|
||||
map_type.set("nested", True)
|
||||
|
||||
self._set_type(expression, map_type)
|
||||
return expression
|
||||
|
||||
def _annotate_to_map(self, expression: exp.ToMap) -> exp.ToMap:
|
||||
self._annotate_args(expression)
|
||||
|
||||
map_type = exp.DataType(this=exp.DataType.Type.MAP)
|
||||
arg = expression.this
|
||||
if arg.is_type(exp.DataType.Type.STRUCT):
|
||||
for coldef in arg.type.expressions:
|
||||
kind = coldef.kind
|
||||
if kind != exp.DataType.Type.UNKNOWN:
|
||||
map_type.set("expressions", [exp.DataType.build("varchar"), kind])
|
||||
map_type.set("nested", True)
|
||||
break
|
||||
|
||||
self._set_type(expression, map_type)
|
||||
return expression
|
||||
|
|
|
@ -16,16 +16,17 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
|
|||
Args:
|
||||
expression: The expression to canonicalize.
|
||||
"""
|
||||
exp.replace_children(expression, canonicalize)
|
||||
|
||||
expression = add_text_to_concat(expression)
|
||||
expression = replace_date_funcs(expression)
|
||||
expression = coerce_type(expression)
|
||||
expression = remove_redundant_casts(expression)
|
||||
expression = ensure_bools(expression, _replace_int_predicate)
|
||||
expression = remove_ascending_order(expression)
|
||||
def _canonicalize(expression: exp.Expression) -> exp.Expression:
|
||||
expression = add_text_to_concat(expression)
|
||||
expression = replace_date_funcs(expression)
|
||||
expression = coerce_type(expression)
|
||||
expression = remove_redundant_casts(expression)
|
||||
expression = ensure_bools(expression, _replace_int_predicate)
|
||||
expression = remove_ascending_order(expression)
|
||||
return expression
|
||||
|
||||
return expression
|
||||
return exp.replace_tree(expression, _canonicalize)
|
||||
|
||||
|
||||
def add_text_to_concat(node: exp.Expression) -> exp.Expression:
|
||||
|
@ -35,7 +36,11 @@ def add_text_to_concat(node: exp.Expression) -> exp.Expression:
|
|||
|
||||
|
||||
def replace_date_funcs(node: exp.Expression) -> exp.Expression:
|
||||
if isinstance(node, exp.Date) and not node.expressions and not node.args.get("zone"):
|
||||
if (
|
||||
isinstance(node, (exp.Date, exp.TsOrDsToDate))
|
||||
and not node.expressions
|
||||
and not node.args.get("zone")
|
||||
):
|
||||
return exp.cast(node.this, to=exp.DataType.Type.DATE)
|
||||
if isinstance(node, exp.Timestamp) and not node.expression:
|
||||
if not node.type:
|
||||
|
@ -121,15 +126,11 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
|
|||
a = _coerce_timeunit_arg(a, b.unit)
|
||||
if (
|
||||
a.type
|
||||
and a.type.this == exp.DataType.Type.DATE
|
||||
and a.type.this in exp.DataType.TEMPORAL_TYPES
|
||||
and b.type
|
||||
and b.type.this
|
||||
not in (
|
||||
exp.DataType.Type.DATE,
|
||||
exp.DataType.Type.INTERVAL,
|
||||
)
|
||||
and b.type.this in exp.DataType.TEXT_TYPES
|
||||
):
|
||||
_replace_cast(b, exp.DataType.Type.DATE)
|
||||
_replace_cast(b, exp.DataType.Type.DATETIME)
|
||||
|
||||
|
||||
def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression:
|
||||
|
@ -169,7 +170,7 @@ def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None:
|
|||
# with y as (select true as x) select x = 0 FROM y -- illegal presto query
|
||||
def _replace_int_predicate(expression: exp.Expression) -> None:
|
||||
if isinstance(expression, exp.Coalesce):
|
||||
for _, child in expression.iter_expressions():
|
||||
for child in expression.iter_expressions():
|
||||
_replace_int_predicate(child)
|
||||
elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
|
||||
expression.replace(expression.neq(0))
|
||||
|
|
|
@ -32,7 +32,7 @@ def eliminate_ctes(expression):
|
|||
cte_node.pop()
|
||||
|
||||
# Pop the entire WITH clause if this is the last CTE
|
||||
if len(with_node.expressions) <= 0:
|
||||
if with_node and len(with_node.expressions) <= 0:
|
||||
with_node.pop()
|
||||
|
||||
# Decrement the ref count for all sources this CTE selects from
|
||||
|
|
|
@ -214,6 +214,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
|||
and not _outer_select_joins_on_inner_select_join()
|
||||
and not _is_a_window_expression_in_unmergable_operation()
|
||||
and not _is_recursive()
|
||||
and not (inner_select.args.get("order") and outer_scope.is_union)
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
|
|||
Returns:
|
||||
sqlglot.Expression: normalized expression
|
||||
"""
|
||||
for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
|
||||
for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))):
|
||||
if isinstance(node, exp.Connector):
|
||||
if normalized(node, dnf=dnf):
|
||||
continue
|
||||
|
|
|
@ -53,10 +53,8 @@ def normalize_identifiers(expression, dialect=None):
|
|||
if isinstance(expression, str):
|
||||
expression = exp.parse_identifier(expression, dialect=dialect)
|
||||
|
||||
def _normalize(node: E) -> E:
|
||||
for node in expression.walk(prune=lambda n: n.meta.get("case_sensitive")):
|
||||
if not node.meta.get("case_sensitive"):
|
||||
exp.replace_children(node, _normalize)
|
||||
node = dialect.normalize_identifier(node)
|
||||
return node
|
||||
dialect.normalize_identifier(node)
|
||||
|
||||
return _normalize(expression)
|
||||
return expression
|
||||
|
|
|
@ -82,13 +82,13 @@ def optimize(
|
|||
**kwargs,
|
||||
}
|
||||
|
||||
expression = exp.maybe_parse(expression, dialect=dialect, copy=True)
|
||||
optimized = exp.maybe_parse(expression, dialect=dialect, copy=True)
|
||||
for rule in rules:
|
||||
# Find any additional rule parameters, beyond `expression`
|
||||
rule_params = rule.__code__.co_varnames
|
||||
rule_kwargs = {
|
||||
param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
|
||||
}
|
||||
expression = rule(expression, **rule_kwargs)
|
||||
optimized = rule(optimized, **rule_kwargs)
|
||||
|
||||
return t.cast(exp.Expression, expression)
|
||||
return optimized
|
||||
|
|
|
@ -77,13 +77,13 @@ def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
|
|||
pushdown_dnf(predicates, sources, scope_ref_count)
|
||||
|
||||
|
||||
def pushdown_cnf(predicates, scope, scope_ref_count, join_index=None):
|
||||
def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
|
||||
"""
|
||||
If the predicates are in CNF like form, we can simply replace each block in the parent.
|
||||
"""
|
||||
join_index = join_index or {}
|
||||
for predicate in predicates:
|
||||
for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
|
||||
for node in nodes_for_predicate(predicate, sources, scope_ref_count).values():
|
||||
if isinstance(node, exp.Join):
|
||||
name = node.alias_or_name
|
||||
predicate_tables = exp.column_table_names(predicate, name)
|
||||
|
@ -103,7 +103,7 @@ def pushdown_cnf(predicates, scope, scope_ref_count, join_index=None):
|
|||
node.where(inner_predicate, copy=False)
|
||||
|
||||
|
||||
def pushdown_dnf(predicates, scope, scope_ref_count):
|
||||
def pushdown_dnf(predicates, sources, scope_ref_count):
|
||||
"""
|
||||
If the predicates are in DNF form, we can only push down conditions that are in all blocks.
|
||||
Additionally, we can't remove predicates from their original form.
|
||||
|
@ -127,7 +127,7 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
|
|||
# pushdown all predicates to their respective nodes
|
||||
for table in sorted(pushdown_tables):
|
||||
for predicate in predicates:
|
||||
nodes = nodes_for_predicate(predicate, scope, scope_ref_count)
|
||||
nodes = nodes_for_predicate(predicate, sources, scope_ref_count)
|
||||
|
||||
if table not in nodes:
|
||||
continue
|
||||
|
|
|
@ -54,11 +54,15 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
|
|||
if any(select.is_star for select in right.expression.selects):
|
||||
referenced_columns[right] = parent_selections
|
||||
elif not any(select.is_star for select in left.expression.selects):
|
||||
referenced_columns[right] = [
|
||||
right.expression.selects[i].alias_or_name
|
||||
for i, select in enumerate(left.expression.selects)
|
||||
if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections
|
||||
]
|
||||
if scope.expression.args.get("by_name"):
|
||||
referenced_columns[right] = referenced_columns[left]
|
||||
else:
|
||||
referenced_columns[right] = [
|
||||
right.expression.selects[i].alias_or_name
|
||||
for i, select in enumerate(left.expression.selects)
|
||||
if SELECT_ALL in parent_selections
|
||||
or select.alias_or_name in parent_selections
|
||||
]
|
||||
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
if remove_unused_selections:
|
||||
|
|
|
@ -209,7 +209,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
|
|||
if not node:
|
||||
return
|
||||
|
||||
for column, *_ in walk_in_scope(node, prune=lambda node, *_: node.is_star):
|
||||
for column in walk_in_scope(node, prune=lambda node: node.is_star):
|
||||
if not isinstance(column, exp.Column):
|
||||
continue
|
||||
|
||||
|
@ -306,7 +306,7 @@ def _expand_positional_references(
|
|||
else:
|
||||
select = select.this
|
||||
|
||||
if isinstance(select, exp.Literal):
|
||||
if isinstance(select, exp.CONSTANTS) or select.find(exp.Explode, exp.Unnest):
|
||||
new_nodes.append(node)
|
||||
else:
|
||||
new_nodes.append(select.copy())
|
||||
|
@ -425,7 +425,7 @@ def _expand_stars(
|
|||
raise OptimizeError(f"Unknown table: {table}")
|
||||
|
||||
columns = resolver.get_source_columns(table, only_visible=True)
|
||||
columns = columns or scope.outer_column_list
|
||||
columns = columns or scope.outer_columns
|
||||
|
||||
if pseudocolumns:
|
||||
columns = [name for name in columns if name.upper() not in pseudocolumns]
|
||||
|
@ -517,7 +517,7 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
|
|||
|
||||
new_selections = []
|
||||
for i, (selection, aliased_column) in enumerate(
|
||||
itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
|
||||
itertools.zip_longest(scope.expression.selects, scope.outer_columns)
|
||||
):
|
||||
if selection is None:
|
||||
break
|
||||
|
@ -544,7 +544,7 @@ def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool
|
|||
"""Makes sure all identifiers that need to be quoted are quoted."""
|
||||
return expression.transform(
|
||||
Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
|
||||
|
|
|
@ -56,7 +56,7 @@ def qualify_tables(
|
|||
table.set("catalog", catalog)
|
||||
|
||||
if not isinstance(expression, exp.Query):
|
||||
for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Query)):
|
||||
for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)):
|
||||
if isinstance(node, exp.Table):
|
||||
_qualify(node)
|
||||
|
||||
|
@ -118,11 +118,11 @@ def qualify_tables(
|
|||
for i, e in enumerate(udtf.expressions[0].expressions):
|
||||
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
|
||||
else:
|
||||
for node, parent, _ in scope.walk():
|
||||
for node in scope.walk():
|
||||
if (
|
||||
isinstance(node, exp.Table)
|
||||
and not node.alias
|
||||
and isinstance(parent, (exp.From, exp.Join))
|
||||
and isinstance(node.parent, (exp.From, exp.Join))
|
||||
):
|
||||
# Mutates the table by attaching an alias to it
|
||||
alias(node, node.name, copy=False, table=True)
|
||||
|
|
|
@ -8,7 +8,7 @@ from enum import Enum, auto
|
|||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import ensure_collection, find_new_name
|
||||
from sqlglot.helper import ensure_collection, find_new_name, seq_get
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
|
@ -38,11 +38,11 @@ class Scope:
|
|||
SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
|
||||
The LATERAL VIEW EXPLODE gets x as a source.
|
||||
cte_sources (dict[str, Scope]): Sources from CTES
|
||||
outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
|
||||
defines a column list of it's alias of this scope, this is that list of columns.
|
||||
outer_columns (list[str]): If this is a derived table or CTE, and the outer query
|
||||
defines a column list for the alias of this scope, this is that list of columns.
|
||||
For example:
|
||||
SELECT * FROM (SELECT ...) AS y(col1, col2)
|
||||
The inner query would have `["col1", "col2"]` for its `outer_column_list`
|
||||
The inner query would have `["col1", "col2"]` for its `outer_columns`
|
||||
parent (Scope): Parent scope
|
||||
scope_type (ScopeType): Type of this scope, relative to it's parent
|
||||
subquery_scopes (list[Scope]): List of all child scopes for subqueries
|
||||
|
@ -58,7 +58,7 @@ class Scope:
|
|||
self,
|
||||
expression,
|
||||
sources=None,
|
||||
outer_column_list=None,
|
||||
outer_columns=None,
|
||||
parent=None,
|
||||
scope_type=ScopeType.ROOT,
|
||||
lateral_sources=None,
|
||||
|
@ -70,7 +70,7 @@ class Scope:
|
|||
self.cte_sources = cte_sources or {}
|
||||
self.sources.update(self.lateral_sources)
|
||||
self.sources.update(self.cte_sources)
|
||||
self.outer_column_list = outer_column_list or []
|
||||
self.outer_columns = outer_columns or []
|
||||
self.parent = parent
|
||||
self.scope_type = scope_type
|
||||
self.subquery_scopes = []
|
||||
|
@ -119,10 +119,11 @@ class Scope:
|
|||
self._raw_columns = []
|
||||
self._join_hints = []
|
||||
|
||||
for node, parent, _ in self.walk(bfs=False):
|
||||
for node in self.walk(bfs=False):
|
||||
if node is self.expression:
|
||||
continue
|
||||
elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
|
||||
|
||||
if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
|
||||
self._raw_columns.append(node)
|
||||
elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
|
||||
self._tables.append(node)
|
||||
|
@ -132,10 +133,8 @@ class Scope:
|
|||
self._udtfs.append(node)
|
||||
elif isinstance(node, exp.CTE):
|
||||
self._ctes.append(node)
|
||||
elif (
|
||||
isinstance(node, exp.Subquery)
|
||||
and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
|
||||
and _is_derived_table(node)
|
||||
elif _is_derived_table(node) and isinstance(
|
||||
node.parent, (exp.From, exp.Join, exp.Subquery)
|
||||
):
|
||||
self._derived_tables.append(node)
|
||||
elif isinstance(node, exp.UNWRAPPED_QUERIES):
|
||||
|
@ -438,11 +437,21 @@ class Scope:
|
|||
Yields:
|
||||
Scope: scope instances in depth-first-search post-order
|
||||
"""
|
||||
for child_scope in itertools.chain(
|
||||
self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
|
||||
):
|
||||
yield from child_scope.traverse()
|
||||
yield self
|
||||
stack = [self]
|
||||
result = []
|
||||
while stack:
|
||||
scope = stack.pop()
|
||||
result.append(scope)
|
||||
stack.extend(
|
||||
itertools.chain(
|
||||
scope.cte_scopes,
|
||||
scope.union_scopes,
|
||||
scope.table_scopes,
|
||||
scope.subquery_scopes,
|
||||
)
|
||||
)
|
||||
|
||||
yield from reversed(result)
|
||||
|
||||
def ref_count(self):
|
||||
"""
|
||||
|
@ -481,14 +490,28 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
|
|||
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
|
||||
|
||||
Args:
|
||||
expression (exp.Expression): expression to traverse
|
||||
expression: Expression to traverse
|
||||
|
||||
Returns:
|
||||
list[Scope]: scope instances
|
||||
A list of the created scope instances
|
||||
"""
|
||||
if isinstance(expression, exp.Query) or (
|
||||
isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query)
|
||||
):
|
||||
if isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query):
|
||||
# We ignore the DDL expression and build a scope for its query instead
|
||||
ddl_with = expression.args.get("with")
|
||||
expression = expression.expression
|
||||
|
||||
# If the DDL has CTEs attached, we need to add them to the query, or
|
||||
# prepend them if the query itself already has CTEs attached to it
|
||||
if ddl_with:
|
||||
ddl_with.pop()
|
||||
query_ctes = expression.ctes
|
||||
if not query_ctes:
|
||||
expression.set("with", ddl_with)
|
||||
else:
|
||||
expression.args["with"].set("recursive", ddl_with.recursive)
|
||||
expression.args["with"].set("expressions", [*ddl_with.expressions, *query_ctes])
|
||||
|
||||
if isinstance(expression, exp.Query):
|
||||
return list(_traverse_scope(Scope(expression)))
|
||||
|
||||
return []
|
||||
|
@ -499,21 +522,21 @@ def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
|
|||
Build a scope tree.
|
||||
|
||||
Args:
|
||||
expression (exp.Expression): expression to build the scope tree for
|
||||
expression: Expression to build the scope tree for.
|
||||
|
||||
Returns:
|
||||
Scope: root scope
|
||||
The root scope
|
||||
"""
|
||||
scopes = traverse_scope(expression)
|
||||
if scopes:
|
||||
return scopes[-1]
|
||||
return None
|
||||
return seq_get(traverse_scope(expression), -1)
|
||||
|
||||
|
||||
def _traverse_scope(scope):
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
yield from _traverse_select(scope)
|
||||
elif isinstance(scope.expression, exp.Union):
|
||||
yield from _traverse_ctes(scope)
|
||||
yield from _traverse_union(scope)
|
||||
return
|
||||
elif isinstance(scope.expression, exp.Subquery):
|
||||
if scope.is_root:
|
||||
yield from _traverse_select(scope)
|
||||
|
@ -523,8 +546,6 @@ def _traverse_scope(scope):
|
|||
yield from _traverse_tables(scope)
|
||||
elif isinstance(scope.expression, exp.UDTF):
|
||||
yield from _traverse_udtfs(scope)
|
||||
elif isinstance(scope.expression, exp.DDL):
|
||||
yield from _traverse_ddl(scope)
|
||||
else:
|
||||
logger.warning(
|
||||
"Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
|
||||
|
@ -541,30 +562,38 @@ def _traverse_select(scope):
|
|||
|
||||
|
||||
def _traverse_union(scope):
|
||||
yield from _traverse_ctes(scope)
|
||||
prev_scope = None
|
||||
union_scope_stack = [scope]
|
||||
expression_stack = [scope.expression.right, scope.expression.left]
|
||||
|
||||
# The last scope to be yield should be the top most scope
|
||||
left = None
|
||||
for left in _traverse_scope(
|
||||
scope.branch(
|
||||
scope.expression.left,
|
||||
outer_column_list=scope.outer_column_list,
|
||||
while expression_stack:
|
||||
expression = expression_stack.pop()
|
||||
union_scope = union_scope_stack[-1]
|
||||
|
||||
new_scope = union_scope.branch(
|
||||
expression,
|
||||
outer_columns=union_scope.outer_columns,
|
||||
scope_type=ScopeType.UNION,
|
||||
)
|
||||
):
|
||||
yield left
|
||||
|
||||
right = None
|
||||
for right in _traverse_scope(
|
||||
scope.branch(
|
||||
scope.expression.right,
|
||||
outer_column_list=scope.outer_column_list,
|
||||
scope_type=ScopeType.UNION,
|
||||
)
|
||||
):
|
||||
yield right
|
||||
if isinstance(expression, exp.Union):
|
||||
yield from _traverse_ctes(new_scope)
|
||||
|
||||
scope.union_scopes = [left, right]
|
||||
union_scope_stack.append(new_scope)
|
||||
expression_stack.extend([expression.right, expression.left])
|
||||
continue
|
||||
|
||||
for scope in _traverse_scope(new_scope):
|
||||
yield scope
|
||||
|
||||
if prev_scope:
|
||||
union_scope_stack.pop()
|
||||
union_scope.union_scopes = [prev_scope, scope]
|
||||
prev_scope = union_scope
|
||||
|
||||
yield union_scope
|
||||
else:
|
||||
prev_scope = scope
|
||||
|
||||
|
||||
def _traverse_ctes(scope):
|
||||
|
@ -588,7 +617,7 @@ def _traverse_ctes(scope):
|
|||
scope.branch(
|
||||
cte.this,
|
||||
cte_sources=sources,
|
||||
outer_column_list=cte.alias_column_names,
|
||||
outer_columns=cte.alias_column_names,
|
||||
scope_type=ScopeType.CTE,
|
||||
)
|
||||
):
|
||||
|
@ -615,7 +644,9 @@ def _is_derived_table(expression: exp.Subquery) -> bool:
|
|||
as it doesn't introduce a new scope. If an alias is present, it shadows all names
|
||||
under the Subquery, so that's one exception to this rule.
|
||||
"""
|
||||
return bool(expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES))
|
||||
return isinstance(expression, exp.Subquery) and bool(
|
||||
expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES)
|
||||
)
|
||||
|
||||
|
||||
def _traverse_tables(scope):
|
||||
|
@ -681,7 +712,7 @@ def _traverse_tables(scope):
|
|||
scope.branch(
|
||||
expression,
|
||||
lateral_sources=lateral_sources,
|
||||
outer_column_list=expression.alias_column_names,
|
||||
outer_columns=expression.alias_column_names,
|
||||
scope_type=scope_type,
|
||||
)
|
||||
):
|
||||
|
@ -719,13 +750,13 @@ def _traverse_udtfs(scope):
|
|||
|
||||
sources = {}
|
||||
for expression in expressions:
|
||||
if isinstance(expression, exp.Subquery) and _is_derived_table(expression):
|
||||
if _is_derived_table(expression):
|
||||
top = None
|
||||
for child_scope in _traverse_scope(
|
||||
scope.branch(
|
||||
expression,
|
||||
scope_type=ScopeType.DERIVED_TABLE,
|
||||
outer_column_list=expression.alias_column_names,
|
||||
outer_columns=expression.alias_column_names,
|
||||
)
|
||||
):
|
||||
yield child_scope
|
||||
|
@ -738,18 +769,6 @@ def _traverse_udtfs(scope):
|
|||
scope.sources.update(sources)
|
||||
|
||||
|
||||
def _traverse_ddl(scope):
|
||||
yield from _traverse_ctes(scope)
|
||||
|
||||
query_scope = scope.branch(
|
||||
scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, sources=scope.sources
|
||||
)
|
||||
query_scope._collect()
|
||||
query_scope._ctes = scope.ctes + query_scope._ctes
|
||||
|
||||
yield from _traverse_scope(query_scope)
|
||||
|
||||
|
||||
def walk_in_scope(expression, bfs=True, prune=None):
|
||||
"""
|
||||
Returns a generator object which visits all nodes in the syntrax tree, stopping at
|
||||
|
@ -769,23 +788,21 @@ def walk_in_scope(expression, bfs=True, prune=None):
|
|||
# Whenever we set it to True, we exclude a subtree from traversal.
|
||||
crossed_scope_boundary = False
|
||||
|
||||
for node, parent, key in expression.walk(
|
||||
bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args))
|
||||
for node in expression.walk(
|
||||
bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n))
|
||||
):
|
||||
crossed_scope_boundary = False
|
||||
|
||||
yield node, parent, key
|
||||
yield node
|
||||
|
||||
if node is expression:
|
||||
continue
|
||||
if (
|
||||
isinstance(node, exp.CTE)
|
||||
or (
|
||||
isinstance(node, exp.Subquery)
|
||||
and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
|
||||
and _is_derived_table(node)
|
||||
isinstance(node.parent, (exp.From, exp.Join, exp.Subquery))
|
||||
and (_is_derived_table(node) or isinstance(node, exp.UDTF))
|
||||
)
|
||||
or isinstance(node, exp.UDTF)
|
||||
or isinstance(node, exp.UNWRAPPED_QUERIES)
|
||||
):
|
||||
crossed_scope_boundary = True
|
||||
|
@ -812,7 +829,7 @@ def find_all_in_scope(expression, expression_types, bfs=True):
|
|||
Yields:
|
||||
exp.Expression: nodes
|
||||
"""
|
||||
for expression, *_ in walk_in_scope(expression, bfs=bfs):
|
||||
for expression in walk_in_scope(expression, bfs=bfs):
|
||||
if isinstance(expression, tuple(ensure_collection(expression_types))):
|
||||
yield expression
|
||||
|
||||
|
|
|
@ -9,19 +9,25 @@ from decimal import Decimal
|
|||
|
||||
import sqlglot
|
||||
from sqlglot import Dialect, exp
|
||||
from sqlglot.helper import first, is_iterable, merge_ranges, while_changing
|
||||
from sqlglot.helper import first, merge_ranges, while_changing
|
||||
from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
|
||||
DateTruncBinaryTransform = t.Callable[
|
||||
[exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression]
|
||||
[exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
|
||||
]
|
||||
|
||||
# Final means that an expression should not be simplified
|
||||
FINAL = "final"
|
||||
|
||||
# Value ranges for byte-sized signed/unsigned integers
|
||||
TINYINT_MIN = -128
|
||||
TINYINT_MAX = 127
|
||||
UTINYINT_MIN = 0
|
||||
UTINYINT_MAX = 255
|
||||
|
||||
|
||||
class UnsupportedUnit(Exception):
|
||||
pass
|
||||
|
@ -63,14 +69,14 @@ def simplify(
|
|||
group.meta[FINAL] = True
|
||||
|
||||
for e in expression.selects:
|
||||
for node, *_ in e.walk():
|
||||
for node in e.walk():
|
||||
if node in groups:
|
||||
e.meta[FINAL] = True
|
||||
break
|
||||
|
||||
having = expression.args.get("having")
|
||||
if having:
|
||||
for node, *_ in having.walk():
|
||||
for node in having.walk():
|
||||
if node in groups:
|
||||
having.meta[FINAL] = True
|
||||
break
|
||||
|
@ -304,6 +310,8 @@ def _simplify_comparison(expression, left, right, or_=False):
|
|||
r = extract_date(r)
|
||||
if not r:
|
||||
return None
|
||||
# python won't compare date and datetime, but many engines will upcast
|
||||
l, r = cast_as_datetime(l), cast_as_datetime(r)
|
||||
|
||||
for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
|
||||
if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
|
||||
|
@ -431,7 +439,7 @@ def propagate_constants(expression, root=True):
|
|||
and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
|
||||
):
|
||||
constant_mapping = {}
|
||||
for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
|
||||
for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
|
||||
if isinstance(expr, exp.EQ):
|
||||
l, r = expr.left, expr.right
|
||||
|
||||
|
@ -544,7 +552,37 @@ def simplify_literals(expression, root=True):
|
|||
return expression
|
||||
|
||||
|
||||
NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
|
||||
|
||||
|
||||
def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
|
||||
if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
|
||||
this = _simplify_integer_cast(expr.this)
|
||||
else:
|
||||
this = expr.this
|
||||
|
||||
if isinstance(expr, exp.Cast) and this.is_int:
|
||||
num = int(this.name)
|
||||
|
||||
# Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
|
||||
# integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
|
||||
# engine-dependent
|
||||
if (
|
||||
TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
|
||||
) or (
|
||||
UTINYINT_MIN <= num <= UTINYINT_MAX
|
||||
and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
|
||||
):
|
||||
return this
|
||||
|
||||
return expr
|
||||
|
||||
|
||||
def _simplify_binary(expression, a, b):
|
||||
if isinstance(expression, COMPARISONS):
|
||||
a = _simplify_integer_cast(a)
|
||||
b = _simplify_integer_cast(b)
|
||||
|
||||
if isinstance(expression, exp.Is):
|
||||
if isinstance(b, exp.Not):
|
||||
c = b.this
|
||||
|
@ -558,7 +596,7 @@ def _simplify_binary(expression, a, b):
|
|||
return exp.true() if not_ else exp.false()
|
||||
if is_null(a):
|
||||
return exp.false() if not_ else exp.true()
|
||||
elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
|
||||
elif isinstance(expression, NULL_OK):
|
||||
return None
|
||||
elif is_null(a) or is_null(b):
|
||||
return exp.null()
|
||||
|
@ -591,17 +629,17 @@ def _simplify_binary(expression, a, b):
|
|||
if boolean:
|
||||
return boolean
|
||||
elif _is_date_literal(a) and isinstance(b, exp.Interval):
|
||||
a, b = extract_date(a), extract_interval(b)
|
||||
if a and b:
|
||||
date, b = extract_date(a), extract_interval(b)
|
||||
if date and b:
|
||||
if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
|
||||
return date_literal(a + b)
|
||||
return date_literal(date + b, extract_type(a))
|
||||
if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
|
||||
return date_literal(a - b)
|
||||
return date_literal(date - b, extract_type(a))
|
||||
elif isinstance(a, exp.Interval) and _is_date_literal(b):
|
||||
a, b = extract_interval(a), extract_date(b)
|
||||
a, date = extract_interval(a), extract_date(b)
|
||||
# you cannot subtract a date from an interval
|
||||
if a and b and isinstance(expression, exp.Add):
|
||||
return date_literal(a + b)
|
||||
return date_literal(a + date, extract_type(b))
|
||||
elif _is_date_literal(a) and _is_date_literal(b):
|
||||
if isinstance(expression, exp.Predicate):
|
||||
a, b = extract_date(a), extract_date(b)
|
||||
|
@ -618,12 +656,16 @@ def simplify_parens(expression):
|
|||
|
||||
this = expression.this
|
||||
parent = expression.parent
|
||||
parent_is_predicate = isinstance(parent, exp.Predicate)
|
||||
|
||||
if not isinstance(this, exp.Select) and (
|
||||
not isinstance(parent, (exp.Condition, exp.Binary))
|
||||
or isinstance(parent, exp.Paren)
|
||||
or not isinstance(this, exp.Binary)
|
||||
or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
|
||||
or (
|
||||
not isinstance(this, exp.Binary)
|
||||
and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
|
||||
)
|
||||
or (isinstance(this, exp.Predicate) and not parent_is_predicate)
|
||||
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
|
||||
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
|
||||
or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
|
||||
|
@ -632,24 +674,12 @@ def simplify_parens(expression):
|
|||
return expression
|
||||
|
||||
|
||||
NONNULL_CONSTANTS = (
|
||||
exp.Literal,
|
||||
exp.Boolean,
|
||||
)
|
||||
|
||||
CONSTANTS = (
|
||||
exp.Literal,
|
||||
exp.Boolean,
|
||||
exp.Null,
|
||||
)
|
||||
|
||||
|
||||
def _is_nonnull_constant(expression: exp.Expression) -> bool:
|
||||
return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression)
|
||||
return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
|
||||
|
||||
|
||||
def _is_constant(expression: exp.Expression) -> bool:
|
||||
return isinstance(expression, CONSTANTS) or _is_date_literal(expression)
|
||||
return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression)
|
||||
|
||||
|
||||
def simplify_coalesce(expression):
|
||||
|
@ -820,45 +850,55 @@ def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Opti
|
|||
return floor, floor + interval(unit)
|
||||
|
||||
|
||||
def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
|
||||
def _datetrunc_eq_expression(
|
||||
left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType]
|
||||
) -> exp.Expression:
|
||||
"""Get the logical expression for a date range"""
|
||||
return exp.and_(
|
||||
left >= date_literal(drange[0]),
|
||||
left < date_literal(drange[1]),
|
||||
left >= date_literal(drange[0], target_type),
|
||||
left < date_literal(drange[1], target_type),
|
||||
copy=False,
|
||||
)
|
||||
|
||||
|
||||
def _datetrunc_eq(
|
||||
left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
|
||||
left: exp.Expression,
|
||||
date: datetime.date,
|
||||
unit: str,
|
||||
dialect: Dialect,
|
||||
target_type: t.Optional[exp.DataType],
|
||||
) -> t.Optional[exp.Expression]:
|
||||
drange = _datetrunc_range(date, unit, dialect)
|
||||
if not drange:
|
||||
return None
|
||||
|
||||
return _datetrunc_eq_expression(left, drange)
|
||||
return _datetrunc_eq_expression(left, drange, target_type)
|
||||
|
||||
|
||||
def _datetrunc_neq(
|
||||
left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
|
||||
left: exp.Expression,
|
||||
date: datetime.date,
|
||||
unit: str,
|
||||
dialect: Dialect,
|
||||
target_type: t.Optional[exp.DataType],
|
||||
) -> t.Optional[exp.Expression]:
|
||||
drange = _datetrunc_range(date, unit, dialect)
|
||||
if not drange:
|
||||
return None
|
||||
|
||||
return exp.and_(
|
||||
left < date_literal(drange[0]),
|
||||
left >= date_literal(drange[1]),
|
||||
left < date_literal(drange[0], target_type),
|
||||
left >= date_literal(drange[1], target_type),
|
||||
copy=False,
|
||||
)
|
||||
|
||||
|
||||
DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
|
||||
exp.LT: lambda l, dt, u, d: l
|
||||
< date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)),
|
||||
exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)),
|
||||
exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)),
|
||||
exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)),
|
||||
exp.LT: lambda l, dt, u, d, t: l
|
||||
< date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
|
||||
exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
|
||||
exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
|
||||
exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
|
||||
exp.EQ: _datetrunc_eq,
|
||||
exp.NEQ: _datetrunc_neq,
|
||||
}
|
||||
|
@ -876,9 +916,11 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr
|
|||
comparison = expression.__class__
|
||||
|
||||
if isinstance(expression, DATETRUNCS):
|
||||
date = extract_date(expression.this)
|
||||
this = expression.this
|
||||
trunc_type = extract_type(this)
|
||||
date = extract_date(this)
|
||||
if date and expression.unit:
|
||||
return date_literal(date_floor(date, expression.unit.name.lower(), dialect))
|
||||
return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type)
|
||||
elif comparison not in DATETRUNC_COMPARISONS:
|
||||
return expression
|
||||
|
||||
|
@ -889,14 +931,21 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr
|
|||
return expression
|
||||
|
||||
l = t.cast(exp.DateTrunc, l)
|
||||
trunc_arg = l.this
|
||||
unit = l.unit.name.lower()
|
||||
date = extract_date(r)
|
||||
|
||||
if not date:
|
||||
return expression
|
||||
|
||||
return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression
|
||||
elif isinstance(expression, exp.In):
|
||||
return (
|
||||
DATETRUNC_BINARY_COMPARISONS[comparison](
|
||||
trunc_arg, date, unit, dialect, extract_type(trunc_arg, r)
|
||||
)
|
||||
or expression
|
||||
)
|
||||
|
||||
if isinstance(expression, exp.In):
|
||||
l = expression.this
|
||||
rs = expression.expressions
|
||||
|
||||
|
@ -917,8 +966,11 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr
|
|||
return expression
|
||||
|
||||
ranges = merge_ranges(ranges)
|
||||
target_type = extract_type(l, *rs)
|
||||
|
||||
return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
|
||||
return exp.or_(
|
||||
*[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
|
@ -954,7 +1006,7 @@ JOINS = {
|
|||
def remove_where_true(expression):
|
||||
for where in expression.find_all(exp.Where):
|
||||
if always_true(where.this):
|
||||
where.parent.set("where", None)
|
||||
where.pop()
|
||||
for join in expression.find_all(exp.Join):
|
||||
if (
|
||||
always_true(join.args.get("on"))
|
||||
|
@ -962,7 +1014,7 @@ def remove_where_true(expression):
|
|||
and not join.args.get("method")
|
||||
and (join.side, join.kind) in JOINS
|
||||
):
|
||||
join.set("on", None)
|
||||
join.args["on"].pop()
|
||||
join.set("side", None)
|
||||
join.set("kind", "CROSS")
|
||||
|
||||
|
@ -1067,15 +1119,25 @@ def extract_interval(expression):
|
|||
return None
|
||||
|
||||
|
||||
def date_literal(date):
|
||||
return exp.cast(
|
||||
exp.Literal.string(date),
|
||||
(
|
||||
def extract_type(*expressions):
|
||||
target_type = None
|
||||
for expression in expressions:
|
||||
target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
|
||||
if target_type:
|
||||
break
|
||||
|
||||
return target_type
|
||||
|
||||
|
||||
def date_literal(date, target_type=None):
|
||||
if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
|
||||
target_type = (
|
||||
exp.DataType.Type.DATETIME
|
||||
if isinstance(date, datetime.datetime)
|
||||
else exp.DataType.Type.DATE
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return exp.cast(exp.Literal.string(date), target_type)
|
||||
|
||||
|
||||
def interval(unit: str, n: int = 1):
|
||||
|
@ -1169,73 +1231,251 @@ def gen(expression: t.Any) -> str:
|
|||
Sorting and deduping sql is a necessary step for optimization. Calling the actual
|
||||
generator is expensive so we have a bare minimum sql generator here.
|
||||
"""
|
||||
if expression is None:
|
||||
return "_"
|
||||
if is_iterable(expression):
|
||||
return ",".join(gen(e) for e in expression)
|
||||
if not isinstance(expression, exp.Expression):
|
||||
return str(expression)
|
||||
|
||||
etype = type(expression)
|
||||
if etype in GEN_MAP:
|
||||
return GEN_MAP[etype](expression)
|
||||
return f"{expression.key} {gen(expression.args.values())}"
|
||||
return Gen().gen(expression)
|
||||
|
||||
|
||||
GEN_MAP = {
|
||||
exp.Add: lambda e: _binary(e, "+"),
|
||||
exp.And: lambda e: _binary(e, "AND"),
|
||||
exp.Anonymous: lambda e: _anonymous(e),
|
||||
exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
|
||||
exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
|
||||
exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
|
||||
exp.Column: lambda e: ".".join(gen(p) for p in e.parts),
|
||||
exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}",
|
||||
exp.Div: lambda e: _binary(e, "/"),
|
||||
exp.Dot: lambda e: _binary(e, "."),
|
||||
exp.EQ: lambda e: _binary(e, "="),
|
||||
exp.GT: lambda e: _binary(e, ">"),
|
||||
exp.GTE: lambda e: _binary(e, ">="),
|
||||
exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name,
|
||||
exp.ILike: lambda e: _binary(e, "ILIKE"),
|
||||
exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})",
|
||||
exp.Is: lambda e: _binary(e, "IS"),
|
||||
exp.Like: lambda e: _binary(e, "LIKE"),
|
||||
exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name,
|
||||
exp.LT: lambda e: _binary(e, "<"),
|
||||
exp.LTE: lambda e: _binary(e, "<="),
|
||||
exp.Mod: lambda e: _binary(e, "%"),
|
||||
exp.Mul: lambda e: _binary(e, "*"),
|
||||
exp.Neg: lambda e: _unary(e, "-"),
|
||||
exp.NEQ: lambda e: _binary(e, "<>"),
|
||||
exp.Not: lambda e: _unary(e, "NOT"),
|
||||
exp.Null: lambda e: "NULL",
|
||||
exp.Or: lambda e: _binary(e, "OR"),
|
||||
exp.Paren: lambda e: f"({gen(e.this)})",
|
||||
exp.Sub: lambda e: _binary(e, "-"),
|
||||
exp.Subquery: lambda e: f"({gen(e.args.values())})",
|
||||
exp.Table: lambda e: gen(e.args.values()),
|
||||
exp.Var: lambda e: e.name,
|
||||
}
|
||||
class Gen:
|
||||
def __init__(self):
|
||||
self.stack = []
|
||||
self.sqls = []
|
||||
|
||||
def gen(self, expression: exp.Expression) -> str:
|
||||
self.stack = [expression]
|
||||
self.sqls.clear()
|
||||
|
||||
def _anonymous(e: exp.Anonymous) -> str:
|
||||
this = e.this
|
||||
if isinstance(this, str):
|
||||
name = this.upper()
|
||||
elif isinstance(this, exp.Identifier):
|
||||
name = f'"{this.name}"' if this.quoted else this.name.upper()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
|
||||
while self.stack:
|
||||
node = self.stack.pop()
|
||||
|
||||
if isinstance(node, exp.Expression):
|
||||
exp_handler_name = f"{node.key}_sql"
|
||||
|
||||
if hasattr(self, exp_handler_name):
|
||||
getattr(self, exp_handler_name)(node)
|
||||
elif isinstance(node, exp.Func):
|
||||
self._function(node)
|
||||
else:
|
||||
key = node.key.upper()
|
||||
self.stack.append(f"{key} " if self._args(node) else key)
|
||||
elif type(node) is list:
|
||||
for n in reversed(node):
|
||||
if n is not None:
|
||||
self.stack.extend((n, ","))
|
||||
if node:
|
||||
self.stack.pop()
|
||||
else:
|
||||
if node is not None:
|
||||
self.sqls.append(str(node))
|
||||
|
||||
return "".join(self.sqls)
|
||||
|
||||
def add_sql(self, e: exp.Add) -> None:
|
||||
self._binary(e, " + ")
|
||||
|
||||
def alias_sql(self, e: exp.Alias) -> None:
|
||||
self.stack.extend(
|
||||
(
|
||||
e.args.get("alias"),
|
||||
" AS ",
|
||||
e.args.get("this"),
|
||||
)
|
||||
)
|
||||
|
||||
return f"{name} {','.join(gen(e) for e in e.expressions)}"
|
||||
def and_sql(self, e: exp.And) -> None:
|
||||
self._binary(e, " AND ")
|
||||
|
||||
def anonymous_sql(self, e: exp.Anonymous) -> None:
|
||||
this = e.this
|
||||
if isinstance(this, str):
|
||||
name = this.upper()
|
||||
elif isinstance(this, exp.Identifier):
|
||||
name = this.this
|
||||
name = f'"{name}"' if this.quoted else name.upper()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
|
||||
)
|
||||
|
||||
def _binary(e: exp.Binary, op: str) -> str:
|
||||
return f"{gen(e.left)} {op} {gen(e.right)}"
|
||||
self.stack.extend(
|
||||
(
|
||||
")",
|
||||
e.expressions,
|
||||
"(",
|
||||
name,
|
||||
)
|
||||
)
|
||||
|
||||
def between_sql(self, e: exp.Between) -> None:
|
||||
self.stack.extend(
|
||||
(
|
||||
e.args.get("high"),
|
||||
" AND ",
|
||||
e.args.get("low"),
|
||||
" BETWEEN ",
|
||||
e.this,
|
||||
)
|
||||
)
|
||||
|
||||
def _unary(e: exp.Unary, op: str) -> str:
|
||||
return f"{op} {gen(e.this)}"
|
||||
def boolean_sql(self, e: exp.Boolean) -> None:
|
||||
self.stack.append("TRUE" if e.this else "FALSE")
|
||||
|
||||
def bracket_sql(self, e: exp.Bracket) -> None:
|
||||
self.stack.extend(
|
||||
(
|
||||
"]",
|
||||
e.expressions,
|
||||
"[",
|
||||
e.this,
|
||||
)
|
||||
)
|
||||
|
||||
def column_sql(self, e: exp.Column) -> None:
|
||||
for p in reversed(e.parts):
|
||||
self.stack.extend((p, "."))
|
||||
self.stack.pop()
|
||||
|
||||
def datatype_sql(self, e: exp.DataType) -> None:
|
||||
self._args(e, 1)
|
||||
self.stack.append(f"{e.this.name} ")
|
||||
|
||||
def div_sql(self, e: exp.Div) -> None:
|
||||
self._binary(e, " / ")
|
||||
|
||||
def dot_sql(self, e: exp.Dot) -> None:
|
||||
self._binary(e, ".")
|
||||
|
||||
def eq_sql(self, e: exp.EQ) -> None:
|
||||
self._binary(e, " = ")
|
||||
|
||||
def from_sql(self, e: exp.From) -> None:
|
||||
self.stack.extend((e.this, "FROM "))
|
||||
|
||||
def gt_sql(self, e: exp.GT) -> None:
|
||||
self._binary(e, " > ")
|
||||
|
||||
def gte_sql(self, e: exp.GTE) -> None:
|
||||
self._binary(e, " >= ")
|
||||
|
||||
def identifier_sql(self, e: exp.Identifier) -> None:
|
||||
self.stack.append(f'"{e.this}"' if e.quoted else e.this)
|
||||
|
||||
def ilike_sql(self, e: exp.ILike) -> None:
|
||||
self._binary(e, " ILIKE ")
|
||||
|
||||
def in_sql(self, e: exp.In) -> None:
|
||||
self.stack.append(")")
|
||||
self._args(e, 1)
|
||||
self.stack.extend(
|
||||
(
|
||||
"(",
|
||||
" IN ",
|
||||
e.this,
|
||||
)
|
||||
)
|
||||
|
||||
def intdiv_sql(self, e: exp.IntDiv) -> None:
|
||||
self._binary(e, " DIV ")
|
||||
|
||||
def is_sql(self, e: exp.Is) -> None:
|
||||
self._binary(e, " IS ")
|
||||
|
||||
def like_sql(self, e: exp.Like) -> None:
|
||||
self._binary(e, " Like ")
|
||||
|
||||
def literal_sql(self, e: exp.Literal) -> None:
|
||||
self.stack.append(f"'{e.this}'" if e.is_string else e.this)
|
||||
|
||||
def lt_sql(self, e: exp.LT) -> None:
|
||||
self._binary(e, " < ")
|
||||
|
||||
def lte_sql(self, e: exp.LTE) -> None:
|
||||
self._binary(e, " <= ")
|
||||
|
||||
def mod_sql(self, e: exp.Mod) -> None:
|
||||
self._binary(e, " % ")
|
||||
|
||||
def mul_sql(self, e: exp.Mul) -> None:
|
||||
self._binary(e, " * ")
|
||||
|
||||
def neg_sql(self, e: exp.Neg) -> None:
|
||||
self._unary(e, "-")
|
||||
|
||||
def neq_sql(self, e: exp.NEQ) -> None:
|
||||
self._binary(e, " <> ")
|
||||
|
||||
def not_sql(self, e: exp.Not) -> None:
|
||||
self._unary(e, "NOT ")
|
||||
|
||||
def null_sql(self, e: exp.Null) -> None:
|
||||
self.stack.append("NULL")
|
||||
|
||||
def or_sql(self, e: exp.Or) -> None:
|
||||
self._binary(e, " OR ")
|
||||
|
||||
def paren_sql(self, e: exp.Paren) -> None:
|
||||
self.stack.extend(
|
||||
(
|
||||
")",
|
||||
e.this,
|
||||
"(",
|
||||
)
|
||||
)
|
||||
|
||||
def sub_sql(self, e: exp.Sub) -> None:
|
||||
self._binary(e, " - ")
|
||||
|
||||
def subquery_sql(self, e: exp.Subquery) -> None:
|
||||
self._args(e, 2)
|
||||
alias = e.args.get("alias")
|
||||
if alias:
|
||||
self.stack.append(alias)
|
||||
self.stack.extend((")", e.this, "("))
|
||||
|
||||
def table_sql(self, e: exp.Table) -> None:
|
||||
self._args(e, 4)
|
||||
alias = e.args.get("alias")
|
||||
if alias:
|
||||
self.stack.append(alias)
|
||||
for p in reversed(e.parts):
|
||||
self.stack.extend((p, "."))
|
||||
self.stack.pop()
|
||||
|
||||
def tablealias_sql(self, e: exp.TableAlias) -> None:
|
||||
columns = e.columns
|
||||
|
||||
if columns:
|
||||
self.stack.extend((")", columns, "("))
|
||||
|
||||
self.stack.extend((e.this, " AS "))
|
||||
|
||||
def var_sql(self, e: exp.Var) -> None:
|
||||
self.stack.append(e.this)
|
||||
|
||||
def _binary(self, e: exp.Binary, op: str) -> None:
|
||||
self.stack.extend((e.expression, op, e.this))
|
||||
|
||||
def _unary(self, e: exp.Unary, op: str) -> None:
|
||||
self.stack.extend((e.this, op))
|
||||
|
||||
def _function(self, e: exp.Func) -> None:
|
||||
self.stack.extend(
|
||||
(
|
||||
")",
|
||||
list(e.args.values()),
|
||||
"(",
|
||||
e.sql_name(),
|
||||
)
|
||||
)
|
||||
|
||||
def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
|
||||
kvs = []
|
||||
arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
|
||||
|
||||
for k in arg_types or arg_types:
|
||||
v = node.args.get(k)
|
||||
|
||||
if v is not None:
|
||||
kvs.append([f":{k}", v])
|
||||
if kvs:
|
||||
self.stack.append(kvs)
|
||||
return True
|
||||
return False
|
||||
|
|
|
@ -138,7 +138,7 @@ def decorrelate(select, parent_select, external_columns, next_alias_name):
|
|||
if isinstance(predicate, exp.Binary):
|
||||
key = (
|
||||
predicate.right
|
||||
if any(node is column for node, *_ in predicate.left.walk())
|
||||
if any(node is column for node in predicate.left.walk())
|
||||
else predicate.left
|
||||
)
|
||||
else:
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -118,6 +118,7 @@ class Step:
|
|||
if joins:
|
||||
join = Join.from_joins(joins, ctes)
|
||||
join.name = step.name
|
||||
join.source_name = step.name
|
||||
join.add_dependency(step)
|
||||
step = join
|
||||
|
||||
|
@ -187,13 +188,13 @@ class Step:
|
|||
intermediate[v.name] = k
|
||||
|
||||
for projection in projections:
|
||||
for node, *_ in projection.walk():
|
||||
for node in projection.walk():
|
||||
name = intermediate.get(node)
|
||||
if name:
|
||||
node.replace(exp.column(name, step.name))
|
||||
|
||||
if aggregate.condition:
|
||||
for node, *_ in aggregate.condition.walk():
|
||||
for node in aggregate.condition.walk():
|
||||
name = intermediate.get(node) or intermediate.get(node.name)
|
||||
if name:
|
||||
node.replace(exp.column(name, step.name))
|
||||
|
@ -331,7 +332,7 @@ class Join(Step):
|
|||
@classmethod
|
||||
def from_joins(
|
||||
cls, joins: t.Iterable[exp.Join], ctes: t.Optional[t.Dict[str, Step]] = None
|
||||
) -> Step:
|
||||
) -> Join:
|
||||
step = Join()
|
||||
|
||||
for join in joins:
|
||||
|
@ -349,10 +350,11 @@ class Join(Step):
|
|||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.source_name: t.Optional[str] = None
|
||||
self.joins: t.Dict[str, t.Dict[str, t.List[str] | exp.Expression]] = {}
|
||||
|
||||
def _to_s(self, indent: str) -> t.List[str]:
|
||||
lines = []
|
||||
lines = [f"{indent}Source: {self.source_name or self.name}"]
|
||||
for name, join in self.joins.items():
|
||||
lines.append(f"{indent}{name}: {join['side'] or 'INNER'}")
|
||||
join_key = ", ".join(str(key) for key in t.cast(list, join.get("join_key") or []))
|
||||
|
@ -423,7 +425,7 @@ class SetOperation(Step):
|
|||
@classmethod
|
||||
def from_expression(
|
||||
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
|
||||
) -> Step:
|
||||
) -> SetOperation:
|
||||
assert isinstance(expression, exp.Union)
|
||||
|
||||
left = Step.from_expression(expression.left, ctes)
|
||||
|
|
|
@ -135,6 +135,7 @@ class TokenType(AutoName):
|
|||
LONGBLOB = auto()
|
||||
TINYBLOB = auto()
|
||||
TINYTEXT = auto()
|
||||
NAME = auto()
|
||||
BINARY = auto()
|
||||
VARBINARY = auto()
|
||||
JSON = auto()
|
||||
|
@ -290,6 +291,7 @@ class TokenType(AutoName):
|
|||
LOAD = auto()
|
||||
LOCK = auto()
|
||||
MAP = auto()
|
||||
MATCH_CONDITION = auto()
|
||||
MATCH_RECOGNIZE = auto()
|
||||
MEMBER_OF = auto()
|
||||
MERGE = auto()
|
||||
|
@ -317,6 +319,7 @@ class TokenType(AutoName):
|
|||
PERCENT = auto()
|
||||
PIVOT = auto()
|
||||
PLACEHOLDER = auto()
|
||||
POSITIONAL = auto()
|
||||
PRAGMA = auto()
|
||||
PREWHERE = auto()
|
||||
PRIMARY_KEY = auto()
|
||||
|
@ -340,6 +343,7 @@ class TokenType(AutoName):
|
|||
SELECT = auto()
|
||||
SEMI = auto()
|
||||
SEPARATOR = auto()
|
||||
SEQUENCE = auto()
|
||||
SERDE_PROPERTIES = auto()
|
||||
SET = auto()
|
||||
SETTINGS = auto()
|
||||
|
@ -518,6 +522,7 @@ class _Tokenizer(type):
|
|||
break_=_TOKEN_TYPE_TO_INDEX[TokenType.BREAK],
|
||||
dcolon=_TOKEN_TYPE_TO_INDEX[TokenType.DCOLON],
|
||||
heredoc_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEREDOC_STRING],
|
||||
raw_string=_TOKEN_TYPE_TO_INDEX[TokenType.RAW_STRING],
|
||||
hex_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEX_STRING],
|
||||
identifier=_TOKEN_TYPE_TO_INDEX[TokenType.IDENTIFIER],
|
||||
number=_TOKEN_TYPE_TO_INDEX[TokenType.NUMBER],
|
||||
|
@ -562,8 +567,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"~": TokenType.TILDA,
|
||||
"?": TokenType.PLACEHOLDER,
|
||||
"@": TokenType.PARAMETER,
|
||||
# used for breaking a var like x'y' but nothing else
|
||||
# the token type doesn't matter
|
||||
# Used for breaking a var like x'y' but nothing else the token type doesn't matter
|
||||
"'": TokenType.QUOTE,
|
||||
"`": TokenType.IDENTIFIER,
|
||||
'"': TokenType.IDENTIFIER,
|
||||
|
@ -796,6 +800,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"LONG": TokenType.BIGINT,
|
||||
"BIGINT": TokenType.BIGINT,
|
||||
"INT8": TokenType.TINYINT,
|
||||
"UINT": TokenType.UINT,
|
||||
"DEC": TokenType.DECIMAL,
|
||||
"DECIMAL": TokenType.DECIMAL,
|
||||
"BIGDECIMAL": TokenType.BIGDECIMAL,
|
||||
|
@ -856,6 +861,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"DATEMULTIRANGE": TokenType.DATEMULTIRANGE,
|
||||
"UNIQUE": TokenType.UNIQUE,
|
||||
"STRUCT": TokenType.STRUCT,
|
||||
"SEQUENCE": TokenType.SEQUENCE,
|
||||
"VARIANT": TokenType.VARIANT,
|
||||
"ALTER": TokenType.ALTER,
|
||||
"ANALYZE": TokenType.COMMAND,
|
||||
|
@ -888,7 +894,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
COMMAND_PREFIX_TOKENS = {TokenType.SEMICOLON, TokenType.BEGIN}
|
||||
|
||||
# handle numeric literals like in hive (3L = BIGINT)
|
||||
# Handle numeric literals like in hive (3L = BIGINT)
|
||||
NUMERIC_LITERALS: t.Dict[str, str] = {}
|
||||
|
||||
COMMENTS = ["--", ("/*", "*/")]
|
||||
|
@ -917,7 +923,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
if USE_RS_TOKENIZER:
|
||||
self._rs_dialect_settings = RsTokenizerDialectSettings(
|
||||
escape_sequences=self.dialect.ESCAPE_SEQUENCES,
|
||||
unescaped_sequences=self.dialect.UNESCAPED_SEQUENCES,
|
||||
identifiers_can_start_with_digit=self.dialect.IDENTIFIERS_CAN_START_WITH_DIGIT,
|
||||
)
|
||||
|
||||
|
@ -961,8 +967,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
while self.size and not self._end:
|
||||
current = self._current
|
||||
|
||||
# skip spaces inline rather than iteratively call advance()
|
||||
# for performance reasons
|
||||
# Skip spaces here rather than iteratively calling advance() for performance reasons
|
||||
while current < self.size:
|
||||
char = self.sql[current]
|
||||
|
||||
|
@ -971,12 +976,10 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
else:
|
||||
break
|
||||
|
||||
n = current - self._current
|
||||
self._start = current
|
||||
self._advance(n if n > 1 else 1)
|
||||
offset = current - self._current if current > self._current else 1
|
||||
|
||||
if self._char is None:
|
||||
break
|
||||
self._start = current
|
||||
self._advance(offset)
|
||||
|
||||
if not self._char.isspace():
|
||||
if self._char.isdigit():
|
||||
|
@ -1004,12 +1007,9 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
def _advance(self, i: int = 1, alnum: bool = False) -> None:
|
||||
if self.WHITE_SPACE.get(self._char) is TokenType.BREAK:
|
||||
# Ensures we don't count an extra line if we get a \r\n line break sequence
|
||||
if self._char == "\r" and self._peek == "\n":
|
||||
i = 2
|
||||
self._start += 1
|
||||
|
||||
self._col = 1
|
||||
self._line += 1
|
||||
if not (self._char == "\r" and self._peek == "\n"):
|
||||
self._col = 1
|
||||
self._line += 1
|
||||
else:
|
||||
self._col += i
|
||||
|
||||
|
@ -1268,13 +1268,27 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
return True
|
||||
|
||||
self._advance()
|
||||
tag = "" if self._char == end else self._extract_string(end)
|
||||
|
||||
if self._char == end:
|
||||
tag = ""
|
||||
else:
|
||||
tag = self._extract_string(
|
||||
end,
|
||||
unescape_sequences=False,
|
||||
raise_unmatched=not self.HEREDOC_TAG_IS_IDENTIFIER,
|
||||
)
|
||||
|
||||
if self._end and tag and self.HEREDOC_TAG_IS_IDENTIFIER:
|
||||
self._advance(-len(tag))
|
||||
self._add(self.HEREDOC_STRING_ALTERNATIVE)
|
||||
return True
|
||||
|
||||
end = f"{start}{tag}{end}"
|
||||
else:
|
||||
return False
|
||||
|
||||
self._advance(len(start))
|
||||
text = self._extract_string(end)
|
||||
text = self._extract_string(end, unescape_sequences=token_type != TokenType.RAW_STRING)
|
||||
|
||||
if base:
|
||||
try:
|
||||
|
@ -1289,7 +1303,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
def _scan_identifier(self, identifier_end: str) -> None:
|
||||
self._advance()
|
||||
text = self._extract_string(identifier_end, self._IDENTIFIER_ESCAPES)
|
||||
text = self._extract_string(identifier_end, escapes=self._IDENTIFIER_ESCAPES)
|
||||
self._add(TokenType.IDENTIFIER, text)
|
||||
|
||||
def _scan_var(self) -> None:
|
||||
|
@ -1306,12 +1320,29 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
else self.KEYWORDS.get(self._text.upper(), TokenType.VAR)
|
||||
)
|
||||
|
||||
def _extract_string(self, delimiter: str, escapes=None) -> str:
|
||||
def _extract_string(
|
||||
self,
|
||||
delimiter: str,
|
||||
escapes: t.Optional[t.Set[str]] = None,
|
||||
unescape_sequences: bool = True,
|
||||
raise_unmatched: bool = True,
|
||||
) -> str:
|
||||
text = ""
|
||||
delim_size = len(delimiter)
|
||||
escapes = self._STRING_ESCAPES if escapes is None else escapes
|
||||
|
||||
while True:
|
||||
if (
|
||||
unescape_sequences
|
||||
and self.dialect.UNESCAPED_SEQUENCES
|
||||
and self._peek
|
||||
and self._char in self.STRING_ESCAPES
|
||||
):
|
||||
unescaped_sequence = self.dialect.UNESCAPED_SEQUENCES.get(self._char + self._peek)
|
||||
if unescaped_sequence:
|
||||
self._advance(2)
|
||||
text += unescaped_sequence
|
||||
continue
|
||||
if (
|
||||
self._char in escapes
|
||||
and (self._peek == delimiter or self._peek in escapes)
|
||||
|
@ -1333,18 +1364,10 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
break
|
||||
|
||||
if self._end:
|
||||
raise TokenError(f"Missing {delimiter} from {self._line}:{self._start}")
|
||||
if not raise_unmatched:
|
||||
return text + self._char
|
||||
|
||||
if (
|
||||
self.dialect.ESCAPE_SEQUENCES
|
||||
and self._peek
|
||||
and self._char in self.STRING_ESCAPES
|
||||
):
|
||||
escaped_sequence = self.dialect.ESCAPE_SEQUENCES.get(self._char + self._peek)
|
||||
if escaped_sequence:
|
||||
self._advance(2)
|
||||
text += escaped_sequence
|
||||
continue
|
||||
raise TokenError(f"Missing {delimiter} from {self._line}:{self._start}")
|
||||
|
||||
current = self._current - 1
|
||||
self._advance(alnum=True)
|
||||
|
|
|
@ -447,7 +447,7 @@ def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
|
|||
if inner_with.recursive:
|
||||
top_level_with.set("recursive", True)
|
||||
|
||||
top_level_with.expressions.extend(inner_with.expressions)
|
||||
top_level_with.set("expressions", inner_with.expressions + top_level_with.expressions)
|
||||
|
||||
return expression
|
||||
|
||||
|
@ -464,7 +464,7 @@ def ensure_bools(expression: exp.Expression) -> exp.Expression:
|
|||
):
|
||||
node.replace(node.neq(0))
|
||||
|
||||
for node, *_ in expression.walk():
|
||||
for node in expression.walk():
|
||||
ensure_bools(node, _ensure_bool)
|
||||
|
||||
return expression
|
||||
|
@ -561,9 +561,7 @@ def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Exp
|
|||
|
||||
|
||||
def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Convert struct arguments to aliases: STRUCT(1 AS y) .
|
||||
"""
|
||||
"""Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
|
||||
if isinstance(expression, exp.Struct):
|
||||
expression.set(
|
||||
"expressions",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue