1
0
Fork 0

Merging upstream version 10.5.2.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:03:38 +01:00
parent 77197f1e44
commit e0f3bbb5f3
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
58 changed files with 1480 additions and 383 deletions

View file

@ -32,7 +32,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType
__version__ = "10.4.2"
__version__ = "10.5.2"
pretty = False
@ -60,9 +60,9 @@ def parse(
def parse_one(
sql: str,
read: t.Optional[str | Dialect] = None,
into: t.Optional[Expression | str] = None,
into: t.Optional[t.Type[Expression] | str] = None,
**opts,
) -> t.Optional[Expression]:
) -> Expression:
"""
Parses the given SQL string and returns a syntax tree for the first parsed SQL statement.
@ -83,7 +83,12 @@ def parse_one(
else:
result = dialect.parse(sql, **opts)
return result[0] if result else None
for expression in result:
if not expression:
raise ParseError(f"No expression was parsed from '{sql}'")
return expression
else:
raise ParseError(f"No expression was parsed from '{sql}'")
def transpile(

View file

@ -2,7 +2,7 @@
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
datestrtodate_sql,
@ -46,8 +46,9 @@ def _date_add_sql(data_type, kind):
def _derived_table_values_to_unnest(self, expression):
if not isinstance(expression.unnest().parent, exp.From):
expression = transforms.remove_precision_parameterized_types(expression)
return self.values_sql(expression)
rows = [list(tuple_exp.find_all(exp.Literal)) for tuple_exp in expression.find_all(exp.Tuple)]
rows = [tuple_exp.expressions for tuple_exp in expression.find_all(exp.Tuple)]
structs = []
for row in rows:
aliases = [
@ -118,6 +119,7 @@ class BigQuery(Dialect):
"BEGIN TRANSACTION": TokenType.BEGIN,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
"CURRENT_TIME": TokenType.CURRENT_TIME,
"DECLARE": TokenType.COMMAND,
"GEOGRAPHY": TokenType.GEOGRAPHY,
"FLOAT64": TokenType.DOUBLE,
"INT64": TokenType.BIGINT,
@ -166,6 +168,7 @@ class BigQuery(Dialect):
class Generator(generator.Generator):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
exp.DateSub: _date_add_sql("DATE", "SUB"),

View file

@ -1,5 +1,7 @@
from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
from sqlglot.parser import parse_var_map
@ -22,6 +24,7 @@ class ClickHouse(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ASOF": TokenType.ASOF,
"GLOBAL": TokenType.GLOBAL,
"DATETIME64": TokenType.DATETIME,
"FINAL": TokenType.FINAL,
"FLOAT32": TokenType.FLOAT,
@ -37,14 +40,32 @@ class ClickHouse(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"MAP": parse_var_map,
"QUANTILE": lambda params, args: exp.Quantile(this=args, quantile=params),
"QUANTILES": lambda params, args: exp.Quantiles(parameters=params, expressions=args),
"QUANTILEIF": lambda params, args: exp.QuantileIf(parameters=params, expressions=args),
}
RANGE_PARSERS = {
**parser.Parser.RANGE_PARSERS,
TokenType.GLOBAL: lambda self, this: self._match(TokenType.IN)
and self._parse_in(this, is_global=True),
}
JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} # type: ignore
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore
def _parse_table(self, schema=False):
this = super()._parse_table(schema)
def _parse_in(
self, this: t.Optional[exp.Expression], is_global: bool = False
) -> exp.Expression:
this = super()._parse_in(this)
this.set("is_global", is_global)
return this
def _parse_table(
self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None
) -> t.Optional[exp.Expression]:
this = super()._parse_table(schema=schema, alias_tokens=alias_tokens)
if self._match(TokenType.FINAL):
this = self.expression(exp.Final, this=this)
@ -76,6 +97,16 @@ class ClickHouse(Dialect):
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.Quantile: lambda self, e: f"quantile{self._param_args_sql(e, 'quantile', 'this')}",
exp.Quantiles: lambda self, e: f"quantiles{self._param_args_sql(e, 'parameters', 'expressions')}",
exp.QuantileIf: lambda self, e: f"quantileIf{self._param_args_sql(e, 'parameters', 'expressions')}",
}
EXPLICIT_UNION = True
def _param_args_sql(
self, expression: exp.Expression, params_name: str, args_name: str
) -> str:
params = self.format_args(self.expressions(expression, params_name))
args = self.format_args(self.expressions(expression, args_name))
return f"({params})({args})"

View file

@ -381,3 +381,20 @@ def timestrtotime_sql(self, expression: exp.TimeStrToTime) -> str:
def datestrtodate_sql(self, expression: exp.DateStrToDate) -> str:
return f"CAST({self.sql(expression, 'this')} AS DATE)"
def trim_sql(self, expression):
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
remove_chars = self.sql(expression, "expression")
collation = self.sql(expression, "collation")
# Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
if not remove_chars and not collation:
return self.trim_sql(expression)
trim_type = f"{trim_type} " if trim_type else ""
remove_chars = f"{remove_chars} " if remove_chars else ""
from_part = "FROM " if trim_type or remove_chars else ""
collation = f" COLLATE {collation}" if collation else ""
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"

View file

@ -175,14 +175,6 @@ class Hive(Dialect):
ESCAPES = ["\\"]
ENCODE = "utf-8"
NUMERIC_LITERALS = {
"L": "BIGINT",
"S": "SMALLINT",
"Y": "TINYINT",
"D": "DOUBLE",
"F": "FLOAT",
"BD": "DECIMAL",
}
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ADD ARCHIVE": TokenType.COMMAND,
@ -191,9 +183,21 @@ class Hive(Dialect):
"ADD FILES": TokenType.COMMAND,
"ADD JAR": TokenType.COMMAND,
"ADD JARS": TokenType.COMMAND,
"MSCK REPAIR": TokenType.COMMAND,
"WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES,
}
NUMERIC_LITERALS = {
"L": "BIGINT",
"S": "SMALLINT",
"Y": "TINYINT",
"D": "DOUBLE",
"F": "FLOAT",
"BD": "DECIMAL",
}
IDENTIFIER_CAN_START_WITH_DIGIT = True
class Parser(parser.Parser):
STRICT_CAST = False
@ -315,6 +319,7 @@ class Hive(Dialect):
exp.RowFormatSerdeProperty: lambda self, e: f"ROW FORMAT SERDE {self.sql(e, 'this')}",
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
exp.LastDateOfMonth: rename_func("LAST_DAY"),
}
WITH_PROPERTIES = {exp.Property}
@ -342,4 +347,6 @@ class Hive(Dialect):
and not expression.expressions
):
expression = exp.DataType.build("text")
elif expression.this in exp.DataType.TEMPORAL_TYPES:
expression = exp.DataType.build(expression.this)
return super().datatype_sql(expression)

View file

@ -1,7 +1,7 @@
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func
from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sql
from sqlglot.helper import csv
from sqlglot.tokens import TokenType
@ -64,6 +64,7 @@ class Oracle(Dialect):
**transforms.UNALIAS_GROUP, # type: ignore
exp.ILike: no_ilike_sql,
exp.Limit: _limit_sql,
exp.Trim: trim_sql,
exp.Matches: rename_func("DECODE"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",

View file

@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
no_tablesample_sql,
no_trycast_sql,
str_position_sql,
trim_sql,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@ -81,23 +82,6 @@ def _substring_sql(self, expression):
return f"SUBSTRING({this}{from_part}{for_part})"
def _trim_sql(self, expression):
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
remove_chars = self.sql(expression, "expression")
collation = self.sql(expression, "collation")
# Use TRIM/LTRIM/RTRIM syntax if the expression isn't postgres-specific
if not remove_chars and not collation:
return self.trim_sql(expression)
trim_type = f"{trim_type} " if trim_type else ""
remove_chars = f"{remove_chars} " if remove_chars else ""
from_part = "FROM " if trim_type or remove_chars else ""
collation = f" COLLATE {collation}" if collation else ""
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def _string_agg_sql(self, expression):
expression = expression.copy()
separator = expression.args.get("separator") or exp.Literal.string(",")
@ -248,7 +232,6 @@ class Postgres(Dialect):
"COMMENT ON": TokenType.COMMAND,
"DECLARE": TokenType.COMMAND,
"DO": TokenType.COMMAND,
"DOUBLE PRECISION": TokenType.DOUBLE,
"GENERATED": TokenType.GENERATED,
"GRANT": TokenType.COMMAND,
"HSTORE": TokenType.HSTORE,
@ -318,7 +301,7 @@ class Postgres(Dialect):
exp.Substring: _substring_sql,
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TableSample: no_tablesample_sql,
exp.Trim: _trim_sql,
exp.Trim: trim_sql,
exp.TryCast: no_trycast_sql,
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
exp.DataType: _datatype_sql,

View file

@ -195,7 +195,6 @@ class Snowflake(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"QUALIFY": TokenType.QUALIFY,
"DOUBLE PRECISION": TokenType.DOUBLE,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
@ -294,3 +293,10 @@ class Snowflake(Dialect):
)
return self.no_identify(lambda: super(self.__class__, self).select_sql(expression))
return super().select_sql(expression)
def describe_sql(self, expression: exp.Describe) -> str:
# Default to table if kind is unknown
kind_value = expression.args.get("kind") or "TABLE"
kind = f" {kind_value}" if kind_value else ""
this = f" {self.sql(expression, 'this')}"
return f"DESCRIBE{kind}{this}"

View file

@ -75,6 +75,20 @@ def _parse_format(args):
)
def _parse_eomonth(args):
date = seq_get(args, 0)
month_lag = seq_get(args, 1)
unit = DATE_DELTA_INTERVAL.get("month")
if month_lag is None:
return exp.LastDateOfMonth(this=date)
# Remove month lag argument in parser as its compared with the number of arguments of the resulting class
args.remove(month_lag)
return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit))
def generate_date_delta_with_unit_sql(self, e):
func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF"
return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})"
@ -256,12 +270,14 @@ class TSQL(Dialect):
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
"DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True),
"DATEPART": _format_time_lambda(exp.TimeToStr),
"GETDATE": exp.CurrentDate.from_arg_list,
"GETDATE": exp.CurrentTimestamp.from_arg_list,
"SYSDATETIME": exp.CurrentTimestamp.from_arg_list,
"IIF": exp.If.from_arg_list,
"LEN": exp.Length.from_arg_list,
"REPLICATE": exp.Repeat.from_arg_list,
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
"FORMAT": _parse_format,
"EOMONTH": _parse_eomonth,
}
VAR_LENGTH_DATATYPES = {
@ -271,6 +287,9 @@ class TSQL(Dialect):
DataType.Type.NCHAR,
}
# https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-temporary#create-a-temporary-table
TABLE_PREFIX_TOKENS = {TokenType.HASH}
def _parse_convert(self, strict):
to = self._parse_types()
self._match(TokenType.COMMA)
@ -323,6 +342,7 @@ class TSQL(Dialect):
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),
exp.CurrentTimestamp: rename_func("GETDATE"),
exp.If: rename_func("IIF"),
exp.NumberToStr: _format_sql,
exp.TimeToStr: _format_sql,

View file

@ -22,6 +22,7 @@ from sqlglot.helper import (
split_num_words,
subclasses,
)
from sqlglot.tokens import Token
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import Dialect
@ -457,6 +458,23 @@ class Expression(metaclass=_Expression):
assert isinstance(self, type_)
return self
def dump(self):
"""
Dump this Expression to a JSON-serializable dict.
"""
from sqlglot.serde import dump
return dump(self)
@classmethod
def load(cls, obj):
"""
Load a dict (as returned by `Expression.dump`) into an Expression instance.
"""
from sqlglot.serde import load
return load(obj)
class Condition(Expression):
def and_(self, *expressions, dialect=None, **opts):
@ -631,11 +649,15 @@ class Create(Expression):
"replace": False,
"unique": False,
"materialized": False,
"data": False,
"statistics": False,
"no_primary_index": False,
"indexes": False,
}
class Describe(Expression):
pass
arg_types = {"this": True, "kind": False}
class Set(Expression):
@ -731,7 +753,7 @@ class Column(Condition):
class ColumnDef(Expression):
arg_types = {
"this": True,
"kind": True,
"kind": False,
"constraints": False,
"exists": False,
}
@ -879,7 +901,15 @@ class Identifier(Expression):
class Index(Expression):
arg_types = {"this": False, "table": False, "where": False, "columns": False}
arg_types = {
"this": False,
"table": False,
"where": False,
"columns": False,
"unique": False,
"primary": False,
"amp": False, # teradata
}
class Insert(Expression):
@ -1361,6 +1391,7 @@ class Table(Expression):
"laterals": False,
"joins": False,
"pivots": False,
"hints": False,
}
@ -1818,7 +1849,12 @@ class Select(Subqueryable):
join.this.replace(join.this.subquery())
if join_type:
natural: t.Optional[Token]
side: t.Optional[Token]
kind: t.Optional[Token]
natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore
if natural:
join.set("natural", True)
if side:
@ -2111,6 +2147,7 @@ class DataType(Expression):
JSON = auto()
JSONB = auto()
INTERVAL = auto()
TIME = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto()
@ -2171,11 +2208,24 @@ class DataType(Expression):
}
@classmethod
def build(cls, dtype, **kwargs) -> DataType:
return DataType(
this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()],
**kwargs,
)
def build(
cls, dtype: str | DataType.Type, dialect: t.Optional[str | Dialect] = None, **kwargs
) -> DataType:
from sqlglot import parse_one
if isinstance(dtype, str):
data_type_exp: t.Optional[Expression]
if dtype.upper() in cls.Type.__members__:
data_type_exp = DataType(this=DataType.Type[dtype.upper()])
else:
data_type_exp = parse_one(dtype, read=dialect, into=DataType)
if data_type_exp is None:
raise ValueError(f"Unparsable data type value: {dtype}")
elif isinstance(dtype, DataType.Type):
data_type_exp = DataType(this=dtype)
else:
raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type")
return DataType(**{**data_type_exp.args, **kwargs})
# https://www.postgresql.org/docs/15/datatype-pseudo.html
@ -2429,6 +2479,7 @@ class In(Predicate):
"query": False,
"unnest": False,
"field": False,
"is_global": False,
}
@ -2678,6 +2729,10 @@ class DatetimeTrunc(Func, TimeUnit):
arg_types = {"this": True, "unit": True, "zone": False}
class LastDateOfMonth(Func):
pass
class Extract(Func):
arg_types = {"this": True, "expression": True}
@ -2815,7 +2870,13 @@ class Length(Func):
class Levenshtein(Func):
arg_types = {"this": True, "expression": False}
arg_types = {
"this": True,
"expression": False,
"ins_cost": False,
"del_cost": False,
"sub_cost": False,
}
class Ln(Func):
@ -2890,6 +2951,16 @@ class Quantile(AggFunc):
arg_types = {"this": True, "quantile": True}
# Clickhouse-specific:
# https://clickhouse.com/docs/en/sql-reference/aggregate-functions/reference/quantiles/#quantiles
class Quantiles(AggFunc):
arg_types = {"parameters": True, "expressions": True}
class QuantileIf(AggFunc):
arg_types = {"parameters": True, "expressions": True}
class ApproxQuantile(Quantile):
arg_types = {"this": True, "quantile": True, "accuracy": False}
@ -2962,8 +3033,10 @@ class StrToTime(Func):
arg_types = {"this": True, "format": True}
# Spark allows unix_timestamp()
# https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.unix_timestamp.html
class StrToUnix(Func):
arg_types = {"this": True, "format": True}
arg_types = {"this": False, "format": False}
class NumberToStr(Func):
@ -3131,7 +3204,7 @@ def maybe_parse(
dialect=None,
prefix=None,
**opts,
) -> t.Optional[Expression]:
) -> Expression:
"""Gracefully handle a possible string or expression.
Example:
@ -3627,11 +3700,11 @@ def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]:
if not isinstance(sql_path, str):
raise ValueError(f"Invalid type provided for a table: {type(sql_path)}")
catalog, db, table_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 3)]
catalog, db, table_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 3))
return Table(this=table_name, db=db, catalog=catalog, **kwargs)
def to_column(sql_path: str, **kwargs) -> Column:
def to_column(sql_path: str | Column, **kwargs) -> Column:
"""
Create a column from a `[table].[column]` sql path. Schema is optional.
@ -3646,7 +3719,7 @@ def to_column(sql_path: str, **kwargs) -> Column:
return sql_path
if not isinstance(sql_path, str):
raise ValueError(f"Invalid type provided for column: {type(sql_path)}")
table_name, column_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 2)]
table_name, column_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 2))
return Column(this=column_name, table=table_name, **kwargs)
@ -3748,7 +3821,7 @@ def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
def values(
values: t.Iterable[t.Tuple[t.Any, ...]],
alias: t.Optional[str] = None,
columns: t.Optional[t.Iterable[str]] = None,
columns: t.Optional[t.Iterable[str] | t.Dict[str, DataType]] = None,
) -> Values:
"""Build VALUES statement.
@ -3759,7 +3832,10 @@ def values(
Args:
values: values statements that will be converted to SQL
alias: optional alias
columns: Optional list of ordered column names. An alias is required when providing column names.
columns: Optional list of ordered column names or ordered dictionary of column names to types.
If either are provided then an alias is also required.
If a dictionary is provided then the first column of the values will be casted to the expected type
in order to help with type inference.
Returns:
Values: the Values expression object
@ -3771,8 +3847,15 @@ def values(
if columns
else TableAlias(this=to_identifier(alias) if alias else None)
)
expressions = [convert(tup) for tup in values]
if columns and isinstance(columns, dict):
types = list(columns.values())
expressions[0].set(
"expressions",
[Cast(this=x, to=types[i]) for i, x in enumerate(expressions[0].expressions)],
)
return Values(
expressions=[convert(tup) for tup in values],
expressions=expressions,
alias=table_alias,
)

View file

@ -50,7 +50,7 @@ class Generator:
The default is on the smaller end because the length only represents a segment and not the true
line length.
Default: 80
comments: Whether or not to preserve comments in the ouput SQL code.
comments: Whether or not to preserve comments in the output SQL code.
Default: True
"""
@ -236,7 +236,10 @@ class Generator:
return sql
sep = "\n" if self.pretty else " "
comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments)
comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments if comment)
if not comments:
return sql
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
return f"{comments}{self.sep()}{sql}"
@ -362,10 +365,10 @@ class Generator:
kind = self.sql(expression, "kind")
constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
exists = "IF NOT EXISTS " if expression.args.get("exists") else ""
kind = f" {kind}" if kind else ""
constraints = f" {constraints}" if constraints else ""
if not constraints:
return f"{exists}{column} {kind}"
return f"{exists}{column} {kind} {constraints}"
return f"{exists}{column}{kind}{constraints}"
def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str:
this = self.sql(expression, "this")
@ -416,7 +419,7 @@ class Generator:
this = self.sql(expression, "this")
kind = self.sql(expression, "kind").upper()
expression_sql = self.sql(expression, "expression")
expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else ""
expression_sql = f" AS{self.sep()}{expression_sql}" if expression_sql else ""
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
transient = (
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
@ -427,6 +430,40 @@ class Generator:
unique = " UNIQUE" if expression.args.get("unique") else ""
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
properties = self.sql(expression, "properties")
data = expression.args.get("data")
if data is None:
data = ""
elif data:
data = " WITH DATA"
else:
data = " WITH NO DATA"
statistics = expression.args.get("statistics")
if statistics is None:
statistics = ""
elif statistics:
statistics = " AND STATISTICS"
else:
statistics = " AND NO STATISTICS"
no_primary_index = " NO PRIMARY INDEX" if expression.args.get("no_primary_index") else ""
indexes = expression.args.get("indexes")
index_sql = ""
if indexes is not None:
indexes_sql = []
for index in indexes:
ind_unique = " UNIQUE" if index.args.get("unique") else ""
ind_primary = " PRIMARY" if index.args.get("primary") else ""
ind_amp = " AMP" if index.args.get("amp") else ""
ind_name = f" {index.name}" if index.name else ""
ind_columns = (
f' ({self.expressions(index, key="columns", flat=True)})'
if index.args.get("columns")
else ""
)
indexes_sql.append(
f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}"
)
index_sql = "".join(indexes_sql)
modifiers = "".join(
(
@ -438,7 +475,10 @@ class Generator:
materialized,
)
)
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties} {expression_sql}"
post_expression_modifiers = "".join((data, statistics, no_primary_index))
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}"
return self.prepend_ctes(expression, expression_sql)
def describe_sql(self, expression: exp.Describe) -> str:
@ -668,6 +708,8 @@ class Generator:
alias = self.sql(expression, "alias")
alias = f"{sep}{alias}" if alias else ""
hints = self.expressions(expression, key="hints", sep=", ", flat=True)
hints = f" WITH ({hints})" if hints else ""
laterals = self.expressions(expression, key="laterals", sep="")
joins = self.expressions(expression, key="joins", sep="")
pivots = self.expressions(expression, key="pivots", sep="")
@ -676,7 +718,7 @@ class Generator:
pivots = f"{pivots}{alias}"
alias = ""
return f"{table}{alias}{laterals}{joins}{pivots}"
return f"{table}{alias}{hints}{laterals}{joins}{pivots}"
def tablesample_sql(self, expression: exp.TableSample) -> str:
if self.alias_post_tablesample and expression.this.alias:
@ -1020,7 +1062,9 @@ class Generator:
if not partition and not order and not spec and alias:
return f"{this} {alias}"
return f"{this} ({alias}{partition_sql}{order_sql}{spec_sql})"
window_args = alias + partition_sql + order_sql + spec_sql
return f"{this} ({window_args.strip()})"
def window_spec_sql(self, expression: exp.WindowSpec) -> str:
kind = self.sql(expression, "kind")
@ -1130,6 +1174,8 @@ class Generator:
query = expression.args.get("query")
unnest = expression.args.get("unnest")
field = expression.args.get("field")
is_global = " GLOBAL" if expression.args.get("is_global") else ""
if query:
in_sql = self.wrap(query)
elif unnest:
@ -1138,7 +1184,8 @@ class Generator:
in_sql = self.sql(field)
else:
in_sql = f"({self.expressions(expression, flat=True)})"
return f"{self.sql(expression, 'this')} IN {in_sql}"
return f"{self.sql(expression, 'this')}{is_global} IN {in_sql}"
def in_unnest_op(self, unnest: exp.Unnest) -> str:
return f"(SELECT {self.sql(unnest)})"
@ -1433,7 +1480,7 @@ class Generator:
result_sqls = []
for i, e in enumerate(expressions):
sql = self.sql(e, comment=False)
comments = self.maybe_comment("", e)
comments = self.maybe_comment("", e) if isinstance(e, exp.Expression) else ""
if self.pretty:
if self._leading_comma:

View file

@ -131,7 +131,7 @@ def subclasses(
]
def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]:
def apply_index_offset(expressions: t.List[t.Optional[E]], offset: int) -> t.List[t.Optional[E]]:
"""
Applies an offset to a given integer literal expression.
@ -148,10 +148,10 @@ def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]:
expression = expressions[0]
if expression.is_int:
if expression and expression.is_int:
expression = expression.copy()
logger.warning("Applying array index offset (%s)", offset)
expression.args["this"] = str(int(expression.this) + offset)
expression.args["this"] = str(int(expression.this) + offset) # type: ignore
return [expression]
return expressions
@ -225,7 +225,7 @@ def open_file(file_name: str) -> t.TextIO:
return gzip.open(file_name, "rt", newline="")
return open(file_name, "rt", encoding="utf-8", newline="")
return open(file_name, encoding="utf-8", newline="")
@contextmanager
@ -256,7 +256,7 @@ def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
file.close()
def find_new_name(taken: t.Sequence[str], base: str) -> str:
def find_new_name(taken: t.Collection[str], base: str) -> str:
"""
Searches for a new name.
@ -356,6 +356,15 @@ def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any,
yield value
def count_params(function: t.Callable) -> int:
"""
Returns the number of formal parameters expected by a function, without counting "self"
and "cls", in case of instance and class methods, respectively.
"""
count = function.__code__.co_argcount
return count - 1 if inspect.ismethod(function) else count
def dict_depth(d: t.Dict) -> int:
"""
Get the nesting depth of a dictionary.
@ -374,6 +383,7 @@ def dict_depth(d: t.Dict) -> int:
Args:
d (dict): dictionary
Returns:
int: depth
"""

View file

@ -43,7 +43,7 @@ class TypeAnnotator:
},
exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr),
exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
exp.Alias: lambda self, expr: self._annotate_unary(expr),
exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),

View file

@ -57,7 +57,7 @@ def _join_is_used(scope, join, alias):
# But columns in the ON clause shouldn't count.
on = join.args.get("on")
if on:
on_clause_columns = set(id(column) for column in on.find_all(exp.Column))
on_clause_columns = {id(column) for column in on.find_all(exp.Column)}
else:
on_clause_columns = set()
return any(
@ -71,7 +71,7 @@ def _is_joined_on_all_unique_outputs(scope, join):
return False
_, join_keys, _ = join_condition(join)
remaining_unique_outputs = unique_outputs - set(c.name for c in join_keys)
remaining_unique_outputs = unique_outputs - {c.name for c in join_keys}
return not remaining_unique_outputs

View file

@ -67,11 +67,9 @@ def merge_ctes(expression, leave_tables_isolated=False):
singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
for outer_scope, inner_scope, table in singular_cte_selections:
inner_select = inner_scope.expression.unnest()
from_or_join = table.find_ancestor(exp.From, exp.Join)
if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
alias = table.alias_or_name
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, table, alias)
_merge_expressions(outer_scope, inner_scope, alias)
@ -80,18 +78,17 @@ def merge_ctes(expression, leave_tables_isolated=False):
_merge_order(outer_scope, inner_scope)
_merge_hints(outer_scope, inner_scope)
_pop_cte(inner_scope)
outer_scope.clear_cache()
return expression
def merge_derived_tables(expression, leave_tables_isolated=False):
for outer_scope in traverse_scope(expression):
for subquery in outer_scope.derived_tables:
inner_select = subquery.unnest()
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
alias = subquery.alias_or_name
inner_scope = outer_scope.sources[alias]
alias = subquery.alias_or_name
inner_scope = outer_scope.sources[alias]
if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, subquery, alias)
_merge_expressions(outer_scope, inner_scope, alias)
@ -99,21 +96,23 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
_merge_hints(outer_scope, inner_scope)
outer_scope.clear_cache()
return expression
def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
"""
Return True if `inner_select` can be merged into outer query.
Args:
outer_scope (Scope)
inner_select (exp.Select)
inner_scope (Scope)
leave_tables_isolated (bool)
from_or_join (exp.From|exp.Join)
Returns:
bool: True if can be merged
"""
inner_select = inner_scope.expression.unnest()
def _is_a_window_expression_in_unmergable_operation():
window_expressions = inner_select.find_all(exp.Window)
@ -133,10 +132,40 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
]
return any(window_expressions_in_unmergable)
def _outer_select_joins_on_inner_select_join():
"""
All columns from the inner select in the ON clause must be from the first FROM table.
That is, this can be merged:
SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a
^^^ ^
But this can't:
SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a
^^^ ^
"""
if not isinstance(from_or_join, exp.Join):
return False
alias = from_or_join.this.alias_or_name
on = from_or_join.args.get("on")
if not on:
return False
selections = [c.name for c in on.find_all(exp.Column) if c.table == alias]
inner_from = inner_scope.expression.args.get("from")
if not inner_from:
return False
inner_from_table = inner_from.expressions[0].alias_or_name
inner_projections = {s.alias_or_name: s for s in inner_scope.selects}
return any(
col.table != inner_from_table
for selection in selections
for col in inner_projections[selection].find_all(exp.Column)
)
return (
isinstance(outer_scope.expression, exp.Select)
and isinstance(inner_select, exp.Select)
and isinstance(inner_select, exp.Select)
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
and inner_select.args.get("from")
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
@ -153,6 +182,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])
)
)
and not _outer_select_joins_on_inner_select_join()
and not _is_a_window_expression_in_unmergable_operation()
)
@ -168,7 +198,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
"""
taken = set(outer_scope.selected_sources)
conflicts = taken.intersection(set(inner_scope.selected_sources))
conflicts = conflicts - {alias}
conflicts -= {alias}
for conflict in conflicts:
new_name = find_new_name(taken, conflict)

View file

@ -15,6 +15,7 @@ from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
from sqlglot.schema import ensure_schema
RULES = (
lower_identities,
@ -51,12 +52,13 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
If no schema is provided then the default schema defined at `sqlgot.schema` will be used
db (str): specify the default database, as might be set by a `USE DATABASE db` statement
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
rules (list): sequence of optimizer rules to use
rules (sequence): sequence of optimizer rules to use
**kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.
Returns:
sqlglot.Expression: optimized expression
"""
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema or sqlglot.schema, **kwargs}
schema = ensure_schema(schema or sqlglot.schema)
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
expression = expression.copy()
for rule in rules:

View file

@ -79,6 +79,7 @@ def _remove_unused_selections(scope, parent_selections):
order_refs = set()
new_selections = []
removed = False
for i, selection in enumerate(scope.selects):
if (
SELECT_ALL in parent_selections
@ -88,12 +89,15 @@ def _remove_unused_selections(scope, parent_selections):
new_selections.append(selection)
else:
removed_indexes.append(i)
removed = True
# If there are no remaining selections, just select a single constant
if not new_selections:
new_selections.append(DEFAULT_SELECTION.copy())
scope.expression.set("expressions", new_selections)
if removed:
scope.clear_cache()
return removed_indexes

View file

@ -365,9 +365,9 @@ class _Resolver:
def all_columns(self):
"""All available columns of all sources in this scope"""
if self._all_columns is None:
self._all_columns = set(
self._all_columns = {
column for columns in self._get_all_source_columns().values() for column in columns
)
}
return self._all_columns
def get_source_columns(self, name, only_visible=False):

View file

@ -361,7 +361,7 @@ def _simplify_binary(expression, a, b):
return boolean
elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
a, b = extract_date(a), extract_interval(b)
if b:
if a and b:
if isinstance(expression, exp.Add):
return date_literal(a + b)
if isinstance(expression, exp.Sub):
@ -369,7 +369,7 @@ def _simplify_binary(expression, a, b):
elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
a, b = extract_interval(a), extract_date(b)
# you cannot subtract a date from an interval
if a and isinstance(expression, exp.Add):
if a and b and isinstance(expression, exp.Add):
return date_literal(a + b)
return None
@ -424,9 +424,15 @@ def eval_boolean(expression, a, b):
def extract_date(cast):
if cast.args["to"].this == exp.DataType.Type.DATE:
return datetime.date.fromisoformat(cast.name)
return None
# The "fromisoformat" conversion could fail if the cast is used on an identifier,
# so in that case we can't extract the date.
try:
if cast.args["to"].this == exp.DataType.Type.DATE:
return datetime.date.fromisoformat(cast.name)
if cast.args["to"].this == exp.DataType.Type.DATETIME:
return datetime.datetime.fromisoformat(cast.name)
except ValueError:
return None
def extract_interval(interval):
@ -450,7 +456,8 @@ def extract_interval(interval):
def date_literal(date):
return exp.Cast(this=exp.Literal.string(date), to=exp.DataType.build("DATE"))
expr_type = exp.DataType.build("DATETIME" if isinstance(date, datetime.datetime) else "DATE")
return exp.Cast(this=exp.Literal.string(date), to=expr_type)
def boolean_literal(condition):

View file

@ -15,8 +15,7 @@ def unnest_subqueries(expression):
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
>>> unnest_subqueries(expression).sql()
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\
AS _u_0 ON x.a = _u_0.a WHERE (_u_0.a = 1 AND NOT _u_0.a IS NULL)'
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
Args:
expression (sqlglot.Expression): expression to unnest
@ -173,10 +172,8 @@ def decorrelate(select, parent_select, external_columns, sequence):
other = _other_operand(parent_predicate)
if isinstance(parent_predicate, exp.Exists):
if value.this in group_by:
parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
else:
parent_predicate = _replace(parent_predicate, "TRUE")
alias = exp.column(list(key_aliases.values())[0], table_alias)
parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
elif isinstance(parent_predicate, exp.All):
parent_predicate = _replace(
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
@ -197,6 +194,23 @@ def decorrelate(select, parent_select, external_columns, sequence):
else:
if is_subquery_projection:
alias = exp.alias_(alias, select.parent.alias)
# COUNT always returns 0 on empty datasets, so we need take that into consideration here
# by transforming all counts into 0 and using that as the coalesced value
if value.find(exp.Count):
def remove_aggs(node):
if isinstance(node, exp.Count):
return exp.Literal.number(0)
elif isinstance(node, exp.AggFunc):
return exp.null()
return node
alias = exp.Coalesce(
this=alias,
expressions=[value.this.transform(remove_aggs)],
)
select.parent.replace(alias)
for key, column, predicate in keys:
@ -209,9 +223,6 @@ def decorrelate(select, parent_select, external_columns, sequence):
if key in group_by:
key.replace(nested)
parent_predicate = _replace(
parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
)
elif isinstance(predicate, exp.EQ):
parent_predicate = _replace(
parent_predicate,
@ -245,7 +256,14 @@ def _other_operand(expression):
if isinstance(expression, exp.In):
return expression.this
if isinstance(expression, (exp.Any, exp.All)):
return _other_operand(expression.parent)
if isinstance(expression, exp.Binary):
return expression.right if expression.arg_key == "this" else expression.left
return (
expression.right
if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All))
else expression.left
)
return None

File diff suppressed because it is too large Load diff

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import abc
import typing as t
import sqlglot
from sqlglot import expressions as exp
from sqlglot.errors import SchemaError
from sqlglot.helper import dict_depth
@ -157,10 +158,10 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
visible: t.Optional[t.Dict] = None,
dialect: t.Optional[str] = None,
) -> None:
super().__init__(schema)
self.visible = visible or {}
self.dialect = dialect
self.visible = visible or {}
self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
super().__init__(self._normalize(schema or {}))
@classmethod
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
@ -180,6 +181,33 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
}
)
def _normalize(self, schema: t.Dict) -> t.Dict:
"""
Converts all identifiers in the schema into lowercase, unless they're quoted.
Args:
schema: the schema to normalize.
Returns:
The normalized schema mapping.
"""
flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
normalized_mapping: t.Dict = {}
for keys in flattened_schema:
columns = _nested_get(schema, *zip(keys, keys))
assert columns is not None
normalized_keys = [self._normalize_name(key) for key in keys]
for column_name, column_type in columns.items():
_nested_set(
normalized_mapping,
normalized_keys + [self._normalize_name(column_name)],
column_type,
)
return normalized_mapping
def add_table(
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
) -> None:
@ -204,6 +232,19 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
)
self.mapping_trie = self._build_trie(self.mapping)
def _normalize_name(self, name: str) -> str:
try:
identifier: t.Optional[exp.Expression] = sqlglot.parse_one(
name, read=self.dialect, into=exp.Identifier
)
except:
identifier = exp.to_identifier(name)
assert isinstance(identifier, exp.Identifier)
if identifier.quoted:
return identifier.name
return identifier.name.lower()
def _depth(self) -> int:
# The columns themselves are a mapping, but we don't want to include those
return super()._depth() - 1

67
sqlglot/serde.py Normal file
View file

@ -0,0 +1,67 @@
from __future__ import annotations
import typing as t
from sqlglot import expressions as exp
if t.TYPE_CHECKING:
JSON = t.Union[dict, list, str, float, int, bool]
Node = t.Union[t.List["Node"], exp.DataType.Type, exp.Expression, JSON]
def dump(node: Node) -> JSON:
"""
Recursively dump an AST into a JSON-serializable dict.
"""
if isinstance(node, list):
return [dump(i) for i in node]
if isinstance(node, exp.DataType.Type):
return {
"class": "DataType.Type",
"value": node.value,
}
if isinstance(node, exp.Expression):
klass = node.__class__.__qualname__
if node.__class__.__module__ != exp.__name__:
klass = f"{node.__module__}.{klass}"
obj = {
"class": klass,
"args": {k: dump(v) for k, v in node.args.items() if v is not None and v != []},
}
if node.type:
obj["type"] = node.type.sql()
if node.comments:
obj["comments"] = node.comments
return obj
return node
def load(obj: JSON) -> Node:
"""
Recursively load a dict (as returned by `dump`) into an AST.
"""
if isinstance(obj, list):
return [load(i) for i in obj]
if isinstance(obj, dict):
class_name = obj["class"]
if class_name == "DataType.Type":
return exp.DataType.Type(obj["value"])
if "." in class_name:
module_path, class_name = class_name.rsplit(".", maxsplit=1)
module = __import__(module_path, fromlist=[class_name])
else:
module = exp
klass = getattr(module, class_name)
expression = klass(**{k: load(v) for k, v in obj["args"].items()})
type_ = obj.get("type")
if type_:
expression.type = exp.DataType.build(type_)
comments = obj.get("comments")
if comments:
expression.comments = load(comments)
return expression
return obj

View file

@ -86,6 +86,7 @@ class TokenType(AutoName):
VARBINARY = auto()
JSON = auto()
JSONB = auto()
TIME = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto()
@ -181,6 +182,7 @@ class TokenType(AutoName):
FUNCTION = auto()
FROM = auto()
GENERATED = auto()
GLOBAL = auto()
GROUP_BY = auto()
GROUPING_SETS = auto()
HAVING = auto()
@ -656,6 +658,7 @@ class Tokenizer(metaclass=_Tokenizer):
"FLOAT4": TokenType.FLOAT,
"FLOAT8": TokenType.DOUBLE,
"DOUBLE": TokenType.DOUBLE,
"DOUBLE PRECISION": TokenType.DOUBLE,
"JSON": TokenType.JSON,
"CHAR": TokenType.CHAR,
"NCHAR": TokenType.NCHAR,
@ -671,6 +674,7 @@ class Tokenizer(metaclass=_Tokenizer):
"BLOB": TokenType.VARBINARY,
"BYTEA": TokenType.VARBINARY,
"VARBINARY": TokenType.VARBINARY,
"TIME": TokenType.TIME,
"TIMESTAMP": TokenType.TIMESTAMP,
"TIMESTAMPTZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ,
@ -721,6 +725,8 @@ class Tokenizer(metaclass=_Tokenizer):
COMMENTS = ["--", ("/*", "*/")]
KEYWORD_TRIE = None # autofilled
IDENTIFIER_CAN_START_WITH_DIGIT = False
__slots__ = (
"sql",
"size",
@ -938,17 +944,24 @@ class Tokenizer(metaclass=_Tokenizer):
elif self._peek.upper() == "E" and not scientific: # type: ignore
scientific += 1
self._advance()
elif self._peek.isalpha(): # type: ignore
self._add(TokenType.NUMBER)
elif self._peek.isidentifier(): # type: ignore
number_text = self._text
literal = []
while self._peek.isalpha(): # type: ignore
while self._peek.isidentifier(): # type: ignore
literal.append(self._peek.upper()) # type: ignore
self._advance()
literal = "".join(literal) # type: ignore
token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) # type: ignore
if token_type:
self._add(TokenType.NUMBER, number_text)
self._add(TokenType.DCOLON, "::")
return self._add(token_type, literal) # type: ignore
elif self.IDENTIFIER_CAN_START_WITH_DIGIT:
return self._add(TokenType.VAR)
self._add(TokenType.NUMBER, number_text)
return self._advance(-len(literal))
else:
return self._add(TokenType.NUMBER)

View file

@ -82,6 +82,27 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
return expression
def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
"""
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions.
This transforms removes the precision from parameterized types in expressions.
"""
return expression.transform(
lambda node: exp.DataType(
**{
**node.args,
"expressions": [
node_expression
for node_expression in node.expressions
if isinstance(node_expression, exp.DataType)
],
}
)
if isinstance(node, exp.DataType)
else node,
)
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
to_sql: t.Callable[[Generator, exp.Expression], str],
@ -121,3 +142,6 @@ def delegate(attr: str) -> t.Callable:
UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))}
ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))}
REMOVE_PRECISION_PARAMETERIZED_TYPES = {
exp.Cast: preprocess([remove_precision_parameterized_types], delegate("cast_sql"))
}

View file

@ -52,7 +52,7 @@ def in_trie(trie: t.Dict, key: key) -> t.Tuple[int, t.Dict]:
Returns:
A pair `(value, subtrie)`, where `subtrie` is the sub-trie we get at the point where the search stops, and `value`
is either 0 (search was unsuccessfull), 1 (`value` is a prefix of a keyword in `trie`) or 2 (`key is in `trie`).
is either 0 (search was unsuccessful), 1 (`value` is a prefix of a keyword in `trie`) or 2 (`key is in `trie`).
"""
if not key:
return (0, trie)