1
0
Fork 0

Merging upstream version 23.7.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:30:28 +01:00
parent ebba7c6a18
commit d26905e4af
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
187 changed files with 86502 additions and 71397 deletions

View file

@ -45,7 +45,7 @@ from sqlglot.expressions import (
from sqlglot.generator import Generator as Generator
from sqlglot.parser import Parser as Parser
from sqlglot.schema import MappingSchema as MappingSchema, Schema as Schema
from sqlglot.tokens import Tokenizer as Tokenizer, TokenType as TokenType
from sqlglot.tokens import Token as Token, Tokenizer as Tokenizer, TokenType as TokenType
if t.TYPE_CHECKING:
from sqlglot._typing import E
@ -69,6 +69,21 @@ schema = MappingSchema()
"""The default schema used by SQLGlot (e.g. in the optimizer)."""
def tokenize(sql: str, read: DialectType = None, dialect: DialectType = None) -> t.List[Token]:
"""
Tokenizes the given SQL string.
Args:
sql: the SQL code string to tokenize.
read: the SQL dialect to apply during tokenizing (eg. "spark", "hive", "presto", "mysql").
dialect: the SQL dialect (alias for read).
Returns:
The resulting list of tokens.
"""
return Dialect.get_or_raise(read or dialect).tokenize(sql)
def parse(
sql: str, read: DialectType = None, dialect: DialectType = None, **opts
) -> t.List[t.Optional[Expression]]:

View file

@ -18,8 +18,6 @@ from sqlglot.dataframe.sql.transforms import replace_id_value
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
from sqlglot.dataframe.sql.window import Window
from sqlglot.helper import ensure_list, object_to_dict, seq_get
from sqlglot.optimizer import optimize as optimize_func
from sqlglot.optimizer.qualify_columns import quote_identifiers
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import (
@ -121,7 +119,9 @@ class DataFrame:
self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
)
replacement_mapping[old_name_id] = new_hashed_id
expression = expression.transform(replace_id_value, replacement_mapping)
expression = expression.transform(replace_id_value, replacement_mapping).assert_is(
exp.Select
)
return expression
def _create_cte_from_expression(
@ -306,11 +306,12 @@ class DataFrame:
replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
for expression_type, select_expression in select_expressions:
select_expression = select_expression.transform(replace_id_value, replacement_mapping)
select_expression = select_expression.transform(
replace_id_value, replacement_mapping
).assert_is(exp.Select)
if optimize:
quote_identifiers(select_expression, dialect=dialect)
select_expression = t.cast(
exp.Select, optimize_func(select_expression, dialect=dialect)
exp.Select, self.spark._optimize(select_expression, dialect=dialect)
)
select_expression = df._replace_cte_names_with_hashes(select_expression)

View file

@ -184,7 +184,7 @@ def floor(col: ColumnOrName) -> Column:
def log10(col: ColumnOrName) -> Column:
return Column.invoke_expression_over_column(col, expression.Log10)
return Column.invoke_expression_over_column(lit(10), expression.Log, expression=col)
def log1p(col: ColumnOrName) -> Column:
@ -192,7 +192,7 @@ def log1p(col: ColumnOrName) -> Column:
def log2(col: ColumnOrName) -> Column:
return Column.invoke_expression_over_column(col, expression.Log2)
return Column.invoke_expression_over_column(lit(2), expression.Log, expression=col)
def log(arg1: t.Union[ColumnOrName, float], arg2: t.Optional[ColumnOrName] = None) -> Column:
@ -356,15 +356,15 @@ def coalesce(*cols: ColumnOrName) -> Column:
def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col1, "CORR", col2)
return Column.invoke_expression_over_column(col1, expression.Corr, expression=col2)
def covar_pop(col1: ColumnOrName, col2: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col1, "COVAR_POP", col2)
return Column.invoke_expression_over_column(col1, expression.CovarPop, expression=col2)
def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col1, "COVAR_SAMP", col2)
return Column.invoke_expression_over_column(col1, expression.CovarSamp, expression=col2)
def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
@ -971,10 +971,10 @@ def array_join(
) -> Column:
if null_replacement is not None:
return Column.invoke_expression_over_column(
col, expression.ArrayJoin, expression=lit(delimiter), null=lit(null_replacement)
col, expression.ArrayToString, expression=lit(delimiter), null=lit(null_replacement)
)
return Column.invoke_expression_over_column(
col, expression.ArrayJoin, expression=lit(delimiter)
col, expression.ArrayToString, expression=lit(delimiter)
)

View file

@ -12,6 +12,8 @@ from sqlglot.dataframe.sql.readwriter import DataFrameReader
from sqlglot.dataframe.sql.types import StructType
from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input
from sqlglot.helper import classproperty
from sqlglot.optimizer import optimize
from sqlglot.optimizer.qualify_columns import quote_identifiers
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput
@ -104,8 +106,15 @@ class SparkSession:
sel_expression = exp.Select(**select_kwargs)
return DataFrame(self, sel_expression)
def _optimize(
self, expression: exp.Expression, dialect: t.Optional[Dialect] = None
) -> exp.Expression:
dialect = dialect or self.dialect
quote_identifiers(expression, dialect=dialect)
return optimize(expression, dialect=dialect)
def sql(self, sqlQuery: str) -> DataFrame:
expression = sqlglot.parse_one(sqlQuery, read=self.dialect)
expression = self._optimize(sqlglot.parse_one(sqlQuery, read=self.dialect))
if isinstance(expression, exp.Select):
df = DataFrame(self, expression)
df = df._convert_leaf_to_cte()

View file

@ -61,6 +61,7 @@ dialect implementations in order to understand how their various components can
----
"""
from sqlglot.dialects.athena import Athena
from sqlglot.dialects.bigquery import BigQuery
from sqlglot.dialects.clickhouse import ClickHouse
from sqlglot.dialects.databricks import Databricks
@ -73,6 +74,7 @@ from sqlglot.dialects.mysql import MySQL
from sqlglot.dialects.oracle import Oracle
from sqlglot.dialects.postgres import Postgres
from sqlglot.dialects.presto import Presto
from sqlglot.dialects.prql import PRQL
from sqlglot.dialects.redshift import Redshift
from sqlglot.dialects.snowflake import Snowflake
from sqlglot.dialects.spark import Spark

View file

@ -0,0 +1,37 @@
from __future__ import annotations
from sqlglot import exp
from sqlglot.dialects.trino import Trino
from sqlglot.tokens import TokenType
class Athena(Trino):
class Parser(Trino.Parser):
STATEMENT_PARSERS = {
**Trino.Parser.STATEMENT_PARSERS,
TokenType.USING: lambda self: self._parse_as_command(self._prev),
}
class Generator(Trino.Generator):
PROPERTIES_LOCATION = {
**Trino.Generator.PROPERTIES_LOCATION,
exp.LocationProperty: exp.Properties.Location.POST_SCHEMA,
}
TYPE_MAPPING = {
**Trino.Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING",
}
TRANSFORMS = {
**Trino.Generator.TRANSFORMS,
exp.FileFormatProperty: lambda self, e: f"'FORMAT'={self.sql(e, 'this')}",
}
def property_sql(self, expression: exp.Property) -> str:
return (
f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}"
)
def with_properties(self, properties: exp.Properties) -> str:
return self.properties(properties, prefix=self.seg("TBLPROPERTIES"))

View file

@ -24,6 +24,7 @@ from sqlglot.dialects.dialect import (
rename_func,
timestrtotime_sql,
ts_or_ds_add_cast,
unit_to_var,
)
from sqlglot.helper import seq_get, split_num_words
from sqlglot.tokens import TokenType
@ -41,14 +42,22 @@ def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Va
structs = []
alias = expression.args.get("alias")
for tup in expression.find_all(exp.Tuple):
field_aliases = alias.columns if alias else (f"_c{i}" for i in range(len(tup.expressions)))
field_aliases = (
alias.columns
if alias and alias.columns
else (f"_c{i}" for i in range(len(tup.expressions)))
)
expressions = [
exp.PropertyEQ(this=exp.to_identifier(name), expression=fld)
for name, fld in zip(field_aliases, tup.expressions)
]
structs.append(exp.Struct(expressions=expressions))
return self.unnest_sql(exp.Unnest(expressions=[exp.array(*structs, copy=False)]))
# Due to `UNNEST_COLUMN_ONLY`, it is expected that the table alias be contained in the columns expression
alias_name_only = exp.TableAlias(columns=[alias.this]) if alias else None
return self.unnest_sql(
exp.Unnest(expressions=[exp.array(*structs, copy=False)], alias=alias_name_only)
)
def _returnsproperty_sql(self: BigQuery.Generator, expression: exp.ReturnsProperty) -> str:
@ -190,7 +199,7 @@ def _ts_or_ds_add_sql(self: BigQuery.Generator, expression: exp.TsOrDsAdd) -> st
def _ts_or_ds_diff_sql(self: BigQuery.Generator, expression: exp.TsOrDsDiff) -> str:
expression.this.replace(exp.cast(expression.this, "TIMESTAMP", copy=True))
expression.expression.replace(exp.cast(expression.expression, "TIMESTAMP", copy=True))
unit = expression.args.get("unit") or "DAY"
unit = unit_to_var(expression)
return self.func("DATE_DIFF", expression.this, expression.expression, unit)
@ -238,16 +247,6 @@ class BigQuery(Dialect):
"%E6S": "%S.%f",
}
ESCAPE_SEQUENCES = {
"\\a": "\a",
"\\b": "\b",
"\\f": "\f",
"\\n": "\n",
"\\r": "\r",
"\\t": "\t",
"\\v": "\v",
}
FORMAT_MAPPING = {
"DD": "%d",
"MM": "%m",
@ -315,6 +314,7 @@ class BigQuery(Dialect):
"BEGIN TRANSACTION": TokenType.BEGIN,
"BYTES": TokenType.BINARY,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
"DATETIME": TokenType.TIMESTAMP,
"DECLARE": TokenType.COMMAND,
"ELSEIF": TokenType.COMMAND,
"EXCEPTION": TokenType.COMMAND,
@ -486,14 +486,14 @@ class BigQuery(Dialect):
table.set("db", exp.Identifier(this=parts[0]))
table.set("this", exp.Identifier(this=parts[1]))
if isinstance(table.this, exp.Identifier) and "." in table.name:
if any("." in p.name for p in table.parts):
catalog, db, this, *rest = (
t.cast(t.Optional[exp.Expression], exp.to_identifier(x, quoted=True))
for x in split_num_words(table.name, ".", 3)
exp.to_identifier(p, quoted=True)
for p in split_num_words(".".join(p.name for p in table.parts), ".", 3)
)
if rest and this:
this = exp.Dot.build(t.cast(t.List[exp.Expression], [this, *rest]))
this = exp.Dot.build([this, *rest]) # type: ignore
table = exp.Table(this=this, db=db, catalog=catalog)
table.meta["quoted_table"] = True
@ -527,7 +527,9 @@ class BigQuery(Dialect):
return json_object
def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
def _parse_bracket(
self, this: t.Optional[exp.Expression] = None
) -> t.Optional[exp.Expression]:
bracket = super()._parse_bracket(this)
if this is bracket:
@ -566,6 +568,7 @@ class BigQuery(Dialect):
IGNORE_NULLS_IN_FUNC = True
JSON_PATH_SINGLE_QUOTE_ESCAPE = True
CAN_IMPLEMENT_ARRAY_ANY = True
SUPPORTS_TO_NUMBER = False
NAMED_PLACEHOLDER_TOKEN = "@"
TRANSFORMS = {
@ -588,7 +591,7 @@ class BigQuery(Dialect):
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
exp.DateAdd: date_add_interval_sql("DATE", "ADD"),
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", e.this, e.expression, e.unit or "DAY"
"DATE_DIFF", e.this, e.expression, unit_to_var(e)
),
exp.DateFromParts: rename_func("DATE"),
exp.DateStrToDate: datestrtodate_sql,
@ -607,6 +610,7 @@ class BigQuery(Dialect):
exp.IntDiv: rename_func("DIV"),
exp.JSONFormat: rename_func("TO_JSON_STRING"),
exp.Max: max_or_greatest,
exp.Mod: rename_func("MOD"),
exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)),
exp.MD5Digest: rename_func("MD5"),
exp.Min: min_or_least,
@ -847,10 +851,10 @@ class BigQuery(Dialect):
return inline_array_sql(self, expression)
def bracket_sql(self, expression: exp.Bracket) -> str:
this = self.sql(expression, "this")
this = expression.this
expressions = expression.expressions
if len(expressions) == 1:
if len(expressions) == 1 and this and this.is_type(exp.DataType.Type.STRUCT):
arg = expressions[0]
if arg.type is None:
from sqlglot.optimizer.annotate_types import annotate_types
@ -858,10 +862,10 @@ class BigQuery(Dialect):
arg = annotate_types(arg)
if arg.type and arg.type.this in exp.DataType.TEXT_TYPES:
# BQ doesn't support bracket syntax with string values
return f"{this}.{arg.name}"
# BQ doesn't support bracket syntax with string values for structs
return f"{self.sql(this)}.{arg.name}"
expressions_sql = ", ".join(self.sql(e) for e in expressions)
expressions_sql = self.expressions(expression, flat=True)
offset = expression.args.get("offset")
if offset == 0:
@ -874,7 +878,7 @@ class BigQuery(Dialect):
if expression.args.get("safe"):
expressions_sql = f"SAFE_{expressions_sql}"
return f"{this}[{expressions_sql}]"
return f"{self.sql(this)}[{expressions_sql}]"
def in_unnest_op(self, expression: exp.Unnest) -> str:
return self.sql(expression)

View file

@ -15,7 +15,6 @@ from sqlglot.dialects.dialect import (
rename_func,
var_map_sql,
)
from sqlglot.errors import ParseError
from sqlglot.helper import is_int, seq_get
from sqlglot.tokens import Token, TokenType
@ -49,8 +48,9 @@ class ClickHouse(Dialect):
NULL_ORDERING = "nulls_are_last"
SUPPORTS_USER_DEFINED_TYPES = False
SAFE_DIVISION = True
LOG_BASE_FIRST: t.Optional[bool] = None
ESCAPE_SEQUENCES = {
UNESCAPED_SEQUENCES = {
"\\0": "\0",
}
@ -105,6 +105,7 @@ class ClickHouse(Dialect):
# * select x from t1 union all select x from t2 limit 1;
# * select x from t1 union all (select x from t2 limit 1);
MODIFIERS_ATTACHED_TO_UNION = False
INTERVAL_SPANS = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@ -260,6 +261,11 @@ class ClickHouse(Dialect):
"ArgMax",
]
FUNC_TOKENS = {
*parser.Parser.FUNC_TOKENS,
TokenType.SET,
}
AGG_FUNC_MAPPING = (
lambda functions, suffixes: {
f"{f}{sfx}": (f, sfx) for sfx in (suffixes + [""]) for f in functions
@ -305,6 +311,10 @@ class ClickHouse(Dialect):
TokenType.SETTINGS,
}
ALIAS_TOKENS = parser.Parser.ALIAS_TOKENS - {
TokenType.FORMAT,
}
LOG_DEFAULTS_TO_LN = True
QUERY_MODIFIER_PARSERS = {
@ -316,6 +326,17 @@ class ClickHouse(Dialect):
TokenType.FORMAT: lambda self: ("format", self._advance() or self._parse_id_var()),
}
CONSTRAINT_PARSERS = {
**parser.Parser.CONSTRAINT_PARSERS,
"INDEX": lambda self: self._parse_index_constraint(),
"CODEC": lambda self: self._parse_compress(),
}
SCHEMA_UNNAMED_CONSTRAINTS = {
*parser.Parser.SCHEMA_UNNAMED_CONSTRAINTS,
"INDEX",
}
def _parse_conjunction(self) -> t.Optional[exp.Expression]:
this = super()._parse_conjunction()
@ -381,21 +402,20 @@ class ClickHouse(Dialect):
# https://clickhouse.com/docs/en/sql-reference/statements/select/with/
def _parse_cte(self) -> exp.CTE:
index = self._index
try:
# WITH <identifier> AS <subquery expression>
return super()._parse_cte()
except ParseError:
# WITH <expression> AS <identifier>
self._retreat(index)
# WITH <identifier> AS <subquery expression>
cte: t.Optional[exp.CTE] = self._try_parse(super()._parse_cte)
return self.expression(
if not cte:
# WITH <expression> AS <identifier>
cte = self.expression(
exp.CTE,
this=self._parse_conjunction(),
alias=self._parse_table_alias(),
scalar=True,
)
return cte
def _parse_join_parts(
self,
) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]:
@ -508,6 +528,27 @@ class ClickHouse(Dialect):
self._retreat(index)
return None
def _parse_index_constraint(
self, kind: t.Optional[str] = None
) -> exp.IndexColumnConstraint:
# INDEX name1 expr TYPE type1(args) GRANULARITY value
this = self._parse_id_var()
expression = self._parse_conjunction()
index_type = self._match_text_seq("TYPE") and (
self._parse_function() or self._parse_var()
)
granularity = self._match_text_seq("GRANULARITY") and self._parse_term()
return self.expression(
exp.IndexColumnConstraint,
this=this,
expression=expression,
index_type=index_type,
granularity=granularity,
)
class Generator(generator.Generator):
QUERY_HINTS = False
STRUCT_DELIMITER = ("(", ")")
@ -517,6 +558,7 @@ class ClickHouse(Dialect):
TABLESAMPLE_KEYWORDS = "SAMPLE"
LAST_DAY_SUPPORTS_DATE_PART = False
CAN_IMPLEMENT_ARRAY_ANY = True
SUPPORTS_TO_NUMBER = False
STRING_TYPE_MAPPING = {
exp.DataType.Type.CHAR: "String",
@ -585,6 +627,9 @@ class ClickHouse(Dialect):
exp.Array: inline_array_sql,
exp.CastToStrType: rename_func("CAST"),
exp.CountIf: rename_func("countIf"),
exp.CompressColumnConstraint: lambda self,
e: f"CODEC({self.expressions(e, key='this', flat=True)})",
exp.ComputedColumnConstraint: lambda self, e: f"ALIAS {self.sql(e, 'this')}",
exp.CurrentDate: lambda self, e: self.func("CURRENT_DATE"),
exp.DateAdd: date_delta_sql("DATE_ADD"),
exp.DateDiff: date_delta_sql("DATE_DIFF"),
@ -737,3 +782,15 @@ class ClickHouse(Dialect):
def prewhere_sql(self, expression: exp.PreWhere) -> str:
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('PREWHERE')}{self.sep()}{this}"
def indexcolumnconstraint_sql(self, expression: exp.IndexColumnConstraint) -> str:
this = self.sql(expression, "this")
this = f" {this}" if this else ""
expr = self.sql(expression, "expression")
expr = f" {expr}" if expr else ""
index_type = self.sql(expression, "index_type")
index_type = f" TYPE {index_type}" if index_type else ""
granularity = self.sql(expression, "granularity")
granularity = f" GRANULARITY {granularity}" if granularity else ""
return f"INDEX{this}{expr}{index_type}{granularity}"

View file

@ -31,6 +31,7 @@ class Dialects(str, Enum):
DIALECT = ""
ATHENA = "athena"
BIGQUERY = "bigquery"
CLICKHOUSE = "clickhouse"
DATABRICKS = "databricks"
@ -42,6 +43,7 @@ class Dialects(str, Enum):
ORACLE = "oracle"
POSTGRES = "postgres"
PRESTO = "presto"
PRQL = "prql"
REDSHIFT = "redshift"
SNOWFLAKE = "snowflake"
SPARK = "spark"
@ -108,11 +110,18 @@ class _Dialect(type):
klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()}
base = seq_get(bases, 0)
base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),)
base_parser = (getattr(base, "parser_class", Parser),)
base_generator = (getattr(base, "generator_class", Generator),)
klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
klass.parser_class = getattr(klass, "Parser", Parser)
klass.generator_class = getattr(klass, "Generator", Generator)
klass.tokenizer_class = klass.__dict__.get(
"Tokenizer", type("Tokenizer", base_tokenizer, {})
)
klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {}))
klass.generator_class = klass.__dict__.get(
"Generator", type("Generator", base_generator, {})
)
klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
@ -134,9 +143,31 @@ class _Dialect(type):
klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
if "\\" in klass.tokenizer_class.STRING_ESCAPES:
klass.UNESCAPED_SEQUENCES = {
"\\a": "\a",
"\\b": "\b",
"\\f": "\f",
"\\n": "\n",
"\\r": "\r",
"\\t": "\t",
"\\v": "\v",
"\\\\": "\\",
**klass.UNESCAPED_SEQUENCES,
}
klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()}
if enum not in ("", "bigquery"):
klass.generator_class.SELECT_KINDS = ()
if enum not in ("", "databricks", "hive", "spark", "spark2"):
modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
for modifier in ("cluster", "distribute", "sort"):
modifier_transforms.pop(modifier, None)
klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
if not klass.SUPPORTS_SEMI_ANTI_JOIN:
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
TokenType.ANTI,
@ -189,8 +220,11 @@ class Dialect(metaclass=_Dialect):
False: Disables function name normalization.
"""
LOG_BASE_FIRST = True
"""Whether the base comes first in the `LOG` function."""
LOG_BASE_FIRST: t.Optional[bool] = True
"""
Whether the base comes first in the `LOG` function.
Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
"""
NULL_ORDERING = "nulls_are_small"
"""
@ -226,8 +260,8 @@ class Dialect(metaclass=_Dialect):
If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
"""
ESCAPE_SEQUENCES: t.Dict[str, str] = {}
"""Mapping of an unescaped escape sequence to the corresponding character."""
UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
"""Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
PSEUDOCOLUMNS: t.Set[str] = set()
"""
@ -266,7 +300,7 @@ class Dialect(metaclass=_Dialect):
INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
INVERSE_TIME_TRIE: t.Dict = {}
INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
ESCAPED_SEQUENCES: t.Dict[str, str] = {}
# Delimiters for string literals and identifiers
QUOTE_START = "'"
@ -587,13 +621,21 @@ def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) ->
return ""
def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
def str_position_sql(
self: Generator, expression: exp.StrPosition, generate_instance: bool = False
) -> str:
this = self.sql(expression, "this")
substr = self.sql(expression, "substr")
position = self.sql(expression, "position")
instance = expression.args.get("instance") if generate_instance else None
position_offset = ""
if position:
return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
return f"STRPOS({this}, {substr})"
# Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
this = self.func("SUBSTR", this, position)
position_offset = f" + {position} - 1"
return self.func("STRPOS", this, substr, instance) + position_offset
def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
@ -689,9 +731,7 @@ def build_date_delta_with_interval(
if expression and expression.is_string:
expression = exp.Literal.number(expression.this)
return expression_class(
this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
)
return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
return _builder
@ -710,18 +750,14 @@ def date_add_interval_sql(
) -> t.Callable[[Generator, exp.Expression], str]:
def func(self: Generator, expression: exp.Expression) -> str:
this = self.sql(expression, "this")
unit = expression.args.get("unit")
unit = exp.var(unit.name.upper() if unit else "DAY")
interval = exp.Interval(this=expression.expression, unit=unit)
interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
return f"{data_type}_{kind}({this}, {self.sql(interval)})"
return func
def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
return self.func(
"DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
)
return self.func("DATE_TRUNC", unit_to_str(expression), expression.this)
def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
@ -956,7 +992,7 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE
return self.func(
name,
exp.var(expression.text("unit").upper() or "DAY"),
unit_to_var(expression),
expression.expression,
expression.this,
)
@ -964,6 +1000,24 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE
return _delta_sql
def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
unit = expression.args.get("unit")
if isinstance(unit, exp.Placeholder):
return unit
if unit:
return exp.Literal.string(unit.name)
return exp.Literal.string(default) if default else None
def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
unit = expression.args.get("unit")
if isinstance(unit, (exp.Var, exp.Placeholder)):
return unit
return exp.Var(this=default) if default else None
def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
trunc_curr_date = exp.func("date_trunc", "month", expression.this)
plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
@ -998,7 +1052,7 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
def build_json_extract_path(
expr_type: t.Type[F], zero_based_indexing: bool = True
expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
) -> t.Callable[[t.List], F]:
def _builder(args: t.List) -> F:
segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
@ -1018,7 +1072,11 @@ def build_json_extract_path(
# This is done to avoid failing in the expression validator due to the arg count
del args[2:]
return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments))
return expr_type(
this=seq_get(args, 0),
expression=exp.JSONPath(expressions=segments),
only_json_types=arrow_req_json_type,
)
return _builder
@ -1070,3 +1128,12 @@ def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> s
unnest = exp.Unnest(expressions=[expression.this])
filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
return self.sql(exp.Array(expressions=[filtered]))
def to_number_with_nls_param(self, expression: exp.ToNumber) -> str:
return self.func(
"TO_NUMBER",
expression.this,
expression.args.get("format"),
expression.args.get("nlsparam"),
)

View file

@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import (
build_timestamp_trunc,
rename_func,
time_format,
unit_to_str,
)
from sqlglot.dialects.mysql import MySQL
@ -27,7 +28,7 @@ class Doris(MySQL):
}
class Generator(MySQL.Generator):
CAST_MAPPING = {}
LAST_DAY_SUPPORTS_DATE_PART = False
TYPE_MAPPING = {
**MySQL.Generator.TYPE_MAPPING,
@ -36,8 +37,7 @@ class Doris(MySQL):
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
}
LAST_DAY_SUPPORTS_DATE_PART = False
CAST_MAPPING = {}
TIMESTAMP_FUNC_TYPES = set()
TRANSFORMS = {
@ -49,9 +49,7 @@ class Doris(MySQL):
exp.ArrayAgg: rename_func("COLLECT_LIST"),
exp.ArrayUniqueAgg: rename_func("COLLECT_SET"),
exp.CurrentTimestamp: lambda self, _: self.func("NOW"),
exp.DateTrunc: lambda self, e: self.func(
"DATE_TRUNC", e.this, "'" + e.text("unit") + "'"
),
exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, unit_to_str(e)),
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.Map: rename_func("ARRAY_MAP"),
@ -63,9 +61,7 @@ class Doris(MySQL):
exp.TsOrDsAdd: lambda self, e: self.func("DATE_ADD", e.this, e.expression),
exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this),
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimestampTrunc: lambda self, e: self.func(
"DATE_TRUNC", e.this, "'" + e.text("unit") + "'"
),
exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, unit_to_str(e)),
exp.UnixToStr: lambda self, e: self.func(
"FROM_UNIXTIME", e.this, time_format("doris")(self, e)
),

View file

@ -1,6 +1,5 @@
from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
@ -12,18 +11,10 @@ from sqlglot.dialects.dialect import (
str_position_sql,
timestrtotime_sql,
)
from sqlglot.dialects.mysql import date_add_sql
from sqlglot.transforms import preprocess, move_schema_columns_to_partitioned_by
def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.DateSub], str]:
def func(self: Drill.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = exp.var(expression.text("unit").upper() or "DAY")
return self.func(f"DATE_{kind}", this, exp.Interval(this=expression.expression, unit=unit))
return func
def _str_to_date(self: Drill.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
@ -84,7 +75,6 @@ class Drill(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE_FORMAT": build_formatted_time(exp.TimeToStr, "drill"),
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
"TO_CHAR": build_formatted_time(exp.TimeToStr, "drill"),
}
@ -124,9 +114,9 @@ class Drill(Dialect):
exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
exp.ArraySize: rename_func("REPEATED_COUNT"),
exp.Create: preprocess([move_schema_columns_to_partitioned_by]),
exp.DateAdd: _date_add_sql("ADD"),
exp.DateAdd: date_add_sql("ADD"),
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _date_add_sql("SUB"),
exp.DateSub: date_add_sql("SUB"),
exp.DateToDi: lambda self,
e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.DATEINT_FORMAT}) AS INT)",
exp.DiToDate: lambda self,

View file

@ -26,6 +26,7 @@ from sqlglot.dialects.dialect import (
str_to_time_sql,
timestamptrunc_sql,
timestrtotime_sql,
unit_to_var,
)
from sqlglot.helper import flatten, seq_get
from sqlglot.tokens import TokenType
@ -33,15 +34,16 @@ from sqlglot.tokens import TokenType
def _ts_or_ds_add_sql(self: DuckDB.Generator, expression: exp.TsOrDsAdd) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
interval = self.sql(exp.Interval(this=expression.expression, unit=unit))
interval = self.sql(exp.Interval(this=expression.expression, unit=unit_to_var(expression)))
return f"CAST({this} AS {self.sql(expression.return_type)}) + {interval}"
def _date_delta_sql(self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
def _date_delta_sql(
self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub | exp.TimeAdd
) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
op = "+" if isinstance(expression, exp.DateAdd) else "-"
unit = unit_to_var(expression)
op = "+" if isinstance(expression, (exp.DateAdd, exp.TimeAdd)) else "-"
return f"{this} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
@ -186,6 +188,11 @@ class DuckDB(Dialect):
return super().to_json_path(path)
class Tokenizer(tokens.Tokenizer):
HEREDOC_STRINGS = ["$"]
HEREDOC_TAG_IS_IDENTIFIER = True
HEREDOC_STRING_ALTERNATIVE = TokenType.PARAMETER
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"//": TokenType.DIV,
@ -199,6 +206,7 @@ class DuckDB(Dialect):
"LOGICAL": TokenType.BOOLEAN,
"ONLY": TokenType.ONLY,
"PIVOT_WIDER": TokenType.PIVOT,
"POSITIONAL": TokenType.POSITIONAL,
"SIGNED": TokenType.INT,
"STRING": TokenType.VARCHAR,
"UBIGINT": TokenType.UBIGINT,
@ -227,8 +235,8 @@ class DuckDB(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARRAY_HAS": exp.ArrayContains.from_arg_list,
"ARRAY_SORT": exp.SortArray.from_arg_list,
"ARRAY_REVERSE_SORT": _build_sort_array_desc,
"ARRAY_SORT": exp.SortArray.from_arg_list,
"DATEDIFF": _build_date_diff,
"DATE_DIFF": _build_date_diff,
"DATE_TRUNC": date_trunc_to_time,
@ -285,6 +293,11 @@ class DuckDB(Dialect):
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
FUNCTION_PARSERS.pop("DECODE")
NO_PAREN_FUNCTION_PARSERS = {
**parser.Parser.NO_PAREN_FUNCTION_PARSERS,
"MAP": lambda self: self._parse_map(),
}
TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
TokenType.SEMI,
TokenType.ANTI,
@ -299,6 +312,13 @@ class DuckDB(Dialect):
),
}
def _parse_map(self) -> exp.ToMap | exp.Map:
if self._match(TokenType.L_BRACE, advance=False):
return self.expression(exp.ToMap, this=self._parse_bracket())
args = self._parse_wrapped_csv(self._parse_conjunction)
return self.expression(exp.Map, keys=seq_get(args, 0), values=seq_get(args, 1))
def _parse_types(
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
) -> t.Optional[exp.Expression]:
@ -345,6 +365,7 @@ class DuckDB(Dialect):
SUPPORTS_CREATE_TABLE_LIKE = False
MULTI_ARG_DISTINCT = False
CAN_IMPLEMENT_ARRAY_ANY = True
SUPPORTS_TO_NUMBER = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@ -425,6 +446,7 @@ class DuckDB(Dialect):
"EPOCH", self.func("STRPTIME", e.this, self.format_time(e))
),
exp.Struct: _struct_sql,
exp.TimeAdd: _date_delta_sql,
exp.Timestamp: no_timestamp_sql,
exp.TimestampDiff: lambda self, e: self.func(
"DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this
@ -478,7 +500,7 @@ class DuckDB(Dialect):
STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"}
UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Paren)
UNWRAPPED_INTERVAL_VALUES = (exp.Literal, exp.Paren)
# DuckDB doesn't generally support CREATE TABLE .. properties
# https://duckdb.org/docs/sql/statements/create_table.html
@ -569,3 +591,9 @@ class DuckDB(Dialect):
return rename_func("RANGE")(self, expression)
return super().generateseries_sql(expression)
def bracket_sql(self, expression: exp.Bracket) -> str:
if isinstance(expression.this, exp.Array):
expression.this.replace(exp.paren(expression.this))
return super().bracket_sql(expression)

View file

@ -319,7 +319,9 @@ class Hive(Dialect):
"TO_DATE": build_formatted_time(exp.TsOrDsToDate, "hive"),
"TO_JSON": exp.JSONFormat.from_arg_list,
"UNBASE64": exp.FromBase64.from_arg_list,
"UNIX_TIMESTAMP": build_formatted_time(exp.StrToUnix, "hive", True),
"UNIX_TIMESTAMP": lambda args: build_formatted_time(exp.StrToUnix, "hive", True)(
args or [exp.CurrentTimestamp()]
),
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
}
@ -431,6 +433,7 @@ class Hive(Dialect):
NVL2_SUPPORTED = False
LAST_DAY_SUPPORTS_DATE_PART = False
JSON_PATH_SINGLE_QUOTE_ESCAPE = True
SUPPORTS_TO_NUMBER = False
EXPRESSIONS_WITHOUT_NESTED_CTES = {
exp.Insert,
@ -472,7 +475,7 @@ class Hive(Dialect):
exp.ArgMax: arg_max_or_min_no_count("MAX_BY"),
exp.ArgMin: arg_max_or_min_no_count("MIN_BY"),
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayJoin: lambda self, e: self.func("CONCAT_WS", e.expression, e.this),
exp.ArrayToString: lambda self, e: self.func("CONCAT_WS", e.expression, e.this),
exp.ArraySize: rename_func("SIZE"),
exp.ArraySort: _array_sort_sql,
exp.With: no_recursive_cte_sql,

View file

@ -23,6 +23,7 @@ from sqlglot.dialects.dialect import (
build_date_delta_with_interval,
rename_func,
strposition_to_locate_sql,
unit_to_var,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@ -109,14 +110,14 @@ def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str:
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
def _date_add_sql(
def date_add_sql(
kind: str,
) -> t.Callable[[MySQL.Generator, exp.Expression], str]:
def func(self: MySQL.Generator, expression: exp.Expression) -> str:
this = self.sql(expression, "this")
unit = expression.text("unit").upper() or "DAY"
return (
f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
) -> t.Callable[[generator.Generator, exp.Expression], str]:
def func(self: generator.Generator, expression: exp.Expression) -> str:
return self.func(
f"DATE_{kind}",
expression.this,
exp.Interval(this=expression.expression, unit=unit_to_var(expression)),
)
return func
@ -291,6 +292,7 @@ class MySQL(Dialect):
"DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
"FROM_UNIXTIME": build_formatted_time(exp.UnixToTime, "mysql"),
"ISNULL": isnull_to_is_null,
"LOCATE": locate_to_strposition,
"MAKETIME": exp.TimeFromParts.from_arg_list,
@ -319,11 +321,7 @@ class MySQL(Dialect):
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
"CHAR": lambda self: self._parse_chr(),
"GROUP_CONCAT": lambda self: self.expression(
exp.GroupConcat,
this=self._parse_lambda(),
separator=self._match(TokenType.SEPARATOR) and self._parse_field(),
),
"GROUP_CONCAT": lambda self: self._parse_group_concat(),
# https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values
"VALUES": lambda self: self.expression(
exp.Anonymous, this="VALUES", expressions=[self._parse_id_var()]
@ -412,6 +410,11 @@ class MySQL(Dialect):
"SPATIAL": lambda self: self._parse_index_constraint(kind="SPATIAL"),
}
ALTER_PARSERS = {
**parser.Parser.ALTER_PARSERS,
"MODIFY": lambda self: self._parse_alter_table_alter(),
}
SCHEMA_UNNAMED_CONSTRAINTS = {
*parser.Parser.SCHEMA_UNNAMED_CONSTRAINTS,
"FULLTEXT",
@ -458,7 +461,7 @@ class MySQL(Dialect):
this = self._parse_id_var(any_token=False)
index_type = self._match(TokenType.USING) and self._advance_any() and self._prev.text
schema = self._parse_schema()
expressions = self._parse_wrapped_csv(self._parse_ordered)
options = []
while True:
@ -478,9 +481,6 @@ class MySQL(Dialect):
elif self._match_text_seq("ENGINE_ATTRIBUTE"):
self._match(TokenType.EQ)
opt = exp.IndexConstraintOption(engine_attr=self._parse_string())
elif self._match_text_seq("ENGINE_ATTRIBUTE"):
self._match(TokenType.EQ)
opt = exp.IndexConstraintOption(engine_attr=self._parse_string())
elif self._match_text_seq("SECONDARY_ENGINE_ATTRIBUTE"):
self._match(TokenType.EQ)
opt = exp.IndexConstraintOption(secondary_engine_attr=self._parse_string())
@ -495,7 +495,7 @@ class MySQL(Dialect):
return self.expression(
exp.IndexColumnConstraint,
this=this,
schema=schema,
expressions=expressions,
kind=kind,
index_type=index_type,
options=options,
@ -617,6 +617,39 @@ class MySQL(Dialect):
return self.expression(exp.Chr, **kwargs)
def _parse_group_concat(self) -> t.Optional[exp.Expression]:
def concat_exprs(
node: t.Optional[exp.Expression], exprs: t.List[exp.Expression]
) -> exp.Expression:
if isinstance(node, exp.Distinct) and len(node.expressions) > 1:
concat_exprs = [
self.expression(exp.Concat, expressions=node.expressions, safe=True)
]
node.set("expressions", concat_exprs)
return node
if len(exprs) == 1:
return exprs[0]
return self.expression(exp.Concat, expressions=args, safe=True)
args = self._parse_csv(self._parse_lambda)
if args:
order = args[-1] if isinstance(args[-1], exp.Order) else None
if order:
# Order By is the last (or only) expression in the list and has consumed the 'expr' before it,
# remove 'expr' from exp.Order and add it back to args
args[-1] = order.this
order.set("this", concat_exprs(order.this, args))
this = order or concat_exprs(args[0], args)
else:
this = None
separator = self._parse_field() if self._match(TokenType.SEPARATOR) else None
return self.expression(exp.GroupConcat, this=this, separator=separator)
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
NULL_ORDERING_SUPPORTED = None
@ -630,6 +663,7 @@ class MySQL(Dialect):
JSON_TYPE_REQUIRED_FOR_EXTRACTION = True
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
JSON_KEY_VALUE_PAIR_SEP = ","
SUPPORTS_TO_NUMBER = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@ -637,9 +671,9 @@ class MySQL(Dialect):
exp.DateDiff: _remove_ts_or_ds_to_date(
lambda self, e: self.func("DATEDIFF", e.this, e.expression), ("this", "expression")
),
exp.DateAdd: _remove_ts_or_ds_to_date(_date_add_sql("ADD")),
exp.DateAdd: _remove_ts_or_ds_to_date(date_add_sql("ADD")),
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _remove_ts_or_ds_to_date(_date_add_sql("SUB")),
exp.DateSub: _remove_ts_or_ds_to_date(date_add_sql("SUB")),
exp.DateTrunc: _date_trunc_sql,
exp.Day: _remove_ts_or_ds_to_date(),
exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")),
@ -672,7 +706,7 @@ class MySQL(Dialect):
exp.TimeFromParts: rename_func("MAKETIME"),
exp.TimestampAdd: date_add_interval_sql("DATE", "ADD"),
exp.TimestampDiff: lambda self, e: self.func(
"TIMESTAMPDIFF", e.text("unit"), e.expression, e.this
"TIMESTAMPDIFF", unit_to_var(e), e.expression, e.this
),
exp.TimestampSub: date_add_interval_sql("DATE", "SUB"),
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
@ -682,9 +716,10 @@ class MySQL(Dialect):
),
exp.Trim: _trim_sql,
exp.TryCast: no_trycast_sql,
exp.TsOrDsAdd: _date_add_sql("ADD"),
exp.TsOrDsAdd: date_add_sql("ADD"),
exp.TsOrDsDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression),
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
exp.UnixToTime: lambda self, e: self.func("FROM_UNIXTIME", e.this, self.format_time(e)),
exp.Week: _remove_ts_or_ds_to_date(),
exp.WeekOfYear: _remove_ts_or_ds_to_date(rename_func("WEEKOFYEAR")),
exp.Year: _remove_ts_or_ds_to_date(),
@ -751,11 +786,6 @@ class MySQL(Dialect):
result = f"{result} UNSIGNED"
return result
def xor_sql(self, expression: exp.Xor) -> str:
if expression.expressions:
return self.expressions(expression, sep=" XOR ")
return super().xor_sql(expression)
def jsonarraycontains_sql(self, expression: exp.JSONArrayContains) -> str:
return f"{self.sql(expression, 'this')} MEMBER OF({self.sql(expression, 'expression')})"

View file

@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import (
build_formatted_time,
no_ilike_sql,
rename_func,
to_number_with_nls_param,
trim_sql,
)
from sqlglot.helper import seq_get
@ -246,6 +247,7 @@ class Oracle(Dialect):
exp.TemporaryProperty: lambda _, e: f"{e.name or 'GLOBAL'} TEMPORARY",
exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.ToNumber: to_number_with_nls_param,
exp.Trim: trim_sql,
exp.UnixToTime: lambda self,
e: f"TO_DATE('1970-01-01', 'YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",

View file

@ -278,6 +278,7 @@ class Postgres(Dialect):
"REVOKE": TokenType.COMMAND,
"SERIAL": TokenType.SERIAL,
"SMALLSERIAL": TokenType.SMALLSERIAL,
"NAME": TokenType.NAME,
"TEMP": TokenType.TEMPORARY,
"CSTRING": TokenType.PSEUDO_TYPE,
"OID": TokenType.OBJECT_IDENTIFIER,
@ -356,6 +357,16 @@ class Postgres(Dialect):
JSON_ARROWS_REQUIRE_JSON_TYPE = True
COLUMN_OPERATORS = {
**parser.Parser.COLUMN_OPERATORS,
TokenType.ARROW: lambda self, this, path: build_json_extract_path(
exp.JSONExtract, arrow_req_json_type=self.JSON_ARROWS_REQUIRE_JSON_TYPE
)([this, path]),
TokenType.DARROW: lambda self, this, path: build_json_extract_path(
exp.JSONExtractScalar, arrow_req_json_type=self.JSON_ARROWS_REQUIRE_JSON_TYPE
)([this, path]),
}
def _parse_operator(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
while True:
if not self._match(TokenType.L_PAREN):
@ -484,6 +495,7 @@ class Postgres(Dialect):
]
),
exp.StrPosition: str_position_sql,
exp.StrToDate: lambda self, e: self.func("TO_DATE", e.this, self.format_time(e)),
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
exp.StructExtract: struct_extract_sql,
exp.Substring: _substring_sql,

View file

@ -22,10 +22,13 @@ from sqlglot.dialects.dialect import (
rename_func,
right_to_substring_sql,
struct_extract_sql,
str_position_sql,
timestamptrunc_sql,
timestrtotime_sql,
ts_or_ds_add_cast,
unit_to_str,
)
from sqlglot.dialects.hive import Hive
from sqlglot.dialects.mysql import MySQL
from sqlglot.helper import apply_index_offset, seq_get
from sqlglot.tokens import TokenType
@ -93,14 +96,14 @@ def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate)
def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str:
expression = ts_or_ds_add_cast(expression)
unit = exp.Literal.string(expression.text("unit") or "DAY")
unit = unit_to_str(expression)
return self.func("DATE_ADD", unit, expression.expression, expression.this)
def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> str:
this = exp.cast(expression.this, "TIMESTAMP")
expr = exp.cast(expression.expression, "TIMESTAMP")
unit = exp.Literal.string(expression.text("unit") or "DAY")
unit = unit_to_str(expression)
return self.func("DATE_DIFF", unit, expr, this)
@ -196,6 +199,7 @@ class Presto(Dialect):
SUPPORTS_SEMI_ANTI_JOIN = False
TYPED_DIVISION = True
TABLESAMPLE_SIZE_IS_PERCENT = True
LOG_BASE_FIRST: t.Optional[bool] = None
# https://github.com/trinodb/trino/issues/17
# https://github.com/trinodb/trino/issues/12289
@ -289,6 +293,7 @@ class Presto(Dialect):
SUPPORTS_SINGLE_ARG_CONCAT = False
LIKE_PROPERTY_INSIDE_SCHEMA = True
MULTI_ARG_DISTINCT = False
SUPPORTS_TO_NUMBER = False
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION,
@ -323,6 +328,7 @@ class Presto(Dialect):
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayContains: rename_func("CONTAINS"),
exp.ArraySize: rename_func("CARDINALITY"),
exp.ArrayToString: rename_func("ARRAY_JOIN"),
exp.ArrayUniqueAgg: rename_func("SET_AGG"),
exp.AtTimeZone: rename_func("AT_TIMEZONE"),
exp.BitwiseAnd: lambda self, e: self.func("BITWISE_AND", e.this, e.expression),
@ -339,19 +345,19 @@ class Presto(Dialect):
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: lambda self, e: self.func(
"DATE_ADD",
exp.Literal.string(e.text("unit") or "DAY"),
unit_to_str(e),
_to_int(e.expression),
e.this,
),
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.expression, e.this
"DATE_DIFF", unit_to_str(e), e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self,
e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
exp.DateSub: lambda self, e: self.func(
"DATE_ADD",
exp.Literal.string(e.text("unit") or "DAY"),
unit_to_str(e),
_to_int(e.expression * -1),
e.this,
),
@ -397,13 +403,10 @@ class Presto(Dialect):
]
),
exp.SortArray: _no_sort_array,
exp.StrPosition: rename_func("STRPOS"),
exp.StrPosition: lambda self, e: str_position_sql(self, e, generate_instance=True),
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
exp.StrToMap: rename_func("SPLIT_TO_MAP"),
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: lambda self, e: self.func(
"TO_UNIXTIME", self.func("DATE_PARSE", e.this, self.format_time(e))
),
exp.StructExtract: struct_extract_sql,
exp.Table: transforms.preprocess([_unnest_sequence]),
exp.Timestamp: no_timestamp_sql,
@ -436,6 +439,22 @@ class Presto(Dialect):
exp.Xor: bool_xor_sql,
}
def strtounix_sql(self, expression: exp.StrToUnix) -> str:
# Since `TO_UNIXTIME` requires a `TIMESTAMP`, we need to parse the argument into one.
# To do this, we first try to `DATE_PARSE` it, but since this can fail when there's a
# timezone involved, we wrap it in a `TRY` call and use `PARSE_DATETIME` as a fallback,
# which seems to be using the same time mapping as Hive, as per:
# https://joda-time.sourceforge.net/apidocs/org/joda/time/format/DateTimeFormat.html
value_as_text = exp.cast(expression.this, "text")
parse_without_tz = self.func("DATE_PARSE", value_as_text, self.format_time(expression))
parse_with_tz = self.func(
"PARSE_DATETIME",
value_as_text,
self.format_time(expression, Hive.INVERSE_TIME_MAPPING, Hive.INVERSE_TIME_TRIE),
)
coalesced = self.func("COALESCE", self.func("TRY", parse_without_tz), parse_with_tz)
return self.func("TO_UNIXTIME", coalesced)
def bracket_sql(self, expression: exp.Bracket) -> str:
if expression.args.get("safe"):
return self.func(
@ -481,8 +500,7 @@ class Presto(Dialect):
return f"CAST(ROW({', '.join(values)}) AS ROW({', '.join(schema)}))"
def interval_sql(self, expression: exp.Interval) -> str:
unit = self.sql(expression, "unit")
if expression.this and unit.startswith("WEEK"):
if expression.this and expression.text("unit").upper().startswith("WEEK"):
return f"({expression.this.name} * INTERVAL '7' DAY)"
return super().interval_sql(expression)

109
sqlglot/dialects/prql.py Normal file
View file

@ -0,0 +1,109 @@
from __future__ import annotations
import typing as t
from sqlglot import exp, parser, tokens
from sqlglot.dialects.dialect import Dialect
from sqlglot.tokens import TokenType
class PRQL(Dialect):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ["`"]
QUOTES = ["'", '"']
SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
"=": TokenType.ALIAS,
"'": TokenType.QUOTE,
'"': TokenType.QUOTE,
"`": TokenType.IDENTIFIER,
"#": TokenType.COMMENT,
}
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
}
class Parser(parser.Parser):
TRANSFORM_PARSERS = {
"DERIVE": lambda self, query: self._parse_selection(query),
"SELECT": lambda self, query: self._parse_selection(query, append=False),
"TAKE": lambda self, query: self._parse_take(query),
}
def _parse_statement(self) -> t.Optional[exp.Expression]:
expression = self._parse_expression()
expression = expression if expression else self._parse_query()
return expression
def _parse_query(self) -> t.Optional[exp.Query]:
from_ = self._parse_from()
if not from_:
return None
query = exp.select("*").from_(from_, copy=False)
while self._match_texts(self.TRANSFORM_PARSERS):
query = self.TRANSFORM_PARSERS[self._prev.text.upper()](self, query)
return query
def _parse_selection(self, query: exp.Query, append: bool = True) -> exp.Query:
if self._match(TokenType.L_BRACE):
selects = self._parse_csv(self._parse_expression)
if not self._match(TokenType.R_BRACE, expression=query):
self.raise_error("Expecting }")
else:
expression = self._parse_expression()
selects = [expression] if expression else []
projections = {
select.alias_or_name: select.this if isinstance(select, exp.Alias) else select
for select in query.selects
}
selects = [
select.transform(
lambda s: (projections[s.name].copy() if s.name in projections else s)
if isinstance(s, exp.Column)
else s,
copy=False,
)
for select in selects
]
return query.select(*selects, append=append, copy=False)
def _parse_take(self, query: exp.Query) -> t.Optional[exp.Query]:
num = self._parse_number() # TODO: TAKE for ranges a..b
return query.limit(num) if num else None
def _parse_expression(self) -> t.Optional[exp.Expression]:
if self._next and self._next.token_type == TokenType.ALIAS:
alias = self._parse_id_var(True)
self._match(TokenType.ALIAS)
return self.expression(exp.Alias, this=self._parse_conjunction(), alias=alias)
return self._parse_conjunction()
def _parse_table(
self,
schema: bool = False,
joins: bool = False,
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
parse_bracket: bool = False,
is_db_reference: bool = False,
) -> t.Optional[exp.Expression]:
return self._parse_table_parts()
def _parse_from(
self, joins: bool = False, skip_from_token: bool = False
) -> t.Optional[exp.From]:
if not skip_from_token and not self._match(TokenType.FROM):
return None
return self.expression(
exp.From, comments=self._prev_comments, this=self._parse_table(joins=joins)
)

View file

@ -92,23 +92,6 @@ class Redshift(Postgres):
return self.expression(exp.Pivot, this=table, unpivot=True) if unpivot else table
def _parse_types(
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
) -> t.Optional[exp.Expression]:
this = super()._parse_types(
check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
)
if (
isinstance(this, exp.DataType)
and this.is_type("varchar")
and this.expressions
and this.expressions[0].this == exp.column("MAX")
):
this.set("expressions", [exp.var("MAX")])
return this
def _parse_convert(
self, strict: bool, safe: t.Optional[bool] = None
) -> t.Optional[exp.Expression]:
@ -153,6 +136,7 @@ class Redshift(Postgres):
NVL2_SUPPORTED = True
LAST_DAY_SUPPORTS_DATE_PART = False
CAN_IMPLEMENT_ARRAY_ANY = False
MULTI_ARG_DISTINCT = True
TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING,
@ -187,9 +171,13 @@ class Redshift(Postgres):
),
exp.SortKeyProperty: lambda self,
e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
exp.StartsWith: lambda self,
e: f"{self.sql(e.this)} LIKE {self.sql(e.expression)} || '%'",
exp.TableSample: no_tablesample_sql,
exp.TsOrDsAdd: date_delta_sql("DATEADD"),
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
exp.UnixToTime: lambda self,
e: f"(TIMESTAMP 'epoch' + {self.sql(e.this)} * INTERVAL '1 SECOND')",
}
# Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots

View file

@ -20,8 +20,7 @@ from sqlglot.dialects.dialect import (
timestrtotime_sql,
var_map_sql,
)
from sqlglot.expressions import Literal
from sqlglot.helper import flatten, is_int, seq_get
from sqlglot.helper import flatten, is_float, is_int, seq_get
from sqlglot.tokens import TokenType
if t.TYPE_CHECKING:
@ -29,33 +28,35 @@ if t.TYPE_CHECKING:
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
def _build_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]:
if len(args) == 2:
first_arg, second_arg = args
if second_arg.is_string:
# case: <string_expr> [ , <format> ]
return build_formatted_time(exp.StrToTime, "snowflake")(args)
return exp.UnixToTime(this=first_arg, scale=second_arg)
def _build_datetime(
name: str, kind: exp.DataType.Type, safe: bool = False
) -> t.Callable[[t.List], exp.Func]:
def _builder(args: t.List) -> exp.Func:
value = seq_get(args, 0)
from sqlglot.optimizer.simplify import simplify_literals
if isinstance(value, exp.Literal):
int_value = is_int(value.this)
# The first argument might be an expression like 40 * 365 * 86400, so we try to
# reduce it using `simplify_literals` first and then check if it's a Literal.
first_arg = seq_get(args, 0)
if not isinstance(simplify_literals(first_arg, root=True), Literal):
# case: <variant_expr> or other expressions such as columns
return exp.TimeStrToTime.from_arg_list(args)
# Converts calls like `TO_TIME('01:02:03')` into casts
if len(args) == 1 and value.is_string and not int_value:
return exp.cast(value, kind)
if first_arg.is_string:
if is_int(first_arg.this):
# case: <integer>
return exp.UnixToTime.from_arg_list(args)
# Handles `TO_TIMESTAMP(str, fmt)` and `TO_TIMESTAMP(num, scale)` as special
# cases so we can transpile them, since they're relatively common
if kind == exp.DataType.Type.TIMESTAMP:
if int_value:
return exp.UnixToTime(this=value, scale=seq_get(args, 1))
if not is_float(value.this):
return build_formatted_time(exp.StrToTime, "snowflake")(args)
# case: <date_expr>
return build_formatted_time(exp.StrToTime, "snowflake", default=True)(args)
if len(args) == 2 and kind == exp.DataType.Type.DATE:
formatted_exp = build_formatted_time(exp.TsOrDsToDate, "snowflake")(args)
formatted_exp.set("safe", safe)
return formatted_exp
# case: <numeric_expr>
return exp.UnixToTime.from_arg_list(args)
return exp.Anonymous(this=name, expressions=args)
return _builder
def _build_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
@ -77,6 +78,17 @@ def _build_datediff(args: t.List) -> exp.DateDiff:
)
def _build_date_time_add(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
def _builder(args: t.List) -> E:
return expr_type(
this=seq_get(args, 2),
expression=seq_get(args, 1),
unit=_map_date_part(seq_get(args, 0)),
)
return _builder
# https://docs.snowflake.com/en/sql-reference/functions/div0
def _build_if_from_div0(args: t.List) -> exp.If:
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
@ -97,14 +109,6 @@ def _build_if_from_nullifzero(args: t.List) -> exp.If:
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
def _datatype_sql(self: Snowflake.Generator, expression: exp.DataType) -> str:
if expression.is_type("array"):
return "ARRAY"
elif expression.is_type("map"):
return "OBJECT"
return self.datatype_sql(expression)
def _regexpilike_sql(self: Snowflake.Generator, expression: exp.RegexpILike) -> str:
flag = expression.text("flag")
@ -258,6 +262,25 @@ def _unqualify_unpivot_columns(expression: exp.Expression) -> exp.Expression:
return expression
def _flatten_structured_types_unless_iceberg(expression: exp.Expression) -> exp.Expression:
assert isinstance(expression, exp.Create)
def _flatten_structured_type(expression: exp.DataType) -> exp.DataType:
if expression.this in exp.DataType.NESTED_TYPES:
expression.set("expressions", None)
return expression
props = expression.args.get("properties")
if isinstance(expression.this, exp.Schema) and not (props and props.find(exp.IcebergProperty)):
for schema_expression in expression.this.expressions:
if isinstance(schema_expression, exp.ColumnDef):
column_type = schema_expression.kind
if isinstance(column_type, exp.DataType):
column_type.transform(_flatten_structured_type, copy=False)
return expression
class Snowflake(Dialect):
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
@ -312,7 +335,13 @@ class Snowflake(Dialect):
class Parser(parser.Parser):
IDENTIFY_PIVOT_STRINGS = True
ID_VAR_TOKENS = {
*parser.Parser.ID_VAR_TOKENS,
TokenType.MATCH_CONDITION,
}
TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS | {TokenType.WINDOW}
TABLE_ALIAS_TOKENS.discard(TokenType.MATCH_CONDITION)
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@ -327,17 +356,13 @@ class Snowflake(Dialect):
end=exp.Sub(this=seq_get(args, 1), expression=exp.Literal.number(1)),
step=seq_get(args, 2),
),
"ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list,
"BITXOR": binary_from_function(exp.BitwiseXor),
"BIT_XOR": binary_from_function(exp.BitwiseXor),
"BOOLXOR": binary_from_function(exp.Xor),
"CONVERT_TIMEZONE": _build_convert_timezone,
"DATE": _build_datetime("DATE", exp.DataType.Type.DATE),
"DATE_TRUNC": _date_trunc_to_time,
"DATEADD": lambda args: exp.DateAdd(
this=seq_get(args, 2),
expression=seq_get(args, 1),
unit=_map_date_part(seq_get(args, 0)),
),
"DATEADD": _build_date_time_add(exp.DateAdd),
"DATEDIFF": _build_datediff,
"DIV0": _build_if_from_div0,
"FLATTEN": exp.Explode.from_arg_list,
@ -349,17 +374,34 @@ class Snowflake(Dialect):
this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1))
),
"LISTAGG": exp.GroupConcat.from_arg_list,
"MEDIAN": lambda args: exp.PercentileCont(
this=seq_get(args, 0), expression=exp.Literal.number(0.5)
),
"NULLIFZERO": _build_if_from_nullifzero,
"OBJECT_CONSTRUCT": _build_object_construct,
"REGEXP_REPLACE": _build_regexp_replace,
"REGEXP_SUBSTR": exp.RegexpExtract.from_arg_list,
"RLIKE": exp.RegexpLike.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"TIMEADD": _build_date_time_add(exp.TimeAdd),
"TIMEDIFF": _build_datediff,
"TIMESTAMPADD": _build_date_time_add(exp.DateAdd),
"TIMESTAMPDIFF": _build_datediff,
"TIMESTAMPFROMPARTS": _build_timestamp_from_parts,
"TIMESTAMP_FROM_PARTS": _build_timestamp_from_parts,
"TO_TIMESTAMP": _build_to_timestamp,
"TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DataType.Type.DATE, safe=True),
"TO_DATE": _build_datetime("TO_DATE", exp.DataType.Type.DATE),
"TO_NUMBER": lambda args: exp.ToNumber(
this=seq_get(args, 0),
format=seq_get(args, 1),
precision=seq_get(args, 2),
scale=seq_get(args, 3),
),
"TO_TIME": _build_datetime("TO_TIME", exp.DataType.Type.TIME),
"TO_TIMESTAMP": _build_datetime("TO_TIMESTAMP", exp.DataType.Type.TIMESTAMP),
"TO_TIMESTAMP_LTZ": _build_datetime("TO_TIMESTAMP_LTZ", exp.DataType.Type.TIMESTAMPLTZ),
"TO_TIMESTAMP_NTZ": _build_datetime("TO_TIMESTAMP_NTZ", exp.DataType.Type.TIMESTAMP),
"TO_TIMESTAMP_TZ": _build_datetime("TO_TIMESTAMP_TZ", exp.DataType.Type.TIMESTAMPTZ),
"TO_VARCHAR": exp.ToChar.from_arg_list,
"ZEROIFNULL": _build_if_from_zeroifnull,
}
@ -377,7 +419,6 @@ class Snowflake(Dialect):
**parser.Parser.RANGE_PARSERS,
TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny),
TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny),
TokenType.COLON: lambda self, this: self._parse_colon_get_path(this),
}
ALTER_PARSERS = {
@ -434,35 +475,35 @@ class Snowflake(Dialect):
SCHEMA_KINDS = {"OBJECTS", "TABLES", "VIEWS", "SEQUENCES", "UNIQUE KEYS", "IMPORTED KEYS"}
def _parse_colon_get_path(
self: parser.Parser, this: t.Optional[exp.Expression]
) -> t.Optional[exp.Expression]:
while True:
path = self._parse_bitwise() or self._parse_var(any_token=True)
def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
this = super()._parse_column_ops(this)
casts = []
json_path = []
while self._match(TokenType.COLON):
path = super()._parse_column_ops(self._parse_field(any_token=True))
# The cast :: operator has a lower precedence than the extraction operator :, so
# we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH
if isinstance(path, exp.Cast):
target_type = path.to
while isinstance(path, exp.Cast):
casts.append(path.to)
path = path.this
else:
target_type = None
if isinstance(path, exp.Expression):
path = exp.Literal.string(path.sql(dialect="snowflake"))
if path:
json_path.append(path.sql(dialect="snowflake", copy=False))
# The extraction operator : is left-associative
if json_path:
this = self.expression(
exp.JSONExtract, this=this, expression=self.dialect.to_json_path(path)
exp.JSONExtract,
this=this,
expression=self.dialect.to_json_path(exp.Literal.string(".".join(json_path))),
)
if target_type:
this = exp.cast(this, target_type)
while casts:
this = self.expression(exp.Cast, this=this, to=casts.pop())
if not self._match(TokenType.COLON):
break
return self._parse_range(this)
return this
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts
@ -663,6 +704,7 @@ class Snowflake(Dialect):
"EXCLUDE": TokenType.EXCEPT,
"ILIKE ANY": TokenType.ILIKE_ANY,
"LIKE ANY": TokenType.LIKE_ANY,
"MATCH_CONDITION": TokenType.MATCH_CONDITION,
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"MINUS": TokenType.EXCEPT,
"NCHAR VARYING": TokenType.VARCHAR,
@ -703,6 +745,7 @@ class Snowflake(Dialect):
LIMIT_ONLY_LITERALS = True
JSON_KEY_VALUE_PAIR_SEP = ","
INSERT_OVERWRITE = " OVERWRITE INTO"
STRUCT_DELIMITER = ("(", ")")
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@ -711,15 +754,14 @@ class Snowflake(Dialect):
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.ArrayContains: lambda self, e: self.func("ARRAY_CONTAINS", e.expression, e.this),
exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
exp.AtTimeZone: lambda self, e: self.func(
"CONVERT_TIMEZONE", e.args.get("zone"), e.this
),
exp.BitwiseXor: rename_func("BITXOR"),
exp.Create: transforms.preprocess([_flatten_structured_types_unless_iceberg]),
exp.DateAdd: date_delta_sql("DATEADD"),
exp.DateDiff: date_delta_sql("DATEDIFF"),
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
@ -769,6 +811,7 @@ class Snowflake(Dialect):
),
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
exp.Stuff: rename_func("INSERT"),
exp.TimeAdd: date_delta_sql("TIMEADD"),
exp.TimestampDiff: lambda self, e: self.func(
"TIMESTAMPDIFF", e.unit, e.expression, e.this
),
@ -783,6 +826,9 @@ class Snowflake(Dialect):
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
exp.TsOrDsToDate: lambda self, e: self.func(
"TRY_TO_DATE" if e.args.get("safe") else "TO_DATE", e.this, self.format_time(e)
),
exp.UnixToTime: rename_func("TO_TIMESTAMP"),
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
@ -797,6 +843,8 @@ class Snowflake(Dialect):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.NESTED: "OBJECT",
exp.DataType.Type.STRUCT: "OBJECT",
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
}
@ -811,6 +859,37 @@ class Snowflake(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
UNSUPPORTED_VALUES_EXPRESSIONS = {
exp.Struct,
}
def values_sql(self, expression: exp.Values, values_as_table: bool = True) -> str:
if expression.find(*self.UNSUPPORTED_VALUES_EXPRESSIONS):
values_as_table = False
return super().values_sql(expression, values_as_table=values_as_table)
def datatype_sql(self, expression: exp.DataType) -> str:
expressions = expression.expressions
if (
expressions
and expression.is_type(*exp.DataType.STRUCT_TYPES)
and any(isinstance(field_type, exp.DataType) for field_type in expressions)
):
# The correct syntax is OBJECT [ (<key> <value_type [NOT NULL] [, ...]) ]
return "OBJECT"
return super().datatype_sql(expression)
def tonumber_sql(self, expression: exp.ToNumber) -> str:
return self.func(
"TO_NUMBER",
expression.this,
expression.args.get("format"),
expression.args.get("precision"),
expression.args.get("scale"),
)
def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str:
milli = expression.args.get("milli")
if milli is not None:

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import typing as t
from sqlglot import exp
from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.dialect import rename_func, unit_to_var
from sqlglot.dialects.hive import _build_with_ignore_nulls
from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider
from sqlglot.helper import seq_get
@ -78,6 +78,8 @@ class Spark(Spark2):
return this
class Generator(Spark2.Generator):
SUPPORTS_TO_NUMBER = True
TYPE_MAPPING = {
**Spark2.Generator.TYPE_MAPPING,
exp.DataType.Type.MONEY: "DECIMAL(15, 4)",
@ -100,7 +102,7 @@ class Spark(Spark2):
e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}",
exp.StartsWith: rename_func("STARTSWITH"),
exp.TimestampAdd: lambda self, e: self.func(
"DATEADD", e.args.get("unit") or "DAY", e.expression, e.this
"DATEADD", unit_to_var(e), e.expression, e.this
),
exp.TryCast: lambda self, e: (
self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e)
@ -117,11 +119,10 @@ class Spark(Spark2):
return self.function_fallback_sql(expression)
def datediff_sql(self, expression: exp.DateDiff) -> str:
unit = self.sql(expression, "unit")
end = self.sql(expression, "this")
start = self.sql(expression, "expression")
if unit:
return self.func("DATEDIFF", unit, start, end)
if expression.unit:
return self.func("DATEDIFF", unit_to_var(expression), start, end)
return self.func("DATEDIFF", end, start)

View file

@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
pivot_column_names,
rename_func,
trim_sql,
unit_to_str,
)
from sqlglot.dialects.hive import Hive
from sqlglot.helper import seq_get
@ -203,6 +204,7 @@ class Spark2(Hive):
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArraySum: lambda self,
e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.ArrayToString: rename_func("ARRAY_JOIN"),
exp.AtTimeZone: lambda self, e: self.func(
"FROM_UTC_TIMESTAMP", e.this, e.args.get("zone")
),
@ -218,7 +220,7 @@ class Spark2(Hive):
]
),
exp.DateFromParts: rename_func("MAKE_DATE"),
exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, unit_to_str(e)),
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
@ -241,9 +243,7 @@ class Spark2(Hive):
),
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
exp.TimestampTrunc: lambda self, e: self.func(
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
),
exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", unit_to_str(e), e.this),
exp.Trim: trim_sql,
exp.UnixToTime: _unix_to_time_sql,
exp.VariancePop: rename_func("VAR_POP"),
@ -252,7 +252,6 @@ class Spark2(Hive):
[transforms.remove_within_group_for_percentiles]
),
}
TRANSFORMS.pop(exp.ArrayJoin)
TRANSFORMS.pop(exp.ArraySort)
TRANSFORMS.pop(exp.ILike)
TRANSFORMS.pop(exp.Left)

View file

@ -33,6 +33,14 @@ def _json_extract_sql(self: SQLite.Generator, expression: exp.JSONExtract) -> st
return arrow_json_extract_sql(self, expression)
def _build_strftime(args: t.List) -> exp.Anonymous | exp.TimeToStr:
if len(args) == 1:
args.append(exp.CurrentTimestamp())
if len(args) == 2:
return exp.TimeToStr(this=exp.TsOrDsToTimestamp(this=args[1]), format=args[0])
return exp.Anonymous(this="STRFTIME", expressions=args)
def _transform_create(expression: exp.Expression) -> exp.Expression:
"""Move primary key to a column and enforce auto_increment on primary keys."""
schema = expression.this
@ -82,6 +90,7 @@ class SQLite(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"EDITDIST3": exp.Levenshtein.from_arg_list,
"STRFTIME": _build_strftime,
}
STRING_ALIASES = True
@ -93,6 +102,7 @@ class SQLite(Dialect):
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
SUPPORTS_CREATE_TABLE_LIKE = False
SUPPORTS_TABLE_ALIAS_COLUMNS = False
SUPPORTS_TO_NUMBER = False
SUPPORTED_JSON_PATH_PARTS = {
exp.JSONPathKey,
@ -151,7 +161,9 @@ class SQLite(Dialect):
),
exp.TableSample: no_tablesample_sql,
exp.TimeStrToTime: lambda self, e: self.sql(e, "this"),
exp.TimeToStr: lambda self, e: self.func("STRFTIME", e.args.get("format"), e.this),
exp.TryCast: no_trycast_sql,
exp.TsOrDsToTimestamp: lambda self, e: self.sql(e, "this"),
}
# SQLite doesn't generally support CREATE TABLE .. properties

View file

@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import (
arrow_json_extract_sql,
build_timestamp_trunc,
rename_func,
unit_to_str,
)
from sqlglot.dialects.mysql import MySQL
from sqlglot.helper import seq_get
@ -39,15 +40,13 @@ class StarRocks(MySQL):
**MySQL.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.this, e.expression
"DATE_DIFF", unit_to_str(e), e.this, e.expression
),
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.RegexpLike: rename_func("REGEXP"),
exp.StrToUnix: lambda self, e: self.func("UNIX_TIMESTAMP", e.this, self.format_time(e)),
exp.TimestampTrunc: lambda self, e: self.func(
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
),
exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", unit_to_str(e), e.this),
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.UnixToStr: lambda self, e: self.func("FROM_UNIXTIME", e.this, self.format_time(e)),
exp.UnixToTime: rename_func("FROM_UNIXTIME"),

View file

@ -5,6 +5,8 @@ from sqlglot.dialects.dialect import Dialect, rename_func
class Tableau(Dialect):
LOG_BASE_FIRST = False
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = [("[", "]")]
QUOTES = ["'", '"']

View file

@ -3,7 +3,13 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least, rename_func
from sqlglot.dialects.dialect import (
Dialect,
max_or_greatest,
min_or_least,
rename_func,
to_number_with_nls_param,
)
from sqlglot.tokens import TokenType
@ -206,6 +212,7 @@ class Teradata(Dialect):
exp.StrToDate: lambda self,
e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.ToNumber: to_number_with_nls_param,
exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}",
}

View file

@ -7,6 +7,7 @@ from sqlglot.dialects.presto import Presto
class Trino(Presto):
SUPPORTS_USER_DEFINED_TYPES = False
LOG_BASE_FIRST = True
class Generator(Presto.Generator):
TRANSFORMS = {

View file

@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
NormalizationStrategy,
any_value_to_max_sql,
date_delta_sql,
datestrtodate_sql,
generatedasidentitycolumnconstraint_sql,
max_or_greatest,
min_or_least,
@ -724,6 +725,7 @@ class TSQL(Dialect):
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
SUPPORTS_SELECT_INTO = True
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
SUPPORTS_TO_NUMBER = False
EXPRESSIONS_WITHOUT_NESTED_CTES = {
exp.Delete,
@ -760,12 +762,14 @@ class TSQL(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.AnyValue: any_value_to_max_sql,
exp.ArrayToString: rename_func("STRING_AGG"),
exp.AutoIncrementColumnConstraint: lambda *_: "IDENTITY",
exp.DateAdd: date_delta_sql("DATEADD"),
exp.DateDiff: date_delta_sql("DATEDIFF"),
exp.CTE: transforms.preprocess([qualify_derived_table_outputs]),
exp.CurrentDate: rename_func("GETDATE"),
exp.CurrentTimestamp: rename_func("GETDATE"),
exp.DateStrToDate: datestrtodate_sql,
exp.Extract: rename_func("DATEPART"),
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
exp.GroupConcat: _string_agg_sql,
@ -808,6 +812,22 @@ class TSQL(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
def select_sql(self, expression: exp.Select) -> str:
if expression.args.get("offset"):
if not expression.args.get("order"):
# ORDER BY is required in order to use OFFSET in a query, so we use
# a noop order by, since we don't really care about the order.
# See: https://www.microsoftpressstore.com/articles/article.aspx?p=2314819
expression.order_by(exp.select(exp.null()).subquery(), copy=False)
limit = expression.args.get("limit")
if isinstance(limit, exp.Limit):
# TOP and OFFSET can't be combined, we need use FETCH instead of TOP
# we replace here because otherwise TOP would be generated in select_sql
limit.replace(exp.Fetch(direction="FIRST", count=limit.expression))
return super().select_sql(expression)
def convert_sql(self, expression: exp.Convert) -> str:
name = "TRY_CONVERT" if expression.args.get("safe") else "CONVERT"
return self.func(
@ -862,12 +882,12 @@ class TSQL(Dialect):
return rename_func("DATETIMEFROMPARTS")(self, expression)
def set_operation(self, expression: exp.Union, op: str) -> str:
def set_operations(self, expression: exp.Union) -> str:
limit = expression.args.get("limit")
if limit:
return self.sql(expression.limit(limit.pop(), copy=False))
return super().set_operation(expression, op)
return super().set_operations(expression)
def setitem_sql(self, expression: exp.SetItem) -> str:
this = expression.this

View file

@ -103,7 +103,7 @@ def diff(
) -> t.Dict[int, exp.Expression]:
return {
id(old_node): new_node
for (old_node, _, _), (new_node, _, _) in zip(original.walk(), copy.walk())
for old_node, new_node in zip(original.walk(), copy.walk())
if id(old_node) in matching_ids
}
@ -158,14 +158,10 @@ class ChangeDistiller:
self._source = source
self._target = target
self._source_index = {
id(n): n
for n, *_ in self._source.bfs()
if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES)
id(n): n for n in self._source.bfs() if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES)
}
self._target_index = {
id(n): n
for n, *_ in self._target.bfs()
if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES)
id(n): n for n in self._target.bfs() if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES)
}
self._unmatched_source_nodes = set(self._source_index) - set(pre_matched_nodes)
self._unmatched_target_nodes = set(self._target_index) - set(pre_matched_nodes.values())
@ -216,10 +212,10 @@ class ChangeDistiller:
matching_set = leaves_matching_set.copy()
ordered_unmatched_source_nodes = {
id(n): None for n, *_ in self._source.bfs() if id(n) in self._unmatched_source_nodes
id(n): None for n in self._source.bfs() if id(n) in self._unmatched_source_nodes
}
ordered_unmatched_target_nodes = {
id(n): None for n, *_ in self._target.bfs() if id(n) in self._unmatched_target_nodes
id(n): None for n in self._target.bfs() if id(n) in self._unmatched_target_nodes
}
for source_node_id in ordered_unmatched_source_nodes:
@ -322,7 +318,7 @@ class ChangeDistiller:
def _get_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]:
has_child_exprs = False
for _, node in expression.iter_expressions():
for node in expression.iter_expressions():
if not isinstance(node, IGNORED_LEAF_EXPRESSION_TYPES):
has_child_exprs = True
yield from _get_leaves(node)

View file

@ -10,11 +10,13 @@ import logging
import time
import typing as t
from sqlglot import exp
from sqlglot.errors import ExecuteError
from sqlglot.executor.python import PythonExecutor
from sqlglot.executor.table import Table, ensure_tables
from sqlglot.helper import dict_depth
from sqlglot.optimizer import optimize
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.planner import Plan
from sqlglot.schema import ensure_schema, flatten_schema, nested_get, nested_set
@ -26,15 +28,11 @@ if t.TYPE_CHECKING:
from sqlglot.schema import Schema
PYTHON_TYPE_TO_SQLGLOT = {
"dict": "MAP",
}
def execute(
sql: str | Expression,
schema: t.Optional[t.Dict | Schema] = None,
read: DialectType = None,
dialect: DialectType = None,
tables: t.Optional[t.Dict] = None,
) -> Table:
"""
@ -48,11 +46,13 @@ def execute(
2. {db: {table: {col: type}}}
3. {catalog: {db: {table: {col: type}}}}
read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
dialect: the SQL dialect (alias for read).
tables: additional tables to register.
Returns:
Simple columnar data structure.
"""
read = read or dialect
tables_ = ensure_tables(tables, dialect=read)
if not schema:
@ -64,8 +64,9 @@ def execute(
assert table is not None
for column in table.columns:
py_type = type(table[0][column]).__name__
nested_set(schema, [*keys, column], PYTHON_TYPE_TO_SQLGLOT.get(py_type) or py_type)
value = table[0][column]
column_type = annotate_types(exp.convert(value)).type or type(value).__name__
nested_set(schema, [*keys, column], column_type)
schema = ensure_schema(schema, dialect=read)

View file

@ -106,6 +106,13 @@ def cast(this, to):
return this
if isinstance(this, str):
return datetime.date.fromisoformat(this)
if to == exp.DataType.Type.TIME:
if isinstance(this, datetime.datetime):
return this.time()
if isinstance(this, datetime.time):
return this
if isinstance(this, str):
return datetime.time.fromisoformat(this)
if to in (exp.DataType.Type.DATETIME, exp.DataType.Type.TIMESTAMP):
if isinstance(this, datetime.datetime):
return this
@ -139,7 +146,7 @@ def interval(this, unit):
@null_if_any("this", "expression")
def arrayjoin(this, expression, null=None):
def arraytostring(this, expression, null=None):
return expression.join(x for x in (x if x is not None else null for x in this) if x is not None)
@ -173,7 +180,7 @@ ENV = {
"ABS": null_if_any(lambda this: abs(this)),
"ADD": null_if_any(lambda e, this: e + this),
"ARRAYANY": null_if_any(lambda arr, func: any(func(e) for e in arr)),
"ARRAYJOIN": arrayjoin,
"ARRAYTOSTRING": arraytostring,
"BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high),
"BITWISEAND": null_if_any(lambda this, e: this & e),
"BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e),
@ -212,6 +219,7 @@ ENV = {
"ORDERED": ordered,
"POW": pow,
"RIGHT": null_if_any(lambda this, e: this[-e:]),
"ROUND": null_if_any(lambda this, decimals=None, truncate=None: round(this, ndigits=decimals)),
"STRPOSITION": str_position,
"SUB": null_if_any(lambda e, this: e - this),
"SUBSTRING": substring,
@ -225,10 +233,12 @@ ENV = {
"CURRENTTIME": datetime.datetime.now,
"CURRENTDATE": datetime.date.today,
"STRFTIME": null_if_any(lambda fmt, arg: datetime.datetime.fromisoformat(arg).strftime(fmt)),
"STRTOTIME": null_if_any(lambda arg, format: datetime.datetime.strptime(arg, format)),
"TRIM": null_if_any(lambda this, e=None: this.strip(e)),
"STRUCT": lambda *args: {
args[x]: args[x + 1]
for x in range(0, len(args), 2)
if (args[x + 1] is not None and args[x] is not None)
},
"UNIXTOTIME": null_if_any(lambda arg: datetime.datetime.utcfromtimestamp(arg)),
}

View file

@ -157,7 +157,7 @@ class PythonExecutor:
yield context.table.reader
def join(self, step, context):
source = step.name
source = step.source_name
source_table = context.tables[source]
source_context = self.context({source: source_table})
@ -398,7 +398,7 @@ def _lambda_sql(self, e: exp.Lambda) -> str:
lambda n: (
exp.var(n.name) if isinstance(n, exp.Identifier) and n.name.lower() in names else n
)
)
).assert_is(exp.Lambda)
return f"lambda {self.expressions(e, flat=True)}: {self.sql(e, 'this')}"

File diff suppressed because it is too large Load diff

View file

@ -46,9 +46,11 @@ class Generator(metaclass=_Generator):
'safe': Only quote identifiers that are case insensitive.
normalize: Whether to normalize identifiers to lowercase.
Default: False.
pad: The pad size in a formatted string.
pad: The pad size in a formatted string. For example, this affects the indentation of
a projection in a query, relative to its nesting level.
Default: 2.
indent: The indentation size in a formatted string.
indent: The indentation size in a formatted string. For example, this affects the
indentation of subqueries and filters under a `WHERE` clause.
Default: 2.
normalize_functions: How to normalize function names. Possible values are:
"upper" or True (default): Convert names to uppercase.
@ -73,6 +75,7 @@ class Generator(metaclass=_Generator):
TRANSFORMS: t.Dict[t.Type[exp.Expression], t.Callable[..., str]] = {
**JSON_PATH_PART_TRANSFORMS,
exp.AutoRefreshProperty: lambda self, e: f"AUTO REFRESH {self.sql(e, 'this')}",
exp.BackupProperty: lambda self, e: f"BACKUP {self.sql(e, 'this')}",
exp.CaseSpecificColumnConstraint: lambda _,
e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
@ -83,15 +86,15 @@ class Generator(metaclass=_Generator):
exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
exp.CopyGrantsProperty: lambda *_: "COPY GRANTS",
exp.DateAdd: lambda self, e: self.func(
"DATE_ADD", e.this, e.expression, exp.Literal.string(e.text("unit"))
),
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}",
exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}",
exp.ExcludeColumnConstraint: lambda self, e: f"EXCLUDE {self.sql(e, 'this').lstrip()}",
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
exp.ExternalProperty: lambda *_: "EXTERNAL",
exp.GlobalProperty: lambda *_: "GLOBAL",
exp.HeapProperty: lambda *_: "HEAP",
exp.IcebergProperty: lambda *_: "ICEBERG",
exp.InheritsProperty: lambda self, e: f"INHERITS ({self.expressions(e, flat=True)})",
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
exp.InputModelProperty: lambda self, e: f"INPUT{self.sql(e, 'this')}",
@ -123,6 +126,7 @@ class Generator(metaclass=_Generator):
exp.SetConfigProperty: lambda self, e: self.sql(e, "this"),
exp.SetProperty: lambda _, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
exp.SharingProperty: lambda self, e: f"SHARING={self.sql(e, 'this')}",
exp.SqlReadWriteProperty: lambda _, e: e.name,
exp.SqlSecurityProperty: lambda _,
e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
@ -130,13 +134,17 @@ class Generator(metaclass=_Generator):
exp.TemporaryProperty: lambda *_: "TEMPORARY",
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
exp.Timestamp: lambda self, e: self.func("TIMESTAMP", e.this, e.expression),
exp.ToMap: lambda self, e: f"MAP {self.sql(e, 'this')}",
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions),
exp.TransientProperty: lambda *_: "TRANSIENT",
exp.UppercaseColumnConstraint: lambda *_: "UPPERCASE",
exp.UnloggedProperty: lambda *_: "UNLOGGED",
exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]),
exp.ViewAttributeProperty: lambda self, e: f"WITH {self.sql(e, 'this')}",
exp.VolatileProperty: lambda *_: "VOLATILE",
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
exp.WithOperator: lambda self, e: f"{self.sql(e, 'this')} WITH {self.sql(e, 'op')}",
}
# Whether null ordering is supported in order by
@ -321,6 +329,9 @@ class Generator(metaclass=_Generator):
# Whether any(f(x) for x in array) can be implemented by this dialect
CAN_IMPLEMENT_ARRAY_ANY = False
# Whether the function TO_NUMBER is supported
SUPPORTS_TO_NUMBER = True
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@ -350,6 +361,18 @@ class Generator(metaclass=_Generator):
"YEARS": "YEAR",
}
AFTER_HAVING_MODIFIER_TRANSFORMS = {
"cluster": lambda self, e: self.sql(e, "cluster"),
"distribute": lambda self, e: self.sql(e, "distribute"),
"qualify": lambda self, e: self.sql(e, "qualify"),
"sort": lambda self, e: self.sql(e, "sort"),
"windows": lambda self, e: (
self.seg("WINDOW ") + self.expressions(e, key="windows", flat=True)
if e.args.get("windows")
else ""
),
}
TOKEN_MAPPING: t.Dict[TokenType, str] = {}
STRUCT_DELIMITER = ("<", ">")
@ -361,6 +384,7 @@ class Generator(metaclass=_Generator):
exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE,
exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA,
exp.AutoRefreshProperty: exp.Properties.Location.POST_SCHEMA,
exp.BackupProperty: exp.Properties.Location.POST_SCHEMA,
exp.BlockCompressionProperty: exp.Properties.Location.POST_NAME,
exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA,
exp.ChecksumProperty: exp.Properties.Location.POST_NAME,
@ -380,8 +404,10 @@ class Generator(metaclass=_Generator):
exp.FallbackProperty: exp.Properties.Location.POST_NAME,
exp.FileFormatProperty: exp.Properties.Location.POST_WITH,
exp.FreespaceProperty: exp.Properties.Location.POST_NAME,
exp.GlobalProperty: exp.Properties.Location.POST_CREATE,
exp.HeapProperty: exp.Properties.Location.POST_WITH,
exp.InheritsProperty: exp.Properties.Location.POST_SCHEMA,
exp.IcebergProperty: exp.Properties.Location.POST_CREATE,
exp.InputModelProperty: exp.Properties.Location.POST_SCHEMA,
exp.IsolatedLoadingProperty: exp.Properties.Location.POST_NAME,
exp.JournalProperty: exp.Properties.Location.POST_NAME,
@ -414,6 +440,8 @@ class Generator(metaclass=_Generator):
exp.SettingsProperty: exp.Properties.Location.POST_SCHEMA,
exp.SetProperty: exp.Properties.Location.POST_CREATE,
exp.SetConfigProperty: exp.Properties.Location.POST_SCHEMA,
exp.SharingProperty: exp.Properties.Location.POST_EXPRESSION,
exp.SequenceProperties: exp.Properties.Location.POST_EXPRESSION,
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA,
exp.SqlReadWriteProperty: exp.Properties.Location.POST_SCHEMA,
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
@ -423,6 +451,8 @@ class Generator(metaclass=_Generator):
exp.TransientProperty: exp.Properties.Location.POST_CREATE,
exp.TransformModelProperty: exp.Properties.Location.POST_SCHEMA,
exp.MergeTreeTTL: exp.Properties.Location.POST_SCHEMA,
exp.UnloggedProperty: exp.Properties.Location.POST_CREATE,
exp.ViewAttributeProperty: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION,
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
@ -441,6 +471,7 @@ class Generator(metaclass=_Generator):
exp.Insert,
exp.Join,
exp.Select,
exp.Union,
exp.Update,
exp.Where,
exp.With,
@ -626,7 +657,7 @@ class Generator(metaclass=_Generator):
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
return (
f"{self.sep()}{comments_sql}{sql}"
if sql[0].isspace()
if not sql or sql[0].isspace()
else f"{comments_sql}{self.sep()}{sql}"
)
@ -869,7 +900,9 @@ class Generator(metaclass=_Generator):
this = f" {this}" if this else ""
index_type = expression.args.get("index_type")
index_type = f" USING {index_type}" if index_type else ""
return f"UNIQUE{this}{index_type}"
on_conflict = self.sql(expression, "on_conflict")
on_conflict = f" {on_conflict}" if on_conflict else ""
return f"UNIQUE{this}{index_type}{on_conflict}"
def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str:
return self.sql(expression, "this")
@ -961,6 +994,31 @@ class Generator(metaclass=_Generator):
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{postexpression_props_sql}{index_sql}{no_schema_binding}{clone}"
return self.prepend_ctes(expression, expression_sql)
def sequenceproperties_sql(self, expression: exp.SequenceProperties) -> str:
start = self.sql(expression, "start")
start = f"START WITH {start}" if start else ""
increment = self.sql(expression, "increment")
increment = f" INCREMENT BY {increment}" if increment else ""
minvalue = self.sql(expression, "minvalue")
minvalue = f" MINVALUE {minvalue}" if minvalue else ""
maxvalue = self.sql(expression, "maxvalue")
maxvalue = f" MAXVALUE {maxvalue}" if maxvalue else ""
owned = self.sql(expression, "owned")
owned = f" OWNED BY {owned}" if owned else ""
cache = expression.args.get("cache")
if cache is None:
cache_str = ""
elif cache is True:
cache_str = " CACHE"
else:
cache_str = f" CACHE {cache}"
options = self.expressions(expression, key="options", flat=True, sep=" ")
options = f" {options}" if options else ""
return f"{start}{increment}{minvalue}{maxvalue}{cache_str}{options}{owned}".lstrip()
def clone_sql(self, expression: exp.Clone) -> str:
this = self.sql(expression, "this")
shallow = "SHALLOW " if expression.args.get("shallow") else ""
@ -968,8 +1026,9 @@ class Generator(metaclass=_Generator):
return f"{shallow}{keyword} {this}"
def describe_sql(self, expression: exp.Describe) -> str:
extended = " EXTENDED" if expression.args.get("extended") else ""
return f"DESCRIBE{extended} {self.sql(expression, 'this')}"
style = expression.args.get("style")
style = f" {style}" if style else ""
return f"DESCRIBE{style} {self.sql(expression, 'this')}"
def heredoc_sql(self, expression: exp.Heredoc) -> str:
tag = self.sql(expression, "tag")
@ -993,7 +1052,14 @@ class Generator(metaclass=_Generator):
def cte_sql(self, expression: exp.CTE) -> str:
alias = self.sql(expression, "alias")
return f"{alias} AS {self.wrap(expression)}"
materialized = expression.args.get("materialized")
if materialized is False:
materialized = "NOT MATERIALIZED "
elif materialized:
materialized = "MATERIALIZED "
return f"{alias} AS {materialized or ''}{self.wrap(expression)}"
def tablealias_sql(self, expression: exp.TableAlias) -> str:
alias = self.sql(expression, "this")
@ -1044,7 +1110,7 @@ class Generator(metaclass=_Generator):
return f"{self.dialect.QUOTE_START}{this}{self.dialect.QUOTE_END}"
def rawstring_sql(self, expression: exp.RawString) -> str:
string = self.escape_str(expression.this.replace("\\", "\\\\"))
string = self.escape_str(expression.this.replace("\\", "\\\\"), escape_backslash=False)
return f"{self.dialect.QUOTE_START}{string}{self.dialect.QUOTE_END}"
def datatypeparam_sql(self, expression: exp.DataTypeParam) -> str:
@ -1114,6 +1180,8 @@ class Generator(metaclass=_Generator):
def drop_sql(self, expression: exp.Drop) -> str:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
expressions = f" ({expressions})" if expressions else ""
kind = expression.args["kind"]
exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
@ -1121,15 +1189,10 @@ class Generator(metaclass=_Generator):
cascade = " CASCADE" if expression.args.get("cascade") else ""
constraints = " CONSTRAINTS" if expression.args.get("constraints") else ""
purge = " PURGE" if expression.args.get("purge") else ""
return (
f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}{constraints}{purge}"
)
return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{expressions}{cascade}{constraints}{purge}"
def except_sql(self, expression: exp.Except) -> str:
return self.prepend_ctes(
expression,
self.set_operation(expression, self.except_op(expression)),
)
return self.set_operations(expression)
def except_op(self, expression: exp.Except) -> str:
return f"EXCEPT{'' if expression.args.get('distinct') else ' ALL'}"
@ -1163,17 +1226,9 @@ class Generator(metaclass=_Generator):
return f" /*+ {self.expressions(expression, sep=self.QUERY_HINT_SEP).strip()} */"
def index_sql(self, expression: exp.Index) -> str:
unique = "UNIQUE " if expression.args.get("unique") else ""
primary = "PRIMARY " if expression.args.get("primary") else ""
amp = "AMP " if expression.args.get("amp") else ""
name = self.sql(expression, "this")
name = f"{name} " if name else ""
table = self.sql(expression, "table")
table = f"{self.INDEX_ON} {table}" if table else ""
def indexparameters_sql(self, expression: exp.IndexParameters) -> str:
using = self.sql(expression, "using")
using = f" USING {using}" if using else ""
index = "INDEX " if not table else ""
columns = self.expressions(expression, key="columns", flat=True)
columns = f"({columns})" if columns else ""
partition_by = self.expressions(expression, key="partition_by", flat=True)
@ -1182,7 +1237,26 @@ class Generator(metaclass=_Generator):
include = self.expressions(expression, key="include", flat=True)
if include:
include = f" INCLUDE ({include})"
return f"{unique}{primary}{amp}{index}{name}{table}{using}{columns}{include}{partition_by}{where}"
with_storage = self.expressions(expression, key="with_storage", flat=True)
with_storage = f" WITH ({with_storage})" if with_storage else ""
tablespace = self.sql(expression, "tablespace")
tablespace = f" USING INDEX TABLESPACE {tablespace}" if tablespace else ""
return f"{using}{columns}{include}{with_storage}{tablespace}{partition_by}{where}"
def index_sql(self, expression: exp.Index) -> str:
unique = "UNIQUE " if expression.args.get("unique") else ""
primary = "PRIMARY " if expression.args.get("primary") else ""
amp = "AMP " if expression.args.get("amp") else ""
name = self.sql(expression, "this")
name = f"{name} " if name else ""
table = self.sql(expression, "table")
table = f"{self.INDEX_ON} {table}" if table else ""
index = "INDEX " if not table else ""
params = self.sql(expression, "params")
return f"{unique}{primary}{amp}{index}{name}{table}{params}"
def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
@ -1371,15 +1445,9 @@ class Generator(metaclass=_Generator):
no = " NO" if no else ""
concurrent = expression.args.get("concurrent")
concurrent = " CONCURRENT" if concurrent else ""
for_ = ""
if expression.args.get("for_all"):
for_ = " FOR ALL"
elif expression.args.get("for_insert"):
for_ = " FOR INSERT"
elif expression.args.get("for_none"):
for_ = " FOR NONE"
return f"WITH{no}{concurrent} ISOLATED LOADING{for_}"
target = self.sql(expression, "target")
target = f" {target}" if target else ""
return f"WITH{no}{concurrent} ISOLATED LOADING{target}"
def partitionboundspec_sql(self, expression: exp.PartitionBoundSpec) -> str:
if isinstance(expression.this, list):
@ -1437,6 +1505,7 @@ class Generator(metaclass=_Generator):
return f"{sql})"
def insert_sql(self, expression: exp.Insert) -> str:
hint = self.sql(expression, "hint")
overwrite = expression.args.get("overwrite")
if isinstance(expression.this, exp.Directory):
@ -1447,7 +1516,9 @@ class Generator(metaclass=_Generator):
alternative = expression.args.get("alternative")
alternative = f" OR {alternative}" if alternative else ""
ignore = " IGNORE" if expression.args.get("ignore") else ""
is_function = expression.args.get("is_function")
if is_function:
this = f"{this} FUNCTION"
this = f"{this} {self.sql(expression, 'this')}"
exists = " IF EXISTS" if expression.args.get("exists") else ""
@ -1457,23 +1528,21 @@ class Generator(metaclass=_Generator):
where = self.sql(expression, "where")
where = f"{self.sep()}REPLACE WHERE {where}" if where else ""
expression_sql = f"{self.sep()}{self.sql(expression, 'expression')}"
conflict = self.sql(expression, "conflict")
on_conflict = self.sql(expression, "conflict")
on_conflict = f" {on_conflict}" if on_conflict else ""
by_name = " BY NAME" if expression.args.get("by_name") else ""
returning = self.sql(expression, "returning")
if self.RETURNING_END:
expression_sql = f"{expression_sql}{conflict}{returning}"
expression_sql = f"{expression_sql}{on_conflict}{returning}"
else:
expression_sql = f"{returning}{expression_sql}{conflict}"
expression_sql = f"{returning}{expression_sql}{on_conflict}"
sql = f"INSERT{alternative}{ignore}{this}{by_name}{exists}{partition_sql}{where}{expression_sql}"
sql = f"INSERT{hint}{alternative}{ignore}{this}{by_name}{exists}{partition_sql}{where}{expression_sql}"
return self.prepend_ctes(expression, sql)
def intersect_sql(self, expression: exp.Intersect) -> str:
return self.prepend_ctes(
expression,
self.set_operation(expression, self.intersect_op(expression)),
)
return self.set_operations(expression)
def intersect_op(self, expression: exp.Intersect) -> str:
return f"INTERSECT{'' if expression.args.get('distinct') else ' ALL'}"
@ -1496,33 +1565,36 @@ class Generator(metaclass=_Generator):
def onconflict_sql(self, expression: exp.OnConflict) -> str:
conflict = "ON DUPLICATE KEY" if expression.args.get("duplicate") else "ON CONFLICT"
constraint = self.sql(expression, "constraint")
if constraint:
constraint = f"ON CONSTRAINT {constraint}"
key = self.expressions(expression, key="key", flat=True)
do = "" if expression.args.get("duplicate") else " DO "
nothing = "NOTHING" if expression.args.get("nothing") else ""
constraint = f" ON CONSTRAINT {constraint}" if constraint else ""
conflict_keys = self.expressions(expression, key="conflict_keys", flat=True)
conflict_keys = f"({conflict_keys}) " if conflict_keys else " "
action = self.sql(expression, "action")
expressions = self.expressions(expression, flat=True)
set_keyword = "SET " if self.DUPLICATE_KEY_UPDATE_WITH_SET else ""
if expressions:
expressions = f"UPDATE {set_keyword}{expressions}"
return f"{self.seg(conflict)} {constraint}{key}{do}{nothing}{expressions}"
set_keyword = "SET " if self.DUPLICATE_KEY_UPDATE_WITH_SET else ""
expressions = f" {set_keyword}{expressions}"
return f"{conflict}{constraint}{conflict_keys}{action}{expressions}"
def returning_sql(self, expression: exp.Returning) -> str:
return f"{self.seg('RETURNING')} {self.expressions(expression, flat=True)}"
def rowformatdelimitedproperty_sql(self, expression: exp.RowFormatDelimitedProperty) -> str:
fields = expression.args.get("fields")
fields = self.sql(expression, "fields")
fields = f" FIELDS TERMINATED BY {fields}" if fields else ""
escaped = expression.args.get("escaped")
escaped = self.sql(expression, "escaped")
escaped = f" ESCAPED BY {escaped}" if escaped else ""
items = expression.args.get("collection_items")
items = self.sql(expression, "collection_items")
items = f" COLLECTION ITEMS TERMINATED BY {items}" if items else ""
keys = expression.args.get("map_keys")
keys = self.sql(expression, "map_keys")
keys = f" MAP KEYS TERMINATED BY {keys}" if keys else ""
lines = expression.args.get("lines")
lines = self.sql(expression, "lines")
lines = f" LINES TERMINATED BY {lines}" if lines else ""
null = expression.args.get("null")
null = self.sql(expression, "null")
null = f" NULL DEFINED AS {null}" if null else ""
return f"ROW FORMAT DELIMITED{fields}{escaped}{items}{keys}{lines}{null}"
@ -1563,7 +1635,9 @@ class Generator(metaclass=_Generator):
hints = f" {hints}" if hints and self.TABLE_HINTS else ""
pivots = self.expressions(expression, key="pivots", sep=" ", flat=True)
pivots = f" {pivots}" if pivots else ""
joins = self.expressions(expression, key="joins", sep="", skip_first=True)
joins = self.indent(
self.expressions(expression, key="joins", sep="", flat=True), skip_first=True
)
laterals = self.expressions(expression, key="laterals", sep="")
file_format = self.sql(expression, "format")
@ -1673,9 +1747,11 @@ class Generator(metaclass=_Generator):
sql = f"UPDATE {this} SET {set_sql}{expression_sql}{order}{limit}"
return self.prepend_ctes(expression, sql)
def values_sql(self, expression: exp.Values) -> str:
def values_sql(self, expression: exp.Values, values_as_table: bool = True) -> str:
values_as_table = values_as_table and self.VALUES_AS_TABLE
# The VALUES clause is still valid in an `INSERT INTO ..` statement, for example
if self.VALUES_AS_TABLE or not expression.find_ancestor(exp.From, exp.Join):
if values_as_table or not expression.find_ancestor(exp.From, exp.Join):
args = self.expressions(expression)
alias = self.sql(expression, "alias")
values = f"VALUES{self.seg('')}{args}"
@ -1769,8 +1845,9 @@ class Generator(metaclass=_Generator):
def connect_sql(self, expression: exp.Connect) -> str:
start = self.sql(expression, "start")
start = self.seg(f"START WITH {start}") if start else ""
nocycle = " NOCYCLE" if expression.args.get("nocycle") else ""
connect = self.sql(expression, "connect")
connect = self.seg(f"CONNECT BY {connect}")
connect = self.seg(f"CONNECT BY{nocycle} {connect}")
return start + connect
def prior_sql(self, expression: exp.Prior) -> str:
@ -1793,6 +1870,8 @@ class Generator(metaclass=_Generator):
)
if op
)
match_cond = self.sql(expression, "match_condition")
match_cond = f" MATCH_CONDITION ({match_cond})" if match_cond else ""
on_sql = self.sql(expression, "on")
using = expression.args.get("using")
@ -1816,7 +1895,7 @@ class Generator(metaclass=_Generator):
return f", {this_sql}"
op_sql = f"{op_sql} JOIN" if op_sql else "JOIN"
return f"{self.seg(op_sql)} {this_sql}{on_sql}"
return f"{self.seg(op_sql)} {this_sql}{match_cond}{on_sql}"
def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str:
args = self.expressions(expression, flat=True)
@ -1919,13 +1998,17 @@ class Generator(metaclass=_Generator):
text = f"{self.dialect.QUOTE_START}{self.escape_str(text)}{self.dialect.QUOTE_END}"
return text
def escape_str(self, text: str) -> str:
text = text.replace(self.dialect.QUOTE_END, self._escaped_quote_end)
if self.dialect.INVERSE_ESCAPE_SEQUENCES:
text = "".join(self.dialect.INVERSE_ESCAPE_SEQUENCES.get(ch, ch) for ch in text)
elif self.pretty:
def escape_str(self, text: str, escape_backslash: bool = True) -> str:
if self.dialect.ESCAPED_SEQUENCES:
to_escaped = self.dialect.ESCAPED_SEQUENCES
text = "".join(
to_escaped.get(ch, ch) if escape_backslash or ch != "\\" else ch for ch in text
)
if self.pretty:
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
return text
return text.replace(self.dialect.QUOTE_END, self._escaped_quote_end)
def loaddata_sql(self, expression: exp.LoadData) -> str:
local = " LOCAL" if expression.args.get("local") else ""
@ -2016,7 +2099,7 @@ class Generator(metaclass=_Generator):
self.unsupported(
f"'{nulls_sort_change.strip()}' translation not supported with positional ordering"
)
else:
elif not isinstance(expression.this, exp.Rand):
null_sort_order = " DESC" if nulls_sort_change == " NULLS FIRST" else ""
this = f"CASE WHEN {this} IS NULL THEN 1 ELSE 0 END{null_sort_order}, {this}"
nulls_sort_change = ""
@ -2059,24 +2142,13 @@ class Generator(metaclass=_Generator):
return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}{alias}"
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
limit: t.Optional[exp.Fetch | exp.Limit] = expression.args.get("limit")
# If the limit is generated as TOP, we need to ensure it's not generated twice
with_offset_limit_modifiers = not isinstance(limit, exp.Limit) or not self.LIMIT_IS_TOP
limit = expression.args.get("limit")
if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch):
limit = exp.Limit(expression=exp.maybe_copy(limit.args.get("count")))
elif self.LIMIT_FETCH == "FETCH" and isinstance(limit, exp.Limit):
limit = exp.Fetch(direction="FIRST", count=exp.maybe_copy(limit.expression))
fetch = isinstance(limit, exp.Fetch)
offset_limit_modifiers = (
self.offset_limit_modifiers(expression, fetch, limit)
if with_offset_limit_modifiers
else []
)
options = self.expressions(expression, key="options")
if options:
options = f" OPTION{self.wrap(options)}"
@ -2091,9 +2163,9 @@ class Generator(metaclass=_Generator):
self.sql(expression, "where"),
self.sql(expression, "group"),
self.sql(expression, "having"),
*self.after_having_modifiers(expression),
*[gen(self, expression) for gen in self.AFTER_HAVING_MODIFIER_TRANSFORMS.values()],
self.sql(expression, "order"),
*offset_limit_modifiers,
*self.offset_limit_modifiers(expression, isinstance(limit, exp.Fetch), limit),
*self.after_limit_modifiers(expression),
options,
sep="",
@ -2110,19 +2182,6 @@ class Generator(metaclass=_Generator):
self.sql(limit) if fetch else self.sql(expression, "offset"),
]
def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
return [
self.sql(expression, "qualify"),
(
self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True)
if expression.args.get("windows")
else ""
),
self.sql(expression, "distribute"),
self.sql(expression, "sort"),
self.sql(expression, "cluster"),
]
def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]:
locks = self.expressions(expression, key="locks", sep=" ")
locks = f" {locks}" if locks else ""
@ -2137,12 +2196,13 @@ class Generator(metaclass=_Generator):
distinct = self.sql(expression, "distinct")
distinct = f" {distinct}" if distinct else ""
kind = self.sql(expression, "kind")
limit = expression.args.get("limit")
top = (
self.limit_sql(limit, top=True)
if isinstance(limit, exp.Limit) and self.LIMIT_IS_TOP
else ""
)
if isinstance(limit, exp.Limit) and self.LIMIT_IS_TOP:
top = self.limit_sql(limit, top=True)
limit.pop()
else:
top = ""
expressions = self.expressions(expression)
@ -2220,7 +2280,7 @@ class Generator(metaclass=_Generator):
return f"@@{kind}{this}"
def placeholder_sql(self, expression: exp.Placeholder) -> str:
return f"{self.NAMED_PLACEHOLDER_TOKEN}{expression.name}" if expression.name else "?"
return f"{self.NAMED_PLACEHOLDER_TOKEN}{expression.name}" if expression.this else "?"
def subquery_sql(self, expression: exp.Subquery, sep: str = " AS ") -> str:
alias = self.sql(expression, "alias")
@ -2236,11 +2296,32 @@ class Generator(metaclass=_Generator):
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('QUALIFY')}{self.sep()}{this}"
def set_operations(self, expression: exp.Union) -> str:
sqls: t.List[str] = []
stack: t.List[t.Union[str, exp.Expression]] = [expression]
while stack:
node = stack.pop()
if isinstance(node, exp.Union):
stack.append(node.expression)
stack.append(
self.maybe_comment(
getattr(self, f"{node.key}_op")(node),
expression=node.this,
comments=node.comments,
)
)
stack.append(node.this)
else:
sqls.append(self.sql(node))
this = self.sep().join(sqls)
this = self.query_modifiers(expression, this)
return self.prepend_ctes(expression, this)
def union_sql(self, expression: exp.Union) -> str:
return self.prepend_ctes(
expression,
self.set_operation(expression, self.union_op(expression)),
)
return self.set_operations(expression)
def union_op(self, expression: exp.Union) -> str:
kind = " DISTINCT" if self.EXPLICIT_UNION else ""
@ -2345,8 +2426,10 @@ class Generator(metaclass=_Generator):
def any_sql(self, expression: exp.Any) -> str:
this = self.sql(expression, "this")
if isinstance(expression.this, exp.UNWRAPPED_QUERIES):
this = self.wrap(this)
if isinstance(expression.this, (*exp.UNWRAPPED_QUERIES, exp.Paren)):
if isinstance(expression.this, exp.UNWRAPPED_QUERIES):
this = self.wrap(this)
return f"ANY{this}"
return f"ANY {this}"
def exists_sql(self, expression: exp.Exists) -> str:
@ -2632,13 +2715,8 @@ class Generator(metaclass=_Generator):
return self.func(self.sql(expression, "this"), *expression.expressions)
def paren_sql(self, expression: exp.Paren) -> str:
if isinstance(expression.unnest(), exp.Select):
sql = self.wrap(expression)
else:
sql = self.seg(self.indent(self.sql(expression, "this")), sep="")
sql = f"({sql}{self.seg(')', sep='')}"
return self.prepend_ctes(expression, sql)
sql = self.seg(self.indent(self.sql(expression, "this")), sep="")
return f"({sql}{self.seg(')', sep='')}"
def neg_sql(self, expression: exp.Neg) -> str:
# This makes sure we don't convert "- - 5" to "--5", which is a comment
@ -2686,23 +2764,55 @@ class Generator(metaclass=_Generator):
def add_sql(self, expression: exp.Add) -> str:
return self.binary(expression, "+")
def and_sql(self, expression: exp.And) -> str:
return self.connector_sql(expression, "AND")
def and_sql(
self, expression: exp.And, stack: t.Optional[t.List[str | exp.Expression]] = None
) -> str:
return self.connector_sql(expression, "AND", stack)
def xor_sql(self, expression: exp.Xor) -> str:
return self.connector_sql(expression, "XOR")
def or_sql(
self, expression: exp.Or, stack: t.Optional[t.List[str | exp.Expression]] = None
) -> str:
return self.connector_sql(expression, "OR", stack)
def connector_sql(self, expression: exp.Connector, op: str) -> str:
if not self.pretty:
return self.binary(expression, op)
def xor_sql(
self, expression: exp.Xor, stack: t.Optional[t.List[str | exp.Expression]] = None
) -> str:
return self.connector_sql(expression, "XOR", stack)
sqls = tuple(
self.maybe_comment(self.sql(e), e, e.parent.comments or []) if i != 1 else self.sql(e)
for i, e in enumerate(expression.flatten(unnest=False))
)
def connector_sql(
self,
expression: exp.Connector,
op: str,
stack: t.Optional[t.List[str | exp.Expression]] = None,
) -> str:
if stack is not None:
if expression.expressions:
stack.append(self.expressions(expression, sep=f" {op} "))
else:
stack.append(expression.right)
if expression.comments:
for comment in expression.comments:
op += f" /*{self.pad_comment(comment)}*/"
stack.extend((op, expression.left))
return op
sep = "\n" if self.text_width(sqls) > self.max_text_width else " "
return f"{sep}{op} ".join(sqls)
stack = [expression]
sqls: t.List[str] = []
ops = set()
while stack:
node = stack.pop()
if isinstance(node, exp.Connector):
ops.add(getattr(self, f"{node.key}_sql")(node, stack))
else:
sql = self.sql(node)
if sqls and sqls[-1] in ops:
sqls[-1] += f" {sql}"
else:
sqls.append(sql)
sep = "\n" if self.pretty and self.text_width(sqls) > self.max_text_width else " "
return sep.join(sqls)
def bitwiseand_sql(self, expression: exp.BitwiseAnd) -> str:
return self.binary(expression, "&")
@ -2727,7 +2837,9 @@ class Generator(metaclass=_Generator):
format_sql = f" FORMAT {format_sql}" if format_sql else ""
to_sql = self.sql(expression, "to")
to_sql = f" {to_sql}" if to_sql else ""
return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS{to_sql}{format_sql})"
action = self.sql(expression, "action")
action = f" {action}" if action else ""
return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS{to_sql}{format_sql}{action})"
def currentdate_sql(self, expression: exp.CurrentDate) -> str:
zone = self.sql(expression, "this")
@ -2817,7 +2929,7 @@ class Generator(metaclass=_Generator):
# Remove db from tables
expression = expression.transform(
lambda n: exp.table_(n.this) if isinstance(n, exp.Table) else n
)
).assert_is(exp.RenameTable)
this = self.sql(expression, "this")
return f"RENAME TO {this}"
@ -2889,30 +3001,6 @@ class Generator(metaclass=_Generator):
kind = "MAX" if expression.args.get("max") else "MIN"
return f"{this_sql} HAVING {kind} {expression_sql}"
def _embed_ignore_nulls(self, expression: exp.IgnoreNulls | exp.RespectNulls, text: str) -> str:
if self.IGNORE_NULLS_IN_FUNC and not expression.meta.get("inline"):
# The first modifier here will be the one closest to the AggFunc's arg
mods = sorted(
expression.find_all(exp.HavingMax, exp.Order, exp.Limit),
key=lambda x: 0
if isinstance(x, exp.HavingMax)
else (1 if isinstance(x, exp.Order) else 2),
)
if mods:
mod = mods[0]
this = expression.__class__(this=mod.this.copy())
this.meta["inline"] = True
mod.this.replace(this)
return self.sql(expression.this)
agg_func = expression.find(exp.AggFunc)
if agg_func:
return self.sql(agg_func)[:-1] + f" {text})"
return f"{self.sql(expression, 'this')} {text}"
def intdiv_sql(self, expression: exp.IntDiv) -> str:
return self.sql(
exp.Cast(
@ -2933,9 +3021,7 @@ class Generator(metaclass=_Generator):
r.replace(exp.Nullif(this=r.copy(), expression=exp.Literal.number(0)))
if self.dialect.TYPED_DIVISION and not expression.args.get("typed"):
if not l.is_type(*exp.DataType.FLOAT_TYPES) and not r.is_type(
*exp.DataType.FLOAT_TYPES
):
if not l.is_type(*exp.DataType.REAL_TYPES) and not r.is_type(*exp.DataType.REAL_TYPES):
l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DOUBLE))
elif not self.dialect.TYPED_DIVISION and expression.args.get("typed"):
@ -3019,9 +3105,6 @@ class Generator(metaclass=_Generator):
def nullsafeneq_sql(self, expression: exp.NullSafeNEQ) -> str:
return self.binary(expression, "IS DISTINCT FROM")
def or_sql(self, expression: exp.Or) -> str:
return self.connector_sql(expression, "OR")
def slice_sql(self, expression: exp.Slice) -> str:
return self.binary(expression, ":")
@ -3035,8 +3118,13 @@ class Generator(metaclass=_Generator):
this = expression.this
expr = expression.expression
if not self.dialect.LOG_BASE_FIRST:
if self.dialect.LOG_BASE_FIRST is False:
this, expr = expr, this
elif self.dialect.LOG_BASE_FIRST is None and expr:
if this.name in ("2", "10"):
return self.func(f"LOG{this.name}", expr)
self.unsupported(f"Unsupported logarithm with base {self.sql(this)}")
return self.func("LOG", this, expr)
@ -3088,11 +3176,16 @@ class Generator(metaclass=_Generator):
def text_width(self, args: t.Iterable) -> int:
return sum(len(arg) for arg in args)
def format_time(self, expression: exp.Expression) -> t.Optional[str]:
def format_time(
self,
expression: exp.Expression,
inverse_time_mapping: t.Optional[t.Dict[str, str]] = None,
inverse_time_trie: t.Optional[t.Dict] = None,
) -> t.Optional[str]:
return format_time(
self.sql(expression, "format"),
self.dialect.INVERSE_TIME_MAPPING,
self.dialect.INVERSE_TIME_TRIE,
inverse_time_mapping or self.dialect.INVERSE_TIME_MAPPING,
inverse_time_trie or self.dialect.INVERSE_TIME_TRIE,
)
def expressions(
@ -3117,8 +3210,11 @@ class Generator(metaclass=_Generator):
num_sqls = len(expressions)
# These are calculated once in case we have the leading_comma / pretty option set, correspondingly
pad = " " * self.pad
stripped_sep = sep.strip()
if self.pretty:
if self.leading_comma:
pad = " " * len(sep)
else:
stripped_sep = sep.strip()
result_sqls = []
for i, e in enumerate(expressions):
@ -3154,13 +3250,6 @@ class Generator(metaclass=_Generator):
self.unsupported(f"Unsupported property {expression.__class__.__name__}")
return f"{property_name} {self.sql(expression, 'this')}"
def set_operation(self, expression: exp.Union, op: str) -> str:
this = self.maybe_comment(self.sql(expression, "this"), comments=expression.comments)
op = self.seg(op)
return self.query_modifiers(
expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
)
def tag_sql(self, expression: exp.Tag) -> str:
return f"{expression.args.get('prefix')}{self.sql(expression.this)}{expression.args.get('postfix')}"
@ -3227,6 +3316,18 @@ class Generator(metaclass=_Generator):
return self.sql(exp.cast(expression.this, "text"))
def tonumber_sql(self, expression: exp.ToNumber) -> str:
if not self.SUPPORTS_TO_NUMBER:
self.unsupported("Unsupported TO_NUMBER function")
return self.sql(exp.cast(expression.this, "double"))
fmt = expression.args.get("format")
if not fmt:
self.unsupported("Conversion format is required for TO_NUMBER")
return self.sql(exp.cast(expression.this, "double"))
return self.func("TO_NUMBER", expression.this, fmt)
def dictproperty_sql(self, expression: exp.DictProperty) -> str:
this = self.sql(expression, "this")
kind = self.sql(expression, "kind")
@ -3320,11 +3421,11 @@ class Generator(metaclass=_Generator):
this = f" {this}" if this else ""
index_type = self.sql(expression, "index_type")
index_type = f" USING {index_type}" if index_type else ""
schema = self.sql(expression, "schema")
schema = f" {schema}" if schema else ""
expressions = self.expressions(expression, flat=True)
expressions = f" ({expressions})" if expressions else ""
options = self.expressions(expression, key="options", sep=" ")
options = f" {options}" if options else ""
return f"{kind}{this}{index_type}{schema}{options}"
return f"{kind}{this}{index_type}{expressions}{options}"
def nvl2_sql(self, expression: exp.Nvl2) -> str:
if self.NVL2_SUPPORTED:
@ -3396,6 +3497,13 @@ class Generator(metaclass=_Generator):
return self.sql(exp.cast(this, "time"))
def tsordstotimestamp_sql(self, expression: exp.TsOrDsToTimestamp) -> str:
this = expression.this
if isinstance(this, exp.TsOrDsToTimestamp) or this.is_type(exp.DataType.Type.TIMESTAMP):
return self.sql(this)
return self.sql(exp.cast(this, "timestamp"))
def tsordstodate_sql(self, expression: exp.TsOrDsToDate) -> str:
this = expression.this
time_format = self.format_time(expression)
@ -3430,6 +3538,13 @@ class Generator(metaclass=_Generator):
return self.func("LAST_DAY", expression.this)
def dateadd_sql(self, expression: exp.DateAdd) -> str:
from sqlglot.dialects.dialect import unit_to_str
return self.func(
"DATE_ADD", expression.this, expression.expression, unit_to_str(expression)
)
def arrayany_sql(self, expression: exp.ArrayAny) -> str:
if self.CAN_IMPLEMENT_ARRAY_ANY:
filtered = exp.ArrayFilter(this=expression.this, expression=expression.expression)
@ -3445,30 +3560,6 @@ class Generator(metaclass=_Generator):
return self.function_fallback_sql(expression)
def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str:
this = expression.this
if isinstance(this, exp.JSONPathWildcard):
this = self.json_path_part(this)
return f".{this}" if this else ""
if exp.SAFE_IDENTIFIER_RE.match(this):
return f".{this}"
this = self.json_path_part(this)
return f"[{this}]" if self.JSON_PATH_BRACKETED_KEY_SUPPORTED else f".{this}"
def _jsonpathsubscript_sql(self, expression: exp.JSONPathSubscript) -> str:
this = self.json_path_part(expression.this)
return f"[{this}]" if this else ""
def _simplify_unless_literal(self, expression: E) -> E:
if not isinstance(expression, exp.Literal):
from sqlglot.optimizer.simplify import simplify
expression = simplify(expression, dialect=self.dialect)
return expression
def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
expression.set("is_end_exclusive", None)
return self.function_fallback_sql(expression)
@ -3477,7 +3568,9 @@ class Generator(metaclass=_Generator):
expression.set(
"expressions",
[
exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
exp.alias_(e.expression, e.name if e.this.is_string else e.this)
if isinstance(e, exp.PropertyEQ)
else e
for e in expression.expressions
],
)
@ -3553,3 +3646,51 @@ class Generator(metaclass=_Generator):
transformed = cast(this=value, to=to, safe=safe)
return self.sql(transformed)
def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str:
this = expression.this
if isinstance(this, exp.JSONPathWildcard):
this = self.json_path_part(this)
return f".{this}" if this else ""
if exp.SAFE_IDENTIFIER_RE.match(this):
return f".{this}"
this = self.json_path_part(this)
return f"[{this}]" if self.JSON_PATH_BRACKETED_KEY_SUPPORTED else f".{this}"
def _jsonpathsubscript_sql(self, expression: exp.JSONPathSubscript) -> str:
this = self.json_path_part(expression.this)
return f"[{this}]" if this else ""
def _simplify_unless_literal(self, expression: E) -> E:
if not isinstance(expression, exp.Literal):
from sqlglot.optimizer.simplify import simplify
expression = simplify(expression, dialect=self.dialect)
return expression
def _embed_ignore_nulls(self, expression: exp.IgnoreNulls | exp.RespectNulls, text: str) -> str:
if self.IGNORE_NULLS_IN_FUNC and not expression.meta.get("inline"):
# The first modifier here will be the one closest to the AggFunc's arg
mods = sorted(
expression.find_all(exp.HavingMax, exp.Order, exp.Limit),
key=lambda x: 0
if isinstance(x, exp.HavingMax)
else (1 if isinstance(x, exp.Order) else 2),
)
if mods:
mod = mods[0]
this = expression.__class__(this=mod.this.copy())
this.meta["inline"] = True
mod.this.replace(this)
return self.sql(expression.this)
agg_func = expression.find(exp.AggFunc)
if agg_func:
return self.sql(agg_func)[:-1] + f" {text})"
return f"{self.sql(expression, 'this')} {text}"

View file

@ -181,7 +181,7 @@ def apply_index_offset(
annotate_types(expression)
if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
logger.warning("Applying array index offset (%s)", offset)
expression = simplify(exp.Add(this=expression, expression=exp.Literal.number(offset)))
expression = simplify(expression + offset)
return [expression]
return expressions
@ -204,13 +204,13 @@ def while_changing(expression: Expression, func: t.Callable[[Expression], E]) ->
The transformed expression.
"""
while True:
for n, *_ in reversed(tuple(expression.walk())):
for n in reversed(tuple(expression.walk())):
n._hash = hash(n)
start = hash(expression)
expression = func(expression)
for n, *_ in expression.walk():
for n in expression.walk():
n._hash = None
if start == hash(expression):
break
@ -317,8 +317,16 @@ def find_new_name(taken: t.Collection[str], base: str) -> str:
def is_int(text: str) -> bool:
return is_type(text, int)
def is_float(text: str) -> bool:
return is_type(text, float)
def is_type(text: str, target_type: t.Type) -> bool:
try:
int(text)
target_type(text)
return True
except ValueError:
return False

View file

@ -28,10 +28,7 @@ class Node:
yield self
for d in self.downstream:
if isinstance(d, Node):
yield from d.walk()
else:
yield d
yield from d.walk()
def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML:
nodes = {}
@ -71,8 +68,10 @@ def lineage(
column: str | exp.Column,
sql: str | exp.Expression,
schema: t.Optional[t.Dict | Schema] = None,
sources: t.Optional[t.Dict[str, str | exp.Query]] = None,
sources: t.Optional[t.Mapping[str, str | exp.Query]] = None,
dialect: DialectType = None,
scope: t.Optional[Scope] = None,
trim_selects: bool = True,
**kwargs,
) -> Node:
"""Build the lineage graph for a column of a SQL query.
@ -83,6 +82,8 @@ def lineage(
schema: The schema of tables.
sources: A mapping of queries which will be used to continue building lineage.
dialect: The dialect of input SQL.
scope: A pre-created scope to use instead.
trim_selects: Whether or not to clean up selects by trimming to only relevant columns.
**kwargs: Qualification optimizer kwargs.
Returns:
@ -99,14 +100,15 @@ def lineage(
dialect=dialect,
)
qualified = qualify.qualify(
expression,
dialect=dialect,
schema=schema,
**{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore
)
if not scope:
expression = qualify.qualify(
expression,
dialect=dialect,
schema=schema,
**{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore
)
scope = build_scope(qualified)
scope = build_scope(expression)
if not scope:
raise SqlglotError("Cannot build lineage, sql must be SELECT")
@ -114,7 +116,7 @@ def lineage(
if not any(select.alias_or_name == column for select in scope.expression.selects):
raise SqlglotError(f"Cannot find column '{column}' in query.")
return to_node(column, scope, dialect)
return to_node(column, scope, dialect, trim_selects=trim_selects)
def to_node(
@ -125,6 +127,7 @@ def to_node(
upstream: t.Optional[Node] = None,
source_name: t.Optional[str] = None,
reference_node_name: t.Optional[str] = None,
trim_selects: bool = True,
) -> Node:
source_names = {
dt.alias: dt.comments[0].split()[1]
@ -143,6 +146,17 @@ def to_node(
)
)
if isinstance(scope.expression, exp.Subquery):
for source in scope.subquery_scopes:
return to_node(
column,
scope=source,
dialect=dialect,
upstream=upstream,
source_name=source_name,
reference_node_name=reference_node_name,
trim_selects=trim_selects,
)
if isinstance(scope.expression, exp.Union):
upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
@ -170,11 +184,12 @@ def to_node(
upstream=upstream,
source_name=source_name,
reference_node_name=reference_node_name,
trim_selects=trim_selects,
)
return upstream
if isinstance(scope.expression, exp.Select):
if trim_selects and isinstance(scope.expression, exp.Select):
# For better ergonomics in our node labels, replace the full select with
# a version that has only the column we care about.
# "x", SELECT x, y FROM foo
@ -206,7 +221,13 @@ def to_node(
continue
for name in subquery.named_selects:
to_node(name, scope=subquery_scope, dialect=dialect, upstream=node)
to_node(
name,
scope=subquery_scope,
dialect=dialect,
upstream=node,
trim_selects=trim_selects,
)
# if the select is a star add all scope sources as downstreams
if select.is_star:
@ -237,6 +258,7 @@ def to_node(
upstream=node,
source_name=source_names.get(table) or source_name,
reference_node_name=selected_node.name if selected_node else None,
trim_selects=trim_selects,
)
else:
# The source is not a scope - we've reached the end of the line. At this point, if a source is not found

View file

@ -168,8 +168,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Exp,
exp.Ln,
exp.Log,
exp.Log2,
exp.Log10,
exp.Pow,
exp.Quantile,
exp.Round,
@ -266,26 +264,30 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Dot: lambda self, e: self._annotate_dot(e),
exp.Explode: lambda self, e: self._annotate_explode(e),
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
exp.GenerateDateArray: lambda self, e: self._annotate_with_type(
e, exp.DataType.build("ARRAY<DATE>")
),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Literal: lambda self, e: self._annotate_literal(e),
exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
exp.Map: lambda self, e: self._annotate_map(e),
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"),
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True),
exp.Struct: lambda self, e: self._annotate_struct(e),
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
exp.Timestamp: lambda self, e: self._annotate_with_type(
e,
exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP,
),
exp.ToMap: lambda self, e: self._annotate_to_map(e),
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.Unnest: lambda self, e: self._annotate_unnest(e),
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
exp.VarMap: lambda self, e: self._annotate_map(e),
}
NESTED_TYPES = {
@ -358,6 +360,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
if isinstance(source.expression, exp.Lateral):
if isinstance(source.expression.this, exp.Explode):
values = [source.expression.this.this]
elif isinstance(source.expression, exp.Unnest):
values = [source.expression]
else:
values = source.expression.expressions[0].expressions
@ -408,7 +412,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
)
def _annotate_args(self, expression: E) -> E:
for _, value in expression.iter_expressions():
for value in expression.iter_expressions():
self._maybe_annotate(value)
return expression
@ -425,23 +429,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
return exp.DataType.Type.UNKNOWN
if type1_value in self.NESTED_TYPES:
return type1
if type2_value in self.NESTED_TYPES:
return type2
return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value
return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore
# Note: the following "no_type_check" decorators were added because mypy was yelling due
# to assigning Type values to expression.type (since its getter returns Optional[DataType]).
# This is a known mypy issue: https://github.com/python/mypy/issues/3004
@t.no_type_check
def _annotate_binary(self, expression: B) -> B:
self._annotate_args(expression)
left, right = expression.left, expression.right
left_type, right_type = left.type.this, right.type.this
left_type, right_type = left.type.this, right.type.this # type: ignore
if isinstance(expression, exp.Connector):
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
@ -462,7 +456,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
return expression
@t.no_type_check
def _annotate_unary(self, expression: E) -> E:
self._annotate_args(expression)
@ -473,7 +466,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
return expression
@t.no_type_check
def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
if expression.is_string:
self._set_type(expression, exp.DataType.Type.VARCHAR)
@ -484,25 +476,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
return expression
@t.no_type_check
def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
self._set_type(expression, target_type)
return self._annotate_args(expression)
@t.no_type_check
def _annotate_struct_value(
self, expression: exp.Expression
) -> t.Optional[exp.DataType] | exp.ColumnDef:
alias = expression.args.get("alias")
if alias:
return exp.ColumnDef(this=alias.copy(), kind=expression.type)
# Case: key = value or key := value
if expression.expression:
return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)
return expression.type
@t.no_type_check
def _annotate_by_args(
self,
@ -510,7 +487,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
*args: str,
promote: bool = False,
array: bool = False,
struct: bool = False,
) -> E:
self._annotate_args(expression)
@ -546,16 +522,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
),
)
if struct:
self._set_type(
expression,
exp.DataType(
this=exp.DataType.Type.STRUCT,
expressions=[self._annotate_struct_value(expr) for expr in expressions],
nested=True,
),
)
return expression
def _annotate_timeunit(
@ -605,6 +571,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
self._set_type(expression, exp.DataType.Type.BIGINT)
else:
self._set_type(expression, self._maybe_coerce(left_type, right_type))
if expression.type and expression.type.this not in exp.DataType.REAL_TYPES:
self._set_type(
expression, self._maybe_coerce(expression.type, exp.DataType.Type.DOUBLE)
)
return expression
@ -631,3 +601,68 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
child = seq_get(expression.expressions, 0)
self._set_type(expression, child and seq_get(child.type.expressions, 0))
return expression
def _annotate_struct_value(
self, expression: exp.Expression
) -> t.Optional[exp.DataType] | exp.ColumnDef:
alias = expression.args.get("alias")
if alias:
return exp.ColumnDef(this=alias.copy(), kind=expression.type)
# Case: key = value or key := value
if expression.expression:
return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)
return expression.type
def _annotate_struct(self, expression: exp.Struct) -> exp.Struct:
self._annotate_args(expression)
self._set_type(
expression,
exp.DataType(
this=exp.DataType.Type.STRUCT,
expressions=[self._annotate_struct_value(expr) for expr in expression.expressions],
nested=True,
),
)
return expression
@t.overload
def _annotate_map(self, expression: exp.Map) -> exp.Map: ...
@t.overload
def _annotate_map(self, expression: exp.VarMap) -> exp.VarMap: ...
def _annotate_map(self, expression):
self._annotate_args(expression)
keys = expression.args.get("keys")
values = expression.args.get("values")
map_type = exp.DataType(this=exp.DataType.Type.MAP)
if isinstance(keys, exp.Array) and isinstance(values, exp.Array):
key_type = seq_get(keys.type.expressions, 0) or exp.DataType.Type.UNKNOWN
value_type = seq_get(values.type.expressions, 0) or exp.DataType.Type.UNKNOWN
if key_type != exp.DataType.Type.UNKNOWN and value_type != exp.DataType.Type.UNKNOWN:
map_type.set("expressions", [key_type, value_type])
map_type.set("nested", True)
self._set_type(expression, map_type)
return expression
def _annotate_to_map(self, expression: exp.ToMap) -> exp.ToMap:
self._annotate_args(expression)
map_type = exp.DataType(this=exp.DataType.Type.MAP)
arg = expression.this
if arg.is_type(exp.DataType.Type.STRUCT):
for coldef in arg.type.expressions:
kind = coldef.kind
if kind != exp.DataType.Type.UNKNOWN:
map_type.set("expressions", [exp.DataType.build("varchar"), kind])
map_type.set("nested", True)
break
self._set_type(expression, map_type)
return expression

View file

@ -16,16 +16,17 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
Args:
expression: The expression to canonicalize.
"""
exp.replace_children(expression, canonicalize)
expression = add_text_to_concat(expression)
expression = replace_date_funcs(expression)
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
expression = ensure_bools(expression, _replace_int_predicate)
expression = remove_ascending_order(expression)
def _canonicalize(expression: exp.Expression) -> exp.Expression:
expression = add_text_to_concat(expression)
expression = replace_date_funcs(expression)
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
expression = ensure_bools(expression, _replace_int_predicate)
expression = remove_ascending_order(expression)
return expression
return expression
return exp.replace_tree(expression, _canonicalize)
def add_text_to_concat(node: exp.Expression) -> exp.Expression:
@ -35,7 +36,11 @@ def add_text_to_concat(node: exp.Expression) -> exp.Expression:
def replace_date_funcs(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Date) and not node.expressions and not node.args.get("zone"):
if (
isinstance(node, (exp.Date, exp.TsOrDsToDate))
and not node.expressions
and not node.args.get("zone")
):
return exp.cast(node.this, to=exp.DataType.Type.DATE)
if isinstance(node, exp.Timestamp) and not node.expression:
if not node.type:
@ -121,15 +126,11 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
a = _coerce_timeunit_arg(a, b.unit)
if (
a.type
and a.type.this == exp.DataType.Type.DATE
and a.type.this in exp.DataType.TEMPORAL_TYPES
and b.type
and b.type.this
not in (
exp.DataType.Type.DATE,
exp.DataType.Type.INTERVAL,
)
and b.type.this in exp.DataType.TEXT_TYPES
):
_replace_cast(b, exp.DataType.Type.DATE)
_replace_cast(b, exp.DataType.Type.DATETIME)
def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression:
@ -169,7 +170,7 @@ def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None:
# with y as (select true as x) select x = 0 FROM y -- illegal presto query
def _replace_int_predicate(expression: exp.Expression) -> None:
if isinstance(expression, exp.Coalesce):
for _, child in expression.iter_expressions():
for child in expression.iter_expressions():
_replace_int_predicate(child)
elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
expression.replace(expression.neq(0))

View file

@ -32,7 +32,7 @@ def eliminate_ctes(expression):
cte_node.pop()
# Pop the entire WITH clause if this is the last CTE
if len(with_node.expressions) <= 0:
if with_node and len(with_node.expressions) <= 0:
with_node.pop()
# Decrement the ref count for all sources this CTE selects from

View file

@ -214,6 +214,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
and not _outer_select_joins_on_inner_select_join()
and not _is_a_window_expression_in_unmergable_operation()
and not _is_recursive()
and not (inner_select.args.get("order") and outer_scope.is_union)
)

View file

@ -28,7 +28,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
Returns:
sqlglot.Expression: normalized expression
"""
for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))):
if isinstance(node, exp.Connector):
if normalized(node, dnf=dnf):
continue

View file

@ -53,10 +53,8 @@ def normalize_identifiers(expression, dialect=None):
if isinstance(expression, str):
expression = exp.parse_identifier(expression, dialect=dialect)
def _normalize(node: E) -> E:
for node in expression.walk(prune=lambda n: n.meta.get("case_sensitive")):
if not node.meta.get("case_sensitive"):
exp.replace_children(node, _normalize)
node = dialect.normalize_identifier(node)
return node
dialect.normalize_identifier(node)
return _normalize(expression)
return expression

View file

@ -82,13 +82,13 @@ def optimize(
**kwargs,
}
expression = exp.maybe_parse(expression, dialect=dialect, copy=True)
optimized = exp.maybe_parse(expression, dialect=dialect, copy=True)
for rule in rules:
# Find any additional rule parameters, beyond `expression`
rule_params = rule.__code__.co_varnames
rule_kwargs = {
param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
}
expression = rule(expression, **rule_kwargs)
optimized = rule(optimized, **rule_kwargs)
return t.cast(exp.Expression, expression)
return optimized

View file

@ -77,13 +77,13 @@ def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
pushdown_dnf(predicates, sources, scope_ref_count)
def pushdown_cnf(predicates, scope, scope_ref_count, join_index=None):
def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
"""
If the predicates are in CNF like form, we can simply replace each block in the parent.
"""
join_index = join_index or {}
for predicate in predicates:
for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
for node in nodes_for_predicate(predicate, sources, scope_ref_count).values():
if isinstance(node, exp.Join):
name = node.alias_or_name
predicate_tables = exp.column_table_names(predicate, name)
@ -103,7 +103,7 @@ def pushdown_cnf(predicates, scope, scope_ref_count, join_index=None):
node.where(inner_predicate, copy=False)
def pushdown_dnf(predicates, scope, scope_ref_count):
def pushdown_dnf(predicates, sources, scope_ref_count):
"""
If the predicates are in DNF form, we can only push down conditions that are in all blocks.
Additionally, we can't remove predicates from their original form.
@ -127,7 +127,7 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
# pushdown all predicates to their respective nodes
for table in sorted(pushdown_tables):
for predicate in predicates:
nodes = nodes_for_predicate(predicate, scope, scope_ref_count)
nodes = nodes_for_predicate(predicate, sources, scope_ref_count)
if table not in nodes:
continue

View file

@ -54,11 +54,15 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
if any(select.is_star for select in right.expression.selects):
referenced_columns[right] = parent_selections
elif not any(select.is_star for select in left.expression.selects):
referenced_columns[right] = [
right.expression.selects[i].alias_or_name
for i, select in enumerate(left.expression.selects)
if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections
]
if scope.expression.args.get("by_name"):
referenced_columns[right] = referenced_columns[left]
else:
referenced_columns[right] = [
right.expression.selects[i].alias_or_name
for i, select in enumerate(left.expression.selects)
if SELECT_ALL in parent_selections
or select.alias_or_name in parent_selections
]
if isinstance(scope.expression, exp.Select):
if remove_unused_selections:

View file

@ -209,7 +209,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
if not node:
return
for column, *_ in walk_in_scope(node, prune=lambda node, *_: node.is_star):
for column in walk_in_scope(node, prune=lambda node: node.is_star):
if not isinstance(column, exp.Column):
continue
@ -306,7 +306,7 @@ def _expand_positional_references(
else:
select = select.this
if isinstance(select, exp.Literal):
if isinstance(select, exp.CONSTANTS) or select.find(exp.Explode, exp.Unnest):
new_nodes.append(node)
else:
new_nodes.append(select.copy())
@ -425,7 +425,7 @@ def _expand_stars(
raise OptimizeError(f"Unknown table: {table}")
columns = resolver.get_source_columns(table, only_visible=True)
columns = columns or scope.outer_column_list
columns = columns or scope.outer_columns
if pseudocolumns:
columns = [name for name in columns if name.upper() not in pseudocolumns]
@ -517,7 +517,7 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
new_selections = []
for i, (selection, aliased_column) in enumerate(
itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
itertools.zip_longest(scope.expression.selects, scope.outer_columns)
):
if selection is None:
break
@ -544,7 +544,7 @@ def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool
"""Makes sure all identifiers that need to be quoted are quoted."""
return expression.transform(
Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
)
) # type: ignore
def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:

View file

@ -56,7 +56,7 @@ def qualify_tables(
table.set("catalog", catalog)
if not isinstance(expression, exp.Query):
for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Query)):
for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)):
if isinstance(node, exp.Table):
_qualify(node)
@ -118,11 +118,11 @@ def qualify_tables(
for i, e in enumerate(udtf.expressions[0].expressions):
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
else:
for node, parent, _ in scope.walk():
for node in scope.walk():
if (
isinstance(node, exp.Table)
and not node.alias
and isinstance(parent, (exp.From, exp.Join))
and isinstance(node.parent, (exp.From, exp.Join))
):
# Mutates the table by attaching an alias to it
alias(node, node.name, copy=False, table=True)

View file

@ -8,7 +8,7 @@ from enum import Enum, auto
from sqlglot import exp
from sqlglot.errors import OptimizeError
from sqlglot.helper import ensure_collection, find_new_name
from sqlglot.helper import ensure_collection, find_new_name, seq_get
logger = logging.getLogger("sqlglot")
@ -38,11 +38,11 @@ class Scope:
SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
The LATERAL VIEW EXPLODE gets x as a source.
cte_sources (dict[str, Scope]): Sources from CTES
outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
outer_columns (list[str]): If this is a derived table or CTE, and the outer query
defines a column list for the alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have `["col1", "col2"]` for its `outer_column_list`
The inner query would have `["col1", "col2"]` for its `outer_columns`
parent (Scope): Parent scope
scope_type (ScopeType): Type of this scope, relative to it's parent
subquery_scopes (list[Scope]): List of all child scopes for subqueries
@ -58,7 +58,7 @@ class Scope:
self,
expression,
sources=None,
outer_column_list=None,
outer_columns=None,
parent=None,
scope_type=ScopeType.ROOT,
lateral_sources=None,
@ -70,7 +70,7 @@ class Scope:
self.cte_sources = cte_sources or {}
self.sources.update(self.lateral_sources)
self.sources.update(self.cte_sources)
self.outer_column_list = outer_column_list or []
self.outer_columns = outer_columns or []
self.parent = parent
self.scope_type = scope_type
self.subquery_scopes = []
@ -119,10 +119,11 @@ class Scope:
self._raw_columns = []
self._join_hints = []
for node, parent, _ in self.walk(bfs=False):
for node in self.walk(bfs=False):
if node is self.expression:
continue
elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
self._raw_columns.append(node)
elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
self._tables.append(node)
@ -132,10 +133,8 @@ class Scope:
self._udtfs.append(node)
elif isinstance(node, exp.CTE):
self._ctes.append(node)
elif (
isinstance(node, exp.Subquery)
and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
and _is_derived_table(node)
elif _is_derived_table(node) and isinstance(
node.parent, (exp.From, exp.Join, exp.Subquery)
):
self._derived_tables.append(node)
elif isinstance(node, exp.UNWRAPPED_QUERIES):
@ -438,11 +437,21 @@ class Scope:
Yields:
Scope: scope instances in depth-first-search post-order
"""
for child_scope in itertools.chain(
self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
):
yield from child_scope.traverse()
yield self
stack = [self]
result = []
while stack:
scope = stack.pop()
result.append(scope)
stack.extend(
itertools.chain(
scope.cte_scopes,
scope.union_scopes,
scope.table_scopes,
scope.subquery_scopes,
)
)
yield from reversed(result)
def ref_count(self):
"""
@ -481,14 +490,28 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Args:
expression (exp.Expression): expression to traverse
expression: Expression to traverse
Returns:
list[Scope]: scope instances
A list of the created scope instances
"""
if isinstance(expression, exp.Query) or (
isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query)
):
if isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query):
# We ignore the DDL expression and build a scope for its query instead
ddl_with = expression.args.get("with")
expression = expression.expression
# If the DDL has CTEs attached, we need to add them to the query, or
# prepend them if the query itself already has CTEs attached to it
if ddl_with:
ddl_with.pop()
query_ctes = expression.ctes
if not query_ctes:
expression.set("with", ddl_with)
else:
expression.args["with"].set("recursive", ddl_with.recursive)
expression.args["with"].set("expressions", [*ddl_with.expressions, *query_ctes])
if isinstance(expression, exp.Query):
return list(_traverse_scope(Scope(expression)))
return []
@ -499,21 +522,21 @@ def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
Build a scope tree.
Args:
expression (exp.Expression): expression to build the scope tree for
expression: Expression to build the scope tree for.
Returns:
Scope: root scope
The root scope
"""
scopes = traverse_scope(expression)
if scopes:
return scopes[-1]
return None
return seq_get(traverse_scope(expression), -1)
def _traverse_scope(scope):
if isinstance(scope.expression, exp.Select):
yield from _traverse_select(scope)
elif isinstance(scope.expression, exp.Union):
yield from _traverse_ctes(scope)
yield from _traverse_union(scope)
return
elif isinstance(scope.expression, exp.Subquery):
if scope.is_root:
yield from _traverse_select(scope)
@ -523,8 +546,6 @@ def _traverse_scope(scope):
yield from _traverse_tables(scope)
elif isinstance(scope.expression, exp.UDTF):
yield from _traverse_udtfs(scope)
elif isinstance(scope.expression, exp.DDL):
yield from _traverse_ddl(scope)
else:
logger.warning(
"Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
@ -541,30 +562,38 @@ def _traverse_select(scope):
def _traverse_union(scope):
yield from _traverse_ctes(scope)
prev_scope = None
union_scope_stack = [scope]
expression_stack = [scope.expression.right, scope.expression.left]
# The last scope to be yield should be the top most scope
left = None
for left in _traverse_scope(
scope.branch(
scope.expression.left,
outer_column_list=scope.outer_column_list,
while expression_stack:
expression = expression_stack.pop()
union_scope = union_scope_stack[-1]
new_scope = union_scope.branch(
expression,
outer_columns=union_scope.outer_columns,
scope_type=ScopeType.UNION,
)
):
yield left
right = None
for right in _traverse_scope(
scope.branch(
scope.expression.right,
outer_column_list=scope.outer_column_list,
scope_type=ScopeType.UNION,
)
):
yield right
if isinstance(expression, exp.Union):
yield from _traverse_ctes(new_scope)
scope.union_scopes = [left, right]
union_scope_stack.append(new_scope)
expression_stack.extend([expression.right, expression.left])
continue
for scope in _traverse_scope(new_scope):
yield scope
if prev_scope:
union_scope_stack.pop()
union_scope.union_scopes = [prev_scope, scope]
prev_scope = union_scope
yield union_scope
else:
prev_scope = scope
def _traverse_ctes(scope):
@ -588,7 +617,7 @@ def _traverse_ctes(scope):
scope.branch(
cte.this,
cte_sources=sources,
outer_column_list=cte.alias_column_names,
outer_columns=cte.alias_column_names,
scope_type=ScopeType.CTE,
)
):
@ -615,7 +644,9 @@ def _is_derived_table(expression: exp.Subquery) -> bool:
as it doesn't introduce a new scope. If an alias is present, it shadows all names
under the Subquery, so that's one exception to this rule.
"""
return bool(expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES))
return isinstance(expression, exp.Subquery) and bool(
expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES)
)
def _traverse_tables(scope):
@ -681,7 +712,7 @@ def _traverse_tables(scope):
scope.branch(
expression,
lateral_sources=lateral_sources,
outer_column_list=expression.alias_column_names,
outer_columns=expression.alias_column_names,
scope_type=scope_type,
)
):
@ -719,13 +750,13 @@ def _traverse_udtfs(scope):
sources = {}
for expression in expressions:
if isinstance(expression, exp.Subquery) and _is_derived_table(expression):
if _is_derived_table(expression):
top = None
for child_scope in _traverse_scope(
scope.branch(
expression,
scope_type=ScopeType.DERIVED_TABLE,
outer_column_list=expression.alias_column_names,
outer_columns=expression.alias_column_names,
)
):
yield child_scope
@ -738,18 +769,6 @@ def _traverse_udtfs(scope):
scope.sources.update(sources)
def _traverse_ddl(scope):
yield from _traverse_ctes(scope)
query_scope = scope.branch(
scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, sources=scope.sources
)
query_scope._collect()
query_scope._ctes = scope.ctes + query_scope._ctes
yield from _traverse_scope(query_scope)
def walk_in_scope(expression, bfs=True, prune=None):
"""
Returns a generator object which visits all nodes in the syntrax tree, stopping at
@ -769,23 +788,21 @@ def walk_in_scope(expression, bfs=True, prune=None):
# Whenever we set it to True, we exclude a subtree from traversal.
crossed_scope_boundary = False
for node, parent, key in expression.walk(
bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args))
for node in expression.walk(
bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n))
):
crossed_scope_boundary = False
yield node, parent, key
yield node
if node is expression:
continue
if (
isinstance(node, exp.CTE)
or (
isinstance(node, exp.Subquery)
and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
and _is_derived_table(node)
isinstance(node.parent, (exp.From, exp.Join, exp.Subquery))
and (_is_derived_table(node) or isinstance(node, exp.UDTF))
)
or isinstance(node, exp.UDTF)
or isinstance(node, exp.UNWRAPPED_QUERIES)
):
crossed_scope_boundary = True
@ -812,7 +829,7 @@ def find_all_in_scope(expression, expression_types, bfs=True):
Yields:
exp.Expression: nodes
"""
for expression, *_ in walk_in_scope(expression, bfs=bfs):
for expression in walk_in_scope(expression, bfs=bfs):
if isinstance(expression, tuple(ensure_collection(expression_types))):
yield expression

View file

@ -9,19 +9,25 @@ from decimal import Decimal
import sqlglot
from sqlglot import Dialect, exp
from sqlglot.helper import first, is_iterable, merge_ranges, while_changing
from sqlglot.helper import first, merge_ranges, while_changing
from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
DateTruncBinaryTransform = t.Callable[
[exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression]
[exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
]
# Final means that an expression should not be simplified
FINAL = "final"
# Value ranges for byte-sized signed/unsigned integers
TINYINT_MIN = -128
TINYINT_MAX = 127
UTINYINT_MIN = 0
UTINYINT_MAX = 255
class UnsupportedUnit(Exception):
pass
@ -63,14 +69,14 @@ def simplify(
group.meta[FINAL] = True
for e in expression.selects:
for node, *_ in e.walk():
for node in e.walk():
if node in groups:
e.meta[FINAL] = True
break
having = expression.args.get("having")
if having:
for node, *_ in having.walk():
for node in having.walk():
if node in groups:
having.meta[FINAL] = True
break
@ -304,6 +310,8 @@ def _simplify_comparison(expression, left, right, or_=False):
r = extract_date(r)
if not r:
return None
# python won't compare date and datetime, but many engines will upcast
l, r = cast_as_datetime(l), cast_as_datetime(r)
for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
@ -431,7 +439,7 @@ def propagate_constants(expression, root=True):
and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
):
constant_mapping = {}
for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
if isinstance(expr, exp.EQ):
l, r = expr.left, expr.right
@ -544,7 +552,37 @@ def simplify_literals(expression, root=True):
return expression
NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
this = _simplify_integer_cast(expr.this)
else:
this = expr.this
if isinstance(expr, exp.Cast) and this.is_int:
num = int(this.name)
# Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
# integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
# engine-dependent
if (
TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
) or (
UTINYINT_MIN <= num <= UTINYINT_MAX
and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
):
return this
return expr
def _simplify_binary(expression, a, b):
if isinstance(expression, COMPARISONS):
a = _simplify_integer_cast(a)
b = _simplify_integer_cast(b)
if isinstance(expression, exp.Is):
if isinstance(b, exp.Not):
c = b.this
@ -558,7 +596,7 @@ def _simplify_binary(expression, a, b):
return exp.true() if not_ else exp.false()
if is_null(a):
return exp.false() if not_ else exp.true()
elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
elif isinstance(expression, NULL_OK):
return None
elif is_null(a) or is_null(b):
return exp.null()
@ -591,17 +629,17 @@ def _simplify_binary(expression, a, b):
if boolean:
return boolean
elif _is_date_literal(a) and isinstance(b, exp.Interval):
a, b = extract_date(a), extract_interval(b)
if a and b:
date, b = extract_date(a), extract_interval(b)
if date and b:
if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
return date_literal(a + b)
return date_literal(date + b, extract_type(a))
if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
return date_literal(a - b)
return date_literal(date - b, extract_type(a))
elif isinstance(a, exp.Interval) and _is_date_literal(b):
a, b = extract_interval(a), extract_date(b)
a, date = extract_interval(a), extract_date(b)
# you cannot subtract a date from an interval
if a and b and isinstance(expression, exp.Add):
return date_literal(a + b)
return date_literal(a + date, extract_type(b))
elif _is_date_literal(a) and _is_date_literal(b):
if isinstance(expression, exp.Predicate):
a, b = extract_date(a), extract_date(b)
@ -618,12 +656,16 @@ def simplify_parens(expression):
this = expression.this
parent = expression.parent
parent_is_predicate = isinstance(parent, exp.Predicate)
if not isinstance(this, exp.Select) and (
not isinstance(parent, (exp.Condition, exp.Binary))
or isinstance(parent, exp.Paren)
or not isinstance(this, exp.Binary)
or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
or (
not isinstance(this, exp.Binary)
and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
)
or (isinstance(this, exp.Predicate) and not parent_is_predicate)
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
@ -632,24 +674,12 @@ def simplify_parens(expression):
return expression
NONNULL_CONSTANTS = (
exp.Literal,
exp.Boolean,
)
CONSTANTS = (
exp.Literal,
exp.Boolean,
exp.Null,
)
def _is_nonnull_constant(expression: exp.Expression) -> bool:
return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression)
return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
def _is_constant(expression: exp.Expression) -> bool:
return isinstance(expression, CONSTANTS) or _is_date_literal(expression)
return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression)
def simplify_coalesce(expression):
@ -820,45 +850,55 @@ def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Opti
return floor, floor + interval(unit)
def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
def _datetrunc_eq_expression(
left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType]
) -> exp.Expression:
"""Get the logical expression for a date range"""
return exp.and_(
left >= date_literal(drange[0]),
left < date_literal(drange[1]),
left >= date_literal(drange[0], target_type),
left < date_literal(drange[1], target_type),
copy=False,
)
def _datetrunc_eq(
left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
left: exp.Expression,
date: datetime.date,
unit: str,
dialect: Dialect,
target_type: t.Optional[exp.DataType],
) -> t.Optional[exp.Expression]:
drange = _datetrunc_range(date, unit, dialect)
if not drange:
return None
return _datetrunc_eq_expression(left, drange)
return _datetrunc_eq_expression(left, drange, target_type)
def _datetrunc_neq(
left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
left: exp.Expression,
date: datetime.date,
unit: str,
dialect: Dialect,
target_type: t.Optional[exp.DataType],
) -> t.Optional[exp.Expression]:
drange = _datetrunc_range(date, unit, dialect)
if not drange:
return None
return exp.and_(
left < date_literal(drange[0]),
left >= date_literal(drange[1]),
left < date_literal(drange[0], target_type),
left >= date_literal(drange[1], target_type),
copy=False,
)
DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
exp.LT: lambda l, dt, u, d: l
< date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)),
exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)),
exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)),
exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)),
exp.LT: lambda l, dt, u, d, t: l
< date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
exp.EQ: _datetrunc_eq,
exp.NEQ: _datetrunc_neq,
}
@ -876,9 +916,11 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr
comparison = expression.__class__
if isinstance(expression, DATETRUNCS):
date = extract_date(expression.this)
this = expression.this
trunc_type = extract_type(this)
date = extract_date(this)
if date and expression.unit:
return date_literal(date_floor(date, expression.unit.name.lower(), dialect))
return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type)
elif comparison not in DATETRUNC_COMPARISONS:
return expression
@ -889,14 +931,21 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr
return expression
l = t.cast(exp.DateTrunc, l)
trunc_arg = l.this
unit = l.unit.name.lower()
date = extract_date(r)
if not date:
return expression
return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression
elif isinstance(expression, exp.In):
return (
DATETRUNC_BINARY_COMPARISONS[comparison](
trunc_arg, date, unit, dialect, extract_type(trunc_arg, r)
)
or expression
)
if isinstance(expression, exp.In):
l = expression.this
rs = expression.expressions
@ -917,8 +966,11 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr
return expression
ranges = merge_ranges(ranges)
target_type = extract_type(l, *rs)
return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
return exp.or_(
*[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False
)
return expression
@ -954,7 +1006,7 @@ JOINS = {
def remove_where_true(expression):
for where in expression.find_all(exp.Where):
if always_true(where.this):
where.parent.set("where", None)
where.pop()
for join in expression.find_all(exp.Join):
if (
always_true(join.args.get("on"))
@ -962,7 +1014,7 @@ def remove_where_true(expression):
and not join.args.get("method")
and (join.side, join.kind) in JOINS
):
join.set("on", None)
join.args["on"].pop()
join.set("side", None)
join.set("kind", "CROSS")
@ -1067,15 +1119,25 @@ def extract_interval(expression):
return None
def date_literal(date):
return exp.cast(
exp.Literal.string(date),
(
def extract_type(*expressions):
target_type = None
for expression in expressions:
target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
if target_type:
break
return target_type
def date_literal(date, target_type=None):
if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
target_type = (
exp.DataType.Type.DATETIME
if isinstance(date, datetime.datetime)
else exp.DataType.Type.DATE
),
)
)
return exp.cast(exp.Literal.string(date), target_type)
def interval(unit: str, n: int = 1):
@ -1169,73 +1231,251 @@ def gen(expression: t.Any) -> str:
Sorting and deduping sql is a necessary step for optimization. Calling the actual
generator is expensive so we have a bare minimum sql generator here.
"""
if expression is None:
return "_"
if is_iterable(expression):
return ",".join(gen(e) for e in expression)
if not isinstance(expression, exp.Expression):
return str(expression)
etype = type(expression)
if etype in GEN_MAP:
return GEN_MAP[etype](expression)
return f"{expression.key} {gen(expression.args.values())}"
return Gen().gen(expression)
GEN_MAP = {
exp.Add: lambda e: _binary(e, "+"),
exp.And: lambda e: _binary(e, "AND"),
exp.Anonymous: lambda e: _anonymous(e),
exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
exp.Column: lambda e: ".".join(gen(p) for p in e.parts),
exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}",
exp.Div: lambda e: _binary(e, "/"),
exp.Dot: lambda e: _binary(e, "."),
exp.EQ: lambda e: _binary(e, "="),
exp.GT: lambda e: _binary(e, ">"),
exp.GTE: lambda e: _binary(e, ">="),
exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name,
exp.ILike: lambda e: _binary(e, "ILIKE"),
exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})",
exp.Is: lambda e: _binary(e, "IS"),
exp.Like: lambda e: _binary(e, "LIKE"),
exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name,
exp.LT: lambda e: _binary(e, "<"),
exp.LTE: lambda e: _binary(e, "<="),
exp.Mod: lambda e: _binary(e, "%"),
exp.Mul: lambda e: _binary(e, "*"),
exp.Neg: lambda e: _unary(e, "-"),
exp.NEQ: lambda e: _binary(e, "<>"),
exp.Not: lambda e: _unary(e, "NOT"),
exp.Null: lambda e: "NULL",
exp.Or: lambda e: _binary(e, "OR"),
exp.Paren: lambda e: f"({gen(e.this)})",
exp.Sub: lambda e: _binary(e, "-"),
exp.Subquery: lambda e: f"({gen(e.args.values())})",
exp.Table: lambda e: gen(e.args.values()),
exp.Var: lambda e: e.name,
}
class Gen:
def __init__(self):
self.stack = []
self.sqls = []
def gen(self, expression: exp.Expression) -> str:
self.stack = [expression]
self.sqls.clear()
def _anonymous(e: exp.Anonymous) -> str:
this = e.this
if isinstance(this, str):
name = this.upper()
elif isinstance(this, exp.Identifier):
name = f'"{this.name}"' if this.quoted else this.name.upper()
else:
raise ValueError(
f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
while self.stack:
node = self.stack.pop()
if isinstance(node, exp.Expression):
exp_handler_name = f"{node.key}_sql"
if hasattr(self, exp_handler_name):
getattr(self, exp_handler_name)(node)
elif isinstance(node, exp.Func):
self._function(node)
else:
key = node.key.upper()
self.stack.append(f"{key} " if self._args(node) else key)
elif type(node) is list:
for n in reversed(node):
if n is not None:
self.stack.extend((n, ","))
if node:
self.stack.pop()
else:
if node is not None:
self.sqls.append(str(node))
return "".join(self.sqls)
def add_sql(self, e: exp.Add) -> None:
self._binary(e, " + ")
def alias_sql(self, e: exp.Alias) -> None:
self.stack.extend(
(
e.args.get("alias"),
" AS ",
e.args.get("this"),
)
)
return f"{name} {','.join(gen(e) for e in e.expressions)}"
def and_sql(self, e: exp.And) -> None:
self._binary(e, " AND ")
def anonymous_sql(self, e: exp.Anonymous) -> None:
this = e.this
if isinstance(this, str):
name = this.upper()
elif isinstance(this, exp.Identifier):
name = this.this
name = f'"{name}"' if this.quoted else name.upper()
else:
raise ValueError(
f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
)
def _binary(e: exp.Binary, op: str) -> str:
return f"{gen(e.left)} {op} {gen(e.right)}"
self.stack.extend(
(
")",
e.expressions,
"(",
name,
)
)
def between_sql(self, e: exp.Between) -> None:
self.stack.extend(
(
e.args.get("high"),
" AND ",
e.args.get("low"),
" BETWEEN ",
e.this,
)
)
def _unary(e: exp.Unary, op: str) -> str:
return f"{op} {gen(e.this)}"
def boolean_sql(self, e: exp.Boolean) -> None:
self.stack.append("TRUE" if e.this else "FALSE")
def bracket_sql(self, e: exp.Bracket) -> None:
self.stack.extend(
(
"]",
e.expressions,
"[",
e.this,
)
)
def column_sql(self, e: exp.Column) -> None:
for p in reversed(e.parts):
self.stack.extend((p, "."))
self.stack.pop()
def datatype_sql(self, e: exp.DataType) -> None:
self._args(e, 1)
self.stack.append(f"{e.this.name} ")
def div_sql(self, e: exp.Div) -> None:
self._binary(e, " / ")
def dot_sql(self, e: exp.Dot) -> None:
self._binary(e, ".")
def eq_sql(self, e: exp.EQ) -> None:
self._binary(e, " = ")
def from_sql(self, e: exp.From) -> None:
self.stack.extend((e.this, "FROM "))
def gt_sql(self, e: exp.GT) -> None:
self._binary(e, " > ")
def gte_sql(self, e: exp.GTE) -> None:
self._binary(e, " >= ")
def identifier_sql(self, e: exp.Identifier) -> None:
self.stack.append(f'"{e.this}"' if e.quoted else e.this)
def ilike_sql(self, e: exp.ILike) -> None:
self._binary(e, " ILIKE ")
def in_sql(self, e: exp.In) -> None:
self.stack.append(")")
self._args(e, 1)
self.stack.extend(
(
"(",
" IN ",
e.this,
)
)
def intdiv_sql(self, e: exp.IntDiv) -> None:
self._binary(e, " DIV ")
def is_sql(self, e: exp.Is) -> None:
self._binary(e, " IS ")
def like_sql(self, e: exp.Like) -> None:
self._binary(e, " Like ")
def literal_sql(self, e: exp.Literal) -> None:
self.stack.append(f"'{e.this}'" if e.is_string else e.this)
def lt_sql(self, e: exp.LT) -> None:
self._binary(e, " < ")
def lte_sql(self, e: exp.LTE) -> None:
self._binary(e, " <= ")
def mod_sql(self, e: exp.Mod) -> None:
self._binary(e, " % ")
def mul_sql(self, e: exp.Mul) -> None:
self._binary(e, " * ")
def neg_sql(self, e: exp.Neg) -> None:
self._unary(e, "-")
def neq_sql(self, e: exp.NEQ) -> None:
self._binary(e, " <> ")
def not_sql(self, e: exp.Not) -> None:
self._unary(e, "NOT ")
def null_sql(self, e: exp.Null) -> None:
self.stack.append("NULL")
def or_sql(self, e: exp.Or) -> None:
self._binary(e, " OR ")
def paren_sql(self, e: exp.Paren) -> None:
self.stack.extend(
(
")",
e.this,
"(",
)
)
def sub_sql(self, e: exp.Sub) -> None:
self._binary(e, " - ")
def subquery_sql(self, e: exp.Subquery) -> None:
self._args(e, 2)
alias = e.args.get("alias")
if alias:
self.stack.append(alias)
self.stack.extend((")", e.this, "("))
def table_sql(self, e: exp.Table) -> None:
self._args(e, 4)
alias = e.args.get("alias")
if alias:
self.stack.append(alias)
for p in reversed(e.parts):
self.stack.extend((p, "."))
self.stack.pop()
def tablealias_sql(self, e: exp.TableAlias) -> None:
columns = e.columns
if columns:
self.stack.extend((")", columns, "("))
self.stack.extend((e.this, " AS "))
def var_sql(self, e: exp.Var) -> None:
self.stack.append(e.this)
def _binary(self, e: exp.Binary, op: str) -> None:
self.stack.extend((e.expression, op, e.this))
def _unary(self, e: exp.Unary, op: str) -> None:
self.stack.extend((e.this, op))
def _function(self, e: exp.Func) -> None:
self.stack.extend(
(
")",
list(e.args.values()),
"(",
e.sql_name(),
)
)
def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
kvs = []
arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
for k in arg_types or arg_types:
v = node.args.get(k)
if v is not None:
kvs.append([f":{k}", v])
if kvs:
self.stack.append(kvs)
return True
return False

View file

@ -138,7 +138,7 @@ def decorrelate(select, parent_select, external_columns, next_alias_name):
if isinstance(predicate, exp.Binary):
key = (
predicate.right
if any(node is column for node, *_ in predicate.left.walk())
if any(node is column for node in predicate.left.walk())
else predicate.left
)
else:

File diff suppressed because it is too large Load diff

View file

@ -118,6 +118,7 @@ class Step:
if joins:
join = Join.from_joins(joins, ctes)
join.name = step.name
join.source_name = step.name
join.add_dependency(step)
step = join
@ -187,13 +188,13 @@ class Step:
intermediate[v.name] = k
for projection in projections:
for node, *_ in projection.walk():
for node in projection.walk():
name = intermediate.get(node)
if name:
node.replace(exp.column(name, step.name))
if aggregate.condition:
for node, *_ in aggregate.condition.walk():
for node in aggregate.condition.walk():
name = intermediate.get(node) or intermediate.get(node.name)
if name:
node.replace(exp.column(name, step.name))
@ -331,7 +332,7 @@ class Join(Step):
@classmethod
def from_joins(
cls, joins: t.Iterable[exp.Join], ctes: t.Optional[t.Dict[str, Step]] = None
) -> Step:
) -> Join:
step = Join()
for join in joins:
@ -349,10 +350,11 @@ class Join(Step):
def __init__(self) -> None:
super().__init__()
self.source_name: t.Optional[str] = None
self.joins: t.Dict[str, t.Dict[str, t.List[str] | exp.Expression]] = {}
def _to_s(self, indent: str) -> t.List[str]:
lines = []
lines = [f"{indent}Source: {self.source_name or self.name}"]
for name, join in self.joins.items():
lines.append(f"{indent}{name}: {join['side'] or 'INNER'}")
join_key = ", ".join(str(key) for key in t.cast(list, join.get("join_key") or []))
@ -423,7 +425,7 @@ class SetOperation(Step):
@classmethod
def from_expression(
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
) -> Step:
) -> SetOperation:
assert isinstance(expression, exp.Union)
left = Step.from_expression(expression.left, ctes)

View file

@ -135,6 +135,7 @@ class TokenType(AutoName):
LONGBLOB = auto()
TINYBLOB = auto()
TINYTEXT = auto()
NAME = auto()
BINARY = auto()
VARBINARY = auto()
JSON = auto()
@ -290,6 +291,7 @@ class TokenType(AutoName):
LOAD = auto()
LOCK = auto()
MAP = auto()
MATCH_CONDITION = auto()
MATCH_RECOGNIZE = auto()
MEMBER_OF = auto()
MERGE = auto()
@ -317,6 +319,7 @@ class TokenType(AutoName):
PERCENT = auto()
PIVOT = auto()
PLACEHOLDER = auto()
POSITIONAL = auto()
PRAGMA = auto()
PREWHERE = auto()
PRIMARY_KEY = auto()
@ -340,6 +343,7 @@ class TokenType(AutoName):
SELECT = auto()
SEMI = auto()
SEPARATOR = auto()
SEQUENCE = auto()
SERDE_PROPERTIES = auto()
SET = auto()
SETTINGS = auto()
@ -518,6 +522,7 @@ class _Tokenizer(type):
break_=_TOKEN_TYPE_TO_INDEX[TokenType.BREAK],
dcolon=_TOKEN_TYPE_TO_INDEX[TokenType.DCOLON],
heredoc_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEREDOC_STRING],
raw_string=_TOKEN_TYPE_TO_INDEX[TokenType.RAW_STRING],
hex_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEX_STRING],
identifier=_TOKEN_TYPE_TO_INDEX[TokenType.IDENTIFIER],
number=_TOKEN_TYPE_TO_INDEX[TokenType.NUMBER],
@ -562,8 +567,7 @@ class Tokenizer(metaclass=_Tokenizer):
"~": TokenType.TILDA,
"?": TokenType.PLACEHOLDER,
"@": TokenType.PARAMETER,
# used for breaking a var like x'y' but nothing else
# the token type doesn't matter
# Used for breaking a var like x'y' but nothing else the token type doesn't matter
"'": TokenType.QUOTE,
"`": TokenType.IDENTIFIER,
'"': TokenType.IDENTIFIER,
@ -796,6 +800,7 @@ class Tokenizer(metaclass=_Tokenizer):
"LONG": TokenType.BIGINT,
"BIGINT": TokenType.BIGINT,
"INT8": TokenType.TINYINT,
"UINT": TokenType.UINT,
"DEC": TokenType.DECIMAL,
"DECIMAL": TokenType.DECIMAL,
"BIGDECIMAL": TokenType.BIGDECIMAL,
@ -856,6 +861,7 @@ class Tokenizer(metaclass=_Tokenizer):
"DATEMULTIRANGE": TokenType.DATEMULTIRANGE,
"UNIQUE": TokenType.UNIQUE,
"STRUCT": TokenType.STRUCT,
"SEQUENCE": TokenType.SEQUENCE,
"VARIANT": TokenType.VARIANT,
"ALTER": TokenType.ALTER,
"ANALYZE": TokenType.COMMAND,
@ -888,7 +894,7 @@ class Tokenizer(metaclass=_Tokenizer):
COMMAND_PREFIX_TOKENS = {TokenType.SEMICOLON, TokenType.BEGIN}
# handle numeric literals like in hive (3L = BIGINT)
# Handle numeric literals like in hive (3L = BIGINT)
NUMERIC_LITERALS: t.Dict[str, str] = {}
COMMENTS = ["--", ("/*", "*/")]
@ -917,7 +923,7 @@ class Tokenizer(metaclass=_Tokenizer):
if USE_RS_TOKENIZER:
self._rs_dialect_settings = RsTokenizerDialectSettings(
escape_sequences=self.dialect.ESCAPE_SEQUENCES,
unescaped_sequences=self.dialect.UNESCAPED_SEQUENCES,
identifiers_can_start_with_digit=self.dialect.IDENTIFIERS_CAN_START_WITH_DIGIT,
)
@ -961,8 +967,7 @@ class Tokenizer(metaclass=_Tokenizer):
while self.size and not self._end:
current = self._current
# skip spaces inline rather than iteratively call advance()
# for performance reasons
# Skip spaces here rather than iteratively calling advance() for performance reasons
while current < self.size:
char = self.sql[current]
@ -971,12 +976,10 @@ class Tokenizer(metaclass=_Tokenizer):
else:
break
n = current - self._current
self._start = current
self._advance(n if n > 1 else 1)
offset = current - self._current if current > self._current else 1
if self._char is None:
break
self._start = current
self._advance(offset)
if not self._char.isspace():
if self._char.isdigit():
@ -1004,12 +1007,9 @@ class Tokenizer(metaclass=_Tokenizer):
def _advance(self, i: int = 1, alnum: bool = False) -> None:
if self.WHITE_SPACE.get(self._char) is TokenType.BREAK:
# Ensures we don't count an extra line if we get a \r\n line break sequence
if self._char == "\r" and self._peek == "\n":
i = 2
self._start += 1
self._col = 1
self._line += 1
if not (self._char == "\r" and self._peek == "\n"):
self._col = 1
self._line += 1
else:
self._col += i
@ -1268,13 +1268,27 @@ class Tokenizer(metaclass=_Tokenizer):
return True
self._advance()
tag = "" if self._char == end else self._extract_string(end)
if self._char == end:
tag = ""
else:
tag = self._extract_string(
end,
unescape_sequences=False,
raise_unmatched=not self.HEREDOC_TAG_IS_IDENTIFIER,
)
if self._end and tag and self.HEREDOC_TAG_IS_IDENTIFIER:
self._advance(-len(tag))
self._add(self.HEREDOC_STRING_ALTERNATIVE)
return True
end = f"{start}{tag}{end}"
else:
return False
self._advance(len(start))
text = self._extract_string(end)
text = self._extract_string(end, unescape_sequences=token_type != TokenType.RAW_STRING)
if base:
try:
@ -1289,7 +1303,7 @@ class Tokenizer(metaclass=_Tokenizer):
def _scan_identifier(self, identifier_end: str) -> None:
self._advance()
text = self._extract_string(identifier_end, self._IDENTIFIER_ESCAPES)
text = self._extract_string(identifier_end, escapes=self._IDENTIFIER_ESCAPES)
self._add(TokenType.IDENTIFIER, text)
def _scan_var(self) -> None:
@ -1306,12 +1320,29 @@ class Tokenizer(metaclass=_Tokenizer):
else self.KEYWORDS.get(self._text.upper(), TokenType.VAR)
)
def _extract_string(self, delimiter: str, escapes=None) -> str:
def _extract_string(
self,
delimiter: str,
escapes: t.Optional[t.Set[str]] = None,
unescape_sequences: bool = True,
raise_unmatched: bool = True,
) -> str:
text = ""
delim_size = len(delimiter)
escapes = self._STRING_ESCAPES if escapes is None else escapes
while True:
if (
unescape_sequences
and self.dialect.UNESCAPED_SEQUENCES
and self._peek
and self._char in self.STRING_ESCAPES
):
unescaped_sequence = self.dialect.UNESCAPED_SEQUENCES.get(self._char + self._peek)
if unescaped_sequence:
self._advance(2)
text += unescaped_sequence
continue
if (
self._char in escapes
and (self._peek == delimiter or self._peek in escapes)
@ -1333,18 +1364,10 @@ class Tokenizer(metaclass=_Tokenizer):
break
if self._end:
raise TokenError(f"Missing {delimiter} from {self._line}:{self._start}")
if not raise_unmatched:
return text + self._char
if (
self.dialect.ESCAPE_SEQUENCES
and self._peek
and self._char in self.STRING_ESCAPES
):
escaped_sequence = self.dialect.ESCAPE_SEQUENCES.get(self._char + self._peek)
if escaped_sequence:
self._advance(2)
text += escaped_sequence
continue
raise TokenError(f"Missing {delimiter} from {self._line}:{self._start}")
current = self._current - 1
self._advance(alnum=True)

View file

@ -447,7 +447,7 @@ def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
if inner_with.recursive:
top_level_with.set("recursive", True)
top_level_with.expressions.extend(inner_with.expressions)
top_level_with.set("expressions", inner_with.expressions + top_level_with.expressions)
return expression
@ -464,7 +464,7 @@ def ensure_bools(expression: exp.Expression) -> exp.Expression:
):
node.replace(node.neq(0))
for node, *_ in expression.walk():
for node in expression.walk():
ensure_bools(node, _ensure_bool)
return expression
@ -561,9 +561,7 @@ def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Exp
def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
"""
Convert struct arguments to aliases: STRUCT(1 AS y) .
"""
"""Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
if isinstance(expression, exp.Struct):
expression.set(
"expressions",