Merging upstream version 15.0.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
8deb804d23
commit
fc63828ee4
167 changed files with 58268 additions and 51337 deletions
|
@ -6,6 +6,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
from sqlglot import expressions as exp
|
||||
|
@ -45,12 +46,19 @@ from sqlglot.schema import MappingSchema as MappingSchema, Schema as Schema
|
|||
from sqlglot.tokens import Tokenizer as Tokenizer, TokenType as TokenType
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import DialectType as DialectType
|
||||
|
||||
T = t.TypeVar("T", bound=Expression)
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
|
||||
__version__ = "12.2.0"
|
||||
try:
|
||||
from sqlglot._version import __version__, __version_tuple__
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"Unable to set __version__, run `pip install -e .` or `python setup.py develop` first."
|
||||
)
|
||||
|
||||
|
||||
pretty = False
|
||||
"""Whether to format generated SQL by default."""
|
||||
|
@ -79,9 +87,9 @@ def parse(sql: str, read: DialectType = None, **opts) -> t.List[t.Optional[Expre
|
|||
def parse_one(
|
||||
sql: str,
|
||||
read: None = None,
|
||||
into: t.Type[T] = ...,
|
||||
into: t.Type[E] = ...,
|
||||
**opts,
|
||||
) -> T:
|
||||
) -> E:
|
||||
...
|
||||
|
||||
|
||||
|
@ -89,9 +97,9 @@ def parse_one(
|
|||
def parse_one(
|
||||
sql: str,
|
||||
read: DialectType,
|
||||
into: t.Type[T],
|
||||
into: t.Type[E],
|
||||
**opts,
|
||||
) -> T:
|
||||
) -> E:
|
||||
...
|
||||
|
||||
|
||||
|
|
8
sqlglot/_typing.py
Normal file
8
sqlglot/_typing.py
Normal file
|
@ -0,0 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
import sqlglot
|
||||
|
||||
E = t.TypeVar("E", bound="sqlglot.exp.Expression")
|
||||
T = t.TypeVar("T")
|
|
@ -11,6 +11,8 @@ if t.TYPE_CHECKING:
|
|||
|
||||
ColumnLiterals = t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
|
||||
ColumnOrName = t.Union[Column, str]
|
||||
ColumnOrLiteral = t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
|
||||
ColumnOrLiteral = t.Union[
|
||||
Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime
|
||||
]
|
||||
SchemaInput = t.Union[str, t.List[str], StructType, t.Dict[str, t.Optional[str]]]
|
||||
OutputExpressionContainer = t.Union[exp.Select, exp.Create, exp.Insert]
|
|
@ -127,7 +127,7 @@ class DataFrame:
|
|||
sequence_id: t.Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> t.Tuple[exp.CTE, str]:
|
||||
name = self.spark._random_name
|
||||
name = self._create_hash_from_expression(expression)
|
||||
expression_to_cte = expression.copy()
|
||||
expression_to_cte.set("with", None)
|
||||
cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0]
|
||||
|
@ -263,7 +263,7 @@ class DataFrame:
|
|||
return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions]
|
||||
|
||||
@classmethod
|
||||
def _create_hash_from_expression(cls, expression: exp.Select):
|
||||
def _create_hash_from_expression(cls, expression: exp.Expression) -> str:
|
||||
value = expression.sql(dialect="spark").encode("utf-8")
|
||||
return f"t{zlib.crc32(value)}"[:6]
|
||||
|
||||
|
@ -299,7 +299,7 @@ class DataFrame:
|
|||
for expression_type, select_expression in select_expressions:
|
||||
select_expression = select_expression.transform(replace_id_value, replacement_mapping)
|
||||
if optimize:
|
||||
select_expression = optimize_func(select_expression, identify="always")
|
||||
select_expression = t.cast(exp.Select, optimize_func(select_expression))
|
||||
select_expression = df._replace_cte_names_with_hashes(select_expression)
|
||||
expression: t.Union[exp.Select, exp.Cache, exp.Drop]
|
||||
if expression_type == exp.Cache:
|
||||
|
@ -570,9 +570,9 @@ class DataFrame:
|
|||
r_expressions.append(l_column)
|
||||
r_columns_unused.remove(l_column)
|
||||
else:
|
||||
r_expressions.append(exp.alias_(exp.Null(), l_column))
|
||||
r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False))
|
||||
for r_column in r_columns_unused:
|
||||
l_expressions.append(exp.alias_(exp.Null(), r_column))
|
||||
l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False))
|
||||
r_expressions.append(r_column)
|
||||
r_df = (
|
||||
other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
|
||||
|
@ -761,7 +761,7 @@ class DataFrame:
|
|||
raise ValueError("Tried to rename a column that doesn't exist")
|
||||
for existing_column in existing_columns:
|
||||
if isinstance(existing_column, exp.Column):
|
||||
existing_column.replace(exp.alias_(existing_column.copy(), new))
|
||||
existing_column.replace(exp.alias_(existing_column, new))
|
||||
else:
|
||||
existing_column.set("alias", exp.to_identifier(new))
|
||||
return self.copy(expression=expression)
|
||||
|
|
|
@ -41,7 +41,7 @@ def operation(op: Operation):
|
|||
self.last_op = Operation.NO_OP
|
||||
last_op = self.last_op
|
||||
new_op = op if op != Operation.NO_OP else last_op
|
||||
if new_op < last_op or (last_op == new_op and new_op == Operation.SELECT):
|
||||
if new_op < last_op or (last_op == new_op == Operation.SELECT):
|
||||
self = self._convert_leaf_to_cte()
|
||||
df: t.Union[DataFrame, GroupedData] = func(self, *args, **kwargs)
|
||||
df.last_op = new_op # type: ignore
|
||||
|
|
|
@ -87,15 +87,13 @@ class SparkSession:
|
|||
select_kwargs = {
|
||||
"expressions": sel_columns,
|
||||
"from": exp.From(
|
||||
expressions=[
|
||||
exp.Values(
|
||||
expressions=data_expressions,
|
||||
alias=exp.TableAlias(
|
||||
this=exp.to_identifier(self._auto_incrementing_name),
|
||||
columns=[exp.to_identifier(col_name) for col_name in column_mapping],
|
||||
),
|
||||
this=exp.Values(
|
||||
expressions=data_expressions,
|
||||
alias=exp.TableAlias(
|
||||
this=exp.to_identifier(self._auto_incrementing_name),
|
||||
columns=[exp.to_identifier(col_name) for col_name in column_mapping],
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
|
@ -127,10 +125,6 @@ class SparkSession:
|
|||
self.incrementing_id += 1
|
||||
return name
|
||||
|
||||
@property
|
||||
def _random_name(self) -> str:
|
||||
return "r" + uuid.uuid4().hex
|
||||
|
||||
@property
|
||||
def _random_branch_id(self) -> str:
|
||||
id = self._random_id
|
||||
|
@ -145,7 +139,7 @@ class SparkSession:
|
|||
|
||||
@property
|
||||
def _random_id(self) -> str:
|
||||
id = self._random_name
|
||||
id = "r" + uuid.uuid4().hex
|
||||
self.known_ids.add(id)
|
||||
return id
|
||||
|
||||
|
|
|
@ -27,6 +27,6 @@ def get_tables_from_expression_with_join(expression: exp.Select) -> t.List[exp.T
|
|||
if not expression.args.get("joins"):
|
||||
return []
|
||||
|
||||
left_table = expression.args["from"].args["expressions"][0]
|
||||
left_table = expression.args["from"].this
|
||||
other_tables = [join.this for join in expression.args["joins"]]
|
||||
return [left_table] + other_tables
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
"""Supports BigQuery Standard SQL."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
@ -18,11 +16,9 @@ from sqlglot.dialects.dialect import (
|
|||
timestrtotime_sql,
|
||||
ts_or_ds_to_date_sql,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.helper import seq_get, split_num_words
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
E = t.TypeVar("E", bound=exp.Expression)
|
||||
|
||||
|
||||
def _date_add_sql(
|
||||
data_type: str, kind: str
|
||||
|
@ -96,19 +92,12 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
|
|||
These are added by the optimizer's qualify_column step.
|
||||
"""
|
||||
if isinstance(expression, exp.Select):
|
||||
unnests = {
|
||||
unnest.alias
|
||||
for unnest in expression.args.get("from", exp.From(expressions=[])).expressions
|
||||
if isinstance(unnest, exp.Unnest) and unnest.alias
|
||||
}
|
||||
|
||||
if unnests:
|
||||
expression = expression.copy()
|
||||
|
||||
for select in expression.expressions:
|
||||
for column in select.find_all(exp.Column):
|
||||
if column.table in unnests:
|
||||
column.set("table", None)
|
||||
for unnest in expression.find_all(exp.Unnest):
|
||||
if isinstance(unnest.parent, (exp.From, exp.Join)) and unnest.alias:
|
||||
for select in expression.selects:
|
||||
for column in select.find_all(exp.Column):
|
||||
if column.table == unnest.alias:
|
||||
column.set("table", None)
|
||||
|
||||
return expression
|
||||
|
||||
|
@ -127,16 +116,20 @@ class BigQuery(Dialect):
|
|||
}
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
QUOTES = [
|
||||
(prefix + quote, quote) if prefix else quote
|
||||
for quote in ["'", '"', '"""', "'''"]
|
||||
for prefix in ["", "r", "R"]
|
||||
]
|
||||
QUOTES = ["'", '"', '"""', "'''"]
|
||||
COMMENTS = ["--", "#", ("/*", "*/")]
|
||||
IDENTIFIERS = ["`"]
|
||||
STRING_ESCAPES = ["\\"]
|
||||
|
||||
HEX_STRINGS = [("0x", ""), ("0X", "")]
|
||||
BYTE_STRINGS = [("b'", "'"), ("B'", "'")]
|
||||
|
||||
BYTE_STRINGS = [
|
||||
(prefix + q, q) for q in t.cast(t.List[str], QUOTES) for prefix in ("b", "B")
|
||||
]
|
||||
|
||||
RAW_STRINGS = [
|
||||
(prefix + q, q) for q in t.cast(t.List[str], QUOTES) for prefix in ("r", "R")
|
||||
]
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
|
@ -144,11 +137,11 @@ class BigQuery(Dialect):
|
|||
"BEGIN": TokenType.COMMAND,
|
||||
"BEGIN TRANSACTION": TokenType.BEGIN,
|
||||
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
|
||||
"BYTES": TokenType.BINARY,
|
||||
"DECLARE": TokenType.COMMAND,
|
||||
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
||||
"FLOAT64": TokenType.DOUBLE,
|
||||
"INT64": TokenType.BIGINT,
|
||||
"BYTES": TokenType.BINARY,
|
||||
"RECORD": TokenType.STRUCT,
|
||||
"NOT DETERMINISTIC": TokenType.VOLATILE,
|
||||
"UNKNOWN": TokenType.NULL,
|
||||
}
|
||||
|
@ -161,7 +154,7 @@ class BigQuery(Dialect):
|
|||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"DATE_TRUNC": lambda args: exp.DateTrunc(
|
||||
unit=exp.Literal.string(str(seq_get(args, 1))),
|
||||
this=seq_get(args, 0),
|
||||
|
@ -191,28 +184,28 @@ class BigQuery(Dialect):
|
|||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS, # type: ignore
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
"ARRAY": lambda self: self.expression(exp.Array, expressions=[self._parse_statement()]),
|
||||
}
|
||||
FUNCTION_PARSERS.pop("TRIM")
|
||||
|
||||
NO_PAREN_FUNCTIONS = {
|
||||
**parser.Parser.NO_PAREN_FUNCTIONS, # type: ignore
|
||||
**parser.Parser.NO_PAREN_FUNCTIONS,
|
||||
TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
|
||||
}
|
||||
|
||||
NESTED_TYPE_TOKENS = {
|
||||
*parser.Parser.NESTED_TYPE_TOKENS, # type: ignore
|
||||
*parser.Parser.NESTED_TYPE_TOKENS,
|
||||
TokenType.TABLE,
|
||||
}
|
||||
|
||||
ID_VAR_TOKENS = {
|
||||
*parser.Parser.ID_VAR_TOKENS, # type: ignore
|
||||
*parser.Parser.ID_VAR_TOKENS,
|
||||
TokenType.VALUES,
|
||||
}
|
||||
|
||||
PROPERTY_PARSERS = {
|
||||
**parser.Parser.PROPERTY_PARSERS, # type: ignore
|
||||
**parser.Parser.PROPERTY_PARSERS,
|
||||
"NOT DETERMINISTIC": lambda self: self.expression(
|
||||
exp.StabilityProperty, this=exp.Literal.string("VOLATILE")
|
||||
),
|
||||
|
@ -220,19 +213,50 @@ class BigQuery(Dialect):
|
|||
}
|
||||
|
||||
CONSTRAINT_PARSERS = {
|
||||
**parser.Parser.CONSTRAINT_PARSERS, # type: ignore
|
||||
**parser.Parser.CONSTRAINT_PARSERS,
|
||||
"OPTIONS": lambda self: exp.Properties(expressions=self._parse_with_property()),
|
||||
}
|
||||
|
||||
def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]:
|
||||
this = super()._parse_table_part(schema=schema)
|
||||
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#table_names
|
||||
if isinstance(this, exp.Identifier):
|
||||
table_name = this.name
|
||||
while self._match(TokenType.DASH, advance=False) and self._next:
|
||||
self._advance(2)
|
||||
table_name += f"-{self._prev.text}"
|
||||
|
||||
this = exp.Identifier(this=table_name, quoted=this.args.get("quoted"))
|
||||
|
||||
return this
|
||||
|
||||
def _parse_table_parts(self, schema: bool = False) -> exp.Table:
|
||||
table = super()._parse_table_parts(schema=schema)
|
||||
if isinstance(table.this, exp.Identifier) and "." in table.name:
|
||||
catalog, db, this, *rest = (
|
||||
t.cast(t.Optional[exp.Expression], exp.to_identifier(x))
|
||||
for x in split_num_words(table.name, ".", 3)
|
||||
)
|
||||
|
||||
if rest and this:
|
||||
this = exp.Dot.build(t.cast(t.List[exp.Expression], [this, *rest]))
|
||||
|
||||
table = exp.Table(this=this, db=db, catalog=catalog)
|
||||
|
||||
return table
|
||||
|
||||
class Generator(generator.Generator):
|
||||
EXPLICIT_UNION = True
|
||||
INTERVAL_ALLOWS_PLURAL_FORM = False
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
RENAME_TABLE_WITH_DB = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
|
||||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||
exp.AtTimeZone: lambda self, e: self.func(
|
||||
"TIMESTAMP", self.func("DATETIME", e.this, e.args.get("zone"))
|
||||
|
@ -259,6 +283,7 @@ class BigQuery(Dialect):
|
|||
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
|
||||
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TryCast: lambda self, e: f"SAFE_CAST({self.sql(e, 'this')} AS {self.sql(e, 'to')})",
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"),
|
||||
exp.TsOrDsAdd: _date_add_sql("DATE", "ADD"),
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
|
@ -274,7 +299,7 @@ class BigQuery(Dialect):
|
|||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC",
|
||||
exp.DataType.Type.BIGINT: "INT64",
|
||||
exp.DataType.Type.BINARY: "BYTES",
|
||||
|
@ -297,7 +322,7 @@ class BigQuery(Dialect):
|
|||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
**generator.Generator.PROPERTIES_LOCATION,
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
|
|
@ -3,11 +3,16 @@ from __future__ import annotations
|
|||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
inline_array_sql,
|
||||
no_pivot_sql,
|
||||
rename_func,
|
||||
var_map_sql,
|
||||
)
|
||||
from sqlglot.errors import ParseError
|
||||
from sqlglot.helper import ensure_list, seq_get
|
||||
from sqlglot.parser import parse_var_map
|
||||
from sqlglot.tokens import TokenType
|
||||
from sqlglot.tokens import Token, TokenType
|
||||
|
||||
|
||||
def _lower_func(sql: str) -> str:
|
||||
|
@ -28,65 +33,122 @@ class ClickHouse(Dialect):
|
|||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"ASOF": TokenType.ASOF,
|
||||
"GLOBAL": TokenType.GLOBAL,
|
||||
"DATETIME64": TokenType.DATETIME,
|
||||
"ATTACH": TokenType.COMMAND,
|
||||
"DATETIME64": TokenType.DATETIME64,
|
||||
"FINAL": TokenType.FINAL,
|
||||
"FLOAT32": TokenType.FLOAT,
|
||||
"FLOAT64": TokenType.DOUBLE,
|
||||
"INT8": TokenType.TINYINT,
|
||||
"UINT8": TokenType.UTINYINT,
|
||||
"INT16": TokenType.SMALLINT,
|
||||
"UINT16": TokenType.USMALLINT,
|
||||
"INT32": TokenType.INT,
|
||||
"UINT32": TokenType.UINT,
|
||||
"INT64": TokenType.BIGINT,
|
||||
"UINT64": TokenType.UBIGINT,
|
||||
"GLOBAL": TokenType.GLOBAL,
|
||||
"INT128": TokenType.INT128,
|
||||
"UINT128": TokenType.UINT128,
|
||||
"INT16": TokenType.SMALLINT,
|
||||
"INT256": TokenType.INT256,
|
||||
"UINT256": TokenType.UINT256,
|
||||
"INT32": TokenType.INT,
|
||||
"INT64": TokenType.BIGINT,
|
||||
"INT8": TokenType.TINYINT,
|
||||
"MAP": TokenType.MAP,
|
||||
"TUPLE": TokenType.STRUCT,
|
||||
"UINT128": TokenType.UINT128,
|
||||
"UINT16": TokenType.USMALLINT,
|
||||
"UINT256": TokenType.UINT256,
|
||||
"UINT32": TokenType.UINT,
|
||||
"UINT64": TokenType.UBIGINT,
|
||||
"UINT8": TokenType.UTINYINT,
|
||||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"EXPONENTIALTIMEDECAYEDAVG": lambda params, args: exp.ExponentialTimeDecayedAvg(
|
||||
this=seq_get(args, 0),
|
||||
time=seq_get(args, 1),
|
||||
decay=seq_get(params, 0),
|
||||
),
|
||||
"GROUPUNIQARRAY": lambda params, args: exp.GroupUniqArray(
|
||||
this=seq_get(args, 0), size=seq_get(params, 0)
|
||||
),
|
||||
"HISTOGRAM": lambda params, args: exp.Histogram(
|
||||
this=seq_get(args, 0), bins=seq_get(params, 0)
|
||||
),
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"ANY": exp.AnyValue.from_arg_list,
|
||||
"MAP": parse_var_map,
|
||||
"MATCH": exp.RegexpLike.from_arg_list,
|
||||
"QUANTILE": lambda params, args: exp.Quantile(this=args, quantile=params),
|
||||
"QUANTILES": lambda params, args: exp.Quantiles(parameters=params, expressions=args),
|
||||
"QUANTILEIF": lambda params, args: exp.QuantileIf(parameters=params, expressions=args),
|
||||
"UNIQ": exp.ApproxDistinct.from_arg_list,
|
||||
}
|
||||
|
||||
FUNCTIONS_WITH_ALIASED_ARGS = {*parser.Parser.FUNCTIONS_WITH_ALIASED_ARGS, "TUPLE"}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
"QUANTILE": lambda self: self._parse_quantile(),
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
|
||||
FUNCTION_PARSERS.pop("MATCH")
|
||||
|
||||
NO_PAREN_FUNCTION_PARSERS = parser.Parser.NO_PAREN_FUNCTION_PARSERS.copy()
|
||||
NO_PAREN_FUNCTION_PARSERS.pop(TokenType.ANY)
|
||||
|
||||
RANGE_PARSERS = {
|
||||
**parser.Parser.RANGE_PARSERS,
|
||||
TokenType.GLOBAL: lambda self, this: self._match(TokenType.IN)
|
||||
and self._parse_in(this, is_global=True),
|
||||
}
|
||||
|
||||
JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} # type: ignore
|
||||
# The PLACEHOLDER entry is popped because 1) it doesn't affect Clickhouse (it corresponds to
|
||||
# the postgres-specific JSONBContains parser) and 2) it makes parsing the ternary op simpler.
|
||||
COLUMN_OPERATORS = parser.Parser.COLUMN_OPERATORS.copy()
|
||||
COLUMN_OPERATORS.pop(TokenType.PLACEHOLDER)
|
||||
|
||||
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore
|
||||
JOIN_KINDS = {
|
||||
*parser.Parser.JOIN_KINDS,
|
||||
TokenType.ANY,
|
||||
TokenType.ASOF,
|
||||
TokenType.ANTI,
|
||||
TokenType.SEMI,
|
||||
}
|
||||
|
||||
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {
|
||||
TokenType.ANY,
|
||||
TokenType.ASOF,
|
||||
TokenType.SEMI,
|
||||
TokenType.ANTI,
|
||||
TokenType.SETTINGS,
|
||||
TokenType.FORMAT,
|
||||
}
|
||||
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
def _parse_in(
|
||||
self, this: t.Optional[exp.Expression], is_global: bool = False
|
||||
) -> exp.Expression:
|
||||
QUERY_MODIFIER_PARSERS = {
|
||||
**parser.Parser.QUERY_MODIFIER_PARSERS,
|
||||
"settings": lambda self: self._parse_csv(self._parse_conjunction)
|
||||
if self._match(TokenType.SETTINGS)
|
||||
else None,
|
||||
"format": lambda self: self._parse_id_var() if self._match(TokenType.FORMAT) else None,
|
||||
}
|
||||
|
||||
def _parse_conjunction(self) -> t.Optional[exp.Expression]:
|
||||
this = super()._parse_conjunction()
|
||||
|
||||
if self._match(TokenType.PLACEHOLDER):
|
||||
return self.expression(
|
||||
exp.If,
|
||||
this=this,
|
||||
true=self._parse_conjunction(),
|
||||
false=self._match(TokenType.COLON) and self._parse_conjunction(),
|
||||
)
|
||||
|
||||
return this
|
||||
|
||||
def _parse_placeholder(self) -> t.Optional[exp.Expression]:
|
||||
"""
|
||||
Parse a placeholder expression like SELECT {abc: UInt32} or FROM {table: Identifier}
|
||||
https://clickhouse.com/docs/en/sql-reference/syntax#defining-and-using-query-parameters
|
||||
"""
|
||||
if not self._match(TokenType.L_BRACE):
|
||||
return None
|
||||
|
||||
this = self._parse_id_var()
|
||||
self._match(TokenType.COLON)
|
||||
kind = self._parse_types(check_func=False) or (
|
||||
self._match_text_seq("IDENTIFIER") and "Identifier"
|
||||
)
|
||||
|
||||
if not kind:
|
||||
self.raise_error("Expecting a placeholder type or 'Identifier' for tables")
|
||||
elif not self._match(TokenType.R_BRACE):
|
||||
self.raise_error("Expecting }")
|
||||
|
||||
return self.expression(exp.Placeholder, this=this, kind=kind)
|
||||
|
||||
def _parse_in(self, this: t.Optional[exp.Expression], is_global: bool = False) -> exp.In:
|
||||
this = super()._parse_in(this)
|
||||
this.set("is_global", is_global)
|
||||
return this
|
||||
|
@ -120,81 +182,142 @@ class ClickHouse(Dialect):
|
|||
|
||||
return self.expression(exp.CTE, this=statement, alias=statement and statement.this)
|
||||
|
||||
def _parse_join_side_and_kind(
|
||||
self,
|
||||
) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]:
|
||||
is_global = self._match(TokenType.GLOBAL) and self._prev
|
||||
kind_pre = self._match_set(self.JOIN_KINDS, advance=False) and self._prev
|
||||
if kind_pre:
|
||||
kind = self._match_set(self.JOIN_KINDS) and self._prev
|
||||
side = self._match_set(self.JOIN_SIDES) and self._prev
|
||||
return is_global, side, kind
|
||||
return (
|
||||
is_global,
|
||||
self._match_set(self.JOIN_SIDES) and self._prev,
|
||||
self._match_set(self.JOIN_KINDS) and self._prev,
|
||||
)
|
||||
|
||||
def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]:
|
||||
join = super()._parse_join(skip_join_token)
|
||||
|
||||
if join:
|
||||
join.set("global", join.args.pop("natural", None))
|
||||
return join
|
||||
|
||||
def _parse_function(
|
||||
self, functions: t.Optional[t.Dict[str, t.Callable]] = None, anonymous: bool = False
|
||||
) -> t.Optional[exp.Expression]:
|
||||
func = super()._parse_function(functions, anonymous)
|
||||
|
||||
if isinstance(func, exp.Anonymous):
|
||||
params = self._parse_func_params(func)
|
||||
|
||||
if params:
|
||||
return self.expression(
|
||||
exp.ParameterizedAgg,
|
||||
this=func.this,
|
||||
expressions=func.expressions,
|
||||
params=params,
|
||||
)
|
||||
|
||||
return func
|
||||
|
||||
def _parse_func_params(
|
||||
self, this: t.Optional[exp.Func] = None
|
||||
) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
|
||||
if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN):
|
||||
return self._parse_csv(self._parse_lambda)
|
||||
if self._match(TokenType.L_PAREN):
|
||||
params = self._parse_csv(self._parse_lambda)
|
||||
self._match_r_paren(this)
|
||||
return params
|
||||
return None
|
||||
|
||||
def _parse_quantile(self) -> exp.Quantile:
|
||||
this = self._parse_lambda()
|
||||
params = self._parse_func_params()
|
||||
if params:
|
||||
return self.expression(exp.Quantile, this=params[0], quantile=this)
|
||||
return self.expression(exp.Quantile, this=this, quantile=exp.Literal.number(0.5))
|
||||
|
||||
def _parse_wrapped_id_vars(
|
||||
self, optional: bool = False
|
||||
) -> t.List[t.Optional[exp.Expression]]:
|
||||
return super()._parse_wrapped_id_vars(optional=True)
|
||||
|
||||
class Generator(generator.Generator):
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.NULLABLE: "Nullable",
|
||||
exp.DataType.Type.DATETIME: "DateTime64",
|
||||
exp.DataType.Type.MAP: "Map",
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.ARRAY: "Array",
|
||||
exp.DataType.Type.BIGINT: "Int64",
|
||||
exp.DataType.Type.DATETIME64: "DateTime64",
|
||||
exp.DataType.Type.DOUBLE: "Float64",
|
||||
exp.DataType.Type.FLOAT: "Float32",
|
||||
exp.DataType.Type.INT: "Int32",
|
||||
exp.DataType.Type.INT128: "Int128",
|
||||
exp.DataType.Type.INT256: "Int256",
|
||||
exp.DataType.Type.MAP: "Map",
|
||||
exp.DataType.Type.NULLABLE: "Nullable",
|
||||
exp.DataType.Type.SMALLINT: "Int16",
|
||||
exp.DataType.Type.STRUCT: "Tuple",
|
||||
exp.DataType.Type.TINYINT: "Int8",
|
||||
exp.DataType.Type.UTINYINT: "UInt8",
|
||||
exp.DataType.Type.SMALLINT: "Int16",
|
||||
exp.DataType.Type.USMALLINT: "UInt16",
|
||||
exp.DataType.Type.INT: "Int32",
|
||||
exp.DataType.Type.UINT: "UInt32",
|
||||
exp.DataType.Type.BIGINT: "Int64",
|
||||
exp.DataType.Type.UBIGINT: "UInt64",
|
||||
exp.DataType.Type.INT128: "Int128",
|
||||
exp.DataType.Type.UINT: "UInt32",
|
||||
exp.DataType.Type.UINT128: "UInt128",
|
||||
exp.DataType.Type.INT256: "Int256",
|
||||
exp.DataType.Type.UINT256: "UInt256",
|
||||
exp.DataType.Type.FLOAT: "Float32",
|
||||
exp.DataType.Type.DOUBLE: "Float64",
|
||||
exp.DataType.Type.USMALLINT: "UInt16",
|
||||
exp.DataType.Type.UTINYINT: "UInt8",
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.AnyValue: rename_func("any"),
|
||||
exp.ApproxDistinct: rename_func("uniq"),
|
||||
exp.Array: inline_array_sql,
|
||||
exp.ExponentialTimeDecayedAvg: lambda self, e: f"exponentialTimeDecayedAvg{self._param_args_sql(e, 'decay', ['this', 'time'])}",
|
||||
exp.CastToStrType: rename_func("CAST"),
|
||||
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
|
||||
exp.GroupUniqArray: lambda self, e: f"groupUniqArray{self._param_args_sql(e, 'size', 'this')}",
|
||||
exp.Histogram: lambda self, e: f"histogram{self._param_args_sql(e, 'bins', 'this')}",
|
||||
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||
exp.Quantile: lambda self, e: f"quantile{self._param_args_sql(e, 'quantile', 'this')}",
|
||||
exp.Quantiles: lambda self, e: f"quantiles{self._param_args_sql(e, 'parameters', 'expressions')}",
|
||||
exp.QuantileIf: lambda self, e: f"quantileIf{self._param_args_sql(e, 'parameters', 'expressions')}",
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.Quantile: lambda self, e: self.func("quantile", e.args.get("quantile"))
|
||||
+ f"({self.sql(e, 'this')})",
|
||||
exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})",
|
||||
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
|
||||
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
**generator.Generator.PROPERTIES_LOCATION,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
}
|
||||
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
EXPLICIT_UNION = True
|
||||
|
||||
def _param_args_sql(
|
||||
self,
|
||||
expression: exp.Expression,
|
||||
param_names: str | t.List[str],
|
||||
arg_names: str | t.List[str],
|
||||
) -> str:
|
||||
params = self.format_args(
|
||||
*(
|
||||
arg
|
||||
for name in ensure_list(param_names)
|
||||
for arg in ensure_list(expression.args.get(name))
|
||||
)
|
||||
)
|
||||
args = self.format_args(
|
||||
*(
|
||||
arg
|
||||
for name in ensure_list(arg_names)
|
||||
for arg in ensure_list(expression.args.get(name))
|
||||
)
|
||||
)
|
||||
return f"({params})({args})"
|
||||
GROUPINGS_SEP = ""
|
||||
|
||||
def cte_sql(self, expression: exp.CTE) -> str:
|
||||
if isinstance(expression.this, exp.Alias):
|
||||
return self.sql(expression, "this")
|
||||
|
||||
return super().cte_sql(expression)
|
||||
|
||||
def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]:
|
||||
return super().after_limit_modifiers(expression) + [
|
||||
self.seg("SETTINGS ") + self.expressions(expression, key="settings", flat=True)
|
||||
if expression.args.get("settings")
|
||||
else "",
|
||||
self.seg("FORMAT ") + self.sql(expression, "format")
|
||||
if expression.args.get("format")
|
||||
else "",
|
||||
]
|
||||
|
||||
def parameterizedagg_sql(self, expression: exp.Anonymous) -> str:
|
||||
params = self.expressions(expression, "params", flat=True)
|
||||
return self.func(expression.name, *expression.expressions) + f"({params})"
|
||||
|
||||
def placeholder_sql(self, expression: exp.Placeholder) -> str:
|
||||
return f"{{{expression.name}: {self.sql(expression, 'kind')}}}"
|
||||
|
|
|
@ -25,7 +25,7 @@ class Databricks(Spark):
|
|||
|
||||
class Generator(Spark.Generator):
|
||||
TRANSFORMS = {
|
||||
**Spark.Generator.TRANSFORMS, # type: ignore
|
||||
**Spark.Generator.TRANSFORMS,
|
||||
exp.DateAdd: generate_date_delta_with_unit_sql,
|
||||
exp.DateDiff: generate_date_delta_with_unit_sql,
|
||||
exp.JSONExtract: lambda self, e: self.binary(e, ":"),
|
||||
|
|
|
@ -8,10 +8,16 @@ from sqlglot.generator import Generator
|
|||
from sqlglot.helper import flatten, seq_get
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.time import format_time
|
||||
from sqlglot.tokens import Token, Tokenizer
|
||||
from sqlglot.tokens import Token, Tokenizer, TokenType
|
||||
from sqlglot.trie import new_trie
|
||||
|
||||
E = t.TypeVar("E", bound=exp.Expression)
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E
|
||||
|
||||
|
||||
# Only Snowflake is currently known to resolve unquoted identifiers as uppercase.
|
||||
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = {"snowflake"}
|
||||
|
||||
|
||||
class Dialects(str, Enum):
|
||||
|
@ -42,6 +48,19 @@ class Dialects(str, Enum):
|
|||
class _Dialect(type):
|
||||
classes: t.Dict[str, t.Type[Dialect]] = {}
|
||||
|
||||
def __eq__(cls, other: t.Any) -> bool:
|
||||
if cls is other:
|
||||
return True
|
||||
if isinstance(other, str):
|
||||
return cls is cls.get(other)
|
||||
if isinstance(other, Dialect):
|
||||
return cls is type(other)
|
||||
|
||||
return False
|
||||
|
||||
def __hash__(cls) -> int:
|
||||
return hash(cls.__name__.lower())
|
||||
|
||||
@classmethod
|
||||
def __getitem__(cls, key: str) -> t.Type[Dialect]:
|
||||
return cls.classes[key]
|
||||
|
@ -70,17 +89,20 @@ class _Dialect(type):
|
|||
klass.tokenizer_class._IDENTIFIERS.items()
|
||||
)[0]
|
||||
|
||||
klass.bit_start, klass.bit_end = seq_get(
|
||||
list(klass.tokenizer_class._BIT_STRINGS.items()), 0
|
||||
) or (None, None)
|
||||
def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
|
||||
return next(
|
||||
(
|
||||
(s, e)
|
||||
for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
|
||||
if t == token_type
|
||||
),
|
||||
(None, None),
|
||||
)
|
||||
|
||||
klass.hex_start, klass.hex_end = seq_get(
|
||||
list(klass.tokenizer_class._HEX_STRINGS.items()), 0
|
||||
) or (None, None)
|
||||
|
||||
klass.byte_start, klass.byte_end = seq_get(
|
||||
list(klass.tokenizer_class._BYTE_STRINGS.items()), 0
|
||||
) or (None, None)
|
||||
klass.bit_start, klass.bit_end = get_start_end(TokenType.BIT_STRING)
|
||||
klass.hex_start, klass.hex_end = get_start_end(TokenType.HEX_STRING)
|
||||
klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING)
|
||||
klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING)
|
||||
|
||||
return klass
|
||||
|
||||
|
@ -110,6 +132,12 @@ class Dialect(metaclass=_Dialect):
|
|||
parser_class = None
|
||||
generator_class = None
|
||||
|
||||
def __eq__(self, other: t.Any) -> bool:
|
||||
return type(self) == other
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(type(self))
|
||||
|
||||
@classmethod
|
||||
def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
|
||||
if not dialect:
|
||||
|
@ -192,6 +220,8 @@ class Dialect(metaclass=_Dialect):
|
|||
"hex_end": self.hex_end,
|
||||
"byte_start": self.byte_start,
|
||||
"byte_end": self.byte_end,
|
||||
"raw_start": self.raw_start,
|
||||
"raw_end": self.raw_end,
|
||||
"identifier_start": self.identifier_start,
|
||||
"identifier_end": self.identifier_end,
|
||||
"string_escape": self.tokenizer_class.STRING_ESCAPES[0],
|
||||
|
@ -275,7 +305,7 @@ def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
|
|||
|
||||
def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
|
||||
self.unsupported("PIVOT unsupported")
|
||||
return self.sql(expression)
|
||||
return ""
|
||||
|
||||
|
||||
def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
|
||||
|
@ -328,7 +358,7 @@ def var_map_sql(
|
|||
|
||||
def format_time_lambda(
|
||||
exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
|
||||
) -> t.Callable[[t.Sequence], E]:
|
||||
) -> t.Callable[[t.List], E]:
|
||||
"""Helper used for time expressions.
|
||||
|
||||
Args:
|
||||
|
@ -340,7 +370,7 @@ def format_time_lambda(
|
|||
A callable that can be used to return the appropriately formatted time expression.
|
||||
"""
|
||||
|
||||
def _format_time(args: t.Sequence):
|
||||
def _format_time(args: t.List):
|
||||
return exp_class(
|
||||
this=seq_get(args, 0),
|
||||
format=Dialect[dialect].format_time(
|
||||
|
@ -377,12 +407,12 @@ def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
|
|||
|
||||
def parse_date_delta(
|
||||
exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
|
||||
) -> t.Callable[[t.Sequence], E]:
|
||||
def inner_func(args: t.Sequence) -> E:
|
||||
) -> t.Callable[[t.List], E]:
|
||||
def inner_func(args: t.List) -> E:
|
||||
unit_based = len(args) == 3
|
||||
this = args[2] if unit_based else seq_get(args, 0)
|
||||
unit = args[0] if unit_based else exp.Literal.string("DAY")
|
||||
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
|
||||
unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
|
||||
return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
|
||||
|
||||
return inner_func
|
||||
|
@ -390,8 +420,8 @@ def parse_date_delta(
|
|||
|
||||
def parse_date_delta_with_interval(
|
||||
expression_class: t.Type[E],
|
||||
) -> t.Callable[[t.Sequence], t.Optional[E]]:
|
||||
def func(args: t.Sequence) -> t.Optional[E]:
|
||||
) -> t.Callable[[t.List], t.Optional[E]]:
|
||||
def func(args: t.List) -> t.Optional[E]:
|
||||
if len(args) < 2:
|
||||
return None
|
||||
|
||||
|
@ -409,7 +439,7 @@ def parse_date_delta_with_interval(
|
|||
return func
|
||||
|
||||
|
||||
def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
|
||||
def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
|
||||
unit = seq_get(args, 0)
|
||||
this = seq_get(args, 1)
|
||||
|
||||
|
@ -424,7 +454,7 @@ def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
|
|||
)
|
||||
|
||||
|
||||
def locate_to_strposition(args: t.Sequence) -> exp.Expression:
|
||||
def locate_to_strposition(args: t.List) -> exp.Expression:
|
||||
return exp.StrPosition(
|
||||
this=seq_get(args, 1),
|
||||
substr=seq_get(args, 0),
|
||||
|
@ -483,7 +513,7 @@ def trim_sql(self: Generator, expression: exp.Trim) -> str:
|
|||
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
|
||||
|
||||
|
||||
def str_to_time_sql(self, expression: exp.Expression) -> str:
|
||||
def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
|
||||
return self.func("STRPTIME", expression.this, self.format_time(expression))
|
||||
|
||||
|
||||
|
@ -496,3 +526,26 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
|
|||
return f"CAST({self.sql(expression, 'this')} AS DATE)"
|
||||
|
||||
return _ts_or_ds_to_date_sql
|
||||
|
||||
|
||||
# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator
|
||||
def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
|
||||
names = []
|
||||
for agg in aggregations:
|
||||
if isinstance(agg, exp.Alias):
|
||||
names.append(agg.alias)
|
||||
else:
|
||||
"""
|
||||
This case corresponds to aggregations without aliases being used as suffixes
|
||||
(e.g. col_avg(foo)). We need to unquote identifiers because they're going to
|
||||
be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
|
||||
Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
|
||||
"""
|
||||
agg_all_unquoted = agg.transform(
|
||||
lambda node: exp.Identifier(this=node.name, quoted=False)
|
||||
if isinstance(node, exp.Identifier)
|
||||
else node
|
||||
)
|
||||
names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
|
||||
|
||||
return names
|
||||
|
|
|
@ -95,7 +95,7 @@ class Drill(Dialect):
|
|||
STRICT_CAST = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "drill"),
|
||||
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
|
||||
"TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"),
|
||||
|
@ -108,7 +108,7 @@ class Drill(Dialect):
|
|||
TABLE_HINTS = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.INT: "INTEGER",
|
||||
exp.DataType.Type.SMALLINT: "INTEGER",
|
||||
exp.DataType.Type.TINYINT: "INTEGER",
|
||||
|
@ -121,13 +121,13 @@ class Drill(Dialect):
|
|||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
**generator.Generator.PROPERTIES_LOCATION,
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
|
||||
exp.ArraySize: rename_func("REPEATED_COUNT"),
|
||||
|
|
|
@ -11,9 +11,9 @@ from sqlglot.dialects.dialect import (
|
|||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
no_comment_column_constraint_sql,
|
||||
no_pivot_sql,
|
||||
no_properties_sql,
|
||||
no_safe_divide_sql,
|
||||
pivot_column_names,
|
||||
rename_func,
|
||||
str_position_sql,
|
||||
str_to_time_sql,
|
||||
|
@ -31,10 +31,11 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s
|
|||
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
|
||||
|
||||
|
||||
def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
|
||||
def _date_delta_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = self.sql(expression, "unit").strip("'") or "DAY"
|
||||
return f"{this} + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
|
||||
op = "+" if isinstance(expression, exp.DateAdd) else "-"
|
||||
return f"{this} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
|
||||
|
||||
|
||||
def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
|
||||
|
@ -50,11 +51,11 @@ def _sort_array_sql(self: generator.Generator, expression: exp.SortArray) -> str
|
|||
return f"ARRAY_SORT({this})"
|
||||
|
||||
|
||||
def _sort_array_reverse(args: t.Sequence) -> exp.Expression:
|
||||
def _sort_array_reverse(args: t.List) -> exp.Expression:
|
||||
return exp.SortArray(this=seq_get(args, 0), asc=exp.false())
|
||||
|
||||
|
||||
def _parse_date_diff(args: t.Sequence) -> exp.Expression:
|
||||
def _parse_date_diff(args: t.List) -> exp.Expression:
|
||||
return exp.DateDiff(
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
|
@ -89,11 +90,14 @@ def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract
|
|||
|
||||
|
||||
class DuckDB(Dialect):
|
||||
null_ordering = "nulls_are_last"
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"~": TokenType.RLIKE,
|
||||
":=": TokenType.EQ,
|
||||
"//": TokenType.DIV,
|
||||
"ATTACH": TokenType.COMMAND,
|
||||
"BINARY": TokenType.VARBINARY,
|
||||
"BPCHAR": TokenType.TEXT,
|
||||
|
@ -104,6 +108,7 @@ class DuckDB(Dialect):
|
|||
"INT1": TokenType.TINYINT,
|
||||
"LOGICAL": TokenType.BOOLEAN,
|
||||
"NUMERIC": TokenType.DOUBLE,
|
||||
"PIVOT_WIDER": TokenType.PIVOT,
|
||||
"SIGNED": TokenType.INT,
|
||||
"STRING": TokenType.VARCHAR,
|
||||
"UBIGINT": TokenType.UBIGINT,
|
||||
|
@ -114,8 +119,7 @@ class DuckDB(Dialect):
|
|||
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
|
||||
"ARRAY_SORT": exp.SortArray.from_arg_list,
|
||||
"ARRAY_REVERSE_SORT": _sort_array_reverse,
|
||||
|
@ -152,11 +156,17 @@ class DuckDB(Dialect):
|
|||
TokenType.UTINYINT,
|
||||
}
|
||||
|
||||
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
|
||||
if len(aggregations) == 1:
|
||||
return super()._pivot_column_names(aggregations)
|
||||
return pivot_column_names(aggregations, dialect="duckdb")
|
||||
|
||||
class Generator(generator.Generator):
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
RENAME_TABLE_WITH_DB = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -175,7 +185,8 @@ class DuckDB(Dialect):
|
|||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.DateAdd: _date_add_sql,
|
||||
exp.DateAdd: _date_delta_sql,
|
||||
exp.DateSub: _date_delta_sql,
|
||||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF", f"'{e.args.get('unit', 'day')}'", e.expression, e.this
|
||||
),
|
||||
|
@ -183,13 +194,13 @@ class DuckDB(Dialect):
|
|||
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
|
||||
exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)",
|
||||
exp.Explode: rename_func("UNNEST"),
|
||||
exp.IntDiv: lambda self, e: self.binary(e, "//"),
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.JSONBExtract: arrow_json_extract_sql,
|
||||
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.Properties: no_properties_sql,
|
||||
exp.RegexpExtract: _regexp_extract_sql,
|
||||
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
|
||||
|
@ -232,11 +243,11 @@ class DuckDB(Dialect):
|
|||
STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
**generator.Generator.PROPERTIES_LOCATION,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def tablesample_sql(
|
||||
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
|
||||
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep: str = " AS "
|
||||
) -> str:
|
||||
return super().tablesample_sql(expression, seed_prefix="REPEATABLE", sep=sep)
|
||||
|
|
|
@ -147,13 +147,6 @@ def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str
|
|||
return f"TO_DATE({this})"
|
||||
|
||||
|
||||
def _index_sql(self: generator.Generator, expression: exp.Index) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
table = self.sql(expression, "table")
|
||||
columns = self.sql(expression, "columns")
|
||||
return f"{this} ON TABLE {table} {columns}"
|
||||
|
||||
|
||||
class Hive(Dialect):
|
||||
alias_post_tablesample = True
|
||||
|
||||
|
@ -225,8 +218,7 @@ class Hive(Dialect):
|
|||
STRICT_CAST = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"BASE64": exp.ToBase64.from_arg_list,
|
||||
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
|
||||
"DATE_ADD": lambda args: exp.TsOrDsAdd(
|
||||
|
@ -271,21 +263,29 @@ class Hive(Dialect):
|
|||
}
|
||||
|
||||
PROPERTY_PARSERS = {
|
||||
**parser.Parser.PROPERTY_PARSERS, # type: ignore
|
||||
**parser.Parser.PROPERTY_PARSERS,
|
||||
"WITH SERDEPROPERTIES": lambda self: exp.SerdeProperties(
|
||||
expressions=self._parse_wrapped_csv(self._parse_property)
|
||||
),
|
||||
}
|
||||
|
||||
QUERY_MODIFIER_PARSERS = {
|
||||
**parser.Parser.QUERY_MODIFIER_PARSERS,
|
||||
"distribute": lambda self: self._parse_sort(exp.Distribute, "DISTRIBUTE", "BY"),
|
||||
"sort": lambda self: self._parse_sort(exp.Sort, "SORT", "BY"),
|
||||
"cluster": lambda self: self._parse_sort(exp.Cluster, "CLUSTER", "BY"),
|
||||
}
|
||||
|
||||
class Generator(generator.Generator):
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
TABLESAMPLE_WITH_METHOD = False
|
||||
TABLESAMPLE_SIZE_IS_PERCENT = True
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
INDEX_ON = "ON TABLE"
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.TEXT: "STRING",
|
||||
exp.DataType.Type.DATETIME: "TIMESTAMP",
|
||||
exp.DataType.Type.VARBINARY: "BINARY",
|
||||
|
@ -294,7 +294,7 @@ class Hive(Dialect):
|
|||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.Group: transforms.preprocess([transforms.unalias_group]),
|
||||
exp.Select: transforms.preprocess(
|
||||
[
|
||||
|
@ -319,7 +319,6 @@ class Hive(Dialect):
|
|||
exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
|
||||
exp.FromBase64: rename_func("UNBASE64"),
|
||||
exp.If: if_sql,
|
||||
exp.Index: _index_sql,
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
|
||||
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
|
||||
|
@ -342,7 +341,6 @@ class Hive(Dialect):
|
|||
exp.StrToTime: _str_to_time_sql,
|
||||
exp.StrToUnix: _str_to_unix_sql,
|
||||
exp.StructExtract: struct_extract_sql,
|
||||
exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}",
|
||||
exp.TimeStrToDate: rename_func("TO_DATE"),
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
|
@ -363,14 +361,13 @@ class Hive(Dialect):
|
|||
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
|
||||
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
|
||||
exp.LastDateOfMonth: rename_func("LAST_DAY"),
|
||||
exp.National: lambda self, e: self.sql(e, "this"),
|
||||
exp.National: lambda self, e: self.national_sql(e, prefix=""),
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
**generator.Generator.PROPERTIES_LOCATION,
|
||||
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
|
@ -396,3 +393,10 @@ class Hive(Dialect):
|
|||
expression = exp.DataType.build(expression.this)
|
||||
|
||||
return super().datatype_sql(expression)
|
||||
|
||||
def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
|
||||
return super().after_having_modifiers(expression) + [
|
||||
self.sql(expression, "distribute"),
|
||||
self.sql(expression, "sort"),
|
||||
self.sql(expression, "cluster"),
|
||||
]
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
|
@ -11,6 +13,7 @@ from sqlglot.dialects.dialect import (
|
|||
min_or_least,
|
||||
no_ilike_sql,
|
||||
no_paren_current_date_sql,
|
||||
no_pivot_sql,
|
||||
no_tablesample_sql,
|
||||
no_trycast_sql,
|
||||
parse_date_delta_with_interval,
|
||||
|
@ -21,14 +24,14 @@ from sqlglot.helper import seq_get
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _show_parser(*args, **kwargs):
|
||||
def _parse(self):
|
||||
def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[MySQL.Parser], exp.Show]:
|
||||
def _parse(self: MySQL.Parser) -> exp.Show:
|
||||
return self._parse_show_mysql(*args, **kwargs)
|
||||
|
||||
return _parse
|
||||
|
||||
|
||||
def _date_trunc_sql(self, expression):
|
||||
def _date_trunc_sql(self: generator.Generator, expression: exp.DateTrunc) -> str:
|
||||
expr = self.sql(expression, "this")
|
||||
unit = expression.text("unit")
|
||||
|
||||
|
@ -54,17 +57,17 @@ def _date_trunc_sql(self, expression):
|
|||
return f"STR_TO_DATE({concat}, '{date_format}')"
|
||||
|
||||
|
||||
def _str_to_date(args):
|
||||
def _str_to_date(args: t.List) -> exp.StrToDate:
|
||||
date_format = MySQL.format_time(seq_get(args, 1))
|
||||
return exp.StrToDate(this=seq_get(args, 0), format=date_format)
|
||||
|
||||
|
||||
def _str_to_date_sql(self, expression):
|
||||
def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate | exp.StrToTime) -> str:
|
||||
date_format = self.format_time(expression)
|
||||
return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})"
|
||||
|
||||
|
||||
def _trim_sql(self, expression):
|
||||
def _trim_sql(self: generator.Generator, expression: exp.Trim) -> str:
|
||||
target = self.sql(expression, "this")
|
||||
trim_type = self.sql(expression, "position")
|
||||
remove_chars = self.sql(expression, "expression")
|
||||
|
@ -79,8 +82,8 @@ def _trim_sql(self, expression):
|
|||
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
|
||||
|
||||
|
||||
def _date_add_sql(kind):
|
||||
def func(self, expression):
|
||||
def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
|
||||
def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = expression.text("unit").upper() or "DAY"
|
||||
return (
|
||||
|
@ -175,10 +178,10 @@ class MySQL(Dialect):
|
|||
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE} # type: ignore
|
||||
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE}
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
|
||||
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
|
||||
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
|
||||
|
@ -191,7 +194,7 @@ class MySQL(Dialect):
|
|||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS, # type: ignore
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
"GROUP_CONCAT": lambda self: self.expression(
|
||||
exp.GroupConcat,
|
||||
this=self._parse_lambda(),
|
||||
|
@ -199,13 +202,8 @@ class MySQL(Dialect):
|
|||
),
|
||||
}
|
||||
|
||||
PROPERTY_PARSERS = {
|
||||
**parser.Parser.PROPERTY_PARSERS, # type: ignore
|
||||
"ENGINE": lambda self: self._parse_property_assignment(exp.EngineProperty),
|
||||
}
|
||||
|
||||
STATEMENT_PARSERS = {
|
||||
**parser.Parser.STATEMENT_PARSERS, # type: ignore
|
||||
**parser.Parser.STATEMENT_PARSERS,
|
||||
TokenType.SHOW: lambda self: self._parse_show(),
|
||||
}
|
||||
|
||||
|
@ -286,7 +284,13 @@ class MySQL(Dialect):
|
|||
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
def _parse_show_mysql(self, this, target=False, full=None, global_=None):
|
||||
def _parse_show_mysql(
|
||||
self,
|
||||
this: str,
|
||||
target: bool | str = False,
|
||||
full: t.Optional[bool] = None,
|
||||
global_: t.Optional[bool] = None,
|
||||
) -> exp.Show:
|
||||
if target:
|
||||
if isinstance(target, str):
|
||||
self._match_text_seq(target)
|
||||
|
@ -342,10 +346,12 @@ class MySQL(Dialect):
|
|||
offset=offset,
|
||||
limit=limit,
|
||||
mutex=mutex,
|
||||
**{"global": global_},
|
||||
**{"global": global_}, # type: ignore
|
||||
)
|
||||
|
||||
def _parse_oldstyle_limit(self):
|
||||
def _parse_oldstyle_limit(
|
||||
self,
|
||||
) -> t.Tuple[t.Optional[exp.Expression], t.Optional[exp.Expression]]:
|
||||
limit = None
|
||||
offset = None
|
||||
if self._match_text_seq("LIMIT"):
|
||||
|
@ -355,23 +361,20 @@ class MySQL(Dialect):
|
|||
elif len(parts) == 2:
|
||||
limit = parts[1]
|
||||
offset = parts[0]
|
||||
|
||||
return offset, limit
|
||||
|
||||
def _parse_set_item_charset(self, kind):
|
||||
def _parse_set_item_charset(self, kind: str) -> exp.Expression:
|
||||
this = self._parse_string() or self._parse_id_var()
|
||||
return self.expression(exp.SetItem, this=this, kind=kind)
|
||||
|
||||
return self.expression(
|
||||
exp.SetItem,
|
||||
this=this,
|
||||
kind=kind,
|
||||
)
|
||||
|
||||
def _parse_set_item_names(self):
|
||||
def _parse_set_item_names(self) -> exp.Expression:
|
||||
charset = self._parse_string() or self._parse_id_var()
|
||||
if self._match_text_seq("COLLATE"):
|
||||
collate = self._parse_string() or self._parse_id_var()
|
||||
else:
|
||||
collate = None
|
||||
|
||||
return self.expression(
|
||||
exp.SetItem,
|
||||
this=charset,
|
||||
|
@ -386,7 +389,7 @@ class MySQL(Dialect):
|
|||
TABLE_HINTS = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.CurrentDate: no_paren_current_date_sql,
|
||||
exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression),
|
||||
exp.DateAdd: _date_add_sql("ADD"),
|
||||
|
@ -403,6 +406,7 @@ class MySQL(Dialect):
|
|||
exp.Min: min_or_least,
|
||||
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
|
||||
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.StrPosition: strposition_to_locate_sql,
|
||||
exp.StrToDate: _str_to_date_sql,
|
||||
|
@ -422,7 +426,7 @@ class MySQL(Dialect):
|
|||
TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB)
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
**generator.Generator.PROPERTIES_LOCATION,
|
||||
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ from sqlglot.helper import seq_get
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _parse_xml_table(self) -> exp.XMLTable:
|
||||
def _parse_xml_table(self: parser.Parser) -> exp.XMLTable:
|
||||
this = self._parse_string()
|
||||
|
||||
passing = None
|
||||
|
@ -66,7 +66,7 @@ class Oracle(Dialect):
|
|||
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP}
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
|
||||
}
|
||||
|
||||
|
@ -107,7 +107,7 @@ class Oracle(Dialect):
|
|||
TABLE_HINTS = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.TINYINT: "NUMBER",
|
||||
exp.DataType.Type.SMALLINT: "NUMBER",
|
||||
exp.DataType.Type.INT: "NUMBER",
|
||||
|
@ -122,7 +122,7 @@ class Oracle(Dialect):
|
|||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.DateStrToDate: lambda self, e: self.func(
|
||||
"TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD")
|
||||
),
|
||||
|
@ -143,7 +143,7 @@ class Oracle(Dialect):
|
|||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
**generator.Generator.PROPERTIES_LOCATION,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import (
|
|||
max_or_greatest,
|
||||
min_or_least,
|
||||
no_paren_current_date_sql,
|
||||
no_pivot_sql,
|
||||
no_tablesample_sql,
|
||||
no_trycast_sql,
|
||||
rename_func,
|
||||
|
@ -33,8 +34,8 @@ DATE_DIFF_FACTOR = {
|
|||
}
|
||||
|
||||
|
||||
def _date_add_sql(kind):
|
||||
def func(self, expression):
|
||||
def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
|
||||
def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -51,7 +52,7 @@ def _date_add_sql(kind):
|
|||
return func
|
||||
|
||||
|
||||
def _date_diff_sql(self, expression):
|
||||
def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
|
||||
unit = expression.text("unit").upper()
|
||||
factor = DATE_DIFF_FACTOR.get(unit)
|
||||
|
||||
|
@ -77,7 +78,7 @@ def _date_diff_sql(self, expression):
|
|||
return f"CAST({unit} AS BIGINT)"
|
||||
|
||||
|
||||
def _substring_sql(self, expression):
|
||||
def _substring_sql(self: generator.Generator, expression: exp.Substring) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
start = self.sql(expression, "start")
|
||||
length = self.sql(expression, "length")
|
||||
|
@ -88,7 +89,7 @@ def _substring_sql(self, expression):
|
|||
return f"SUBSTRING({this}{from_part}{for_part})"
|
||||
|
||||
|
||||
def _string_agg_sql(self, expression):
|
||||
def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str:
|
||||
expression = expression.copy()
|
||||
separator = expression.args.get("separator") or exp.Literal.string(",")
|
||||
|
||||
|
@ -102,13 +103,13 @@ def _string_agg_sql(self, expression):
|
|||
return f"STRING_AGG({self.format_args(this, separator)}{order})"
|
||||
|
||||
|
||||
def _datatype_sql(self, expression):
|
||||
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
|
||||
if expression.this == exp.DataType.Type.ARRAY:
|
||||
return f"{self.expressions(expression, flat=True)}[]"
|
||||
return self.datatype_sql(expression)
|
||||
|
||||
|
||||
def _auto_increment_to_serial(expression):
|
||||
def _auto_increment_to_serial(expression: exp.Expression) -> exp.Expression:
|
||||
auto = expression.find(exp.AutoIncrementColumnConstraint)
|
||||
|
||||
if auto:
|
||||
|
@ -126,7 +127,7 @@ def _auto_increment_to_serial(expression):
|
|||
return expression
|
||||
|
||||
|
||||
def _serial_to_generated(expression):
|
||||
def _serial_to_generated(expression: exp.Expression) -> exp.Expression:
|
||||
kind = expression.args["kind"]
|
||||
|
||||
if kind.this == exp.DataType.Type.SERIAL:
|
||||
|
@ -144,6 +145,7 @@ def _serial_to_generated(expression):
|
|||
constraints = expression.args["constraints"]
|
||||
generated = exp.ColumnConstraint(kind=exp.GeneratedAsIdentityColumnConstraint(this=False))
|
||||
notnull = exp.ColumnConstraint(kind=exp.NotNullColumnConstraint())
|
||||
|
||||
if notnull not in constraints:
|
||||
constraints.insert(0, notnull)
|
||||
if generated not in constraints:
|
||||
|
@ -152,7 +154,7 @@ def _serial_to_generated(expression):
|
|||
return expression
|
||||
|
||||
|
||||
def _generate_series(args):
|
||||
def _generate_series(args: t.List) -> exp.Expression:
|
||||
# The goal is to convert step values like '1 day' or INTERVAL '1 day' into INTERVAL '1' day
|
||||
step = seq_get(args, 2)
|
||||
|
||||
|
@ -168,11 +170,12 @@ def _generate_series(args):
|
|||
return exp.GenerateSeries.from_arg_list(args)
|
||||
|
||||
|
||||
def _to_timestamp(args):
|
||||
def _to_timestamp(args: t.List) -> exp.Expression:
|
||||
# TO_TIMESTAMP accepts either a single double argument or (text, text)
|
||||
if len(args) == 1:
|
||||
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE
|
||||
return exp.UnixToTime.from_arg_list(args)
|
||||
|
||||
# https://www.postgresql.org/docs/current/functions-formatting.html
|
||||
return format_time_lambda(exp.StrToTime, "postgres")(args)
|
||||
|
||||
|
@ -255,7 +258,7 @@ class Postgres(Dialect):
|
|||
STRICT_CAST = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
|
||||
this=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
),
|
||||
|
@ -271,7 +274,7 @@ class Postgres(Dialect):
|
|||
}
|
||||
|
||||
BITWISE = {
|
||||
**parser.Parser.BITWISE, # type: ignore
|
||||
**parser.Parser.BITWISE,
|
||||
TokenType.HASH: exp.BitwiseXor,
|
||||
}
|
||||
|
||||
|
@ -280,7 +283,7 @@ class Postgres(Dialect):
|
|||
}
|
||||
|
||||
RANGE_PARSERS = {
|
||||
**parser.Parser.RANGE_PARSERS, # type: ignore
|
||||
**parser.Parser.RANGE_PARSERS,
|
||||
TokenType.DAMP: binary_range_parser(exp.ArrayOverlaps),
|
||||
TokenType.AT_GT: binary_range_parser(exp.ArrayContains),
|
||||
TokenType.LT_AT: binary_range_parser(exp.ArrayContained),
|
||||
|
@ -303,14 +306,14 @@ class Postgres(Dialect):
|
|||
return self.expression(exp.Extract, this=part, expression=value)
|
||||
|
||||
class Generator(generator.Generator):
|
||||
INTERVAL_ALLOWS_PLURAL_FORM = False
|
||||
SINGLE_STRING_INTERVAL = True
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
PARAMETER_TOKEN = "$"
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.TINYINT: "SMALLINT",
|
||||
exp.DataType.Type.FLOAT: "REAL",
|
||||
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
|
||||
|
@ -320,14 +323,9 @@ class Postgres(Dialect):
|
|||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
|
||||
exp.ColumnDef: transforms.preprocess(
|
||||
[
|
||||
_auto_increment_to_serial,
|
||||
_serial_to_generated,
|
||||
],
|
||||
),
|
||||
exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]),
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
|
||||
|
@ -348,6 +346,7 @@ class Postgres(Dialect):
|
|||
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
|
||||
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
|
||||
exp.Merge: transforms.preprocess([transforms.remove_target_from_merge]),
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
|
||||
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
|
||||
exp.StrPosition: str_position_sql,
|
||||
|
@ -369,7 +368,7 @@ class Postgres(Dialect):
|
|||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
**generator.Generator.PROPERTIES_LOCATION,
|
||||
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import (
|
|||
format_time_lambda,
|
||||
if_sql,
|
||||
no_ilike_sql,
|
||||
no_pivot_sql,
|
||||
no_safe_divide_sql,
|
||||
rename_func,
|
||||
struct_extract_sql,
|
||||
|
@ -127,39 +128,12 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s
|
|||
)
|
||||
|
||||
|
||||
def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str:
|
||||
start = expression.args["start"]
|
||||
end = expression.args["end"]
|
||||
step = expression.args.get("step")
|
||||
|
||||
target_type = None
|
||||
|
||||
if isinstance(start, exp.Cast):
|
||||
target_type = start.to
|
||||
elif isinstance(end, exp.Cast):
|
||||
target_type = end.to
|
||||
|
||||
if target_type and target_type.this == exp.DataType.Type.TIMESTAMP:
|
||||
to = target_type.copy()
|
||||
|
||||
if target_type is start.to:
|
||||
end = exp.Cast(this=end, to=to)
|
||||
else:
|
||||
start = exp.Cast(this=start, to=to)
|
||||
|
||||
sql = self.func("SEQUENCE", start, end, step)
|
||||
if isinstance(expression.parent, exp.Table):
|
||||
sql = f"UNNEST({sql})"
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
def _ensure_utf8(charset: exp.Literal) -> None:
|
||||
if charset.name.lower() != "utf-8":
|
||||
raise UnsupportedError(f"Unsupported charset {charset}")
|
||||
|
||||
|
||||
def _approx_percentile(args: t.Sequence) -> exp.Expression:
|
||||
def _approx_percentile(args: t.List) -> exp.Expression:
|
||||
if len(args) == 4:
|
||||
return exp.ApproxQuantile(
|
||||
this=seq_get(args, 0),
|
||||
|
@ -176,7 +150,7 @@ def _approx_percentile(args: t.Sequence) -> exp.Expression:
|
|||
return exp.ApproxQuantile.from_arg_list(args)
|
||||
|
||||
|
||||
def _from_unixtime(args: t.Sequence) -> exp.Expression:
|
||||
def _from_unixtime(args: t.List) -> exp.Expression:
|
||||
if len(args) == 3:
|
||||
return exp.UnixToTime(
|
||||
this=seq_get(args, 0),
|
||||
|
@ -191,22 +165,39 @@ def _from_unixtime(args: t.Sequence) -> exp.Expression:
|
|||
return exp.UnixToTime.from_arg_list(args)
|
||||
|
||||
|
||||
def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
|
||||
if isinstance(expression, exp.Table):
|
||||
if isinstance(expression.this, exp.GenerateSeries):
|
||||
unnest = exp.Unnest(expressions=[expression.this])
|
||||
|
||||
if expression.alias:
|
||||
return exp.alias_(
|
||||
unnest,
|
||||
alias="_u",
|
||||
table=[expression.alias],
|
||||
copy=False,
|
||||
)
|
||||
return unnest
|
||||
return expression
|
||||
|
||||
|
||||
class Presto(Dialect):
|
||||
index_offset = 1
|
||||
null_ordering = "nulls_are_last"
|
||||
time_format = MySQL.time_format # type: ignore
|
||||
time_mapping = MySQL.time_mapping # type: ignore
|
||||
time_format = MySQL.time_format
|
||||
time_mapping = MySQL.time_mapping
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"START": TokenType.BEGIN,
|
||||
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
|
||||
"ROW": TokenType.STRUCT,
|
||||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
|
||||
"APPROX_PERCENTILE": _approx_percentile,
|
||||
"CARDINALITY": exp.ArraySize.from_arg_list,
|
||||
|
@ -252,13 +243,13 @@ class Presto(Dialect):
|
|||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
**generator.Generator.PROPERTIES_LOCATION,
|
||||
exp.LocationProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.INT: "INTEGER",
|
||||
exp.DataType.Type.FLOAT: "REAL",
|
||||
exp.DataType.Type.BINARY: "VARBINARY",
|
||||
|
@ -268,8 +259,9 @@ class Presto(Dialect):
|
|||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.ApproxDistinct: _approx_distinct_sql,
|
||||
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
|
||||
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
|
||||
exp.ArrayConcat: rename_func("CONCAT"),
|
||||
exp.ArrayContains: rename_func("CONTAINS"),
|
||||
|
@ -293,7 +285,7 @@ class Presto(Dialect):
|
|||
exp.Decode: _decode_sql,
|
||||
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
|
||||
exp.Encode: _encode_sql,
|
||||
exp.GenerateSeries: _sequence_sql,
|
||||
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
|
||||
exp.Group: transforms.preprocess([transforms.unalias_group]),
|
||||
exp.Hex: rename_func("TO_HEX"),
|
||||
exp.If: if_sql,
|
||||
|
@ -301,10 +293,10 @@ class Presto(Dialect):
|
|||
exp.Initcap: _initcap_sql,
|
||||
exp.Lateral: _explode_to_unnest_sql,
|
||||
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.Quantile: _quantile_sql,
|
||||
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
|
||||
exp.SafeDivide: no_safe_divide_sql,
|
||||
exp.Schema: _schema_sql,
|
||||
exp.Select: transforms.preprocess(
|
||||
|
@ -320,8 +312,7 @@ class Presto(Dialect):
|
|||
exp.StrToTime: _str_to_time_sql,
|
||||
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
|
||||
exp.StructExtract: struct_extract_sql,
|
||||
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'",
|
||||
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
|
||||
exp.Table: transforms.preprocess([_unnest_sequence]),
|
||||
exp.TimestampTrunc: timestamptrunc_sql,
|
||||
exp.TimeStrToDate: timestrtotime_sql,
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
|
@ -336,6 +327,7 @@ class Presto(Dialect):
|
|||
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
||||
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
exp.With: transforms.preprocess([transforms.add_recursive_cte_column_names]),
|
||||
exp.WithinGroup: transforms.preprocess(
|
||||
[transforms.remove_within_group_for_percentiles]
|
||||
),
|
||||
|
@ -351,3 +343,25 @@ class Presto(Dialect):
|
|||
modes = expression.args.get("modes")
|
||||
modes = f" {', '.join(modes)}" if modes else ""
|
||||
return f"START TRANSACTION{modes}"
|
||||
|
||||
def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
|
||||
start = expression.args["start"]
|
||||
end = expression.args["end"]
|
||||
step = expression.args.get("step")
|
||||
|
||||
if isinstance(start, exp.Cast):
|
||||
target_type = start.to
|
||||
elif isinstance(end, exp.Cast):
|
||||
target_type = end.to
|
||||
else:
|
||||
target_type = None
|
||||
|
||||
if target_type and target_type.is_type(exp.DataType.Type.TIMESTAMP):
|
||||
to = target_type.copy()
|
||||
|
||||
if target_type is start.to:
|
||||
end = exp.Cast(this=end, to=to)
|
||||
else:
|
||||
start = exp.Cast(this=start, to=to)
|
||||
|
||||
return self.func("SEQUENCE", start, end, step)
|
||||
|
|
|
@ -8,21 +8,21 @@ from sqlglot.helper import seq_get
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _json_sql(self, e) -> str:
|
||||
return f'{self.sql(e, "this")}."{e.expression.name}"'
|
||||
def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str:
|
||||
return f'{self.sql(expression, "this")}."{expression.expression.name}"'
|
||||
|
||||
|
||||
class Redshift(Postgres):
|
||||
time_format = "'YYYY-MM-DD HH:MI:SS'"
|
||||
time_mapping = {
|
||||
**Postgres.time_mapping, # type: ignore
|
||||
**Postgres.time_mapping,
|
||||
"MON": "%b",
|
||||
"HH": "%H",
|
||||
}
|
||||
|
||||
class Parser(Postgres.Parser):
|
||||
FUNCTIONS = {
|
||||
**Postgres.Parser.FUNCTIONS, # type: ignore
|
||||
**Postgres.Parser.FUNCTIONS,
|
||||
"DATEADD": lambda args: exp.DateAdd(
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
|
@ -45,7 +45,7 @@ class Redshift(Postgres):
|
|||
isinstance(this, exp.DataType)
|
||||
and this.this == exp.DataType.Type.VARCHAR
|
||||
and this.expressions
|
||||
and this.expressions[0] == exp.column("MAX")
|
||||
and this.expressions[0].this == exp.column("MAX")
|
||||
):
|
||||
this.set("expressions", [exp.Var(this="MAX")])
|
||||
|
||||
|
@ -57,9 +57,7 @@ class Redshift(Postgres):
|
|||
STRING_ESCAPES = ["\\"]
|
||||
|
||||
KEYWORDS = {
|
||||
**Postgres.Tokenizer.KEYWORDS, # type: ignore
|
||||
"GEOMETRY": TokenType.GEOMETRY,
|
||||
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
||||
**Postgres.Tokenizer.KEYWORDS,
|
||||
"HLLSKETCH": TokenType.HLLSKETCH,
|
||||
"SUPER": TokenType.SUPER,
|
||||
"SYSDATE": TokenType.CURRENT_TIMESTAMP,
|
||||
|
@ -76,22 +74,22 @@ class Redshift(Postgres):
|
|||
|
||||
class Generator(Postgres.Generator):
|
||||
LOCKING_READS_SUPPORTED = False
|
||||
SINGLE_STRING_INTERVAL = True
|
||||
RENAME_TABLE_WITH_DB = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**Postgres.Generator.TYPE_MAPPING, # type: ignore
|
||||
**Postgres.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.BINARY: "VARBYTE",
|
||||
exp.DataType.Type.VARBINARY: "VARBYTE",
|
||||
exp.DataType.Type.INT: "INTEGER",
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**Postgres.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
**Postgres.Generator.PROPERTIES_LOCATION,
|
||||
exp.LikeProperty: exp.Properties.Location.POST_WITH,
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**Postgres.Generator.TRANSFORMS, # type: ignore
|
||||
**Postgres.Generator.TRANSFORMS,
|
||||
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
|
||||
exp.DateAdd: lambda self, e: self.func(
|
||||
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
|
||||
|
@ -107,10 +105,13 @@ class Redshift(Postgres):
|
|||
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
|
||||
}
|
||||
|
||||
# Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots
|
||||
TRANSFORMS.pop(exp.Pivot)
|
||||
|
||||
# Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres)
|
||||
TRANSFORMS.pop(exp.Pow)
|
||||
|
||||
RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot"}
|
||||
RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"}
|
||||
|
||||
def values_sql(self, expression: exp.Values) -> str:
|
||||
"""
|
||||
|
@ -120,37 +121,36 @@ class Redshift(Postgres):
|
|||
evaluate the expression. You may need to increase `sys.setrecursionlimit` to run and it can also be
|
||||
very slow.
|
||||
"""
|
||||
if not isinstance(expression.unnest().parent, exp.From):
|
||||
|
||||
# The VALUES clause is still valid in an `INSERT INTO ..` statement, for example
|
||||
if not expression.find_ancestor(exp.From, exp.Join):
|
||||
return super().values_sql(expression)
|
||||
rows = [tuple_exp.expressions for tuple_exp in expression.expressions]
|
||||
|
||||
column_names = expression.alias and expression.args["alias"].columns
|
||||
|
||||
selects = []
|
||||
rows = [tuple_exp.expressions for tuple_exp in expression.expressions]
|
||||
|
||||
for i, row in enumerate(rows):
|
||||
if i == 0 and expression.alias:
|
||||
if i == 0 and column_names:
|
||||
row = [
|
||||
exp.alias_(value, column_name)
|
||||
for value, column_name in zip(row, expression.args["alias"].args["columns"])
|
||||
for value, column_name in zip(row, column_names)
|
||||
]
|
||||
|
||||
selects.append(exp.Select(expressions=row))
|
||||
subquery_expression = selects[0]
|
||||
|
||||
subquery_expression: exp.Select | exp.Union = selects[0]
|
||||
if len(selects) > 1:
|
||||
for select in selects[1:]:
|
||||
subquery_expression = exp.union(subquery_expression, select, distinct=False)
|
||||
|
||||
return self.subquery_sql(subquery_expression.subquery(expression.alias))
|
||||
|
||||
def with_properties(self, properties: exp.Properties) -> str:
|
||||
"""Redshift doesn't have `WITH` as part of their with_properties so we remove it"""
|
||||
return self.properties(properties, prefix=" ", suffix="")
|
||||
|
||||
def renametable_sql(self, expression: exp.RenameTable) -> str:
|
||||
"""Redshift only supports defining the table name itself (not the db) when renaming tables"""
|
||||
expression = expression.copy()
|
||||
target_table = expression.this
|
||||
for arg in target_table.args:
|
||||
if arg != "this":
|
||||
target_table.set(arg, None)
|
||||
this = self.sql(expression, "this")
|
||||
return f"RENAME TO {this}"
|
||||
|
||||
def datatype_sql(self, expression: exp.DataType) -> str:
|
||||
"""
|
||||
Redshift converts the `TEXT` data type to `VARCHAR(255)` by default when people more generally mean
|
||||
|
@ -162,6 +162,8 @@ class Redshift(Postgres):
|
|||
expression = expression.copy()
|
||||
expression.set("this", exp.DataType.Type.VARCHAR)
|
||||
precision = expression.args.get("expressions")
|
||||
|
||||
if not precision:
|
||||
expression.append("expressions", exp.Var(this="MAX"))
|
||||
|
||||
return super().datatype_sql(expression)
|
||||
|
|
|
@ -18,7 +18,7 @@ from sqlglot.dialects.dialect import (
|
|||
var_map_sql,
|
||||
)
|
||||
from sqlglot.expressions import Literal
|
||||
from sqlglot.helper import flatten, seq_get
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.parser import binary_range_parser
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
@ -30,7 +30,7 @@ def _check_int(s: str) -> bool:
|
|||
|
||||
|
||||
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
|
||||
def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.UnixToTime]:
|
||||
def _snowflake_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]:
|
||||
if len(args) == 2:
|
||||
first_arg, second_arg = args
|
||||
if second_arg.is_string:
|
||||
|
@ -52,8 +52,12 @@ def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.Unix
|
|||
|
||||
return exp.UnixToTime(this=first_arg, scale=timescale)
|
||||
|
||||
from sqlglot.optimizer.simplify import simplify_literals
|
||||
|
||||
# The first argument might be an expression like 40 * 365 * 86400, so we try to
|
||||
# reduce it using `simplify_literals` first and then check if it's a Literal.
|
||||
first_arg = seq_get(args, 0)
|
||||
if not isinstance(first_arg, Literal):
|
||||
if not isinstance(simplify_literals(first_arg, root=True), Literal):
|
||||
# case: <variant_expr>
|
||||
return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args)
|
||||
|
||||
|
@ -69,6 +73,19 @@ def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.Unix
|
|||
return exp.UnixToTime.from_arg_list(args)
|
||||
|
||||
|
||||
def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
|
||||
expression = parser.parse_var_map(args)
|
||||
|
||||
if isinstance(expression, exp.StarMap):
|
||||
return expression
|
||||
|
||||
return exp.Struct(
|
||||
expressions=[
|
||||
t.cast(exp.Condition, k).eq(v) for k, v in zip(expression.keys, expression.values)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str:
|
||||
scale = expression.args.get("scale")
|
||||
timestamp = self.sql(expression, "this")
|
||||
|
@ -116,7 +133,7 @@ def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]:
|
|||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/div0
|
||||
def _div0_to_if(args: t.Sequence) -> exp.Expression:
|
||||
def _div0_to_if(args: t.List) -> exp.Expression:
|
||||
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
|
||||
true = exp.Literal.number(0)
|
||||
false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1))
|
||||
|
@ -124,13 +141,13 @@ def _div0_to_if(args: t.Sequence) -> exp.Expression:
|
|||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
|
||||
def _zeroifnull_to_if(args: t.Sequence) -> exp.Expression:
|
||||
def _zeroifnull_to_if(args: t.List) -> exp.Expression:
|
||||
cond = exp.Is(this=seq_get(args, 0), expression=exp.Null())
|
||||
return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0))
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
|
||||
def _nullifzero_to_if(args: t.Sequence) -> exp.Expression:
|
||||
def _nullifzero_to_if(args: t.List) -> exp.Expression:
|
||||
cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0))
|
||||
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
|
||||
|
||||
|
@ -143,6 +160,12 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
|
|||
return self.datatype_sql(expression)
|
||||
|
||||
|
||||
def _parse_convert_timezone(args: t.List) -> exp.Expression:
|
||||
if len(args) == 3:
|
||||
return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args)
|
||||
return exp.AtTimeZone(this=seq_get(args, 1), zone=seq_get(args, 0))
|
||||
|
||||
|
||||
class Snowflake(Dialect):
|
||||
null_ordering = "nulls_are_large"
|
||||
time_format = "'yyyy-mm-dd hh24:mi:ss'"
|
||||
|
@ -177,17 +200,14 @@ class Snowflake(Dialect):
|
|||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
QUOTED_PIVOT_COLUMNS = True
|
||||
IDENTIFY_PIVOT_STRINGS = True
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
|
||||
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
|
||||
"ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list,
|
||||
"CONVERT_TIMEZONE": lambda args: exp.AtTimeZone(
|
||||
this=seq_get(args, 1),
|
||||
zone=seq_get(args, 0),
|
||||
),
|
||||
"CONVERT_TIMEZONE": _parse_convert_timezone,
|
||||
"DATE_TRUNC": date_trunc_to_time,
|
||||
"DATEADD": lambda args: exp.DateAdd(
|
||||
this=seq_get(args, 2),
|
||||
|
@ -202,7 +222,7 @@ class Snowflake(Dialect):
|
|||
"DIV0": _div0_to_if,
|
||||
"IFF": exp.If.from_arg_list,
|
||||
"NULLIFZERO": _nullifzero_to_if,
|
||||
"OBJECT_CONSTRUCT": parser.parse_var_map,
|
||||
"OBJECT_CONSTRUCT": _parse_object_construct,
|
||||
"RLIKE": exp.RegexpLike.from_arg_list,
|
||||
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
|
||||
"TO_ARRAY": exp.Array.from_arg_list,
|
||||
|
@ -224,7 +244,7 @@ class Snowflake(Dialect):
|
|||
}
|
||||
|
||||
COLUMN_OPERATORS = {
|
||||
**parser.Parser.COLUMN_OPERATORS, # type: ignore
|
||||
**parser.Parser.COLUMN_OPERATORS,
|
||||
TokenType.COLON: lambda self, this, path: self.expression(
|
||||
exp.Bracket,
|
||||
this=this,
|
||||
|
@ -232,14 +252,16 @@ class Snowflake(Dialect):
|
|||
),
|
||||
}
|
||||
|
||||
TIMESTAMPS = parser.Parser.TIMESTAMPS.copy() - {TokenType.TIME}
|
||||
|
||||
RANGE_PARSERS = {
|
||||
**parser.Parser.RANGE_PARSERS, # type: ignore
|
||||
**parser.Parser.RANGE_PARSERS,
|
||||
TokenType.LIKE_ANY: binary_range_parser(exp.LikeAny),
|
||||
TokenType.ILIKE_ANY: binary_range_parser(exp.ILikeAny),
|
||||
}
|
||||
|
||||
ALTER_PARSERS = {
|
||||
**parser.Parser.ALTER_PARSERS, # type: ignore
|
||||
**parser.Parser.ALTER_PARSERS,
|
||||
"UNSET": lambda self: self._parse_alter_table_set_tag(unset=True),
|
||||
"SET": lambda self: self._parse_alter_table_set_tag(),
|
||||
}
|
||||
|
@ -256,17 +278,20 @@ class Snowflake(Dialect):
|
|||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"CHAR VARYING": TokenType.VARCHAR,
|
||||
"CHARACTER VARYING": TokenType.VARCHAR,
|
||||
"EXCLUDE": TokenType.EXCEPT,
|
||||
"ILIKE ANY": TokenType.ILIKE_ANY,
|
||||
"LIKE ANY": TokenType.LIKE_ANY,
|
||||
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
|
||||
"MINUS": TokenType.EXCEPT,
|
||||
"NCHAR VARYING": TokenType.VARCHAR,
|
||||
"PUT": TokenType.COMMAND,
|
||||
"RENAME": TokenType.REPLACE,
|
||||
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
|
||||
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
|
||||
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
|
||||
"TIMESTAMPNTZ": TokenType.TIMESTAMP,
|
||||
"MINUS": TokenType.EXCEPT,
|
||||
"SAMPLE": TokenType.TABLE_SAMPLE,
|
||||
}
|
||||
|
||||
|
@ -285,7 +310,7 @@ class Snowflake(Dialect):
|
|||
TABLE_HINTS = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.Array: inline_array_sql,
|
||||
exp.ArrayConcat: rename_func("ARRAY_CAT"),
|
||||
exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
|
||||
|
@ -299,6 +324,7 @@ class Snowflake(Dialect):
|
|||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.Extract: rename_func("DATE_PART"),
|
||||
exp.If: rename_func("IFF"),
|
||||
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
|
||||
exp.LogicalOr: rename_func("BOOLOR_AGG"),
|
||||
|
@ -312,6 +338,10 @@ class Snowflake(Dialect):
|
|||
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
|
||||
),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.Struct: lambda self, e: self.func(
|
||||
"OBJECT_CONSTRUCT",
|
||||
*(arg for expression in e.expressions for arg in expression.flatten()),
|
||||
),
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
|
||||
exp.TimeToStr: lambda self, e: self.func(
|
||||
|
@ -326,7 +356,7 @@ class Snowflake(Dialect):
|
|||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
|
||||
}
|
||||
|
||||
|
@ -336,7 +366,7 @@ class Snowflake(Dialect):
|
|||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
**generator.Generator.PROPERTIES_LOCATION,
|
||||
exp.SetProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
@ -351,53 +381,10 @@ class Snowflake(Dialect):
|
|||
self.unsupported("INTERSECT with All is not supported in Snowflake")
|
||||
return super().intersect_op(expression)
|
||||
|
||||
def values_sql(self, expression: exp.Values) -> str:
|
||||
"""Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted.
|
||||
|
||||
We also want to make sure that after we find matches where we need to unquote a column that we prevent users
|
||||
from adding quotes to the column by using the `identify` argument when generating the SQL.
|
||||
"""
|
||||
alias = expression.args.get("alias")
|
||||
if alias and alias.args.get("columns"):
|
||||
expression = expression.transform(
|
||||
lambda node: exp.Identifier(**{**node.args, "quoted": False})
|
||||
if isinstance(node, exp.Identifier)
|
||||
and isinstance(node.parent, exp.TableAlias)
|
||||
and node.arg_key == "columns"
|
||||
else node,
|
||||
)
|
||||
return self.no_identify(lambda: super(self.__class__, self).values_sql(expression))
|
||||
return super().values_sql(expression)
|
||||
|
||||
def settag_sql(self, expression: exp.SetTag) -> str:
|
||||
action = "UNSET" if expression.args.get("unset") else "SET"
|
||||
return f"{action} TAG {self.expressions(expression)}"
|
||||
|
||||
def select_sql(self, expression: exp.Select) -> str:
|
||||
"""Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted and also
|
||||
that all columns in a SELECT are unquoted. We also want to make sure that after we find matches where we need
|
||||
to unquote a column that we prevent users from adding quotes to the column by using the `identify` argument when
|
||||
generating the SQL.
|
||||
|
||||
Note: We make an assumption that any columns referenced in a VALUES expression should be unquoted throughout the
|
||||
expression. This might not be true in a case where the same column name can be sourced from another table that can
|
||||
properly quote but should be true in most cases.
|
||||
"""
|
||||
values_identifiers = set(
|
||||
flatten(
|
||||
(v.args.get("alias") or exp.Alias()).args.get("columns", [])
|
||||
for v in expression.find_all(exp.Values)
|
||||
)
|
||||
)
|
||||
if values_identifiers:
|
||||
expression = expression.transform(
|
||||
lambda node: exp.Identifier(**{**node.args, "quoted": False})
|
||||
if isinstance(node, exp.Identifier) and node in values_identifiers
|
||||
else node,
|
||||
)
|
||||
return self.no_identify(lambda: super(self.__class__, self).select_sql(expression))
|
||||
return super().select_sql(expression)
|
||||
|
||||
def describe_sql(self, expression: exp.Describe) -> str:
|
||||
# Default to table if kind is unknown
|
||||
kind_value = expression.args.get("kind") or "TABLE"
|
||||
|
|
|
@ -7,10 +7,10 @@ from sqlglot.dialects.spark2 import Spark2
|
|||
from sqlglot.helper import seq_get
|
||||
|
||||
|
||||
def _parse_datediff(args: t.Sequence) -> exp.Expression:
|
||||
def _parse_datediff(args: t.List) -> exp.Expression:
|
||||
"""
|
||||
Although Spark docs don't mention the "unit" argument, Spark3 added support for
|
||||
it at some point. Databricks also supports this variation (see below).
|
||||
it at some point. Databricks also supports this variant (see below).
|
||||
|
||||
For example, in spark-sql (v3.3.1):
|
||||
- SELECT DATEDIFF('2020-01-01', '2020-01-05') results in -4
|
||||
|
@ -36,7 +36,7 @@ def _parse_datediff(args: t.Sequence) -> exp.Expression:
|
|||
class Spark(Spark2):
|
||||
class Parser(Spark2.Parser):
|
||||
FUNCTIONS = {
|
||||
**Spark2.Parser.FUNCTIONS, # type: ignore
|
||||
**Spark2.Parser.FUNCTIONS,
|
||||
"DATEDIFF": _parse_datediff,
|
||||
}
|
||||
|
||||
|
|
|
@ -3,7 +3,12 @@ from __future__ import annotations
|
|||
import typing as t
|
||||
|
||||
from sqlglot import exp, parser, transforms
|
||||
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
|
||||
from sqlglot.dialects.dialect import (
|
||||
create_with_partitions_sql,
|
||||
pivot_column_names,
|
||||
rename_func,
|
||||
trim_sql,
|
||||
)
|
||||
from sqlglot.dialects.hive import Hive
|
||||
from sqlglot.helper import seq_get
|
||||
|
||||
|
@ -26,7 +31,7 @@ def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
|
|||
return f"MAP_FROM_ARRAYS({keys}, {values})"
|
||||
|
||||
|
||||
def _parse_as_cast(to_type: str) -> t.Callable[[t.Sequence], exp.Expression]:
|
||||
def _parse_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]:
|
||||
return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type))
|
||||
|
||||
|
||||
|
@ -53,10 +58,56 @@ def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
|
|||
raise ValueError("Improper scale for timestamp")
|
||||
|
||||
|
||||
def _unalias_pivot(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Spark doesn't allow PIVOT aliases, so we need to remove them and possibly wrap a
|
||||
pivoted source in a subquery with the same alias to preserve the query's semantics.
|
||||
|
||||
Example:
|
||||
>>> from sqlglot import parse_one
|
||||
>>> expr = parse_one("SELECT piv.x FROM tbl PIVOT (SUM(a) FOR b IN ('x')) piv")
|
||||
>>> print(_unalias_pivot(expr).sql(dialect="spark"))
|
||||
SELECT piv.x FROM (SELECT * FROM tbl PIVOT(SUM(a) FOR b IN ('x'))) AS piv
|
||||
"""
|
||||
if isinstance(expression, exp.From) and expression.this.args.get("pivots"):
|
||||
pivot = expression.this.args["pivots"][0]
|
||||
if pivot.alias:
|
||||
alias = pivot.args["alias"].pop()
|
||||
return exp.From(
|
||||
this=expression.this.replace(
|
||||
exp.select("*").from_(expression.this.copy()).subquery(alias=alias)
|
||||
)
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Spark doesn't allow the column referenced in the PIVOT's field to be qualified,
|
||||
so we need to unqualify it.
|
||||
|
||||
Example:
|
||||
>>> from sqlglot import parse_one
|
||||
>>> expr = parse_one("SELECT * FROM tbl PIVOT (SUM(tbl.sales) FOR tbl.quarter IN ('Q1', 'Q2'))")
|
||||
>>> print(_unqualify_pivot_columns(expr).sql(dialect="spark"))
|
||||
SELECT * FROM tbl PIVOT(SUM(tbl.sales) FOR quarter IN ('Q1', 'Q1'))
|
||||
"""
|
||||
if isinstance(expression, exp.Pivot):
|
||||
expression.args["field"].transform(
|
||||
lambda node: exp.column(node.output_name, quoted=node.this.quoted)
|
||||
if isinstance(node, exp.Column)
|
||||
else node,
|
||||
copy=False,
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
class Spark2(Hive):
|
||||
class Parser(Hive.Parser):
|
||||
FUNCTIONS = {
|
||||
**Hive.Parser.FUNCTIONS, # type: ignore
|
||||
**Hive.Parser.FUNCTIONS,
|
||||
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
|
||||
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
|
||||
"LEFT": lambda args: exp.Substring(
|
||||
|
@ -110,7 +161,7 @@ class Spark2(Hive):
|
|||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS, # type: ignore
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
|
||||
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
|
||||
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
|
||||
|
@ -131,43 +182,21 @@ class Spark2(Hive):
|
|||
kind="COLUMNS",
|
||||
)
|
||||
|
||||
def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
|
||||
# Spark doesn't add a suffix to the pivot columns when there's a single aggregation
|
||||
if len(pivot_columns) == 1:
|
||||
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
|
||||
if len(aggregations) == 1:
|
||||
return [""]
|
||||
|
||||
names = []
|
||||
for agg in pivot_columns:
|
||||
if isinstance(agg, exp.Alias):
|
||||
names.append(agg.alias)
|
||||
else:
|
||||
"""
|
||||
This case corresponds to aggregations without aliases being used as suffixes
|
||||
(e.g. col_avg(foo)). We need to unquote identifiers because they're going to
|
||||
be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
|
||||
Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
|
||||
|
||||
Moreover, function names are lowercased in order to mimic Spark's naming scheme.
|
||||
"""
|
||||
agg_all_unquoted = agg.transform(
|
||||
lambda node: exp.Identifier(this=node.name, quoted=False)
|
||||
if isinstance(node, exp.Identifier)
|
||||
else node
|
||||
)
|
||||
names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
|
||||
|
||||
return names
|
||||
return pivot_column_names(aggregations, dialect="spark")
|
||||
|
||||
class Generator(Hive.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**Hive.Generator.TYPE_MAPPING, # type: ignore
|
||||
**Hive.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.TINYINT: "BYTE",
|
||||
exp.DataType.Type.SMALLINT: "SHORT",
|
||||
exp.DataType.Type.BIGINT: "LONG",
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**Hive.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
**Hive.Generator.PROPERTIES_LOCATION,
|
||||
exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
|
@ -175,7 +204,7 @@ class Spark2(Hive):
|
|||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**Hive.Generator.TRANSFORMS, # type: ignore
|
||||
**Hive.Generator.TRANSFORMS,
|
||||
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
|
||||
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
|
||||
|
@ -188,11 +217,12 @@ class Spark2(Hive):
|
|||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
|
||||
exp.From: transforms.preprocess([_unalias_pivot]),
|
||||
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.Map: _map_sql,
|
||||
exp.Pivot: transforms.preprocess([transforms.unqualify_pivot_columns]),
|
||||
exp.Pivot: transforms.preprocess([_unqualify_pivot_columns]),
|
||||
exp.Reduce: rename_func("AGGREGATE"),
|
||||
exp.StrToDate: _str_to_date,
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
|
|
|
@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import (
|
|||
arrow_json_extract_sql,
|
||||
count_if_to_sum,
|
||||
no_ilike_sql,
|
||||
no_pivot_sql,
|
||||
no_tablesample_sql,
|
||||
no_trycast_sql,
|
||||
rename_func,
|
||||
|
@ -14,7 +15,7 @@ from sqlglot.dialects.dialect import (
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _date_add_sql(self, expression):
|
||||
def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
|
||||
modifier = expression.expression
|
||||
modifier = modifier.name if modifier.is_string else self.sql(modifier)
|
||||
unit = expression.args.get("unit")
|
||||
|
@ -67,7 +68,7 @@ class SQLite(Dialect):
|
|||
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"EDITDIST3": exp.Levenshtein.from_arg_list,
|
||||
}
|
||||
|
||||
|
@ -76,7 +77,7 @@ class SQLite(Dialect):
|
|||
TABLE_HINTS = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.BOOLEAN: "INTEGER",
|
||||
exp.DataType.Type.TINYINT: "INTEGER",
|
||||
exp.DataType.Type.SMALLINT: "INTEGER",
|
||||
|
@ -98,7 +99,7 @@ class SQLite(Dialect):
|
|||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.CountIf: count_if_to_sum,
|
||||
exp.Create: transforms.preprocess([_transform_create]),
|
||||
exp.CurrentDate: lambda *_: "CURRENT_DATE",
|
||||
|
@ -114,6 +115,7 @@ class SQLite(Dialect):
|
|||
exp.Levenshtein: rename_func("EDITDIST3"),
|
||||
exp.LogicalOr: rename_func("MAX"),
|
||||
exp.LogicalAnd: rename_func("MIN"),
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_distinct_on, transforms.eliminate_qualify]
|
||||
),
|
||||
|
@ -163,12 +165,15 @@ class SQLite(Dialect):
|
|||
return f"CAST({sql} AS INTEGER)"
|
||||
|
||||
# https://www.sqlite.org/lang_aggfunc.html#group_concat
|
||||
def groupconcat_sql(self, expression):
|
||||
def groupconcat_sql(self, expression: exp.GroupConcat) -> str:
|
||||
this = expression.this
|
||||
distinct = expression.find(exp.Distinct)
|
||||
|
||||
if distinct:
|
||||
this = distinct.expressions[0]
|
||||
distinct = "DISTINCT "
|
||||
distinct_sql = "DISTINCT "
|
||||
else:
|
||||
distinct_sql = ""
|
||||
|
||||
if isinstance(expression.this, exp.Order):
|
||||
self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.")
|
||||
|
@ -176,7 +181,7 @@ class SQLite(Dialect):
|
|||
this = expression.this.this
|
||||
|
||||
separator = expression.args.get("separator")
|
||||
return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})"
|
||||
return f"GROUP_CONCAT({distinct_sql}{self.format_args(this, separator)})"
|
||||
|
||||
def least_sql(self, expression: exp.Least) -> str:
|
||||
if len(expression.expressions) > 1:
|
||||
|
|
|
@ -11,25 +11,24 @@ from sqlglot.helper import seq_get
|
|||
|
||||
|
||||
class StarRocks(MySQL):
|
||||
class Parser(MySQL.Parser): # type: ignore
|
||||
class Parser(MySQL.Parser):
|
||||
FUNCTIONS = {
|
||||
**MySQL.Parser.FUNCTIONS,
|
||||
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
|
||||
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
|
||||
this=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
),
|
||||
}
|
||||
|
||||
class Generator(MySQL.Generator): # type: ignore
|
||||
class Generator(MySQL.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**MySQL.Generator.TYPE_MAPPING, # type: ignore
|
||||
**MySQL.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.TEXT: "STRING",
|
||||
exp.DataType.Type.TIMESTAMP: "DATETIME",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**MySQL.Generator.TRANSFORMS, # type: ignore
|
||||
**MySQL.Generator.TRANSFORMS,
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_sql,
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
|
@ -43,4 +42,5 @@ class StarRocks(MySQL):
|
|||
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
||||
}
|
||||
|
||||
TRANSFORMS.pop(exp.DateTrunc)
|
||||
|
|
|
@ -4,41 +4,38 @@ from sqlglot import exp, generator, parser, transforms
|
|||
from sqlglot.dialects.dialect import Dialect
|
||||
|
||||
|
||||
def _if_sql(self, expression):
|
||||
return f"IF {self.sql(expression, 'this')} THEN {self.sql(expression, 'true')} ELSE {self.sql(expression, 'false')} END"
|
||||
|
||||
|
||||
def _coalesce_sql(self, expression):
|
||||
return f"IFNULL({self.sql(expression, 'this')}, {self.expressions(expression)})"
|
||||
|
||||
|
||||
def _count_sql(self, expression):
|
||||
this = expression.this
|
||||
if isinstance(this, exp.Distinct):
|
||||
return f"COUNTD({self.expressions(this, flat=True)})"
|
||||
return f"COUNT({self.sql(expression, 'this')})"
|
||||
|
||||
|
||||
class Tableau(Dialect):
|
||||
class Generator(generator.Generator):
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
exp.If: _if_sql,
|
||||
exp.Coalesce: _coalesce_sql,
|
||||
exp.Count: _count_sql,
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
**generator.Generator.PROPERTIES_LOCATION,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def if_sql(self, expression: exp.If) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
true = self.sql(expression, "true")
|
||||
false = self.sql(expression, "false")
|
||||
return f"IF {this} THEN {true} ELSE {false} END"
|
||||
|
||||
def coalesce_sql(self, expression: exp.Coalesce) -> str:
|
||||
return f"IFNULL({self.sql(expression, 'this')}, {self.expressions(expression)})"
|
||||
|
||||
def count_sql(self, expression: exp.Count) -> str:
|
||||
this = expression.this
|
||||
if isinstance(this, exp.Distinct):
|
||||
return f"COUNTD({self.expressions(this, flat=True)})"
|
||||
return f"COUNT({self.sql(expression, 'this')})"
|
||||
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
|
||||
}
|
||||
|
|
|
@ -75,12 +75,12 @@ class Teradata(Dialect):
|
|||
FUNC_TOKENS.remove(TokenType.REPLACE)
|
||||
|
||||
STATEMENT_PARSERS = {
|
||||
**parser.Parser.STATEMENT_PARSERS, # type: ignore
|
||||
**parser.Parser.STATEMENT_PARSERS,
|
||||
TokenType.REPLACE: lambda self: self._parse_create(),
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS, # type: ignore
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
"RANGE_N": lambda self: self._parse_rangen(),
|
||||
"TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST),
|
||||
}
|
||||
|
@ -106,7 +106,7 @@ class Teradata(Dialect):
|
|||
exp.Update,
|
||||
**{ # type: ignore
|
||||
"this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
|
||||
"from": self._parse_from(),
|
||||
"from": self._parse_from(modifiers=True),
|
||||
"expressions": self._match(TokenType.SET)
|
||||
and self._parse_csv(self._parse_equality),
|
||||
"where": self._parse_where(),
|
||||
|
@ -135,13 +135,15 @@ class Teradata(Dialect):
|
|||
TABLE_HINTS = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.GEOMETRY: "ST_GEOMETRY",
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX,
|
||||
**generator.Generator.PROPERTIES_LOCATION,
|
||||
exp.OnCommitProperty: exp.Properties.Location.POST_INDEX,
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_EXPRESSION,
|
||||
exp.StabilityProperty: exp.Properties.Location.POST_CREATE,
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
|
|
|
@ -7,7 +7,7 @@ from sqlglot.dialects.presto import Presto
|
|||
class Trino(Presto):
|
||||
class Generator(Presto.Generator):
|
||||
TRANSFORMS = {
|
||||
**Presto.Generator.TRANSFORMS, # type: ignore
|
||||
**Presto.Generator.TRANSFORMS,
|
||||
exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,9 @@ from sqlglot.helper import seq_get
|
|||
from sqlglot.time import format_time
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E
|
||||
|
||||
FULL_FORMAT_TIME_MAPPING = {
|
||||
"weekday": "%A",
|
||||
"dw": "%A",
|
||||
|
@ -50,13 +53,17 @@ DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{
|
|||
TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}
|
||||
|
||||
|
||||
def _format_time_lambda(exp_class, full_format_mapping=None, default=None):
|
||||
def _format_time(args):
|
||||
def _format_time_lambda(
|
||||
exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None
|
||||
) -> t.Callable[[t.List], E]:
|
||||
def _format_time(args: t.List) -> E:
|
||||
assert len(args) == 2
|
||||
|
||||
return exp_class(
|
||||
this=seq_get(args, 1),
|
||||
this=args[1],
|
||||
format=exp.Literal.string(
|
||||
format_time(
|
||||
seq_get(args, 0).name or (TSQL.time_format if default is True else default),
|
||||
args[0].name,
|
||||
{**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING}
|
||||
if full_format_mapping
|
||||
else TSQL.time_mapping,
|
||||
|
@ -67,13 +74,17 @@ def _format_time_lambda(exp_class, full_format_mapping=None, default=None):
|
|||
return _format_time
|
||||
|
||||
|
||||
def _parse_format(args):
|
||||
fmt = seq_get(args, 1)
|
||||
number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this)
|
||||
def _parse_format(args: t.List) -> exp.Expression:
|
||||
assert len(args) == 2
|
||||
|
||||
fmt = args[1]
|
||||
number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.name)
|
||||
|
||||
if number_fmt:
|
||||
return exp.NumberToStr(this=seq_get(args, 0), format=fmt)
|
||||
return exp.NumberToStr(this=args[0], format=fmt)
|
||||
|
||||
return exp.TimeToStr(
|
||||
this=seq_get(args, 0),
|
||||
this=args[0],
|
||||
format=exp.Literal.string(
|
||||
format_time(fmt.name, TSQL.format_time_mapping)
|
||||
if len(fmt.name) == 1
|
||||
|
@ -82,7 +93,7 @@ def _parse_format(args):
|
|||
)
|
||||
|
||||
|
||||
def _parse_eomonth(args):
|
||||
def _parse_eomonth(args: t.List) -> exp.Expression:
|
||||
date = seq_get(args, 0)
|
||||
month_lag = seq_get(args, 1)
|
||||
unit = DATE_DELTA_INTERVAL.get("month")
|
||||
|
@ -96,7 +107,7 @@ def _parse_eomonth(args):
|
|||
return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit))
|
||||
|
||||
|
||||
def _parse_hashbytes(args):
|
||||
def _parse_hashbytes(args: t.List) -> exp.Expression:
|
||||
kind, data = args
|
||||
kind = kind.name.upper() if kind.is_string else ""
|
||||
|
||||
|
@ -110,40 +121,47 @@ def _parse_hashbytes(args):
|
|||
return exp.SHA2(this=data, length=exp.Literal.number(256))
|
||||
if kind == "SHA2_512":
|
||||
return exp.SHA2(this=data, length=exp.Literal.number(512))
|
||||
|
||||
return exp.func("HASHBYTES", *args)
|
||||
|
||||
|
||||
def generate_date_delta_with_unit_sql(self, e):
|
||||
func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF"
|
||||
return self.func(func, e.text("unit"), e.expression, e.this)
|
||||
def generate_date_delta_with_unit_sql(
|
||||
self: generator.Generator, expression: exp.DateAdd | exp.DateDiff
|
||||
) -> str:
|
||||
func = "DATEADD" if isinstance(expression, exp.DateAdd) else "DATEDIFF"
|
||||
return self.func(func, expression.text("unit"), expression.expression, expression.this)
|
||||
|
||||
|
||||
def _format_sql(self, e):
|
||||
def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str:
|
||||
fmt = (
|
||||
e.args["format"]
|
||||
if isinstance(e, exp.NumberToStr)
|
||||
else exp.Literal.string(format_time(e.text("format"), TSQL.inverse_time_mapping))
|
||||
expression.args["format"]
|
||||
if isinstance(expression, exp.NumberToStr)
|
||||
else exp.Literal.string(
|
||||
format_time(
|
||||
expression.text("format"), t.cast(t.Dict[str, str], TSQL.inverse_time_mapping)
|
||||
)
|
||||
)
|
||||
)
|
||||
return self.func("FORMAT", e.this, fmt)
|
||||
return self.func("FORMAT", expression.this, fmt)
|
||||
|
||||
|
||||
def _string_agg_sql(self, e):
|
||||
e = e.copy()
|
||||
def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str:
|
||||
expression = expression.copy()
|
||||
|
||||
this = e.this
|
||||
distinct = e.find(exp.Distinct)
|
||||
this = expression.this
|
||||
distinct = expression.find(exp.Distinct)
|
||||
if distinct:
|
||||
# exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression
|
||||
self.unsupported("T-SQL STRING_AGG doesn't support DISTINCT.")
|
||||
this = distinct.pop().expressions[0]
|
||||
|
||||
order = ""
|
||||
if isinstance(e.this, exp.Order):
|
||||
if e.this.this:
|
||||
this = e.this.this.pop()
|
||||
order = f" WITHIN GROUP ({self.sql(e.this)[1:]})" # Order has a leading space
|
||||
if isinstance(expression.this, exp.Order):
|
||||
if expression.this.this:
|
||||
this = expression.this.this.pop()
|
||||
order = f" WITHIN GROUP ({self.sql(expression.this)[1:]})" # Order has a leading space
|
||||
|
||||
separator = e.args.get("separator") or exp.Literal.string(",")
|
||||
separator = expression.args.get("separator") or exp.Literal.string(",")
|
||||
return f"STRING_AGG({self.format_args(this, separator)}){order}"
|
||||
|
||||
|
||||
|
@ -292,7 +310,7 @@ class TSQL(Dialect):
|
|||
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"CHARINDEX": lambda args: exp.StrPosition(
|
||||
this=seq_get(args, 1),
|
||||
substr=seq_get(args, 0),
|
||||
|
@ -332,13 +350,13 @@ class TSQL(Dialect):
|
|||
DataType.Type.NCHAR,
|
||||
}
|
||||
|
||||
RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - { # type: ignore
|
||||
RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - {
|
||||
TokenType.TABLE,
|
||||
*parser.Parser.TYPE_TOKENS, # type: ignore
|
||||
*parser.Parser.TYPE_TOKENS,
|
||||
}
|
||||
|
||||
STATEMENT_PARSERS = {
|
||||
**parser.Parser.STATEMENT_PARSERS, # type: ignore
|
||||
**parser.Parser.STATEMENT_PARSERS,
|
||||
TokenType.END: lambda self: self._parse_command(),
|
||||
}
|
||||
|
||||
|
@ -377,7 +395,7 @@ class TSQL(Dialect):
|
|||
|
||||
return system_time
|
||||
|
||||
def _parse_table_parts(self, schema: bool = False) -> exp.Expression:
|
||||
def _parse_table_parts(self, schema: bool = False) -> exp.Table:
|
||||
table = super()._parse_table_parts(schema=schema)
|
||||
table.set("system_time", self._parse_system_time())
|
||||
return table
|
||||
|
@ -450,7 +468,7 @@ class TSQL(Dialect):
|
|||
LOCKING_READS_SUPPORTED = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.INT: "INTEGER",
|
||||
exp.DataType.Type.DECIMAL: "NUMERIC",
|
||||
exp.DataType.Type.DATETIME: "DATETIME2",
|
||||
|
@ -458,7 +476,7 @@ class TSQL(Dialect):
|
|||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.DateAdd: generate_date_delta_with_unit_sql,
|
||||
exp.DateDiff: generate_date_delta_with_unit_sql,
|
||||
exp.CurrentDate: rename_func("GETDATE"),
|
||||
|
@ -480,7 +498,7 @@ class TSQL(Dialect):
|
|||
TRANSFORMS.pop(exp.ReturnsProperty)
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
**generator.Generator.PROPERTIES_LOCATION,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
|
|
|
@ -53,7 +53,8 @@ class Keep:
|
|||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
T = t.TypeVar("T")
|
||||
from sqlglot._typing import T
|
||||
|
||||
Edit = t.Union[Insert, Remove, Move, Update, Keep]
|
||||
|
||||
|
||||
|
@ -240,7 +241,7 @@ class ChangeDistiller:
|
|||
return matching_set
|
||||
|
||||
def _compute_leaf_matching_set(self) -> t.Set[t.Tuple[int, int]]:
|
||||
candidate_matchings: t.List[t.Tuple[float, int, exp.Expression, exp.Expression]] = []
|
||||
candidate_matchings: t.List[t.Tuple[float, int, int, exp.Expression, exp.Expression]] = []
|
||||
source_leaves = list(_get_leaves(self._source))
|
||||
target_leaves = list(_get_leaves(self._target))
|
||||
for source_leaf in source_leaves:
|
||||
|
@ -252,6 +253,7 @@ class ChangeDistiller:
|
|||
candidate_matchings,
|
||||
(
|
||||
-similarity_score,
|
||||
-_parent_similarity_score(source_leaf, target_leaf),
|
||||
len(candidate_matchings),
|
||||
source_leaf,
|
||||
target_leaf,
|
||||
|
@ -261,7 +263,7 @@ class ChangeDistiller:
|
|||
# Pick best matchings based on the highest score
|
||||
matching_set = set()
|
||||
while candidate_matchings:
|
||||
_, _, source_leaf, target_leaf = heappop(candidate_matchings)
|
||||
_, _, _, source_leaf, target_leaf = heappop(candidate_matchings)
|
||||
if (
|
||||
id(source_leaf) in self._unmatched_source_nodes
|
||||
and id(target_leaf) in self._unmatched_target_nodes
|
||||
|
@ -327,6 +329,15 @@ def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def _parent_similarity_score(
|
||||
source: t.Optional[exp.Expression], target: t.Optional[exp.Expression]
|
||||
) -> int:
|
||||
if source is None or target is None or type(source) is not type(target):
|
||||
return 0
|
||||
|
||||
return 1 + _parent_similarity_score(source.parent, target.parent)
|
||||
|
||||
|
||||
def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]:
|
||||
args: t.List[t.Union[exp.Expression, t.List]] = []
|
||||
if expression:
|
||||
|
|
|
@ -14,9 +14,10 @@ from sqlglot import maybe_parse
|
|||
from sqlglot.errors import ExecuteError
|
||||
from sqlglot.executor.python import PythonExecutor
|
||||
from sqlglot.executor.table import Table, ensure_tables
|
||||
from sqlglot.helper import dict_depth
|
||||
from sqlglot.optimizer import optimize
|
||||
from sqlglot.planner import Plan
|
||||
from sqlglot.schema import ensure_schema
|
||||
from sqlglot.schema import ensure_schema, flatten_schema, nested_get, nested_set
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
|
@ -52,10 +53,15 @@ def execute(
|
|||
tables_ = ensure_tables(tables)
|
||||
|
||||
if not schema:
|
||||
schema = {
|
||||
name: {column: type(table[0][column]).__name__ for column in table.columns}
|
||||
for name, table in tables_.mapping.items()
|
||||
}
|
||||
schema = {}
|
||||
flattened_tables = flatten_schema(tables_.mapping, depth=dict_depth(tables_.mapping))
|
||||
|
||||
for keys in flattened_tables:
|
||||
table = nested_get(tables_.mapping, *zip(keys, keys))
|
||||
assert table is not None
|
||||
|
||||
for column in table.columns:
|
||||
nested_set(schema, [*keys, column], type(table[0][column]).__name__)
|
||||
|
||||
schema = ensure_schema(schema, dialect=read)
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import statistics
|
|||
from functools import wraps
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import PYTHON_VERSION
|
||||
|
||||
|
||||
|
@ -102,6 +103,8 @@ def cast(this, to):
|
|||
return datetime.date.fromisoformat(this)
|
||||
if to == exp.DataType.Type.DATETIME:
|
||||
return datetime.datetime.fromisoformat(this)
|
||||
if to == exp.DataType.Type.BOOLEAN:
|
||||
return bool(this)
|
||||
if to in exp.DataType.TEXT_TYPES:
|
||||
return str(this)
|
||||
if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}:
|
||||
|
@ -119,9 +122,11 @@ def ordered(this, desc, nulls_first):
|
|||
|
||||
@null_if_any
|
||||
def interval(this, unit):
|
||||
if unit == "DAY":
|
||||
return datetime.timedelta(days=float(this))
|
||||
raise NotImplementedError
|
||||
unit = unit.lower()
|
||||
plural = unit + "s"
|
||||
if plural in Generator.TIME_PART_SINGULARS:
|
||||
unit = plural
|
||||
return datetime.timedelta(**{unit: float(this)})
|
||||
|
||||
|
||||
ENV = {
|
||||
|
@ -147,7 +152,9 @@ ENV = {
|
|||
"COALESCE": lambda *args: next((a for a in args if a is not None), None),
|
||||
"CONCAT": null_if_any(lambda *args: "".join(args)),
|
||||
"CONCATWS": null_if_any(lambda this, *args: this.join(args)),
|
||||
"DATESTRTODATE": null_if_any(lambda arg: datetime.date.fromisoformat(arg)),
|
||||
"DIV": null_if_any(lambda e, this: e / this),
|
||||
"DOT": null_if_any(lambda e, this: e[this]),
|
||||
"EQ": null_if_any(lambda this, e: this == e),
|
||||
"EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
|
||||
"GT": null_if_any(lambda this, e: this > e),
|
||||
|
@ -162,6 +169,7 @@ ENV = {
|
|||
"LOWER": null_if_any(lambda arg: arg.lower()),
|
||||
"LT": null_if_any(lambda this, e: this < e),
|
||||
"LTE": null_if_any(lambda this, e: this <= e),
|
||||
"MAP": null_if_any(lambda *args: dict(zip(*args))), # type: ignore
|
||||
"MOD": null_if_any(lambda e, this: e % this),
|
||||
"MUL": null_if_any(lambda e, this: e * this),
|
||||
"NEQ": null_if_any(lambda this, e: this != e),
|
||||
|
@ -180,4 +188,5 @@ ENV = {
|
|||
"CURRENTTIMESTAMP": datetime.datetime.now,
|
||||
"CURRENTTIME": datetime.datetime.now,
|
||||
"CURRENTDATE": datetime.date.today,
|
||||
"STRFTIME": null_if_any(lambda fmt, arg: datetime.datetime.fromisoformat(arg).strftime(fmt)),
|
||||
}
|
||||
|
|
|
@ -360,11 +360,19 @@ def _ordered_py(self, expression):
|
|||
|
||||
def _rename(self, e):
|
||||
try:
|
||||
if "expressions" in e.args:
|
||||
this = self.sql(e, "this")
|
||||
this = f"{this}, " if this else ""
|
||||
return f"{e.key.upper()}({this}{self.expressions(e)})"
|
||||
return self.func(e.key, *e.args.values())
|
||||
values = list(e.args.values())
|
||||
|
||||
if len(values) == 1:
|
||||
values = values[0]
|
||||
if not isinstance(values, list):
|
||||
return self.func(e.key, values)
|
||||
return self.func(e.key, *values)
|
||||
|
||||
if isinstance(e, exp.Func) and e.is_var_len_args:
|
||||
*head, tail = values
|
||||
return self.func(e.key, *head, *tail)
|
||||
|
||||
return self.func(e.key, *values)
|
||||
except Exception as ex:
|
||||
raise Exception(f"Could not rename {repr(e)}") from ex
|
||||
|
||||
|
@ -413,6 +421,7 @@ class Python(Dialect):
|
|||
exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})",
|
||||
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
|
||||
exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})",
|
||||
exp.Interval: lambda self, e: f"INTERVAL({self.sql(e.this)}, '{self.sql(e.unit)}')",
|
||||
exp.Is: lambda self, e: self.binary(e, "is"),
|
||||
exp.Lambda: _lambda_sql,
|
||||
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -31,6 +31,8 @@ class Generator:
|
|||
hex_end (str): specifies which ending character to use to delimit hex literals. Default: None.
|
||||
byte_start (str): specifies which starting character to use to delimit byte literals. Default: None.
|
||||
byte_end (str): specifies which ending character to use to delimit byte literals. Default: None.
|
||||
raw_start (str): specifies which starting character to use to delimit raw literals. Default: None.
|
||||
raw_end (str): specifies which ending character to use to delimit raw literals. Default: None.
|
||||
identify (bool | str): 'always': always quote, 'safe': quote identifiers if they don't contain an upcase, True defaults to always.
|
||||
normalize (bool): if set to True all identifiers will lower cased
|
||||
string_escape (str): specifies a string escape character. Default: '.
|
||||
|
@ -76,11 +78,12 @@ class Generator:
|
|||
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
|
||||
exp.MaterializedProperty: lambda self, e: "MATERIALIZED",
|
||||
exp.NoPrimaryIndexProperty: lambda self, e: "NO PRIMARY INDEX",
|
||||
exp.OnCommitProperty: lambda self, e: "ON COMMIT PRESERVE ROWS",
|
||||
exp.OnCommitProperty: lambda self, e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS",
|
||||
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
|
||||
exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
|
||||
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
|
||||
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
|
||||
exp.TemporaryProperty: lambda self, e: f"{'GLOBAL ' if e.args.get('global_') else ''}TEMPORARY",
|
||||
exp.TemporaryProperty: lambda self, e: f"TEMPORARY",
|
||||
exp.TransientProperty: lambda self, e: "TRANSIENT",
|
||||
exp.StabilityProperty: lambda self, e: e.name,
|
||||
exp.VolatileProperty: lambda self, e: "VOLATILE",
|
||||
|
@ -133,6 +136,15 @@ class Generator:
|
|||
# Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH")
|
||||
LIMIT_FETCH = "ALL"
|
||||
|
||||
# Whether a table is allowed to be renamed with a db
|
||||
RENAME_TABLE_WITH_DB = True
|
||||
|
||||
# The separator for grouping sets and rollups
|
||||
GROUPINGS_SEP = ","
|
||||
|
||||
# The string used for creating index on a table
|
||||
INDEX_ON = "ON"
|
||||
|
||||
TYPE_MAPPING = {
|
||||
exp.DataType.Type.NCHAR: "CHAR",
|
||||
exp.DataType.Type.NVARCHAR: "VARCHAR",
|
||||
|
@ -167,7 +179,6 @@ class Generator:
|
|||
PARAMETER_TOKEN = "@"
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
exp.AfterJournalProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.BlockCompressionProperty: exp.Properties.Location.POST_NAME,
|
||||
|
@ -196,7 +207,9 @@ class Generator:
|
|||
exp.MergeBlockRatioProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.NoPrimaryIndexProperty: exp.Properties.Location.POST_EXPRESSION,
|
||||
exp.OnCommitProperty: exp.Properties.Location.POST_EXPRESSION,
|
||||
exp.Order: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_WITH,
|
||||
exp.PrimaryKey: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.Property: exp.Properties.Location.POST_WITH,
|
||||
exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.RowFormatProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
|
@ -204,13 +217,15 @@ class Generator:
|
|||
exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.Set: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SettingsProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SetProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.TableFormatProperty: exp.Properties.Location.POST_WITH,
|
||||
exp.TemporaryProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.TransientProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.MergeTreeTTL: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION,
|
||||
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
|
||||
|
@ -221,7 +236,7 @@ class Generator:
|
|||
|
||||
RESERVED_KEYWORDS: t.Set[str] = set()
|
||||
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.With)
|
||||
UNWRAPPED_INTERVAL_VALUES = (exp.Literal, exp.Paren, exp.Column)
|
||||
UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Neg, exp.Paren)
|
||||
|
||||
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
|
||||
|
||||
|
@ -239,6 +254,8 @@ class Generator:
|
|||
"hex_end",
|
||||
"byte_start",
|
||||
"byte_end",
|
||||
"raw_start",
|
||||
"raw_end",
|
||||
"identify",
|
||||
"normalize",
|
||||
"string_escape",
|
||||
|
@ -276,6 +293,8 @@ class Generator:
|
|||
hex_end=None,
|
||||
byte_start=None,
|
||||
byte_end=None,
|
||||
raw_start=None,
|
||||
raw_end=None,
|
||||
identify=False,
|
||||
normalize=False,
|
||||
string_escape=None,
|
||||
|
@ -308,6 +327,8 @@ class Generator:
|
|||
self.hex_end = hex_end
|
||||
self.byte_start = byte_start
|
||||
self.byte_end = byte_end
|
||||
self.raw_start = raw_start
|
||||
self.raw_end = raw_end
|
||||
self.identify = identify
|
||||
self.normalize = normalize
|
||||
self.string_escape = string_escape or "'"
|
||||
|
@ -399,7 +420,11 @@ class Generator:
|
|||
return sql
|
||||
|
||||
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
|
||||
return f"{comments_sql}{self.sep()}{sql}"
|
||||
return (
|
||||
f"{self.sep()}{comments_sql}{sql}"
|
||||
if sql[0].isspace()
|
||||
else f"{comments_sql}{self.sep()}{sql}"
|
||||
)
|
||||
|
||||
return f"{sql} {comments_sql}"
|
||||
|
||||
|
@ -567,7 +592,9 @@ class Generator:
|
|||
) -> str:
|
||||
this = ""
|
||||
if expression.this is not None:
|
||||
this = " ALWAYS " if expression.this else " BY DEFAULT "
|
||||
on_null = "ON NULL " if expression.args.get("on_null") else ""
|
||||
this = " ALWAYS " if expression.this else f" BY DEFAULT {on_null}"
|
||||
|
||||
start = expression.args.get("start")
|
||||
start = f"START WITH {start}" if start else ""
|
||||
increment = expression.args.get("increment")
|
||||
|
@ -578,14 +605,20 @@ class Generator:
|
|||
maxvalue = f" MAXVALUE {maxvalue}" if maxvalue else ""
|
||||
cycle = expression.args.get("cycle")
|
||||
cycle_sql = ""
|
||||
|
||||
if cycle is not None:
|
||||
cycle_sql = f"{' NO' if not cycle else ''} CYCLE"
|
||||
cycle_sql = cycle_sql.strip() if not start and not increment else cycle_sql
|
||||
|
||||
sequence_opts = ""
|
||||
if start or increment or cycle_sql:
|
||||
sequence_opts = f"{start}{increment}{minvalue}{maxvalue}{cycle_sql}"
|
||||
sequence_opts = f" ({sequence_opts.strip()})"
|
||||
return f"GENERATED{this}AS IDENTITY{sequence_opts}"
|
||||
|
||||
expr = self.sql(expression, "expression")
|
||||
expr = f"({expr})" if expr else "IDENTITY"
|
||||
|
||||
return f"GENERATED{this}AS {expr}{sequence_opts}"
|
||||
|
||||
def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str:
|
||||
return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL"
|
||||
|
@ -596,8 +629,10 @@ class Generator:
|
|||
return f"PRIMARY KEY{' DESC' if desc else ' ASC'}"
|
||||
return f"PRIMARY KEY"
|
||||
|
||||
def uniquecolumnconstraint_sql(self, _) -> str:
|
||||
return "UNIQUE"
|
||||
def uniquecolumnconstraint_sql(self, expression: exp.UniqueColumnConstraint) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
this = f" {this}" if this else ""
|
||||
return f"UNIQUE{this}"
|
||||
|
||||
def create_sql(self, expression: exp.Create) -> str:
|
||||
kind = self.sql(expression, "kind").upper()
|
||||
|
@ -653,33 +688,9 @@ class Generator:
|
|||
prefix=" ",
|
||||
)
|
||||
|
||||
indexes = expression.args.get("indexes")
|
||||
if indexes:
|
||||
indexes_sql: t.List[str] = []
|
||||
for index in indexes:
|
||||
ind_unique = " UNIQUE" if index.args.get("unique") else ""
|
||||
ind_primary = " PRIMARY" if index.args.get("primary") else ""
|
||||
ind_amp = " AMP" if index.args.get("amp") else ""
|
||||
ind_name = f" {index.name}" if index.name else ""
|
||||
ind_columns = (
|
||||
f' ({self.expressions(index, key="columns", flat=True)})'
|
||||
if index.args.get("columns")
|
||||
else ""
|
||||
)
|
||||
ind_sql = f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}"
|
||||
|
||||
if indexes_sql:
|
||||
indexes_sql.append(ind_sql)
|
||||
else:
|
||||
indexes_sql.append(
|
||||
f"{ind_sql}{postindex_props_sql}"
|
||||
if index.args.get("primary")
|
||||
else f"{postindex_props_sql}{ind_sql}"
|
||||
)
|
||||
|
||||
index_sql = "".join(indexes_sql)
|
||||
else:
|
||||
index_sql = postindex_props_sql
|
||||
indexes = self.expressions(expression, key="indexes", indent=False, sep=" ")
|
||||
indexes = f" {indexes}" if indexes else ""
|
||||
index_sql = indexes + postindex_props_sql
|
||||
|
||||
replace = " OR REPLACE" if expression.args.get("replace") else ""
|
||||
unique = " UNIQUE" if expression.args.get("unique") else ""
|
||||
|
@ -711,9 +722,23 @@ class Generator:
|
|||
" WITH NO SCHEMA BINDING" if expression.args.get("no_schema_binding") else ""
|
||||
)
|
||||
|
||||
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{postexpression_props_sql}{index_sql}{no_schema_binding}"
|
||||
clone = self.sql(expression, "clone")
|
||||
clone = f" {clone}" if clone else ""
|
||||
|
||||
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{postexpression_props_sql}{index_sql}{no_schema_binding}{clone}"
|
||||
return self.prepend_ctes(expression, expression_sql)
|
||||
|
||||
def clone_sql(self, expression: exp.Clone) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
when = self.sql(expression, "when")
|
||||
|
||||
if when:
|
||||
kind = self.sql(expression, "kind")
|
||||
expr = self.sql(expression, "expression")
|
||||
return f"CLONE {this} {when} ({kind} => {expr})"
|
||||
|
||||
return f"CLONE {this}"
|
||||
|
||||
def describe_sql(self, expression: exp.Describe) -> str:
|
||||
return f"DESCRIBE {self.sql(expression, 'this')}"
|
||||
|
||||
|
@ -757,6 +782,17 @@ class Generator:
|
|||
return f"{self.byte_start}{this}{self.byte_end}"
|
||||
return this
|
||||
|
||||
def rawstring_sql(self, expression: exp.RawString) -> str:
|
||||
if self.raw_start:
|
||||
return f"{self.raw_start}{expression.name}{self.raw_end}"
|
||||
return self.sql(exp.Literal.string(expression.name.replace("\\", "\\\\")))
|
||||
|
||||
def datatypesize_sql(self, expression: exp.DataTypeSize) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
specifier = self.sql(expression, "expression")
|
||||
specifier = f" {specifier}" if specifier else ""
|
||||
return f"{this}{specifier}"
|
||||
|
||||
def datatype_sql(self, expression: exp.DataType) -> str:
|
||||
type_value = expression.this
|
||||
type_sql = self.TYPE_MAPPING.get(type_value, type_value.value)
|
||||
|
@ -768,7 +804,8 @@ class Generator:
|
|||
nested = f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}"
|
||||
if expression.args.get("values") is not None:
|
||||
delimiters = ("[", "]") if type_value == exp.DataType.Type.ARRAY else ("(", ")")
|
||||
values = f"{delimiters[0]}{self.expressions(expression, key='values')}{delimiters[1]}"
|
||||
values = self.expressions(expression, key="values", flat=True)
|
||||
values = f"{delimiters[0]}{values}{delimiters[1]}"
|
||||
else:
|
||||
nested = f"({interior})"
|
||||
|
||||
|
@ -836,10 +873,17 @@ class Generator:
|
|||
return ""
|
||||
|
||||
def index_sql(self, expression: exp.Index) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unique = "UNIQUE " if expression.args.get("unique") else ""
|
||||
primary = "PRIMARY " if expression.args.get("primary") else ""
|
||||
amp = "AMP " if expression.args.get("amp") else ""
|
||||
name = f"{expression.name} " if expression.name else ""
|
||||
table = self.sql(expression, "table")
|
||||
columns = self.sql(expression, "columns")
|
||||
return f"{this} ON {table} {columns}"
|
||||
table = f"{self.INDEX_ON} {table} " if table else ""
|
||||
index = "INDEX " if not table else ""
|
||||
columns = self.expressions(expression, key="columns", flat=True)
|
||||
partition_by = self.expressions(expression, key="partition_by", flat=True)
|
||||
partition_by = f" PARTITION BY {partition_by}" if partition_by else ""
|
||||
return f"{unique}{primary}{amp}{index}{name}{table}({columns}){partition_by}"
|
||||
|
||||
def identifier_sql(self, expression: exp.Identifier) -> str:
|
||||
text = expression.name
|
||||
|
@ -861,8 +905,9 @@ class Generator:
|
|||
output_format = f"OUTPUTFORMAT {output_format}" if output_format else ""
|
||||
return self.sep().join((input_format, output_format))
|
||||
|
||||
def national_sql(self, expression: exp.National) -> str:
|
||||
return f"N{self.sql(expression, 'this')}"
|
||||
def national_sql(self, expression: exp.National, prefix: str = "N") -> str:
|
||||
string = self.sql(exp.Literal.string(expression.name))
|
||||
return f"{prefix}{string}"
|
||||
|
||||
def partition_sql(self, expression: exp.Partition) -> str:
|
||||
return f"PARTITION({self.expressions(expression)})"
|
||||
|
@ -955,23 +1000,18 @@ class Generator:
|
|||
|
||||
def journalproperty_sql(self, expression: exp.JournalProperty) -> str:
|
||||
no = "NO " if expression.args.get("no") else ""
|
||||
local = expression.args.get("local")
|
||||
local = f"{local} " if local else ""
|
||||
dual = "DUAL " if expression.args.get("dual") else ""
|
||||
before = "BEFORE " if expression.args.get("before") else ""
|
||||
return f"{no}{dual}{before}JOURNAL"
|
||||
after = "AFTER " if expression.args.get("after") else ""
|
||||
return f"{no}{local}{dual}{before}{after}JOURNAL"
|
||||
|
||||
def freespaceproperty_sql(self, expression: exp.FreespaceProperty) -> str:
|
||||
freespace = self.sql(expression, "this")
|
||||
percent = " PERCENT" if expression.args.get("percent") else ""
|
||||
return f"FREESPACE={freespace}{percent}"
|
||||
|
||||
def afterjournalproperty_sql(self, expression: exp.AfterJournalProperty) -> str:
|
||||
no = "NO " if expression.args.get("no") else ""
|
||||
dual = "DUAL " if expression.args.get("dual") else ""
|
||||
local = ""
|
||||
if expression.args.get("local") is not None:
|
||||
local = "LOCAL " if expression.args.get("local") else "NOT LOCAL "
|
||||
return f"{no}{dual}{local}AFTER JOURNAL"
|
||||
|
||||
def checksumproperty_sql(self, expression: exp.ChecksumProperty) -> str:
|
||||
if expression.args.get("default"):
|
||||
property = "DEFAULT"
|
||||
|
@ -992,19 +1032,19 @@ class Generator:
|
|||
|
||||
def datablocksizeproperty_sql(self, expression: exp.DataBlocksizeProperty) -> str:
|
||||
default = expression.args.get("default")
|
||||
min = expression.args.get("min")
|
||||
if default is not None or min is not None:
|
||||
minimum = expression.args.get("minimum")
|
||||
maximum = expression.args.get("maximum")
|
||||
if default or minimum or maximum:
|
||||
if default:
|
||||
property = "DEFAULT"
|
||||
elif min:
|
||||
property = "MINIMUM"
|
||||
prop = "DEFAULT"
|
||||
elif minimum:
|
||||
prop = "MINIMUM"
|
||||
else:
|
||||
property = "MAXIMUM"
|
||||
return f"{property} DATABLOCKSIZE"
|
||||
else:
|
||||
units = expression.args.get("units")
|
||||
units = f" {units}" if units else ""
|
||||
return f"DATABLOCKSIZE={self.sql(expression, 'size')}{units}"
|
||||
prop = "MAXIMUM"
|
||||
return f"{prop} DATABLOCKSIZE"
|
||||
units = expression.args.get("units")
|
||||
units = f" {units}" if units else ""
|
||||
return f"DATABLOCKSIZE={self.sql(expression, 'size')}{units}"
|
||||
|
||||
def blockcompressionproperty_sql(self, expression: exp.BlockCompressionProperty) -> str:
|
||||
autotemp = expression.args.get("autotemp")
|
||||
|
@ -1014,16 +1054,16 @@ class Generator:
|
|||
never = expression.args.get("never")
|
||||
|
||||
if autotemp is not None:
|
||||
property = f"AUTOTEMP({self.expressions(autotemp)})"
|
||||
prop = f"AUTOTEMP({self.expressions(autotemp)})"
|
||||
elif always:
|
||||
property = "ALWAYS"
|
||||
prop = "ALWAYS"
|
||||
elif default:
|
||||
property = "DEFAULT"
|
||||
prop = "DEFAULT"
|
||||
elif manual:
|
||||
property = "MANUAL"
|
||||
prop = "MANUAL"
|
||||
elif never:
|
||||
property = "NEVER"
|
||||
return f"BLOCKCOMPRESSION={property}"
|
||||
prop = "NEVER"
|
||||
return f"BLOCKCOMPRESSION={prop}"
|
||||
|
||||
def isolatedloadingproperty_sql(self, expression: exp.IsolatedLoadingProperty) -> str:
|
||||
no = expression.args.get("no")
|
||||
|
@ -1138,21 +1178,24 @@ class Generator:
|
|||
|
||||
alias = self.sql(expression, "alias")
|
||||
alias = f"{sep}{alias}" if alias else ""
|
||||
hints = self.expressions(expression, key="hints", sep=", ", flat=True)
|
||||
hints = self.expressions(expression, key="hints", flat=True)
|
||||
hints = f" WITH ({hints})" if hints and self.TABLE_HINTS else ""
|
||||
laterals = self.expressions(expression, key="laterals", sep="")
|
||||
pivots = self.expressions(expression, key="pivots", sep=" ", flat=True)
|
||||
pivots = f" {pivots}" if pivots else ""
|
||||
joins = self.expressions(expression, key="joins", sep="")
|
||||
pivots = self.expressions(expression, key="pivots", sep="")
|
||||
laterals = self.expressions(expression, key="laterals", sep="")
|
||||
system_time = expression.args.get("system_time")
|
||||
system_time = f" {self.sql(expression, 'system_time')}" if system_time else ""
|
||||
|
||||
return f"{table}{system_time}{alias}{hints}{laterals}{joins}{pivots}"
|
||||
return f"{table}{system_time}{alias}{hints}{pivots}{joins}{laterals}"
|
||||
|
||||
def tablesample_sql(
|
||||
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
|
||||
) -> str:
|
||||
if self.alias_post_tablesample and expression.this.alias:
|
||||
this = self.sql(expression.this, "this")
|
||||
table = expression.this.copy()
|
||||
table.set("alias", None)
|
||||
this = self.sql(table)
|
||||
alias = f"{sep}{self.sql(expression.this, 'alias')}"
|
||||
else:
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -1177,14 +1220,22 @@ class Generator:
|
|||
return f"{this} {kind} {method}({bucket}{percent}{rows}{size}){seed}{alias}"
|
||||
|
||||
def pivot_sql(self, expression: exp.Pivot) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
expressions = self.expressions(expression, flat=True)
|
||||
|
||||
if expression.this:
|
||||
this = self.sql(expression, "this")
|
||||
on = f"{self.seg('ON')} {expressions}"
|
||||
using = self.expressions(expression, key="using", flat=True)
|
||||
using = f"{self.seg('USING')} {using}" if using else ""
|
||||
group = self.sql(expression, "group")
|
||||
return f"PIVOT {this}{on}{using}{group}"
|
||||
|
||||
alias = self.sql(expression, "alias")
|
||||
alias = f" AS {alias}" if alias else ""
|
||||
unpivot = expression.args.get("unpivot")
|
||||
direction = "UNPIVOT" if unpivot else "PIVOT"
|
||||
expressions = self.expressions(expression, key="expressions")
|
||||
field = self.sql(expression, "field")
|
||||
return f"{this} {direction}({expressions} FOR {field}){alias}"
|
||||
return f"{direction}({expressions} FOR {field}){alias}"
|
||||
|
||||
def tuple_sql(self, expression: exp.Tuple) -> str:
|
||||
return f"({self.expressions(expression, flat=True)})"
|
||||
|
@ -1218,8 +1269,7 @@ class Generator:
|
|||
return f"{self.seg('INTO')}{temporary or unlogged} {self.sql(expression, 'this')}"
|
||||
|
||||
def from_sql(self, expression: exp.From) -> str:
|
||||
expressions = self.expressions(expression, flat=True)
|
||||
return f"{self.seg('FROM')} {expressions}"
|
||||
return f"{self.seg('FROM')} {self.sql(expression, 'this')}"
|
||||
|
||||
def group_sql(self, expression: exp.Group) -> str:
|
||||
group_by = self.op_expressions("GROUP BY", expression)
|
||||
|
@ -1242,10 +1292,16 @@ class Generator:
|
|||
rollup_sql = self.expressions(expression, key="rollup", indent=False)
|
||||
rollup_sql = f"{self.seg('ROLLUP')} {self.wrap(rollup_sql)}" if rollup_sql else ""
|
||||
|
||||
groupings = csv(grouping_sets, cube_sql, rollup_sql, sep=",")
|
||||
groupings = csv(
|
||||
grouping_sets,
|
||||
cube_sql,
|
||||
rollup_sql,
|
||||
self.seg("WITH TOTALS") if expression.args.get("totals") else "",
|
||||
sep=self.GROUPINGS_SEP,
|
||||
)
|
||||
|
||||
if expression.args.get("expressions") and groupings:
|
||||
group_by = f"{group_by},"
|
||||
group_by = f"{group_by}{self.GROUPINGS_SEP}"
|
||||
|
||||
return f"{group_by}{groupings}"
|
||||
|
||||
|
@ -1254,18 +1310,16 @@ class Generator:
|
|||
return f"{self.seg('HAVING')}{self.sep()}{this}"
|
||||
|
||||
def join_sql(self, expression: exp.Join) -> str:
|
||||
op_sql = self.seg(
|
||||
" ".join(
|
||||
op
|
||||
for op in (
|
||||
"NATURAL" if expression.args.get("natural") else None,
|
||||
expression.side,
|
||||
expression.kind,
|
||||
expression.hint if self.JOIN_HINTS else None,
|
||||
"JOIN",
|
||||
)
|
||||
if op
|
||||
op_sql = " ".join(
|
||||
op
|
||||
for op in (
|
||||
"NATURAL" if expression.args.get("natural") else None,
|
||||
"GLOBAL" if expression.args.get("global") else None,
|
||||
expression.side,
|
||||
expression.kind,
|
||||
expression.hint if self.JOIN_HINTS else None,
|
||||
)
|
||||
if op
|
||||
)
|
||||
on_sql = self.sql(expression, "on")
|
||||
using = expression.args.get("using")
|
||||
|
@ -1273,6 +1327,8 @@ class Generator:
|
|||
if not on_sql and using:
|
||||
on_sql = csv(*(self.sql(column) for column in using))
|
||||
|
||||
this_sql = self.sql(expression, "this")
|
||||
|
||||
if on_sql:
|
||||
on_sql = self.indent(on_sql, skip_first=True)
|
||||
space = self.seg(" " * self.pad) if self.pretty else " "
|
||||
|
@ -1280,10 +1336,11 @@ class Generator:
|
|||
on_sql = f"{space}USING ({on_sql})"
|
||||
else:
|
||||
on_sql = f"{space}ON {on_sql}"
|
||||
elif not op_sql:
|
||||
return f", {this_sql}"
|
||||
|
||||
expression_sql = self.sql(expression, "expression")
|
||||
this_sql = self.sql(expression, "this")
|
||||
return f"{expression_sql}{op_sql} {this_sql}{on_sql}"
|
||||
op_sql = f"{op_sql} JOIN" if op_sql else "JOIN"
|
||||
return f"{self.seg(op_sql)} {this_sql}{on_sql}"
|
||||
|
||||
def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str:
|
||||
args = self.expressions(expression, flat=True)
|
||||
|
@ -1336,12 +1393,22 @@ class Generator:
|
|||
return f"PRAGMA {self.sql(expression, 'this')}"
|
||||
|
||||
def lock_sql(self, expression: exp.Lock) -> str:
|
||||
if self.LOCKING_READS_SUPPORTED:
|
||||
lock_type = "UPDATE" if expression.args["update"] else "SHARE"
|
||||
return self.seg(f"FOR {lock_type}")
|
||||
if not self.LOCKING_READS_SUPPORTED:
|
||||
self.unsupported("Locking reads using 'FOR UPDATE/SHARE' are not supported")
|
||||
return ""
|
||||
|
||||
self.unsupported("Locking reads using 'FOR UPDATE/SHARE' are not supported")
|
||||
return ""
|
||||
lock_type = "FOR UPDATE" if expression.args["update"] else "FOR SHARE"
|
||||
expressions = self.expressions(expression, flat=True)
|
||||
expressions = f" OF {expressions}" if expressions else ""
|
||||
wait = expression.args.get("wait")
|
||||
|
||||
if wait is not None:
|
||||
if isinstance(wait, exp.Literal):
|
||||
wait = f" WAIT {self.sql(wait)}"
|
||||
else:
|
||||
wait = " NOWAIT" if wait else " SKIP LOCKED"
|
||||
|
||||
return f"{lock_type}{expressions}{wait or ''}"
|
||||
|
||||
def literal_sql(self, expression: exp.Literal) -> str:
|
||||
text = expression.this or ""
|
||||
|
@ -1460,26 +1527,32 @@ class Generator:
|
|||
|
||||
return csv(
|
||||
*sqls,
|
||||
*[self.sql(sql) for sql in expression.args.get("joins") or []],
|
||||
*[self.sql(join) for join in expression.args.get("joins") or []],
|
||||
self.sql(expression, "match"),
|
||||
*[self.sql(sql) for sql in expression.args.get("laterals") or []],
|
||||
*[self.sql(lateral) for lateral in expression.args.get("laterals") or []],
|
||||
self.sql(expression, "where"),
|
||||
self.sql(expression, "group"),
|
||||
self.sql(expression, "having"),
|
||||
*self.after_having_modifiers(expression),
|
||||
self.sql(expression, "order"),
|
||||
self.sql(expression, "offset") if fetch else self.sql(limit),
|
||||
self.sql(limit) if fetch else self.sql(expression, "offset"),
|
||||
*self.after_limit_modifiers(expression),
|
||||
sep="",
|
||||
)
|
||||
|
||||
def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
|
||||
return [
|
||||
self.sql(expression, "qualify"),
|
||||
self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True)
|
||||
if expression.args.get("windows")
|
||||
else "",
|
||||
self.sql(expression, "distribute"),
|
||||
self.sql(expression, "sort"),
|
||||
self.sql(expression, "cluster"),
|
||||
self.sql(expression, "order"),
|
||||
self.sql(expression, "offset") if fetch else self.sql(limit),
|
||||
self.sql(limit) if fetch else self.sql(expression, "offset"),
|
||||
self.sql(expression, "lock"),
|
||||
self.sql(expression, "sample"),
|
||||
sep="",
|
||||
)
|
||||
]
|
||||
|
||||
def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]:
|
||||
locks = self.expressions(expression, key="locks", sep=" ")
|
||||
locks = f" {locks}" if locks else ""
|
||||
return [locks, self.sql(expression, "sample")]
|
||||
|
||||
def select_sql(self, expression: exp.Select) -> str:
|
||||
hint = self.sql(expression, "hint")
|
||||
|
@ -1529,13 +1602,10 @@ class Generator:
|
|||
alias = self.sql(expression, "alias")
|
||||
alias = f"{sep}{alias}" if alias else ""
|
||||
|
||||
sql = self.query_modifiers(
|
||||
expression,
|
||||
self.wrap(expression),
|
||||
alias,
|
||||
self.expressions(expression, key="pivots", sep=" "),
|
||||
)
|
||||
pivots = self.expressions(expression, key="pivots", sep=" ", flat=True)
|
||||
pivots = f" {pivots}" if pivots else ""
|
||||
|
||||
sql = self.query_modifiers(expression, self.wrap(expression), alias, pivots)
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
def qualify_sql(self, expression: exp.Qualify) -> str:
|
||||
|
@ -1712,10 +1782,6 @@ class Generator:
|
|||
options = f" {options}" if options else ""
|
||||
return f"PRIMARY KEY ({expressions}){options}"
|
||||
|
||||
def unique_sql(self, expression: exp.Unique) -> str:
|
||||
columns = self.expressions(expression, key="expressions")
|
||||
return f"UNIQUE ({columns})"
|
||||
|
||||
def if_sql(self, expression: exp.If) -> str:
|
||||
return self.case_sql(
|
||||
exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))
|
||||
|
@ -1745,6 +1811,26 @@ class Generator:
|
|||
encoding = f" ENCODING {encoding}" if encoding else ""
|
||||
return f"JSON_OBJECT({expressions}{null_handling}{unique_keys}{return_type}{format_json}{encoding})"
|
||||
|
||||
def openjsoncolumndef_sql(self, expression: exp.OpenJSONColumnDef) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
kind = self.sql(expression, "kind")
|
||||
path = self.sql(expression, "path")
|
||||
path = f" {path}" if path else ""
|
||||
as_json = " AS JSON" if expression.args.get("as_json") else ""
|
||||
return f"{this} {kind}{path}{as_json}"
|
||||
|
||||
def openjson_sql(self, expression: exp.OpenJSON) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
path = self.sql(expression, "path")
|
||||
path = f", {path}" if path else ""
|
||||
expressions = self.expressions(expression)
|
||||
with_ = (
|
||||
f" WITH ({self.seg(self.indent(expressions), sep='')}{self.seg(')', sep='')}"
|
||||
if expressions
|
||||
else ""
|
||||
)
|
||||
return f"OPENJSON({this}{path}){with_}"
|
||||
|
||||
def in_sql(self, expression: exp.In) -> str:
|
||||
query = expression.args.get("query")
|
||||
unnest = expression.args.get("unnest")
|
||||
|
@ -1773,7 +1859,7 @@ class Generator:
|
|||
|
||||
if self.SINGLE_STRING_INTERVAL:
|
||||
this = expression.this.name if expression.this else ""
|
||||
return f"INTERVAL '{this}{unit}'"
|
||||
return f"INTERVAL '{this}{unit}'" if this else f"INTERVAL{unit}"
|
||||
|
||||
this = self.sql(expression, "this")
|
||||
if this:
|
||||
|
@ -1883,6 +1969,28 @@ class Generator:
|
|||
expression_sql = self.sql(expression, "expression")
|
||||
return f"COMMENT{exists_sql}ON {kind} {this} IS {expression_sql}"
|
||||
|
||||
def mergetreettlaction_sql(self, expression: exp.MergeTreeTTLAction) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
delete = " DELETE" if expression.args.get("delete") else ""
|
||||
recompress = self.sql(expression, "recompress")
|
||||
recompress = f" RECOMPRESS {recompress}" if recompress else ""
|
||||
to_disk = self.sql(expression, "to_disk")
|
||||
to_disk = f" TO DISK {to_disk}" if to_disk else ""
|
||||
to_volume = self.sql(expression, "to_volume")
|
||||
to_volume = f" TO VOLUME {to_volume}" if to_volume else ""
|
||||
return f"{this}{delete}{recompress}{to_disk}{to_volume}"
|
||||
|
||||
def mergetreettl_sql(self, expression: exp.MergeTreeTTL) -> str:
|
||||
where = self.sql(expression, "where")
|
||||
group = self.sql(expression, "group")
|
||||
aggregates = self.expressions(expression, key="aggregates")
|
||||
aggregates = self.seg("SET") + self.seg(aggregates) if aggregates else ""
|
||||
|
||||
if not (where or group or aggregates) and len(expression.expressions) == 1:
|
||||
return f"TTL {self.expressions(expression, flat=True)}"
|
||||
|
||||
return f"TTL{self.seg(self.expressions(expression))}{where}{group}{aggregates}"
|
||||
|
||||
def transaction_sql(self, expression: exp.Transaction) -> str:
|
||||
return "BEGIN"
|
||||
|
||||
|
@ -1919,6 +2027,11 @@ class Generator:
|
|||
return f"ALTER COLUMN {this} DROP DEFAULT"
|
||||
|
||||
def renametable_sql(self, expression: exp.RenameTable) -> str:
|
||||
if not self.RENAME_TABLE_WITH_DB:
|
||||
# Remove db from tables
|
||||
expression = expression.transform(
|
||||
lambda n: exp.table_(n.this) if isinstance(n, exp.Table) else n
|
||||
)
|
||||
this = self.sql(expression, "this")
|
||||
return f"RENAME TO {this}"
|
||||
|
||||
|
@ -2208,3 +2321,12 @@ class Generator:
|
|||
self.unsupported("Format argument unsupported for TO_CHAR/TO_VARCHAR function")
|
||||
|
||||
return self.sql(exp.cast(expression.this, "text"))
|
||||
|
||||
|
||||
def cached_generator(
|
||||
cache: t.Optional[t.Dict[int, str]] = None
|
||||
) -> t.Callable[[exp.Expression], str]:
|
||||
"""Returns a cached generator."""
|
||||
cache = {} if cache is None else cache
|
||||
generator = Generator(normalize=True, identify="safe")
|
||||
return lambda e: generator.generate(e, cache)
|
||||
|
|
|
@ -9,14 +9,14 @@ from collections.abc import Collection
|
|||
from contextlib import contextmanager
|
||||
from copy import copy
|
||||
from enum import Enum
|
||||
from itertools import count
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot import exp
|
||||
from sqlglot._typing import E, T
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
from sqlglot.expressions import Expression
|
||||
|
||||
T = t.TypeVar("T")
|
||||
E = t.TypeVar("E", bound=Expression)
|
||||
|
||||
CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
|
||||
PYTHON_VERSION = sys.version_info[:2]
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
@ -25,7 +25,7 @@ logger = logging.getLogger("sqlglot")
|
|||
class AutoName(Enum):
|
||||
"""This is used for creating enum classes where `auto()` is the string form of the corresponding value's name."""
|
||||
|
||||
def _generate_next_value_(name, _start, _count, _last_values): # type: ignore
|
||||
def _generate_next_value_(name, _start, _count, _last_values):
|
||||
return name
|
||||
|
||||
|
||||
|
@ -92,7 +92,7 @@ def ensure_collection(value):
|
|||
)
|
||||
|
||||
|
||||
def csv(*args, sep: str = ", ") -> str:
|
||||
def csv(*args: str, sep: str = ", ") -> str:
|
||||
"""
|
||||
Formats any number of string arguments as CSV.
|
||||
|
||||
|
@ -304,9 +304,18 @@ def find_new_name(taken: t.Collection[str], base: str) -> str:
|
|||
return new
|
||||
|
||||
|
||||
def name_sequence(prefix: str) -> t.Callable[[], str]:
|
||||
"""Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
|
||||
sequence = count()
|
||||
return lambda: f"{prefix}{next(sequence)}"
|
||||
|
||||
|
||||
def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
|
||||
"""Returns a dictionary created from an object's attributes."""
|
||||
return {**{k: copy(v) for k, v in vars(obj).copy().items()}, **kwargs}
|
||||
return {
|
||||
**{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
def split_num_words(
|
||||
|
@ -381,15 +390,6 @@ def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
|
|||
yield value
|
||||
|
||||
|
||||
def count_params(function: t.Callable) -> int:
|
||||
"""
|
||||
Returns the number of formal parameters expected by a function, without counting "self"
|
||||
and "cls", in case of instance and class methods, respectively.
|
||||
"""
|
||||
count = function.__code__.co_argcount
|
||||
return count - 1 if inspect.ismethod(function) else count
|
||||
|
||||
|
||||
def dict_depth(d: t.Dict) -> int:
|
||||
"""
|
||||
Get the nesting depth of a dictionary.
|
||||
|
@ -430,12 +430,23 @@ def first(it: t.Iterable[T]) -> T:
|
|||
return next(i for i in it)
|
||||
|
||||
|
||||
def should_identify(text: str, identify: str | bool) -> bool:
|
||||
def case_sensitive(text: str, dialect: DialectType) -> bool:
|
||||
"""Checks if text contains any case sensitive characters depending on dialect."""
|
||||
from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE
|
||||
|
||||
unsafe = str.islower if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
|
||||
return any(unsafe(char) for char in text)
|
||||
|
||||
|
||||
def should_identify(text: str, identify: str | bool, dialect: DialectType = None) -> bool:
|
||||
"""Checks if text should be identified given an identify option.
|
||||
|
||||
Args:
|
||||
text: the text to check.
|
||||
identify: "always" | True - always returns true, "safe" - true if no upper case
|
||||
identify:
|
||||
"always" or `True`: always returns true.
|
||||
"safe": true if there is no uppercase or lowercase character in `text`, depending on `dialect`.
|
||||
dialect: the dialect to use in order to decide whether a text should be identified.
|
||||
|
||||
Returns:
|
||||
Whether or not a string should be identified.
|
||||
|
@ -443,5 +454,5 @@ def should_identify(text: str, identify: str | bool) -> bool:
|
|||
if identify is True or identify == "always":
|
||||
return True
|
||||
if identify == "safe":
|
||||
return not any(char.isupper() for char in text)
|
||||
return not case_sensitive(text, dialect)
|
||||
return False
|
||||
|
|
|
@ -5,10 +5,8 @@ import typing as t
|
|||
from dataclasses import dataclass, field
|
||||
|
||||
from sqlglot import Schema, exp, maybe_parse
|
||||
from sqlglot.optimizer import Scope, build_scope, optimize
|
||||
from sqlglot.optimizer.expand_laterals import expand_laterals
|
||||
from sqlglot.optimizer.qualify_columns import qualify_columns
|
||||
from sqlglot.optimizer.qualify_tables import qualify_tables
|
||||
from sqlglot.errors import SqlglotError
|
||||
from sqlglot.optimizer import Scope, build_scope, qualify
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
|
@ -40,8 +38,8 @@ def lineage(
|
|||
sql: str | exp.Expression,
|
||||
schema: t.Optional[t.Dict | Schema] = None,
|
||||
sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
|
||||
rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns, expand_laterals),
|
||||
dialect: DialectType = None,
|
||||
**kwargs,
|
||||
) -> Node:
|
||||
"""Build the lineage graph for a column of a SQL query.
|
||||
|
||||
|
@ -50,8 +48,8 @@ def lineage(
|
|||
sql: The SQL string or expression.
|
||||
schema: The schema of tables.
|
||||
sources: A mapping of queries which will be used to continue building lineage.
|
||||
rules: Optimizer rules to apply, by default only qualifying tables and columns.
|
||||
dialect: The dialect of input SQL.
|
||||
**kwargs: Qualification optimizer kwargs.
|
||||
|
||||
Returns:
|
||||
A lineage node.
|
||||
|
@ -68,8 +66,17 @@ def lineage(
|
|||
},
|
||||
)
|
||||
|
||||
optimized = optimize(expression, schema=schema, rules=rules)
|
||||
scope = build_scope(optimized)
|
||||
qualified = qualify.qualify(
|
||||
expression,
|
||||
dialect=dialect,
|
||||
schema=schema,
|
||||
**{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore
|
||||
)
|
||||
|
||||
scope = build_scope(qualified)
|
||||
|
||||
if not scope:
|
||||
raise SqlglotError("Cannot build lineage, sql must be SELECT")
|
||||
|
||||
def to_node(
|
||||
column_name: str,
|
||||
|
@ -109,10 +116,7 @@ def lineage(
|
|||
# a version that has only the column we care about.
|
||||
# "x", SELECT x, y FROM foo
|
||||
# => "x", SELECT x FROM foo
|
||||
source = optimize(
|
||||
scope.expression.select(select, append=False), schema=schema, rules=rules
|
||||
)
|
||||
select = source.selects[0]
|
||||
source = t.cast(exp.Expression, scope.expression.select(select, append=False))
|
||||
else:
|
||||
source = scope.expression
|
||||
|
||||
|
|
|
@ -3,10 +3,9 @@ from __future__ import annotations
|
|||
import itertools
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.helper import should_identify
|
||||
|
||||
|
||||
def canonicalize(expression: exp.Expression, identify: str = "safe") -> exp.Expression:
|
||||
def canonicalize(expression: exp.Expression) -> exp.Expression:
|
||||
"""Converts a sql expression into a standard form.
|
||||
|
||||
This method relies on annotate_types because many of the
|
||||
|
@ -14,19 +13,14 @@ def canonicalize(expression: exp.Expression, identify: str = "safe") -> exp.Expr
|
|||
|
||||
Args:
|
||||
expression: The expression to canonicalize.
|
||||
identify: Whether or not to force identify identifier.
|
||||
"""
|
||||
exp.replace_children(expression, canonicalize, identify=identify)
|
||||
exp.replace_children(expression, canonicalize)
|
||||
|
||||
expression = add_text_to_concat(expression)
|
||||
expression = coerce_type(expression)
|
||||
expression = remove_redundant_casts(expression)
|
||||
expression = ensure_bool_predicates(expression)
|
||||
|
||||
if isinstance(expression, exp.Identifier):
|
||||
if should_identify(expression.this, identify):
|
||||
expression.set("quoted", True)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
|
|
|
@ -19,24 +19,25 @@ def eliminate_ctes(expression):
|
|||
"""
|
||||
root = build_scope(expression)
|
||||
|
||||
ref_count = root.ref_count()
|
||||
if root:
|
||||
ref_count = root.ref_count()
|
||||
|
||||
# Traverse the scope tree in reverse so we can remove chains of unused CTEs
|
||||
for scope in reversed(list(root.traverse())):
|
||||
if scope.is_cte:
|
||||
count = ref_count[id(scope)]
|
||||
if count <= 0:
|
||||
cte_node = scope.expression.parent
|
||||
with_node = cte_node.parent
|
||||
cte_node.pop()
|
||||
# Traverse the scope tree in reverse so we can remove chains of unused CTEs
|
||||
for scope in reversed(list(root.traverse())):
|
||||
if scope.is_cte:
|
||||
count = ref_count[id(scope)]
|
||||
if count <= 0:
|
||||
cte_node = scope.expression.parent
|
||||
with_node = cte_node.parent
|
||||
cte_node.pop()
|
||||
|
||||
# Pop the entire WITH clause if this is the last CTE
|
||||
if len(with_node.expressions) <= 0:
|
||||
with_node.pop()
|
||||
# Pop the entire WITH clause if this is the last CTE
|
||||
if len(with_node.expressions) <= 0:
|
||||
with_node.pop()
|
||||
|
||||
# Decrement the ref count for all sources this CTE selects from
|
||||
for _, source in scope.selected_sources.values():
|
||||
if isinstance(source, Scope):
|
||||
ref_count[id(source)] -= 1
|
||||
# Decrement the ref count for all sources this CTE selects from
|
||||
for _, source in scope.selected_sources.values():
|
||||
if isinstance(source, Scope):
|
||||
ref_count[id(source)] -= 1
|
||||
|
||||
return expression
|
||||
|
|
|
@ -16,9 +16,9 @@ def eliminate_subqueries(expression):
|
|||
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
|
||||
|
||||
This also deduplicates common subqueries:
|
||||
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z")
|
||||
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z")
|
||||
>>> eliminate_subqueries(expression).sql()
|
||||
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z'
|
||||
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression
|
||||
|
@ -32,6 +32,9 @@ def eliminate_subqueries(expression):
|
|||
|
||||
root = build_scope(expression)
|
||||
|
||||
if not root:
|
||||
return expression
|
||||
|
||||
# Map of alias->Scope|Table
|
||||
# These are all aliases that are already used in the expression.
|
||||
# We don't want to create new CTEs that conflict with these names.
|
||||
|
@ -112,7 +115,7 @@ def _eliminate_union(scope, existing_ctes, taken):
|
|||
# Try to maintain the selections
|
||||
expressions = scope.selects
|
||||
selects = [
|
||||
exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name)
|
||||
exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name, copy=False)
|
||||
for e in expressions
|
||||
if e.alias_or_name
|
||||
]
|
||||
|
@ -120,7 +123,9 @@ def _eliminate_union(scope, existing_ctes, taken):
|
|||
if len(selects) != len(expressions):
|
||||
selects = ["*"]
|
||||
|
||||
scope.expression.replace(exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias)))
|
||||
scope.expression.replace(
|
||||
exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias, copy=False))
|
||||
)
|
||||
|
||||
if not duplicate_cte_alias:
|
||||
existing_ctes[scope.expression] = alias
|
||||
|
@ -131,6 +136,10 @@ def _eliminate_union(scope, existing_ctes, taken):
|
|||
|
||||
|
||||
def _eliminate_derived_table(scope, existing_ctes, taken):
|
||||
# This ensures we don't drop the "pivot" arg from a pivoted subquery
|
||||
if scope.parent.pivots:
|
||||
return None
|
||||
|
||||
parent = scope.expression.parent
|
||||
name, cte = _new_cte(scope, existing_ctes, taken)
|
||||
|
||||
|
@ -153,7 +162,7 @@ def _eliminate_cte(scope, existing_ctes, taken):
|
|||
for child_scope in scope.parent.traverse():
|
||||
for table, source in child_scope.selected_sources.values():
|
||||
if source is scope:
|
||||
new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name)
|
||||
new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name, copy=False)
|
||||
table.replace(new_table)
|
||||
|
||||
return cte
|
||||
|
|
|
@ -1,34 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
|
||||
|
||||
def expand_laterals(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Expand lateral column alias references.
|
||||
|
||||
This assumes `qualify_columns` as already run.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> sql = "SELECT x.a + 1 AS b, b + 1 AS c FROM x"
|
||||
>>> expression = sqlglot.parse_one(sql)
|
||||
>>> expand_laterals(expression).sql()
|
||||
'SELECT x.a + 1 AS b, x.a + 1 + 1 AS c FROM x'
|
||||
|
||||
Args:
|
||||
expression: expression to optimize
|
||||
Returns:
|
||||
optimized expression
|
||||
"""
|
||||
for select in expression.find_all(exp.Select):
|
||||
alias_to_expression: t.Dict[str, exp.Expression] = {}
|
||||
for projection in select.expressions:
|
||||
for column in projection.find_all(exp.Column):
|
||||
if not column.table and column.name in alias_to_expression:
|
||||
column.replace(alias_to_expression[column.name].copy())
|
||||
if isinstance(projection, exp.Alias):
|
||||
alias_to_expression[projection.alias] = projection.this
|
||||
return expression
|
|
@ -1,24 +0,0 @@
|
|||
from sqlglot import exp
|
||||
|
||||
|
||||
def expand_multi_table_selects(expression):
|
||||
"""
|
||||
Replace multiple FROM expressions with JOINs.
|
||||
|
||||
Example:
|
||||
>>> from sqlglot import parse_one
|
||||
>>> expand_multi_table_selects(parse_one("SELECT * FROM x, y")).sql()
|
||||
'SELECT * FROM x CROSS JOIN y'
|
||||
"""
|
||||
for from_ in expression.find_all(exp.From):
|
||||
parent = from_.parent
|
||||
|
||||
for query in from_.expressions[1:]:
|
||||
parent.join(
|
||||
query,
|
||||
join_type="CROSS",
|
||||
copy=False,
|
||||
)
|
||||
from_.expressions.remove(query)
|
||||
|
||||
return expression
|
|
@ -21,7 +21,7 @@ def isolate_table_selects(expression, schema=None):
|
|||
source.replace(
|
||||
exp.select("*")
|
||||
.from_(
|
||||
alias(source.copy(), source.name or source.alias, table=True),
|
||||
alias(source, source.name or source.alias, table=True),
|
||||
copy=False,
|
||||
)
|
||||
.subquery(source.alias, copy=False)
|
||||
|
|
|
@ -1,88 +0,0 @@
|
|||
from sqlglot import exp
|
||||
|
||||
|
||||
def lower_identities(expression):
|
||||
"""
|
||||
Convert all unquoted identifiers to lower case.
|
||||
|
||||
Assuming the schema is all lower case, this essentially makes identifiers case-insensitive.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
|
||||
>>> lower_identities(expression).sql()
|
||||
'SELECT bar.a AS A FROM "Foo".bar'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to quote
|
||||
Returns:
|
||||
sqlglot.Expression: quoted expression
|
||||
"""
|
||||
# We need to leave the output aliases unchanged, so the selects need special handling
|
||||
_lower_selects(expression)
|
||||
|
||||
# These clauses can reference output aliases and also need special handling
|
||||
_lower_order(expression)
|
||||
_lower_having(expression)
|
||||
|
||||
# We've already handled these args, so don't traverse into them
|
||||
traversed = {"expressions", "order", "having"}
|
||||
|
||||
if isinstance(expression, exp.Subquery):
|
||||
# Root subquery, e.g. (SELECT A AS A FROM X) LIMIT 1
|
||||
lower_identities(expression.this)
|
||||
traversed |= {"this"}
|
||||
|
||||
if isinstance(expression, exp.Union):
|
||||
# Union, e.g. SELECT A AS A FROM X UNION SELECT A AS A FROM X
|
||||
lower_identities(expression.left)
|
||||
lower_identities(expression.right)
|
||||
traversed |= {"this", "expression"}
|
||||
|
||||
for k, v in expression.iter_expressions():
|
||||
if k in traversed:
|
||||
continue
|
||||
v.transform(_lower, copy=False)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def _lower_selects(expression):
|
||||
for e in expression.expressions:
|
||||
# Leave output aliases as-is
|
||||
e.unalias().transform(_lower, copy=False)
|
||||
|
||||
|
||||
def _lower_order(expression):
|
||||
order = expression.args.get("order")
|
||||
|
||||
if not order:
|
||||
return
|
||||
|
||||
output_aliases = {e.alias for e in expression.expressions if isinstance(e, exp.Alias)}
|
||||
|
||||
for ordered in order.expressions:
|
||||
# Don't lower references to output aliases
|
||||
if not (
|
||||
isinstance(ordered.this, exp.Column)
|
||||
and not ordered.this.table
|
||||
and ordered.this.name in output_aliases
|
||||
):
|
||||
ordered.transform(_lower, copy=False)
|
||||
|
||||
|
||||
def _lower_having(expression):
|
||||
having = expression.args.get("having")
|
||||
|
||||
if not having:
|
||||
return
|
||||
|
||||
# Don't lower references to output aliases
|
||||
for agg in having.find_all(exp.AggFunc):
|
||||
agg.transform(_lower, copy=False)
|
||||
|
||||
|
||||
def _lower(node):
|
||||
if isinstance(node, exp.Identifier) and not node.quoted:
|
||||
node.set("this", node.this.lower())
|
||||
return node
|
|
@ -13,15 +13,15 @@ def merge_subqueries(expression, leave_tables_isolated=False):
|
|||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
|
||||
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
|
||||
>>> merge_subqueries(expression).sql()
|
||||
'SELECT x.a FROM x JOIN y'
|
||||
'SELECT x.a FROM x CROSS JOIN y'
|
||||
|
||||
If `leave_tables_isolated` is True, this will not merge inner queries into outer
|
||||
queries if it would result in multiple table selects in a single query:
|
||||
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
|
||||
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
|
||||
>>> merge_subqueries(expression, leave_tables_isolated=True).sql()
|
||||
'SELECT a FROM (SELECT x.a FROM x) JOIN y'
|
||||
'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y'
|
||||
|
||||
Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
|
||||
|
||||
|
@ -154,7 +154,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
|||
inner_from = inner_scope.expression.args.get("from")
|
||||
if not inner_from:
|
||||
return False
|
||||
inner_from_table = inner_from.expressions[0].alias_or_name
|
||||
inner_from_table = inner_from.alias_or_name
|
||||
inner_projections = {s.alias_or_name: s for s in inner_scope.selects}
|
||||
return any(
|
||||
col.table != inner_from_table
|
||||
|
@ -167,6 +167,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
|||
and isinstance(inner_select, exp.Select)
|
||||
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
|
||||
and inner_select.args.get("from")
|
||||
and not outer_scope.pivots
|
||||
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
|
||||
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
|
||||
and not (
|
||||
|
@ -210,7 +211,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
|
|||
elif isinstance(source, exp.Table) and source.alias:
|
||||
source.set("alias", new_alias)
|
||||
elif isinstance(source, exp.Table):
|
||||
source.replace(exp.alias_(source.copy(), new_alias))
|
||||
source.replace(exp.alias_(source, new_alias))
|
||||
|
||||
for column in inner_scope.source_columns(conflict):
|
||||
column.set("table", exp.to_identifier(new_name))
|
||||
|
@ -228,7 +229,7 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
|
|||
node_to_replace (exp.Subquery|exp.Table)
|
||||
alias (str)
|
||||
"""
|
||||
new_subquery = inner_scope.expression.args.get("from").expressions[0]
|
||||
new_subquery = inner_scope.expression.args["from"].this
|
||||
node_to_replace.replace(new_subquery)
|
||||
for join_hint in outer_scope.join_hints:
|
||||
tables = join_hint.find_all(exp.Table)
|
||||
|
@ -319,7 +320,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
|
|||
# Merge predicates from an outer join to the ON clause
|
||||
# if it only has columns that are already joined
|
||||
from_ = expression.args.get("from")
|
||||
sources = {table.alias_or_name for table in from_.expressions} if from_ else {}
|
||||
sources = {from_.alias_or_name} if from_ else {}
|
||||
|
||||
for join in expression.args["joins"]:
|
||||
source = join.alias_or_name
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.generator import cached_generator
|
||||
from sqlglot.helper import while_changing
|
||||
from sqlglot.optimizer.simplify import flatten, uniq_sort
|
||||
from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
|
@ -28,13 +28,16 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
|
|||
Returns:
|
||||
sqlglot.Expression: normalized expression
|
||||
"""
|
||||
cache: t.Dict[int, str] = {}
|
||||
generate = cached_generator()
|
||||
|
||||
for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
|
||||
if isinstance(node, exp.Connector):
|
||||
if normalized(node, dnf=dnf):
|
||||
continue
|
||||
root = node is expression
|
||||
original = node.copy()
|
||||
|
||||
node.transform(rewrite_between, copy=False)
|
||||
distance = normalization_distance(node, dnf=dnf)
|
||||
|
||||
if distance > max_distance:
|
||||
|
@ -43,11 +46,9 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
|
|||
)
|
||||
return expression
|
||||
|
||||
root = node is expression
|
||||
original = node.copy()
|
||||
try:
|
||||
node = node.replace(
|
||||
while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
|
||||
while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate))
|
||||
)
|
||||
except OptimizeError as e:
|
||||
logger.info(e)
|
||||
|
@ -111,7 +112,7 @@ def _predicate_lengths(expression, dnf):
|
|||
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
|
||||
|
||||
|
||||
def distributive_law(expression, dnf, max_distance, cache=None):
|
||||
def distributive_law(expression, dnf, max_distance, generate):
|
||||
"""
|
||||
x OR (y AND z) -> (x OR y) AND (x OR z)
|
||||
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
|
||||
|
@ -124,7 +125,7 @@ def distributive_law(expression, dnf, max_distance, cache=None):
|
|||
if distance > max_distance:
|
||||
raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
|
||||
|
||||
exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache))
|
||||
exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate))
|
||||
to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
|
||||
|
||||
if isinstance(expression, from_exp):
|
||||
|
@ -135,30 +136,30 @@ def distributive_law(expression, dnf, max_distance, cache=None):
|
|||
|
||||
if isinstance(a, to_exp) and isinstance(b, to_exp):
|
||||
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
|
||||
return _distribute(a, b, from_func, to_func, cache)
|
||||
return _distribute(b, a, from_func, to_func, cache)
|
||||
return _distribute(a, b, from_func, to_func, generate)
|
||||
return _distribute(b, a, from_func, to_func, generate)
|
||||
if isinstance(a, to_exp):
|
||||
return _distribute(b, a, from_func, to_func, cache)
|
||||
return _distribute(b, a, from_func, to_func, generate)
|
||||
if isinstance(b, to_exp):
|
||||
return _distribute(a, b, from_func, to_func, cache)
|
||||
return _distribute(a, b, from_func, to_func, generate)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def _distribute(a, b, from_func, to_func, cache):
|
||||
def _distribute(a, b, from_func, to_func, generate):
|
||||
if isinstance(a, exp.Connector):
|
||||
exp.replace_children(
|
||||
a,
|
||||
lambda c: to_func(
|
||||
uniq_sort(flatten(from_func(c, b.left)), cache),
|
||||
uniq_sort(flatten(from_func(c, b.right)), cache),
|
||||
uniq_sort(flatten(from_func(c, b.left)), generate),
|
||||
uniq_sort(flatten(from_func(c, b.right)), generate),
|
||||
copy=False,
|
||||
),
|
||||
)
|
||||
else:
|
||||
a = to_func(
|
||||
uniq_sort(flatten(from_func(a, b.left)), cache),
|
||||
uniq_sort(flatten(from_func(a, b.right)), cache),
|
||||
uniq_sort(flatten(from_func(a, b.left)), generate),
|
||||
uniq_sort(flatten(from_func(a, b.right)), generate),
|
||||
copy=False,
|
||||
)
|
||||
|
||||
|
|
36
sqlglot/optimizer/normalize_identifiers.py
Normal file
36
sqlglot/optimizer/normalize_identifiers.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE, DialectType
|
||||
|
||||
|
||||
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
|
||||
"""
|
||||
Normalize all unquoted identifiers to either lower or upper case, depending on
|
||||
the dialect. This essentially makes those identifiers case-insensitive.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
|
||||
>>> normalize_identifiers(expression).sql()
|
||||
'SELECT bar.a AS a FROM "Foo".bar'
|
||||
|
||||
Args:
|
||||
expression: The expression to transform.
|
||||
dialect: The dialect to use in order to decide how to normalize identifiers.
|
||||
|
||||
Returns:
|
||||
The transformed expression.
|
||||
"""
|
||||
return expression.transform(_normalize, dialect, copy=False)
|
||||
|
||||
|
||||
def _normalize(node: exp.Expression, dialect: DialectType = None) -> exp.Expression:
|
||||
if isinstance(node, exp.Identifier) and not node.quoted:
|
||||
node.set(
|
||||
"this",
|
||||
node.this.upper()
|
||||
if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE
|
||||
else node.this.lower(),
|
||||
)
|
||||
|
||||
return node
|
|
@ -1,6 +1,8 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.helper import tsort
|
||||
|
||||
JOIN_ATTRS = ("on", "side", "kind", "using", "natural")
|
||||
|
||||
|
||||
def optimize_joins(expression):
|
||||
"""
|
||||
|
@ -45,7 +47,7 @@ def reorder_joins(expression):
|
|||
Reorder joins by topological sort order based on predicate references.
|
||||
"""
|
||||
for from_ in expression.find_all(exp.From):
|
||||
head = from_.expressions[0]
|
||||
head = from_.this
|
||||
parent = from_.parent
|
||||
joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])}
|
||||
dag = {head.alias_or_name: []}
|
||||
|
@ -65,6 +67,9 @@ def normalize(expression):
|
|||
Remove INNER and OUTER from joins as they are optional.
|
||||
"""
|
||||
for join in expression.find_all(exp.Join):
|
||||
if not any(join.args.get(k) for k in JOIN_ATTRS):
|
||||
join.set("kind", "CROSS")
|
||||
|
||||
if join.kind != "CROSS":
|
||||
join.set("kind", None)
|
||||
return expression
|
||||
|
|
|
@ -10,36 +10,29 @@ from sqlglot.optimizer.canonicalize import canonicalize
|
|||
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
|
||||
from sqlglot.optimizer.eliminate_joins import eliminate_joins
|
||||
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
|
||||
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
|
||||
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
|
||||
from sqlglot.optimizer.lower_identities import lower_identities
|
||||
from sqlglot.optimizer.merge_subqueries import merge_subqueries
|
||||
from sqlglot.optimizer.normalize import normalize
|
||||
from sqlglot.optimizer.optimize_joins import optimize_joins
|
||||
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
|
||||
from sqlglot.optimizer.pushdown_projections import pushdown_projections
|
||||
from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns
|
||||
from sqlglot.optimizer.qualify_tables import qualify_tables
|
||||
from sqlglot.optimizer.qualify import qualify
|
||||
from sqlglot.optimizer.qualify_columns import quote_identifiers
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
|
||||
from sqlglot.schema import ensure_schema
|
||||
|
||||
RULES = (
|
||||
lower_identities,
|
||||
qualify_tables,
|
||||
isolate_table_selects,
|
||||
qualify_columns,
|
||||
qualify,
|
||||
pushdown_projections,
|
||||
validate_qualify_columns,
|
||||
normalize,
|
||||
unnest_subqueries,
|
||||
expand_multi_table_selects,
|
||||
pushdown_predicates,
|
||||
optimize_joins,
|
||||
eliminate_subqueries,
|
||||
merge_subqueries,
|
||||
eliminate_joins,
|
||||
eliminate_ctes,
|
||||
quote_identifiers,
|
||||
annotate_types,
|
||||
canonicalize,
|
||||
simplify,
|
||||
|
@ -54,7 +47,7 @@ def optimize(
|
|||
dialect: DialectType = None,
|
||||
rules: t.Sequence[t.Callable] = RULES,
|
||||
**kwargs,
|
||||
):
|
||||
) -> exp.Expression:
|
||||
"""
|
||||
Rewrite a sqlglot AST into an optimized form.
|
||||
|
||||
|
@ -72,14 +65,23 @@ def optimize(
|
|||
dialect: The dialect to parse the sql string.
|
||||
rules: sequence of optimizer rules to use.
|
||||
Many of the rules require tables and columns to be qualified.
|
||||
Do not remove qualify_tables or qualify_columns from the sequence of rules unless you know
|
||||
what you're doing!
|
||||
Do not remove `qualify` from the sequence of rules unless you know what you're doing!
|
||||
**kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.
|
||||
|
||||
Returns:
|
||||
sqlglot.Expression: optimized expression
|
||||
The optimized expression.
|
||||
"""
|
||||
schema = ensure_schema(schema or sqlglot.schema, dialect=dialect)
|
||||
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
|
||||
possible_kwargs = {
|
||||
"db": db,
|
||||
"catalog": catalog,
|
||||
"schema": schema,
|
||||
"dialect": dialect,
|
||||
"isolate_tables": True, # needed for other optimizations to perform well
|
||||
"quote_identifiers": False, # this happens in canonicalize
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
expression = exp.maybe_parse(expression, dialect=dialect, copy=True)
|
||||
for rule in rules:
|
||||
# Find any additional rule parameters, beyond `expression`
|
||||
|
@ -88,4 +90,5 @@ def optimize(
|
|||
param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
|
||||
}
|
||||
expression = rule(expression, **rule_kwargs)
|
||||
return expression
|
||||
|
||||
return t.cast(exp.Expression, expression)
|
||||
|
|
|
@ -21,26 +21,28 @@ def pushdown_predicates(expression):
|
|||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
root = build_scope(expression)
|
||||
scope_ref_count = root.ref_count()
|
||||
|
||||
for scope in reversed(list(root.traverse())):
|
||||
select = scope.expression
|
||||
where = select.args.get("where")
|
||||
if where:
|
||||
selected_sources = scope.selected_sources
|
||||
# a right join can only push down to itself and not the source FROM table
|
||||
for k, (node, source) in selected_sources.items():
|
||||
parent = node.find_ancestor(exp.Join, exp.From)
|
||||
if isinstance(parent, exp.Join) and parent.side == "RIGHT":
|
||||
selected_sources = {k: (node, source)}
|
||||
break
|
||||
pushdown(where.this, selected_sources, scope_ref_count)
|
||||
if root:
|
||||
scope_ref_count = root.ref_count()
|
||||
|
||||
# joins should only pushdown into itself, not to other joins
|
||||
# so we limit the selected sources to only itself
|
||||
for join in select.args.get("joins") or []:
|
||||
name = join.this.alias_or_name
|
||||
pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)
|
||||
for scope in reversed(list(root.traverse())):
|
||||
select = scope.expression
|
||||
where = select.args.get("where")
|
||||
if where:
|
||||
selected_sources = scope.selected_sources
|
||||
# a right join can only push down to itself and not the source FROM table
|
||||
for k, (node, source) in selected_sources.items():
|
||||
parent = node.find_ancestor(exp.Join, exp.From)
|
||||
if isinstance(parent, exp.Join) and parent.side == "RIGHT":
|
||||
selected_sources = {k: (node, source)}
|
||||
break
|
||||
pushdown(where.this, selected_sources, scope_ref_count)
|
||||
|
||||
# joins should only pushdown into itself, not to other joins
|
||||
# so we limit the selected sources to only itself
|
||||
for join in select.args.get("joins") or []:
|
||||
name = join.this.alias_or_name
|
||||
pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)
|
||||
|
||||
return expression
|
||||
|
||||
|
|
|
@ -39,8 +39,9 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
|
|||
for scope in reversed(traverse_scope(expression)):
|
||||
parent_selections = referenced_columns.get(scope, {SELECT_ALL})
|
||||
|
||||
if scope.expression.args.get("distinct"):
|
||||
# We can't remove columns SELECT DISTINCT nor UNION DISTINCT
|
||||
if scope.expression.args.get("distinct") or scope.parent and scope.parent.pivots:
|
||||
# We can't remove columns SELECT DISTINCT nor UNION DISTINCT. The same holds if
|
||||
# we select from a pivoted source in the parent scope.
|
||||
parent_selections = {SELECT_ALL}
|
||||
|
||||
if isinstance(scope.expression, exp.Union):
|
||||
|
@ -105,7 +106,9 @@ def _remove_unused_selections(scope, parent_selections, schema):
|
|||
|
||||
for name in sorted(parent_selections):
|
||||
if name not in names:
|
||||
new_selections.append(alias(exp.column(name, table=resolver.get_table(name)), name))
|
||||
new_selections.append(
|
||||
alias(exp.column(name, table=resolver.get_table(name)), name, copy=False)
|
||||
)
|
||||
|
||||
# If there are no remaining selections, just select a single constant
|
||||
if not new_selections:
|
||||
|
|
80
sqlglot/optimizer/qualify.py
Normal file
80
sqlglot/optimizer/qualify.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
|
||||
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
|
||||
from sqlglot.optimizer.qualify_columns import (
|
||||
qualify_columns as qualify_columns_func,
|
||||
quote_identifiers as quote_identifiers_func,
|
||||
validate_qualify_columns as validate_qualify_columns_func,
|
||||
)
|
||||
from sqlglot.optimizer.qualify_tables import qualify_tables
|
||||
from sqlglot.schema import Schema, ensure_schema
|
||||
|
||||
|
||||
def qualify(
|
||||
expression: exp.Expression,
|
||||
dialect: DialectType = None,
|
||||
db: t.Optional[str] = None,
|
||||
catalog: t.Optional[str] = None,
|
||||
schema: t.Optional[dict | Schema] = None,
|
||||
expand_alias_refs: bool = True,
|
||||
infer_schema: t.Optional[bool] = None,
|
||||
isolate_tables: bool = False,
|
||||
qualify_columns: bool = True,
|
||||
validate_qualify_columns: bool = True,
|
||||
quote_identifiers: bool = True,
|
||||
identify: bool = True,
|
||||
) -> exp.Expression:
|
||||
"""
|
||||
Rewrite sqlglot AST to have normalized and qualified tables and columns.
|
||||
|
||||
This step is necessary for all further SQLGlot optimizations.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> schema = {"tbl": {"col": "INT"}}
|
||||
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
|
||||
>>> qualify(expression, schema=schema).sql()
|
||||
'SELECT "tbl"."col" AS "col" FROM "tbl" AS "tbl"'
|
||||
|
||||
Args:
|
||||
expression: Expression to qualify.
|
||||
db: Default database name for tables.
|
||||
catalog: Default catalog name for tables.
|
||||
schema: Schema to infer column names and types.
|
||||
expand_alias_refs: Whether or not to expand references to aliases.
|
||||
infer_schema: Whether or not to infer the schema if missing.
|
||||
isolate_tables: Whether or not to isolate table selects.
|
||||
qualify_columns: Whether or not to qualify columns.
|
||||
validate_qualify_columns: Whether or not to validate columns.
|
||||
quote_identifiers: Whether or not to run the quote_identifiers step.
|
||||
This step is necessary to ensure correctness for case sensitive queries.
|
||||
But this flag is provided in case this step is performed at a later time.
|
||||
identify: If True, quote all identifiers, else only necessary ones.
|
||||
|
||||
Returns:
|
||||
The qualified expression.
|
||||
"""
|
||||
schema = ensure_schema(schema, dialect=dialect)
|
||||
expression = normalize_identifiers(expression, dialect=dialect)
|
||||
expression = qualify_tables(expression, db=db, catalog=catalog, schema=schema)
|
||||
|
||||
if isolate_tables:
|
||||
expression = isolate_table_selects(expression, schema=schema)
|
||||
|
||||
if qualify_columns:
|
||||
expression = qualify_columns_func(
|
||||
expression, schema, expand_alias_refs=expand_alias_refs, infer_schema=infer_schema
|
||||
)
|
||||
|
||||
if quote_identifiers:
|
||||
expression = quote_identifiers_func(expression, dialect=dialect, identify=identify)
|
||||
|
||||
if validate_qualify_columns:
|
||||
validate_qualify_columns_func(expression)
|
||||
|
||||
return expression
|
|
@ -1,14 +1,23 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import typing as t
|
||||
|
||||
from sqlglot import alias, exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.optimizer.expand_laterals import expand_laterals as _expand_laterals
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import ensure_schema
|
||||
from sqlglot.helper import case_sensitive, seq_get
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
|
||||
from sqlglot.schema import Schema, ensure_schema
|
||||
|
||||
|
||||
def qualify_columns(expression, schema, expand_laterals=True):
|
||||
def qualify_columns(
|
||||
expression: exp.Expression,
|
||||
schema: dict | Schema,
|
||||
expand_alias_refs: bool = True,
|
||||
infer_schema: t.Optional[bool] = None,
|
||||
) -> exp.Expression:
|
||||
"""
|
||||
Rewrite sqlglot AST to have fully qualified columns.
|
||||
|
||||
|
@ -20,32 +29,36 @@ def qualify_columns(expression, schema, expand_laterals=True):
|
|||
'SELECT tbl.col AS col FROM tbl'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to qualify
|
||||
schema (dict|sqlglot.optimizer.Schema): Database schema
|
||||
expression: expression to qualify
|
||||
schema: Database schema
|
||||
expand_alias_refs: whether or not to expand references to aliases
|
||||
infer_schema: whether or not to infer the schema if missing
|
||||
Returns:
|
||||
sqlglot.Expression: qualified expression
|
||||
"""
|
||||
schema = ensure_schema(schema)
|
||||
|
||||
if not schema.mapping and expand_laterals:
|
||||
expression = _expand_laterals(expression)
|
||||
infer_schema = schema.empty if infer_schema is None else infer_schema
|
||||
|
||||
for scope in traverse_scope(expression):
|
||||
resolver = Resolver(scope, schema)
|
||||
resolver = Resolver(scope, schema, infer_schema=infer_schema)
|
||||
_pop_table_column_aliases(scope.ctes)
|
||||
_pop_table_column_aliases(scope.derived_tables)
|
||||
using_column_tables = _expand_using(scope, resolver)
|
||||
|
||||
if schema.empty and expand_alias_refs:
|
||||
_expand_alias_refs(scope, resolver)
|
||||
|
||||
_qualify_columns(scope, resolver)
|
||||
|
||||
if not schema.empty and expand_alias_refs:
|
||||
_expand_alias_refs(scope, resolver)
|
||||
|
||||
if not isinstance(scope.expression, exp.UDTF):
|
||||
_expand_stars(scope, resolver, using_column_tables)
|
||||
_qualify_outputs(scope)
|
||||
_expand_alias_refs(scope, resolver)
|
||||
_expand_group_by(scope, resolver)
|
||||
_expand_order_by(scope)
|
||||
|
||||
if schema.mapping and expand_laterals:
|
||||
expression = _expand_laterals(expression)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -55,9 +68,11 @@ def validate_qualify_columns(expression):
|
|||
for scope in traverse_scope(expression):
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
unqualified_columns.extend(scope.unqualified_columns)
|
||||
if scope.external_columns and not scope.is_correlated_subquery:
|
||||
if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
|
||||
column = scope.external_columns[0]
|
||||
raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'")
|
||||
raise OptimizeError(
|
||||
f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
|
||||
)
|
||||
|
||||
if unqualified_columns:
|
||||
raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
|
||||
|
@ -142,52 +157,48 @@ def _expand_using(scope, resolver):
|
|||
|
||||
# Ensure selects keep their output name
|
||||
if isinstance(column.parent, exp.Select):
|
||||
replacement = exp.alias_(replacement, alias=column.name)
|
||||
replacement = alias(replacement, alias=column.name, copy=False)
|
||||
|
||||
scope.replace(column, replacement)
|
||||
|
||||
return column_tables
|
||||
|
||||
|
||||
def _expand_alias_refs(scope, resolver):
|
||||
selects = {}
|
||||
def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
|
||||
expression = scope.expression
|
||||
|
||||
# Replace references to select aliases
|
||||
def transform(node, source_first=True):
|
||||
if isinstance(node, exp.Column) and not node.table:
|
||||
table = resolver.get_table(node.name)
|
||||
if not isinstance(expression, exp.Select):
|
||||
return
|
||||
|
||||
# Source columns get priority over select aliases
|
||||
if source_first and table:
|
||||
node.set("table", table)
|
||||
return node
|
||||
alias_to_expression: t.Dict[str, exp.Expression] = {}
|
||||
|
||||
if not selects:
|
||||
for s in scope.selects:
|
||||
selects[s.alias_or_name] = s
|
||||
select = selects.get(node.name)
|
||||
|
||||
if select:
|
||||
scope.clear_cache()
|
||||
if isinstance(select, exp.Alias):
|
||||
select = select.this
|
||||
return select.copy()
|
||||
|
||||
node.set("table", table)
|
||||
elif isinstance(node, exp.Expression) and not isinstance(node, exp.Subqueryable):
|
||||
exp.replace_children(node, transform, source_first)
|
||||
|
||||
return node
|
||||
|
||||
for select in scope.expression.selects:
|
||||
transform(select)
|
||||
|
||||
for modifier, source_first in (
|
||||
("where", True),
|
||||
("group", True),
|
||||
("having", False),
|
||||
def replace_columns(
|
||||
node: t.Optional[exp.Expression], expand: bool = True, resolve_agg: bool = False
|
||||
):
|
||||
transform(scope.expression.args.get(modifier), source_first=source_first)
|
||||
if not node:
|
||||
return
|
||||
|
||||
for column, *_ in walk_in_scope(node):
|
||||
if not isinstance(column, exp.Column):
|
||||
continue
|
||||
table = resolver.get_table(column.name) if resolve_agg and not column.table else None
|
||||
if table and column.find_ancestor(exp.AggFunc):
|
||||
column.set("table", table)
|
||||
elif expand and not column.table and column.name in alias_to_expression:
|
||||
column.replace(alias_to_expression[column.name].copy())
|
||||
|
||||
for projection in scope.selects:
|
||||
replace_columns(projection)
|
||||
|
||||
if isinstance(projection, exp.Alias):
|
||||
alias_to_expression[projection.alias] = projection.this
|
||||
|
||||
replace_columns(expression.args.get("where"))
|
||||
replace_columns(expression.args.get("group"))
|
||||
replace_columns(expression.args.get("having"), resolve_agg=True)
|
||||
replace_columns(expression.args.get("qualify"), resolve_agg=True)
|
||||
replace_columns(expression.args.get("order"), expand=False, resolve_agg=True)
|
||||
scope.clear_cache()
|
||||
|
||||
|
||||
def _expand_group_by(scope, resolver):
|
||||
|
@ -242,6 +253,12 @@ def _qualify_columns(scope, resolver):
|
|||
raise OptimizeError(f"Unknown column: {column_name}")
|
||||
|
||||
if not column_table:
|
||||
if scope.pivots and not column.find_ancestor(exp.Pivot):
|
||||
# If the column is under the Pivot expression, we need to qualify it
|
||||
# using the name of the pivoted source instead of the pivot's alias
|
||||
column.set("table", exp.to_identifier(scope.pivots[0].alias))
|
||||
continue
|
||||
|
||||
column_table = resolver.get_table(column_name)
|
||||
|
||||
# column_table can be a '' because bigquery unnest has no table alias
|
||||
|
@ -265,38 +282,12 @@ def _qualify_columns(scope, resolver):
|
|||
if column_table:
|
||||
column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
|
||||
|
||||
columns_missing_from_scope = []
|
||||
|
||||
# Determine whether each reference in the order by clause is to a column or an alias.
|
||||
order = scope.expression.args.get("order")
|
||||
|
||||
if order:
|
||||
for ordered in order.expressions:
|
||||
for column in ordered.find_all(exp.Column):
|
||||
if (
|
||||
not column.table
|
||||
and column.parent is not ordered
|
||||
and column.name in resolver.all_columns
|
||||
):
|
||||
columns_missing_from_scope.append(column)
|
||||
|
||||
# Determine whether each reference in the having clause is to a column or an alias.
|
||||
having = scope.expression.args.get("having")
|
||||
|
||||
if having:
|
||||
for column in having.find_all(exp.Column):
|
||||
if (
|
||||
not column.table
|
||||
and column.find_ancestor(exp.AggFunc)
|
||||
and column.name in resolver.all_columns
|
||||
):
|
||||
columns_missing_from_scope.append(column)
|
||||
|
||||
for column in columns_missing_from_scope:
|
||||
column_table = resolver.get_table(column.name)
|
||||
|
||||
if column_table:
|
||||
column.set("table", column_table)
|
||||
for pivot in scope.pivots:
|
||||
for column in pivot.find_all(exp.Column):
|
||||
if not column.table and column.name in resolver.all_columns:
|
||||
column_table = resolver.get_table(column.name)
|
||||
if column_table:
|
||||
column.set("table", column_table)
|
||||
|
||||
|
||||
def _expand_stars(scope, resolver, using_column_tables):
|
||||
|
@ -307,6 +298,19 @@ def _expand_stars(scope, resolver, using_column_tables):
|
|||
replace_columns = {}
|
||||
coalesced_columns = set()
|
||||
|
||||
# TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
|
||||
pivot_columns = None
|
||||
pivot_output_columns = None
|
||||
pivot = seq_get(scope.pivots, 0)
|
||||
|
||||
has_pivoted_source = pivot and not pivot.args.get("unpivot")
|
||||
if has_pivoted_source:
|
||||
pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))
|
||||
|
||||
pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
|
||||
if not pivot_output_columns:
|
||||
pivot_output_columns = [col.alias_or_name for col in pivot.expressions]
|
||||
|
||||
for expression in scope.selects:
|
||||
if isinstance(expression, exp.Star):
|
||||
tables = list(scope.selected_sources)
|
||||
|
@ -323,9 +327,18 @@ def _expand_stars(scope, resolver, using_column_tables):
|
|||
for table in tables:
|
||||
if table not in scope.sources:
|
||||
raise OptimizeError(f"Unknown table: {table}")
|
||||
|
||||
columns = resolver.get_source_columns(table, only_visible=True)
|
||||
|
||||
if columns and "*" not in columns:
|
||||
if has_pivoted_source:
|
||||
implicit_columns = [col for col in columns if col not in pivot_columns]
|
||||
new_selections.extend(
|
||||
exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
|
||||
for name in implicit_columns + pivot_output_columns
|
||||
)
|
||||
continue
|
||||
|
||||
table_id = id(table)
|
||||
for name in columns:
|
||||
if name in using_column_tables and table in using_column_tables[name]:
|
||||
|
@ -337,16 +350,21 @@ def _expand_stars(scope, resolver, using_column_tables):
|
|||
coalesce = [exp.column(name, table=table) for table in tables]
|
||||
|
||||
new_selections.append(
|
||||
exp.alias_(
|
||||
exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), alias=name
|
||||
alias(
|
||||
exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
|
||||
alias=name,
|
||||
copy=False,
|
||||
)
|
||||
)
|
||||
elif name not in except_columns.get(table_id, set()):
|
||||
alias_ = replace_columns.get(table_id, {}).get(name, name)
|
||||
column = exp.column(name, table)
|
||||
new_selections.append(alias(column, alias_) if alias_ != name else column)
|
||||
column = exp.column(name, table=table)
|
||||
new_selections.append(
|
||||
alias(column, alias_, copy=False) if alias_ != name else column
|
||||
)
|
||||
else:
|
||||
return
|
||||
|
||||
scope.expression.set("expressions", new_selections)
|
||||
|
||||
|
||||
|
@ -388,9 +406,6 @@ def _qualify_outputs(scope):
|
|||
selection = alias(
|
||||
selection,
|
||||
alias=selection.output_name or f"_col_{i}",
|
||||
quoted=True
|
||||
if isinstance(selection, exp.Column) and selection.this.quoted
|
||||
else None,
|
||||
)
|
||||
if aliased_column:
|
||||
selection.set("alias", exp.to_identifier(aliased_column))
|
||||
|
@ -400,6 +415,23 @@ def _qualify_outputs(scope):
|
|||
scope.expression.set("expressions", new_selections)
|
||||
|
||||
|
||||
def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
|
||||
"""Makes sure all identifiers that need to be quoted are quoted."""
|
||||
|
||||
def _quote(expression: E) -> E:
|
||||
if isinstance(expression, exp.Identifier):
|
||||
name = expression.this
|
||||
expression.set(
|
||||
"quoted",
|
||||
identify
|
||||
or case_sensitive(name, dialect=dialect)
|
||||
or not exp.SAFE_IDENTIFIER_RE.match(name),
|
||||
)
|
||||
return expression
|
||||
|
||||
return expression.transform(_quote, copy=False)
|
||||
|
||||
|
||||
class Resolver:
|
||||
"""
|
||||
Helper for resolving columns.
|
||||
|
@ -407,12 +439,13 @@ class Resolver:
|
|||
This is a class so we can lazily load some things and easily share them across functions.
|
||||
"""
|
||||
|
||||
def __init__(self, scope, schema):
|
||||
def __init__(self, scope, schema, infer_schema: bool = True):
|
||||
self.scope = scope
|
||||
self.schema = schema
|
||||
self._source_columns = None
|
||||
self._unambiguous_columns = None
|
||||
self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
|
||||
self._all_columns = None
|
||||
self._infer_schema = infer_schema
|
||||
|
||||
def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
|
||||
"""
|
||||
|
@ -430,7 +463,7 @@ class Resolver:
|
|||
|
||||
table_name = self._unambiguous_columns.get(column_name)
|
||||
|
||||
if not table_name:
|
||||
if not table_name and self._infer_schema:
|
||||
sources_without_schema = tuple(
|
||||
source
|
||||
for source, columns in self._get_all_source_columns().items()
|
||||
|
@ -450,11 +483,9 @@ class Resolver:
|
|||
|
||||
node_alias = node.args.get("alias")
|
||||
if node_alias:
|
||||
return node_alias.this
|
||||
return exp.to_identifier(node_alias.this)
|
||||
|
||||
return exp.to_identifier(
|
||||
table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
|
||||
)
|
||||
return exp.to_identifier(table_name)
|
||||
|
||||
@property
|
||||
def all_columns(self):
|
||||
|
|
|
@ -1,11 +1,19 @@
|
|||
import itertools
|
||||
import typing as t
|
||||
|
||||
from sqlglot import alias, exp
|
||||
from sqlglot.helper import csv_reader
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.helper import csv_reader, name_sequence
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import Schema
|
||||
|
||||
|
||||
def qualify_tables(expression, db=None, catalog=None, schema=None):
|
||||
def qualify_tables(
|
||||
expression: E,
|
||||
db: t.Optional[str] = None,
|
||||
catalog: t.Optional[str] = None,
|
||||
schema: t.Optional[Schema] = None,
|
||||
) -> E:
|
||||
"""
|
||||
Rewrite sqlglot AST to have fully qualified tables. Additionally, this
|
||||
replaces "join constructs" (*) by equivalent SELECT * subqueries.
|
||||
|
@ -21,19 +29,17 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
|
|||
'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to qualify
|
||||
db (str): Database name
|
||||
catalog (str): Catalog name
|
||||
expression: Expression to qualify
|
||||
db: Database name
|
||||
catalog: Catalog name
|
||||
schema: A schema to populate
|
||||
|
||||
Returns:
|
||||
sqlglot.Expression: qualified expression
|
||||
The qualified expression.
|
||||
|
||||
(*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html
|
||||
"""
|
||||
sequence = itertools.count()
|
||||
|
||||
next_name = lambda: f"_q_{next(sequence)}"
|
||||
next_alias_name = name_sequence("_q_")
|
||||
|
||||
for scope in traverse_scope(expression):
|
||||
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
|
||||
|
@ -44,10 +50,14 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
|
|||
derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
|
||||
|
||||
if not derived_table.args.get("alias"):
|
||||
alias_ = f"_q_{next(sequence)}"
|
||||
alias_ = next_alias_name()
|
||||
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
|
||||
scope.rename_source(None, alias_)
|
||||
|
||||
pivots = derived_table.args.get("pivots")
|
||||
if pivots and not pivots[0].alias:
|
||||
pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())))
|
||||
|
||||
for name, source in scope.sources.items():
|
||||
if isinstance(source, exp.Table):
|
||||
if isinstance(source.this, exp.Identifier):
|
||||
|
@ -59,12 +69,19 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
|
|||
if not source.alias:
|
||||
source = source.replace(
|
||||
alias(
|
||||
source.copy(),
|
||||
name if name else next_name(),
|
||||
source,
|
||||
name or source.name or next_alias_name(),
|
||||
copy=True,
|
||||
table=True,
|
||||
)
|
||||
)
|
||||
|
||||
pivots = source.args.get("pivots")
|
||||
if pivots and not pivots[0].alias:
|
||||
pivots[0].set(
|
||||
"alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))
|
||||
)
|
||||
|
||||
if schema and isinstance(source.this, exp.ReadCSV):
|
||||
with csv_reader(source.this) as reader:
|
||||
header = next(reader)
|
||||
|
@ -74,11 +91,11 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
|
|||
)
|
||||
elif isinstance(source, Scope) and source.is_udtf:
|
||||
udtf = source.expression
|
||||
table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_name())
|
||||
table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_alias_name())
|
||||
udtf.set("alias", table_alias)
|
||||
|
||||
if not table_alias.name:
|
||||
table_alias.set("this", next_name())
|
||||
table_alias.set("this", next_alias_name())
|
||||
if isinstance(udtf, exp.Values) and not table_alias.columns:
|
||||
for i, e in enumerate(udtf.expressions[0].expressions):
|
||||
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import itertools
|
||||
import typing as t
|
||||
from collections import defaultdict
|
||||
from enum import Enum, auto
|
||||
|
||||
|
@ -83,6 +84,7 @@ class Scope:
|
|||
self._columns = None
|
||||
self._external_columns = None
|
||||
self._join_hints = None
|
||||
self._pivots = None
|
||||
|
||||
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
|
||||
"""Branch from the current scope to a new, inner scope"""
|
||||
|
@ -261,12 +263,14 @@ class Scope:
|
|||
|
||||
self._columns = []
|
||||
for column in columns + external_columns:
|
||||
ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint)
|
||||
ancestor = column.find_ancestor(
|
||||
exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint
|
||||
)
|
||||
if (
|
||||
not ancestor
|
||||
# Window functions can have an ORDER BY clause
|
||||
or not isinstance(ancestor.parent, exp.Select)
|
||||
or column.table
|
||||
or isinstance(ancestor, exp.Select)
|
||||
or (isinstance(ancestor, exp.Order) and isinstance(ancestor.parent, exp.Window))
|
||||
or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
|
||||
):
|
||||
self._columns.append(column)
|
||||
|
@ -370,6 +374,17 @@ class Scope:
|
|||
return []
|
||||
return self._join_hints
|
||||
|
||||
@property
|
||||
def pivots(self):
|
||||
if not self._pivots:
|
||||
self._pivots = [
|
||||
pivot
|
||||
for node in self.tables + self.derived_tables
|
||||
for pivot in node.args.get("pivots") or []
|
||||
]
|
||||
|
||||
return self._pivots
|
||||
|
||||
def source_columns(self, source_name):
|
||||
"""
|
||||
Get all columns in the current scope for a particular source.
|
||||
|
@ -463,7 +478,7 @@ class Scope:
|
|||
return scope_ref_count
|
||||
|
||||
|
||||
def traverse_scope(expression):
|
||||
def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
|
||||
"""
|
||||
Traverse an expression by it's "scopes".
|
||||
|
||||
|
@ -488,10 +503,12 @@ def traverse_scope(expression):
|
|||
Returns:
|
||||
list[Scope]: scope instances
|
||||
"""
|
||||
if not isinstance(expression, exp.Unionable):
|
||||
return []
|
||||
return list(_traverse_scope(Scope(expression)))
|
||||
|
||||
|
||||
def build_scope(expression):
|
||||
def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
|
||||
"""
|
||||
Build a scope tree.
|
||||
|
||||
|
@ -500,7 +517,10 @@ def build_scope(expression):
|
|||
Returns:
|
||||
Scope: root scope
|
||||
"""
|
||||
return traverse_scope(expression)[-1]
|
||||
scopes = traverse_scope(expression)
|
||||
if scopes:
|
||||
return scopes[-1]
|
||||
return None
|
||||
|
||||
|
||||
def _traverse_scope(scope):
|
||||
|
@ -585,7 +605,7 @@ def _traverse_tables(scope):
|
|||
expressions = []
|
||||
from_ = scope.expression.args.get("from")
|
||||
if from_:
|
||||
expressions.extend(from_.expressions)
|
||||
expressions.append(from_.this)
|
||||
|
||||
for join in scope.expression.args.get("joins") or []:
|
||||
expressions.append(join.this)
|
||||
|
@ -601,8 +621,13 @@ def _traverse_tables(scope):
|
|||
source_name = expression.alias_or_name
|
||||
|
||||
if table_name in scope.sources:
|
||||
# This is a reference to a parent source (e.g. a CTE), not an actual table.
|
||||
sources[source_name] = scope.sources[table_name]
|
||||
# This is a reference to a parent source (e.g. a CTE), not an actual table, unless
|
||||
# it is pivoted, because then we get back a new table and hence a new source.
|
||||
pivots = expression.args.get("pivots")
|
||||
if pivots:
|
||||
sources[pivots[0].alias] = expression
|
||||
else:
|
||||
sources[source_name] = scope.sources[table_name]
|
||||
elif source_name in sources:
|
||||
sources[find_new_name(sources, table_name)] = expression
|
||||
else:
|
||||
|
|
|
@ -5,11 +5,9 @@ from collections import deque
|
|||
from decimal import Decimal
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.generator import cached_generator
|
||||
from sqlglot.helper import first, while_changing
|
||||
|
||||
GENERATOR = Generator(normalize=True, identify="safe")
|
||||
|
||||
|
||||
def simplify(expression):
|
||||
"""
|
||||
|
@ -27,12 +25,12 @@ def simplify(expression):
|
|||
sqlglot.Expression: simplified expression
|
||||
"""
|
||||
|
||||
cache = {}
|
||||
generate = cached_generator()
|
||||
|
||||
def _simplify(expression, root=True):
|
||||
node = expression
|
||||
node = rewrite_between(node)
|
||||
node = uniq_sort(node, cache, root)
|
||||
node = uniq_sort(node, generate, root)
|
||||
node = absorb_and_eliminate(node, root)
|
||||
exp.replace_children(node, lambda e: _simplify(e, False))
|
||||
node = simplify_not(node)
|
||||
|
@ -247,7 +245,7 @@ def remove_compliments(expression, root=True):
|
|||
return expression
|
||||
|
||||
|
||||
def uniq_sort(expression, cache=None, root=True):
|
||||
def uniq_sort(expression, generate, root=True):
|
||||
"""
|
||||
Uniq and sort a connector.
|
||||
|
||||
|
@ -256,7 +254,7 @@ def uniq_sort(expression, cache=None, root=True):
|
|||
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
|
||||
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
|
||||
flattened = tuple(expression.flatten())
|
||||
deduped = {GENERATOR.generate(e, cache): e for e in flattened}
|
||||
deduped = {generate(e): e for e in flattened}
|
||||
arr = tuple(deduped.items())
|
||||
|
||||
# check if the operands are already sorted, if not sort them
|
||||
|
@ -388,14 +386,18 @@ def _simplify_binary(expression, a, b):
|
|||
|
||||
|
||||
def simplify_parens(expression):
|
||||
if (
|
||||
isinstance(expression, exp.Paren)
|
||||
and not isinstance(expression.this, exp.Select)
|
||||
and (
|
||||
not isinstance(expression.parent, (exp.Condition, exp.Binary))
|
||||
or isinstance(expression.this, exp.Predicate)
|
||||
or not isinstance(expression.this, exp.Binary)
|
||||
)
|
||||
if not isinstance(expression, exp.Paren):
|
||||
return expression
|
||||
|
||||
this = expression.this
|
||||
parent = expression.parent
|
||||
|
||||
if not isinstance(this, exp.Select) and (
|
||||
not isinstance(parent, (exp.Condition, exp.Binary))
|
||||
or isinstance(this, exp.Predicate)
|
||||
or not isinstance(this, exp.Binary)
|
||||
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
|
||||
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
|
||||
):
|
||||
return expression.this
|
||||
return expression
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import itertools
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.helper import name_sequence
|
||||
from sqlglot.optimizer.scope import ScopeType, traverse_scope
|
||||
|
||||
|
||||
|
@ -22,7 +21,7 @@ def unnest_subqueries(expression):
|
|||
Returns:
|
||||
sqlglot.Expression: unnested expression
|
||||
"""
|
||||
sequence = itertools.count()
|
||||
next_alias_name = name_sequence("_u_")
|
||||
|
||||
for scope in traverse_scope(expression):
|
||||
select = scope.expression
|
||||
|
@ -30,19 +29,19 @@ def unnest_subqueries(expression):
|
|||
if not parent:
|
||||
continue
|
||||
if scope.external_columns:
|
||||
decorrelate(select, parent, scope.external_columns, sequence)
|
||||
decorrelate(select, parent, scope.external_columns, next_alias_name)
|
||||
elif scope.scope_type == ScopeType.SUBQUERY:
|
||||
unnest(select, parent, sequence)
|
||||
unnest(select, parent, next_alias_name)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def unnest(select, parent_select, sequence):
|
||||
def unnest(select, parent_select, next_alias_name):
|
||||
if len(select.selects) > 1:
|
||||
return
|
||||
|
||||
predicate = select.find_ancestor(exp.Condition)
|
||||
alias = _alias(sequence)
|
||||
alias = next_alias_name()
|
||||
|
||||
if not predicate or parent_select is not predicate.parent_select:
|
||||
return
|
||||
|
@ -87,13 +86,13 @@ def unnest(select, parent_select, sequence):
|
|||
)
|
||||
|
||||
|
||||
def decorrelate(select, parent_select, external_columns, sequence):
|
||||
def decorrelate(select, parent_select, external_columns, next_alias_name):
|
||||
where = select.args.get("where")
|
||||
|
||||
if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
|
||||
return
|
||||
|
||||
table_alias = _alias(sequence)
|
||||
table_alias = next_alias_name()
|
||||
keys = []
|
||||
|
||||
# for all external columns in the where statement, find the relevant predicate
|
||||
|
@ -136,7 +135,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
group_by.append(key)
|
||||
else:
|
||||
if key not in key_aliases:
|
||||
key_aliases[key] = _alias(sequence)
|
||||
key_aliases[key] = next_alias_name()
|
||||
# all predicates that are equalities must also be in the unique
|
||||
# so that we don't do a many to many join
|
||||
if isinstance(predicate, exp.EQ) and key not in group_by:
|
||||
|
@ -244,10 +243,6 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
)
|
||||
|
||||
|
||||
def _alias(sequence):
|
||||
return f"_u_{next(sequence)}"
|
||||
|
||||
|
||||
def _replace(expression, condition):
|
||||
return expression.replace(exp.condition(condition))
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,11 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import math
|
||||
import typing as t
|
||||
|
||||
from sqlglot import alias, exp
|
||||
from sqlglot.errors import UnsupportedError
|
||||
from sqlglot.helper import name_sequence
|
||||
from sqlglot.optimizer.eliminate_joins import join_condition
|
||||
|
||||
|
||||
|
@ -105,13 +104,7 @@ class Step:
|
|||
from_ = expression.args.get("from")
|
||||
|
||||
if isinstance(expression, exp.Select) and from_:
|
||||
from_ = from_.expressions
|
||||
if len(from_) > 1:
|
||||
raise UnsupportedError(
|
||||
"Multi-from statements are unsupported. Run it through the optimizer"
|
||||
)
|
||||
|
||||
step = Scan.from_expression(from_[0], ctes)
|
||||
step = Scan.from_expression(from_.this, ctes)
|
||||
elif isinstance(expression, exp.Union):
|
||||
step = SetOperation.from_expression(expression, ctes)
|
||||
else:
|
||||
|
@ -128,7 +121,7 @@ class Step:
|
|||
projections = [] # final selects in this chain of steps representing a select
|
||||
operands = {} # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
|
||||
aggregations = []
|
||||
sequence = itertools.count()
|
||||
next_operand_name = name_sequence("_a_")
|
||||
|
||||
def extract_agg_operands(expression):
|
||||
for agg in expression.find_all(exp.AggFunc):
|
||||
|
@ -136,7 +129,7 @@ class Step:
|
|||
if isinstance(operand, exp.Column):
|
||||
continue
|
||||
if operand not in operands:
|
||||
operands[operand] = f"_a_{next(sequence)}"
|
||||
operands[operand] = next_operand_name()
|
||||
operand.replace(exp.column(operands[operand], quoted=True))
|
||||
|
||||
for e in expression.expressions:
|
||||
|
@ -310,7 +303,7 @@ class Join(Step):
|
|||
for join in joins:
|
||||
source_key, join_key, condition = join_condition(join)
|
||||
step.joins[join.this.alias_or_name] = {
|
||||
"side": join.side,
|
||||
"side": join.side, # type: ignore
|
||||
"join_key": join_key,
|
||||
"source_key": source_key,
|
||||
"condition": condition,
|
||||
|
|
|
@ -5,6 +5,8 @@ import typing as t
|
|||
|
||||
import sqlglot
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot._typing import T
|
||||
from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE
|
||||
from sqlglot.errors import ParseError, SchemaError
|
||||
from sqlglot.helper import dict_depth
|
||||
from sqlglot.trie import in_trie, new_trie
|
||||
|
@ -17,62 +19,83 @@ if t.TYPE_CHECKING:
|
|||
|
||||
TABLE_ARGS = ("this", "db", "catalog")
|
||||
|
||||
T = t.TypeVar("T")
|
||||
|
||||
|
||||
class Schema(abc.ABC):
|
||||
"""Abstract base class for database schemas"""
|
||||
|
||||
dialect: DialectType
|
||||
|
||||
@abc.abstractmethod
|
||||
def add_table(
|
||||
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
|
||||
self,
|
||||
table: exp.Table | str,
|
||||
column_mapping: t.Optional[ColumnMapping] = None,
|
||||
dialect: DialectType = None,
|
||||
) -> None:
|
||||
"""
|
||||
Register or update a table. Some implementing classes may require column information to also be provided.
|
||||
|
||||
Args:
|
||||
table: table expression instance or string representing the table.
|
||||
table: the `Table` expression instance or string representing the table.
|
||||
column_mapping: a column mapping that describes the structure of the table.
|
||||
dialect: the SQL dialect that will be used to parse `table` if it's a string.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
|
||||
def column_names(
|
||||
self,
|
||||
table: exp.Table | str,
|
||||
only_visible: bool = False,
|
||||
dialect: DialectType = None,
|
||||
) -> t.List[str]:
|
||||
"""
|
||||
Get the column names for a table.
|
||||
|
||||
Args:
|
||||
table: the `Table` expression instance.
|
||||
only_visible: whether to include invisible columns.
|
||||
dialect: the SQL dialect that will be used to parse `table` if it's a string.
|
||||
|
||||
Returns:
|
||||
The list of column names.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType:
|
||||
def get_column_type(
|
||||
self,
|
||||
table: exp.Table | str,
|
||||
column: exp.Column,
|
||||
dialect: DialectType = None,
|
||||
) -> exp.DataType:
|
||||
"""
|
||||
Get the :class:`sqlglot.exp.DataType` type of a column in the schema.
|
||||
Get the `sqlglot.exp.DataType` type of a column in the schema.
|
||||
|
||||
Args:
|
||||
table: the source table.
|
||||
column: the target column.
|
||||
dialect: the SQL dialect that will be used to parse `table` if it's a string.
|
||||
|
||||
Returns:
|
||||
The resulting column type.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def supported_table_args(self) -> t.Tuple[str, ...]:
|
||||
"""
|
||||
Table arguments this schema support, e.g. `("this", "db", "catalog")`
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def empty(self) -> bool:
|
||||
"""Returns whether or not the schema is empty."""
|
||||
return True
|
||||
|
||||
|
||||
class AbstractMappingSchema(t.Generic[T]):
|
||||
def __init__(
|
||||
self,
|
||||
mapping: dict | None = None,
|
||||
mapping: t.Optional[t.Dict] = None,
|
||||
) -> None:
|
||||
self.mapping = mapping or {}
|
||||
self.mapping_trie = new_trie(
|
||||
|
@ -80,6 +103,10 @@ class AbstractMappingSchema(t.Generic[T]):
|
|||
)
|
||||
self._supported_table_args: t.Tuple[str, ...] = tuple()
|
||||
|
||||
@property
|
||||
def empty(self) -> bool:
|
||||
return not self.mapping
|
||||
|
||||
def _depth(self) -> int:
|
||||
return dict_depth(self.mapping)
|
||||
|
||||
|
@ -110,8 +137,10 @@ class AbstractMappingSchema(t.Generic[T]):
|
|||
|
||||
if value == 0:
|
||||
return None
|
||||
elif value == 1:
|
||||
|
||||
if value == 1:
|
||||
possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
|
||||
|
||||
if len(possibilities) == 1:
|
||||
parts.extend(possibilities[0])
|
||||
else:
|
||||
|
@ -119,12 +148,13 @@ class AbstractMappingSchema(t.Generic[T]):
|
|||
if raise_on_missing:
|
||||
raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
|
||||
return None
|
||||
return self._nested_get(parts, raise_on_missing=raise_on_missing)
|
||||
|
||||
def _nested_get(
|
||||
return self.nested_get(parts, raise_on_missing=raise_on_missing)
|
||||
|
||||
def nested_get(
|
||||
self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
|
||||
) -> t.Optional[t.Any]:
|
||||
return _nested_get(
|
||||
return nested_get(
|
||||
d or self.mapping,
|
||||
*zip(self.supported_table_args, reversed(parts)),
|
||||
raise_on_missing=raise_on_missing,
|
||||
|
@ -136,17 +166,18 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
Schema based on a nested mapping.
|
||||
|
||||
Args:
|
||||
schema (dict): Mapping in one of the following forms:
|
||||
schema: Mapping in one of the following forms:
|
||||
1. {table: {col: type}}
|
||||
2. {db: {table: {col: type}}}
|
||||
3. {catalog: {db: {table: {col: type}}}}
|
||||
4. None - Tables will be added later
|
||||
visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
|
||||
visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
|
||||
are assumed to be visible. The nesting should mirror that of the schema:
|
||||
1. {table: set(*cols)}}
|
||||
2. {db: {table: set(*cols)}}}
|
||||
3. {catalog: {db: {table: set(*cols)}}}}
|
||||
dialect (str): The dialect to be used for custom type mappings.
|
||||
dialect: The dialect to be used for custom type mappings & parsing string arguments.
|
||||
normalize: Whether to normalize identifier names according to the given dialect or not.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -154,10 +185,13 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
schema: t.Optional[t.Dict] = None,
|
||||
visible: t.Optional[t.Dict] = None,
|
||||
dialect: DialectType = None,
|
||||
normalize: bool = True,
|
||||
) -> None:
|
||||
self.dialect = dialect
|
||||
self.visible = visible or {}
|
||||
self.normalize = normalize
|
||||
self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
|
||||
|
||||
super().__init__(self._normalize(schema or {}))
|
||||
|
||||
@classmethod
|
||||
|
@ -179,7 +213,10 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
)
|
||||
|
||||
def add_table(
|
||||
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
|
||||
self,
|
||||
table: exp.Table | str,
|
||||
column_mapping: t.Optional[ColumnMapping] = None,
|
||||
dialect: DialectType = None,
|
||||
) -> None:
|
||||
"""
|
||||
Register or update a table. Updates are only performed if a new column mapping is provided.
|
||||
|
@ -187,10 +224,13 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
Args:
|
||||
table: the `Table` expression instance or string representing the table.
|
||||
column_mapping: a column mapping that describes the structure of the table.
|
||||
dialect: the SQL dialect that will be used to parse `table` if it's a string.
|
||||
"""
|
||||
normalized_table = self._normalize_table(self._ensure_table(table))
|
||||
normalized_table = self._normalize_table(
|
||||
self._ensure_table(table, dialect=dialect), dialect=dialect
|
||||
)
|
||||
normalized_column_mapping = {
|
||||
self._normalize_name(key): value
|
||||
self._normalize_name(key, dialect=dialect): value
|
||||
for key, value in ensure_column_mapping(column_mapping).items()
|
||||
}
|
||||
|
||||
|
@ -200,38 +240,51 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
|
||||
parts = self.table_parts(normalized_table)
|
||||
|
||||
_nested_set(
|
||||
self.mapping,
|
||||
tuple(reversed(parts)),
|
||||
normalized_column_mapping,
|
||||
)
|
||||
nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
|
||||
new_trie([parts], self.mapping_trie)
|
||||
|
||||
def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
|
||||
table_ = self._normalize_table(self._ensure_table(table))
|
||||
schema = self.find(table_)
|
||||
def column_names(
|
||||
self,
|
||||
table: exp.Table | str,
|
||||
only_visible: bool = False,
|
||||
dialect: DialectType = None,
|
||||
) -> t.List[str]:
|
||||
normalized_table = self._normalize_table(
|
||||
self._ensure_table(table, dialect=dialect), dialect=dialect
|
||||
)
|
||||
|
||||
schema = self.find(normalized_table)
|
||||
if schema is None:
|
||||
return []
|
||||
|
||||
if not only_visible or not self.visible:
|
||||
return list(schema)
|
||||
|
||||
visible = self._nested_get(self.table_parts(table_), self.visible)
|
||||
return [col for col in schema if col in visible] # type: ignore
|
||||
visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
|
||||
return [col for col in schema if col in visible]
|
||||
|
||||
def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
|
||||
column_name = self._normalize_name(column if isinstance(column, str) else column.this)
|
||||
table_ = self._normalize_table(self._ensure_table(table))
|
||||
def get_column_type(
|
||||
self,
|
||||
table: exp.Table | str,
|
||||
column: exp.Column,
|
||||
dialect: DialectType = None,
|
||||
) -> exp.DataType:
|
||||
normalized_table = self._normalize_table(
|
||||
self._ensure_table(table, dialect=dialect), dialect=dialect
|
||||
)
|
||||
normalized_column_name = self._normalize_name(
|
||||
column if isinstance(column, str) else column.this, dialect=dialect
|
||||
)
|
||||
|
||||
table_schema = self.find(table_, raise_on_missing=False)
|
||||
table_schema = self.find(normalized_table, raise_on_missing=False)
|
||||
if table_schema:
|
||||
column_type = table_schema.get(column_name)
|
||||
column_type = table_schema.get(normalized_column_name)
|
||||
|
||||
if isinstance(column_type, exp.DataType):
|
||||
return column_type
|
||||
elif isinstance(column_type, str):
|
||||
return self._to_data_type(column_type.upper())
|
||||
return self._to_data_type(column_type.upper(), dialect=dialect)
|
||||
|
||||
raise SchemaError(f"Unknown column type '{column_type}'")
|
||||
|
||||
return exp.DataType.build("unknown")
|
||||
|
@ -250,81 +303,88 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
|
||||
normalized_mapping: t.Dict = {}
|
||||
for keys in flattened_schema:
|
||||
columns = _nested_get(schema, *zip(keys, keys))
|
||||
columns = nested_get(schema, *zip(keys, keys))
|
||||
assert columns is not None
|
||||
|
||||
normalized_keys = [self._normalize_name(key) for key in keys]
|
||||
normalized_keys = [self._normalize_name(key, dialect=self.dialect) for key in keys]
|
||||
for column_name, column_type in columns.items():
|
||||
_nested_set(
|
||||
nested_set(
|
||||
normalized_mapping,
|
||||
normalized_keys + [self._normalize_name(column_name)],
|
||||
normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
|
||||
column_type,
|
||||
)
|
||||
|
||||
return normalized_mapping
|
||||
|
||||
def _normalize_table(self, table: exp.Table) -> exp.Table:
|
||||
def _normalize_table(self, table: exp.Table, dialect: DialectType = None) -> exp.Table:
|
||||
normalized_table = table.copy()
|
||||
|
||||
for arg in TABLE_ARGS:
|
||||
value = normalized_table.args.get(arg)
|
||||
if isinstance(value, (str, exp.Identifier)):
|
||||
normalized_table.set(arg, self._normalize_name(value))
|
||||
normalized_table.set(
|
||||
arg, exp.to_identifier(self._normalize_name(value, dialect=dialect))
|
||||
)
|
||||
|
||||
return normalized_table
|
||||
|
||||
def _normalize_name(self, name: str | exp.Identifier) -> str:
|
||||
def _normalize_name(self, name: str | exp.Identifier, dialect: DialectType = None) -> str:
|
||||
dialect = dialect or self.dialect
|
||||
|
||||
try:
|
||||
identifier = sqlglot.maybe_parse(name, dialect=self.dialect, into=exp.Identifier)
|
||||
identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
|
||||
except ParseError:
|
||||
return name if isinstance(name, str) else name.name
|
||||
|
||||
return identifier.name if identifier.quoted else identifier.name.lower()
|
||||
name = identifier.name
|
||||
|
||||
if not self.normalize or identifier.quoted:
|
||||
return name
|
||||
|
||||
return name.upper() if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else name.lower()
|
||||
|
||||
def _depth(self) -> int:
|
||||
# The columns themselves are a mapping, but we don't want to include those
|
||||
return super()._depth() - 1
|
||||
|
||||
def _ensure_table(self, table: exp.Table | str) -> exp.Table:
|
||||
if isinstance(table, exp.Table):
|
||||
return table
|
||||
def _ensure_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table:
|
||||
return exp.maybe_parse(table, into=exp.Table, dialect=dialect or self.dialect)
|
||||
|
||||
table_ = sqlglot.parse_one(table, read=self.dialect, into=exp.Table)
|
||||
if not table_:
|
||||
raise SchemaError(f"Not a valid table '{table}'")
|
||||
|
||||
return table_
|
||||
|
||||
def _to_data_type(self, schema_type: str) -> exp.DataType:
|
||||
def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
|
||||
"""
|
||||
Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
|
||||
Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
|
||||
|
||||
Args:
|
||||
schema_type: the type we want to convert.
|
||||
dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
|
||||
|
||||
Returns:
|
||||
The resulting expression type.
|
||||
"""
|
||||
if schema_type not in self._type_mapping_cache:
|
||||
dialect = dialect or self.dialect
|
||||
|
||||
try:
|
||||
expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect)
|
||||
if expression is None:
|
||||
raise ValueError(f"Could not parse {schema_type}")
|
||||
self._type_mapping_cache[schema_type] = expression # type: ignore
|
||||
expression = exp.DataType.build(schema_type, dialect=dialect)
|
||||
self._type_mapping_cache[schema_type] = expression
|
||||
except AttributeError:
|
||||
raise SchemaError(f"Failed to convert type {schema_type}")
|
||||
in_dialect = f" in dialect {dialect}" if dialect else ""
|
||||
raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
|
||||
|
||||
return self._type_mapping_cache[schema_type]
|
||||
|
||||
|
||||
def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema:
|
||||
def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
|
||||
if isinstance(schema, Schema):
|
||||
return schema
|
||||
|
||||
return MappingSchema(schema, dialect=dialect)
|
||||
return MappingSchema(schema, **kwargs)
|
||||
|
||||
|
||||
def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
|
||||
if isinstance(mapping, dict):
|
||||
if mapping is None:
|
||||
return {}
|
||||
elif isinstance(mapping, dict):
|
||||
return mapping
|
||||
elif isinstance(mapping, str):
|
||||
col_name_type_strs = [x.strip() for x in mapping.split(",")]
|
||||
|
@ -334,11 +394,10 @@ def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
|
|||
}
|
||||
# Check if mapping looks like a DataFrame StructType
|
||||
elif hasattr(mapping, "simpleString"):
|
||||
return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} # type: ignore
|
||||
return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
|
||||
elif isinstance(mapping, list):
|
||||
return {x.strip(): None for x in mapping}
|
||||
elif mapping is None:
|
||||
return {}
|
||||
|
||||
raise ValueError(f"Invalid mapping provided: {type(mapping)}")
|
||||
|
||||
|
||||
|
@ -353,10 +412,11 @@ def flatten_schema(
|
|||
tables.extend(flatten_schema(v, depth - 1, keys + [k]))
|
||||
elif depth == 1:
|
||||
tables.append(keys + [k])
|
||||
|
||||
return tables
|
||||
|
||||
|
||||
def _nested_get(
|
||||
def nested_get(
|
||||
d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
|
||||
) -> t.Optional[t.Any]:
|
||||
"""
|
||||
|
@ -378,18 +438,19 @@ def _nested_get(
|
|||
name = "table" if name == "this" else name
|
||||
raise ValueError(f"Unknown {name}: {key}")
|
||||
return None
|
||||
|
||||
return d
|
||||
|
||||
|
||||
def _nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
|
||||
def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
|
||||
"""
|
||||
In-place set a value for a nested dictionary
|
||||
|
||||
Example:
|
||||
>>> _nested_set({}, ["top_key", "second_key"], "value")
|
||||
>>> nested_set({}, ["top_key", "second_key"], "value")
|
||||
{'top_key': {'second_key': 'value'}}
|
||||
|
||||
>>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
|
||||
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
|
||||
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
|
||||
|
||||
Args:
|
||||
|
|
|
@ -51,7 +51,6 @@ class TokenType(AutoName):
|
|||
DOLLAR = auto()
|
||||
PARAMETER = auto()
|
||||
SESSION_PARAMETER = auto()
|
||||
NATIONAL = auto()
|
||||
DAMP = auto()
|
||||
|
||||
BLOCK_START = auto()
|
||||
|
@ -72,6 +71,8 @@ class TokenType(AutoName):
|
|||
BIT_STRING = auto()
|
||||
HEX_STRING = auto()
|
||||
BYTE_STRING = auto()
|
||||
NATIONAL_STRING = auto()
|
||||
RAW_STRING = auto()
|
||||
|
||||
# types
|
||||
BIT = auto()
|
||||
|
@ -110,6 +111,7 @@ class TokenType(AutoName):
|
|||
TIMESTAMPTZ = auto()
|
||||
TIMESTAMPLTZ = auto()
|
||||
DATETIME = auto()
|
||||
DATETIME64 = auto()
|
||||
DATE = auto()
|
||||
UUID = auto()
|
||||
GEOGRAPHY = auto()
|
||||
|
@ -142,30 +144,22 @@ class TokenType(AutoName):
|
|||
ARRAY = auto()
|
||||
ASC = auto()
|
||||
ASOF = auto()
|
||||
AT_TIME_ZONE = auto()
|
||||
AUTO_INCREMENT = auto()
|
||||
BEGIN = auto()
|
||||
BETWEEN = auto()
|
||||
BOTH = auto()
|
||||
BUCKET = auto()
|
||||
BY_DEFAULT = auto()
|
||||
CACHE = auto()
|
||||
CASCADE = auto()
|
||||
CASE = auto()
|
||||
CHARACTER_SET = auto()
|
||||
CLUSTER_BY = auto()
|
||||
COLLATE = auto()
|
||||
COMMAND = auto()
|
||||
COMMENT = auto()
|
||||
COMMIT = auto()
|
||||
COMPOUND = auto()
|
||||
CONSTRAINT = auto()
|
||||
CREATE = auto()
|
||||
CROSS = auto()
|
||||
CUBE = auto()
|
||||
CURRENT_DATE = auto()
|
||||
CURRENT_DATETIME = auto()
|
||||
CURRENT_ROW = auto()
|
||||
CURRENT_TIME = auto()
|
||||
CURRENT_TIMESTAMP = auto()
|
||||
CURRENT_USER = auto()
|
||||
|
@ -174,8 +168,6 @@ class TokenType(AutoName):
|
|||
DESC = auto()
|
||||
DESCRIBE = auto()
|
||||
DISTINCT = auto()
|
||||
DISTINCT_FROM = auto()
|
||||
DISTRIBUTE_BY = auto()
|
||||
DIV = auto()
|
||||
DROP = auto()
|
||||
ELSE = auto()
|
||||
|
@ -189,7 +181,6 @@ class TokenType(AutoName):
|
|||
FILTER = auto()
|
||||
FINAL = auto()
|
||||
FIRST = auto()
|
||||
FOLLOWING = auto()
|
||||
FOR = auto()
|
||||
FOREIGN_KEY = auto()
|
||||
FORMAT = auto()
|
||||
|
@ -203,7 +194,6 @@ class TokenType(AutoName):
|
|||
HAVING = auto()
|
||||
HINT = auto()
|
||||
IF = auto()
|
||||
IGNORE_NULLS = auto()
|
||||
ILIKE = auto()
|
||||
ILIKE_ANY = auto()
|
||||
IN = auto()
|
||||
|
@ -222,36 +212,27 @@ class TokenType(AutoName):
|
|||
KEEP = auto()
|
||||
LANGUAGE = auto()
|
||||
LATERAL = auto()
|
||||
LAZY = auto()
|
||||
LEADING = auto()
|
||||
LEFT = auto()
|
||||
LIKE = auto()
|
||||
LIKE_ANY = auto()
|
||||
LIMIT = auto()
|
||||
LOAD_DATA = auto()
|
||||
LOCAL = auto()
|
||||
LOAD = auto()
|
||||
LOCK = auto()
|
||||
MAP = auto()
|
||||
MATCH_RECOGNIZE = auto()
|
||||
MATERIALIZED = auto()
|
||||
MERGE = auto()
|
||||
MOD = auto()
|
||||
NATURAL = auto()
|
||||
NEXT = auto()
|
||||
NEXT_VALUE_FOR = auto()
|
||||
NO_ACTION = auto()
|
||||
NOTNULL = auto()
|
||||
NULL = auto()
|
||||
NULLS_FIRST = auto()
|
||||
NULLS_LAST = auto()
|
||||
OFFSET = auto()
|
||||
ON = auto()
|
||||
ONLY = auto()
|
||||
OPTIONS = auto()
|
||||
ORDER_BY = auto()
|
||||
ORDERED = auto()
|
||||
ORDINALITY = auto()
|
||||
OUTER = auto()
|
||||
OUT_OF = auto()
|
||||
OVER = auto()
|
||||
OVERLAPS = auto()
|
||||
OVERWRITE = auto()
|
||||
|
@ -261,7 +242,6 @@ class TokenType(AutoName):
|
|||
PIVOT = auto()
|
||||
PLACEHOLDER = auto()
|
||||
PRAGMA = auto()
|
||||
PRECEDING = auto()
|
||||
PRIMARY_KEY = auto()
|
||||
PROCEDURE = auto()
|
||||
PROPERTIES = auto()
|
||||
|
@ -271,7 +251,6 @@ class TokenType(AutoName):
|
|||
RANGE = auto()
|
||||
RECURSIVE = auto()
|
||||
REPLACE = auto()
|
||||
RESPECT_NULLS = auto()
|
||||
RETURNING = auto()
|
||||
REFERENCES = auto()
|
||||
RIGHT = auto()
|
||||
|
@ -280,28 +259,23 @@ class TokenType(AutoName):
|
|||
ROLLUP = auto()
|
||||
ROW = auto()
|
||||
ROWS = auto()
|
||||
SEED = auto()
|
||||
SELECT = auto()
|
||||
SEMI = auto()
|
||||
SEPARATOR = auto()
|
||||
SERDE_PROPERTIES = auto()
|
||||
SET = auto()
|
||||
SETTINGS = auto()
|
||||
SHOW = auto()
|
||||
SIMILAR_TO = auto()
|
||||
SOME = auto()
|
||||
SORTKEY = auto()
|
||||
SORT_BY = auto()
|
||||
STRUCT = auto()
|
||||
TABLE_SAMPLE = auto()
|
||||
TEMPORARY = auto()
|
||||
TOP = auto()
|
||||
THEN = auto()
|
||||
TRAILING = auto()
|
||||
TRUE = auto()
|
||||
UNBOUNDED = auto()
|
||||
UNCACHE = auto()
|
||||
UNION = auto()
|
||||
UNLOGGED = auto()
|
||||
UNNEST = auto()
|
||||
UNPIVOT = auto()
|
||||
UPDATE = auto()
|
||||
|
@ -314,15 +288,11 @@ class TokenType(AutoName):
|
|||
WHERE = auto()
|
||||
WINDOW = auto()
|
||||
WITH = auto()
|
||||
WITH_TIME_ZONE = auto()
|
||||
WITH_LOCAL_TIME_ZONE = auto()
|
||||
WITHIN_GROUP = auto()
|
||||
WITHOUT_TIME_ZONE = auto()
|
||||
UNIQUE = auto()
|
||||
|
||||
|
||||
class Token:
|
||||
__slots__ = ("token_type", "text", "line", "col", "end", "comments")
|
||||
__slots__ = ("token_type", "text", "line", "col", "start", "end", "comments")
|
||||
|
||||
@classmethod
|
||||
def number(cls, number: int) -> Token:
|
||||
|
@ -350,22 +320,28 @@ class Token:
|
|||
text: str,
|
||||
line: int = 1,
|
||||
col: int = 1,
|
||||
start: int = 0,
|
||||
end: int = 0,
|
||||
comments: t.List[str] = [],
|
||||
) -> None:
|
||||
"""Token initializer.
|
||||
|
||||
Args:
|
||||
token_type: The TokenType Enum.
|
||||
text: The text of the token.
|
||||
line: The line that the token ends on.
|
||||
col: The column that the token ends on.
|
||||
start: The start index of the token.
|
||||
end: The ending index of the token.
|
||||
"""
|
||||
self.token_type = token_type
|
||||
self.text = text
|
||||
self.line = line
|
||||
size = len(text)
|
||||
self.col = col
|
||||
self.end = end if end else size
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.comments = comments
|
||||
|
||||
@property
|
||||
def start(self) -> int:
|
||||
"""Returns the start of the token."""
|
||||
return self.end - len(self.text)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
|
||||
return f"<Token {attributes}>"
|
||||
|
@ -375,15 +351,31 @@ class _Tokenizer(type):
|
|||
def __new__(cls, clsname, bases, attrs):
|
||||
klass = super().__new__(cls, clsname, bases, attrs)
|
||||
|
||||
klass._QUOTES = {
|
||||
f"{prefix}{s}": e
|
||||
for s, e in cls._delimeter_list_to_dict(klass.QUOTES).items()
|
||||
for prefix in (("",) if s[0].isalpha() else ("", "n", "N"))
|
||||
def _convert_quotes(arr: t.List[str | t.Tuple[str, str]]) -> t.Dict[str, str]:
|
||||
return dict(
|
||||
(item, item) if isinstance(item, str) else (item[0], item[1]) for item in arr
|
||||
)
|
||||
|
||||
def _quotes_to_format(
|
||||
token_type: TokenType, arr: t.List[str | t.Tuple[str, str]]
|
||||
) -> t.Dict[str, t.Tuple[str, TokenType]]:
|
||||
return {k: (v, token_type) for k, v in _convert_quotes(arr).items()}
|
||||
|
||||
klass._QUOTES = _convert_quotes(klass.QUOTES)
|
||||
klass._IDENTIFIERS = _convert_quotes(klass.IDENTIFIERS)
|
||||
|
||||
klass._FORMAT_STRINGS = {
|
||||
**{
|
||||
p + s: (e, TokenType.NATIONAL_STRING)
|
||||
for s, e in klass._QUOTES.items()
|
||||
for p in ("n", "N")
|
||||
},
|
||||
**_quotes_to_format(TokenType.BIT_STRING, klass.BIT_STRINGS),
|
||||
**_quotes_to_format(TokenType.BYTE_STRING, klass.BYTE_STRINGS),
|
||||
**_quotes_to_format(TokenType.HEX_STRING, klass.HEX_STRINGS),
|
||||
**_quotes_to_format(TokenType.RAW_STRING, klass.RAW_STRINGS),
|
||||
}
|
||||
klass._BIT_STRINGS = cls._delimeter_list_to_dict(klass.BIT_STRINGS)
|
||||
klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
|
||||
klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS)
|
||||
klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS)
|
||||
|
||||
klass._STRING_ESCAPES = set(klass.STRING_ESCAPES)
|
||||
klass._IDENTIFIER_ESCAPES = set(klass.IDENTIFIER_ESCAPES)
|
||||
klass._COMMENTS = dict(
|
||||
|
@ -393,23 +385,17 @@ class _Tokenizer(type):
|
|||
|
||||
klass.KEYWORD_TRIE = new_trie(
|
||||
key.upper()
|
||||
for key in {
|
||||
**klass.KEYWORDS,
|
||||
**{comment: TokenType.COMMENT for comment in klass._COMMENTS},
|
||||
**{quote: TokenType.QUOTE for quote in klass._QUOTES},
|
||||
**{bit_string: TokenType.BIT_STRING for bit_string in klass._BIT_STRINGS},
|
||||
**{hex_string: TokenType.HEX_STRING for hex_string in klass._HEX_STRINGS},
|
||||
**{byte_string: TokenType.BYTE_STRING for byte_string in klass._BYTE_STRINGS},
|
||||
}
|
||||
for key in (
|
||||
*klass.KEYWORDS,
|
||||
*klass._COMMENTS,
|
||||
*klass._QUOTES,
|
||||
*klass._FORMAT_STRINGS,
|
||||
)
|
||||
if " " in key or any(single in key for single in klass.SINGLE_TOKENS)
|
||||
)
|
||||
|
||||
return klass
|
||||
|
||||
@staticmethod
|
||||
def _delimeter_list_to_dict(list: t.List[str | t.Tuple[str, str]]) -> t.Dict[str, str]:
|
||||
return dict((item, item) if isinstance(item, str) else (item[0], item[1]) for item in list)
|
||||
|
||||
|
||||
class Tokenizer(metaclass=_Tokenizer):
|
||||
SINGLE_TOKENS = {
|
||||
|
@ -450,6 +436,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
BIT_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||
BYTE_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||
HEX_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||
RAW_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||
IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
|
||||
IDENTIFIER_ESCAPES = ['"']
|
||||
QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
|
||||
|
@ -457,9 +444,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
VAR_SINGLE_TOKENS: t.Set[str] = set()
|
||||
|
||||
_COMMENTS: t.Dict[str, str] = {}
|
||||
_BIT_STRINGS: t.Dict[str, str] = {}
|
||||
_BYTE_STRINGS: t.Dict[str, str] = {}
|
||||
_HEX_STRINGS: t.Dict[str, str] = {}
|
||||
_FORMAT_STRINGS: t.Dict[str, t.Tuple[str, TokenType]] = {}
|
||||
_IDENTIFIERS: t.Dict[str, str] = {}
|
||||
_IDENTIFIER_ESCAPES: t.Set[str] = set()
|
||||
_QUOTES: t.Dict[str, str] = {}
|
||||
|
@ -495,30 +480,22 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"ANY": TokenType.ANY,
|
||||
"ASC": TokenType.ASC,
|
||||
"AS": TokenType.ALIAS,
|
||||
"AT TIME ZONE": TokenType.AT_TIME_ZONE,
|
||||
"AUTOINCREMENT": TokenType.AUTO_INCREMENT,
|
||||
"AUTO_INCREMENT": TokenType.AUTO_INCREMENT,
|
||||
"BEGIN": TokenType.BEGIN,
|
||||
"BETWEEN": TokenType.BETWEEN,
|
||||
"BOTH": TokenType.BOTH,
|
||||
"BUCKET": TokenType.BUCKET,
|
||||
"BY DEFAULT": TokenType.BY_DEFAULT,
|
||||
"CACHE": TokenType.CACHE,
|
||||
"UNCACHE": TokenType.UNCACHE,
|
||||
"CASE": TokenType.CASE,
|
||||
"CASCADE": TokenType.CASCADE,
|
||||
"CHARACTER SET": TokenType.CHARACTER_SET,
|
||||
"CLUSTER BY": TokenType.CLUSTER_BY,
|
||||
"COLLATE": TokenType.COLLATE,
|
||||
"COLUMN": TokenType.COLUMN,
|
||||
"COMMIT": TokenType.COMMIT,
|
||||
"COMPOUND": TokenType.COMPOUND,
|
||||
"CONSTRAINT": TokenType.CONSTRAINT,
|
||||
"CREATE": TokenType.CREATE,
|
||||
"CROSS": TokenType.CROSS,
|
||||
"CUBE": TokenType.CUBE,
|
||||
"CURRENT_DATE": TokenType.CURRENT_DATE,
|
||||
"CURRENT ROW": TokenType.CURRENT_ROW,
|
||||
"CURRENT_TIME": TokenType.CURRENT_TIME,
|
||||
"CURRENT_TIMESTAMP": TokenType.CURRENT_TIMESTAMP,
|
||||
"CURRENT_USER": TokenType.CURRENT_USER,
|
||||
|
@ -528,8 +505,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"DESC": TokenType.DESC,
|
||||
"DESCRIBE": TokenType.DESCRIBE,
|
||||
"DISTINCT": TokenType.DISTINCT,
|
||||
"DISTINCT FROM": TokenType.DISTINCT_FROM,
|
||||
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
|
||||
"DIV": TokenType.DIV,
|
||||
"DROP": TokenType.DROP,
|
||||
"ELSE": TokenType.ELSE,
|
||||
|
@ -544,18 +519,18 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"FIRST": TokenType.FIRST,
|
||||
"FULL": TokenType.FULL,
|
||||
"FUNCTION": TokenType.FUNCTION,
|
||||
"FOLLOWING": TokenType.FOLLOWING,
|
||||
"FOR": TokenType.FOR,
|
||||
"FOREIGN KEY": TokenType.FOREIGN_KEY,
|
||||
"FORMAT": TokenType.FORMAT,
|
||||
"FROM": TokenType.FROM,
|
||||
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
||||
"GEOMETRY": TokenType.GEOMETRY,
|
||||
"GLOB": TokenType.GLOB,
|
||||
"GROUP BY": TokenType.GROUP_BY,
|
||||
"GROUPING SETS": TokenType.GROUPING_SETS,
|
||||
"HAVING": TokenType.HAVING,
|
||||
"IF": TokenType.IF,
|
||||
"ILIKE": TokenType.ILIKE,
|
||||
"IGNORE NULLS": TokenType.IGNORE_NULLS,
|
||||
"IN": TokenType.IN,
|
||||
"INDEX": TokenType.INDEX,
|
||||
"INET": TokenType.INET,
|
||||
|
@ -569,34 +544,25 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"JOIN": TokenType.JOIN,
|
||||
"KEEP": TokenType.KEEP,
|
||||
"LATERAL": TokenType.LATERAL,
|
||||
"LAZY": TokenType.LAZY,
|
||||
"LEADING": TokenType.LEADING,
|
||||
"LEFT": TokenType.LEFT,
|
||||
"LIKE": TokenType.LIKE,
|
||||
"LIMIT": TokenType.LIMIT,
|
||||
"LOAD DATA": TokenType.LOAD_DATA,
|
||||
"LOCAL": TokenType.LOCAL,
|
||||
"MATERIALIZED": TokenType.MATERIALIZED,
|
||||
"LOAD": TokenType.LOAD,
|
||||
"LOCK": TokenType.LOCK,
|
||||
"MERGE": TokenType.MERGE,
|
||||
"NATURAL": TokenType.NATURAL,
|
||||
"NEXT": TokenType.NEXT,
|
||||
"NEXT VALUE FOR": TokenType.NEXT_VALUE_FOR,
|
||||
"NO ACTION": TokenType.NO_ACTION,
|
||||
"NOT": TokenType.NOT,
|
||||
"NOTNULL": TokenType.NOTNULL,
|
||||
"NULL": TokenType.NULL,
|
||||
"NULLS FIRST": TokenType.NULLS_FIRST,
|
||||
"NULLS LAST": TokenType.NULLS_LAST,
|
||||
"OBJECT": TokenType.OBJECT,
|
||||
"OFFSET": TokenType.OFFSET,
|
||||
"ON": TokenType.ON,
|
||||
"ONLY": TokenType.ONLY,
|
||||
"OPTIONS": TokenType.OPTIONS,
|
||||
"OR": TokenType.OR,
|
||||
"ORDER BY": TokenType.ORDER_BY,
|
||||
"ORDINALITY": TokenType.ORDINALITY,
|
||||
"OUTER": TokenType.OUTER,
|
||||
"OUT OF": TokenType.OUT_OF,
|
||||
"OVER": TokenType.OVER,
|
||||
"OVERLAPS": TokenType.OVERLAPS,
|
||||
"OVERWRITE": TokenType.OVERWRITE,
|
||||
|
@ -607,7 +573,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"PERCENT": TokenType.PERCENT,
|
||||
"PIVOT": TokenType.PIVOT,
|
||||
"PRAGMA": TokenType.PRAGMA,
|
||||
"PRECEDING": TokenType.PRECEDING,
|
||||
"PRIMARY KEY": TokenType.PRIMARY_KEY,
|
||||
"PROCEDURE": TokenType.PROCEDURE,
|
||||
"QUALIFY": TokenType.QUALIFY,
|
||||
|
@ -615,7 +580,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"RECURSIVE": TokenType.RECURSIVE,
|
||||
"REGEXP": TokenType.RLIKE,
|
||||
"REPLACE": TokenType.REPLACE,
|
||||
"RESPECT NULLS": TokenType.RESPECT_NULLS,
|
||||
"REFERENCES": TokenType.REFERENCES,
|
||||
"RIGHT": TokenType.RIGHT,
|
||||
"RLIKE": TokenType.RLIKE,
|
||||
|
@ -624,25 +588,20 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"ROW": TokenType.ROW,
|
||||
"ROWS": TokenType.ROWS,
|
||||
"SCHEMA": TokenType.SCHEMA,
|
||||
"SEED": TokenType.SEED,
|
||||
"SELECT": TokenType.SELECT,
|
||||
"SEMI": TokenType.SEMI,
|
||||
"SET": TokenType.SET,
|
||||
"SETTINGS": TokenType.SETTINGS,
|
||||
"SHOW": TokenType.SHOW,
|
||||
"SIMILAR TO": TokenType.SIMILAR_TO,
|
||||
"SOME": TokenType.SOME,
|
||||
"SORTKEY": TokenType.SORTKEY,
|
||||
"SORT BY": TokenType.SORT_BY,
|
||||
"TABLE": TokenType.TABLE,
|
||||
"TABLESAMPLE": TokenType.TABLE_SAMPLE,
|
||||
"TEMP": TokenType.TEMPORARY,
|
||||
"TEMPORARY": TokenType.TEMPORARY,
|
||||
"THEN": TokenType.THEN,
|
||||
"TRUE": TokenType.TRUE,
|
||||
"TRAILING": TokenType.TRAILING,
|
||||
"UNBOUNDED": TokenType.UNBOUNDED,
|
||||
"UNION": TokenType.UNION,
|
||||
"UNLOGGED": TokenType.UNLOGGED,
|
||||
"UNNEST": TokenType.UNNEST,
|
||||
"UNPIVOT": TokenType.UNPIVOT,
|
||||
"UPDATE": TokenType.UPDATE,
|
||||
|
@ -656,10 +615,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"WHERE": TokenType.WHERE,
|
||||
"WINDOW": TokenType.WINDOW,
|
||||
"WITH": TokenType.WITH,
|
||||
"WITH TIME ZONE": TokenType.WITH_TIME_ZONE,
|
||||
"WITH LOCAL TIME ZONE": TokenType.WITH_LOCAL_TIME_ZONE,
|
||||
"WITHIN GROUP": TokenType.WITHIN_GROUP,
|
||||
"WITHOUT TIME ZONE": TokenType.WITHOUT_TIME_ZONE,
|
||||
"APPLY": TokenType.APPLY,
|
||||
"ARRAY": TokenType.ARRAY,
|
||||
"BIT": TokenType.BIT,
|
||||
|
@ -718,15 +673,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"STRUCT": TokenType.STRUCT,
|
||||
"VARIANT": TokenType.VARIANT,
|
||||
"ALTER": TokenType.ALTER,
|
||||
"ALTER AGGREGATE": TokenType.COMMAND,
|
||||
"ALTER DEFAULT": TokenType.COMMAND,
|
||||
"ALTER DOMAIN": TokenType.COMMAND,
|
||||
"ALTER ROLE": TokenType.COMMAND,
|
||||
"ALTER RULE": TokenType.COMMAND,
|
||||
"ALTER SEQUENCE": TokenType.COMMAND,
|
||||
"ALTER TYPE": TokenType.COMMAND,
|
||||
"ALTER USER": TokenType.COMMAND,
|
||||
"ALTER VIEW": TokenType.COMMAND,
|
||||
"ANALYZE": TokenType.COMMAND,
|
||||
"CALL": TokenType.COMMAND,
|
||||
"COMMENT": TokenType.COMMENT,
|
||||
|
@ -790,7 +736,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._start = 0
|
||||
self._current = 0
|
||||
self._line = 1
|
||||
self._col = 1
|
||||
self._col = 0
|
||||
self._comments: t.List[str] = []
|
||||
|
||||
self._char = ""
|
||||
|
@ -803,13 +749,12 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self.reset()
|
||||
self.sql = sql
|
||||
self.size = len(sql)
|
||||
|
||||
try:
|
||||
self._scan()
|
||||
except Exception as e:
|
||||
start = self._current - 50
|
||||
end = self._current + 50
|
||||
start = start if start > 0 else 0
|
||||
end = end if end < self.size else self.size - 1
|
||||
start = max(self._current - 50, 0)
|
||||
end = min(self._current + 50, self.size - 1)
|
||||
context = self.sql[start:end]
|
||||
raise ValueError(f"Error tokenizing '{context}'") from e
|
||||
|
||||
|
@ -834,17 +779,17 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
if until and until():
|
||||
break
|
||||
|
||||
if self.tokens:
|
||||
if self.tokens and self._comments:
|
||||
self.tokens[-1].comments.extend(self._comments)
|
||||
|
||||
def _chars(self, size: int) -> str:
|
||||
if size == 1:
|
||||
return self._char
|
||||
|
||||
start = self._current - 1
|
||||
end = start + size
|
||||
if end <= self.size:
|
||||
return self.sql[start:end]
|
||||
return ""
|
||||
|
||||
return self.sql[start:end] if end <= self.size else ""
|
||||
|
||||
def _advance(self, i: int = 1, alnum: bool = False) -> None:
|
||||
if self.WHITE_SPACE.get(self._char) is TokenType.BREAK:
|
||||
|
@ -859,6 +804,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._peek = "" if self._end else self.sql[self._current]
|
||||
|
||||
if alnum and self._char.isalnum():
|
||||
# Here we use local variables instead of attributes for better performance
|
||||
_col = self._col
|
||||
_current = self._current
|
||||
_end = self._end
|
||||
|
@ -885,11 +831,12 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self.tokens.append(
|
||||
Token(
|
||||
token_type,
|
||||
self._text if text is None else text,
|
||||
self._line,
|
||||
self._col,
|
||||
self._current,
|
||||
self._comments,
|
||||
text=self._text if text is None else text,
|
||||
line=self._line,
|
||||
col=self._col,
|
||||
start=self._start,
|
||||
end=self._current - 1,
|
||||
comments=self._comments,
|
||||
)
|
||||
)
|
||||
self._comments = []
|
||||
|
@ -929,6 +876,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
break
|
||||
if result == 2:
|
||||
word = chars
|
||||
|
||||
size += 1
|
||||
end = self._current - 1 + size
|
||||
|
||||
|
@ -946,6 +894,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
else:
|
||||
skip = True
|
||||
else:
|
||||
char = ""
|
||||
chars = " "
|
||||
|
||||
word = None if not single_token and chars[-1] not in self.WHITE_SPACE else word
|
||||
|
@ -959,8 +908,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
if self._scan_string(word):
|
||||
return
|
||||
if self._scan_formatted_string(word):
|
||||
return
|
||||
if self._scan_comment(word):
|
||||
return
|
||||
|
||||
|
@ -1004,9 +951,9 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
if self._char == "0":
|
||||
peek = self._peek.upper()
|
||||
if peek == "B":
|
||||
return self._scan_bits() if self._BIT_STRINGS else self._add(TokenType.NUMBER)
|
||||
return self._scan_bits() if self.BIT_STRINGS else self._add(TokenType.NUMBER)
|
||||
elif peek == "X":
|
||||
return self._scan_hex() if self._HEX_STRINGS else self._add(TokenType.NUMBER)
|
||||
return self._scan_hex() if self.HEX_STRINGS else self._add(TokenType.NUMBER)
|
||||
|
||||
decimal = False
|
||||
scientific = 0
|
||||
|
@ -1075,37 +1022,24 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
return self._text
|
||||
|
||||
def _scan_string(self, quote: str) -> bool:
|
||||
quote_end = self._QUOTES.get(quote)
|
||||
if quote_end is None:
|
||||
return False
|
||||
def _scan_string(self, start: str) -> bool:
|
||||
base = None
|
||||
token_type = TokenType.STRING
|
||||
|
||||
self._advance(len(quote))
|
||||
text = self._extract_string(quote_end)
|
||||
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
|
||||
self._add(TokenType.NATIONAL if quote[0].upper() == "N" else TokenType.STRING, text)
|
||||
return True
|
||||
if start in self._QUOTES:
|
||||
end = self._QUOTES[start]
|
||||
elif start in self._FORMAT_STRINGS:
|
||||
end, token_type = self._FORMAT_STRINGS[start]
|
||||
|
||||
# X'1234', b'0110', E'\\\\\' etc.
|
||||
def _scan_formatted_string(self, string_start: str) -> bool:
|
||||
if string_start in self._HEX_STRINGS:
|
||||
delimiters = self._HEX_STRINGS
|
||||
token_type = TokenType.HEX_STRING
|
||||
base = 16
|
||||
elif string_start in self._BIT_STRINGS:
|
||||
delimiters = self._BIT_STRINGS
|
||||
token_type = TokenType.BIT_STRING
|
||||
base = 2
|
||||
elif string_start in self._BYTE_STRINGS:
|
||||
delimiters = self._BYTE_STRINGS
|
||||
token_type = TokenType.BYTE_STRING
|
||||
base = None
|
||||
if token_type == TokenType.HEX_STRING:
|
||||
base = 16
|
||||
elif token_type == TokenType.BIT_STRING:
|
||||
base = 2
|
||||
else:
|
||||
return False
|
||||
|
||||
self._advance(len(string_start))
|
||||
string_end = delimiters[string_start]
|
||||
text = self._extract_string(string_end)
|
||||
self._advance(len(start))
|
||||
text = self._extract_string(end)
|
||||
|
||||
if base:
|
||||
try:
|
||||
|
@ -1114,6 +1048,8 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
raise RuntimeError(
|
||||
f"Numeric string contains invalid characters from {self._line}:{self._start}"
|
||||
)
|
||||
else:
|
||||
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
|
||||
|
||||
self._add(token_type, text)
|
||||
return True
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
import typing as t
|
||||
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.helper import find_new_name
|
||||
from sqlglot.helper import find_new_name, name_sequence
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.generator import Generator
|
||||
|
@ -63,16 +63,17 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
|
|||
distinct_cols = expression.args["distinct"].pop().args["on"].expressions
|
||||
outer_selects = expression.selects
|
||||
row_number = find_new_name(expression.named_selects, "_row_number")
|
||||
window = exp.Window(
|
||||
this=exp.RowNumber(),
|
||||
partition_by=distinct_cols,
|
||||
)
|
||||
window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
|
||||
order = expression.args.get("order")
|
||||
|
||||
if order:
|
||||
window.set("order", order.pop().copy())
|
||||
|
||||
window = exp.alias_(window, row_number)
|
||||
expression.select(window, copy=False)
|
||||
|
||||
return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1')
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -93,7 +94,7 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
|
|||
for select in expression.selects:
|
||||
if not select.alias_or_name:
|
||||
alias = find_new_name(taken, "_c")
|
||||
select.replace(exp.alias_(select.copy(), alias))
|
||||
select.replace(exp.alias_(select, alias))
|
||||
taken.add(alias)
|
||||
|
||||
outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
|
||||
|
@ -102,8 +103,9 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
|
|||
for expr in qualify_filters.find_all((exp.Window, exp.Column)):
|
||||
if isinstance(expr, exp.Window):
|
||||
alias = find_new_name(expression.named_selects, "_w")
|
||||
expression.select(exp.alias_(expr.copy(), alias), copy=False)
|
||||
expression.select(exp.alias_(expr, alias), copy=False)
|
||||
column = exp.column(alias)
|
||||
|
||||
if isinstance(expr.parent, exp.Qualify):
|
||||
qualify_filters = column
|
||||
else:
|
||||
|
@ -123,6 +125,7 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr
|
|||
"""
|
||||
for node in expression.find_all(exp.DataType):
|
||||
node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)])
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -147,6 +150,7 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
|
|||
alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore
|
||||
),
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -156,7 +160,10 @@ def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
|
|||
from sqlglot.optimizer.scope import build_scope
|
||||
|
||||
taken_select_names = set(expression.named_selects)
|
||||
taken_source_names = set(build_scope(expression).selected_sources)
|
||||
scope = build_scope(expression)
|
||||
if not scope:
|
||||
return expression
|
||||
taken_source_names = set(scope.selected_sources)
|
||||
|
||||
for select in expression.selects:
|
||||
to_replace = select
|
||||
|
@ -226,6 +233,7 @@ def remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
|
|||
else node,
|
||||
copy=False,
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -242,12 +250,20 @@ def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expre
|
|||
return expression
|
||||
|
||||
|
||||
def unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
|
||||
if isinstance(expression, exp.Pivot):
|
||||
expression.args["field"].transform(
|
||||
lambda node: exp.column(node.output_name) if isinstance(node, exp.Column) else node,
|
||||
copy=False,
|
||||
)
|
||||
def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
|
||||
if isinstance(expression, exp.With) and expression.recursive:
|
||||
next_name = name_sequence("_c_")
|
||||
|
||||
for cte in expression.expressions:
|
||||
if not cte.args["alias"].columns:
|
||||
query = cte.this
|
||||
if isinstance(query, exp.Union):
|
||||
query = query.this
|
||||
|
||||
cte.args["alias"].set(
|
||||
"columns",
|
||||
[exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue