Merging upstream version 10.5.2.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
77197f1e44
commit
e0f3bbb5f3
58 changed files with 1480 additions and 383 deletions
|
@ -32,7 +32,7 @@ from sqlglot.parser import Parser
|
|||
from sqlglot.schema import MappingSchema
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
|
||||
__version__ = "10.4.2"
|
||||
__version__ = "10.5.2"
|
||||
|
||||
pretty = False
|
||||
|
||||
|
@ -60,9 +60,9 @@ def parse(
|
|||
def parse_one(
|
||||
sql: str,
|
||||
read: t.Optional[str | Dialect] = None,
|
||||
into: t.Optional[Expression | str] = None,
|
||||
into: t.Optional[t.Type[Expression] | str] = None,
|
||||
**opts,
|
||||
) -> t.Optional[Expression]:
|
||||
) -> Expression:
|
||||
"""
|
||||
Parses the given SQL string and returns a syntax tree for the first parsed SQL statement.
|
||||
|
||||
|
@ -83,7 +83,12 @@ def parse_one(
|
|||
else:
|
||||
result = dialect.parse(sql, **opts)
|
||||
|
||||
return result[0] if result else None
|
||||
for expression in result:
|
||||
if not expression:
|
||||
raise ParseError(f"No expression was parsed from '{sql}'")
|
||||
return expression
|
||||
else:
|
||||
raise ParseError(f"No expression was parsed from '{sql}'")
|
||||
|
||||
|
||||
def transpile(
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
datestrtodate_sql,
|
||||
|
@ -46,8 +46,9 @@ def _date_add_sql(data_type, kind):
|
|||
|
||||
def _derived_table_values_to_unnest(self, expression):
|
||||
if not isinstance(expression.unnest().parent, exp.From):
|
||||
expression = transforms.remove_precision_parameterized_types(expression)
|
||||
return self.values_sql(expression)
|
||||
rows = [list(tuple_exp.find_all(exp.Literal)) for tuple_exp in expression.find_all(exp.Tuple)]
|
||||
rows = [tuple_exp.expressions for tuple_exp in expression.find_all(exp.Tuple)]
|
||||
structs = []
|
||||
for row in rows:
|
||||
aliases = [
|
||||
|
@ -118,6 +119,7 @@ class BigQuery(Dialect):
|
|||
"BEGIN TRANSACTION": TokenType.BEGIN,
|
||||
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
|
||||
"CURRENT_TIME": TokenType.CURRENT_TIME,
|
||||
"DECLARE": TokenType.COMMAND,
|
||||
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
||||
"FLOAT64": TokenType.DOUBLE,
|
||||
"INT64": TokenType.BIGINT,
|
||||
|
@ -166,6 +168,7 @@ class BigQuery(Dialect):
|
|||
class Generator(generator.Generator):
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore
|
||||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||
exp.DateAdd: _date_add_sql("DATE", "ADD"),
|
||||
exp.DateSub: _date_add_sql("DATE", "SUB"),
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
|
||||
from sqlglot.parser import parse_var_map
|
||||
|
@ -22,6 +24,7 @@ class ClickHouse(Dialect):
|
|||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"ASOF": TokenType.ASOF,
|
||||
"GLOBAL": TokenType.GLOBAL,
|
||||
"DATETIME64": TokenType.DATETIME,
|
||||
"FINAL": TokenType.FINAL,
|
||||
"FLOAT32": TokenType.FLOAT,
|
||||
|
@ -37,14 +40,32 @@ class ClickHouse(Dialect):
|
|||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"MAP": parse_var_map,
|
||||
"QUANTILE": lambda params, args: exp.Quantile(this=args, quantile=params),
|
||||
"QUANTILES": lambda params, args: exp.Quantiles(parameters=params, expressions=args),
|
||||
"QUANTILEIF": lambda params, args: exp.QuantileIf(parameters=params, expressions=args),
|
||||
}
|
||||
|
||||
RANGE_PARSERS = {
|
||||
**parser.Parser.RANGE_PARSERS,
|
||||
TokenType.GLOBAL: lambda self, this: self._match(TokenType.IN)
|
||||
and self._parse_in(this, is_global=True),
|
||||
}
|
||||
|
||||
JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} # type: ignore
|
||||
|
||||
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore
|
||||
|
||||
def _parse_table(self, schema=False):
|
||||
this = super()._parse_table(schema)
|
||||
def _parse_in(
|
||||
self, this: t.Optional[exp.Expression], is_global: bool = False
|
||||
) -> exp.Expression:
|
||||
this = super()._parse_in(this)
|
||||
this.set("is_global", is_global)
|
||||
return this
|
||||
|
||||
def _parse_table(
|
||||
self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None
|
||||
) -> t.Optional[exp.Expression]:
|
||||
this = super()._parse_table(schema=schema, alias_tokens=alias_tokens)
|
||||
|
||||
if self._match(TokenType.FINAL):
|
||||
this = self.expression(exp.Final, this=this)
|
||||
|
@ -76,6 +97,16 @@ class ClickHouse(Dialect):
|
|||
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
|
||||
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||
exp.Quantile: lambda self, e: f"quantile{self._param_args_sql(e, 'quantile', 'this')}",
|
||||
exp.Quantiles: lambda self, e: f"quantiles{self._param_args_sql(e, 'parameters', 'expressions')}",
|
||||
exp.QuantileIf: lambda self, e: f"quantileIf{self._param_args_sql(e, 'parameters', 'expressions')}",
|
||||
}
|
||||
|
||||
EXPLICIT_UNION = True
|
||||
|
||||
def _param_args_sql(
|
||||
self, expression: exp.Expression, params_name: str, args_name: str
|
||||
) -> str:
|
||||
params = self.format_args(self.expressions(expression, params_name))
|
||||
args = self.format_args(self.expressions(expression, args_name))
|
||||
return f"({params})({args})"
|
||||
|
|
|
@ -381,3 +381,20 @@ def timestrtotime_sql(self, expression: exp.TimeStrToTime) -> str:
|
|||
|
||||
def datestrtodate_sql(self, expression: exp.DateStrToDate) -> str:
|
||||
return f"CAST({self.sql(expression, 'this')} AS DATE)"
|
||||
|
||||
|
||||
def trim_sql(self, expression):
|
||||
target = self.sql(expression, "this")
|
||||
trim_type = self.sql(expression, "position")
|
||||
remove_chars = self.sql(expression, "expression")
|
||||
collation = self.sql(expression, "collation")
|
||||
|
||||
# Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
|
||||
if not remove_chars and not collation:
|
||||
return self.trim_sql(expression)
|
||||
|
||||
trim_type = f"{trim_type} " if trim_type else ""
|
||||
remove_chars = f"{remove_chars} " if remove_chars else ""
|
||||
from_part = "FROM " if trim_type or remove_chars else ""
|
||||
collation = f" COLLATE {collation}" if collation else ""
|
||||
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
|
||||
|
|
|
@ -175,14 +175,6 @@ class Hive(Dialect):
|
|||
ESCAPES = ["\\"]
|
||||
ENCODE = "utf-8"
|
||||
|
||||
NUMERIC_LITERALS = {
|
||||
"L": "BIGINT",
|
||||
"S": "SMALLINT",
|
||||
"Y": "TINYINT",
|
||||
"D": "DOUBLE",
|
||||
"F": "FLOAT",
|
||||
"BD": "DECIMAL",
|
||||
}
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"ADD ARCHIVE": TokenType.COMMAND,
|
||||
|
@ -191,9 +183,21 @@ class Hive(Dialect):
|
|||
"ADD FILES": TokenType.COMMAND,
|
||||
"ADD JAR": TokenType.COMMAND,
|
||||
"ADD JARS": TokenType.COMMAND,
|
||||
"MSCK REPAIR": TokenType.COMMAND,
|
||||
"WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES,
|
||||
}
|
||||
|
||||
NUMERIC_LITERALS = {
|
||||
"L": "BIGINT",
|
||||
"S": "SMALLINT",
|
||||
"Y": "TINYINT",
|
||||
"D": "DOUBLE",
|
||||
"F": "FLOAT",
|
||||
"BD": "DECIMAL",
|
||||
}
|
||||
|
||||
IDENTIFIER_CAN_START_WITH_DIGIT = True
|
||||
|
||||
class Parser(parser.Parser):
|
||||
STRICT_CAST = False
|
||||
|
||||
|
@ -315,6 +319,7 @@ class Hive(Dialect):
|
|||
exp.RowFormatSerdeProperty: lambda self, e: f"ROW FORMAT SERDE {self.sql(e, 'this')}",
|
||||
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
|
||||
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
|
||||
exp.LastDateOfMonth: rename_func("LAST_DAY"),
|
||||
}
|
||||
|
||||
WITH_PROPERTIES = {exp.Property}
|
||||
|
@ -342,4 +347,6 @@ class Hive(Dialect):
|
|||
and not expression.expressions
|
||||
):
|
||||
expression = exp.DataType.build("text")
|
||||
elif expression.this in exp.DataType.TEMPORAL_TYPES:
|
||||
expression = exp.DataType.build(expression.this)
|
||||
return super().datatype_sql(expression)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func
|
||||
from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sql
|
||||
from sqlglot.helper import csv
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
@ -64,6 +64,7 @@ class Oracle(Dialect):
|
|||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.Limit: _limit_sql,
|
||||
exp.Trim: trim_sql,
|
||||
exp.Matches: rename_func("DECODE"),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
|
|
|
@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
|
|||
no_tablesample_sql,
|
||||
no_trycast_sql,
|
||||
str_position_sql,
|
||||
trim_sql,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
@ -81,23 +82,6 @@ def _substring_sql(self, expression):
|
|||
return f"SUBSTRING({this}{from_part}{for_part})"
|
||||
|
||||
|
||||
def _trim_sql(self, expression):
|
||||
target = self.sql(expression, "this")
|
||||
trim_type = self.sql(expression, "position")
|
||||
remove_chars = self.sql(expression, "expression")
|
||||
collation = self.sql(expression, "collation")
|
||||
|
||||
# Use TRIM/LTRIM/RTRIM syntax if the expression isn't postgres-specific
|
||||
if not remove_chars and not collation:
|
||||
return self.trim_sql(expression)
|
||||
|
||||
trim_type = f"{trim_type} " if trim_type else ""
|
||||
remove_chars = f"{remove_chars} " if remove_chars else ""
|
||||
from_part = "FROM " if trim_type or remove_chars else ""
|
||||
collation = f" COLLATE {collation}" if collation else ""
|
||||
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
|
||||
|
||||
|
||||
def _string_agg_sql(self, expression):
|
||||
expression = expression.copy()
|
||||
separator = expression.args.get("separator") or exp.Literal.string(",")
|
||||
|
@ -248,7 +232,6 @@ class Postgres(Dialect):
|
|||
"COMMENT ON": TokenType.COMMAND,
|
||||
"DECLARE": TokenType.COMMAND,
|
||||
"DO": TokenType.COMMAND,
|
||||
"DOUBLE PRECISION": TokenType.DOUBLE,
|
||||
"GENERATED": TokenType.GENERATED,
|
||||
"GRANT": TokenType.COMMAND,
|
||||
"HSTORE": TokenType.HSTORE,
|
||||
|
@ -318,7 +301,7 @@ class Postgres(Dialect):
|
|||
exp.Substring: _substring_sql,
|
||||
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
exp.Trim: _trim_sql,
|
||||
exp.Trim: trim_sql,
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
|
||||
exp.DataType: _datatype_sql,
|
||||
|
|
|
@ -195,7 +195,6 @@ class Snowflake(Dialect):
|
|||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"QUALIFY": TokenType.QUALIFY,
|
||||
"DOUBLE PRECISION": TokenType.DOUBLE,
|
||||
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
|
||||
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
|
||||
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
|
||||
|
@ -294,3 +293,10 @@ class Snowflake(Dialect):
|
|||
)
|
||||
return self.no_identify(lambda: super(self.__class__, self).select_sql(expression))
|
||||
return super().select_sql(expression)
|
||||
|
||||
def describe_sql(self, expression: exp.Describe) -> str:
|
||||
# Default to table if kind is unknown
|
||||
kind_value = expression.args.get("kind") or "TABLE"
|
||||
kind = f" {kind_value}" if kind_value else ""
|
||||
this = f" {self.sql(expression, 'this')}"
|
||||
return f"DESCRIBE{kind}{this}"
|
||||
|
|
|
@ -75,6 +75,20 @@ def _parse_format(args):
|
|||
)
|
||||
|
||||
|
||||
def _parse_eomonth(args):
|
||||
date = seq_get(args, 0)
|
||||
month_lag = seq_get(args, 1)
|
||||
unit = DATE_DELTA_INTERVAL.get("month")
|
||||
|
||||
if month_lag is None:
|
||||
return exp.LastDateOfMonth(this=date)
|
||||
|
||||
# Remove month lag argument in parser as its compared with the number of arguments of the resulting class
|
||||
args.remove(month_lag)
|
||||
|
||||
return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit))
|
||||
|
||||
|
||||
def generate_date_delta_with_unit_sql(self, e):
|
||||
func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF"
|
||||
return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})"
|
||||
|
@ -256,12 +270,14 @@ class TSQL(Dialect):
|
|||
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
|
||||
"DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True),
|
||||
"DATEPART": _format_time_lambda(exp.TimeToStr),
|
||||
"GETDATE": exp.CurrentDate.from_arg_list,
|
||||
"GETDATE": exp.CurrentTimestamp.from_arg_list,
|
||||
"SYSDATETIME": exp.CurrentTimestamp.from_arg_list,
|
||||
"IIF": exp.If.from_arg_list,
|
||||
"LEN": exp.Length.from_arg_list,
|
||||
"REPLICATE": exp.Repeat.from_arg_list,
|
||||
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
|
||||
"FORMAT": _parse_format,
|
||||
"EOMONTH": _parse_eomonth,
|
||||
}
|
||||
|
||||
VAR_LENGTH_DATATYPES = {
|
||||
|
@ -271,6 +287,9 @@ class TSQL(Dialect):
|
|||
DataType.Type.NCHAR,
|
||||
}
|
||||
|
||||
# https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-temporary#create-a-temporary-table
|
||||
TABLE_PREFIX_TOKENS = {TokenType.HASH}
|
||||
|
||||
def _parse_convert(self, strict):
|
||||
to = self._parse_types()
|
||||
self._match(TokenType.COMMA)
|
||||
|
@ -323,6 +342,7 @@ class TSQL(Dialect):
|
|||
exp.DateAdd: generate_date_delta_with_unit_sql,
|
||||
exp.DateDiff: generate_date_delta_with_unit_sql,
|
||||
exp.CurrentDate: rename_func("GETDATE"),
|
||||
exp.CurrentTimestamp: rename_func("GETDATE"),
|
||||
exp.If: rename_func("IIF"),
|
||||
exp.NumberToStr: _format_sql,
|
||||
exp.TimeToStr: _format_sql,
|
||||
|
|
|
@ -22,6 +22,7 @@ from sqlglot.helper import (
|
|||
split_num_words,
|
||||
subclasses,
|
||||
)
|
||||
from sqlglot.tokens import Token
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
|
@ -457,6 +458,23 @@ class Expression(metaclass=_Expression):
|
|||
assert isinstance(self, type_)
|
||||
return self
|
||||
|
||||
def dump(self):
|
||||
"""
|
||||
Dump this Expression to a JSON-serializable dict.
|
||||
"""
|
||||
from sqlglot.serde import dump
|
||||
|
||||
return dump(self)
|
||||
|
||||
@classmethod
|
||||
def load(cls, obj):
|
||||
"""
|
||||
Load a dict (as returned by `Expression.dump`) into an Expression instance.
|
||||
"""
|
||||
from sqlglot.serde import load
|
||||
|
||||
return load(obj)
|
||||
|
||||
|
||||
class Condition(Expression):
|
||||
def and_(self, *expressions, dialect=None, **opts):
|
||||
|
@ -631,11 +649,15 @@ class Create(Expression):
|
|||
"replace": False,
|
||||
"unique": False,
|
||||
"materialized": False,
|
||||
"data": False,
|
||||
"statistics": False,
|
||||
"no_primary_index": False,
|
||||
"indexes": False,
|
||||
}
|
||||
|
||||
|
||||
class Describe(Expression):
|
||||
pass
|
||||
arg_types = {"this": True, "kind": False}
|
||||
|
||||
|
||||
class Set(Expression):
|
||||
|
@ -731,7 +753,7 @@ class Column(Condition):
|
|||
class ColumnDef(Expression):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"kind": True,
|
||||
"kind": False,
|
||||
"constraints": False,
|
||||
"exists": False,
|
||||
}
|
||||
|
@ -879,7 +901,15 @@ class Identifier(Expression):
|
|||
|
||||
|
||||
class Index(Expression):
|
||||
arg_types = {"this": False, "table": False, "where": False, "columns": False}
|
||||
arg_types = {
|
||||
"this": False,
|
||||
"table": False,
|
||||
"where": False,
|
||||
"columns": False,
|
||||
"unique": False,
|
||||
"primary": False,
|
||||
"amp": False, # teradata
|
||||
}
|
||||
|
||||
|
||||
class Insert(Expression):
|
||||
|
@ -1361,6 +1391,7 @@ class Table(Expression):
|
|||
"laterals": False,
|
||||
"joins": False,
|
||||
"pivots": False,
|
||||
"hints": False,
|
||||
}
|
||||
|
||||
|
||||
|
@ -1818,7 +1849,12 @@ class Select(Subqueryable):
|
|||
join.this.replace(join.this.subquery())
|
||||
|
||||
if join_type:
|
||||
natural: t.Optional[Token]
|
||||
side: t.Optional[Token]
|
||||
kind: t.Optional[Token]
|
||||
|
||||
natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore
|
||||
|
||||
if natural:
|
||||
join.set("natural", True)
|
||||
if side:
|
||||
|
@ -2111,6 +2147,7 @@ class DataType(Expression):
|
|||
JSON = auto()
|
||||
JSONB = auto()
|
||||
INTERVAL = auto()
|
||||
TIME = auto()
|
||||
TIMESTAMP = auto()
|
||||
TIMESTAMPTZ = auto()
|
||||
TIMESTAMPLTZ = auto()
|
||||
|
@ -2171,11 +2208,24 @@ class DataType(Expression):
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def build(cls, dtype, **kwargs) -> DataType:
|
||||
return DataType(
|
||||
this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()],
|
||||
**kwargs,
|
||||
)
|
||||
def build(
|
||||
cls, dtype: str | DataType.Type, dialect: t.Optional[str | Dialect] = None, **kwargs
|
||||
) -> DataType:
|
||||
from sqlglot import parse_one
|
||||
|
||||
if isinstance(dtype, str):
|
||||
data_type_exp: t.Optional[Expression]
|
||||
if dtype.upper() in cls.Type.__members__:
|
||||
data_type_exp = DataType(this=DataType.Type[dtype.upper()])
|
||||
else:
|
||||
data_type_exp = parse_one(dtype, read=dialect, into=DataType)
|
||||
if data_type_exp is None:
|
||||
raise ValueError(f"Unparsable data type value: {dtype}")
|
||||
elif isinstance(dtype, DataType.Type):
|
||||
data_type_exp = DataType(this=dtype)
|
||||
else:
|
||||
raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type")
|
||||
return DataType(**{**data_type_exp.args, **kwargs})
|
||||
|
||||
|
||||
# https://www.postgresql.org/docs/15/datatype-pseudo.html
|
||||
|
@ -2429,6 +2479,7 @@ class In(Predicate):
|
|||
"query": False,
|
||||
"unnest": False,
|
||||
"field": False,
|
||||
"is_global": False,
|
||||
}
|
||||
|
||||
|
||||
|
@ -2678,6 +2729,10 @@ class DatetimeTrunc(Func, TimeUnit):
|
|||
arg_types = {"this": True, "unit": True, "zone": False}
|
||||
|
||||
|
||||
class LastDateOfMonth(Func):
|
||||
pass
|
||||
|
||||
|
||||
class Extract(Func):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
@ -2815,7 +2870,13 @@ class Length(Func):
|
|||
|
||||
|
||||
class Levenshtein(Func):
|
||||
arg_types = {"this": True, "expression": False}
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"expression": False,
|
||||
"ins_cost": False,
|
||||
"del_cost": False,
|
||||
"sub_cost": False,
|
||||
}
|
||||
|
||||
|
||||
class Ln(Func):
|
||||
|
@ -2890,6 +2951,16 @@ class Quantile(AggFunc):
|
|||
arg_types = {"this": True, "quantile": True}
|
||||
|
||||
|
||||
# Clickhouse-specific:
|
||||
# https://clickhouse.com/docs/en/sql-reference/aggregate-functions/reference/quantiles/#quantiles
|
||||
class Quantiles(AggFunc):
|
||||
arg_types = {"parameters": True, "expressions": True}
|
||||
|
||||
|
||||
class QuantileIf(AggFunc):
|
||||
arg_types = {"parameters": True, "expressions": True}
|
||||
|
||||
|
||||
class ApproxQuantile(Quantile):
|
||||
arg_types = {"this": True, "quantile": True, "accuracy": False}
|
||||
|
||||
|
@ -2962,8 +3033,10 @@ class StrToTime(Func):
|
|||
arg_types = {"this": True, "format": True}
|
||||
|
||||
|
||||
# Spark allows unix_timestamp()
|
||||
# https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.unix_timestamp.html
|
||||
class StrToUnix(Func):
|
||||
arg_types = {"this": True, "format": True}
|
||||
arg_types = {"this": False, "format": False}
|
||||
|
||||
|
||||
class NumberToStr(Func):
|
||||
|
@ -3131,7 +3204,7 @@ def maybe_parse(
|
|||
dialect=None,
|
||||
prefix=None,
|
||||
**opts,
|
||||
) -> t.Optional[Expression]:
|
||||
) -> Expression:
|
||||
"""Gracefully handle a possible string or expression.
|
||||
|
||||
Example:
|
||||
|
@ -3627,11 +3700,11 @@ def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]:
|
|||
if not isinstance(sql_path, str):
|
||||
raise ValueError(f"Invalid type provided for a table: {type(sql_path)}")
|
||||
|
||||
catalog, db, table_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 3)]
|
||||
catalog, db, table_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 3))
|
||||
return Table(this=table_name, db=db, catalog=catalog, **kwargs)
|
||||
|
||||
|
||||
def to_column(sql_path: str, **kwargs) -> Column:
|
||||
def to_column(sql_path: str | Column, **kwargs) -> Column:
|
||||
"""
|
||||
Create a column from a `[table].[column]` sql path. Schema is optional.
|
||||
|
||||
|
@ -3646,7 +3719,7 @@ def to_column(sql_path: str, **kwargs) -> Column:
|
|||
return sql_path
|
||||
if not isinstance(sql_path, str):
|
||||
raise ValueError(f"Invalid type provided for column: {type(sql_path)}")
|
||||
table_name, column_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 2)]
|
||||
table_name, column_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 2))
|
||||
return Column(this=column_name, table=table_name, **kwargs)
|
||||
|
||||
|
||||
|
@ -3748,7 +3821,7 @@ def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
|
|||
def values(
|
||||
values: t.Iterable[t.Tuple[t.Any, ...]],
|
||||
alias: t.Optional[str] = None,
|
||||
columns: t.Optional[t.Iterable[str]] = None,
|
||||
columns: t.Optional[t.Iterable[str] | t.Dict[str, DataType]] = None,
|
||||
) -> Values:
|
||||
"""Build VALUES statement.
|
||||
|
||||
|
@ -3759,7 +3832,10 @@ def values(
|
|||
Args:
|
||||
values: values statements that will be converted to SQL
|
||||
alias: optional alias
|
||||
columns: Optional list of ordered column names. An alias is required when providing column names.
|
||||
columns: Optional list of ordered column names or ordered dictionary of column names to types.
|
||||
If either are provided then an alias is also required.
|
||||
If a dictionary is provided then the first column of the values will be casted to the expected type
|
||||
in order to help with type inference.
|
||||
|
||||
Returns:
|
||||
Values: the Values expression object
|
||||
|
@ -3771,8 +3847,15 @@ def values(
|
|||
if columns
|
||||
else TableAlias(this=to_identifier(alias) if alias else None)
|
||||
)
|
||||
expressions = [convert(tup) for tup in values]
|
||||
if columns and isinstance(columns, dict):
|
||||
types = list(columns.values())
|
||||
expressions[0].set(
|
||||
"expressions",
|
||||
[Cast(this=x, to=types[i]) for i, x in enumerate(expressions[0].expressions)],
|
||||
)
|
||||
return Values(
|
||||
expressions=[convert(tup) for tup in values],
|
||||
expressions=expressions,
|
||||
alias=table_alias,
|
||||
)
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ class Generator:
|
|||
The default is on the smaller end because the length only represents a segment and not the true
|
||||
line length.
|
||||
Default: 80
|
||||
comments: Whether or not to preserve comments in the ouput SQL code.
|
||||
comments: Whether or not to preserve comments in the output SQL code.
|
||||
Default: True
|
||||
"""
|
||||
|
||||
|
@ -236,7 +236,10 @@ class Generator:
|
|||
return sql
|
||||
|
||||
sep = "\n" if self.pretty else " "
|
||||
comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments)
|
||||
comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments if comment)
|
||||
|
||||
if not comments:
|
||||
return sql
|
||||
|
||||
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
|
||||
return f"{comments}{self.sep()}{sql}"
|
||||
|
@ -362,10 +365,10 @@ class Generator:
|
|||
kind = self.sql(expression, "kind")
|
||||
constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
|
||||
exists = "IF NOT EXISTS " if expression.args.get("exists") else ""
|
||||
kind = f" {kind}" if kind else ""
|
||||
constraints = f" {constraints}" if constraints else ""
|
||||
|
||||
if not constraints:
|
||||
return f"{exists}{column} {kind}"
|
||||
return f"{exists}{column} {kind} {constraints}"
|
||||
return f"{exists}{column}{kind}{constraints}"
|
||||
|
||||
def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -416,7 +419,7 @@ class Generator:
|
|||
this = self.sql(expression, "this")
|
||||
kind = self.sql(expression, "kind").upper()
|
||||
expression_sql = self.sql(expression, "expression")
|
||||
expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else ""
|
||||
expression_sql = f" AS{self.sep()}{expression_sql}" if expression_sql else ""
|
||||
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
|
||||
transient = (
|
||||
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
|
||||
|
@ -427,6 +430,40 @@ class Generator:
|
|||
unique = " UNIQUE" if expression.args.get("unique") else ""
|
||||
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
|
||||
properties = self.sql(expression, "properties")
|
||||
data = expression.args.get("data")
|
||||
if data is None:
|
||||
data = ""
|
||||
elif data:
|
||||
data = " WITH DATA"
|
||||
else:
|
||||
data = " WITH NO DATA"
|
||||
statistics = expression.args.get("statistics")
|
||||
if statistics is None:
|
||||
statistics = ""
|
||||
elif statistics:
|
||||
statistics = " AND STATISTICS"
|
||||
else:
|
||||
statistics = " AND NO STATISTICS"
|
||||
no_primary_index = " NO PRIMARY INDEX" if expression.args.get("no_primary_index") else ""
|
||||
|
||||
indexes = expression.args.get("indexes")
|
||||
index_sql = ""
|
||||
if indexes is not None:
|
||||
indexes_sql = []
|
||||
for index in indexes:
|
||||
ind_unique = " UNIQUE" if index.args.get("unique") else ""
|
||||
ind_primary = " PRIMARY" if index.args.get("primary") else ""
|
||||
ind_amp = " AMP" if index.args.get("amp") else ""
|
||||
ind_name = f" {index.name}" if index.name else ""
|
||||
ind_columns = (
|
||||
f' ({self.expressions(index, key="columns", flat=True)})'
|
||||
if index.args.get("columns")
|
||||
else ""
|
||||
)
|
||||
indexes_sql.append(
|
||||
f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}"
|
||||
)
|
||||
index_sql = "".join(indexes_sql)
|
||||
|
||||
modifiers = "".join(
|
||||
(
|
||||
|
@ -438,7 +475,10 @@ class Generator:
|
|||
materialized,
|
||||
)
|
||||
)
|
||||
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties} {expression_sql}"
|
||||
|
||||
post_expression_modifiers = "".join((data, statistics, no_primary_index))
|
||||
|
||||
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}"
|
||||
return self.prepend_ctes(expression, expression_sql)
|
||||
|
||||
def describe_sql(self, expression: exp.Describe) -> str:
|
||||
|
@ -668,6 +708,8 @@ class Generator:
|
|||
|
||||
alias = self.sql(expression, "alias")
|
||||
alias = f"{sep}{alias}" if alias else ""
|
||||
hints = self.expressions(expression, key="hints", sep=", ", flat=True)
|
||||
hints = f" WITH ({hints})" if hints else ""
|
||||
laterals = self.expressions(expression, key="laterals", sep="")
|
||||
joins = self.expressions(expression, key="joins", sep="")
|
||||
pivots = self.expressions(expression, key="pivots", sep="")
|
||||
|
@ -676,7 +718,7 @@ class Generator:
|
|||
pivots = f"{pivots}{alias}"
|
||||
alias = ""
|
||||
|
||||
return f"{table}{alias}{laterals}{joins}{pivots}"
|
||||
return f"{table}{alias}{hints}{laterals}{joins}{pivots}"
|
||||
|
||||
def tablesample_sql(self, expression: exp.TableSample) -> str:
|
||||
if self.alias_post_tablesample and expression.this.alias:
|
||||
|
@ -1020,7 +1062,9 @@ class Generator:
|
|||
if not partition and not order and not spec and alias:
|
||||
return f"{this} {alias}"
|
||||
|
||||
return f"{this} ({alias}{partition_sql}{order_sql}{spec_sql})"
|
||||
window_args = alias + partition_sql + order_sql + spec_sql
|
||||
|
||||
return f"{this} ({window_args.strip()})"
|
||||
|
||||
def window_spec_sql(self, expression: exp.WindowSpec) -> str:
|
||||
kind = self.sql(expression, "kind")
|
||||
|
@ -1130,6 +1174,8 @@ class Generator:
|
|||
query = expression.args.get("query")
|
||||
unnest = expression.args.get("unnest")
|
||||
field = expression.args.get("field")
|
||||
is_global = " GLOBAL" if expression.args.get("is_global") else ""
|
||||
|
||||
if query:
|
||||
in_sql = self.wrap(query)
|
||||
elif unnest:
|
||||
|
@ -1138,7 +1184,8 @@ class Generator:
|
|||
in_sql = self.sql(field)
|
||||
else:
|
||||
in_sql = f"({self.expressions(expression, flat=True)})"
|
||||
return f"{self.sql(expression, 'this')} IN {in_sql}"
|
||||
|
||||
return f"{self.sql(expression, 'this')}{is_global} IN {in_sql}"
|
||||
|
||||
def in_unnest_op(self, unnest: exp.Unnest) -> str:
|
||||
return f"(SELECT {self.sql(unnest)})"
|
||||
|
@ -1433,7 +1480,7 @@ class Generator:
|
|||
result_sqls = []
|
||||
for i, e in enumerate(expressions):
|
||||
sql = self.sql(e, comment=False)
|
||||
comments = self.maybe_comment("", e)
|
||||
comments = self.maybe_comment("", e) if isinstance(e, exp.Expression) else ""
|
||||
|
||||
if self.pretty:
|
||||
if self._leading_comma:
|
||||
|
|
|
@ -131,7 +131,7 @@ def subclasses(
|
|||
]
|
||||
|
||||
|
||||
def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]:
|
||||
def apply_index_offset(expressions: t.List[t.Optional[E]], offset: int) -> t.List[t.Optional[E]]:
|
||||
"""
|
||||
Applies an offset to a given integer literal expression.
|
||||
|
||||
|
@ -148,10 +148,10 @@ def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]:
|
|||
|
||||
expression = expressions[0]
|
||||
|
||||
if expression.is_int:
|
||||
if expression and expression.is_int:
|
||||
expression = expression.copy()
|
||||
logger.warning("Applying array index offset (%s)", offset)
|
||||
expression.args["this"] = str(int(expression.this) + offset)
|
||||
expression.args["this"] = str(int(expression.this) + offset) # type: ignore
|
||||
return [expression]
|
||||
|
||||
return expressions
|
||||
|
@ -225,7 +225,7 @@ def open_file(file_name: str) -> t.TextIO:
|
|||
|
||||
return gzip.open(file_name, "rt", newline="")
|
||||
|
||||
return open(file_name, "rt", encoding="utf-8", newline="")
|
||||
return open(file_name, encoding="utf-8", newline="")
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -256,7 +256,7 @@ def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
|
|||
file.close()
|
||||
|
||||
|
||||
def find_new_name(taken: t.Sequence[str], base: str) -> str:
|
||||
def find_new_name(taken: t.Collection[str], base: str) -> str:
|
||||
"""
|
||||
Searches for a new name.
|
||||
|
||||
|
@ -356,6 +356,15 @@ def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any,
|
|||
yield value
|
||||
|
||||
|
||||
def count_params(function: t.Callable) -> int:
|
||||
"""
|
||||
Returns the number of formal parameters expected by a function, without counting "self"
|
||||
and "cls", in case of instance and class methods, respectively.
|
||||
"""
|
||||
count = function.__code__.co_argcount
|
||||
return count - 1 if inspect.ismethod(function) else count
|
||||
|
||||
|
||||
def dict_depth(d: t.Dict) -> int:
|
||||
"""
|
||||
Get the nesting depth of a dictionary.
|
||||
|
@ -374,6 +383,7 @@ def dict_depth(d: t.Dict) -> int:
|
|||
|
||||
Args:
|
||||
d (dict): dictionary
|
||||
|
||||
Returns:
|
||||
int: depth
|
||||
"""
|
||||
|
|
|
@ -43,7 +43,7 @@ class TypeAnnotator:
|
|||
},
|
||||
exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
|
||||
exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
|
||||
exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr),
|
||||
exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
|
||||
exp.Alias: lambda self, expr: self._annotate_unary(expr),
|
||||
exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||
exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||
|
|
|
@ -57,7 +57,7 @@ def _join_is_used(scope, join, alias):
|
|||
# But columns in the ON clause shouldn't count.
|
||||
on = join.args.get("on")
|
||||
if on:
|
||||
on_clause_columns = set(id(column) for column in on.find_all(exp.Column))
|
||||
on_clause_columns = {id(column) for column in on.find_all(exp.Column)}
|
||||
else:
|
||||
on_clause_columns = set()
|
||||
return any(
|
||||
|
@ -71,7 +71,7 @@ def _is_joined_on_all_unique_outputs(scope, join):
|
|||
return False
|
||||
|
||||
_, join_keys, _ = join_condition(join)
|
||||
remaining_unique_outputs = unique_outputs - set(c.name for c in join_keys)
|
||||
remaining_unique_outputs = unique_outputs - {c.name for c in join_keys}
|
||||
return not remaining_unique_outputs
|
||||
|
||||
|
||||
|
|
|
@ -67,11 +67,9 @@ def merge_ctes(expression, leave_tables_isolated=False):
|
|||
|
||||
singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
|
||||
for outer_scope, inner_scope, table in singular_cte_selections:
|
||||
inner_select = inner_scope.expression.unnest()
|
||||
from_or_join = table.find_ancestor(exp.From, exp.Join)
|
||||
if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
||||
if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
||||
alias = table.alias_or_name
|
||||
|
||||
_rename_inner_sources(outer_scope, inner_scope, alias)
|
||||
_merge_from(outer_scope, inner_scope, table, alias)
|
||||
_merge_expressions(outer_scope, inner_scope, alias)
|
||||
|
@ -80,18 +78,17 @@ def merge_ctes(expression, leave_tables_isolated=False):
|
|||
_merge_order(outer_scope, inner_scope)
|
||||
_merge_hints(outer_scope, inner_scope)
|
||||
_pop_cte(inner_scope)
|
||||
outer_scope.clear_cache()
|
||||
return expression
|
||||
|
||||
|
||||
def merge_derived_tables(expression, leave_tables_isolated=False):
|
||||
for outer_scope in traverse_scope(expression):
|
||||
for subquery in outer_scope.derived_tables:
|
||||
inner_select = subquery.unnest()
|
||||
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
|
||||
if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
||||
alias = subquery.alias_or_name
|
||||
inner_scope = outer_scope.sources[alias]
|
||||
|
||||
alias = subquery.alias_or_name
|
||||
inner_scope = outer_scope.sources[alias]
|
||||
if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
||||
_rename_inner_sources(outer_scope, inner_scope, alias)
|
||||
_merge_from(outer_scope, inner_scope, subquery, alias)
|
||||
_merge_expressions(outer_scope, inner_scope, alias)
|
||||
|
@ -99,21 +96,23 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
|
|||
_merge_where(outer_scope, inner_scope, from_or_join)
|
||||
_merge_order(outer_scope, inner_scope)
|
||||
_merge_hints(outer_scope, inner_scope)
|
||||
outer_scope.clear_cache()
|
||||
return expression
|
||||
|
||||
|
||||
def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
||||
def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
||||
"""
|
||||
Return True if `inner_select` can be merged into outer query.
|
||||
|
||||
Args:
|
||||
outer_scope (Scope)
|
||||
inner_select (exp.Select)
|
||||
inner_scope (Scope)
|
||||
leave_tables_isolated (bool)
|
||||
from_or_join (exp.From|exp.Join)
|
||||
Returns:
|
||||
bool: True if can be merged
|
||||
"""
|
||||
inner_select = inner_scope.expression.unnest()
|
||||
|
||||
def _is_a_window_expression_in_unmergable_operation():
|
||||
window_expressions = inner_select.find_all(exp.Window)
|
||||
|
@ -133,10 +132,40 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
|||
]
|
||||
return any(window_expressions_in_unmergable)
|
||||
|
||||
def _outer_select_joins_on_inner_select_join():
|
||||
"""
|
||||
All columns from the inner select in the ON clause must be from the first FROM table.
|
||||
|
||||
That is, this can be merged:
|
||||
SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a
|
||||
^^^ ^
|
||||
But this can't:
|
||||
SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a
|
||||
^^^ ^
|
||||
"""
|
||||
if not isinstance(from_or_join, exp.Join):
|
||||
return False
|
||||
|
||||
alias = from_or_join.this.alias_or_name
|
||||
|
||||
on = from_or_join.args.get("on")
|
||||
if not on:
|
||||
return False
|
||||
selections = [c.name for c in on.find_all(exp.Column) if c.table == alias]
|
||||
inner_from = inner_scope.expression.args.get("from")
|
||||
if not inner_from:
|
||||
return False
|
||||
inner_from_table = inner_from.expressions[0].alias_or_name
|
||||
inner_projections = {s.alias_or_name: s for s in inner_scope.selects}
|
||||
return any(
|
||||
col.table != inner_from_table
|
||||
for selection in selections
|
||||
for col in inner_projections[selection].find_all(exp.Column)
|
||||
)
|
||||
|
||||
return (
|
||||
isinstance(outer_scope.expression, exp.Select)
|
||||
and isinstance(inner_select, exp.Select)
|
||||
and isinstance(inner_select, exp.Select)
|
||||
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
|
||||
and inner_select.args.get("from")
|
||||
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
|
||||
|
@ -153,6 +182,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
|||
j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])
|
||||
)
|
||||
)
|
||||
and not _outer_select_joins_on_inner_select_join()
|
||||
and not _is_a_window_expression_in_unmergable_operation()
|
||||
)
|
||||
|
||||
|
@ -168,7 +198,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
|
|||
"""
|
||||
taken = set(outer_scope.selected_sources)
|
||||
conflicts = taken.intersection(set(inner_scope.selected_sources))
|
||||
conflicts = conflicts - {alias}
|
||||
conflicts -= {alias}
|
||||
|
||||
for conflict in conflicts:
|
||||
new_name = find_new_name(taken, conflict)
|
||||
|
|
|
@ -15,6 +15,7 @@ from sqlglot.optimizer.pushdown_projections import pushdown_projections
|
|||
from sqlglot.optimizer.qualify_columns import qualify_columns
|
||||
from sqlglot.optimizer.qualify_tables import qualify_tables
|
||||
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
|
||||
from sqlglot.schema import ensure_schema
|
||||
|
||||
RULES = (
|
||||
lower_identities,
|
||||
|
@ -51,12 +52,13 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
|
|||
If no schema is provided then the default schema defined at `sqlgot.schema` will be used
|
||||
db (str): specify the default database, as might be set by a `USE DATABASE db` statement
|
||||
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
|
||||
rules (list): sequence of optimizer rules to use
|
||||
rules (sequence): sequence of optimizer rules to use
|
||||
**kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.
|
||||
Returns:
|
||||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema or sqlglot.schema, **kwargs}
|
||||
schema = ensure_schema(schema or sqlglot.schema)
|
||||
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
|
||||
expression = expression.copy()
|
||||
for rule in rules:
|
||||
|
||||
|
|
|
@ -79,6 +79,7 @@ def _remove_unused_selections(scope, parent_selections):
|
|||
order_refs = set()
|
||||
|
||||
new_selections = []
|
||||
removed = False
|
||||
for i, selection in enumerate(scope.selects):
|
||||
if (
|
||||
SELECT_ALL in parent_selections
|
||||
|
@ -88,12 +89,15 @@ def _remove_unused_selections(scope, parent_selections):
|
|||
new_selections.append(selection)
|
||||
else:
|
||||
removed_indexes.append(i)
|
||||
removed = True
|
||||
|
||||
# If there are no remaining selections, just select a single constant
|
||||
if not new_selections:
|
||||
new_selections.append(DEFAULT_SELECTION.copy())
|
||||
|
||||
scope.expression.set("expressions", new_selections)
|
||||
if removed:
|
||||
scope.clear_cache()
|
||||
return removed_indexes
|
||||
|
||||
|
||||
|
|
|
@ -365,9 +365,9 @@ class _Resolver:
|
|||
def all_columns(self):
|
||||
"""All available columns of all sources in this scope"""
|
||||
if self._all_columns is None:
|
||||
self._all_columns = set(
|
||||
self._all_columns = {
|
||||
column for columns in self._get_all_source_columns().values() for column in columns
|
||||
)
|
||||
}
|
||||
return self._all_columns
|
||||
|
||||
def get_source_columns(self, name, only_visible=False):
|
||||
|
|
|
@ -361,7 +361,7 @@ def _simplify_binary(expression, a, b):
|
|||
return boolean
|
||||
elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
|
||||
a, b = extract_date(a), extract_interval(b)
|
||||
if b:
|
||||
if a and b:
|
||||
if isinstance(expression, exp.Add):
|
||||
return date_literal(a + b)
|
||||
if isinstance(expression, exp.Sub):
|
||||
|
@ -369,7 +369,7 @@ def _simplify_binary(expression, a, b):
|
|||
elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
|
||||
a, b = extract_interval(a), extract_date(b)
|
||||
# you cannot subtract a date from an interval
|
||||
if a and isinstance(expression, exp.Add):
|
||||
if a and b and isinstance(expression, exp.Add):
|
||||
return date_literal(a + b)
|
||||
|
||||
return None
|
||||
|
@ -424,9 +424,15 @@ def eval_boolean(expression, a, b):
|
|||
|
||||
|
||||
def extract_date(cast):
|
||||
if cast.args["to"].this == exp.DataType.Type.DATE:
|
||||
return datetime.date.fromisoformat(cast.name)
|
||||
return None
|
||||
# The "fromisoformat" conversion could fail if the cast is used on an identifier,
|
||||
# so in that case we can't extract the date.
|
||||
try:
|
||||
if cast.args["to"].this == exp.DataType.Type.DATE:
|
||||
return datetime.date.fromisoformat(cast.name)
|
||||
if cast.args["to"].this == exp.DataType.Type.DATETIME:
|
||||
return datetime.datetime.fromisoformat(cast.name)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def extract_interval(interval):
|
||||
|
@ -450,7 +456,8 @@ def extract_interval(interval):
|
|||
|
||||
|
||||
def date_literal(date):
|
||||
return exp.Cast(this=exp.Literal.string(date), to=exp.DataType.build("DATE"))
|
||||
expr_type = exp.DataType.build("DATETIME" if isinstance(date, datetime.datetime) else "DATE")
|
||||
return exp.Cast(this=exp.Literal.string(date), to=expr_type)
|
||||
|
||||
|
||||
def boolean_literal(condition):
|
||||
|
|
|
@ -15,8 +15,7 @@ def unnest_subqueries(expression):
|
|||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
|
||||
>>> unnest_subqueries(expression).sql()
|
||||
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\
|
||||
AS _u_0 ON x.a = _u_0.a WHERE (_u_0.a = 1 AND NOT _u_0.a IS NULL)'
|
||||
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to unnest
|
||||
|
@ -173,10 +172,8 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
other = _other_operand(parent_predicate)
|
||||
|
||||
if isinstance(parent_predicate, exp.Exists):
|
||||
if value.this in group_by:
|
||||
parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
|
||||
else:
|
||||
parent_predicate = _replace(parent_predicate, "TRUE")
|
||||
alias = exp.column(list(key_aliases.values())[0], table_alias)
|
||||
parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
|
||||
elif isinstance(parent_predicate, exp.All):
|
||||
parent_predicate = _replace(
|
||||
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
|
||||
|
@ -197,6 +194,23 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
else:
|
||||
if is_subquery_projection:
|
||||
alias = exp.alias_(alias, select.parent.alias)
|
||||
|
||||
# COUNT always returns 0 on empty datasets, so we need take that into consideration here
|
||||
# by transforming all counts into 0 and using that as the coalesced value
|
||||
if value.find(exp.Count):
|
||||
|
||||
def remove_aggs(node):
|
||||
if isinstance(node, exp.Count):
|
||||
return exp.Literal.number(0)
|
||||
elif isinstance(node, exp.AggFunc):
|
||||
return exp.null()
|
||||
return node
|
||||
|
||||
alias = exp.Coalesce(
|
||||
this=alias,
|
||||
expressions=[value.this.transform(remove_aggs)],
|
||||
)
|
||||
|
||||
select.parent.replace(alias)
|
||||
|
||||
for key, column, predicate in keys:
|
||||
|
@ -209,9 +223,6 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
|
||||
if key in group_by:
|
||||
key.replace(nested)
|
||||
parent_predicate = _replace(
|
||||
parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
|
||||
)
|
||||
elif isinstance(predicate, exp.EQ):
|
||||
parent_predicate = _replace(
|
||||
parent_predicate,
|
||||
|
@ -245,7 +256,14 @@ def _other_operand(expression):
|
|||
if isinstance(expression, exp.In):
|
||||
return expression.this
|
||||
|
||||
if isinstance(expression, (exp.Any, exp.All)):
|
||||
return _other_operand(expression.parent)
|
||||
|
||||
if isinstance(expression, exp.Binary):
|
||||
return expression.right if expression.arg_key == "this" else expression.left
|
||||
return (
|
||||
expression.right
|
||||
if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All))
|
||||
else expression.left
|
||||
)
|
||||
|
||||
return None
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||
import abc
|
||||
import typing as t
|
||||
|
||||
import sqlglot
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.errors import SchemaError
|
||||
from sqlglot.helper import dict_depth
|
||||
|
@ -157,10 +158,10 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
visible: t.Optional[t.Dict] = None,
|
||||
dialect: t.Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__(schema)
|
||||
self.visible = visible or {}
|
||||
self.dialect = dialect
|
||||
self.visible = visible or {}
|
||||
self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
|
||||
super().__init__(self._normalize(schema or {}))
|
||||
|
||||
@classmethod
|
||||
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
|
||||
|
@ -180,6 +181,33 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
}
|
||||
)
|
||||
|
||||
def _normalize(self, schema: t.Dict) -> t.Dict:
|
||||
"""
|
||||
Converts all identifiers in the schema into lowercase, unless they're quoted.
|
||||
|
||||
Args:
|
||||
schema: the schema to normalize.
|
||||
|
||||
Returns:
|
||||
The normalized schema mapping.
|
||||
"""
|
||||
flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
|
||||
|
||||
normalized_mapping: t.Dict = {}
|
||||
for keys in flattened_schema:
|
||||
columns = _nested_get(schema, *zip(keys, keys))
|
||||
assert columns is not None
|
||||
|
||||
normalized_keys = [self._normalize_name(key) for key in keys]
|
||||
for column_name, column_type in columns.items():
|
||||
_nested_set(
|
||||
normalized_mapping,
|
||||
normalized_keys + [self._normalize_name(column_name)],
|
||||
column_type,
|
||||
)
|
||||
|
||||
return normalized_mapping
|
||||
|
||||
def add_table(
|
||||
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
|
||||
) -> None:
|
||||
|
@ -204,6 +232,19 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
)
|
||||
self.mapping_trie = self._build_trie(self.mapping)
|
||||
|
||||
def _normalize_name(self, name: str) -> str:
|
||||
try:
|
||||
identifier: t.Optional[exp.Expression] = sqlglot.parse_one(
|
||||
name, read=self.dialect, into=exp.Identifier
|
||||
)
|
||||
except:
|
||||
identifier = exp.to_identifier(name)
|
||||
assert isinstance(identifier, exp.Identifier)
|
||||
|
||||
if identifier.quoted:
|
||||
return identifier.name
|
||||
return identifier.name.lower()
|
||||
|
||||
def _depth(self) -> int:
|
||||
# The columns themselves are a mapping, but we don't want to include those
|
||||
return super()._depth() - 1
|
||||
|
|
67
sqlglot/serde.py
Normal file
67
sqlglot/serde.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import expressions as exp
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
JSON = t.Union[dict, list, str, float, int, bool]
|
||||
Node = t.Union[t.List["Node"], exp.DataType.Type, exp.Expression, JSON]
|
||||
|
||||
|
||||
def dump(node: Node) -> JSON:
|
||||
"""
|
||||
Recursively dump an AST into a JSON-serializable dict.
|
||||
"""
|
||||
if isinstance(node, list):
|
||||
return [dump(i) for i in node]
|
||||
if isinstance(node, exp.DataType.Type):
|
||||
return {
|
||||
"class": "DataType.Type",
|
||||
"value": node.value,
|
||||
}
|
||||
if isinstance(node, exp.Expression):
|
||||
klass = node.__class__.__qualname__
|
||||
if node.__class__.__module__ != exp.__name__:
|
||||
klass = f"{node.__module__}.{klass}"
|
||||
obj = {
|
||||
"class": klass,
|
||||
"args": {k: dump(v) for k, v in node.args.items() if v is not None and v != []},
|
||||
}
|
||||
if node.type:
|
||||
obj["type"] = node.type.sql()
|
||||
if node.comments:
|
||||
obj["comments"] = node.comments
|
||||
return obj
|
||||
return node
|
||||
|
||||
|
||||
def load(obj: JSON) -> Node:
|
||||
"""
|
||||
Recursively load a dict (as returned by `dump`) into an AST.
|
||||
"""
|
||||
if isinstance(obj, list):
|
||||
return [load(i) for i in obj]
|
||||
if isinstance(obj, dict):
|
||||
class_name = obj["class"]
|
||||
|
||||
if class_name == "DataType.Type":
|
||||
return exp.DataType.Type(obj["value"])
|
||||
|
||||
if "." in class_name:
|
||||
module_path, class_name = class_name.rsplit(".", maxsplit=1)
|
||||
module = __import__(module_path, fromlist=[class_name])
|
||||
else:
|
||||
module = exp
|
||||
|
||||
klass = getattr(module, class_name)
|
||||
|
||||
expression = klass(**{k: load(v) for k, v in obj["args"].items()})
|
||||
type_ = obj.get("type")
|
||||
if type_:
|
||||
expression.type = exp.DataType.build(type_)
|
||||
comments = obj.get("comments")
|
||||
if comments:
|
||||
expression.comments = load(comments)
|
||||
return expression
|
||||
return obj
|
|
@ -86,6 +86,7 @@ class TokenType(AutoName):
|
|||
VARBINARY = auto()
|
||||
JSON = auto()
|
||||
JSONB = auto()
|
||||
TIME = auto()
|
||||
TIMESTAMP = auto()
|
||||
TIMESTAMPTZ = auto()
|
||||
TIMESTAMPLTZ = auto()
|
||||
|
@ -181,6 +182,7 @@ class TokenType(AutoName):
|
|||
FUNCTION = auto()
|
||||
FROM = auto()
|
||||
GENERATED = auto()
|
||||
GLOBAL = auto()
|
||||
GROUP_BY = auto()
|
||||
GROUPING_SETS = auto()
|
||||
HAVING = auto()
|
||||
|
@ -656,6 +658,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"FLOAT4": TokenType.FLOAT,
|
||||
"FLOAT8": TokenType.DOUBLE,
|
||||
"DOUBLE": TokenType.DOUBLE,
|
||||
"DOUBLE PRECISION": TokenType.DOUBLE,
|
||||
"JSON": TokenType.JSON,
|
||||
"CHAR": TokenType.CHAR,
|
||||
"NCHAR": TokenType.NCHAR,
|
||||
|
@ -671,6 +674,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"BLOB": TokenType.VARBINARY,
|
||||
"BYTEA": TokenType.VARBINARY,
|
||||
"VARBINARY": TokenType.VARBINARY,
|
||||
"TIME": TokenType.TIME,
|
||||
"TIMESTAMP": TokenType.TIMESTAMP,
|
||||
"TIMESTAMPTZ": TokenType.TIMESTAMPTZ,
|
||||
"TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ,
|
||||
|
@ -721,6 +725,8 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
COMMENTS = ["--", ("/*", "*/")]
|
||||
KEYWORD_TRIE = None # autofilled
|
||||
|
||||
IDENTIFIER_CAN_START_WITH_DIGIT = False
|
||||
|
||||
__slots__ = (
|
||||
"sql",
|
||||
"size",
|
||||
|
@ -938,17 +944,24 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
elif self._peek.upper() == "E" and not scientific: # type: ignore
|
||||
scientific += 1
|
||||
self._advance()
|
||||
elif self._peek.isalpha(): # type: ignore
|
||||
self._add(TokenType.NUMBER)
|
||||
elif self._peek.isidentifier(): # type: ignore
|
||||
number_text = self._text
|
||||
literal = []
|
||||
while self._peek.isalpha(): # type: ignore
|
||||
while self._peek.isidentifier(): # type: ignore
|
||||
literal.append(self._peek.upper()) # type: ignore
|
||||
self._advance()
|
||||
|
||||
literal = "".join(literal) # type: ignore
|
||||
token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) # type: ignore
|
||||
|
||||
if token_type:
|
||||
self._add(TokenType.NUMBER, number_text)
|
||||
self._add(TokenType.DCOLON, "::")
|
||||
return self._add(token_type, literal) # type: ignore
|
||||
elif self.IDENTIFIER_CAN_START_WITH_DIGIT:
|
||||
return self._add(TokenType.VAR)
|
||||
|
||||
self._add(TokenType.NUMBER, number_text)
|
||||
return self._advance(-len(literal))
|
||||
else:
|
||||
return self._add(TokenType.NUMBER)
|
||||
|
|
|
@ -82,6 +82,27 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions.
|
||||
This transforms removes the precision from parameterized types in expressions.
|
||||
"""
|
||||
return expression.transform(
|
||||
lambda node: exp.DataType(
|
||||
**{
|
||||
**node.args,
|
||||
"expressions": [
|
||||
node_expression
|
||||
for node_expression in node.expressions
|
||||
if isinstance(node_expression, exp.DataType)
|
||||
],
|
||||
}
|
||||
)
|
||||
if isinstance(node, exp.DataType)
|
||||
else node,
|
||||
)
|
||||
|
||||
|
||||
def preprocess(
|
||||
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
|
||||
to_sql: t.Callable[[Generator, exp.Expression], str],
|
||||
|
@ -121,3 +142,6 @@ def delegate(attr: str) -> t.Callable:
|
|||
|
||||
UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))}
|
||||
ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))}
|
||||
REMOVE_PRECISION_PARAMETERIZED_TYPES = {
|
||||
exp.Cast: preprocess([remove_precision_parameterized_types], delegate("cast_sql"))
|
||||
}
|
||||
|
|
|
@ -52,7 +52,7 @@ def in_trie(trie: t.Dict, key: key) -> t.Tuple[int, t.Dict]:
|
|||
|
||||
Returns:
|
||||
A pair `(value, subtrie)`, where `subtrie` is the sub-trie we get at the point where the search stops, and `value`
|
||||
is either 0 (search was unsuccessfull), 1 (`value` is a prefix of a keyword in `trie`) or 2 (`key is in `trie`).
|
||||
is either 0 (search was unsuccessful), 1 (`value` is a prefix of a keyword in `trie`) or 2 (`key is in `trie`).
|
||||
"""
|
||||
if not key:
|
||||
return (0, trie)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue