1
0
Fork 0

Merging upstream version 15.0.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:57:23 +01:00
parent 8deb804d23
commit fc63828ee4
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
167 changed files with 58268 additions and 51337 deletions

View file

@ -6,6 +6,7 @@
from __future__ import annotations
import logging
import typing as t
from sqlglot import expressions as exp
@ -45,12 +46,19 @@ from sqlglot.schema import MappingSchema as MappingSchema, Schema as Schema
from sqlglot.tokens import Tokenizer as Tokenizer, TokenType as TokenType
if t.TYPE_CHECKING:
from sqlglot._typing import E
from sqlglot.dialects.dialect import DialectType as DialectType
T = t.TypeVar("T", bound=Expression)
logger = logging.getLogger("sqlglot")
__version__ = "12.2.0"
try:
from sqlglot._version import __version__, __version_tuple__
except ImportError:
logger.error(
"Unable to set __version__, run `pip install -e .` or `python setup.py develop` first."
)
pretty = False
"""Whether to format generated SQL by default."""
@ -79,9 +87,9 @@ def parse(sql: str, read: DialectType = None, **opts) -> t.List[t.Optional[Expre
def parse_one(
sql: str,
read: None = None,
into: t.Type[T] = ...,
into: t.Type[E] = ...,
**opts,
) -> T:
) -> E:
...
@ -89,9 +97,9 @@ def parse_one(
def parse_one(
sql: str,
read: DialectType,
into: t.Type[T],
into: t.Type[E],
**opts,
) -> T:
) -> E:
...

8
sqlglot/_typing.py Normal file
View file

@ -0,0 +1,8 @@
from __future__ import annotations
import typing as t
import sqlglot
E = t.TypeVar("E", bound="sqlglot.exp.Expression")
T = t.TypeVar("T")

View file

@ -11,6 +11,8 @@ if t.TYPE_CHECKING:
ColumnLiterals = t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
ColumnOrName = t.Union[Column, str]
ColumnOrLiteral = t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
ColumnOrLiteral = t.Union[
Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime
]
SchemaInput = t.Union[str, t.List[str], StructType, t.Dict[str, t.Optional[str]]]
OutputExpressionContainer = t.Union[exp.Select, exp.Create, exp.Insert]

View file

@ -127,7 +127,7 @@ class DataFrame:
sequence_id: t.Optional[str] = None,
**kwargs,
) -> t.Tuple[exp.CTE, str]:
name = self.spark._random_name
name = self._create_hash_from_expression(expression)
expression_to_cte = expression.copy()
expression_to_cte.set("with", None)
cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0]
@ -263,7 +263,7 @@ class DataFrame:
return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions]
@classmethod
def _create_hash_from_expression(cls, expression: exp.Select):
def _create_hash_from_expression(cls, expression: exp.Expression) -> str:
value = expression.sql(dialect="spark").encode("utf-8")
return f"t{zlib.crc32(value)}"[:6]
@ -299,7 +299,7 @@ class DataFrame:
for expression_type, select_expression in select_expressions:
select_expression = select_expression.transform(replace_id_value, replacement_mapping)
if optimize:
select_expression = optimize_func(select_expression, identify="always")
select_expression = t.cast(exp.Select, optimize_func(select_expression))
select_expression = df._replace_cte_names_with_hashes(select_expression)
expression: t.Union[exp.Select, exp.Cache, exp.Drop]
if expression_type == exp.Cache:
@ -570,9 +570,9 @@ class DataFrame:
r_expressions.append(l_column)
r_columns_unused.remove(l_column)
else:
r_expressions.append(exp.alias_(exp.Null(), l_column))
r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False))
for r_column in r_columns_unused:
l_expressions.append(exp.alias_(exp.Null(), r_column))
l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False))
r_expressions.append(r_column)
r_df = (
other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
@ -761,7 +761,7 @@ class DataFrame:
raise ValueError("Tried to rename a column that doesn't exist")
for existing_column in existing_columns:
if isinstance(existing_column, exp.Column):
existing_column.replace(exp.alias_(existing_column.copy(), new))
existing_column.replace(exp.alias_(existing_column, new))
else:
existing_column.set("alias", exp.to_identifier(new))
return self.copy(expression=expression)

View file

@ -41,7 +41,7 @@ def operation(op: Operation):
self.last_op = Operation.NO_OP
last_op = self.last_op
new_op = op if op != Operation.NO_OP else last_op
if new_op < last_op or (last_op == new_op and new_op == Operation.SELECT):
if new_op < last_op or (last_op == new_op == Operation.SELECT):
self = self._convert_leaf_to_cte()
df: t.Union[DataFrame, GroupedData] = func(self, *args, **kwargs)
df.last_op = new_op # type: ignore

View file

@ -87,15 +87,13 @@ class SparkSession:
select_kwargs = {
"expressions": sel_columns,
"from": exp.From(
expressions=[
exp.Values(
expressions=data_expressions,
alias=exp.TableAlias(
this=exp.to_identifier(self._auto_incrementing_name),
columns=[exp.to_identifier(col_name) for col_name in column_mapping],
),
this=exp.Values(
expressions=data_expressions,
alias=exp.TableAlias(
this=exp.to_identifier(self._auto_incrementing_name),
columns=[exp.to_identifier(col_name) for col_name in column_mapping],
),
],
),
),
}
@ -127,10 +125,6 @@ class SparkSession:
self.incrementing_id += 1
return name
@property
def _random_name(self) -> str:
return "r" + uuid.uuid4().hex
@property
def _random_branch_id(self) -> str:
id = self._random_id
@ -145,7 +139,7 @@ class SparkSession:
@property
def _random_id(self) -> str:
id = self._random_name
id = "r" + uuid.uuid4().hex
self.known_ids.add(id)
return id

View file

@ -27,6 +27,6 @@ def get_tables_from_expression_with_join(expression: exp.Select) -> t.List[exp.T
if not expression.args.get("joins"):
return []
left_table = expression.args["from"].args["expressions"][0]
left_table = expression.args["from"].this
other_tables = [join.this for join in expression.args["joins"]]
return [left_table] + other_tables

View file

@ -1,5 +1,3 @@
"""Supports BigQuery Standard SQL."""
from __future__ import annotations
import re
@ -18,11 +16,9 @@ from sqlglot.dialects.dialect import (
timestrtotime_sql,
ts_or_ds_to_date_sql,
)
from sqlglot.helper import seq_get
from sqlglot.helper import seq_get, split_num_words
from sqlglot.tokens import TokenType
E = t.TypeVar("E", bound=exp.Expression)
def _date_add_sql(
data_type: str, kind: str
@ -96,19 +92,12 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
These are added by the optimizer's qualify_column step.
"""
if isinstance(expression, exp.Select):
unnests = {
unnest.alias
for unnest in expression.args.get("from", exp.From(expressions=[])).expressions
if isinstance(unnest, exp.Unnest) and unnest.alias
}
if unnests:
expression = expression.copy()
for select in expression.expressions:
for column in select.find_all(exp.Column):
if column.table in unnests:
column.set("table", None)
for unnest in expression.find_all(exp.Unnest):
if isinstance(unnest.parent, (exp.From, exp.Join)) and unnest.alias:
for select in expression.selects:
for column in select.find_all(exp.Column):
if column.table == unnest.alias:
column.set("table", None)
return expression
@ -127,16 +116,20 @@ class BigQuery(Dialect):
}
class Tokenizer(tokens.Tokenizer):
QUOTES = [
(prefix + quote, quote) if prefix else quote
for quote in ["'", '"', '"""', "'''"]
for prefix in ["", "r", "R"]
]
QUOTES = ["'", '"', '"""', "'''"]
COMMENTS = ["--", "#", ("/*", "*/")]
IDENTIFIERS = ["`"]
STRING_ESCAPES = ["\\"]
HEX_STRINGS = [("0x", ""), ("0X", "")]
BYTE_STRINGS = [("b'", "'"), ("B'", "'")]
BYTE_STRINGS = [
(prefix + q, q) for q in t.cast(t.List[str], QUOTES) for prefix in ("b", "B")
]
RAW_STRINGS = [
(prefix + q, q) for q in t.cast(t.List[str], QUOTES) for prefix in ("r", "R")
]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@ -144,11 +137,11 @@ class BigQuery(Dialect):
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
"BYTES": TokenType.BINARY,
"DECLARE": TokenType.COMMAND,
"GEOGRAPHY": TokenType.GEOGRAPHY,
"FLOAT64": TokenType.DOUBLE,
"INT64": TokenType.BIGINT,
"BYTES": TokenType.BINARY,
"RECORD": TokenType.STRUCT,
"NOT DETERMINISTIC": TokenType.VOLATILE,
"UNKNOWN": TokenType.NULL,
}
@ -161,7 +154,7 @@ class BigQuery(Dialect):
LOG_DEFAULTS_TO_LN = True
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
**parser.Parser.FUNCTIONS,
"DATE_TRUNC": lambda args: exp.DateTrunc(
unit=exp.Literal.string(str(seq_get(args, 1))),
this=seq_get(args, 0),
@ -191,28 +184,28 @@ class BigQuery(Dialect):
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS, # type: ignore
**parser.Parser.FUNCTION_PARSERS,
"ARRAY": lambda self: self.expression(exp.Array, expressions=[self._parse_statement()]),
}
FUNCTION_PARSERS.pop("TRIM")
NO_PAREN_FUNCTIONS = {
**parser.Parser.NO_PAREN_FUNCTIONS, # type: ignore
**parser.Parser.NO_PAREN_FUNCTIONS,
TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
}
NESTED_TYPE_TOKENS = {
*parser.Parser.NESTED_TYPE_TOKENS, # type: ignore
*parser.Parser.NESTED_TYPE_TOKENS,
TokenType.TABLE,
}
ID_VAR_TOKENS = {
*parser.Parser.ID_VAR_TOKENS, # type: ignore
*parser.Parser.ID_VAR_TOKENS,
TokenType.VALUES,
}
PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS, # type: ignore
**parser.Parser.PROPERTY_PARSERS,
"NOT DETERMINISTIC": lambda self: self.expression(
exp.StabilityProperty, this=exp.Literal.string("VOLATILE")
),
@ -220,19 +213,50 @@ class BigQuery(Dialect):
}
CONSTRAINT_PARSERS = {
**parser.Parser.CONSTRAINT_PARSERS, # type: ignore
**parser.Parser.CONSTRAINT_PARSERS,
"OPTIONS": lambda self: exp.Properties(expressions=self._parse_with_property()),
}
def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]:
this = super()._parse_table_part(schema=schema)
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#table_names
if isinstance(this, exp.Identifier):
table_name = this.name
while self._match(TokenType.DASH, advance=False) and self._next:
self._advance(2)
table_name += f"-{self._prev.text}"
this = exp.Identifier(this=table_name, quoted=this.args.get("quoted"))
return this
def _parse_table_parts(self, schema: bool = False) -> exp.Table:
table = super()._parse_table_parts(schema=schema)
if isinstance(table.this, exp.Identifier) and "." in table.name:
catalog, db, this, *rest = (
t.cast(t.Optional[exp.Expression], exp.to_identifier(x))
for x in split_num_words(table.name, ".", 3)
)
if rest and this:
this = exp.Dot.build(t.cast(t.List[exp.Expression], [this, *rest]))
table = exp.Table(this=this, db=db, catalog=catalog)
return table
class Generator(generator.Generator):
EXPLICIT_UNION = True
INTERVAL_ALLOWS_PLURAL_FORM = False
JOIN_HINTS = False
TABLE_HINTS = False
LIMIT_FETCH = "LIMIT"
RENAME_TABLE_WITH_DB = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.AtTimeZone: lambda self, e: self.func(
"TIMESTAMP", self.func("DATETIME", e.this, e.args.get("zone"))
@ -259,6 +283,7 @@ class BigQuery(Dialect):
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
exp.TimeStrToTime: timestrtotime_sql,
exp.TryCast: lambda self, e: f"SAFE_CAST({self.sql(e, 'this')} AS {self.sql(e, 'to')})",
exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"),
exp.TsOrDsAdd: _date_add_sql("DATE", "ADD"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
@ -274,7 +299,7 @@ class BigQuery(Dialect):
}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC",
exp.DataType.Type.BIGINT: "INT64",
exp.DataType.Type.BINARY: "BYTES",
@ -297,7 +322,7 @@ class BigQuery(Dialect):
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
**generator.Generator.PROPERTIES_LOCATION,
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}

View file

@ -3,11 +3,16 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
from sqlglot.dialects.dialect import (
Dialect,
inline_array_sql,
no_pivot_sql,
rename_func,
var_map_sql,
)
from sqlglot.errors import ParseError
from sqlglot.helper import ensure_list, seq_get
from sqlglot.parser import parse_var_map
from sqlglot.tokens import TokenType
from sqlglot.tokens import Token, TokenType
def _lower_func(sql: str) -> str:
@ -28,65 +33,122 @@ class ClickHouse(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ASOF": TokenType.ASOF,
"GLOBAL": TokenType.GLOBAL,
"DATETIME64": TokenType.DATETIME,
"ATTACH": TokenType.COMMAND,
"DATETIME64": TokenType.DATETIME64,
"FINAL": TokenType.FINAL,
"FLOAT32": TokenType.FLOAT,
"FLOAT64": TokenType.DOUBLE,
"INT8": TokenType.TINYINT,
"UINT8": TokenType.UTINYINT,
"INT16": TokenType.SMALLINT,
"UINT16": TokenType.USMALLINT,
"INT32": TokenType.INT,
"UINT32": TokenType.UINT,
"INT64": TokenType.BIGINT,
"UINT64": TokenType.UBIGINT,
"GLOBAL": TokenType.GLOBAL,
"INT128": TokenType.INT128,
"UINT128": TokenType.UINT128,
"INT16": TokenType.SMALLINT,
"INT256": TokenType.INT256,
"UINT256": TokenType.UINT256,
"INT32": TokenType.INT,
"INT64": TokenType.BIGINT,
"INT8": TokenType.TINYINT,
"MAP": TokenType.MAP,
"TUPLE": TokenType.STRUCT,
"UINT128": TokenType.UINT128,
"UINT16": TokenType.USMALLINT,
"UINT256": TokenType.UINT256,
"UINT32": TokenType.UINT,
"UINT64": TokenType.UBIGINT,
"UINT8": TokenType.UTINYINT,
}
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"EXPONENTIALTIMEDECAYEDAVG": lambda params, args: exp.ExponentialTimeDecayedAvg(
this=seq_get(args, 0),
time=seq_get(args, 1),
decay=seq_get(params, 0),
),
"GROUPUNIQARRAY": lambda params, args: exp.GroupUniqArray(
this=seq_get(args, 0), size=seq_get(params, 0)
),
"HISTOGRAM": lambda params, args: exp.Histogram(
this=seq_get(args, 0), bins=seq_get(params, 0)
),
**parser.Parser.FUNCTIONS,
"ANY": exp.AnyValue.from_arg_list,
"MAP": parse_var_map,
"MATCH": exp.RegexpLike.from_arg_list,
"QUANTILE": lambda params, args: exp.Quantile(this=args, quantile=params),
"QUANTILES": lambda params, args: exp.Quantiles(parameters=params, expressions=args),
"QUANTILEIF": lambda params, args: exp.QuantileIf(parameters=params, expressions=args),
"UNIQ": exp.ApproxDistinct.from_arg_list,
}
FUNCTIONS_WITH_ALIASED_ARGS = {*parser.Parser.FUNCTIONS_WITH_ALIASED_ARGS, "TUPLE"}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
"QUANTILE": lambda self: self._parse_quantile(),
}
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
FUNCTION_PARSERS.pop("MATCH")
NO_PAREN_FUNCTION_PARSERS = parser.Parser.NO_PAREN_FUNCTION_PARSERS.copy()
NO_PAREN_FUNCTION_PARSERS.pop(TokenType.ANY)
RANGE_PARSERS = {
**parser.Parser.RANGE_PARSERS,
TokenType.GLOBAL: lambda self, this: self._match(TokenType.IN)
and self._parse_in(this, is_global=True),
}
JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} # type: ignore
# The PLACEHOLDER entry is popped because 1) it doesn't affect Clickhouse (it corresponds to
# the postgres-specific JSONBContains parser) and 2) it makes parsing the ternary op simpler.
COLUMN_OPERATORS = parser.Parser.COLUMN_OPERATORS.copy()
COLUMN_OPERATORS.pop(TokenType.PLACEHOLDER)
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore
JOIN_KINDS = {
*parser.Parser.JOIN_KINDS,
TokenType.ANY,
TokenType.ASOF,
TokenType.ANTI,
TokenType.SEMI,
}
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {
TokenType.ANY,
TokenType.ASOF,
TokenType.SEMI,
TokenType.ANTI,
TokenType.SETTINGS,
TokenType.FORMAT,
}
LOG_DEFAULTS_TO_LN = True
def _parse_in(
self, this: t.Optional[exp.Expression], is_global: bool = False
) -> exp.Expression:
QUERY_MODIFIER_PARSERS = {
**parser.Parser.QUERY_MODIFIER_PARSERS,
"settings": lambda self: self._parse_csv(self._parse_conjunction)
if self._match(TokenType.SETTINGS)
else None,
"format": lambda self: self._parse_id_var() if self._match(TokenType.FORMAT) else None,
}
def _parse_conjunction(self) -> t.Optional[exp.Expression]:
this = super()._parse_conjunction()
if self._match(TokenType.PLACEHOLDER):
return self.expression(
exp.If,
this=this,
true=self._parse_conjunction(),
false=self._match(TokenType.COLON) and self._parse_conjunction(),
)
return this
def _parse_placeholder(self) -> t.Optional[exp.Expression]:
"""
Parse a placeholder expression like SELECT {abc: UInt32} or FROM {table: Identifier}
https://clickhouse.com/docs/en/sql-reference/syntax#defining-and-using-query-parameters
"""
if not self._match(TokenType.L_BRACE):
return None
this = self._parse_id_var()
self._match(TokenType.COLON)
kind = self._parse_types(check_func=False) or (
self._match_text_seq("IDENTIFIER") and "Identifier"
)
if not kind:
self.raise_error("Expecting a placeholder type or 'Identifier' for tables")
elif not self._match(TokenType.R_BRACE):
self.raise_error("Expecting }")
return self.expression(exp.Placeholder, this=this, kind=kind)
def _parse_in(self, this: t.Optional[exp.Expression], is_global: bool = False) -> exp.In:
this = super()._parse_in(this)
this.set("is_global", is_global)
return this
@ -120,81 +182,142 @@ class ClickHouse(Dialect):
return self.expression(exp.CTE, this=statement, alias=statement and statement.this)
def _parse_join_side_and_kind(
self,
) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]:
is_global = self._match(TokenType.GLOBAL) and self._prev
kind_pre = self._match_set(self.JOIN_KINDS, advance=False) and self._prev
if kind_pre:
kind = self._match_set(self.JOIN_KINDS) and self._prev
side = self._match_set(self.JOIN_SIDES) and self._prev
return is_global, side, kind
return (
is_global,
self._match_set(self.JOIN_SIDES) and self._prev,
self._match_set(self.JOIN_KINDS) and self._prev,
)
def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]:
join = super()._parse_join(skip_join_token)
if join:
join.set("global", join.args.pop("natural", None))
return join
def _parse_function(
self, functions: t.Optional[t.Dict[str, t.Callable]] = None, anonymous: bool = False
) -> t.Optional[exp.Expression]:
func = super()._parse_function(functions, anonymous)
if isinstance(func, exp.Anonymous):
params = self._parse_func_params(func)
if params:
return self.expression(
exp.ParameterizedAgg,
this=func.this,
expressions=func.expressions,
params=params,
)
return func
def _parse_func_params(
self, this: t.Optional[exp.Func] = None
) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN):
return self._parse_csv(self._parse_lambda)
if self._match(TokenType.L_PAREN):
params = self._parse_csv(self._parse_lambda)
self._match_r_paren(this)
return params
return None
def _parse_quantile(self) -> exp.Quantile:
this = self._parse_lambda()
params = self._parse_func_params()
if params:
return self.expression(exp.Quantile, this=params[0], quantile=this)
return self.expression(exp.Quantile, this=this, quantile=exp.Literal.number(0.5))
def _parse_wrapped_id_vars(
self, optional: bool = False
) -> t.List[t.Optional[exp.Expression]]:
return super()._parse_wrapped_id_vars(optional=True)
class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.NULLABLE: "Nullable",
exp.DataType.Type.DATETIME: "DateTime64",
exp.DataType.Type.MAP: "Map",
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.ARRAY: "Array",
exp.DataType.Type.BIGINT: "Int64",
exp.DataType.Type.DATETIME64: "DateTime64",
exp.DataType.Type.DOUBLE: "Float64",
exp.DataType.Type.FLOAT: "Float32",
exp.DataType.Type.INT: "Int32",
exp.DataType.Type.INT128: "Int128",
exp.DataType.Type.INT256: "Int256",
exp.DataType.Type.MAP: "Map",
exp.DataType.Type.NULLABLE: "Nullable",
exp.DataType.Type.SMALLINT: "Int16",
exp.DataType.Type.STRUCT: "Tuple",
exp.DataType.Type.TINYINT: "Int8",
exp.DataType.Type.UTINYINT: "UInt8",
exp.DataType.Type.SMALLINT: "Int16",
exp.DataType.Type.USMALLINT: "UInt16",
exp.DataType.Type.INT: "Int32",
exp.DataType.Type.UINT: "UInt32",
exp.DataType.Type.BIGINT: "Int64",
exp.DataType.Type.UBIGINT: "UInt64",
exp.DataType.Type.INT128: "Int128",
exp.DataType.Type.UINT: "UInt32",
exp.DataType.Type.UINT128: "UInt128",
exp.DataType.Type.INT256: "Int256",
exp.DataType.Type.UINT256: "UInt256",
exp.DataType.Type.FLOAT: "Float32",
exp.DataType.Type.DOUBLE: "Float64",
exp.DataType.Type.USMALLINT: "UInt16",
exp.DataType.Type.UTINYINT: "UInt8",
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**generator.Generator.TRANSFORMS,
exp.AnyValue: rename_func("any"),
exp.ApproxDistinct: rename_func("uniq"),
exp.Array: inline_array_sql,
exp.ExponentialTimeDecayedAvg: lambda self, e: f"exponentialTimeDecayedAvg{self._param_args_sql(e, 'decay', ['this', 'time'])}",
exp.CastToStrType: rename_func("CAST"),
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
exp.GroupUniqArray: lambda self, e: f"groupUniqArray{self._param_args_sql(e, 'size', 'this')}",
exp.Histogram: lambda self, e: f"histogram{self._param_args_sql(e, 'bins', 'this')}",
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.Quantile: lambda self, e: f"quantile{self._param_args_sql(e, 'quantile', 'this')}",
exp.Quantiles: lambda self, e: f"quantiles{self._param_args_sql(e, 'parameters', 'expressions')}",
exp.QuantileIf: lambda self, e: f"quantileIf{self._param_args_sql(e, 'parameters', 'expressions')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Pivot: no_pivot_sql,
exp.Quantile: lambda self, e: self.func("quantile", e.args.get("quantile"))
+ f"({self.sql(e, 'this')})",
exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})",
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
**generator.Generator.PROPERTIES_LOCATION,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
}
JOIN_HINTS = False
TABLE_HINTS = False
EXPLICIT_UNION = True
def _param_args_sql(
self,
expression: exp.Expression,
param_names: str | t.List[str],
arg_names: str | t.List[str],
) -> str:
params = self.format_args(
*(
arg
for name in ensure_list(param_names)
for arg in ensure_list(expression.args.get(name))
)
)
args = self.format_args(
*(
arg
for name in ensure_list(arg_names)
for arg in ensure_list(expression.args.get(name))
)
)
return f"({params})({args})"
GROUPINGS_SEP = ""
def cte_sql(self, expression: exp.CTE) -> str:
if isinstance(expression.this, exp.Alias):
return self.sql(expression, "this")
return super().cte_sql(expression)
def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]:
return super().after_limit_modifiers(expression) + [
self.seg("SETTINGS ") + self.expressions(expression, key="settings", flat=True)
if expression.args.get("settings")
else "",
self.seg("FORMAT ") + self.sql(expression, "format")
if expression.args.get("format")
else "",
]
def parameterizedagg_sql(self, expression: exp.Anonymous) -> str:
params = self.expressions(expression, "params", flat=True)
return self.func(expression.name, *expression.expressions) + f"({params})"
def placeholder_sql(self, expression: exp.Placeholder) -> str:
return f"{{{expression.name}: {self.sql(expression, 'kind')}}}"

View file

@ -25,7 +25,7 @@ class Databricks(Spark):
class Generator(Spark.Generator):
TRANSFORMS = {
**Spark.Generator.TRANSFORMS, # type: ignore
**Spark.Generator.TRANSFORMS,
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.JSONExtract: lambda self, e: self.binary(e, ":"),

View file

@ -8,10 +8,16 @@ from sqlglot.generator import Generator
from sqlglot.helper import flatten, seq_get
from sqlglot.parser import Parser
from sqlglot.time import format_time
from sqlglot.tokens import Token, Tokenizer
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import new_trie
E = t.TypeVar("E", bound=exp.Expression)
if t.TYPE_CHECKING:
from sqlglot._typing import E
# Only Snowflake is currently known to resolve unquoted identifiers as uppercase.
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
RESOLVES_IDENTIFIERS_AS_UPPERCASE = {"snowflake"}
class Dialects(str, Enum):
@ -42,6 +48,19 @@ class Dialects(str, Enum):
class _Dialect(type):
classes: t.Dict[str, t.Type[Dialect]] = {}
def __eq__(cls, other: t.Any) -> bool:
if cls is other:
return True
if isinstance(other, str):
return cls is cls.get(other)
if isinstance(other, Dialect):
return cls is type(other)
return False
def __hash__(cls) -> int:
return hash(cls.__name__.lower())
@classmethod
def __getitem__(cls, key: str) -> t.Type[Dialect]:
return cls.classes[key]
@ -70,17 +89,20 @@ class _Dialect(type):
klass.tokenizer_class._IDENTIFIERS.items()
)[0]
klass.bit_start, klass.bit_end = seq_get(
list(klass.tokenizer_class._BIT_STRINGS.items()), 0
) or (None, None)
def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
return next(
(
(s, e)
for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
if t == token_type
),
(None, None),
)
klass.hex_start, klass.hex_end = seq_get(
list(klass.tokenizer_class._HEX_STRINGS.items()), 0
) or (None, None)
klass.byte_start, klass.byte_end = seq_get(
list(klass.tokenizer_class._BYTE_STRINGS.items()), 0
) or (None, None)
klass.bit_start, klass.bit_end = get_start_end(TokenType.BIT_STRING)
klass.hex_start, klass.hex_end = get_start_end(TokenType.HEX_STRING)
klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING)
klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING)
return klass
@ -110,6 +132,12 @@ class Dialect(metaclass=_Dialect):
parser_class = None
generator_class = None
def __eq__(self, other: t.Any) -> bool:
return type(self) == other
def __hash__(self) -> int:
return hash(type(self))
@classmethod
def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
if not dialect:
@ -192,6 +220,8 @@ class Dialect(metaclass=_Dialect):
"hex_end": self.hex_end,
"byte_start": self.byte_start,
"byte_end": self.byte_end,
"raw_start": self.raw_start,
"raw_end": self.raw_end,
"identifier_start": self.identifier_start,
"identifier_end": self.identifier_end,
"string_escape": self.tokenizer_class.STRING_ESCAPES[0],
@ -275,7 +305,7 @@ def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
self.unsupported("PIVOT unsupported")
return self.sql(expression)
return ""
def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
@ -328,7 +358,7 @@ def var_map_sql(
def format_time_lambda(
exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
) -> t.Callable[[t.Sequence], E]:
) -> t.Callable[[t.List], E]:
"""Helper used for time expressions.
Args:
@ -340,7 +370,7 @@ def format_time_lambda(
A callable that can be used to return the appropriately formatted time expression.
"""
def _format_time(args: t.Sequence):
def _format_time(args: t.List):
return exp_class(
this=seq_get(args, 0),
format=Dialect[dialect].format_time(
@ -377,12 +407,12 @@ def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
def parse_date_delta(
exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
) -> t.Callable[[t.Sequence], E]:
def inner_func(args: t.Sequence) -> E:
) -> t.Callable[[t.List], E]:
def inner_func(args: t.List) -> E:
unit_based = len(args) == 3
this = args[2] if unit_based else seq_get(args, 0)
unit = args[0] if unit_based else exp.Literal.string("DAY")
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
return inner_func
@ -390,8 +420,8 @@ def parse_date_delta(
def parse_date_delta_with_interval(
expression_class: t.Type[E],
) -> t.Callable[[t.Sequence], t.Optional[E]]:
def func(args: t.Sequence) -> t.Optional[E]:
) -> t.Callable[[t.List], t.Optional[E]]:
def func(args: t.List) -> t.Optional[E]:
if len(args) < 2:
return None
@ -409,7 +439,7 @@ def parse_date_delta_with_interval(
return func
def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
unit = seq_get(args, 0)
this = seq_get(args, 1)
@ -424,7 +454,7 @@ def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
)
def locate_to_strposition(args: t.Sequence) -> exp.Expression:
def locate_to_strposition(args: t.List) -> exp.Expression:
return exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
@ -483,7 +513,7 @@ def trim_sql(self: Generator, expression: exp.Trim) -> str:
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql(self, expression: exp.Expression) -> str:
def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
return self.func("STRPTIME", expression.this, self.format_time(expression))
@ -496,3 +526,26 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
return f"CAST({self.sql(expression, 'this')} AS DATE)"
return _ts_or_ds_to_date_sql
# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator
def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
names = []
for agg in aggregations:
if isinstance(agg, exp.Alias):
names.append(agg.alias)
else:
"""
This case corresponds to aggregations without aliases being used as suffixes
(e.g. col_avg(foo)). We need to unquote identifiers because they're going to
be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
"""
agg_all_unquoted = agg.transform(
lambda node: exp.Identifier(this=node.name, quoted=False)
if isinstance(node, exp.Identifier)
else node
)
names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
return names

View file

@ -95,7 +95,7 @@ class Drill(Dialect):
STRICT_CAST = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
**parser.Parser.FUNCTIONS,
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "drill"),
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"),
@ -108,7 +108,7 @@ class Drill(Dialect):
TABLE_HINTS = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.SMALLINT: "INTEGER",
exp.DataType.Type.TINYINT: "INTEGER",
@ -121,13 +121,13 @@ class Drill(Dialect):
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
**generator.Generator.PROPERTIES_LOCATION,
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**generator.Generator.TRANSFORMS,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
exp.ArraySize: rename_func("REPEATED_COUNT"),

View file

@ -11,9 +11,9 @@ from sqlglot.dialects.dialect import (
datestrtodate_sql,
format_time_lambda,
no_comment_column_constraint_sql,
no_pivot_sql,
no_properties_sql,
no_safe_divide_sql,
pivot_column_names,
rename_func,
str_position_sql,
str_to_time_sql,
@ -31,10 +31,11 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
def _date_delta_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
return f"{this} + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
op = "+" if isinstance(expression, exp.DateAdd) else "-"
return f"{this} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
@ -50,11 +51,11 @@ def _sort_array_sql(self: generator.Generator, expression: exp.SortArray) -> str
return f"ARRAY_SORT({this})"
def _sort_array_reverse(args: t.Sequence) -> exp.Expression:
def _sort_array_reverse(args: t.List) -> exp.Expression:
return exp.SortArray(this=seq_get(args, 0), asc=exp.false())
def _parse_date_diff(args: t.Sequence) -> exp.Expression:
def _parse_date_diff(args: t.List) -> exp.Expression:
return exp.DateDiff(
this=seq_get(args, 2),
expression=seq_get(args, 1),
@ -89,11 +90,14 @@ def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract
class DuckDB(Dialect):
null_ordering = "nulls_are_last"
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"~": TokenType.RLIKE,
":=": TokenType.EQ,
"//": TokenType.DIV,
"ATTACH": TokenType.COMMAND,
"BINARY": TokenType.VARBINARY,
"BPCHAR": TokenType.TEXT,
@ -104,6 +108,7 @@ class DuckDB(Dialect):
"INT1": TokenType.TINYINT,
"LOGICAL": TokenType.BOOLEAN,
"NUMERIC": TokenType.DOUBLE,
"PIVOT_WIDER": TokenType.PIVOT,
"SIGNED": TokenType.INT,
"STRING": TokenType.VARCHAR,
"UBIGINT": TokenType.UBIGINT,
@ -114,8 +119,7 @@ class DuckDB(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
**parser.Parser.FUNCTIONS,
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
"ARRAY_SORT": exp.SortArray.from_arg_list,
"ARRAY_REVERSE_SORT": _sort_array_reverse,
@ -152,11 +156,17 @@ class DuckDB(Dialect):
TokenType.UTINYINT,
}
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
if len(aggregations) == 1:
return super()._pivot_column_names(aggregations)
return pivot_column_names(aggregations, dialect="duckdb")
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
LIMIT_FETCH = "LIMIT"
STRUCT_DELIMITER = ("(", ")")
RENAME_TABLE_WITH_DB = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@ -175,7 +185,8 @@ class DuckDB(Dialect):
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.DataType: _datatype_sql,
exp.DateAdd: _date_add_sql,
exp.DateAdd: _date_delta_sql,
exp.DateSub: _date_delta_sql,
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", f"'{e.args.get('unit', 'day')}'", e.expression, e.this
),
@ -183,13 +194,13 @@ class DuckDB(Dialect):
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)",
exp.Explode: rename_func("UNNEST"),
exp.IntDiv: lambda self, e: self.binary(e, "//"),
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: arrow_json_extract_sql,
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.Pivot: no_pivot_sql,
exp.Properties: no_properties_sql,
exp.RegexpExtract: _regexp_extract_sql,
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
@ -232,11 +243,11 @@ class DuckDB(Dialect):
STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
**generator.Generator.PROPERTIES_LOCATION,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep: str = " AS "
) -> str:
return super().tablesample_sql(expression, seed_prefix="REPEATABLE", sep=sep)

View file

@ -147,13 +147,6 @@ def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str
return f"TO_DATE({this})"
def _index_sql(self: generator.Generator, expression: exp.Index) -> str:
this = self.sql(expression, "this")
table = self.sql(expression, "table")
columns = self.sql(expression, "columns")
return f"{this} ON TABLE {table} {columns}"
class Hive(Dialect):
alias_post_tablesample = True
@ -225,8 +218,7 @@ class Hive(Dialect):
STRICT_CAST = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
**parser.Parser.FUNCTIONS,
"BASE64": exp.ToBase64.from_arg_list,
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
"DATE_ADD": lambda args: exp.TsOrDsAdd(
@ -271,21 +263,29 @@ class Hive(Dialect):
}
PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS, # type: ignore
**parser.Parser.PROPERTY_PARSERS,
"WITH SERDEPROPERTIES": lambda self: exp.SerdeProperties(
expressions=self._parse_wrapped_csv(self._parse_property)
),
}
QUERY_MODIFIER_PARSERS = {
**parser.Parser.QUERY_MODIFIER_PARSERS,
"distribute": lambda self: self._parse_sort(exp.Distribute, "DISTRIBUTE", "BY"),
"sort": lambda self: self._parse_sort(exp.Sort, "SORT", "BY"),
"cluster": lambda self: self._parse_sort(exp.Cluster, "CLUSTER", "BY"),
}
class Generator(generator.Generator):
LIMIT_FETCH = "LIMIT"
TABLESAMPLE_WITH_METHOD = False
TABLESAMPLE_SIZE_IS_PERCENT = True
JOIN_HINTS = False
TABLE_HINTS = False
INDEX_ON = "ON TABLE"
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.DATETIME: "TIMESTAMP",
exp.DataType.Type.VARBINARY: "BINARY",
@ -294,7 +294,7 @@ class Hive(Dialect):
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**generator.Generator.TRANSFORMS,
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Select: transforms.preprocess(
[
@ -319,7 +319,6 @@ class Hive(Dialect):
exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
exp.FromBase64: rename_func("UNBASE64"),
exp.If: if_sql,
exp.Index: _index_sql,
exp.ILike: no_ilike_sql,
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
@ -342,7 +341,6 @@ class Hive(Dialect):
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: _str_to_unix_sql,
exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}",
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
@ -363,14 +361,13 @@ class Hive(Dialect):
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
exp.LastDateOfMonth: rename_func("LAST_DAY"),
exp.National: lambda self, e: self.sql(e, "this"),
exp.National: lambda self, e: self.national_sql(e, prefix=""),
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
**generator.Generator.PROPERTIES_LOCATION,
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA,
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
@ -396,3 +393,10 @@ class Hive(Dialect):
expression = exp.DataType.build(expression.this)
return super().datatype_sql(expression)
def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
return super().after_having_modifiers(expression) + [
self.sql(expression, "distribute"),
self.sql(expression, "sort"),
self.sql(expression, "cluster"),
]

View file

@ -1,5 +1,7 @@
from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
@ -11,6 +13,7 @@ from sqlglot.dialects.dialect import (
min_or_least,
no_ilike_sql,
no_paren_current_date_sql,
no_pivot_sql,
no_tablesample_sql,
no_trycast_sql,
parse_date_delta_with_interval,
@ -21,14 +24,14 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
def _show_parser(*args, **kwargs):
def _parse(self):
def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[MySQL.Parser], exp.Show]:
def _parse(self: MySQL.Parser) -> exp.Show:
return self._parse_show_mysql(*args, **kwargs)
return _parse
def _date_trunc_sql(self, expression):
def _date_trunc_sql(self: generator.Generator, expression: exp.DateTrunc) -> str:
expr = self.sql(expression, "this")
unit = expression.text("unit")
@ -54,17 +57,17 @@ def _date_trunc_sql(self, expression):
return f"STR_TO_DATE({concat}, '{date_format}')"
def _str_to_date(args):
def _str_to_date(args: t.List) -> exp.StrToDate:
date_format = MySQL.format_time(seq_get(args, 1))
return exp.StrToDate(this=seq_get(args, 0), format=date_format)
def _str_to_date_sql(self, expression):
def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate | exp.StrToTime) -> str:
date_format = self.format_time(expression)
return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})"
def _trim_sql(self, expression):
def _trim_sql(self: generator.Generator, expression: exp.Trim) -> str:
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
remove_chars = self.sql(expression, "expression")
@ -79,8 +82,8 @@ def _trim_sql(self, expression):
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
def _date_add_sql(kind):
def func(self, expression):
def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = expression.text("unit").upper() or "DAY"
return (
@ -175,10 +178,10 @@ class MySQL(Dialect):
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW}
class Parser(parser.Parser):
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE} # type: ignore
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE}
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
**parser.Parser.FUNCTIONS,
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
@ -191,7 +194,7 @@ class MySQL(Dialect):
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS, # type: ignore
**parser.Parser.FUNCTION_PARSERS,
"GROUP_CONCAT": lambda self: self.expression(
exp.GroupConcat,
this=self._parse_lambda(),
@ -199,13 +202,8 @@ class MySQL(Dialect):
),
}
PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS, # type: ignore
"ENGINE": lambda self: self._parse_property_assignment(exp.EngineProperty),
}
STATEMENT_PARSERS = {
**parser.Parser.STATEMENT_PARSERS, # type: ignore
**parser.Parser.STATEMENT_PARSERS,
TokenType.SHOW: lambda self: self._parse_show(),
}
@ -286,7 +284,13 @@ class MySQL(Dialect):
LOG_DEFAULTS_TO_LN = True
def _parse_show_mysql(self, this, target=False, full=None, global_=None):
def _parse_show_mysql(
self,
this: str,
target: bool | str = False,
full: t.Optional[bool] = None,
global_: t.Optional[bool] = None,
) -> exp.Show:
if target:
if isinstance(target, str):
self._match_text_seq(target)
@ -342,10 +346,12 @@ class MySQL(Dialect):
offset=offset,
limit=limit,
mutex=mutex,
**{"global": global_},
**{"global": global_}, # type: ignore
)
def _parse_oldstyle_limit(self):
def _parse_oldstyle_limit(
self,
) -> t.Tuple[t.Optional[exp.Expression], t.Optional[exp.Expression]]:
limit = None
offset = None
if self._match_text_seq("LIMIT"):
@ -355,23 +361,20 @@ class MySQL(Dialect):
elif len(parts) == 2:
limit = parts[1]
offset = parts[0]
return offset, limit
def _parse_set_item_charset(self, kind):
def _parse_set_item_charset(self, kind: str) -> exp.Expression:
this = self._parse_string() or self._parse_id_var()
return self.expression(exp.SetItem, this=this, kind=kind)
return self.expression(
exp.SetItem,
this=this,
kind=kind,
)
def _parse_set_item_names(self):
def _parse_set_item_names(self) -> exp.Expression:
charset = self._parse_string() or self._parse_id_var()
if self._match_text_seq("COLLATE"):
collate = self._parse_string() or self._parse_id_var()
else:
collate = None
return self.expression(
exp.SetItem,
this=charset,
@ -386,7 +389,7 @@ class MySQL(Dialect):
TABLE_HINTS = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**generator.Generator.TRANSFORMS,
exp.CurrentDate: no_paren_current_date_sql,
exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression),
exp.DateAdd: _date_add_sql("ADD"),
@ -403,6 +406,7 @@ class MySQL(Dialect):
exp.Min: min_or_least,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
exp.Pivot: no_pivot_sql,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date_sql,
@ -422,7 +426,7 @@ class MySQL(Dialect):
TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB)
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
**generator.Generator.PROPERTIES_LOCATION,
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}

View file

@ -8,7 +8,7 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
def _parse_xml_table(self) -> exp.XMLTable:
def _parse_xml_table(self: parser.Parser) -> exp.XMLTable:
this = self._parse_string()
passing = None
@ -66,7 +66,7 @@ class Oracle(Dialect):
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP}
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
**parser.Parser.FUNCTIONS,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
}
@ -107,7 +107,7 @@ class Oracle(Dialect):
TABLE_HINTS = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "NUMBER",
exp.DataType.Type.SMALLINT: "NUMBER",
exp.DataType.Type.INT: "NUMBER",
@ -122,7 +122,7 @@ class Oracle(Dialect):
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**generator.Generator.TRANSFORMS,
exp.DateStrToDate: lambda self, e: self.func(
"TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD")
),
@ -143,7 +143,7 @@ class Oracle(Dialect):
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
**generator.Generator.PROPERTIES_LOCATION,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}

View file

@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import (
max_or_greatest,
min_or_least,
no_paren_current_date_sql,
no_pivot_sql,
no_tablesample_sql,
no_trycast_sql,
rename_func,
@ -33,8 +34,8 @@ DATE_DIFF_FACTOR = {
}
def _date_add_sql(kind):
def func(self, expression):
def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
from sqlglot.optimizer.simplify import simplify
this = self.sql(expression, "this")
@ -51,7 +52,7 @@ def _date_add_sql(kind):
return func
def _date_diff_sql(self, expression):
def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
unit = expression.text("unit").upper()
factor = DATE_DIFF_FACTOR.get(unit)
@ -77,7 +78,7 @@ def _date_diff_sql(self, expression):
return f"CAST({unit} AS BIGINT)"
def _substring_sql(self, expression):
def _substring_sql(self: generator.Generator, expression: exp.Substring) -> str:
this = self.sql(expression, "this")
start = self.sql(expression, "start")
length = self.sql(expression, "length")
@ -88,7 +89,7 @@ def _substring_sql(self, expression):
return f"SUBSTRING({this}{from_part}{for_part})"
def _string_agg_sql(self, expression):
def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str:
expression = expression.copy()
separator = expression.args.get("separator") or exp.Literal.string(",")
@ -102,13 +103,13 @@ def _string_agg_sql(self, expression):
return f"STRING_AGG({self.format_args(this, separator)}{order})"
def _datatype_sql(self, expression):
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
if expression.this == exp.DataType.Type.ARRAY:
return f"{self.expressions(expression, flat=True)}[]"
return self.datatype_sql(expression)
def _auto_increment_to_serial(expression):
def _auto_increment_to_serial(expression: exp.Expression) -> exp.Expression:
auto = expression.find(exp.AutoIncrementColumnConstraint)
if auto:
@ -126,7 +127,7 @@ def _auto_increment_to_serial(expression):
return expression
def _serial_to_generated(expression):
def _serial_to_generated(expression: exp.Expression) -> exp.Expression:
kind = expression.args["kind"]
if kind.this == exp.DataType.Type.SERIAL:
@ -144,6 +145,7 @@ def _serial_to_generated(expression):
constraints = expression.args["constraints"]
generated = exp.ColumnConstraint(kind=exp.GeneratedAsIdentityColumnConstraint(this=False))
notnull = exp.ColumnConstraint(kind=exp.NotNullColumnConstraint())
if notnull not in constraints:
constraints.insert(0, notnull)
if generated not in constraints:
@ -152,7 +154,7 @@ def _serial_to_generated(expression):
return expression
def _generate_series(args):
def _generate_series(args: t.List) -> exp.Expression:
# The goal is to convert step values like '1 day' or INTERVAL '1 day' into INTERVAL '1' day
step = seq_get(args, 2)
@ -168,11 +170,12 @@ def _generate_series(args):
return exp.GenerateSeries.from_arg_list(args)
def _to_timestamp(args):
def _to_timestamp(args: t.List) -> exp.Expression:
# TO_TIMESTAMP accepts either a single double argument or (text, text)
if len(args) == 1:
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE
return exp.UnixToTime.from_arg_list(args)
# https://www.postgresql.org/docs/current/functions-formatting.html
return format_time_lambda(exp.StrToTime, "postgres")(args)
@ -255,7 +258,7 @@ class Postgres(Dialect):
STRICT_CAST = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
**parser.Parser.FUNCTIONS,
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1), unit=seq_get(args, 0)
),
@ -271,7 +274,7 @@ class Postgres(Dialect):
}
BITWISE = {
**parser.Parser.BITWISE, # type: ignore
**parser.Parser.BITWISE,
TokenType.HASH: exp.BitwiseXor,
}
@ -280,7 +283,7 @@ class Postgres(Dialect):
}
RANGE_PARSERS = {
**parser.Parser.RANGE_PARSERS, # type: ignore
**parser.Parser.RANGE_PARSERS,
TokenType.DAMP: binary_range_parser(exp.ArrayOverlaps),
TokenType.AT_GT: binary_range_parser(exp.ArrayContains),
TokenType.LT_AT: binary_range_parser(exp.ArrayContained),
@ -303,14 +306,14 @@ class Postgres(Dialect):
return self.expression(exp.Extract, this=part, expression=value)
class Generator(generator.Generator):
INTERVAL_ALLOWS_PLURAL_FORM = False
SINGLE_STRING_INTERVAL = True
LOCKING_READS_SUPPORTED = True
JOIN_HINTS = False
TABLE_HINTS = False
PARAMETER_TOKEN = "$"
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "SMALLINT",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
@ -320,14 +323,9 @@ class Postgres(Dialect):
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**generator.Generator.TRANSFORMS,
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
exp.ColumnDef: transforms.preprocess(
[
_auto_increment_to_serial,
_serial_to_generated,
],
),
exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]),
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
@ -348,6 +346,7 @@ class Postgres(Dialect):
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
exp.Merge: transforms.preprocess([transforms.remove_target_from_merge]),
exp.Pivot: no_pivot_sql,
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
exp.StrPosition: str_position_sql,
@ -369,7 +368,7 @@ class Postgres(Dialect):
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
**generator.Generator.PROPERTIES_LOCATION,
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}

View file

@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import (
format_time_lambda,
if_sql,
no_ilike_sql,
no_pivot_sql,
no_safe_divide_sql,
rename_func,
struct_extract_sql,
@ -127,39 +128,12 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s
)
def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str:
start = expression.args["start"]
end = expression.args["end"]
step = expression.args.get("step")
target_type = None
if isinstance(start, exp.Cast):
target_type = start.to
elif isinstance(end, exp.Cast):
target_type = end.to
if target_type and target_type.this == exp.DataType.Type.TIMESTAMP:
to = target_type.copy()
if target_type is start.to:
end = exp.Cast(this=end, to=to)
else:
start = exp.Cast(this=start, to=to)
sql = self.func("SEQUENCE", start, end, step)
if isinstance(expression.parent, exp.Table):
sql = f"UNNEST({sql})"
return sql
def _ensure_utf8(charset: exp.Literal) -> None:
if charset.name.lower() != "utf-8":
raise UnsupportedError(f"Unsupported charset {charset}")
def _approx_percentile(args: t.Sequence) -> exp.Expression:
def _approx_percentile(args: t.List) -> exp.Expression:
if len(args) == 4:
return exp.ApproxQuantile(
this=seq_get(args, 0),
@ -176,7 +150,7 @@ def _approx_percentile(args: t.Sequence) -> exp.Expression:
return exp.ApproxQuantile.from_arg_list(args)
def _from_unixtime(args: t.Sequence) -> exp.Expression:
def _from_unixtime(args: t.List) -> exp.Expression:
if len(args) == 3:
return exp.UnixToTime(
this=seq_get(args, 0),
@ -191,22 +165,39 @@ def _from_unixtime(args: t.Sequence) -> exp.Expression:
return exp.UnixToTime.from_arg_list(args)
def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Table):
if isinstance(expression.this, exp.GenerateSeries):
unnest = exp.Unnest(expressions=[expression.this])
if expression.alias:
return exp.alias_(
unnest,
alias="_u",
table=[expression.alias],
copy=False,
)
return unnest
return expression
class Presto(Dialect):
index_offset = 1
null_ordering = "nulls_are_last"
time_format = MySQL.time_format # type: ignore
time_mapping = MySQL.time_mapping # type: ignore
time_format = MySQL.time_format
time_mapping = MySQL.time_mapping
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"START": TokenType.BEGIN,
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"ROW": TokenType.STRUCT,
}
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
**parser.Parser.FUNCTIONS,
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
"APPROX_PERCENTILE": _approx_percentile,
"CARDINALITY": exp.ArraySize.from_arg_list,
@ -252,13 +243,13 @@ class Presto(Dialect):
STRUCT_DELIMITER = ("(", ")")
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
**generator.Generator.PROPERTIES_LOCATION,
exp.LocationProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.BINARY: "VARBINARY",
@ -268,8 +259,9 @@ class Presto(Dialect):
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: _approx_distinct_sql,
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayContains: rename_func("CONTAINS"),
@ -293,7 +285,7 @@ class Presto(Dialect):
exp.Decode: _decode_sql,
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
exp.Encode: _encode_sql,
exp.GenerateSeries: _sequence_sql,
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql,
@ -301,10 +293,10 @@ class Presto(Dialect):
exp.Initcap: _initcap_sql,
exp.Lateral: _explode_to_unnest_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Pivot: no_pivot_sql,
exp.Quantile: _quantile_sql,
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.Select: transforms.preprocess(
@ -320,8 +312,7 @@ class Presto(Dialect):
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'",
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.Table: transforms.preprocess([_unnest_sequence]),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToDate: timestrtotime_sql,
exp.TimeStrToTime: timestrtotime_sql,
@ -336,6 +327,7 @@ class Presto(Dialect):
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
exp.VariancePop: rename_func("VAR_POP"),
exp.With: transforms.preprocess([transforms.add_recursive_cte_column_names]),
exp.WithinGroup: transforms.preprocess(
[transforms.remove_within_group_for_percentiles]
),
@ -351,3 +343,25 @@ class Presto(Dialect):
modes = expression.args.get("modes")
modes = f" {', '.join(modes)}" if modes else ""
return f"START TRANSACTION{modes}"
def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
start = expression.args["start"]
end = expression.args["end"]
step = expression.args.get("step")
if isinstance(start, exp.Cast):
target_type = start.to
elif isinstance(end, exp.Cast):
target_type = end.to
else:
target_type = None
if target_type and target_type.is_type(exp.DataType.Type.TIMESTAMP):
to = target_type.copy()
if target_type is start.to:
end = exp.Cast(this=end, to=to)
else:
start = exp.Cast(this=start, to=to)
return self.func("SEQUENCE", start, end, step)

View file

@ -8,21 +8,21 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
def _json_sql(self, e) -> str:
return f'{self.sql(e, "this")}."{e.expression.name}"'
def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str:
return f'{self.sql(expression, "this")}."{expression.expression.name}"'
class Redshift(Postgres):
time_format = "'YYYY-MM-DD HH:MI:SS'"
time_mapping = {
**Postgres.time_mapping, # type: ignore
**Postgres.time_mapping,
"MON": "%b",
"HH": "%H",
}
class Parser(Postgres.Parser):
FUNCTIONS = {
**Postgres.Parser.FUNCTIONS, # type: ignore
**Postgres.Parser.FUNCTIONS,
"DATEADD": lambda args: exp.DateAdd(
this=seq_get(args, 2),
expression=seq_get(args, 1),
@ -45,7 +45,7 @@ class Redshift(Postgres):
isinstance(this, exp.DataType)
and this.this == exp.DataType.Type.VARCHAR
and this.expressions
and this.expressions[0] == exp.column("MAX")
and this.expressions[0].this == exp.column("MAX")
):
this.set("expressions", [exp.Var(this="MAX")])
@ -57,9 +57,7 @@ class Redshift(Postgres):
STRING_ESCAPES = ["\\"]
KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS, # type: ignore
"GEOMETRY": TokenType.GEOMETRY,
"GEOGRAPHY": TokenType.GEOGRAPHY,
**Postgres.Tokenizer.KEYWORDS,
"HLLSKETCH": TokenType.HLLSKETCH,
"SUPER": TokenType.SUPER,
"SYSDATE": TokenType.CURRENT_TIMESTAMP,
@ -76,22 +74,22 @@ class Redshift(Postgres):
class Generator(Postgres.Generator):
LOCKING_READS_SUPPORTED = False
SINGLE_STRING_INTERVAL = True
RENAME_TABLE_WITH_DB = False
TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING, # type: ignore
**Postgres.Generator.TYPE_MAPPING,
exp.DataType.Type.BINARY: "VARBYTE",
exp.DataType.Type.VARBINARY: "VARBYTE",
exp.DataType.Type.INT: "INTEGER",
}
PROPERTIES_LOCATION = {
**Postgres.Generator.PROPERTIES_LOCATION, # type: ignore
**Postgres.Generator.PROPERTIES_LOCATION,
exp.LikeProperty: exp.Properties.Location.POST_WITH,
}
TRANSFORMS = {
**Postgres.Generator.TRANSFORMS, # type: ignore
**Postgres.Generator.TRANSFORMS,
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
exp.DateAdd: lambda self, e: self.func(
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
@ -107,10 +105,13 @@ class Redshift(Postgres):
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
}
# Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots
TRANSFORMS.pop(exp.Pivot)
# Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres)
TRANSFORMS.pop(exp.Pow)
RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot"}
RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"}
def values_sql(self, expression: exp.Values) -> str:
"""
@ -120,37 +121,36 @@ class Redshift(Postgres):
evaluate the expression. You may need to increase `sys.setrecursionlimit` to run and it can also be
very slow.
"""
if not isinstance(expression.unnest().parent, exp.From):
# The VALUES clause is still valid in an `INSERT INTO ..` statement, for example
if not expression.find_ancestor(exp.From, exp.Join):
return super().values_sql(expression)
rows = [tuple_exp.expressions for tuple_exp in expression.expressions]
column_names = expression.alias and expression.args["alias"].columns
selects = []
rows = [tuple_exp.expressions for tuple_exp in expression.expressions]
for i, row in enumerate(rows):
if i == 0 and expression.alias:
if i == 0 and column_names:
row = [
exp.alias_(value, column_name)
for value, column_name in zip(row, expression.args["alias"].args["columns"])
for value, column_name in zip(row, column_names)
]
selects.append(exp.Select(expressions=row))
subquery_expression = selects[0]
subquery_expression: exp.Select | exp.Union = selects[0]
if len(selects) > 1:
for select in selects[1:]:
subquery_expression = exp.union(subquery_expression, select, distinct=False)
return self.subquery_sql(subquery_expression.subquery(expression.alias))
def with_properties(self, properties: exp.Properties) -> str:
"""Redshift doesn't have `WITH` as part of their with_properties so we remove it"""
return self.properties(properties, prefix=" ", suffix="")
def renametable_sql(self, expression: exp.RenameTable) -> str:
"""Redshift only supports defining the table name itself (not the db) when renaming tables"""
expression = expression.copy()
target_table = expression.this
for arg in target_table.args:
if arg != "this":
target_table.set(arg, None)
this = self.sql(expression, "this")
return f"RENAME TO {this}"
def datatype_sql(self, expression: exp.DataType) -> str:
"""
Redshift converts the `TEXT` data type to `VARCHAR(255)` by default when people more generally mean
@ -162,6 +162,8 @@ class Redshift(Postgres):
expression = expression.copy()
expression.set("this", exp.DataType.Type.VARCHAR)
precision = expression.args.get("expressions")
if not precision:
expression.append("expressions", exp.Var(this="MAX"))
return super().datatype_sql(expression)

View file

@ -18,7 +18,7 @@ from sqlglot.dialects.dialect import (
var_map_sql,
)
from sqlglot.expressions import Literal
from sqlglot.helper import flatten, seq_get
from sqlglot.helper import seq_get
from sqlglot.parser import binary_range_parser
from sqlglot.tokens import TokenType
@ -30,7 +30,7 @@ def _check_int(s: str) -> bool:
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.UnixToTime]:
def _snowflake_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]:
if len(args) == 2:
first_arg, second_arg = args
if second_arg.is_string:
@ -52,8 +52,12 @@ def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.Unix
return exp.UnixToTime(this=first_arg, scale=timescale)
from sqlglot.optimizer.simplify import simplify_literals
# The first argument might be an expression like 40 * 365 * 86400, so we try to
# reduce it using `simplify_literals` first and then check if it's a Literal.
first_arg = seq_get(args, 0)
if not isinstance(first_arg, Literal):
if not isinstance(simplify_literals(first_arg, root=True), Literal):
# case: <variant_expr>
return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args)
@ -69,6 +73,19 @@ def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.Unix
return exp.UnixToTime.from_arg_list(args)
def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
expression = parser.parse_var_map(args)
if isinstance(expression, exp.StarMap):
return expression
return exp.Struct(
expressions=[
t.cast(exp.Condition, k).eq(v) for k, v in zip(expression.keys, expression.values)
]
)
def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
@ -116,7 +133,7 @@ def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]:
# https://docs.snowflake.com/en/sql-reference/functions/div0
def _div0_to_if(args: t.Sequence) -> exp.Expression:
def _div0_to_if(args: t.List) -> exp.Expression:
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
true = exp.Literal.number(0)
false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1))
@ -124,13 +141,13 @@ def _div0_to_if(args: t.Sequence) -> exp.Expression:
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
def _zeroifnull_to_if(args: t.Sequence) -> exp.Expression:
def _zeroifnull_to_if(args: t.List) -> exp.Expression:
cond = exp.Is(this=seq_get(args, 0), expression=exp.Null())
return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0))
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
def _nullifzero_to_if(args: t.Sequence) -> exp.Expression:
def _nullifzero_to_if(args: t.List) -> exp.Expression:
cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0))
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
@ -143,6 +160,12 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
return self.datatype_sql(expression)
def _parse_convert_timezone(args: t.List) -> exp.Expression:
if len(args) == 3:
return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args)
return exp.AtTimeZone(this=seq_get(args, 1), zone=seq_get(args, 0))
class Snowflake(Dialect):
null_ordering = "nulls_are_large"
time_format = "'yyyy-mm-dd hh24:mi:ss'"
@ -177,17 +200,14 @@ class Snowflake(Dialect):
}
class Parser(parser.Parser):
QUOTED_PIVOT_COLUMNS = True
IDENTIFY_PIVOT_STRINGS = True
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
"ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list,
"CONVERT_TIMEZONE": lambda args: exp.AtTimeZone(
this=seq_get(args, 1),
zone=seq_get(args, 0),
),
"CONVERT_TIMEZONE": _parse_convert_timezone,
"DATE_TRUNC": date_trunc_to_time,
"DATEADD": lambda args: exp.DateAdd(
this=seq_get(args, 2),
@ -202,7 +222,7 @@ class Snowflake(Dialect):
"DIV0": _div0_to_if,
"IFF": exp.If.from_arg_list,
"NULLIFZERO": _nullifzero_to_if,
"OBJECT_CONSTRUCT": parser.parse_var_map,
"OBJECT_CONSTRUCT": _parse_object_construct,
"RLIKE": exp.RegexpLike.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"TO_ARRAY": exp.Array.from_arg_list,
@ -224,7 +244,7 @@ class Snowflake(Dialect):
}
COLUMN_OPERATORS = {
**parser.Parser.COLUMN_OPERATORS, # type: ignore
**parser.Parser.COLUMN_OPERATORS,
TokenType.COLON: lambda self, this, path: self.expression(
exp.Bracket,
this=this,
@ -232,14 +252,16 @@ class Snowflake(Dialect):
),
}
TIMESTAMPS = parser.Parser.TIMESTAMPS.copy() - {TokenType.TIME}
RANGE_PARSERS = {
**parser.Parser.RANGE_PARSERS, # type: ignore
**parser.Parser.RANGE_PARSERS,
TokenType.LIKE_ANY: binary_range_parser(exp.LikeAny),
TokenType.ILIKE_ANY: binary_range_parser(exp.ILikeAny),
}
ALTER_PARSERS = {
**parser.Parser.ALTER_PARSERS, # type: ignore
**parser.Parser.ALTER_PARSERS,
"UNSET": lambda self: self._parse_alter_table_set_tag(unset=True),
"SET": lambda self: self._parse_alter_table_set_tag(),
}
@ -256,17 +278,20 @@ class Snowflake(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"CHAR VARYING": TokenType.VARCHAR,
"CHARACTER VARYING": TokenType.VARCHAR,
"EXCLUDE": TokenType.EXCEPT,
"ILIKE ANY": TokenType.ILIKE_ANY,
"LIKE ANY": TokenType.LIKE_ANY,
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"MINUS": TokenType.EXCEPT,
"NCHAR VARYING": TokenType.VARCHAR,
"PUT": TokenType.COMMAND,
"RENAME": TokenType.REPLACE,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPNTZ": TokenType.TIMESTAMP,
"MINUS": TokenType.EXCEPT,
"SAMPLE": TokenType.TABLE_SAMPLE,
}
@ -285,7 +310,7 @@ class Snowflake(Dialect):
TABLE_HINTS = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**generator.Generator.TRANSFORMS,
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
@ -299,6 +324,7 @@ class Snowflake(Dialect):
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.Extract: rename_func("DATE_PART"),
exp.If: rename_func("IFF"),
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
exp.LogicalOr: rename_func("BOOLOR_AGG"),
@ -312,6 +338,10 @@ class Snowflake(Dialect):
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Struct: lambda self, e: self.func(
"OBJECT_CONSTRUCT",
*(arg for expression in e.expressions for arg in expression.flatten()),
),
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.TimeToStr: lambda self, e: self.func(
@ -326,7 +356,7 @@ class Snowflake(Dialect):
}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
}
@ -336,7 +366,7 @@ class Snowflake(Dialect):
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
**generator.Generator.PROPERTIES_LOCATION,
exp.SetProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
@ -351,53 +381,10 @@ class Snowflake(Dialect):
self.unsupported("INTERSECT with All is not supported in Snowflake")
return super().intersect_op(expression)
def values_sql(self, expression: exp.Values) -> str:
"""Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted.
We also want to make sure that after we find matches where we need to unquote a column that we prevent users
from adding quotes to the column by using the `identify` argument when generating the SQL.
"""
alias = expression.args.get("alias")
if alias and alias.args.get("columns"):
expression = expression.transform(
lambda node: exp.Identifier(**{**node.args, "quoted": False})
if isinstance(node, exp.Identifier)
and isinstance(node.parent, exp.TableAlias)
and node.arg_key == "columns"
else node,
)
return self.no_identify(lambda: super(self.__class__, self).values_sql(expression))
return super().values_sql(expression)
def settag_sql(self, expression: exp.SetTag) -> str:
action = "UNSET" if expression.args.get("unset") else "SET"
return f"{action} TAG {self.expressions(expression)}"
def select_sql(self, expression: exp.Select) -> str:
"""Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted and also
that all columns in a SELECT are unquoted. We also want to make sure that after we find matches where we need
to unquote a column that we prevent users from adding quotes to the column by using the `identify` argument when
generating the SQL.
Note: We make an assumption that any columns referenced in a VALUES expression should be unquoted throughout the
expression. This might not be true in a case where the same column name can be sourced from another table that can
properly quote but should be true in most cases.
"""
values_identifiers = set(
flatten(
(v.args.get("alias") or exp.Alias()).args.get("columns", [])
for v in expression.find_all(exp.Values)
)
)
if values_identifiers:
expression = expression.transform(
lambda node: exp.Identifier(**{**node.args, "quoted": False})
if isinstance(node, exp.Identifier) and node in values_identifiers
else node,
)
return self.no_identify(lambda: super(self.__class__, self).select_sql(expression))
return super().select_sql(expression)
def describe_sql(self, expression: exp.Describe) -> str:
# Default to table if kind is unknown
kind_value = expression.args.get("kind") or "TABLE"

View file

@ -7,10 +7,10 @@ from sqlglot.dialects.spark2 import Spark2
from sqlglot.helper import seq_get
def _parse_datediff(args: t.Sequence) -> exp.Expression:
def _parse_datediff(args: t.List) -> exp.Expression:
"""
Although Spark docs don't mention the "unit" argument, Spark3 added support for
it at some point. Databricks also supports this variation (see below).
it at some point. Databricks also supports this variant (see below).
For example, in spark-sql (v3.3.1):
- SELECT DATEDIFF('2020-01-01', '2020-01-05') results in -4
@ -36,7 +36,7 @@ def _parse_datediff(args: t.Sequence) -> exp.Expression:
class Spark(Spark2):
class Parser(Spark2.Parser):
FUNCTIONS = {
**Spark2.Parser.FUNCTIONS, # type: ignore
**Spark2.Parser.FUNCTIONS,
"DATEDIFF": _parse_datediff,
}

View file

@ -3,7 +3,12 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, parser, transforms
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
from sqlglot.dialects.dialect import (
create_with_partitions_sql,
pivot_column_names,
rename_func,
trim_sql,
)
from sqlglot.dialects.hive import Hive
from sqlglot.helper import seq_get
@ -26,7 +31,7 @@ def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
return f"MAP_FROM_ARRAYS({keys}, {values})"
def _parse_as_cast(to_type: str) -> t.Callable[[t.Sequence], exp.Expression]:
def _parse_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]:
return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type))
@ -53,10 +58,56 @@ def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
raise ValueError("Improper scale for timestamp")
def _unalias_pivot(expression: exp.Expression) -> exp.Expression:
"""
Spark doesn't allow PIVOT aliases, so we need to remove them and possibly wrap a
pivoted source in a subquery with the same alias to preserve the query's semantics.
Example:
>>> from sqlglot import parse_one
>>> expr = parse_one("SELECT piv.x FROM tbl PIVOT (SUM(a) FOR b IN ('x')) piv")
>>> print(_unalias_pivot(expr).sql(dialect="spark"))
SELECT piv.x FROM (SELECT * FROM tbl PIVOT(SUM(a) FOR b IN ('x'))) AS piv
"""
if isinstance(expression, exp.From) and expression.this.args.get("pivots"):
pivot = expression.this.args["pivots"][0]
if pivot.alias:
alias = pivot.args["alias"].pop()
return exp.From(
this=expression.this.replace(
exp.select("*").from_(expression.this.copy()).subquery(alias=alias)
)
)
return expression
def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
"""
Spark doesn't allow the column referenced in the PIVOT's field to be qualified,
so we need to unqualify it.
Example:
>>> from sqlglot import parse_one
>>> expr = parse_one("SELECT * FROM tbl PIVOT (SUM(tbl.sales) FOR tbl.quarter IN ('Q1', 'Q2'))")
>>> print(_unqualify_pivot_columns(expr).sql(dialect="spark"))
SELECT * FROM tbl PIVOT(SUM(tbl.sales) FOR quarter IN ('Q1', 'Q1'))
"""
if isinstance(expression, exp.Pivot):
expression.args["field"].transform(
lambda node: exp.column(node.output_name, quoted=node.this.quoted)
if isinstance(node, exp.Column)
else node,
copy=False,
)
return expression
class Spark2(Hive):
class Parser(Hive.Parser):
FUNCTIONS = {
**Hive.Parser.FUNCTIONS, # type: ignore
**Hive.Parser.FUNCTIONS,
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
"LEFT": lambda args: exp.Substring(
@ -110,7 +161,7 @@ class Spark2(Hive):
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS, # type: ignore
**parser.Parser.FUNCTION_PARSERS,
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
@ -131,43 +182,21 @@ class Spark2(Hive):
kind="COLUMNS",
)
def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
# Spark doesn't add a suffix to the pivot columns when there's a single aggregation
if len(pivot_columns) == 1:
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
if len(aggregations) == 1:
return [""]
names = []
for agg in pivot_columns:
if isinstance(agg, exp.Alias):
names.append(agg.alias)
else:
"""
This case corresponds to aggregations without aliases being used as suffixes
(e.g. col_avg(foo)). We need to unquote identifiers because they're going to
be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
Moreover, function names are lowercased in order to mimic Spark's naming scheme.
"""
agg_all_unquoted = agg.transform(
lambda node: exp.Identifier(this=node.name, quoted=False)
if isinstance(node, exp.Identifier)
else node
)
names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
return names
return pivot_column_names(aggregations, dialect="spark")
class Generator(Hive.Generator):
TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING, # type: ignore
**Hive.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "BYTE",
exp.DataType.Type.SMALLINT: "SHORT",
exp.DataType.Type.BIGINT: "LONG",
}
PROPERTIES_LOCATION = {
**Hive.Generator.PROPERTIES_LOCATION, # type: ignore
**Hive.Generator.PROPERTIES_LOCATION,
exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
@ -175,7 +204,7 @@ class Spark2(Hive):
}
TRANSFORMS = {
**Hive.Generator.TRANSFORMS, # type: ignore
**Hive.Generator.TRANSFORMS,
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
@ -188,11 +217,12 @@ class Spark2(Hive):
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
exp.From: transforms.preprocess([_unalias_pivot]),
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Map: _map_sql,
exp.Pivot: transforms.preprocess([transforms.unqualify_pivot_columns]),
exp.Pivot: transforms.preprocess([_unqualify_pivot_columns]),
exp.Reduce: rename_func("AGGREGATE"),
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",

View file

@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import (
arrow_json_extract_sql,
count_if_to_sum,
no_ilike_sql,
no_pivot_sql,
no_tablesample_sql,
no_trycast_sql,
rename_func,
@ -14,7 +15,7 @@ from sqlglot.dialects.dialect import (
from sqlglot.tokens import TokenType
def _date_add_sql(self, expression):
def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
modifier = expression.expression
modifier = modifier.name if modifier.is_string else self.sql(modifier)
unit = expression.args.get("unit")
@ -67,7 +68,7 @@ class SQLite(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
**parser.Parser.FUNCTIONS,
"EDITDIST3": exp.Levenshtein.from_arg_list,
}
@ -76,7 +77,7 @@ class SQLite(Dialect):
TABLE_HINTS = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BOOLEAN: "INTEGER",
exp.DataType.Type.TINYINT: "INTEGER",
exp.DataType.Type.SMALLINT: "INTEGER",
@ -98,7 +99,7 @@ class SQLite(Dialect):
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**generator.Generator.TRANSFORMS,
exp.CountIf: count_if_to_sum,
exp.Create: transforms.preprocess([_transform_create]),
exp.CurrentDate: lambda *_: "CURRENT_DATE",
@ -114,6 +115,7 @@ class SQLite(Dialect):
exp.Levenshtein: rename_func("EDITDIST3"),
exp.LogicalOr: rename_func("MAX"),
exp.LogicalAnd: rename_func("MIN"),
exp.Pivot: no_pivot_sql,
exp.Select: transforms.preprocess(
[transforms.eliminate_distinct_on, transforms.eliminate_qualify]
),
@ -163,12 +165,15 @@ class SQLite(Dialect):
return f"CAST({sql} AS INTEGER)"
# https://www.sqlite.org/lang_aggfunc.html#group_concat
def groupconcat_sql(self, expression):
def groupconcat_sql(self, expression: exp.GroupConcat) -> str:
this = expression.this
distinct = expression.find(exp.Distinct)
if distinct:
this = distinct.expressions[0]
distinct = "DISTINCT "
distinct_sql = "DISTINCT "
else:
distinct_sql = ""
if isinstance(expression.this, exp.Order):
self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.")
@ -176,7 +181,7 @@ class SQLite(Dialect):
this = expression.this.this
separator = expression.args.get("separator")
return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})"
return f"GROUP_CONCAT({distinct_sql}{self.format_args(this, separator)})"
def least_sql(self, expression: exp.Least) -> str:
if len(expression.expressions) > 1:

View file

@ -11,25 +11,24 @@ from sqlglot.helper import seq_get
class StarRocks(MySQL):
class Parser(MySQL.Parser): # type: ignore
class Parser(MySQL.Parser):
FUNCTIONS = {
**MySQL.Parser.FUNCTIONS,
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1), unit=seq_get(args, 0)
),
}
class Generator(MySQL.Generator): # type: ignore
class Generator(MySQL.Generator):
TYPE_MAPPING = {
**MySQL.Generator.TYPE_MAPPING, # type: ignore
**MySQL.Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.TIMESTAMP: "DATETIME",
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
}
TRANSFORMS = {
**MySQL.Generator.TRANSFORMS, # type: ignore
**MySQL.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
@ -43,4 +42,5 @@ class StarRocks(MySQL):
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
}
TRANSFORMS.pop(exp.DateTrunc)

View file

@ -4,41 +4,38 @@ from sqlglot import exp, generator, parser, transforms
from sqlglot.dialects.dialect import Dialect
def _if_sql(self, expression):
return f"IF {self.sql(expression, 'this')} THEN {self.sql(expression, 'true')} ELSE {self.sql(expression, 'false')} END"
def _coalesce_sql(self, expression):
return f"IFNULL({self.sql(expression, 'this')}, {self.expressions(expression)})"
def _count_sql(self, expression):
this = expression.this
if isinstance(this, exp.Distinct):
return f"COUNTD({self.expressions(this, flat=True)})"
return f"COUNT({self.sql(expression, 'this')})"
class Tableau(Dialect):
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
exp.If: _if_sql,
exp.Coalesce: _coalesce_sql,
exp.Count: _count_sql,
**generator.Generator.TRANSFORMS,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
**generator.Generator.PROPERTIES_LOCATION,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
def if_sql(self, expression: exp.If) -> str:
this = self.sql(expression, "this")
true = self.sql(expression, "true")
false = self.sql(expression, "false")
return f"IF {this} THEN {true} ELSE {false} END"
def coalesce_sql(self, expression: exp.Coalesce) -> str:
return f"IFNULL({self.sql(expression, 'this')}, {self.expressions(expression)})"
def count_sql(self, expression: exp.Count) -> str:
this = expression.this
if isinstance(this, exp.Distinct):
return f"COUNTD({self.expressions(this, flat=True)})"
return f"COUNT({self.sql(expression, 'this')})"
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
**parser.Parser.FUNCTIONS,
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
}

View file

@ -75,12 +75,12 @@ class Teradata(Dialect):
FUNC_TOKENS.remove(TokenType.REPLACE)
STATEMENT_PARSERS = {
**parser.Parser.STATEMENT_PARSERS, # type: ignore
**parser.Parser.STATEMENT_PARSERS,
TokenType.REPLACE: lambda self: self._parse_create(),
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS, # type: ignore
**parser.Parser.FUNCTION_PARSERS,
"RANGE_N": lambda self: self._parse_rangen(),
"TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST),
}
@ -106,7 +106,7 @@ class Teradata(Dialect):
exp.Update,
**{ # type: ignore
"this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
"from": self._parse_from(),
"from": self._parse_from(modifiers=True),
"expressions": self._match(TokenType.SET)
and self._parse_csv(self._parse_equality),
"where": self._parse_where(),
@ -135,13 +135,15 @@ class Teradata(Dialect):
TABLE_HINTS = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.GEOMETRY: "ST_GEOMETRY",
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX,
**generator.Generator.PROPERTIES_LOCATION,
exp.OnCommitProperty: exp.Properties.Location.POST_INDEX,
exp.PartitionedByProperty: exp.Properties.Location.POST_EXPRESSION,
exp.StabilityProperty: exp.Properties.Location.POST_CREATE,
}
TRANSFORMS = {

View file

@ -7,7 +7,7 @@ from sqlglot.dialects.presto import Presto
class Trino(Presto):
class Generator(Presto.Generator):
TRANSFORMS = {
**Presto.Generator.TRANSFORMS, # type: ignore
**Presto.Generator.TRANSFORMS,
exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
}

View file

@ -16,6 +16,9 @@ from sqlglot.helper import seq_get
from sqlglot.time import format_time
from sqlglot.tokens import TokenType
if t.TYPE_CHECKING:
from sqlglot._typing import E
FULL_FORMAT_TIME_MAPPING = {
"weekday": "%A",
"dw": "%A",
@ -50,13 +53,17 @@ DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{
TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}
def _format_time_lambda(exp_class, full_format_mapping=None, default=None):
def _format_time(args):
def _format_time_lambda(
exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None
) -> t.Callable[[t.List], E]:
def _format_time(args: t.List) -> E:
assert len(args) == 2
return exp_class(
this=seq_get(args, 1),
this=args[1],
format=exp.Literal.string(
format_time(
seq_get(args, 0).name or (TSQL.time_format if default is True else default),
args[0].name,
{**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING}
if full_format_mapping
else TSQL.time_mapping,
@ -67,13 +74,17 @@ def _format_time_lambda(exp_class, full_format_mapping=None, default=None):
return _format_time
def _parse_format(args):
fmt = seq_get(args, 1)
number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this)
def _parse_format(args: t.List) -> exp.Expression:
assert len(args) == 2
fmt = args[1]
number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.name)
if number_fmt:
return exp.NumberToStr(this=seq_get(args, 0), format=fmt)
return exp.NumberToStr(this=args[0], format=fmt)
return exp.TimeToStr(
this=seq_get(args, 0),
this=args[0],
format=exp.Literal.string(
format_time(fmt.name, TSQL.format_time_mapping)
if len(fmt.name) == 1
@ -82,7 +93,7 @@ def _parse_format(args):
)
def _parse_eomonth(args):
def _parse_eomonth(args: t.List) -> exp.Expression:
date = seq_get(args, 0)
month_lag = seq_get(args, 1)
unit = DATE_DELTA_INTERVAL.get("month")
@ -96,7 +107,7 @@ def _parse_eomonth(args):
return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit))
def _parse_hashbytes(args):
def _parse_hashbytes(args: t.List) -> exp.Expression:
kind, data = args
kind = kind.name.upper() if kind.is_string else ""
@ -110,40 +121,47 @@ def _parse_hashbytes(args):
return exp.SHA2(this=data, length=exp.Literal.number(256))
if kind == "SHA2_512":
return exp.SHA2(this=data, length=exp.Literal.number(512))
return exp.func("HASHBYTES", *args)
def generate_date_delta_with_unit_sql(self, e):
func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF"
return self.func(func, e.text("unit"), e.expression, e.this)
def generate_date_delta_with_unit_sql(
self: generator.Generator, expression: exp.DateAdd | exp.DateDiff
) -> str:
func = "DATEADD" if isinstance(expression, exp.DateAdd) else "DATEDIFF"
return self.func(func, expression.text("unit"), expression.expression, expression.this)
def _format_sql(self, e):
def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str:
fmt = (
e.args["format"]
if isinstance(e, exp.NumberToStr)
else exp.Literal.string(format_time(e.text("format"), TSQL.inverse_time_mapping))
expression.args["format"]
if isinstance(expression, exp.NumberToStr)
else exp.Literal.string(
format_time(
expression.text("format"), t.cast(t.Dict[str, str], TSQL.inverse_time_mapping)
)
)
)
return self.func("FORMAT", e.this, fmt)
return self.func("FORMAT", expression.this, fmt)
def _string_agg_sql(self, e):
e = e.copy()
def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str:
expression = expression.copy()
this = e.this
distinct = e.find(exp.Distinct)
this = expression.this
distinct = expression.find(exp.Distinct)
if distinct:
# exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression
self.unsupported("T-SQL STRING_AGG doesn't support DISTINCT.")
this = distinct.pop().expressions[0]
order = ""
if isinstance(e.this, exp.Order):
if e.this.this:
this = e.this.this.pop()
order = f" WITHIN GROUP ({self.sql(e.this)[1:]})" # Order has a leading space
if isinstance(expression.this, exp.Order):
if expression.this.this:
this = expression.this.this.pop()
order = f" WITHIN GROUP ({self.sql(expression.this)[1:]})" # Order has a leading space
separator = e.args.get("separator") or exp.Literal.string(",")
separator = expression.args.get("separator") or exp.Literal.string(",")
return f"STRING_AGG({self.format_args(this, separator)}){order}"
@ -292,7 +310,7 @@ class TSQL(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
**parser.Parser.FUNCTIONS,
"CHARINDEX": lambda args: exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
@ -332,13 +350,13 @@ class TSQL(Dialect):
DataType.Type.NCHAR,
}
RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - { # type: ignore
RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - {
TokenType.TABLE,
*parser.Parser.TYPE_TOKENS, # type: ignore
*parser.Parser.TYPE_TOKENS,
}
STATEMENT_PARSERS = {
**parser.Parser.STATEMENT_PARSERS, # type: ignore
**parser.Parser.STATEMENT_PARSERS,
TokenType.END: lambda self: self._parse_command(),
}
@ -377,7 +395,7 @@ class TSQL(Dialect):
return system_time
def _parse_table_parts(self, schema: bool = False) -> exp.Expression:
def _parse_table_parts(self, schema: bool = False) -> exp.Table:
table = super()._parse_table_parts(schema=schema)
table.set("system_time", self._parse_system_time())
return table
@ -450,7 +468,7 @@ class TSQL(Dialect):
LOCKING_READS_SUPPORTED = True
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.DECIMAL: "NUMERIC",
exp.DataType.Type.DATETIME: "DATETIME2",
@ -458,7 +476,7 @@ class TSQL(Dialect):
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**generator.Generator.TRANSFORMS,
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),
@ -480,7 +498,7 @@ class TSQL(Dialect):
TRANSFORMS.pop(exp.ReturnsProperty)
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
**generator.Generator.PROPERTIES_LOCATION,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}

View file

@ -53,7 +53,8 @@ class Keep:
if t.TYPE_CHECKING:
T = t.TypeVar("T")
from sqlglot._typing import T
Edit = t.Union[Insert, Remove, Move, Update, Keep]
@ -240,7 +241,7 @@ class ChangeDistiller:
return matching_set
def _compute_leaf_matching_set(self) -> t.Set[t.Tuple[int, int]]:
candidate_matchings: t.List[t.Tuple[float, int, exp.Expression, exp.Expression]] = []
candidate_matchings: t.List[t.Tuple[float, int, int, exp.Expression, exp.Expression]] = []
source_leaves = list(_get_leaves(self._source))
target_leaves = list(_get_leaves(self._target))
for source_leaf in source_leaves:
@ -252,6 +253,7 @@ class ChangeDistiller:
candidate_matchings,
(
-similarity_score,
-_parent_similarity_score(source_leaf, target_leaf),
len(candidate_matchings),
source_leaf,
target_leaf,
@ -261,7 +263,7 @@ class ChangeDistiller:
# Pick best matchings based on the highest score
matching_set = set()
while candidate_matchings:
_, _, source_leaf, target_leaf = heappop(candidate_matchings)
_, _, _, source_leaf, target_leaf = heappop(candidate_matchings)
if (
id(source_leaf) in self._unmatched_source_nodes
and id(target_leaf) in self._unmatched_target_nodes
@ -327,6 +329,15 @@ def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool:
return False
def _parent_similarity_score(
source: t.Optional[exp.Expression], target: t.Optional[exp.Expression]
) -> int:
if source is None or target is None or type(source) is not type(target):
return 0
return 1 + _parent_similarity_score(source.parent, target.parent)
def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]:
args: t.List[t.Union[exp.Expression, t.List]] = []
if expression:

View file

@ -14,9 +14,10 @@ from sqlglot import maybe_parse
from sqlglot.errors import ExecuteError
from sqlglot.executor.python import PythonExecutor
from sqlglot.executor.table import Table, ensure_tables
from sqlglot.helper import dict_depth
from sqlglot.optimizer import optimize
from sqlglot.planner import Plan
from sqlglot.schema import ensure_schema
from sqlglot.schema import ensure_schema, flatten_schema, nested_get, nested_set
logger = logging.getLogger("sqlglot")
@ -52,10 +53,15 @@ def execute(
tables_ = ensure_tables(tables)
if not schema:
schema = {
name: {column: type(table[0][column]).__name__ for column in table.columns}
for name, table in tables_.mapping.items()
}
schema = {}
flattened_tables = flatten_schema(tables_.mapping, depth=dict_depth(tables_.mapping))
for keys in flattened_tables:
table = nested_get(tables_.mapping, *zip(keys, keys))
assert table is not None
for column in table.columns:
nested_set(schema, [*keys, column], type(table[0][column]).__name__)
schema = ensure_schema(schema, dialect=read)

View file

@ -5,6 +5,7 @@ import statistics
from functools import wraps
from sqlglot import exp
from sqlglot.generator import Generator
from sqlglot.helper import PYTHON_VERSION
@ -102,6 +103,8 @@ def cast(this, to):
return datetime.date.fromisoformat(this)
if to == exp.DataType.Type.DATETIME:
return datetime.datetime.fromisoformat(this)
if to == exp.DataType.Type.BOOLEAN:
return bool(this)
if to in exp.DataType.TEXT_TYPES:
return str(this)
if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}:
@ -119,9 +122,11 @@ def ordered(this, desc, nulls_first):
@null_if_any
def interval(this, unit):
if unit == "DAY":
return datetime.timedelta(days=float(this))
raise NotImplementedError
unit = unit.lower()
plural = unit + "s"
if plural in Generator.TIME_PART_SINGULARS:
unit = plural
return datetime.timedelta(**{unit: float(this)})
ENV = {
@ -147,7 +152,9 @@ ENV = {
"COALESCE": lambda *args: next((a for a in args if a is not None), None),
"CONCAT": null_if_any(lambda *args: "".join(args)),
"CONCATWS": null_if_any(lambda this, *args: this.join(args)),
"DATESTRTODATE": null_if_any(lambda arg: datetime.date.fromisoformat(arg)),
"DIV": null_if_any(lambda e, this: e / this),
"DOT": null_if_any(lambda e, this: e[this]),
"EQ": null_if_any(lambda this, e: this == e),
"EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
"GT": null_if_any(lambda this, e: this > e),
@ -162,6 +169,7 @@ ENV = {
"LOWER": null_if_any(lambda arg: arg.lower()),
"LT": null_if_any(lambda this, e: this < e),
"LTE": null_if_any(lambda this, e: this <= e),
"MAP": null_if_any(lambda *args: dict(zip(*args))), # type: ignore
"MOD": null_if_any(lambda e, this: e % this),
"MUL": null_if_any(lambda e, this: e * this),
"NEQ": null_if_any(lambda this, e: this != e),
@ -180,4 +188,5 @@ ENV = {
"CURRENTTIMESTAMP": datetime.datetime.now,
"CURRENTTIME": datetime.datetime.now,
"CURRENTDATE": datetime.date.today,
"STRFTIME": null_if_any(lambda fmt, arg: datetime.datetime.fromisoformat(arg).strftime(fmt)),
}

View file

@ -360,11 +360,19 @@ def _ordered_py(self, expression):
def _rename(self, e):
try:
if "expressions" in e.args:
this = self.sql(e, "this")
this = f"{this}, " if this else ""
return f"{e.key.upper()}({this}{self.expressions(e)})"
return self.func(e.key, *e.args.values())
values = list(e.args.values())
if len(values) == 1:
values = values[0]
if not isinstance(values, list):
return self.func(e.key, values)
return self.func(e.key, *values)
if isinstance(e, exp.Func) and e.is_var_len_args:
*head, tail = values
return self.func(e.key, *head, *tail)
return self.func(e.key, *values)
except Exception as ex:
raise Exception(f"Could not rename {repr(e)}") from ex
@ -413,6 +421,7 @@ class Python(Dialect):
exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})",
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})",
exp.Interval: lambda self, e: f"INTERVAL({self.sql(e.this)}, '{self.sql(e.unit)}')",
exp.Is: lambda self, e: self.binary(e, "is"),
exp.Lambda: _lambda_sql,
exp.Not: lambda self, e: f"not {self.sql(e.this)}",

File diff suppressed because it is too large Load diff

View file

@ -31,6 +31,8 @@ class Generator:
hex_end (str): specifies which ending character to use to delimit hex literals. Default: None.
byte_start (str): specifies which starting character to use to delimit byte literals. Default: None.
byte_end (str): specifies which ending character to use to delimit byte literals. Default: None.
raw_start (str): specifies which starting character to use to delimit raw literals. Default: None.
raw_end (str): specifies which ending character to use to delimit raw literals. Default: None.
identify (bool | str): 'always': always quote, 'safe': quote identifiers if they don't contain an upcase, True defaults to always.
normalize (bool): if set to True all identifiers will lower cased
string_escape (str): specifies a string escape character. Default: '.
@ -76,11 +78,12 @@ class Generator:
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
exp.MaterializedProperty: lambda self, e: "MATERIALIZED",
exp.NoPrimaryIndexProperty: lambda self, e: "NO PRIMARY INDEX",
exp.OnCommitProperty: lambda self, e: "ON COMMIT PRESERVE ROWS",
exp.OnCommitProperty: lambda self, e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS",
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
exp.TemporaryProperty: lambda self, e: f"{'GLOBAL ' if e.args.get('global_') else ''}TEMPORARY",
exp.TemporaryProperty: lambda self, e: f"TEMPORARY",
exp.TransientProperty: lambda self, e: "TRANSIENT",
exp.StabilityProperty: lambda self, e: e.name,
exp.VolatileProperty: lambda self, e: "VOLATILE",
@ -133,6 +136,15 @@ class Generator:
# Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH")
LIMIT_FETCH = "ALL"
# Whether a table is allowed to be renamed with a db
RENAME_TABLE_WITH_DB = True
# The separator for grouping sets and rollups
GROUPINGS_SEP = ","
# The string used for creating index on a table
INDEX_ON = "ON"
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@ -167,7 +179,6 @@ class Generator:
PARAMETER_TOKEN = "@"
PROPERTIES_LOCATION = {
exp.AfterJournalProperty: exp.Properties.Location.POST_NAME,
exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE,
exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA,
exp.BlockCompressionProperty: exp.Properties.Location.POST_NAME,
@ -196,7 +207,9 @@ class Generator:
exp.MergeBlockRatioProperty: exp.Properties.Location.POST_NAME,
exp.NoPrimaryIndexProperty: exp.Properties.Location.POST_EXPRESSION,
exp.OnCommitProperty: exp.Properties.Location.POST_EXPRESSION,
exp.Order: exp.Properties.Location.POST_SCHEMA,
exp.PartitionedByProperty: exp.Properties.Location.POST_WITH,
exp.PrimaryKey: exp.Properties.Location.POST_SCHEMA,
exp.Property: exp.Properties.Location.POST_WITH,
exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA,
exp.RowFormatProperty: exp.Properties.Location.POST_SCHEMA,
@ -204,13 +217,15 @@ class Generator:
exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA,
exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA,
exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA,
exp.Set: exp.Properties.Location.POST_SCHEMA,
exp.SettingsProperty: exp.Properties.Location.POST_SCHEMA,
exp.SetProperty: exp.Properties.Location.POST_CREATE,
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA,
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA,
exp.TableFormatProperty: exp.Properties.Location.POST_WITH,
exp.TemporaryProperty: exp.Properties.Location.POST_CREATE,
exp.TransientProperty: exp.Properties.Location.POST_CREATE,
exp.MergeTreeTTL: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION,
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
@ -221,7 +236,7 @@ class Generator:
RESERVED_KEYWORDS: t.Set[str] = set()
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.With)
UNWRAPPED_INTERVAL_VALUES = (exp.Literal, exp.Paren, exp.Column)
UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Neg, exp.Paren)
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
@ -239,6 +254,8 @@ class Generator:
"hex_end",
"byte_start",
"byte_end",
"raw_start",
"raw_end",
"identify",
"normalize",
"string_escape",
@ -276,6 +293,8 @@ class Generator:
hex_end=None,
byte_start=None,
byte_end=None,
raw_start=None,
raw_end=None,
identify=False,
normalize=False,
string_escape=None,
@ -308,6 +327,8 @@ class Generator:
self.hex_end = hex_end
self.byte_start = byte_start
self.byte_end = byte_end
self.raw_start = raw_start
self.raw_end = raw_end
self.identify = identify
self.normalize = normalize
self.string_escape = string_escape or "'"
@ -399,7 +420,11 @@ class Generator:
return sql
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
return f"{comments_sql}{self.sep()}{sql}"
return (
f"{self.sep()}{comments_sql}{sql}"
if sql[0].isspace()
else f"{comments_sql}{self.sep()}{sql}"
)
return f"{sql} {comments_sql}"
@ -567,7 +592,9 @@ class Generator:
) -> str:
this = ""
if expression.this is not None:
this = " ALWAYS " if expression.this else " BY DEFAULT "
on_null = "ON NULL " if expression.args.get("on_null") else ""
this = " ALWAYS " if expression.this else f" BY DEFAULT {on_null}"
start = expression.args.get("start")
start = f"START WITH {start}" if start else ""
increment = expression.args.get("increment")
@ -578,14 +605,20 @@ class Generator:
maxvalue = f" MAXVALUE {maxvalue}" if maxvalue else ""
cycle = expression.args.get("cycle")
cycle_sql = ""
if cycle is not None:
cycle_sql = f"{' NO' if not cycle else ''} CYCLE"
cycle_sql = cycle_sql.strip() if not start and not increment else cycle_sql
sequence_opts = ""
if start or increment or cycle_sql:
sequence_opts = f"{start}{increment}{minvalue}{maxvalue}{cycle_sql}"
sequence_opts = f" ({sequence_opts.strip()})"
return f"GENERATED{this}AS IDENTITY{sequence_opts}"
expr = self.sql(expression, "expression")
expr = f"({expr})" if expr else "IDENTITY"
return f"GENERATED{this}AS {expr}{sequence_opts}"
def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str:
return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL"
@ -596,8 +629,10 @@ class Generator:
return f"PRIMARY KEY{' DESC' if desc else ' ASC'}"
return f"PRIMARY KEY"
def uniquecolumnconstraint_sql(self, _) -> str:
return "UNIQUE"
def uniquecolumnconstraint_sql(self, expression: exp.UniqueColumnConstraint) -> str:
this = self.sql(expression, "this")
this = f" {this}" if this else ""
return f"UNIQUE{this}"
def create_sql(self, expression: exp.Create) -> str:
kind = self.sql(expression, "kind").upper()
@ -653,33 +688,9 @@ class Generator:
prefix=" ",
)
indexes = expression.args.get("indexes")
if indexes:
indexes_sql: t.List[str] = []
for index in indexes:
ind_unique = " UNIQUE" if index.args.get("unique") else ""
ind_primary = " PRIMARY" if index.args.get("primary") else ""
ind_amp = " AMP" if index.args.get("amp") else ""
ind_name = f" {index.name}" if index.name else ""
ind_columns = (
f' ({self.expressions(index, key="columns", flat=True)})'
if index.args.get("columns")
else ""
)
ind_sql = f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}"
if indexes_sql:
indexes_sql.append(ind_sql)
else:
indexes_sql.append(
f"{ind_sql}{postindex_props_sql}"
if index.args.get("primary")
else f"{postindex_props_sql}{ind_sql}"
)
index_sql = "".join(indexes_sql)
else:
index_sql = postindex_props_sql
indexes = self.expressions(expression, key="indexes", indent=False, sep=" ")
indexes = f" {indexes}" if indexes else ""
index_sql = indexes + postindex_props_sql
replace = " OR REPLACE" if expression.args.get("replace") else ""
unique = " UNIQUE" if expression.args.get("unique") else ""
@ -711,9 +722,23 @@ class Generator:
" WITH NO SCHEMA BINDING" if expression.args.get("no_schema_binding") else ""
)
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{postexpression_props_sql}{index_sql}{no_schema_binding}"
clone = self.sql(expression, "clone")
clone = f" {clone}" if clone else ""
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{postexpression_props_sql}{index_sql}{no_schema_binding}{clone}"
return self.prepend_ctes(expression, expression_sql)
def clone_sql(self, expression: exp.Clone) -> str:
this = self.sql(expression, "this")
when = self.sql(expression, "when")
if when:
kind = self.sql(expression, "kind")
expr = self.sql(expression, "expression")
return f"CLONE {this} {when} ({kind} => {expr})"
return f"CLONE {this}"
def describe_sql(self, expression: exp.Describe) -> str:
return f"DESCRIBE {self.sql(expression, 'this')}"
@ -757,6 +782,17 @@ class Generator:
return f"{self.byte_start}{this}{self.byte_end}"
return this
def rawstring_sql(self, expression: exp.RawString) -> str:
if self.raw_start:
return f"{self.raw_start}{expression.name}{self.raw_end}"
return self.sql(exp.Literal.string(expression.name.replace("\\", "\\\\")))
def datatypesize_sql(self, expression: exp.DataTypeSize) -> str:
this = self.sql(expression, "this")
specifier = self.sql(expression, "expression")
specifier = f" {specifier}" if specifier else ""
return f"{this}{specifier}"
def datatype_sql(self, expression: exp.DataType) -> str:
type_value = expression.this
type_sql = self.TYPE_MAPPING.get(type_value, type_value.value)
@ -768,7 +804,8 @@ class Generator:
nested = f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}"
if expression.args.get("values") is not None:
delimiters = ("[", "]") if type_value == exp.DataType.Type.ARRAY else ("(", ")")
values = f"{delimiters[0]}{self.expressions(expression, key='values')}{delimiters[1]}"
values = self.expressions(expression, key="values", flat=True)
values = f"{delimiters[0]}{values}{delimiters[1]}"
else:
nested = f"({interior})"
@ -836,10 +873,17 @@ class Generator:
return ""
def index_sql(self, expression: exp.Index) -> str:
this = self.sql(expression, "this")
unique = "UNIQUE " if expression.args.get("unique") else ""
primary = "PRIMARY " if expression.args.get("primary") else ""
amp = "AMP " if expression.args.get("amp") else ""
name = f"{expression.name} " if expression.name else ""
table = self.sql(expression, "table")
columns = self.sql(expression, "columns")
return f"{this} ON {table} {columns}"
table = f"{self.INDEX_ON} {table} " if table else ""
index = "INDEX " if not table else ""
columns = self.expressions(expression, key="columns", flat=True)
partition_by = self.expressions(expression, key="partition_by", flat=True)
partition_by = f" PARTITION BY {partition_by}" if partition_by else ""
return f"{unique}{primary}{amp}{index}{name}{table}({columns}){partition_by}"
def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
@ -861,8 +905,9 @@ class Generator:
output_format = f"OUTPUTFORMAT {output_format}" if output_format else ""
return self.sep().join((input_format, output_format))
def national_sql(self, expression: exp.National) -> str:
return f"N{self.sql(expression, 'this')}"
def national_sql(self, expression: exp.National, prefix: str = "N") -> str:
string = self.sql(exp.Literal.string(expression.name))
return f"{prefix}{string}"
def partition_sql(self, expression: exp.Partition) -> str:
return f"PARTITION({self.expressions(expression)})"
@ -955,23 +1000,18 @@ class Generator:
def journalproperty_sql(self, expression: exp.JournalProperty) -> str:
no = "NO " if expression.args.get("no") else ""
local = expression.args.get("local")
local = f"{local} " if local else ""
dual = "DUAL " if expression.args.get("dual") else ""
before = "BEFORE " if expression.args.get("before") else ""
return f"{no}{dual}{before}JOURNAL"
after = "AFTER " if expression.args.get("after") else ""
return f"{no}{local}{dual}{before}{after}JOURNAL"
def freespaceproperty_sql(self, expression: exp.FreespaceProperty) -> str:
freespace = self.sql(expression, "this")
percent = " PERCENT" if expression.args.get("percent") else ""
return f"FREESPACE={freespace}{percent}"
def afterjournalproperty_sql(self, expression: exp.AfterJournalProperty) -> str:
no = "NO " if expression.args.get("no") else ""
dual = "DUAL " if expression.args.get("dual") else ""
local = ""
if expression.args.get("local") is not None:
local = "LOCAL " if expression.args.get("local") else "NOT LOCAL "
return f"{no}{dual}{local}AFTER JOURNAL"
def checksumproperty_sql(self, expression: exp.ChecksumProperty) -> str:
if expression.args.get("default"):
property = "DEFAULT"
@ -992,19 +1032,19 @@ class Generator:
def datablocksizeproperty_sql(self, expression: exp.DataBlocksizeProperty) -> str:
default = expression.args.get("default")
min = expression.args.get("min")
if default is not None or min is not None:
minimum = expression.args.get("minimum")
maximum = expression.args.get("maximum")
if default or minimum or maximum:
if default:
property = "DEFAULT"
elif min:
property = "MINIMUM"
prop = "DEFAULT"
elif minimum:
prop = "MINIMUM"
else:
property = "MAXIMUM"
return f"{property} DATABLOCKSIZE"
else:
units = expression.args.get("units")
units = f" {units}" if units else ""
return f"DATABLOCKSIZE={self.sql(expression, 'size')}{units}"
prop = "MAXIMUM"
return f"{prop} DATABLOCKSIZE"
units = expression.args.get("units")
units = f" {units}" if units else ""
return f"DATABLOCKSIZE={self.sql(expression, 'size')}{units}"
def blockcompressionproperty_sql(self, expression: exp.BlockCompressionProperty) -> str:
autotemp = expression.args.get("autotemp")
@ -1014,16 +1054,16 @@ class Generator:
never = expression.args.get("never")
if autotemp is not None:
property = f"AUTOTEMP({self.expressions(autotemp)})"
prop = f"AUTOTEMP({self.expressions(autotemp)})"
elif always:
property = "ALWAYS"
prop = "ALWAYS"
elif default:
property = "DEFAULT"
prop = "DEFAULT"
elif manual:
property = "MANUAL"
prop = "MANUAL"
elif never:
property = "NEVER"
return f"BLOCKCOMPRESSION={property}"
prop = "NEVER"
return f"BLOCKCOMPRESSION={prop}"
def isolatedloadingproperty_sql(self, expression: exp.IsolatedLoadingProperty) -> str:
no = expression.args.get("no")
@ -1138,21 +1178,24 @@ class Generator:
alias = self.sql(expression, "alias")
alias = f"{sep}{alias}" if alias else ""
hints = self.expressions(expression, key="hints", sep=", ", flat=True)
hints = self.expressions(expression, key="hints", flat=True)
hints = f" WITH ({hints})" if hints and self.TABLE_HINTS else ""
laterals = self.expressions(expression, key="laterals", sep="")
pivots = self.expressions(expression, key="pivots", sep=" ", flat=True)
pivots = f" {pivots}" if pivots else ""
joins = self.expressions(expression, key="joins", sep="")
pivots = self.expressions(expression, key="pivots", sep="")
laterals = self.expressions(expression, key="laterals", sep="")
system_time = expression.args.get("system_time")
system_time = f" {self.sql(expression, 'system_time')}" if system_time else ""
return f"{table}{system_time}{alias}{hints}{laterals}{joins}{pivots}"
return f"{table}{system_time}{alias}{hints}{pivots}{joins}{laterals}"
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
) -> str:
if self.alias_post_tablesample and expression.this.alias:
this = self.sql(expression.this, "this")
table = expression.this.copy()
table.set("alias", None)
this = self.sql(table)
alias = f"{sep}{self.sql(expression.this, 'alias')}"
else:
this = self.sql(expression, "this")
@ -1177,14 +1220,22 @@ class Generator:
return f"{this} {kind} {method}({bucket}{percent}{rows}{size}){seed}{alias}"
def pivot_sql(self, expression: exp.Pivot) -> str:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
if expression.this:
this = self.sql(expression, "this")
on = f"{self.seg('ON')} {expressions}"
using = self.expressions(expression, key="using", flat=True)
using = f"{self.seg('USING')} {using}" if using else ""
group = self.sql(expression, "group")
return f"PIVOT {this}{on}{using}{group}"
alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else ""
unpivot = expression.args.get("unpivot")
direction = "UNPIVOT" if unpivot else "PIVOT"
expressions = self.expressions(expression, key="expressions")
field = self.sql(expression, "field")
return f"{this} {direction}({expressions} FOR {field}){alias}"
return f"{direction}({expressions} FOR {field}){alias}"
def tuple_sql(self, expression: exp.Tuple) -> str:
return f"({self.expressions(expression, flat=True)})"
@ -1218,8 +1269,7 @@ class Generator:
return f"{self.seg('INTO')}{temporary or unlogged} {self.sql(expression, 'this')}"
def from_sql(self, expression: exp.From) -> str:
expressions = self.expressions(expression, flat=True)
return f"{self.seg('FROM')} {expressions}"
return f"{self.seg('FROM')} {self.sql(expression, 'this')}"
def group_sql(self, expression: exp.Group) -> str:
group_by = self.op_expressions("GROUP BY", expression)
@ -1242,10 +1292,16 @@ class Generator:
rollup_sql = self.expressions(expression, key="rollup", indent=False)
rollup_sql = f"{self.seg('ROLLUP')} {self.wrap(rollup_sql)}" if rollup_sql else ""
groupings = csv(grouping_sets, cube_sql, rollup_sql, sep=",")
groupings = csv(
grouping_sets,
cube_sql,
rollup_sql,
self.seg("WITH TOTALS") if expression.args.get("totals") else "",
sep=self.GROUPINGS_SEP,
)
if expression.args.get("expressions") and groupings:
group_by = f"{group_by},"
group_by = f"{group_by}{self.GROUPINGS_SEP}"
return f"{group_by}{groupings}"
@ -1254,18 +1310,16 @@ class Generator:
return f"{self.seg('HAVING')}{self.sep()}{this}"
def join_sql(self, expression: exp.Join) -> str:
op_sql = self.seg(
" ".join(
op
for op in (
"NATURAL" if expression.args.get("natural") else None,
expression.side,
expression.kind,
expression.hint if self.JOIN_HINTS else None,
"JOIN",
)
if op
op_sql = " ".join(
op
for op in (
"NATURAL" if expression.args.get("natural") else None,
"GLOBAL" if expression.args.get("global") else None,
expression.side,
expression.kind,
expression.hint if self.JOIN_HINTS else None,
)
if op
)
on_sql = self.sql(expression, "on")
using = expression.args.get("using")
@ -1273,6 +1327,8 @@ class Generator:
if not on_sql and using:
on_sql = csv(*(self.sql(column) for column in using))
this_sql = self.sql(expression, "this")
if on_sql:
on_sql = self.indent(on_sql, skip_first=True)
space = self.seg(" " * self.pad) if self.pretty else " "
@ -1280,10 +1336,11 @@ class Generator:
on_sql = f"{space}USING ({on_sql})"
else:
on_sql = f"{space}ON {on_sql}"
elif not op_sql:
return f", {this_sql}"
expression_sql = self.sql(expression, "expression")
this_sql = self.sql(expression, "this")
return f"{expression_sql}{op_sql} {this_sql}{on_sql}"
op_sql = f"{op_sql} JOIN" if op_sql else "JOIN"
return f"{self.seg(op_sql)} {this_sql}{on_sql}"
def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str:
args = self.expressions(expression, flat=True)
@ -1336,12 +1393,22 @@ class Generator:
return f"PRAGMA {self.sql(expression, 'this')}"
def lock_sql(self, expression: exp.Lock) -> str:
if self.LOCKING_READS_SUPPORTED:
lock_type = "UPDATE" if expression.args["update"] else "SHARE"
return self.seg(f"FOR {lock_type}")
if not self.LOCKING_READS_SUPPORTED:
self.unsupported("Locking reads using 'FOR UPDATE/SHARE' are not supported")
return ""
self.unsupported("Locking reads using 'FOR UPDATE/SHARE' are not supported")
return ""
lock_type = "FOR UPDATE" if expression.args["update"] else "FOR SHARE"
expressions = self.expressions(expression, flat=True)
expressions = f" OF {expressions}" if expressions else ""
wait = expression.args.get("wait")
if wait is not None:
if isinstance(wait, exp.Literal):
wait = f" WAIT {self.sql(wait)}"
else:
wait = " NOWAIT" if wait else " SKIP LOCKED"
return f"{lock_type}{expressions}{wait or ''}"
def literal_sql(self, expression: exp.Literal) -> str:
text = expression.this or ""
@ -1460,26 +1527,32 @@ class Generator:
return csv(
*sqls,
*[self.sql(sql) for sql in expression.args.get("joins") or []],
*[self.sql(join) for join in expression.args.get("joins") or []],
self.sql(expression, "match"),
*[self.sql(sql) for sql in expression.args.get("laterals") or []],
*[self.sql(lateral) for lateral in expression.args.get("laterals") or []],
self.sql(expression, "where"),
self.sql(expression, "group"),
self.sql(expression, "having"),
*self.after_having_modifiers(expression),
self.sql(expression, "order"),
self.sql(expression, "offset") if fetch else self.sql(limit),
self.sql(limit) if fetch else self.sql(expression, "offset"),
*self.after_limit_modifiers(expression),
sep="",
)
def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
return [
self.sql(expression, "qualify"),
self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True)
if expression.args.get("windows")
else "",
self.sql(expression, "distribute"),
self.sql(expression, "sort"),
self.sql(expression, "cluster"),
self.sql(expression, "order"),
self.sql(expression, "offset") if fetch else self.sql(limit),
self.sql(limit) if fetch else self.sql(expression, "offset"),
self.sql(expression, "lock"),
self.sql(expression, "sample"),
sep="",
)
]
def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]:
locks = self.expressions(expression, key="locks", sep=" ")
locks = f" {locks}" if locks else ""
return [locks, self.sql(expression, "sample")]
def select_sql(self, expression: exp.Select) -> str:
hint = self.sql(expression, "hint")
@ -1529,13 +1602,10 @@ class Generator:
alias = self.sql(expression, "alias")
alias = f"{sep}{alias}" if alias else ""
sql = self.query_modifiers(
expression,
self.wrap(expression),
alias,
self.expressions(expression, key="pivots", sep=" "),
)
pivots = self.expressions(expression, key="pivots", sep=" ", flat=True)
pivots = f" {pivots}" if pivots else ""
sql = self.query_modifiers(expression, self.wrap(expression), alias, pivots)
return self.prepend_ctes(expression, sql)
def qualify_sql(self, expression: exp.Qualify) -> str:
@ -1712,10 +1782,6 @@ class Generator:
options = f" {options}" if options else ""
return f"PRIMARY KEY ({expressions}){options}"
def unique_sql(self, expression: exp.Unique) -> str:
columns = self.expressions(expression, key="expressions")
return f"UNIQUE ({columns})"
def if_sql(self, expression: exp.If) -> str:
return self.case_sql(
exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))
@ -1745,6 +1811,26 @@ class Generator:
encoding = f" ENCODING {encoding}" if encoding else ""
return f"JSON_OBJECT({expressions}{null_handling}{unique_keys}{return_type}{format_json}{encoding})"
def openjsoncolumndef_sql(self, expression: exp.OpenJSONColumnDef) -> str:
this = self.sql(expression, "this")
kind = self.sql(expression, "kind")
path = self.sql(expression, "path")
path = f" {path}" if path else ""
as_json = " AS JSON" if expression.args.get("as_json") else ""
return f"{this} {kind}{path}{as_json}"
def openjson_sql(self, expression: exp.OpenJSON) -> str:
this = self.sql(expression, "this")
path = self.sql(expression, "path")
path = f", {path}" if path else ""
expressions = self.expressions(expression)
with_ = (
f" WITH ({self.seg(self.indent(expressions), sep='')}{self.seg(')', sep='')}"
if expressions
else ""
)
return f"OPENJSON({this}{path}){with_}"
def in_sql(self, expression: exp.In) -> str:
query = expression.args.get("query")
unnest = expression.args.get("unnest")
@ -1773,7 +1859,7 @@ class Generator:
if self.SINGLE_STRING_INTERVAL:
this = expression.this.name if expression.this else ""
return f"INTERVAL '{this}{unit}'"
return f"INTERVAL '{this}{unit}'" if this else f"INTERVAL{unit}"
this = self.sql(expression, "this")
if this:
@ -1883,6 +1969,28 @@ class Generator:
expression_sql = self.sql(expression, "expression")
return f"COMMENT{exists_sql}ON {kind} {this} IS {expression_sql}"
def mergetreettlaction_sql(self, expression: exp.MergeTreeTTLAction) -> str:
this = self.sql(expression, "this")
delete = " DELETE" if expression.args.get("delete") else ""
recompress = self.sql(expression, "recompress")
recompress = f" RECOMPRESS {recompress}" if recompress else ""
to_disk = self.sql(expression, "to_disk")
to_disk = f" TO DISK {to_disk}" if to_disk else ""
to_volume = self.sql(expression, "to_volume")
to_volume = f" TO VOLUME {to_volume}" if to_volume else ""
return f"{this}{delete}{recompress}{to_disk}{to_volume}"
def mergetreettl_sql(self, expression: exp.MergeTreeTTL) -> str:
where = self.sql(expression, "where")
group = self.sql(expression, "group")
aggregates = self.expressions(expression, key="aggregates")
aggregates = self.seg("SET") + self.seg(aggregates) if aggregates else ""
if not (where or group or aggregates) and len(expression.expressions) == 1:
return f"TTL {self.expressions(expression, flat=True)}"
return f"TTL{self.seg(self.expressions(expression))}{where}{group}{aggregates}"
def transaction_sql(self, expression: exp.Transaction) -> str:
return "BEGIN"
@ -1919,6 +2027,11 @@ class Generator:
return f"ALTER COLUMN {this} DROP DEFAULT"
def renametable_sql(self, expression: exp.RenameTable) -> str:
if not self.RENAME_TABLE_WITH_DB:
# Remove db from tables
expression = expression.transform(
lambda n: exp.table_(n.this) if isinstance(n, exp.Table) else n
)
this = self.sql(expression, "this")
return f"RENAME TO {this}"
@ -2208,3 +2321,12 @@ class Generator:
self.unsupported("Format argument unsupported for TO_CHAR/TO_VARCHAR function")
return self.sql(exp.cast(expression.this, "text"))
def cached_generator(
cache: t.Optional[t.Dict[int, str]] = None
) -> t.Callable[[exp.Expression], str]:
"""Returns a cached generator."""
cache = {} if cache is None else cache
generator = Generator(normalize=True, identify="safe")
return lambda e: generator.generate(e, cache)

View file

@ -9,14 +9,14 @@ from collections.abc import Collection
from contextlib import contextmanager
from copy import copy
from enum import Enum
from itertools import count
if t.TYPE_CHECKING:
from sqlglot import exp
from sqlglot._typing import E, T
from sqlglot.dialects.dialect import DialectType
from sqlglot.expressions import Expression
T = t.TypeVar("T")
E = t.TypeVar("E", bound=Expression)
CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
PYTHON_VERSION = sys.version_info[:2]
logger = logging.getLogger("sqlglot")
@ -25,7 +25,7 @@ logger = logging.getLogger("sqlglot")
class AutoName(Enum):
"""This is used for creating enum classes where `auto()` is the string form of the corresponding value's name."""
def _generate_next_value_(name, _start, _count, _last_values): # type: ignore
def _generate_next_value_(name, _start, _count, _last_values):
return name
@ -92,7 +92,7 @@ def ensure_collection(value):
)
def csv(*args, sep: str = ", ") -> str:
def csv(*args: str, sep: str = ", ") -> str:
"""
Formats any number of string arguments as CSV.
@ -304,9 +304,18 @@ def find_new_name(taken: t.Collection[str], base: str) -> str:
return new
def name_sequence(prefix: str) -> t.Callable[[], str]:
"""Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
sequence = count()
return lambda: f"{prefix}{next(sequence)}"
def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
"""Returns a dictionary created from an object's attributes."""
return {**{k: copy(v) for k, v in vars(obj).copy().items()}, **kwargs}
return {
**{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
**kwargs,
}
def split_num_words(
@ -381,15 +390,6 @@ def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
yield value
def count_params(function: t.Callable) -> int:
"""
Returns the number of formal parameters expected by a function, without counting "self"
and "cls", in case of instance and class methods, respectively.
"""
count = function.__code__.co_argcount
return count - 1 if inspect.ismethod(function) else count
def dict_depth(d: t.Dict) -> int:
"""
Get the nesting depth of a dictionary.
@ -430,12 +430,23 @@ def first(it: t.Iterable[T]) -> T:
return next(i for i in it)
def should_identify(text: str, identify: str | bool) -> bool:
def case_sensitive(text: str, dialect: DialectType) -> bool:
"""Checks if text contains any case sensitive characters depending on dialect."""
from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE
unsafe = str.islower if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
return any(unsafe(char) for char in text)
def should_identify(text: str, identify: str | bool, dialect: DialectType = None) -> bool:
"""Checks if text should be identified given an identify option.
Args:
text: the text to check.
identify: "always" | True - always returns true, "safe" - true if no upper case
identify:
"always" or `True`: always returns true.
"safe": true if there is no uppercase or lowercase character in `text`, depending on `dialect`.
dialect: the dialect to use in order to decide whether a text should be identified.
Returns:
Whether or not a string should be identified.
@ -443,5 +454,5 @@ def should_identify(text: str, identify: str | bool) -> bool:
if identify is True or identify == "always":
return True
if identify == "safe":
return not any(char.isupper() for char in text)
return not case_sensitive(text, dialect)
return False

View file

@ -5,10 +5,8 @@ import typing as t
from dataclasses import dataclass, field
from sqlglot import Schema, exp, maybe_parse
from sqlglot.optimizer import Scope, build_scope, optimize
from sqlglot.optimizer.expand_laterals import expand_laterals
from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.errors import SqlglotError
from sqlglot.optimizer import Scope, build_scope, qualify
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
@ -40,8 +38,8 @@ def lineage(
sql: str | exp.Expression,
schema: t.Optional[t.Dict | Schema] = None,
sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns, expand_laterals),
dialect: DialectType = None,
**kwargs,
) -> Node:
"""Build the lineage graph for a column of a SQL query.
@ -50,8 +48,8 @@ def lineage(
sql: The SQL string or expression.
schema: The schema of tables.
sources: A mapping of queries which will be used to continue building lineage.
rules: Optimizer rules to apply, by default only qualifying tables and columns.
dialect: The dialect of input SQL.
**kwargs: Qualification optimizer kwargs.
Returns:
A lineage node.
@ -68,8 +66,17 @@ def lineage(
},
)
optimized = optimize(expression, schema=schema, rules=rules)
scope = build_scope(optimized)
qualified = qualify.qualify(
expression,
dialect=dialect,
schema=schema,
**{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore
)
scope = build_scope(qualified)
if not scope:
raise SqlglotError("Cannot build lineage, sql must be SELECT")
def to_node(
column_name: str,
@ -109,10 +116,7 @@ def lineage(
# a version that has only the column we care about.
# "x", SELECT x, y FROM foo
# => "x", SELECT x FROM foo
source = optimize(
scope.expression.select(select, append=False), schema=schema, rules=rules
)
select = source.selects[0]
source = t.cast(exp.Expression, scope.expression.select(select, append=False))
else:
source = scope.expression

View file

@ -3,10 +3,9 @@ from __future__ import annotations
import itertools
from sqlglot import exp
from sqlglot.helper import should_identify
def canonicalize(expression: exp.Expression, identify: str = "safe") -> exp.Expression:
def canonicalize(expression: exp.Expression) -> exp.Expression:
"""Converts a sql expression into a standard form.
This method relies on annotate_types because many of the
@ -14,19 +13,14 @@ def canonicalize(expression: exp.Expression, identify: str = "safe") -> exp.Expr
Args:
expression: The expression to canonicalize.
identify: Whether or not to force identify identifier.
"""
exp.replace_children(expression, canonicalize, identify=identify)
exp.replace_children(expression, canonicalize)
expression = add_text_to_concat(expression)
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
expression = ensure_bool_predicates(expression)
if isinstance(expression, exp.Identifier):
if should_identify(expression.this, identify):
expression.set("quoted", True)
return expression

View file

@ -19,24 +19,25 @@ def eliminate_ctes(expression):
"""
root = build_scope(expression)
ref_count = root.ref_count()
if root:
ref_count = root.ref_count()
# Traverse the scope tree in reverse so we can remove chains of unused CTEs
for scope in reversed(list(root.traverse())):
if scope.is_cte:
count = ref_count[id(scope)]
if count <= 0:
cte_node = scope.expression.parent
with_node = cte_node.parent
cte_node.pop()
# Traverse the scope tree in reverse so we can remove chains of unused CTEs
for scope in reversed(list(root.traverse())):
if scope.is_cte:
count = ref_count[id(scope)]
if count <= 0:
cte_node = scope.expression.parent
with_node = cte_node.parent
cte_node.pop()
# Pop the entire WITH clause if this is the last CTE
if len(with_node.expressions) <= 0:
with_node.pop()
# Pop the entire WITH clause if this is the last CTE
if len(with_node.expressions) <= 0:
with_node.pop()
# Decrement the ref count for all sources this CTE selects from
for _, source in scope.selected_sources.values():
if isinstance(source, Scope):
ref_count[id(source)] -= 1
# Decrement the ref count for all sources this CTE selects from
for _, source in scope.selected_sources.values():
if isinstance(source, Scope):
ref_count[id(source)] -= 1
return expression

View file

@ -16,9 +16,9 @@ def eliminate_subqueries(expression):
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
This also deduplicates common subqueries:
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z")
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z")
>>> eliminate_subqueries(expression).sql()
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z'
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z'
Args:
expression (sqlglot.Expression): expression
@ -32,6 +32,9 @@ def eliminate_subqueries(expression):
root = build_scope(expression)
if not root:
return expression
# Map of alias->Scope|Table
# These are all aliases that are already used in the expression.
# We don't want to create new CTEs that conflict with these names.
@ -112,7 +115,7 @@ def _eliminate_union(scope, existing_ctes, taken):
# Try to maintain the selections
expressions = scope.selects
selects = [
exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name)
exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name, copy=False)
for e in expressions
if e.alias_or_name
]
@ -120,7 +123,9 @@ def _eliminate_union(scope, existing_ctes, taken):
if len(selects) != len(expressions):
selects = ["*"]
scope.expression.replace(exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias)))
scope.expression.replace(
exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias, copy=False))
)
if not duplicate_cte_alias:
existing_ctes[scope.expression] = alias
@ -131,6 +136,10 @@ def _eliminate_union(scope, existing_ctes, taken):
def _eliminate_derived_table(scope, existing_ctes, taken):
# This ensures we don't drop the "pivot" arg from a pivoted subquery
if scope.parent.pivots:
return None
parent = scope.expression.parent
name, cte = _new_cte(scope, existing_ctes, taken)
@ -153,7 +162,7 @@ def _eliminate_cte(scope, existing_ctes, taken):
for child_scope in scope.parent.traverse():
for table, source in child_scope.selected_sources.values():
if source is scope:
new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name)
new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name, copy=False)
table.replace(new_table)
return cte

View file

@ -1,34 +0,0 @@
from __future__ import annotations
import typing as t
from sqlglot import exp
def expand_laterals(expression: exp.Expression) -> exp.Expression:
"""
Expand lateral column alias references.
This assumes `qualify_columns` as already run.
Example:
>>> import sqlglot
>>> sql = "SELECT x.a + 1 AS b, b + 1 AS c FROM x"
>>> expression = sqlglot.parse_one(sql)
>>> expand_laterals(expression).sql()
'SELECT x.a + 1 AS b, x.a + 1 + 1 AS c FROM x'
Args:
expression: expression to optimize
Returns:
optimized expression
"""
for select in expression.find_all(exp.Select):
alias_to_expression: t.Dict[str, exp.Expression] = {}
for projection in select.expressions:
for column in projection.find_all(exp.Column):
if not column.table and column.name in alias_to_expression:
column.replace(alias_to_expression[column.name].copy())
if isinstance(projection, exp.Alias):
alias_to_expression[projection.alias] = projection.this
return expression

View file

@ -1,24 +0,0 @@
from sqlglot import exp
def expand_multi_table_selects(expression):
"""
Replace multiple FROM expressions with JOINs.
Example:
>>> from sqlglot import parse_one
>>> expand_multi_table_selects(parse_one("SELECT * FROM x, y")).sql()
'SELECT * FROM x CROSS JOIN y'
"""
for from_ in expression.find_all(exp.From):
parent = from_.parent
for query in from_.expressions[1:]:
parent.join(
query,
join_type="CROSS",
copy=False,
)
from_.expressions.remove(query)
return expression

View file

@ -21,7 +21,7 @@ def isolate_table_selects(expression, schema=None):
source.replace(
exp.select("*")
.from_(
alias(source.copy(), source.name or source.alias, table=True),
alias(source, source.name or source.alias, table=True),
copy=False,
)
.subquery(source.alias, copy=False)

View file

@ -1,88 +0,0 @@
from sqlglot import exp
def lower_identities(expression):
"""
Convert all unquoted identifiers to lower case.
Assuming the schema is all lower case, this essentially makes identifiers case-insensitive.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
>>> lower_identities(expression).sql()
'SELECT bar.a AS A FROM "Foo".bar'
Args:
expression (sqlglot.Expression): expression to quote
Returns:
sqlglot.Expression: quoted expression
"""
# We need to leave the output aliases unchanged, so the selects need special handling
_lower_selects(expression)
# These clauses can reference output aliases and also need special handling
_lower_order(expression)
_lower_having(expression)
# We've already handled these args, so don't traverse into them
traversed = {"expressions", "order", "having"}
if isinstance(expression, exp.Subquery):
# Root subquery, e.g. (SELECT A AS A FROM X) LIMIT 1
lower_identities(expression.this)
traversed |= {"this"}
if isinstance(expression, exp.Union):
# Union, e.g. SELECT A AS A FROM X UNION SELECT A AS A FROM X
lower_identities(expression.left)
lower_identities(expression.right)
traversed |= {"this", "expression"}
for k, v in expression.iter_expressions():
if k in traversed:
continue
v.transform(_lower, copy=False)
return expression
def _lower_selects(expression):
for e in expression.expressions:
# Leave output aliases as-is
e.unalias().transform(_lower, copy=False)
def _lower_order(expression):
order = expression.args.get("order")
if not order:
return
output_aliases = {e.alias for e in expression.expressions if isinstance(e, exp.Alias)}
for ordered in order.expressions:
# Don't lower references to output aliases
if not (
isinstance(ordered.this, exp.Column)
and not ordered.this.table
and ordered.this.name in output_aliases
):
ordered.transform(_lower, copy=False)
def _lower_having(expression):
having = expression.args.get("having")
if not having:
return
# Don't lower references to output aliases
for agg in having.find_all(exp.AggFunc):
agg.transform(_lower, copy=False)
def _lower(node):
if isinstance(node, exp.Identifier) and not node.quoted:
node.set("this", node.this.lower())
return node

View file

@ -13,15 +13,15 @@ def merge_subqueries(expression, leave_tables_isolated=False):
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
>>> merge_subqueries(expression).sql()
'SELECT x.a FROM x JOIN y'
'SELECT x.a FROM x CROSS JOIN y'
If `leave_tables_isolated` is True, this will not merge inner queries into outer
queries if it would result in multiple table selects in a single query:
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
>>> merge_subqueries(expression, leave_tables_isolated=True).sql()
'SELECT a FROM (SELECT x.a FROM x) JOIN y'
'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y'
Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
@ -154,7 +154,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
inner_from = inner_scope.expression.args.get("from")
if not inner_from:
return False
inner_from_table = inner_from.expressions[0].alias_or_name
inner_from_table = inner_from.alias_or_name
inner_projections = {s.alias_or_name: s for s in inner_scope.selects}
return any(
col.table != inner_from_table
@ -167,6 +167,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
and isinstance(inner_select, exp.Select)
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
and inner_select.args.get("from")
and not outer_scope.pivots
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
and not (
@ -210,7 +211,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
elif isinstance(source, exp.Table) and source.alias:
source.set("alias", new_alias)
elif isinstance(source, exp.Table):
source.replace(exp.alias_(source.copy(), new_alias))
source.replace(exp.alias_(source, new_alias))
for column in inner_scope.source_columns(conflict):
column.set("table", exp.to_identifier(new_name))
@ -228,7 +229,7 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
node_to_replace (exp.Subquery|exp.Table)
alias (str)
"""
new_subquery = inner_scope.expression.args.get("from").expressions[0]
new_subquery = inner_scope.expression.args["from"].this
node_to_replace.replace(new_subquery)
for join_hint in outer_scope.join_hints:
tables = join_hint.find_all(exp.Table)
@ -319,7 +320,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
# Merge predicates from an outer join to the ON clause
# if it only has columns that are already joined
from_ = expression.args.get("from")
sources = {table.alias_or_name for table in from_.expressions} if from_ else {}
sources = {from_.alias_or_name} if from_ else {}
for join in expression.args["joins"]:
source = join.alias_or_name

View file

@ -1,12 +1,12 @@
from __future__ import annotations
import logging
import typing as t
from sqlglot import exp
from sqlglot.errors import OptimizeError
from sqlglot.generator import cached_generator
from sqlglot.helper import while_changing
from sqlglot.optimizer.simplify import flatten, uniq_sort
from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
logger = logging.getLogger("sqlglot")
@ -28,13 +28,16 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
Returns:
sqlglot.Expression: normalized expression
"""
cache: t.Dict[int, str] = {}
generate = cached_generator()
for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
if isinstance(node, exp.Connector):
if normalized(node, dnf=dnf):
continue
root = node is expression
original = node.copy()
node.transform(rewrite_between, copy=False)
distance = normalization_distance(node, dnf=dnf)
if distance > max_distance:
@ -43,11 +46,9 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
)
return expression
root = node is expression
original = node.copy()
try:
node = node.replace(
while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate))
)
except OptimizeError as e:
logger.info(e)
@ -111,7 +112,7 @@ def _predicate_lengths(expression, dnf):
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
def distributive_law(expression, dnf, max_distance, cache=None):
def distributive_law(expression, dnf, max_distance, generate):
"""
x OR (y AND z) -> (x OR y) AND (x OR z)
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
@ -124,7 +125,7 @@ def distributive_law(expression, dnf, max_distance, cache=None):
if distance > max_distance:
raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache))
exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate))
to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
if isinstance(expression, from_exp):
@ -135,30 +136,30 @@ def distributive_law(expression, dnf, max_distance, cache=None):
if isinstance(a, to_exp) and isinstance(b, to_exp):
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
return _distribute(a, b, from_func, to_func, cache)
return _distribute(b, a, from_func, to_func, cache)
return _distribute(a, b, from_func, to_func, generate)
return _distribute(b, a, from_func, to_func, generate)
if isinstance(a, to_exp):
return _distribute(b, a, from_func, to_func, cache)
return _distribute(b, a, from_func, to_func, generate)
if isinstance(b, to_exp):
return _distribute(a, b, from_func, to_func, cache)
return _distribute(a, b, from_func, to_func, generate)
return expression
def _distribute(a, b, from_func, to_func, cache):
def _distribute(a, b, from_func, to_func, generate):
if isinstance(a, exp.Connector):
exp.replace_children(
a,
lambda c: to_func(
uniq_sort(flatten(from_func(c, b.left)), cache),
uniq_sort(flatten(from_func(c, b.right)), cache),
uniq_sort(flatten(from_func(c, b.left)), generate),
uniq_sort(flatten(from_func(c, b.right)), generate),
copy=False,
),
)
else:
a = to_func(
uniq_sort(flatten(from_func(a, b.left)), cache),
uniq_sort(flatten(from_func(a, b.right)), cache),
uniq_sort(flatten(from_func(a, b.left)), generate),
uniq_sort(flatten(from_func(a, b.right)), generate),
copy=False,
)

View file

@ -0,0 +1,36 @@
from sqlglot import exp
from sqlglot._typing import E
from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE, DialectType
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
"""
Normalize all unquoted identifiers to either lower or upper case, depending on
the dialect. This essentially makes those identifiers case-insensitive.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
>>> normalize_identifiers(expression).sql()
'SELECT bar.a AS a FROM "Foo".bar'
Args:
expression: The expression to transform.
dialect: The dialect to use in order to decide how to normalize identifiers.
Returns:
The transformed expression.
"""
return expression.transform(_normalize, dialect, copy=False)
def _normalize(node: exp.Expression, dialect: DialectType = None) -> exp.Expression:
if isinstance(node, exp.Identifier) and not node.quoted:
node.set(
"this",
node.this.upper()
if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE
else node.this.lower(),
)
return node

View file

@ -1,6 +1,8 @@
from sqlglot import exp
from sqlglot.helper import tsort
JOIN_ATTRS = ("on", "side", "kind", "using", "natural")
def optimize_joins(expression):
"""
@ -45,7 +47,7 @@ def reorder_joins(expression):
Reorder joins by topological sort order based on predicate references.
"""
for from_ in expression.find_all(exp.From):
head = from_.expressions[0]
head = from_.this
parent = from_.parent
joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])}
dag = {head.alias_or_name: []}
@ -65,6 +67,9 @@ def normalize(expression):
Remove INNER and OUTER from joins as they are optional.
"""
for join in expression.find_all(exp.Join):
if not any(join.args.get(k) for k in JOIN_ATTRS):
join.set("kind", "CROSS")
if join.kind != "CROSS":
join.set("kind", None)
return expression

View file

@ -10,36 +10,29 @@ from sqlglot.optimizer.canonicalize import canonicalize
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.lower_identities import lower_identities
from sqlglot.optimizer.merge_subqueries import merge_subqueries
from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.optimizer.qualify import qualify
from sqlglot.optimizer.qualify_columns import quote_identifiers
from sqlglot.optimizer.simplify import simplify
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
from sqlglot.schema import ensure_schema
RULES = (
lower_identities,
qualify_tables,
isolate_table_selects,
qualify_columns,
qualify,
pushdown_projections,
validate_qualify_columns,
normalize,
unnest_subqueries,
expand_multi_table_selects,
pushdown_predicates,
optimize_joins,
eliminate_subqueries,
merge_subqueries,
eliminate_joins,
eliminate_ctes,
quote_identifiers,
annotate_types,
canonicalize,
simplify,
@ -54,7 +47,7 @@ def optimize(
dialect: DialectType = None,
rules: t.Sequence[t.Callable] = RULES,
**kwargs,
):
) -> exp.Expression:
"""
Rewrite a sqlglot AST into an optimized form.
@ -72,14 +65,23 @@ def optimize(
dialect: The dialect to parse the sql string.
rules: sequence of optimizer rules to use.
Many of the rules require tables and columns to be qualified.
Do not remove qualify_tables or qualify_columns from the sequence of rules unless you know
what you're doing!
Do not remove `qualify` from the sequence of rules unless you know what you're doing!
**kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.
Returns:
sqlglot.Expression: optimized expression
The optimized expression.
"""
schema = ensure_schema(schema or sqlglot.schema, dialect=dialect)
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
possible_kwargs = {
"db": db,
"catalog": catalog,
"schema": schema,
"dialect": dialect,
"isolate_tables": True, # needed for other optimizations to perform well
"quote_identifiers": False, # this happens in canonicalize
**kwargs,
}
expression = exp.maybe_parse(expression, dialect=dialect, copy=True)
for rule in rules:
# Find any additional rule parameters, beyond `expression`
@ -88,4 +90,5 @@ def optimize(
param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
}
expression = rule(expression, **rule_kwargs)
return expression
return t.cast(exp.Expression, expression)

View file

@ -21,26 +21,28 @@ def pushdown_predicates(expression):
sqlglot.Expression: optimized expression
"""
root = build_scope(expression)
scope_ref_count = root.ref_count()
for scope in reversed(list(root.traverse())):
select = scope.expression
where = select.args.get("where")
if where:
selected_sources = scope.selected_sources
# a right join can only push down to itself and not the source FROM table
for k, (node, source) in selected_sources.items():
parent = node.find_ancestor(exp.Join, exp.From)
if isinstance(parent, exp.Join) and parent.side == "RIGHT":
selected_sources = {k: (node, source)}
break
pushdown(where.this, selected_sources, scope_ref_count)
if root:
scope_ref_count = root.ref_count()
# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
for join in select.args.get("joins") or []:
name = join.this.alias_or_name
pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)
for scope in reversed(list(root.traverse())):
select = scope.expression
where = select.args.get("where")
if where:
selected_sources = scope.selected_sources
# a right join can only push down to itself and not the source FROM table
for k, (node, source) in selected_sources.items():
parent = node.find_ancestor(exp.Join, exp.From)
if isinstance(parent, exp.Join) and parent.side == "RIGHT":
selected_sources = {k: (node, source)}
break
pushdown(where.this, selected_sources, scope_ref_count)
# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
for join in select.args.get("joins") or []:
name = join.this.alias_or_name
pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)
return expression

View file

@ -39,8 +39,9 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
for scope in reversed(traverse_scope(expression)):
parent_selections = referenced_columns.get(scope, {SELECT_ALL})
if scope.expression.args.get("distinct"):
# We can't remove columns SELECT DISTINCT nor UNION DISTINCT
if scope.expression.args.get("distinct") or scope.parent and scope.parent.pivots:
# We can't remove columns SELECT DISTINCT nor UNION DISTINCT. The same holds if
# we select from a pivoted source in the parent scope.
parent_selections = {SELECT_ALL}
if isinstance(scope.expression, exp.Union):
@ -105,7 +106,9 @@ def _remove_unused_selections(scope, parent_selections, schema):
for name in sorted(parent_selections):
if name not in names:
new_selections.append(alias(exp.column(name, table=resolver.get_table(name)), name))
new_selections.append(
alias(exp.column(name, table=resolver.get_table(name)), name, copy=False)
)
# If there are no remaining selections, just select a single constant
if not new_selections:

View file

@ -0,0 +1,80 @@
from __future__ import annotations
import typing as t
from sqlglot import exp
from sqlglot.dialects.dialect import DialectType
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
from sqlglot.optimizer.qualify_columns import (
qualify_columns as qualify_columns_func,
quote_identifiers as quote_identifiers_func,
validate_qualify_columns as validate_qualify_columns_func,
)
from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.schema import Schema, ensure_schema
def qualify(
expression: exp.Expression,
dialect: DialectType = None,
db: t.Optional[str] = None,
catalog: t.Optional[str] = None,
schema: t.Optional[dict | Schema] = None,
expand_alias_refs: bool = True,
infer_schema: t.Optional[bool] = None,
isolate_tables: bool = False,
qualify_columns: bool = True,
validate_qualify_columns: bool = True,
quote_identifiers: bool = True,
identify: bool = True,
) -> exp.Expression:
"""
Rewrite sqlglot AST to have normalized and qualified tables and columns.
This step is necessary for all further SQLGlot optimizations.
Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify(expression, schema=schema).sql()
'SELECT "tbl"."col" AS "col" FROM "tbl" AS "tbl"'
Args:
expression: Expression to qualify.
db: Default database name for tables.
catalog: Default catalog name for tables.
schema: Schema to infer column names and types.
expand_alias_refs: Whether or not to expand references to aliases.
infer_schema: Whether or not to infer the schema if missing.
isolate_tables: Whether or not to isolate table selects.
qualify_columns: Whether or not to qualify columns.
validate_qualify_columns: Whether or not to validate columns.
quote_identifiers: Whether or not to run the quote_identifiers step.
This step is necessary to ensure correctness for case sensitive queries.
But this flag is provided in case this step is performed at a later time.
identify: If True, quote all identifiers, else only necessary ones.
Returns:
The qualified expression.
"""
schema = ensure_schema(schema, dialect=dialect)
expression = normalize_identifiers(expression, dialect=dialect)
expression = qualify_tables(expression, db=db, catalog=catalog, schema=schema)
if isolate_tables:
expression = isolate_table_selects(expression, schema=schema)
if qualify_columns:
expression = qualify_columns_func(
expression, schema, expand_alias_refs=expand_alias_refs, infer_schema=infer_schema
)
if quote_identifiers:
expression = quote_identifiers_func(expression, dialect=dialect, identify=identify)
if validate_qualify_columns:
validate_qualify_columns_func(expression)
return expression

View file

@ -1,14 +1,23 @@
from __future__ import annotations
import itertools
import typing as t
from sqlglot import alias, exp
from sqlglot._typing import E
from sqlglot.dialects.dialect import DialectType
from sqlglot.errors import OptimizeError
from sqlglot.optimizer.expand_laterals import expand_laterals as _expand_laterals
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
from sqlglot.helper import case_sensitive, seq_get
from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
from sqlglot.schema import Schema, ensure_schema
def qualify_columns(expression, schema, expand_laterals=True):
def qualify_columns(
expression: exp.Expression,
schema: dict | Schema,
expand_alias_refs: bool = True,
infer_schema: t.Optional[bool] = None,
) -> exp.Expression:
"""
Rewrite sqlglot AST to have fully qualified columns.
@ -20,32 +29,36 @@ def qualify_columns(expression, schema, expand_laterals=True):
'SELECT tbl.col AS col FROM tbl'
Args:
expression (sqlglot.Expression): expression to qualify
schema (dict|sqlglot.optimizer.Schema): Database schema
expression: expression to qualify
schema: Database schema
expand_alias_refs: whether or not to expand references to aliases
infer_schema: whether or not to infer the schema if missing
Returns:
sqlglot.Expression: qualified expression
"""
schema = ensure_schema(schema)
if not schema.mapping and expand_laterals:
expression = _expand_laterals(expression)
infer_schema = schema.empty if infer_schema is None else infer_schema
for scope in traverse_scope(expression):
resolver = Resolver(scope, schema)
resolver = Resolver(scope, schema, infer_schema=infer_schema)
_pop_table_column_aliases(scope.ctes)
_pop_table_column_aliases(scope.derived_tables)
using_column_tables = _expand_using(scope, resolver)
if schema.empty and expand_alias_refs:
_expand_alias_refs(scope, resolver)
_qualify_columns(scope, resolver)
if not schema.empty and expand_alias_refs:
_expand_alias_refs(scope, resolver)
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver, using_column_tables)
_qualify_outputs(scope)
_expand_alias_refs(scope, resolver)
_expand_group_by(scope, resolver)
_expand_order_by(scope)
if schema.mapping and expand_laterals:
expression = _expand_laterals(expression)
return expression
@ -55,9 +68,11 @@ def validate_qualify_columns(expression):
for scope in traverse_scope(expression):
if isinstance(scope.expression, exp.Select):
unqualified_columns.extend(scope.unqualified_columns)
if scope.external_columns and not scope.is_correlated_subquery:
if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
column = scope.external_columns[0]
raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'")
raise OptimizeError(
f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
)
if unqualified_columns:
raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
@ -142,52 +157,48 @@ def _expand_using(scope, resolver):
# Ensure selects keep their output name
if isinstance(column.parent, exp.Select):
replacement = exp.alias_(replacement, alias=column.name)
replacement = alias(replacement, alias=column.name, copy=False)
scope.replace(column, replacement)
return column_tables
def _expand_alias_refs(scope, resolver):
selects = {}
def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
expression = scope.expression
# Replace references to select aliases
def transform(node, source_first=True):
if isinstance(node, exp.Column) and not node.table:
table = resolver.get_table(node.name)
if not isinstance(expression, exp.Select):
return
# Source columns get priority over select aliases
if source_first and table:
node.set("table", table)
return node
alias_to_expression: t.Dict[str, exp.Expression] = {}
if not selects:
for s in scope.selects:
selects[s.alias_or_name] = s
select = selects.get(node.name)
if select:
scope.clear_cache()
if isinstance(select, exp.Alias):
select = select.this
return select.copy()
node.set("table", table)
elif isinstance(node, exp.Expression) and not isinstance(node, exp.Subqueryable):
exp.replace_children(node, transform, source_first)
return node
for select in scope.expression.selects:
transform(select)
for modifier, source_first in (
("where", True),
("group", True),
("having", False),
def replace_columns(
node: t.Optional[exp.Expression], expand: bool = True, resolve_agg: bool = False
):
transform(scope.expression.args.get(modifier), source_first=source_first)
if not node:
return
for column, *_ in walk_in_scope(node):
if not isinstance(column, exp.Column):
continue
table = resolver.get_table(column.name) if resolve_agg and not column.table else None
if table and column.find_ancestor(exp.AggFunc):
column.set("table", table)
elif expand and not column.table and column.name in alias_to_expression:
column.replace(alias_to_expression[column.name].copy())
for projection in scope.selects:
replace_columns(projection)
if isinstance(projection, exp.Alias):
alias_to_expression[projection.alias] = projection.this
replace_columns(expression.args.get("where"))
replace_columns(expression.args.get("group"))
replace_columns(expression.args.get("having"), resolve_agg=True)
replace_columns(expression.args.get("qualify"), resolve_agg=True)
replace_columns(expression.args.get("order"), expand=False, resolve_agg=True)
scope.clear_cache()
def _expand_group_by(scope, resolver):
@ -242,6 +253,12 @@ def _qualify_columns(scope, resolver):
raise OptimizeError(f"Unknown column: {column_name}")
if not column_table:
if scope.pivots and not column.find_ancestor(exp.Pivot):
# If the column is under the Pivot expression, we need to qualify it
# using the name of the pivoted source instead of the pivot's alias
column.set("table", exp.to_identifier(scope.pivots[0].alias))
continue
column_table = resolver.get_table(column_name)
# column_table can be a '' because bigquery unnest has no table alias
@ -265,38 +282,12 @@ def _qualify_columns(scope, resolver):
if column_table:
column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
columns_missing_from_scope = []
# Determine whether each reference in the order by clause is to a column or an alias.
order = scope.expression.args.get("order")
if order:
for ordered in order.expressions:
for column in ordered.find_all(exp.Column):
if (
not column.table
and column.parent is not ordered
and column.name in resolver.all_columns
):
columns_missing_from_scope.append(column)
# Determine whether each reference in the having clause is to a column or an alias.
having = scope.expression.args.get("having")
if having:
for column in having.find_all(exp.Column):
if (
not column.table
and column.find_ancestor(exp.AggFunc)
and column.name in resolver.all_columns
):
columns_missing_from_scope.append(column)
for column in columns_missing_from_scope:
column_table = resolver.get_table(column.name)
if column_table:
column.set("table", column_table)
for pivot in scope.pivots:
for column in pivot.find_all(exp.Column):
if not column.table and column.name in resolver.all_columns:
column_table = resolver.get_table(column.name)
if column_table:
column.set("table", column_table)
def _expand_stars(scope, resolver, using_column_tables):
@ -307,6 +298,19 @@ def _expand_stars(scope, resolver, using_column_tables):
replace_columns = {}
coalesced_columns = set()
# TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
pivot_columns = None
pivot_output_columns = None
pivot = seq_get(scope.pivots, 0)
has_pivoted_source = pivot and not pivot.args.get("unpivot")
if has_pivoted_source:
pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))
pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
if not pivot_output_columns:
pivot_output_columns = [col.alias_or_name for col in pivot.expressions]
for expression in scope.selects:
if isinstance(expression, exp.Star):
tables = list(scope.selected_sources)
@ -323,9 +327,18 @@ def _expand_stars(scope, resolver, using_column_tables):
for table in tables:
if table not in scope.sources:
raise OptimizeError(f"Unknown table: {table}")
columns = resolver.get_source_columns(table, only_visible=True)
if columns and "*" not in columns:
if has_pivoted_source:
implicit_columns = [col for col in columns if col not in pivot_columns]
new_selections.extend(
exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
for name in implicit_columns + pivot_output_columns
)
continue
table_id = id(table)
for name in columns:
if name in using_column_tables and table in using_column_tables[name]:
@ -337,16 +350,21 @@ def _expand_stars(scope, resolver, using_column_tables):
coalesce = [exp.column(name, table=table) for table in tables]
new_selections.append(
exp.alias_(
exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), alias=name
alias(
exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
alias=name,
copy=False,
)
)
elif name not in except_columns.get(table_id, set()):
alias_ = replace_columns.get(table_id, {}).get(name, name)
column = exp.column(name, table)
new_selections.append(alias(column, alias_) if alias_ != name else column)
column = exp.column(name, table=table)
new_selections.append(
alias(column, alias_, copy=False) if alias_ != name else column
)
else:
return
scope.expression.set("expressions", new_selections)
@ -388,9 +406,6 @@ def _qualify_outputs(scope):
selection = alias(
selection,
alias=selection.output_name or f"_col_{i}",
quoted=True
if isinstance(selection, exp.Column) and selection.this.quoted
else None,
)
if aliased_column:
selection.set("alias", exp.to_identifier(aliased_column))
@ -400,6 +415,23 @@ def _qualify_outputs(scope):
scope.expression.set("expressions", new_selections)
def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
"""Makes sure all identifiers that need to be quoted are quoted."""
def _quote(expression: E) -> E:
if isinstance(expression, exp.Identifier):
name = expression.this
expression.set(
"quoted",
identify
or case_sensitive(name, dialect=dialect)
or not exp.SAFE_IDENTIFIER_RE.match(name),
)
return expression
return expression.transform(_quote, copy=False)
class Resolver:
"""
Helper for resolving columns.
@ -407,12 +439,13 @@ class Resolver:
This is a class so we can lazily load some things and easily share them across functions.
"""
def __init__(self, scope, schema):
def __init__(self, scope, schema, infer_schema: bool = True):
self.scope = scope
self.schema = schema
self._source_columns = None
self._unambiguous_columns = None
self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
self._all_columns = None
self._infer_schema = infer_schema
def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
"""
@ -430,7 +463,7 @@ class Resolver:
table_name = self._unambiguous_columns.get(column_name)
if not table_name:
if not table_name and self._infer_schema:
sources_without_schema = tuple(
source
for source, columns in self._get_all_source_columns().items()
@ -450,11 +483,9 @@ class Resolver:
node_alias = node.args.get("alias")
if node_alias:
return node_alias.this
return exp.to_identifier(node_alias.this)
return exp.to_identifier(
table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
)
return exp.to_identifier(table_name)
@property
def all_columns(self):

View file

@ -1,11 +1,19 @@
import itertools
import typing as t
from sqlglot import alias, exp
from sqlglot.helper import csv_reader
from sqlglot._typing import E
from sqlglot.helper import csv_reader, name_sequence
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema
def qualify_tables(expression, db=None, catalog=None, schema=None):
def qualify_tables(
expression: E,
db: t.Optional[str] = None,
catalog: t.Optional[str] = None,
schema: t.Optional[Schema] = None,
) -> E:
"""
Rewrite sqlglot AST to have fully qualified tables. Additionally, this
replaces "join constructs" (*) by equivalent SELECT * subqueries.
@ -21,19 +29,17 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0'
Args:
expression (sqlglot.Expression): expression to qualify
db (str): Database name
catalog (str): Catalog name
expression: Expression to qualify
db: Database name
catalog: Catalog name
schema: A schema to populate
Returns:
sqlglot.Expression: qualified expression
The qualified expression.
(*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html
"""
sequence = itertools.count()
next_name = lambda: f"_q_{next(sequence)}"
next_alias_name = name_sequence("_q_")
for scope in traverse_scope(expression):
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
@ -44,10 +50,14 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
if not derived_table.args.get("alias"):
alias_ = f"_q_{next(sequence)}"
alias_ = next_alias_name()
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
scope.rename_source(None, alias_)
pivots = derived_table.args.get("pivots")
if pivots and not pivots[0].alias:
pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())))
for name, source in scope.sources.items():
if isinstance(source, exp.Table):
if isinstance(source.this, exp.Identifier):
@ -59,12 +69,19 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
if not source.alias:
source = source.replace(
alias(
source.copy(),
name if name else next_name(),
source,
name or source.name or next_alias_name(),
copy=True,
table=True,
)
)
pivots = source.args.get("pivots")
if pivots and not pivots[0].alias:
pivots[0].set(
"alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))
)
if schema and isinstance(source.this, exp.ReadCSV):
with csv_reader(source.this) as reader:
header = next(reader)
@ -74,11 +91,11 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
)
elif isinstance(source, Scope) and source.is_udtf:
udtf = source.expression
table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_name())
table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_alias_name())
udtf.set("alias", table_alias)
if not table_alias.name:
table_alias.set("this", next_name())
table_alias.set("this", next_alias_name())
if isinstance(udtf, exp.Values) and not table_alias.columns:
for i, e in enumerate(udtf.expressions[0].expressions):
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))

View file

@ -1,4 +1,5 @@
import itertools
import typing as t
from collections import defaultdict
from enum import Enum, auto
@ -83,6 +84,7 @@ class Scope:
self._columns = None
self._external_columns = None
self._join_hints = None
self._pivots = None
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
"""Branch from the current scope to a new, inner scope"""
@ -261,12 +263,14 @@ class Scope:
self._columns = []
for column in columns + external_columns:
ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint)
ancestor = column.find_ancestor(
exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint
)
if (
not ancestor
# Window functions can have an ORDER BY clause
or not isinstance(ancestor.parent, exp.Select)
or column.table
or isinstance(ancestor, exp.Select)
or (isinstance(ancestor, exp.Order) and isinstance(ancestor.parent, exp.Window))
or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
):
self._columns.append(column)
@ -370,6 +374,17 @@ class Scope:
return []
return self._join_hints
@property
def pivots(self):
if not self._pivots:
self._pivots = [
pivot
for node in self.tables + self.derived_tables
for pivot in node.args.get("pivots") or []
]
return self._pivots
def source_columns(self, source_name):
"""
Get all columns in the current scope for a particular source.
@ -463,7 +478,7 @@ class Scope:
return scope_ref_count
def traverse_scope(expression):
def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
"""
Traverse an expression by it's "scopes".
@ -488,10 +503,12 @@ def traverse_scope(expression):
Returns:
list[Scope]: scope instances
"""
if not isinstance(expression, exp.Unionable):
return []
return list(_traverse_scope(Scope(expression)))
def build_scope(expression):
def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
"""
Build a scope tree.
@ -500,7 +517,10 @@ def build_scope(expression):
Returns:
Scope: root scope
"""
return traverse_scope(expression)[-1]
scopes = traverse_scope(expression)
if scopes:
return scopes[-1]
return None
def _traverse_scope(scope):
@ -585,7 +605,7 @@ def _traverse_tables(scope):
expressions = []
from_ = scope.expression.args.get("from")
if from_:
expressions.extend(from_.expressions)
expressions.append(from_.this)
for join in scope.expression.args.get("joins") or []:
expressions.append(join.this)
@ -601,8 +621,13 @@ def _traverse_tables(scope):
source_name = expression.alias_or_name
if table_name in scope.sources:
# This is a reference to a parent source (e.g. a CTE), not an actual table.
sources[source_name] = scope.sources[table_name]
# This is a reference to a parent source (e.g. a CTE), not an actual table, unless
# it is pivoted, because then we get back a new table and hence a new source.
pivots = expression.args.get("pivots")
if pivots:
sources[pivots[0].alias] = expression
else:
sources[source_name] = scope.sources[table_name]
elif source_name in sources:
sources[find_new_name(sources, table_name)] = expression
else:

View file

@ -5,11 +5,9 @@ from collections import deque
from decimal import Decimal
from sqlglot import exp
from sqlglot.generator import Generator
from sqlglot.generator import cached_generator
from sqlglot.helper import first, while_changing
GENERATOR = Generator(normalize=True, identify="safe")
def simplify(expression):
"""
@ -27,12 +25,12 @@ def simplify(expression):
sqlglot.Expression: simplified expression
"""
cache = {}
generate = cached_generator()
def _simplify(expression, root=True):
node = expression
node = rewrite_between(node)
node = uniq_sort(node, cache, root)
node = uniq_sort(node, generate, root)
node = absorb_and_eliminate(node, root)
exp.replace_children(node, lambda e: _simplify(e, False))
node = simplify_not(node)
@ -247,7 +245,7 @@ def remove_compliments(expression, root=True):
return expression
def uniq_sort(expression, cache=None, root=True):
def uniq_sort(expression, generate, root=True):
"""
Uniq and sort a connector.
@ -256,7 +254,7 @@ def uniq_sort(expression, cache=None, root=True):
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
flattened = tuple(expression.flatten())
deduped = {GENERATOR.generate(e, cache): e for e in flattened}
deduped = {generate(e): e for e in flattened}
arr = tuple(deduped.items())
# check if the operands are already sorted, if not sort them
@ -388,14 +386,18 @@ def _simplify_binary(expression, a, b):
def simplify_parens(expression):
if (
isinstance(expression, exp.Paren)
and not isinstance(expression.this, exp.Select)
and (
not isinstance(expression.parent, (exp.Condition, exp.Binary))
or isinstance(expression.this, exp.Predicate)
or not isinstance(expression.this, exp.Binary)
)
if not isinstance(expression, exp.Paren):
return expression
this = expression.this
parent = expression.parent
if not isinstance(this, exp.Select) and (
not isinstance(parent, (exp.Condition, exp.Binary))
or isinstance(this, exp.Predicate)
or not isinstance(this, exp.Binary)
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
):
return expression.this
return expression

View file

@ -1,6 +1,5 @@
import itertools
from sqlglot import exp
from sqlglot.helper import name_sequence
from sqlglot.optimizer.scope import ScopeType, traverse_scope
@ -22,7 +21,7 @@ def unnest_subqueries(expression):
Returns:
sqlglot.Expression: unnested expression
"""
sequence = itertools.count()
next_alias_name = name_sequence("_u_")
for scope in traverse_scope(expression):
select = scope.expression
@ -30,19 +29,19 @@ def unnest_subqueries(expression):
if not parent:
continue
if scope.external_columns:
decorrelate(select, parent, scope.external_columns, sequence)
decorrelate(select, parent, scope.external_columns, next_alias_name)
elif scope.scope_type == ScopeType.SUBQUERY:
unnest(select, parent, sequence)
unnest(select, parent, next_alias_name)
return expression
def unnest(select, parent_select, sequence):
def unnest(select, parent_select, next_alias_name):
if len(select.selects) > 1:
return
predicate = select.find_ancestor(exp.Condition)
alias = _alias(sequence)
alias = next_alias_name()
if not predicate or parent_select is not predicate.parent_select:
return
@ -87,13 +86,13 @@ def unnest(select, parent_select, sequence):
)
def decorrelate(select, parent_select, external_columns, sequence):
def decorrelate(select, parent_select, external_columns, next_alias_name):
where = select.args.get("where")
if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
return
table_alias = _alias(sequence)
table_alias = next_alias_name()
keys = []
# for all external columns in the where statement, find the relevant predicate
@ -136,7 +135,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
group_by.append(key)
else:
if key not in key_aliases:
key_aliases[key] = _alias(sequence)
key_aliases[key] = next_alias_name()
# all predicates that are equalities must also be in the unique
# so that we don't do a many to many join
if isinstance(predicate, exp.EQ) and key not in group_by:
@ -244,10 +243,6 @@ def decorrelate(select, parent_select, external_columns, sequence):
)
def _alias(sequence):
return f"_u_{next(sequence)}"
def _replace(expression, condition):
return expression.replace(exp.condition(condition))

File diff suppressed because it is too large Load diff

View file

@ -1,11 +1,10 @@
from __future__ import annotations
import itertools
import math
import typing as t
from sqlglot import alias, exp
from sqlglot.errors import UnsupportedError
from sqlglot.helper import name_sequence
from sqlglot.optimizer.eliminate_joins import join_condition
@ -105,13 +104,7 @@ class Step:
from_ = expression.args.get("from")
if isinstance(expression, exp.Select) and from_:
from_ = from_.expressions
if len(from_) > 1:
raise UnsupportedError(
"Multi-from statements are unsupported. Run it through the optimizer"
)
step = Scan.from_expression(from_[0], ctes)
step = Scan.from_expression(from_.this, ctes)
elif isinstance(expression, exp.Union):
step = SetOperation.from_expression(expression, ctes)
else:
@ -128,7 +121,7 @@ class Step:
projections = [] # final selects in this chain of steps representing a select
operands = {} # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
aggregations = []
sequence = itertools.count()
next_operand_name = name_sequence("_a_")
def extract_agg_operands(expression):
for agg in expression.find_all(exp.AggFunc):
@ -136,7 +129,7 @@ class Step:
if isinstance(operand, exp.Column):
continue
if operand not in operands:
operands[operand] = f"_a_{next(sequence)}"
operands[operand] = next_operand_name()
operand.replace(exp.column(operands[operand], quoted=True))
for e in expression.expressions:
@ -310,7 +303,7 @@ class Join(Step):
for join in joins:
source_key, join_key, condition = join_condition(join)
step.joins[join.this.alias_or_name] = {
"side": join.side,
"side": join.side, # type: ignore
"join_key": join_key,
"source_key": source_key,
"condition": condition,

View file

@ -5,6 +5,8 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
from sqlglot._typing import T
from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE
from sqlglot.errors import ParseError, SchemaError
from sqlglot.helper import dict_depth
from sqlglot.trie import in_trie, new_trie
@ -17,62 +19,83 @@ if t.TYPE_CHECKING:
TABLE_ARGS = ("this", "db", "catalog")
T = t.TypeVar("T")
class Schema(abc.ABC):
"""Abstract base class for database schemas"""
dialect: DialectType
@abc.abstractmethod
def add_table(
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
self,
table: exp.Table | str,
column_mapping: t.Optional[ColumnMapping] = None,
dialect: DialectType = None,
) -> None:
"""
Register or update a table. Some implementing classes may require column information to also be provided.
Args:
table: table expression instance or string representing the table.
table: the `Table` expression instance or string representing the table.
column_mapping: a column mapping that describes the structure of the table.
dialect: the SQL dialect that will be used to parse `table` if it's a string.
"""
@abc.abstractmethod
def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
def column_names(
self,
table: exp.Table | str,
only_visible: bool = False,
dialect: DialectType = None,
) -> t.List[str]:
"""
Get the column names for a table.
Args:
table: the `Table` expression instance.
only_visible: whether to include invisible columns.
dialect: the SQL dialect that will be used to parse `table` if it's a string.
Returns:
The list of column names.
"""
@abc.abstractmethod
def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType:
def get_column_type(
self,
table: exp.Table | str,
column: exp.Column,
dialect: DialectType = None,
) -> exp.DataType:
"""
Get the :class:`sqlglot.exp.DataType` type of a column in the schema.
Get the `sqlglot.exp.DataType` type of a column in the schema.
Args:
table: the source table.
column: the target column.
dialect: the SQL dialect that will be used to parse `table` if it's a string.
Returns:
The resulting column type.
"""
@property
@abc.abstractmethod
def supported_table_args(self) -> t.Tuple[str, ...]:
"""
Table arguments this schema support, e.g. `("this", "db", "catalog")`
"""
raise NotImplementedError
@property
def empty(self) -> bool:
"""Returns whether or not the schema is empty."""
return True
class AbstractMappingSchema(t.Generic[T]):
def __init__(
self,
mapping: dict | None = None,
mapping: t.Optional[t.Dict] = None,
) -> None:
self.mapping = mapping or {}
self.mapping_trie = new_trie(
@ -80,6 +103,10 @@ class AbstractMappingSchema(t.Generic[T]):
)
self._supported_table_args: t.Tuple[str, ...] = tuple()
@property
def empty(self) -> bool:
return not self.mapping
def _depth(self) -> int:
return dict_depth(self.mapping)
@ -110,8 +137,10 @@ class AbstractMappingSchema(t.Generic[T]):
if value == 0:
return None
elif value == 1:
if value == 1:
possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
if len(possibilities) == 1:
parts.extend(possibilities[0])
else:
@ -119,12 +148,13 @@ class AbstractMappingSchema(t.Generic[T]):
if raise_on_missing:
raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
return None
return self._nested_get(parts, raise_on_missing=raise_on_missing)
def _nested_get(
return self.nested_get(parts, raise_on_missing=raise_on_missing)
def nested_get(
self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
) -> t.Optional[t.Any]:
return _nested_get(
return nested_get(
d or self.mapping,
*zip(self.supported_table_args, reversed(parts)),
raise_on_missing=raise_on_missing,
@ -136,17 +166,18 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
Schema based on a nested mapping.
Args:
schema (dict): Mapping in one of the following forms:
schema: Mapping in one of the following forms:
1. {table: {col: type}}
2. {db: {table: {col: type}}}
3. {catalog: {db: {table: {col: type}}}}
4. None - Tables will be added later
visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
1. {table: set(*cols)}}
2. {db: {table: set(*cols)}}}
3. {catalog: {db: {table: set(*cols)}}}}
dialect (str): The dialect to be used for custom type mappings.
dialect: The dialect to be used for custom type mappings & parsing string arguments.
normalize: Whether to normalize identifier names according to the given dialect or not.
"""
def __init__(
@ -154,10 +185,13 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
schema: t.Optional[t.Dict] = None,
visible: t.Optional[t.Dict] = None,
dialect: DialectType = None,
normalize: bool = True,
) -> None:
self.dialect = dialect
self.visible = visible or {}
self.normalize = normalize
self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
super().__init__(self._normalize(schema or {}))
@classmethod
@ -179,7 +213,10 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
)
def add_table(
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
self,
table: exp.Table | str,
column_mapping: t.Optional[ColumnMapping] = None,
dialect: DialectType = None,
) -> None:
"""
Register or update a table. Updates are only performed if a new column mapping is provided.
@ -187,10 +224,13 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
Args:
table: the `Table` expression instance or string representing the table.
column_mapping: a column mapping that describes the structure of the table.
dialect: the SQL dialect that will be used to parse `table` if it's a string.
"""
normalized_table = self._normalize_table(self._ensure_table(table))
normalized_table = self._normalize_table(
self._ensure_table(table, dialect=dialect), dialect=dialect
)
normalized_column_mapping = {
self._normalize_name(key): value
self._normalize_name(key, dialect=dialect): value
for key, value in ensure_column_mapping(column_mapping).items()
}
@ -200,38 +240,51 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
parts = self.table_parts(normalized_table)
_nested_set(
self.mapping,
tuple(reversed(parts)),
normalized_column_mapping,
)
nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
new_trie([parts], self.mapping_trie)
def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
table_ = self._normalize_table(self._ensure_table(table))
schema = self.find(table_)
def column_names(
self,
table: exp.Table | str,
only_visible: bool = False,
dialect: DialectType = None,
) -> t.List[str]:
normalized_table = self._normalize_table(
self._ensure_table(table, dialect=dialect), dialect=dialect
)
schema = self.find(normalized_table)
if schema is None:
return []
if not only_visible or not self.visible:
return list(schema)
visible = self._nested_get(self.table_parts(table_), self.visible)
return [col for col in schema if col in visible] # type: ignore
visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
return [col for col in schema if col in visible]
def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
column_name = self._normalize_name(column if isinstance(column, str) else column.this)
table_ = self._normalize_table(self._ensure_table(table))
def get_column_type(
self,
table: exp.Table | str,
column: exp.Column,
dialect: DialectType = None,
) -> exp.DataType:
normalized_table = self._normalize_table(
self._ensure_table(table, dialect=dialect), dialect=dialect
)
normalized_column_name = self._normalize_name(
column if isinstance(column, str) else column.this, dialect=dialect
)
table_schema = self.find(table_, raise_on_missing=False)
table_schema = self.find(normalized_table, raise_on_missing=False)
if table_schema:
column_type = table_schema.get(column_name)
column_type = table_schema.get(normalized_column_name)
if isinstance(column_type, exp.DataType):
return column_type
elif isinstance(column_type, str):
return self._to_data_type(column_type.upper())
return self._to_data_type(column_type.upper(), dialect=dialect)
raise SchemaError(f"Unknown column type '{column_type}'")
return exp.DataType.build("unknown")
@ -250,81 +303,88 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
normalized_mapping: t.Dict = {}
for keys in flattened_schema:
columns = _nested_get(schema, *zip(keys, keys))
columns = nested_get(schema, *zip(keys, keys))
assert columns is not None
normalized_keys = [self._normalize_name(key) for key in keys]
normalized_keys = [self._normalize_name(key, dialect=self.dialect) for key in keys]
for column_name, column_type in columns.items():
_nested_set(
nested_set(
normalized_mapping,
normalized_keys + [self._normalize_name(column_name)],
normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
column_type,
)
return normalized_mapping
def _normalize_table(self, table: exp.Table) -> exp.Table:
def _normalize_table(self, table: exp.Table, dialect: DialectType = None) -> exp.Table:
normalized_table = table.copy()
for arg in TABLE_ARGS:
value = normalized_table.args.get(arg)
if isinstance(value, (str, exp.Identifier)):
normalized_table.set(arg, self._normalize_name(value))
normalized_table.set(
arg, exp.to_identifier(self._normalize_name(value, dialect=dialect))
)
return normalized_table
def _normalize_name(self, name: str | exp.Identifier) -> str:
def _normalize_name(self, name: str | exp.Identifier, dialect: DialectType = None) -> str:
dialect = dialect or self.dialect
try:
identifier = sqlglot.maybe_parse(name, dialect=self.dialect, into=exp.Identifier)
identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
except ParseError:
return name if isinstance(name, str) else name.name
return identifier.name if identifier.quoted else identifier.name.lower()
name = identifier.name
if not self.normalize or identifier.quoted:
return name
return name.upper() if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else name.lower()
def _depth(self) -> int:
# The columns themselves are a mapping, but we don't want to include those
return super()._depth() - 1
def _ensure_table(self, table: exp.Table | str) -> exp.Table:
if isinstance(table, exp.Table):
return table
def _ensure_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table:
return exp.maybe_parse(table, into=exp.Table, dialect=dialect or self.dialect)
table_ = sqlglot.parse_one(table, read=self.dialect, into=exp.Table)
if not table_:
raise SchemaError(f"Not a valid table '{table}'")
return table_
def _to_data_type(self, schema_type: str) -> exp.DataType:
def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
"""
Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
Args:
schema_type: the type we want to convert.
dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
Returns:
The resulting expression type.
"""
if schema_type not in self._type_mapping_cache:
dialect = dialect or self.dialect
try:
expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect)
if expression is None:
raise ValueError(f"Could not parse {schema_type}")
self._type_mapping_cache[schema_type] = expression # type: ignore
expression = exp.DataType.build(schema_type, dialect=dialect)
self._type_mapping_cache[schema_type] = expression
except AttributeError:
raise SchemaError(f"Failed to convert type {schema_type}")
in_dialect = f" in dialect {dialect}" if dialect else ""
raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
return self._type_mapping_cache[schema_type]
def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema:
def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
if isinstance(schema, Schema):
return schema
return MappingSchema(schema, dialect=dialect)
return MappingSchema(schema, **kwargs)
def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
if isinstance(mapping, dict):
if mapping is None:
return {}
elif isinstance(mapping, dict):
return mapping
elif isinstance(mapping, str):
col_name_type_strs = [x.strip() for x in mapping.split(",")]
@ -334,11 +394,10 @@ def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
}
# Check if mapping looks like a DataFrame StructType
elif hasattr(mapping, "simpleString"):
return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} # type: ignore
return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
elif isinstance(mapping, list):
return {x.strip(): None for x in mapping}
elif mapping is None:
return {}
raise ValueError(f"Invalid mapping provided: {type(mapping)}")
@ -353,10 +412,11 @@ def flatten_schema(
tables.extend(flatten_schema(v, depth - 1, keys + [k]))
elif depth == 1:
tables.append(keys + [k])
return tables
def _nested_get(
def nested_get(
d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
) -> t.Optional[t.Any]:
"""
@ -378,18 +438,19 @@ def _nested_get(
name = "table" if name == "this" else name
raise ValueError(f"Unknown {name}: {key}")
return None
return d
def _nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
"""
In-place set a value for a nested dictionary
Example:
>>> _nested_set({}, ["top_key", "second_key"], "value")
>>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Args:

View file

@ -51,7 +51,6 @@ class TokenType(AutoName):
DOLLAR = auto()
PARAMETER = auto()
SESSION_PARAMETER = auto()
NATIONAL = auto()
DAMP = auto()
BLOCK_START = auto()
@ -72,6 +71,8 @@ class TokenType(AutoName):
BIT_STRING = auto()
HEX_STRING = auto()
BYTE_STRING = auto()
NATIONAL_STRING = auto()
RAW_STRING = auto()
# types
BIT = auto()
@ -110,6 +111,7 @@ class TokenType(AutoName):
TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto()
DATETIME = auto()
DATETIME64 = auto()
DATE = auto()
UUID = auto()
GEOGRAPHY = auto()
@ -142,30 +144,22 @@ class TokenType(AutoName):
ARRAY = auto()
ASC = auto()
ASOF = auto()
AT_TIME_ZONE = auto()
AUTO_INCREMENT = auto()
BEGIN = auto()
BETWEEN = auto()
BOTH = auto()
BUCKET = auto()
BY_DEFAULT = auto()
CACHE = auto()
CASCADE = auto()
CASE = auto()
CHARACTER_SET = auto()
CLUSTER_BY = auto()
COLLATE = auto()
COMMAND = auto()
COMMENT = auto()
COMMIT = auto()
COMPOUND = auto()
CONSTRAINT = auto()
CREATE = auto()
CROSS = auto()
CUBE = auto()
CURRENT_DATE = auto()
CURRENT_DATETIME = auto()
CURRENT_ROW = auto()
CURRENT_TIME = auto()
CURRENT_TIMESTAMP = auto()
CURRENT_USER = auto()
@ -174,8 +168,6 @@ class TokenType(AutoName):
DESC = auto()
DESCRIBE = auto()
DISTINCT = auto()
DISTINCT_FROM = auto()
DISTRIBUTE_BY = auto()
DIV = auto()
DROP = auto()
ELSE = auto()
@ -189,7 +181,6 @@ class TokenType(AutoName):
FILTER = auto()
FINAL = auto()
FIRST = auto()
FOLLOWING = auto()
FOR = auto()
FOREIGN_KEY = auto()
FORMAT = auto()
@ -203,7 +194,6 @@ class TokenType(AutoName):
HAVING = auto()
HINT = auto()
IF = auto()
IGNORE_NULLS = auto()
ILIKE = auto()
ILIKE_ANY = auto()
IN = auto()
@ -222,36 +212,27 @@ class TokenType(AutoName):
KEEP = auto()
LANGUAGE = auto()
LATERAL = auto()
LAZY = auto()
LEADING = auto()
LEFT = auto()
LIKE = auto()
LIKE_ANY = auto()
LIMIT = auto()
LOAD_DATA = auto()
LOCAL = auto()
LOAD = auto()
LOCK = auto()
MAP = auto()
MATCH_RECOGNIZE = auto()
MATERIALIZED = auto()
MERGE = auto()
MOD = auto()
NATURAL = auto()
NEXT = auto()
NEXT_VALUE_FOR = auto()
NO_ACTION = auto()
NOTNULL = auto()
NULL = auto()
NULLS_FIRST = auto()
NULLS_LAST = auto()
OFFSET = auto()
ON = auto()
ONLY = auto()
OPTIONS = auto()
ORDER_BY = auto()
ORDERED = auto()
ORDINALITY = auto()
OUTER = auto()
OUT_OF = auto()
OVER = auto()
OVERLAPS = auto()
OVERWRITE = auto()
@ -261,7 +242,6 @@ class TokenType(AutoName):
PIVOT = auto()
PLACEHOLDER = auto()
PRAGMA = auto()
PRECEDING = auto()
PRIMARY_KEY = auto()
PROCEDURE = auto()
PROPERTIES = auto()
@ -271,7 +251,6 @@ class TokenType(AutoName):
RANGE = auto()
RECURSIVE = auto()
REPLACE = auto()
RESPECT_NULLS = auto()
RETURNING = auto()
REFERENCES = auto()
RIGHT = auto()
@ -280,28 +259,23 @@ class TokenType(AutoName):
ROLLUP = auto()
ROW = auto()
ROWS = auto()
SEED = auto()
SELECT = auto()
SEMI = auto()
SEPARATOR = auto()
SERDE_PROPERTIES = auto()
SET = auto()
SETTINGS = auto()
SHOW = auto()
SIMILAR_TO = auto()
SOME = auto()
SORTKEY = auto()
SORT_BY = auto()
STRUCT = auto()
TABLE_SAMPLE = auto()
TEMPORARY = auto()
TOP = auto()
THEN = auto()
TRAILING = auto()
TRUE = auto()
UNBOUNDED = auto()
UNCACHE = auto()
UNION = auto()
UNLOGGED = auto()
UNNEST = auto()
UNPIVOT = auto()
UPDATE = auto()
@ -314,15 +288,11 @@ class TokenType(AutoName):
WHERE = auto()
WINDOW = auto()
WITH = auto()
WITH_TIME_ZONE = auto()
WITH_LOCAL_TIME_ZONE = auto()
WITHIN_GROUP = auto()
WITHOUT_TIME_ZONE = auto()
UNIQUE = auto()
class Token:
__slots__ = ("token_type", "text", "line", "col", "end", "comments")
__slots__ = ("token_type", "text", "line", "col", "start", "end", "comments")
@classmethod
def number(cls, number: int) -> Token:
@ -350,22 +320,28 @@ class Token:
text: str,
line: int = 1,
col: int = 1,
start: int = 0,
end: int = 0,
comments: t.List[str] = [],
) -> None:
"""Token initializer.
Args:
token_type: The TokenType Enum.
text: The text of the token.
line: The line that the token ends on.
col: The column that the token ends on.
start: The start index of the token.
end: The ending index of the token.
"""
self.token_type = token_type
self.text = text
self.line = line
size = len(text)
self.col = col
self.end = end if end else size
self.start = start
self.end = end
self.comments = comments
@property
def start(self) -> int:
"""Returns the start of the token."""
return self.end - len(self.text)
def __repr__(self) -> str:
attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
return f"<Token {attributes}>"
@ -375,15 +351,31 @@ class _Tokenizer(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
klass._QUOTES = {
f"{prefix}{s}": e
for s, e in cls._delimeter_list_to_dict(klass.QUOTES).items()
for prefix in (("",) if s[0].isalpha() else ("", "n", "N"))
def _convert_quotes(arr: t.List[str | t.Tuple[str, str]]) -> t.Dict[str, str]:
return dict(
(item, item) if isinstance(item, str) else (item[0], item[1]) for item in arr
)
def _quotes_to_format(
token_type: TokenType, arr: t.List[str | t.Tuple[str, str]]
) -> t.Dict[str, t.Tuple[str, TokenType]]:
return {k: (v, token_type) for k, v in _convert_quotes(arr).items()}
klass._QUOTES = _convert_quotes(klass.QUOTES)
klass._IDENTIFIERS = _convert_quotes(klass.IDENTIFIERS)
klass._FORMAT_STRINGS = {
**{
p + s: (e, TokenType.NATIONAL_STRING)
for s, e in klass._QUOTES.items()
for p in ("n", "N")
},
**_quotes_to_format(TokenType.BIT_STRING, klass.BIT_STRINGS),
**_quotes_to_format(TokenType.BYTE_STRING, klass.BYTE_STRINGS),
**_quotes_to_format(TokenType.HEX_STRING, klass.HEX_STRINGS),
**_quotes_to_format(TokenType.RAW_STRING, klass.RAW_STRINGS),
}
klass._BIT_STRINGS = cls._delimeter_list_to_dict(klass.BIT_STRINGS)
klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS)
klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS)
klass._STRING_ESCAPES = set(klass.STRING_ESCAPES)
klass._IDENTIFIER_ESCAPES = set(klass.IDENTIFIER_ESCAPES)
klass._COMMENTS = dict(
@ -393,23 +385,17 @@ class _Tokenizer(type):
klass.KEYWORD_TRIE = new_trie(
key.upper()
for key in {
**klass.KEYWORDS,
**{comment: TokenType.COMMENT for comment in klass._COMMENTS},
**{quote: TokenType.QUOTE for quote in klass._QUOTES},
**{bit_string: TokenType.BIT_STRING for bit_string in klass._BIT_STRINGS},
**{hex_string: TokenType.HEX_STRING for hex_string in klass._HEX_STRINGS},
**{byte_string: TokenType.BYTE_STRING for byte_string in klass._BYTE_STRINGS},
}
for key in (
*klass.KEYWORDS,
*klass._COMMENTS,
*klass._QUOTES,
*klass._FORMAT_STRINGS,
)
if " " in key or any(single in key for single in klass.SINGLE_TOKENS)
)
return klass
@staticmethod
def _delimeter_list_to_dict(list: t.List[str | t.Tuple[str, str]]) -> t.Dict[str, str]:
return dict((item, item) if isinstance(item, str) else (item[0], item[1]) for item in list)
class Tokenizer(metaclass=_Tokenizer):
SINGLE_TOKENS = {
@ -450,6 +436,7 @@ class Tokenizer(metaclass=_Tokenizer):
BIT_STRINGS: t.List[str | t.Tuple[str, str]] = []
BYTE_STRINGS: t.List[str | t.Tuple[str, str]] = []
HEX_STRINGS: t.List[str | t.Tuple[str, str]] = []
RAW_STRINGS: t.List[str | t.Tuple[str, str]] = []
IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
IDENTIFIER_ESCAPES = ['"']
QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
@ -457,9 +444,7 @@ class Tokenizer(metaclass=_Tokenizer):
VAR_SINGLE_TOKENS: t.Set[str] = set()
_COMMENTS: t.Dict[str, str] = {}
_BIT_STRINGS: t.Dict[str, str] = {}
_BYTE_STRINGS: t.Dict[str, str] = {}
_HEX_STRINGS: t.Dict[str, str] = {}
_FORMAT_STRINGS: t.Dict[str, t.Tuple[str, TokenType]] = {}
_IDENTIFIERS: t.Dict[str, str] = {}
_IDENTIFIER_ESCAPES: t.Set[str] = set()
_QUOTES: t.Dict[str, str] = {}
@ -495,30 +480,22 @@ class Tokenizer(metaclass=_Tokenizer):
"ANY": TokenType.ANY,
"ASC": TokenType.ASC,
"AS": TokenType.ALIAS,
"AT TIME ZONE": TokenType.AT_TIME_ZONE,
"AUTOINCREMENT": TokenType.AUTO_INCREMENT,
"AUTO_INCREMENT": TokenType.AUTO_INCREMENT,
"BEGIN": TokenType.BEGIN,
"BETWEEN": TokenType.BETWEEN,
"BOTH": TokenType.BOTH,
"BUCKET": TokenType.BUCKET,
"BY DEFAULT": TokenType.BY_DEFAULT,
"CACHE": TokenType.CACHE,
"UNCACHE": TokenType.UNCACHE,
"CASE": TokenType.CASE,
"CASCADE": TokenType.CASCADE,
"CHARACTER SET": TokenType.CHARACTER_SET,
"CLUSTER BY": TokenType.CLUSTER_BY,
"COLLATE": TokenType.COLLATE,
"COLUMN": TokenType.COLUMN,
"COMMIT": TokenType.COMMIT,
"COMPOUND": TokenType.COMPOUND,
"CONSTRAINT": TokenType.CONSTRAINT,
"CREATE": TokenType.CREATE,
"CROSS": TokenType.CROSS,
"CUBE": TokenType.CUBE,
"CURRENT_DATE": TokenType.CURRENT_DATE,
"CURRENT ROW": TokenType.CURRENT_ROW,
"CURRENT_TIME": TokenType.CURRENT_TIME,
"CURRENT_TIMESTAMP": TokenType.CURRENT_TIMESTAMP,
"CURRENT_USER": TokenType.CURRENT_USER,
@ -528,8 +505,6 @@ class Tokenizer(metaclass=_Tokenizer):
"DESC": TokenType.DESC,
"DESCRIBE": TokenType.DESCRIBE,
"DISTINCT": TokenType.DISTINCT,
"DISTINCT FROM": TokenType.DISTINCT_FROM,
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
"DIV": TokenType.DIV,
"DROP": TokenType.DROP,
"ELSE": TokenType.ELSE,
@ -544,18 +519,18 @@ class Tokenizer(metaclass=_Tokenizer):
"FIRST": TokenType.FIRST,
"FULL": TokenType.FULL,
"FUNCTION": TokenType.FUNCTION,
"FOLLOWING": TokenType.FOLLOWING,
"FOR": TokenType.FOR,
"FOREIGN KEY": TokenType.FOREIGN_KEY,
"FORMAT": TokenType.FORMAT,
"FROM": TokenType.FROM,
"GEOGRAPHY": TokenType.GEOGRAPHY,
"GEOMETRY": TokenType.GEOMETRY,
"GLOB": TokenType.GLOB,
"GROUP BY": TokenType.GROUP_BY,
"GROUPING SETS": TokenType.GROUPING_SETS,
"HAVING": TokenType.HAVING,
"IF": TokenType.IF,
"ILIKE": TokenType.ILIKE,
"IGNORE NULLS": TokenType.IGNORE_NULLS,
"IN": TokenType.IN,
"INDEX": TokenType.INDEX,
"INET": TokenType.INET,
@ -569,34 +544,25 @@ class Tokenizer(metaclass=_Tokenizer):
"JOIN": TokenType.JOIN,
"KEEP": TokenType.KEEP,
"LATERAL": TokenType.LATERAL,
"LAZY": TokenType.LAZY,
"LEADING": TokenType.LEADING,
"LEFT": TokenType.LEFT,
"LIKE": TokenType.LIKE,
"LIMIT": TokenType.LIMIT,
"LOAD DATA": TokenType.LOAD_DATA,
"LOCAL": TokenType.LOCAL,
"MATERIALIZED": TokenType.MATERIALIZED,
"LOAD": TokenType.LOAD,
"LOCK": TokenType.LOCK,
"MERGE": TokenType.MERGE,
"NATURAL": TokenType.NATURAL,
"NEXT": TokenType.NEXT,
"NEXT VALUE FOR": TokenType.NEXT_VALUE_FOR,
"NO ACTION": TokenType.NO_ACTION,
"NOT": TokenType.NOT,
"NOTNULL": TokenType.NOTNULL,
"NULL": TokenType.NULL,
"NULLS FIRST": TokenType.NULLS_FIRST,
"NULLS LAST": TokenType.NULLS_LAST,
"OBJECT": TokenType.OBJECT,
"OFFSET": TokenType.OFFSET,
"ON": TokenType.ON,
"ONLY": TokenType.ONLY,
"OPTIONS": TokenType.OPTIONS,
"OR": TokenType.OR,
"ORDER BY": TokenType.ORDER_BY,
"ORDINALITY": TokenType.ORDINALITY,
"OUTER": TokenType.OUTER,
"OUT OF": TokenType.OUT_OF,
"OVER": TokenType.OVER,
"OVERLAPS": TokenType.OVERLAPS,
"OVERWRITE": TokenType.OVERWRITE,
@ -607,7 +573,6 @@ class Tokenizer(metaclass=_Tokenizer):
"PERCENT": TokenType.PERCENT,
"PIVOT": TokenType.PIVOT,
"PRAGMA": TokenType.PRAGMA,
"PRECEDING": TokenType.PRECEDING,
"PRIMARY KEY": TokenType.PRIMARY_KEY,
"PROCEDURE": TokenType.PROCEDURE,
"QUALIFY": TokenType.QUALIFY,
@ -615,7 +580,6 @@ class Tokenizer(metaclass=_Tokenizer):
"RECURSIVE": TokenType.RECURSIVE,
"REGEXP": TokenType.RLIKE,
"REPLACE": TokenType.REPLACE,
"RESPECT NULLS": TokenType.RESPECT_NULLS,
"REFERENCES": TokenType.REFERENCES,
"RIGHT": TokenType.RIGHT,
"RLIKE": TokenType.RLIKE,
@ -624,25 +588,20 @@ class Tokenizer(metaclass=_Tokenizer):
"ROW": TokenType.ROW,
"ROWS": TokenType.ROWS,
"SCHEMA": TokenType.SCHEMA,
"SEED": TokenType.SEED,
"SELECT": TokenType.SELECT,
"SEMI": TokenType.SEMI,
"SET": TokenType.SET,
"SETTINGS": TokenType.SETTINGS,
"SHOW": TokenType.SHOW,
"SIMILAR TO": TokenType.SIMILAR_TO,
"SOME": TokenType.SOME,
"SORTKEY": TokenType.SORTKEY,
"SORT BY": TokenType.SORT_BY,
"TABLE": TokenType.TABLE,
"TABLESAMPLE": TokenType.TABLE_SAMPLE,
"TEMP": TokenType.TEMPORARY,
"TEMPORARY": TokenType.TEMPORARY,
"THEN": TokenType.THEN,
"TRUE": TokenType.TRUE,
"TRAILING": TokenType.TRAILING,
"UNBOUNDED": TokenType.UNBOUNDED,
"UNION": TokenType.UNION,
"UNLOGGED": TokenType.UNLOGGED,
"UNNEST": TokenType.UNNEST,
"UNPIVOT": TokenType.UNPIVOT,
"UPDATE": TokenType.UPDATE,
@ -656,10 +615,6 @@ class Tokenizer(metaclass=_Tokenizer):
"WHERE": TokenType.WHERE,
"WINDOW": TokenType.WINDOW,
"WITH": TokenType.WITH,
"WITH TIME ZONE": TokenType.WITH_TIME_ZONE,
"WITH LOCAL TIME ZONE": TokenType.WITH_LOCAL_TIME_ZONE,
"WITHIN GROUP": TokenType.WITHIN_GROUP,
"WITHOUT TIME ZONE": TokenType.WITHOUT_TIME_ZONE,
"APPLY": TokenType.APPLY,
"ARRAY": TokenType.ARRAY,
"BIT": TokenType.BIT,
@ -718,15 +673,6 @@ class Tokenizer(metaclass=_Tokenizer):
"STRUCT": TokenType.STRUCT,
"VARIANT": TokenType.VARIANT,
"ALTER": TokenType.ALTER,
"ALTER AGGREGATE": TokenType.COMMAND,
"ALTER DEFAULT": TokenType.COMMAND,
"ALTER DOMAIN": TokenType.COMMAND,
"ALTER ROLE": TokenType.COMMAND,
"ALTER RULE": TokenType.COMMAND,
"ALTER SEQUENCE": TokenType.COMMAND,
"ALTER TYPE": TokenType.COMMAND,
"ALTER USER": TokenType.COMMAND,
"ALTER VIEW": TokenType.COMMAND,
"ANALYZE": TokenType.COMMAND,
"CALL": TokenType.COMMAND,
"COMMENT": TokenType.COMMENT,
@ -790,7 +736,7 @@ class Tokenizer(metaclass=_Tokenizer):
self._start = 0
self._current = 0
self._line = 1
self._col = 1
self._col = 0
self._comments: t.List[str] = []
self._char = ""
@ -803,13 +749,12 @@ class Tokenizer(metaclass=_Tokenizer):
self.reset()
self.sql = sql
self.size = len(sql)
try:
self._scan()
except Exception as e:
start = self._current - 50
end = self._current + 50
start = start if start > 0 else 0
end = end if end < self.size else self.size - 1
start = max(self._current - 50, 0)
end = min(self._current + 50, self.size - 1)
context = self.sql[start:end]
raise ValueError(f"Error tokenizing '{context}'") from e
@ -834,17 +779,17 @@ class Tokenizer(metaclass=_Tokenizer):
if until and until():
break
if self.tokens:
if self.tokens and self._comments:
self.tokens[-1].comments.extend(self._comments)
def _chars(self, size: int) -> str:
if size == 1:
return self._char
start = self._current - 1
end = start + size
if end <= self.size:
return self.sql[start:end]
return ""
return self.sql[start:end] if end <= self.size else ""
def _advance(self, i: int = 1, alnum: bool = False) -> None:
if self.WHITE_SPACE.get(self._char) is TokenType.BREAK:
@ -859,6 +804,7 @@ class Tokenizer(metaclass=_Tokenizer):
self._peek = "" if self._end else self.sql[self._current]
if alnum and self._char.isalnum():
# Here we use local variables instead of attributes for better performance
_col = self._col
_current = self._current
_end = self._end
@ -885,11 +831,12 @@ class Tokenizer(metaclass=_Tokenizer):
self.tokens.append(
Token(
token_type,
self._text if text is None else text,
self._line,
self._col,
self._current,
self._comments,
text=self._text if text is None else text,
line=self._line,
col=self._col,
start=self._start,
end=self._current - 1,
comments=self._comments,
)
)
self._comments = []
@ -929,6 +876,7 @@ class Tokenizer(metaclass=_Tokenizer):
break
if result == 2:
word = chars
size += 1
end = self._current - 1 + size
@ -946,6 +894,7 @@ class Tokenizer(metaclass=_Tokenizer):
else:
skip = True
else:
char = ""
chars = " "
word = None if not single_token and chars[-1] not in self.WHITE_SPACE else word
@ -959,8 +908,6 @@ class Tokenizer(metaclass=_Tokenizer):
if self._scan_string(word):
return
if self._scan_formatted_string(word):
return
if self._scan_comment(word):
return
@ -1004,9 +951,9 @@ class Tokenizer(metaclass=_Tokenizer):
if self._char == "0":
peek = self._peek.upper()
if peek == "B":
return self._scan_bits() if self._BIT_STRINGS else self._add(TokenType.NUMBER)
return self._scan_bits() if self.BIT_STRINGS else self._add(TokenType.NUMBER)
elif peek == "X":
return self._scan_hex() if self._HEX_STRINGS else self._add(TokenType.NUMBER)
return self._scan_hex() if self.HEX_STRINGS else self._add(TokenType.NUMBER)
decimal = False
scientific = 0
@ -1075,37 +1022,24 @@ class Tokenizer(metaclass=_Tokenizer):
return self._text
def _scan_string(self, quote: str) -> bool:
quote_end = self._QUOTES.get(quote)
if quote_end is None:
return False
def _scan_string(self, start: str) -> bool:
base = None
token_type = TokenType.STRING
self._advance(len(quote))
text = self._extract_string(quote_end)
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
self._add(TokenType.NATIONAL if quote[0].upper() == "N" else TokenType.STRING, text)
return True
if start in self._QUOTES:
end = self._QUOTES[start]
elif start in self._FORMAT_STRINGS:
end, token_type = self._FORMAT_STRINGS[start]
# X'1234', b'0110', E'\\\\\' etc.
def _scan_formatted_string(self, string_start: str) -> bool:
if string_start in self._HEX_STRINGS:
delimiters = self._HEX_STRINGS
token_type = TokenType.HEX_STRING
base = 16
elif string_start in self._BIT_STRINGS:
delimiters = self._BIT_STRINGS
token_type = TokenType.BIT_STRING
base = 2
elif string_start in self._BYTE_STRINGS:
delimiters = self._BYTE_STRINGS
token_type = TokenType.BYTE_STRING
base = None
if token_type == TokenType.HEX_STRING:
base = 16
elif token_type == TokenType.BIT_STRING:
base = 2
else:
return False
self._advance(len(string_start))
string_end = delimiters[string_start]
text = self._extract_string(string_end)
self._advance(len(start))
text = self._extract_string(end)
if base:
try:
@ -1114,6 +1048,8 @@ class Tokenizer(metaclass=_Tokenizer):
raise RuntimeError(
f"Numeric string contains invalid characters from {self._line}:{self._start}"
)
else:
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
self._add(token_type, text)
return True

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import typing as t
from sqlglot import expressions as exp
from sqlglot.helper import find_new_name
from sqlglot.helper import find_new_name, name_sequence
if t.TYPE_CHECKING:
from sqlglot.generator import Generator
@ -63,16 +63,17 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
distinct_cols = expression.args["distinct"].pop().args["on"].expressions
outer_selects = expression.selects
row_number = find_new_name(expression.named_selects, "_row_number")
window = exp.Window(
this=exp.RowNumber(),
partition_by=distinct_cols,
)
window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
order = expression.args.get("order")
if order:
window.set("order", order.pop().copy())
window = exp.alias_(window, row_number)
expression.select(window, copy=False)
return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1')
return expression
@ -93,7 +94,7 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
for select in expression.selects:
if not select.alias_or_name:
alias = find_new_name(taken, "_c")
select.replace(exp.alias_(select.copy(), alias))
select.replace(exp.alias_(select, alias))
taken.add(alias)
outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
@ -102,8 +103,9 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
for expr in qualify_filters.find_all((exp.Window, exp.Column)):
if isinstance(expr, exp.Window):
alias = find_new_name(expression.named_selects, "_w")
expression.select(exp.alias_(expr.copy(), alias), copy=False)
expression.select(exp.alias_(expr, alias), copy=False)
column = exp.column(alias)
if isinstance(expr.parent, exp.Qualify):
qualify_filters = column
else:
@ -123,6 +125,7 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr
"""
for node in expression.find_all(exp.DataType):
node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)])
return expression
@ -147,6 +150,7 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore
),
)
return expression
@ -156,7 +160,10 @@ def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
from sqlglot.optimizer.scope import build_scope
taken_select_names = set(expression.named_selects)
taken_source_names = set(build_scope(expression).selected_sources)
scope = build_scope(expression)
if not scope:
return expression
taken_source_names = set(scope.selected_sources)
for select in expression.selects:
to_replace = select
@ -226,6 +233,7 @@ def remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
else node,
copy=False,
)
return expression
@ -242,12 +250,20 @@ def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expre
return expression
def unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Pivot):
expression.args["field"].transform(
lambda node: exp.column(node.output_name) if isinstance(node, exp.Column) else node,
copy=False,
)
def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.With) and expression.recursive:
next_name = name_sequence("_c_")
for cte in expression.expressions:
if not cte.args["alias"].columns:
query = cte.this
if isinstance(query, exp.Union):
query = query.this
cte.args["alias"].set(
"columns",
[exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
)
return expression