1
0
Fork 0

Merging upstream version 10.1.3.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:56:25 +01:00
parent 582b160275
commit a5128ea109
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
57 changed files with 1542 additions and 529 deletions

View file

@ -30,7 +30,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType
__version__ = "10.0.8"
__version__ = "10.1.3"
pretty = False

View file

@ -56,12 +56,12 @@ def _derived_table_values_to_unnest(self, expression):
def _returnsproperty_sql(self, expression):
value = expression.args.get("value")
if isinstance(value, exp.Schema):
value = f"{value.this} <{self.expressions(value)}>"
this = expression.this
if isinstance(this, exp.Schema):
this = f"{this.this} <{self.expressions(this)}>"
else:
value = self.sql(value)
return f"RETURNS {value}"
this = self.sql(this)
return f"RETURNS {this}"
def _create_sql(self, expression):
@ -142,6 +142,11 @@ class BigQuery(Dialect):
),
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
}
FUNCTION_PARSERS.pop("TRIM")
NO_PAREN_FUNCTIONS = {
**parser.Parser.NO_PAREN_FUNCTIONS,
TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
@ -174,6 +179,7 @@ class BigQuery(Dialect):
exp.Values: _derived_table_values_to_unnest,
exp.ReturnsProperty: _returnsproperty_sql,
exp.Create: _create_sql,
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
if e.name == "IMMUTABLE"
else "NOT DETERMINISTIC",
@ -200,9 +206,7 @@ class BigQuery(Dialect):
exp.VolatilityProperty,
}
WITH_PROPERTIES = {
exp.AnonymousProperty,
}
WITH_PROPERTIES = {exp.Property}
EXPLICIT_UNION = True

View file

@ -21,14 +21,15 @@ class ClickHouse(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"FINAL": TokenType.FINAL,
"ASOF": TokenType.ASOF,
"DATETIME64": TokenType.DATETIME,
"INT8": TokenType.TINYINT,
"FINAL": TokenType.FINAL,
"FLOAT32": TokenType.FLOAT,
"FLOAT64": TokenType.DOUBLE,
"INT16": TokenType.SMALLINT,
"INT32": TokenType.INT,
"INT64": TokenType.BIGINT,
"FLOAT32": TokenType.FLOAT,
"FLOAT64": TokenType.DOUBLE,
"INT8": TokenType.TINYINT,
"TUPLE": TokenType.STRUCT,
}
@ -38,6 +39,10 @@ class ClickHouse(Dialect):
"MAP": parse_var_map,
}
JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF}
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY}
def _parse_table(self, schema=False):
this = super()._parse_table(schema)

View file

@ -289,19 +289,19 @@ def struct_extract_sql(self, expression):
return f"{this}.{struct_key}"
def var_map_sql(self, expression):
def var_map_sql(self, expression, map_func_name="MAP"):
keys = expression.args["keys"]
values = expression.args["values"]
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
self.unsupported("Cannot convert array columns into map.")
return f"MAP({self.format_args(keys, values)})"
return f"{map_func_name}({self.format_args(keys, values)})"
args = []
for key, value in zip(keys.expressions, values.expressions):
args.append(self.sql(key))
args.append(self.sql(value))
return f"MAP({self.format_args(*args)})"
return f"{map_func_name}({self.format_args(*args)})"
def format_time_lambda(exp_class, dialect, default=None):
@ -336,18 +336,13 @@ def create_with_partitions_sql(self, expression):
if has_schema and is_partitionable:
expression = expression.copy()
prop = expression.find(exp.PartitionedByProperty)
value = prop and prop.args.get("value")
if prop and not isinstance(value, exp.Schema):
this = prop and prop.this
if prop and not isinstance(this, exp.Schema):
schema = expression.this
columns = {v.name.upper() for v in value.expressions}
columns = {v.name.upper() for v in this.expressions}
partitions = [col for col in schema.expressions if col.name.upper() in columns]
schema.set(
"expressions",
[e for e in schema.expressions if e not in partitions],
)
prop.replace(
exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions))
)
schema.set("expressions", [e for e in schema.expressions if e not in partitions])
prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
expression.set("this", schema)
return self.create_sql(expression)

View file

@ -153,7 +153,7 @@ class Drill(Dialect):
exp.If: if_sql,
exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Pivot: no_pivot_sql,
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
exp.StrPosition: str_position_sql,

View file

@ -61,9 +61,7 @@ def _array_sort(self, expression):
def _property_sql(self, expression):
key = expression.name
value = self.sql(expression, "value")
return f"'{key}'={value}"
return f"'{expression.name}'={self.sql(expression, 'value')}"
def _str_to_unix(self, expression):
@ -250,7 +248,7 @@ class Hive(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**transforms.UNALIAS_GROUP, # type: ignore
exp.AnonymousProperty: _property_sql,
exp.Property: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayAgg: rename_func("COLLECT_LIST"),
exp.ArrayConcat: rename_func("CONCAT"),
@ -262,7 +260,7 @@ class Hive(Dialect):
exp.DateStrToDate: rename_func("TO_DATE"),
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
exp.FileFormatProperty: lambda self, e: f"STORED AS {e.text('value').upper()}",
exp.FileFormatProperty: lambda self, e: f"STORED AS {e.name.upper()}",
exp.If: if_sql,
exp.Index: _index_sql,
exp.ILike: no_ilike_sql,
@ -285,7 +283,7 @@ class Hive(Dialect):
exp.StrToTime: _str_to_time,
exp.StrToUnix: _str_to_unix,
exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'value')}",
exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}",
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
@ -298,11 +296,11 @@ class Hive(Dialect):
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.format_args(e.this, _time_format(self, e))})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'value')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
}
WITH_PROPERTIES = {exp.AnonymousProperty}
WITH_PROPERTIES = {exp.Property}
ROOT_PROPERTIES = {
exp.PartitionedByProperty,

View file

@ -453,6 +453,7 @@ class MySQL(Dialect):
exp.CharacterSetProperty,
exp.CollateProperty,
exp.SchemaCommentProperty,
exp.LikeProperty,
}
WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()

View file

@ -1,7 +1,7 @@
from __future__ import annotations
from sqlglot import exp, generator, tokens, transforms
from sqlglot.dialects.dialect import Dialect, no_ilike_sql
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func
from sqlglot.helper import csv
from sqlglot.tokens import TokenType
@ -37,6 +37,12 @@ class Oracle(Dialect):
"YYYY": "%Y", # 2015
}
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DECODE": exp.Matches.from_arg_list,
}
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@ -58,6 +64,7 @@ class Oracle(Dialect):
**transforms.UNALIAS_GROUP, # type: ignore
exp.ILike: no_ilike_sql,
exp.Limit: _limit_sql,
exp.Matches: rename_func("DECODE"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",

View file

@ -74,6 +74,27 @@ def _trim_sql(self, expression):
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def _string_agg_sql(self, expression):
expression = expression.copy()
separator = expression.args.get("separator") or exp.Literal.string(",")
order = ""
this = expression.this
if isinstance(this, exp.Order):
if this.this:
this = this.this
this.pop()
order = self.sql(expression.this) # Order has a leading space
return f"STRING_AGG({self.format_args(this, separator)}{order})"
def _datatype_sql(self, expression):
if expression.this == exp.DataType.Type.ARRAY:
return f"{self.expressions(expression, flat=True)}[]"
return self.datatype_sql(expression)
def _auto_increment_to_serial(expression):
auto = expression.find(exp.AutoIncrementColumnConstraint)
@ -191,25 +212,27 @@ class Postgres(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ALWAYS": TokenType.ALWAYS,
"BY DEFAULT": TokenType.BY_DEFAULT,
"IDENTITY": TokenType.IDENTITY,
"GENERATED": TokenType.GENERATED,
"DOUBLE PRECISION": TokenType.DOUBLE,
"BIGSERIAL": TokenType.BIGSERIAL,
"SERIAL": TokenType.SERIAL,
"SMALLSERIAL": TokenType.SMALLSERIAL,
"UUID": TokenType.UUID,
"TEMP": TokenType.TEMPORARY,
"BEGIN TRANSACTION": TokenType.BEGIN,
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
"BIGSERIAL": TokenType.BIGSERIAL,
"BY DEFAULT": TokenType.BY_DEFAULT,
"COMMENT ON": TokenType.COMMAND,
"DECLARE": TokenType.COMMAND,
"DO": TokenType.COMMAND,
"DOUBLE PRECISION": TokenType.DOUBLE,
"GENERATED": TokenType.GENERATED,
"GRANT": TokenType.COMMAND,
"HSTORE": TokenType.HSTORE,
"IDENTITY": TokenType.IDENTITY,
"JSONB": TokenType.JSONB,
"REFRESH": TokenType.COMMAND,
"REINDEX": TokenType.COMMAND,
"RESET": TokenType.COMMAND,
"REVOKE": TokenType.COMMAND,
"GRANT": TokenType.COMMAND,
"SERIAL": TokenType.SERIAL,
"SMALLSERIAL": TokenType.SMALLSERIAL,
"TEMP": TokenType.TEMPORARY,
"UUID": TokenType.UUID,
**{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
**{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
}
@ -265,4 +288,7 @@ class Postgres(Dialect):
exp.Trim: _trim_sql,
exp.TryCast: no_trycast_sql,
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
exp.DataType: _datatype_sql,
exp.GroupConcat: _string_agg_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
}

View file

@ -171,16 +171,7 @@ class Presto(Dialect):
STRUCT_DELIMITER = ("(", ")")
ROOT_PROPERTIES = {
exp.SchemaCommentProperty,
}
WITH_PROPERTIES = {
exp.PartitionedByProperty,
exp.FileFormatProperty,
exp.AnonymousProperty,
exp.TableFormatProperty,
}
ROOT_PROPERTIES = {exp.SchemaCommentProperty}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@ -231,7 +222,8 @@ class Presto(Dialect):
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT = '{e.text('value').upper()}'",
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'",
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.TimeStrToDate: _date_parse_sql,
exp.TimeStrToTime: _date_parse_sql,
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))",

View file

@ -1,6 +1,6 @@
from __future__ import annotations
from sqlglot import exp
from sqlglot import exp, transforms
from sqlglot.dialects.postgres import Postgres
from sqlglot.tokens import TokenType
@ -18,12 +18,14 @@ class Redshift(Postgres):
KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS, # type: ignore
"COPY": TokenType.COMMAND,
"GEOMETRY": TokenType.GEOMETRY,
"GEOGRAPHY": TokenType.GEOGRAPHY,
"HLLSKETCH": TokenType.HLLSKETCH,
"SUPER": TokenType.SUPER,
"TIME": TokenType.TIMESTAMP,
"TIMETZ": TokenType.TIMESTAMPTZ,
"UNLOAD": TokenType.COMMAND,
"VARBYTE": TokenType.VARBINARY,
"SIMILAR TO": TokenType.SIMILAR_TO,
}
@ -35,3 +37,17 @@ class Redshift(Postgres):
exp.DataType.Type.VARBINARY: "VARBYTE",
exp.DataType.Type.INT: "INTEGER",
}
ROOT_PROPERTIES = {
exp.DistKeyProperty,
exp.SortKeyProperty,
exp.DistStyleProperty,
}
TRANSFORMS = {
**Postgres.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
}

View file

@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import (
format_time_lambda,
inline_array_sql,
rename_func,
var_map_sql,
)
from sqlglot.expressions import Literal
from sqlglot.helper import seq_get
@ -100,6 +101,14 @@ def _parse_date_part(self):
return self.expression(exp.Extract, this=this, expression=expression)
def _datatype_sql(self, expression):
if expression.this == exp.DataType.Type.ARRAY:
return "ARRAY"
elif expression.this == exp.DataType.Type.MAP:
return "OBJECT"
return self.datatype_sql(expression)
class Snowflake(Dialect):
null_ordering = "nulls_are_large"
time_format = "'yyyy-mm-dd hh24:mi:ss'"
@ -142,6 +151,8 @@ class Snowflake(Dialect):
"TO_TIMESTAMP": _snowflake_to_timestamp,
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
"RLIKE": exp.RegexpLike.from_arg_list,
"DECODE": exp.Matches.from_arg_list,
"OBJECT_CONSTRUCT": parser.parse_var_map,
}
FUNCTION_PARSERS = {
@ -195,16 +206,20 @@ class Snowflake(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.If: rename_func("IFF"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: _unix_to_time_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Array: inline_array_sql,
exp.StrPosition: rename_func("POSITION"),
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.DataType: _datatype_sql,
exp.If: rename_func("IFF"),
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Matches: rename_func("DECODE"),
exp.StrPosition: rename_func("POSITION"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
exp.UnixToTime: _unix_to_time_sql,
}
TYPE_MAPPING = {

View file

@ -98,7 +98,7 @@ class Spark(Hive):
TRANSFORMS = {
**Hive.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}",
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),

View file

@ -13,6 +13,23 @@ from sqlglot.dialects.dialect import (
from sqlglot.tokens import TokenType
# https://www.sqlite.org/lang_aggfunc.html#group_concat
def _group_concat_sql(self, expression):
this = expression.this
distinct = expression.find(exp.Distinct)
if distinct:
this = distinct.expressions[0]
distinct = "DISTINCT "
if isinstance(expression.this, exp.Order):
self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.")
if expression.this.this and not distinct:
this = expression.this.this
separator = expression.args.get("separator")
return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})"
class SQLite(Dialect):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"]
@ -62,6 +79,7 @@ class SQLite(Dialect):
exp.Levenshtein: rename_func("EDITDIST3"),
exp.TableSample: no_tablesample_sql,
exp.TryCast: no_trycast_sql,
exp.GroupConcat: _group_concat_sql,
}
def transaction_sql(self, expression):

View file

@ -17,6 +17,7 @@ FULL_FORMAT_TIME_MAPPING = {
"mm": "%B",
"m": "%B",
}
DATE_DELTA_INTERVAL = {
"year": "year",
"yyyy": "year",
@ -37,11 +38,12 @@ DATE_DELTA_INTERVAL = {
DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{1,2})")
# N = Numeric, C=Currency
TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}
def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
def _format_time_lambda(exp_class, full_format_mapping=None, default=None):
def _format_time(args):
return exp_class(
this=seq_get(args, 1),
@ -58,7 +60,7 @@ def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
return _format_time
def parse_format(args):
def _parse_format(args):
fmt = seq_get(args, 1)
number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this)
if number_fmt:
@ -78,7 +80,7 @@ def generate_date_delta_with_unit_sql(self, e):
return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})"
def generate_format_sql(self, e):
def _format_sql(self, e):
fmt = (
e.args["format"]
if isinstance(e, exp.NumberToStr)
@ -87,6 +89,28 @@ def generate_format_sql(self, e):
return f"FORMAT({self.format_args(e.this, fmt)})"
def _string_agg_sql(self, e):
e = e.copy()
this = e.this
distinct = e.find(exp.Distinct)
if distinct:
# exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression
self.unsupported("T-SQL STRING_AGG doesn't support DISTINCT.")
this = distinct.expressions[0]
distinct.pop()
order = ""
if isinstance(e.this, exp.Order):
if e.this.this:
this = e.this.this
e.this.this.pop()
order = f" WITHIN GROUP ({self.sql(e.this)[1:]})" # Order has a leading space
separator = e.args.get("separator") or exp.Literal.string(",")
return f"STRING_AGG({self.format_args(this, separator)}){order}"
class TSQL(Dialect):
null_ordering = "nulls_are_small"
time_format = "'yyyy-mm-dd hh:mm:ss'"
@ -228,14 +252,14 @@ class TSQL(Dialect):
"ISNULL": exp.Coalesce.from_arg_list,
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
"DATENAME": tsql_format_time_lambda(exp.TimeToStr, full_format_mapping=True),
"DATEPART": tsql_format_time_lambda(exp.TimeToStr),
"DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True),
"DATEPART": _format_time_lambda(exp.TimeToStr),
"GETDATE": exp.CurrentDate.from_arg_list,
"IIF": exp.If.from_arg_list,
"LEN": exp.Length.from_arg_list,
"REPLICATE": exp.Repeat.from_arg_list,
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
"FORMAT": parse_format,
"FORMAT": _parse_format,
}
VAR_LENGTH_DATATYPES = {
@ -298,6 +322,7 @@ class TSQL(Dialect):
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),
exp.If: rename_func("IIF"),
exp.NumberToStr: generate_format_sql,
exp.TimeToStr: generate_format_sql,
exp.NumberToStr: _format_sql,
exp.TimeToStr: _format_sql,
exp.GroupConcat: _string_agg_sql,
}

View file

@ -22,7 +22,40 @@ class UnsupportedError(SqlglotError):
class ParseError(SqlglotError):
pass
def __init__(
self,
message: str,
errors: t.Optional[t.List[t.Dict[str, t.Any]]] = None,
):
super().__init__(message)
self.errors = errors or []
@classmethod
def new(
cls,
message: str,
description: t.Optional[str] = None,
line: t.Optional[int] = None,
col: t.Optional[int] = None,
start_context: t.Optional[str] = None,
highlight: t.Optional[str] = None,
end_context: t.Optional[str] = None,
into_expression: t.Optional[str] = None,
) -> ParseError:
return cls(
message,
[
{
"description": description,
"line": line,
"col": col,
"start_context": start_context,
"highlight": highlight,
"end_context": end_context,
"into_expression": into_expression,
}
],
)
class TokenError(SqlglotError):
@ -41,9 +74,13 @@ class ExecuteError(SqlglotError):
pass
def concat_errors(errors: t.Sequence[t.Any], maximum: int) -> str:
def concat_messages(errors: t.Sequence[t.Any], maximum: int) -> str:
msg = [str(e) for e in errors[:maximum]]
remaining = len(errors) - maximum
if remaining > 0:
msg.append(f"... and {remaining} more")
return "\n\n".join(msg)
def merge_errors(errors: t.Sequence[ParseError]) -> t.List[t.Dict[str, t.Any]]:
return [e_dict for error in errors for e_dict in error.errors]

View file

@ -122,7 +122,6 @@ def interval(this, unit):
ENV = {
"__builtins__": {},
"exp": exp,
# aggs
"SUM": filter_nulls(sum),

View file

@ -115,6 +115,9 @@ class PythonExecutor:
sink = self.table(context.columns)
for reader in table_iter:
if len(sink) >= step.limit:
break
if condition and not context.eval(condition):
continue
@ -123,9 +126,6 @@ class PythonExecutor:
else:
sink.append(reader.row)
if len(sink) >= step.limit:
break
return self.context({step.name: sink})
def static(self):
@ -288,21 +288,32 @@ class PythonExecutor:
end = 1
length = len(context.table)
table = self.table(list(step.group) + step.aggregations)
condition = self.generate(step.condition)
for i in range(length):
context.set_index(i)
key = context.eval_tuple(group_by)
group = key if group is None else group
end += 1
if key != group:
context.set_range(start, end - 2)
table.append(group + context.eval_tuple(aggregations))
group = key
start = end - 2
if i == length - 1:
context.set_range(start, end - 1)
def add_row():
if not condition or context.eval(condition):
table.append(group + context.eval_tuple(aggregations))
if length:
for i in range(length):
context.set_index(i)
key = context.eval_tuple(group_by)
group = key if group is None else group
end += 1
if key != group:
context.set_range(start, end - 2)
add_row()
group = key
start = end - 2
if len(table.rows) >= step.limit:
break
if i == length - 1:
context.set_range(start, end - 1)
add_row()
elif step.limit > 0:
context.set_range(0, 0)
table.append(context.eval_tuple(group_by) + context.eval_tuple(aggregations))
context = self.context({step.name: table, **{name: table for name in context.tables}})
if step.projections:
@ -311,11 +322,9 @@ class PythonExecutor:
def sort(self, step, context):
projections = self.generate_tuple(step.projections)
projection_columns = [p.alias_or_name for p in step.projections]
all_columns = list(context.columns) + projection_columns
sink = self.table(all_columns)
for reader, ctx in context:
sink.append(reader.row + ctx.eval_tuple(projections))
@ -401,8 +410,9 @@ class Python(Dialect):
exp.Boolean: lambda self, e: "True" if e.this else "False",
exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})",
exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})",
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})",
exp.Is: lambda self, e: self.binary(e, "is"),
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
exp.Null: lambda *_: "None",

View file

@ -43,14 +43,14 @@ class Expression(metaclass=_Expression):
key = "Expression"
arg_types = {"this": True}
__slots__ = ("args", "parent", "arg_key", "type", "comment")
__slots__ = ("args", "parent", "arg_key", "type", "comments")
def __init__(self, **args):
self.args = args
self.parent = None
self.arg_key = None
self.type = None
self.comment = None
self.comments = None
for arg_key, value in self.args.items():
self._set_parent(arg_key, value)
@ -88,19 +88,6 @@ class Expression(metaclass=_Expression):
return field.this
return ""
def find_comment(self, key: str) -> str:
"""
Finds the comment that is attached to a specified child node.
Args:
key: the key of the target child node (e.g. "this", "expression", etc).
Returns:
The comment attached to the child node, or the empty string, if it doesn't exist.
"""
field = self.args.get(key)
return field.comment if isinstance(field, Expression) else ""
@property
def is_string(self):
return isinstance(self, Literal) and self.args["is_string"]
@ -137,7 +124,7 @@ class Expression(metaclass=_Expression):
def __deepcopy__(self, memo):
copy = self.__class__(**deepcopy(self.args))
copy.comment = self.comment
copy.comments = self.comments
copy.type = self.type
return copy
@ -369,7 +356,7 @@ class Expression(metaclass=_Expression):
)
for k, vs in self.args.items()
}
args["comment"] = self.comment
args["comments"] = self.comments
args["type"] = self.type
args = {k: v for k, v in args.items() if v or not hide_missing}
@ -767,7 +754,7 @@ class NotNullColumnConstraint(ColumnConstraintKind):
class PrimaryKeyColumnConstraint(ColumnConstraintKind):
pass
arg_types = {"desc": False}
class UniqueColumnConstraint(ColumnConstraintKind):
@ -819,6 +806,12 @@ class Unique(Expression):
arg_types = {"expressions": True}
# https://www.postgresql.org/docs/9.1/sql-selectinto.html
# https://docs.aws.amazon.com/redshift/latest/dg/r_SELECT_INTO.html#r_SELECT_INTO-examples
class Into(Expression):
arg_types = {"this": True, "temporary": False, "unlogged": False}
class From(Expression):
arg_types = {"expressions": True}
@ -1065,67 +1058,67 @@ class Property(Expression):
class TableFormatProperty(Property):
pass
arg_types = {"this": True}
class PartitionedByProperty(Property):
pass
arg_types = {"this": True}
class FileFormatProperty(Property):
pass
arg_types = {"this": True}
class DistKeyProperty(Property):
pass
arg_types = {"this": True}
class SortKeyProperty(Property):
pass
arg_types = {"this": True, "compound": False}
class DistStyleProperty(Property):
pass
arg_types = {"this": True}
class LikeProperty(Property):
arg_types = {"this": True, "expressions": False}
class LocationProperty(Property):
pass
arg_types = {"this": True}
class EngineProperty(Property):
pass
arg_types = {"this": True}
class AutoIncrementProperty(Property):
pass
arg_types = {"this": True}
class CharacterSetProperty(Property):
arg_types = {"this": True, "value": True, "default": True}
arg_types = {"this": True, "default": True}
class CollateProperty(Property):
pass
arg_types = {"this": True}
class SchemaCommentProperty(Property):
pass
class AnonymousProperty(Property):
pass
arg_types = {"this": True}
class ReturnsProperty(Property):
arg_types = {"this": True, "value": True, "is_table": False}
arg_types = {"this": True, "is_table": False}
class LanguageProperty(Property):
pass
arg_types = {"this": True}
class ExecuteAsProperty(Property):
pass
arg_types = {"this": True}
class VolatilityProperty(Property):
@ -1135,27 +1128,36 @@ class VolatilityProperty(Property):
class Properties(Expression):
arg_types = {"expressions": True}
PROPERTY_KEY_MAPPING = {
NAME_TO_PROPERTY = {
"AUTO_INCREMENT": AutoIncrementProperty,
"CHARACTER_SET": CharacterSetProperty,
"CHARACTER SET": CharacterSetProperty,
"COLLATE": CollateProperty,
"COMMENT": SchemaCommentProperty,
"ENGINE": EngineProperty,
"FORMAT": FileFormatProperty,
"LOCATION": LocationProperty,
"PARTITIONED_BY": PartitionedByProperty,
"TABLE_FORMAT": TableFormatProperty,
"DISTKEY": DistKeyProperty,
"DISTSTYLE": DistStyleProperty,
"ENGINE": EngineProperty,
"EXECUTE AS": ExecuteAsProperty,
"FORMAT": FileFormatProperty,
"LANGUAGE": LanguageProperty,
"LOCATION": LocationProperty,
"PARTITIONED_BY": PartitionedByProperty,
"RETURNS": ReturnsProperty,
"SORTKEY": SortKeyProperty,
"TABLE_FORMAT": TableFormatProperty,
}
PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()}
@classmethod
def from_dict(cls, properties_dict) -> Properties:
expressions = []
for key, value in properties_dict.items():
property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty)
expressions.append(property_cls(this=Literal.string(key), value=convert(value)))
property_cls = cls.NAME_TO_PROPERTY.get(key.upper())
if property_cls:
expressions.append(property_cls(this=convert(value)))
else:
expressions.append(Property(this=Literal.string(key), value=convert(value)))
return cls(expressions=expressions)
@ -1383,6 +1385,7 @@ class Select(Subqueryable):
"expressions": False,
"hint": False,
"distinct": False,
"into": False,
"from": False,
**QUERY_MODIFIERS,
}
@ -2015,6 +2018,7 @@ class DataType(Expression):
DECIMAL = auto()
BOOLEAN = auto()
JSON = auto()
JSONB = auto()
INTERVAL = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
@ -2029,6 +2033,7 @@ class DataType(Expression):
STRUCT = auto()
NULLABLE = auto()
HLLSKETCH = auto()
HSTORE = auto()
SUPER = auto()
SERIAL = auto()
SMALLSERIAL = auto()
@ -2109,7 +2114,7 @@ class Transaction(Command):
class Commit(Command):
arg_types = {} # type: ignore
arg_types = {"chain": False}
class Rollback(Command):
@ -2442,7 +2447,7 @@ class ArrayFilter(Func):
class ArraySize(Func):
pass
arg_types = {"this": True, "expression": False}
class ArraySort(Func):
@ -2726,6 +2731,16 @@ class VarMap(Func):
is_var_len_args = True
class Matches(Func):
"""Oracle/Snowflake decode.
https://docs.oracle.com/cd/B19306_01/server.102/b14200/functions040.htm
Pattern matching MATCHES(value, search1, result1, ...searchN, resultN, else)
"""
arg_types = {"this": True, "expressions": True}
is_var_len_args = True
class Max(AggFunc):
pass
@ -2785,6 +2800,10 @@ class Round(Func):
arg_types = {"this": True, "decimals": False}
class RowNumber(Func):
arg_types: t.Dict[str, t.Any] = {}
class SafeDivide(Func):
arg_types = {"this": True, "expression": True}

View file

@ -1,19 +1,16 @@
from __future__ import annotations
import logging
import re
import typing as t
from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
from sqlglot.helper import apply_index_offset, csv
from sqlglot.time import format_time
from sqlglot.tokens import TokenType
logger = logging.getLogger("sqlglot")
NEWLINE_RE = re.compile("\r\n?|\n")
class Generator:
"""
@ -58,11 +55,11 @@ class Generator:
"""
TRANSFORMS = {
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}",
exp.DateAdd: lambda self, e: f"DATE_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})",
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
exp.VarMap: lambda self, e: f"MAP({self.format_args(e.args['keys'], e.args['values'])})",
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'this')}",
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
@ -97,16 +94,17 @@ class Generator:
exp.DistStyleProperty,
exp.DistKeyProperty,
exp.SortKeyProperty,
exp.LikeProperty,
}
WITH_PROPERTIES = {
exp.AnonymousProperty,
exp.Property,
exp.FileFormatProperty,
exp.PartitionedByProperty,
exp.TableFormatProperty,
}
WITH_SEPARATED_COMMENTS = (exp.Select,)
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
__slots__ = (
"time_mapping",
@ -211,7 +209,7 @@ class Generator:
for msg in self.unsupported_messages:
logger.warning(msg)
elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
raise UnsupportedError(concat_errors(self.unsupported_messages, self.max_unsupported))
raise UnsupportedError(concat_messages(self.unsupported_messages, self.max_unsupported))
return sql
@ -226,25 +224,24 @@ class Generator:
def seg(self, sql, sep=" "):
return f"{self.sep(sep)}{sql}"
def maybe_comment(self, sql, expression, single_line=False):
comment = expression.comment if self._comments else None
if not comment:
return sql
def pad_comment(self, comment):
comment = " " + comment if comment[0].strip() else comment
comment = comment + " " if comment[-1].strip() else comment
return comment
def maybe_comment(self, sql, expression):
comments = expression.comments if self._comments else None
if not comments:
return sql
sep = "\n" if self.pretty else " "
comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments)
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
return f"/*{comment}*/{self.sep()}{sql}"
return f"{comments}{self.sep()}{sql}"
if not self.pretty:
return f"{sql} /*{comment}*/"
if not NEWLINE_RE.search(comment):
return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/"
return f"/*{comment}*/\n{sql}" if sql else f" /*{comment}*/"
return f"{sql} {comments}"
def wrap(self, expression):
this_sql = self.indent(
@ -387,8 +384,11 @@ class Generator:
def notnullcolumnconstraint_sql(self, _):
return "NOT NULL"
def primarykeycolumnconstraint_sql(self, _):
return "PRIMARY KEY"
def primarykeycolumnconstraint_sql(self, expression):
desc = expression.args.get("desc")
if desc is not None:
return f"PRIMARY KEY{' DESC' if desc else ' ASC'}"
return f"PRIMARY KEY"
def uniquecolumnconstraint_sql(self, _):
return "UNIQUE"
@ -546,36 +546,33 @@ class Generator:
def root_properties(self, properties):
if properties.expressions:
return self.sep() + self.expressions(
properties,
indent=False,
sep=" ",
)
return self.sep() + self.expressions(properties, indent=False, sep=" ")
return ""
def properties(self, properties, prefix="", sep=", "):
if properties.expressions:
expressions = self.expressions(
properties,
sep=sep,
indent=False,
)
expressions = self.expressions(properties, sep=sep, indent=False)
return f"{self.seg(prefix)}{' ' if prefix else ''}{self.wrap(expressions)}"
return ""
def with_properties(self, properties):
return self.properties(
properties,
prefix="WITH",
)
return self.properties(properties, prefix="WITH")
def property_sql(self, expression):
if isinstance(expression.this, exp.Literal):
key = expression.this.this
else:
key = expression.name
value = self.sql(expression, "value")
return f"{key}={value}"
property_cls = expression.__class__
if property_cls == exp.Property:
return f"{expression.name}={self.sql(expression, 'value')}"
property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls)
if not property_name:
self.unsupported(f"Unsupported property {property_name}")
return f"{property_name}={self.sql(expression, 'this')}"
def likeproperty_sql(self, expression):
options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions)
options = f" {options}" if options else ""
return f"LIKE {self.sql(expression, 'this')}{options}"
def insert_sql(self, expression):
overwrite = expression.args.get("overwrite")
@ -700,6 +697,11 @@ class Generator:
def var_sql(self, expression):
return self.sql(expression, "this")
def into_sql(self, expression):
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
unlogged = " UNLOGGED" if expression.args.get("unlogged") else ""
return f"{self.seg('INTO')}{temporary or unlogged} {self.sql(expression, 'this')}"
def from_sql(self, expression):
expressions = self.expressions(expression, flat=True)
return f"{self.seg('FROM')} {expressions}"
@ -883,6 +885,7 @@ class Generator:
sql = self.query_modifiers(
expression,
f"SELECT{hint}{distinct}{expressions}",
self.sql(expression, "into", comment=False),
self.sql(expression, "from", comment=False),
)
return self.prepend_ctes(expression, sql)
@ -1061,6 +1064,11 @@ class Generator:
else:
return f"TRIM({target})"
def concat_sql(self, expression):
if len(expression.expressions) == 1:
return self.sql(expression.expressions[0])
return self.function_fallback_sql(expression)
def check_sql(self, expression):
this = self.sql(expression, key="this")
return f"CHECK ({this})"
@ -1125,7 +1133,10 @@ class Generator:
return self.prepend_ctes(expression, sql)
def neg_sql(self, expression):
return f"-{self.sql(expression, 'this')}"
# This makes sure we don't convert "- - 5" to "--5", which is a comment
this_sql = self.sql(expression, "this")
sep = " " if this_sql[0] == "-" else ""
return f"-{sep}{this_sql}"
def not_sql(self, expression):
return f"NOT {self.sql(expression, 'this')}"
@ -1191,8 +1202,12 @@ class Generator:
def transaction_sql(self, *_):
return "BEGIN"
def commit_sql(self, *_):
return "COMMIT"
def commit_sql(self, expression):
chain = expression.args.get("chain")
if chain is not None:
chain = " AND CHAIN" if chain else " AND NO CHAIN"
return f"COMMIT{chain or ''}"
def rollback_sql(self, expression):
savepoint = expression.args.get("savepoint")
@ -1334,15 +1349,15 @@ class Generator:
result_sqls = []
for i, e in enumerate(expressions):
sql = self.sql(e, comment=False)
comment = self.maybe_comment("", e, single_line=True)
comments = self.maybe_comment("", e)
if self.pretty:
if self._leading_comma:
result_sqls.append(f"{sep if i > 0 else pad}{sql}{comment}")
result_sqls.append(f"{sep if i > 0 else pad}{sql}{comments}")
else:
result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comment}")
result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}")
else:
result_sqls.append(f"{sql}{comment}{sep if i + 1 < num_sqls else ''}")
result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}")
result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
return self.indent(result_sqls, skip_first=False) if indent else result_sqls
@ -1354,7 +1369,10 @@ class Generator:
return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}"
def naked_property(self, expression):
return f"{expression.name} {self.sql(expression, 'value')}"
property_name = exp.Properties.PROPERTY_TO_NAME.get(expression.__class__)
if not property_name:
self.unsupported(f"Unsupported property {expression.__class__.__name__}")
return f"{property_name} {self.sql(expression, 'this')}"
def set_operation(self, expression, op):
this = self.sql(expression, "this")

View file

@ -68,6 +68,9 @@ def eliminate_subqueries(expression):
for cte_scope in root.cte_scopes:
# Append all the new CTEs from this existing CTE
for scope in cte_scope.traverse():
if scope is cte_scope:
# Don't try to eliminate this CTE itself
continue
new_cte = _eliminate(scope, existing_ctes, taken)
if new_cte:
new_ctes.append(new_cte)
@ -97,6 +100,9 @@ def _eliminate(scope, existing_ctes, taken):
if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF):
return _eliminate_derived_table(scope, existing_ctes, taken)
if scope.is_cte:
return _eliminate_cte(scope, existing_ctes, taken)
def _eliminate_union(scope, existing_ctes, taken):
duplicate_cte_alias = existing_ctes.get(scope.expression)
@ -127,26 +133,61 @@ def _eliminate_union(scope, existing_ctes, taken):
def _eliminate_derived_table(scope, existing_ctes, taken):
parent = scope.expression.parent
name, cte = _new_cte(scope, existing_ctes, taken)
table = exp.alias_(exp.table_(name), alias=parent.alias or name)
parent.replace(table)
return cte
def _eliminate_cte(scope, existing_ctes, taken):
parent = scope.expression.parent
name, cte = _new_cte(scope, existing_ctes, taken)
with_ = parent.parent
parent.pop()
if not with_.expressions:
with_.pop()
# Rename references to this CTE
for child_scope in scope.parent.traverse():
for table, source in child_scope.selected_sources.values():
if source is scope:
new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name)
table.replace(new_table)
return cte
def _new_cte(scope, existing_ctes, taken):
"""
Returns:
tuple of (name, cte)
where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance.
If this CTE duplicates an existing CTE, `cte` will be None.
"""
duplicate_cte_alias = existing_ctes.get(scope.expression)
parent = scope.expression.parent
name = alias = parent.alias
name = parent.alias
if not alias:
name = alias = find_new_name(taken=taken, base="cte")
if not name:
name = find_new_name(taken=taken, base="cte")
if duplicate_cte_alias:
name = duplicate_cte_alias
elif taken.get(alias):
name = find_new_name(taken=taken, base=alias)
elif taken.get(name):
name = find_new_name(taken=taken, base=name)
taken[name] = scope
table = exp.alias_(exp.table_(name), alias=alias)
parent.replace(table)
if not duplicate_cte_alias:
existing_ctes[scope.expression] = name
return exp.CTE(
cte = exp.CTE(
this=scope.expression,
alias=exp.TableAlias(this=exp.to_identifier(name)),
)
else:
cte = None
return name, cte

View file

@ -0,0 +1,92 @@
from sqlglot import exp
from sqlglot.helper import ensure_collection
def lower_identities(expression):
"""
Convert all unquoted identifiers to lower case.
Assuming the schema is all lower case, this essentially makes identifiers case-insensitive.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
>>> lower_identities(expression).sql()
'SELECT bar.a AS A FROM "Foo".bar'
Args:
expression (sqlglot.Expression): expression to quote
Returns:
sqlglot.Expression: quoted expression
"""
# We need to leave the output aliases unchanged, so the selects need special handling
_lower_selects(expression)
# These clauses can reference output aliases and also need special handling
_lower_order(expression)
_lower_having(expression)
# We've already handled these args, so don't traverse into them
traversed = {"expressions", "order", "having"}
if isinstance(expression, exp.Subquery):
# Root subquery, e.g. (SELECT A AS A FROM X) LIMIT 1
lower_identities(expression.this)
traversed |= {"this"}
if isinstance(expression, exp.Union):
# Union, e.g. SELECT A AS A FROM X UNION SELECT A AS A FROM X
lower_identities(expression.left)
lower_identities(expression.right)
traversed |= {"this", "expression"}
for k, v in expression.args.items():
if k in traversed:
continue
for child in ensure_collection(v):
if isinstance(child, exp.Expression):
child.transform(_lower, copy=False)
return expression
def _lower_selects(expression):
for e in expression.expressions:
# Leave output aliases as-is
e.unalias().transform(_lower, copy=False)
def _lower_order(expression):
order = expression.args.get("order")
if not order:
return
output_aliases = {e.alias for e in expression.expressions if isinstance(e, exp.Alias)}
for ordered in order.expressions:
# Don't lower references to output aliases
if not (
isinstance(ordered.this, exp.Column)
and not ordered.this.table
and ordered.this.name in output_aliases
):
ordered.transform(_lower, copy=False)
def _lower_having(expression):
having = expression.args.get("having")
if not having:
return
# Don't lower references to output aliases
for agg in having.find_all(exp.AggFunc):
agg.transform(_lower, copy=False)
def _lower(node):
if isinstance(node, exp.Identifier) and not node.quoted:
node.set("this", node.this.lower())
return node

View file

@ -6,6 +6,7 @@ from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.lower_identities import lower_identities
from sqlglot.optimizer.merge_subqueries import merge_subqueries
from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins
@ -17,6 +18,7 @@ from sqlglot.optimizer.quote_identities import quote_identities
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
RULES = (
lower_identities,
qualify_tables,
isolate_table_selects,
qualify_columns,

View file

@ -1,16 +1,15 @@
import itertools
from sqlglot import exp
from sqlglot.optimizer.scope import traverse_scope
from sqlglot.optimizer.scope import ScopeType, traverse_scope
def unnest_subqueries(expression):
"""
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
Convert the subquery into a group by so it is not a many to many left join.
Unnesting can only occur if the subquery does not have LIMIT or OFFSET.
Unnesting non correlated subqueries only happens on IN statements or = ANY statements.
Convert scalar subqueries into cross joins.
Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
Example:
>>> import sqlglot
@ -29,21 +28,43 @@ def unnest_subqueries(expression):
for scope in traverse_scope(expression):
select = scope.expression
parent = select.parent_select
if not parent:
continue
if scope.external_columns:
decorrelate(select, parent, scope.external_columns, sequence)
else:
elif scope.scope_type == ScopeType.SUBQUERY:
unnest(select, parent, sequence)
return expression
def unnest(select, parent_select, sequence):
predicate = select.find_ancestor(exp.In, exp.Any)
if len(select.selects) > 1:
return
predicate = select.find_ancestor(exp.Condition)
alias = _alias(sequence)
if not predicate or parent_select is not predicate.parent_select:
return
if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset):
# this subquery returns a scalar and can just be converted to a cross join
if not isinstance(predicate, (exp.In, exp.Any)):
having = predicate.find_ancestor(exp.Having)
column = exp.column(select.selects[0].alias_or_name, alias)
if having and having.parent_select is parent_select:
column = exp.Max(this=column)
_replace(select.parent, column)
parent_select.join(
select,
join_type="CROSS",
join_alias=alias,
copy=False,
)
return
if select.find(exp.Limit, exp.Offset):
return
if isinstance(predicate, exp.Any):
@ -54,7 +75,6 @@ def unnest(select, parent_select, sequence):
column = _other_operand(predicate)
value = select.selects[0]
alias = _alias(sequence)
on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
_replace(predicate, f"NOT {on.right} IS NULL")

View file

@ -4,7 +4,7 @@ import logging
import typing as t
from sqlglot import exp
from sqlglot.errors import ErrorLevel, ParseError, concat_errors
from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors
from sqlglot.helper import apply_index_offset, ensure_collection, seq_get
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import in_trie, new_trie
@ -104,6 +104,7 @@ class Parser(metaclass=_Parser):
TokenType.BINARY,
TokenType.VARBINARY,
TokenType.JSON,
TokenType.JSONB,
TokenType.INTERVAL,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
@ -115,6 +116,7 @@ class Parser(metaclass=_Parser):
TokenType.GEOGRAPHY,
TokenType.GEOMETRY,
TokenType.HLLSKETCH,
TokenType.HSTORE,
TokenType.SUPER,
TokenType.SERIAL,
TokenType.SMALLSERIAL,
@ -153,6 +155,7 @@ class Parser(metaclass=_Parser):
TokenType.COLLATE,
TokenType.COMMAND,
TokenType.COMMIT,
TokenType.COMPOUND,
TokenType.CONSTRAINT,
TokenType.CURRENT_TIME,
TokenType.DEFAULT,
@ -194,6 +197,7 @@ class Parser(metaclass=_Parser):
TokenType.RANGE,
TokenType.REFERENCES,
TokenType.RETURNS,
TokenType.ROW,
TokenType.ROWS,
TokenType.SCHEMA,
TokenType.SCHEMA_COMMENT,
@ -213,6 +217,7 @@ class Parser(metaclass=_Parser):
TokenType.TRUE,
TokenType.UNBOUNDED,
TokenType.UNIQUE,
TokenType.UNLOGGED,
TokenType.UNPIVOT,
TokenType.PROPERTIES,
TokenType.PROCEDURE,
@ -400,9 +405,17 @@ class Parser(metaclass=_Parser):
TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()),
TokenType.BEGIN: lambda self: self._parse_transaction(),
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
TokenType.END: lambda self: self._parse_commit_or_rollback(),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
}
UNARY_PARSERS = {
TokenType.PLUS: lambda self: self._parse_unary(), # Unary + is handled as a no-op
TokenType.NOT: lambda self: self.expression(exp.Not, this=self._parse_equality()),
TokenType.TILDA: lambda self: self.expression(exp.BitwiseNot, this=self._parse_unary()),
TokenType.DASH: lambda self: self.expression(exp.Neg, this=self._parse_unary()),
}
PRIMARY_PARSERS = {
TokenType.STRING: lambda self, token: self.expression(
exp.Literal, this=token.text, is_string=True
@ -446,19 +459,20 @@ class Parser(metaclass=_Parser):
}
PROPERTY_PARSERS = {
TokenType.AUTO_INCREMENT: lambda self: self._parse_auto_increment(),
TokenType.CHARACTER_SET: lambda self: self._parse_character_set(),
TokenType.LOCATION: lambda self: self.expression(
exp.LocationProperty,
this=exp.Literal.string("LOCATION"),
value=self._parse_string(),
TokenType.AUTO_INCREMENT: lambda self: self._parse_property_assignment(
exp.AutoIncrementProperty
),
TokenType.CHARACTER_SET: lambda self: self._parse_character_set(),
TokenType.LOCATION: lambda self: self._parse_property_assignment(exp.LocationProperty),
TokenType.PARTITIONED_BY: lambda self: self._parse_partitioned_by(),
TokenType.SCHEMA_COMMENT: lambda self: self._parse_schema_comment(),
TokenType.STORED: lambda self: self._parse_stored(),
TokenType.SCHEMA_COMMENT: lambda self: self._parse_property_assignment(
exp.SchemaCommentProperty
),
TokenType.STORED: lambda self: self._parse_property_assignment(exp.FileFormatProperty),
TokenType.DISTKEY: lambda self: self._parse_distkey(),
TokenType.DISTSTYLE: lambda self: self._parse_diststyle(),
TokenType.DISTSTYLE: lambda self: self._parse_property_assignment(exp.DistStyleProperty),
TokenType.SORTKEY: lambda self: self._parse_sortkey(),
TokenType.LIKE: lambda self: self._parse_create_like(),
TokenType.RETURNS: lambda self: self._parse_returns(),
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
@ -468,7 +482,7 @@ class Parser(metaclass=_Parser):
),
TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty),
TokenType.EXECUTE: lambda self: self._parse_execute_as(),
TokenType.EXECUTE: lambda self: self._parse_property_assignment(exp.ExecuteAsProperty),
TokenType.DETERMINISTIC: lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
),
@ -489,6 +503,7 @@ class Parser(metaclass=_Parser):
),
TokenType.FOREIGN_KEY: lambda self: self._parse_foreign_key(),
TokenType.UNIQUE: lambda self: self._parse_unique(),
TokenType.LIKE: lambda self: self._parse_create_like(),
}
NO_PAREN_FUNCTION_PARSERS = {
@ -505,6 +520,7 @@ class Parser(metaclass=_Parser):
"TRIM": lambda self: self._parse_trim(),
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
"TRY_CAST": lambda self: self._parse_cast(False),
"STRING_AGG": lambda self: self._parse_string_agg(),
}
QUERY_MODIFIER_PARSERS = {
@ -556,7 +572,7 @@ class Parser(metaclass=_Parser):
"_curr",
"_next",
"_prev",
"_prev_comment",
"_prev_comments",
"_show_trie",
"_set_trie",
)
@ -589,7 +605,7 @@ class Parser(metaclass=_Parser):
self._curr = None
self._next = None
self._prev = None
self._prev_comment = None
self._prev_comments = None
def parse(self, raw_tokens, sql=None):
"""
@ -608,6 +624,7 @@ class Parser(metaclass=_Parser):
)
def parse_into(self, expression_types, raw_tokens, sql=None):
errors = []
for expression_type in ensure_collection(expression_types):
parser = self.EXPRESSION_PARSERS.get(expression_type)
if not parser:
@ -615,8 +632,12 @@ class Parser(metaclass=_Parser):
try:
return self._parse(parser, raw_tokens, sql)
except ParseError as e:
error = e
raise ParseError(f"Failed to parse into {expression_types}") from error
e.errors[0]["into_expression"] = expression_type
errors.append(e)
raise ParseError(
f"Failed to parse into {expression_types}",
errors=merge_errors(errors),
) from errors[-1]
def _parse(self, parse_method, raw_tokens, sql=None):
self.reset()
@ -650,7 +671,10 @@ class Parser(metaclass=_Parser):
for error in self.errors:
logger.error(str(error))
elif self.error_level == ErrorLevel.RAISE and self.errors:
raise ParseError(concat_errors(self.errors, self.max_errors))
raise ParseError(
concat_messages(self.errors, self.max_errors),
errors=merge_errors(self.errors),
)
def raise_error(self, message, token=None):
token = token or self._curr or self._prev or Token.string("")
@ -659,19 +683,27 @@ class Parser(metaclass=_Parser):
start_context = self.sql[max(start - self.error_message_context, 0) : start]
highlight = self.sql[start:end]
end_context = self.sql[end : end + self.error_message_context]
error = ParseError(
error = ParseError.new(
f"{message}. Line {token.line}, Col: {token.col}.\n"
f" {start_context}\033[4m{highlight}\033[0m{end_context}"
f" {start_context}\033[4m{highlight}\033[0m{end_context}",
description=message,
line=token.line,
col=token.col,
start_context=start_context,
highlight=highlight,
end_context=end_context,
)
if self.error_level == ErrorLevel.IMMEDIATE:
raise error
self.errors.append(error)
def expression(self, exp_class, **kwargs):
def expression(self, exp_class, comments=None, **kwargs):
instance = exp_class(**kwargs)
if self._prev_comment:
instance.comment = self._prev_comment
self._prev_comment = None
if self._prev_comments:
instance.comments = self._prev_comments
self._prev_comments = None
if comments:
instance.comments = comments
self.validate_expression(instance)
return instance
@ -714,10 +746,10 @@ class Parser(metaclass=_Parser):
self._next = seq_get(self._tokens, self._index + 1)
if self._index > 0:
self._prev = self._tokens[self._index - 1]
self._prev_comment = self._prev.comment
self._prev_comments = self._prev.comments
else:
self._prev = None
self._prev_comment = None
self._prev_comments = None
def _retreat(self, index):
self._advance(index - self._index)
@ -768,7 +800,7 @@ class Parser(metaclass=_Parser):
)
def _parse_create(self):
replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE)
replace = self._match_pair(TokenType.OR, TokenType.REPLACE)
temporary = self._match(TokenType.TEMPORARY)
transient = self._match(TokenType.TRANSIENT)
unique = self._match(TokenType.UNIQUE)
@ -822,97 +854,57 @@ class Parser(metaclass=_Parser):
def _parse_property(self):
if self._match_set(self.PROPERTY_PARSERS):
return self.PROPERTY_PARSERS[self._prev.token_type](self)
if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET):
return self._parse_character_set(True)
if self._match_pair(TokenType.VAR, TokenType.EQ, advance=False):
key = self._parse_var().this
self._match(TokenType.EQ)
if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY):
return self._parse_sortkey(compound=True)
return self.expression(
exp.AnonymousProperty,
this=exp.Literal.string(key),
value=self._parse_column(),
)
if self._match_pair(TokenType.VAR, TokenType.EQ, advance=False):
key = self._parse_var()
self._match(TokenType.EQ)
return self.expression(exp.Property, this=key, value=self._parse_column())
return None
def _parse_property_assignment(self, exp_class):
prop = self._prev.text
self._match(TokenType.EQ)
return self.expression(exp_class, this=prop, value=self._parse_var_or_string())
self._match(TokenType.ALIAS)
return self.expression(exp_class, this=self._parse_var_or_string() or self._parse_number())
def _parse_partitioned_by(self):
self._match(TokenType.EQ)
return self.expression(
exp.PartitionedByProperty,
this=exp.Literal.string("PARTITIONED_BY"),
value=self._parse_schema() or self._parse_bracket(self._parse_field()),
)
def _parse_stored(self):
self._match(TokenType.ALIAS)
self._match(TokenType.EQ)
return self.expression(
exp.FileFormatProperty,
this=exp.Literal.string("FORMAT"),
value=exp.Literal.string(self._parse_var_or_string().name),
this=self._parse_schema() or self._parse_bracket(self._parse_field()),
)
def _parse_distkey(self):
self._match_l_paren()
this = exp.Literal.string("DISTKEY")
value = exp.Literal.string(self._parse_var().name)
self._match_r_paren()
return self.expression(
exp.DistKeyProperty,
this=this,
value=value,
)
return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_var))
def _parse_sortkey(self):
self._match_l_paren()
this = exp.Literal.string("SORTKEY")
value = exp.Literal.string(self._parse_var().name)
self._match_r_paren()
return self.expression(
exp.SortKeyProperty,
this=this,
value=value,
)
def _parse_create_like(self):
table = self._parse_table(schema=True)
options = []
while self._match_texts(("INCLUDING", "EXCLUDING")):
options.append(
self.expression(
exp.Property,
this=self._prev.text.upper(),
value=exp.Var(this=self._parse_id_var().this.upper()),
)
)
return self.expression(exp.LikeProperty, this=table, expressions=options)
def _parse_diststyle(self):
this = exp.Literal.string("DISTSTYLE")
value = exp.Literal.string(self._parse_var().name)
def _parse_sortkey(self, compound=False):
return self.expression(
exp.DistStyleProperty,
this=this,
value=value,
)
def _parse_auto_increment(self):
self._match(TokenType.EQ)
return self.expression(
exp.AutoIncrementProperty,
this=exp.Literal.string("AUTO_INCREMENT"),
value=self._parse_number(),
)
def _parse_schema_comment(self):
self._match(TokenType.EQ)
return self.expression(
exp.SchemaCommentProperty,
this=exp.Literal.string("COMMENT"),
value=self._parse_string(),
exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_var), compound=compound
)
def _parse_character_set(self, default=False):
self._match(TokenType.EQ)
return self.expression(
exp.CharacterSetProperty,
this=exp.Literal.string("CHARACTER_SET"),
value=self._parse_var_or_string(),
default=default,
exp.CharacterSetProperty, this=self._parse_var_or_string(), default=default
)
def _parse_returns(self):
@ -931,20 +923,7 @@ class Parser(metaclass=_Parser):
else:
value = self._parse_types()
return self.expression(
exp.ReturnsProperty,
this=exp.Literal.string("RETURNS"),
value=value,
is_table=is_table,
)
def _parse_execute_as(self):
self._match(TokenType.ALIAS)
return self.expression(
exp.ExecuteAsProperty,
this=exp.Literal.string("EXECUTE AS"),
value=self._parse_var(),
)
return self.expression(exp.ReturnsProperty, this=value, is_table=is_table)
def _parse_properties(self):
properties = []
@ -956,7 +935,7 @@ class Parser(metaclass=_Parser):
properties.extend(
self._parse_wrapped_csv(
lambda: self.expression(
exp.AnonymousProperty,
exp.Property,
this=self._parse_string(),
value=self._match(TokenType.EQ) and self._parse_string(),
)
@ -1076,7 +1055,12 @@ class Parser(metaclass=_Parser):
options = []
if self._match(TokenType.OPTIONS):
options = self._parse_wrapped_csv(self._parse_string, sep=TokenType.EQ)
self._match_l_paren()
k = self._parse_string()
self._match(TokenType.EQ)
v = self._parse_string()
options = [k, v]
self._match_r_paren()
self._match(TokenType.ALIAS)
return self.expression(
@ -1116,7 +1100,7 @@ class Parser(metaclass=_Parser):
self.raise_error(f"{this.key} does not support CTE")
this = cte
elif self._match(TokenType.SELECT):
comment = self._prev_comment
comments = self._prev_comments
hint = self._parse_hint()
all_ = self._match(TokenType.ALL)
@ -1141,10 +1125,16 @@ class Parser(metaclass=_Parser):
expressions=expressions,
limit=limit,
)
this.comment = comment
this.comments = comments
into = self._parse_into()
if into:
this.set("into", into)
from_ = self._parse_from()
if from_:
this.set("from", from_)
self._parse_query_modifiers(this)
elif (table or nested) and self._match(TokenType.L_PAREN):
this = self._parse_table() if table else self._parse_select(nested=True)
@ -1248,11 +1238,24 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Hint, expressions=hints)
return None
def _parse_into(self):
if not self._match(TokenType.INTO):
return None
temp = self._match(TokenType.TEMPORARY)
unlogged = self._match(TokenType.UNLOGGED)
self._match(TokenType.TABLE)
return self.expression(
exp.Into, this=self._parse_table(schema=True), temporary=temp, unlogged=unlogged
)
def _parse_from(self):
if not self._match(TokenType.FROM):
return None
return self.expression(exp.From, expressions=self._parse_csv(self._parse_table))
return self.expression(
exp.From, comments=self._prev_comments, expressions=self._parse_csv(self._parse_table)
)
def _parse_lateral(self):
outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY)
@ -1515,7 +1518,9 @@ class Parser(metaclass=_Parser):
def _parse_where(self, skip_where_token=False):
if not skip_where_token and not self._match(TokenType.WHERE):
return None
return self.expression(exp.Where, this=self._parse_conjunction())
return self.expression(
exp.Where, comments=self._prev_comments, this=self._parse_conjunction()
)
def _parse_group(self, skip_group_by_token=False):
if not skip_group_by_token and not self._match(TokenType.GROUP_BY):
@ -1737,12 +1742,8 @@ class Parser(metaclass=_Parser):
return self._parse_tokens(self._parse_unary, self.FACTOR)
def _parse_unary(self):
if self._match(TokenType.NOT):
return self.expression(exp.Not, this=self._parse_equality())
if self._match(TokenType.TILDA):
return self.expression(exp.BitwiseNot, this=self._parse_unary())
if self._match(TokenType.DASH):
return self.expression(exp.Neg, this=self._parse_unary())
if self._match_set(self.UNARY_PARSERS):
return self.UNARY_PARSERS[self._prev.token_type](self)
return self._parse_at_time_zone(self._parse_type())
def _parse_type(self):
@ -1775,17 +1776,6 @@ class Parser(metaclass=_Parser):
expressions = None
maybe_func = False
if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
return exp.DataType(
this=exp.DataType.Type.ARRAY,
expressions=[exp.DataType.build(type_token.value)],
nested=True,
)
if self._match(TokenType.L_BRACKET):
self._retreat(index)
return None
if self._match(TokenType.L_PAREN):
if is_struct:
expressions = self._parse_csv(self._parse_struct_kwargs)
@ -1801,6 +1791,17 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
maybe_func = True
if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
return exp.DataType(
this=exp.DataType.Type.ARRAY,
expressions=[exp.DataType.build(type_token.value, expressions=expressions)],
nested=True,
)
if self._match(TokenType.L_BRACKET):
self._retreat(index)
return None
if nested and self._match(TokenType.LT):
if is_struct:
expressions = self._parse_csv(self._parse_struct_kwargs)
@ -1904,7 +1905,7 @@ class Parser(metaclass=_Parser):
return exp.Literal.number(f"0.{self._prev.text}")
if self._match(TokenType.L_PAREN):
comment = self._prev_comment
comments = self._prev_comments
query = self._parse_select()
if query:
@ -1924,8 +1925,8 @@ class Parser(metaclass=_Parser):
this = self.expression(exp.Tuple, expressions=expressions)
else:
this = self.expression(exp.Paren, this=this)
if comment:
this.comment = comment
if comments:
this.comments = comments
return this
return None
@ -2098,7 +2099,10 @@ class Parser(metaclass=_Parser):
elif self._match(TokenType.SCHEMA_COMMENT):
kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string())
elif self._match(TokenType.PRIMARY_KEY):
kind = exp.PrimaryKeyColumnConstraint()
desc = None
if self._match(TokenType.ASC) or self._match(TokenType.DESC):
desc = self._prev.token_type == TokenType.DESC
kind = exp.PrimaryKeyColumnConstraint(desc=desc)
elif self._match(TokenType.UNIQUE):
kind = exp.UniqueColumnConstraint()
elif self._match(TokenType.GENERATED):
@ -2189,7 +2193,7 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.R_BRACKET):
self.raise_error("Expected ]")
this.comment = self._prev_comment
this.comments = self._prev_comments
return self._parse_bracket(this)
def _parse_case(self):
@ -2256,6 +2260,33 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
def _parse_string_agg(self):
if self._match(TokenType.DISTINCT):
args = self._parse_csv(self._parse_conjunction)
expression = self.expression(exp.Distinct, expressions=[seq_get(args, 0)])
else:
args = self._parse_csv(self._parse_conjunction)
expression = seq_get(args, 0)
index = self._index
if not self._match(TokenType.R_PAREN):
# postgres: STRING_AGG([DISTINCT] expression, separator [ORDER BY expression1 {ASC | DESC} [, ...]])
order = self._parse_order(this=expression)
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
# Checks if we can parse an order clause: WITHIN GROUP (ORDER BY <order_by_expression_list> [ASC | DESC]).
# This is done "manually", instead of letting _parse_window parse it into an exp.WithinGroup node, so that
# the STRING_AGG call is parsed like in MySQL / SQLite and can thus be transpiled more easily to them.
if not self._match(TokenType.WITHIN_GROUP):
self._retreat(index)
this = exp.GroupConcat.from_arg_list(args)
self.validate_expression(this, args)
return this
self._match_l_paren() # The corresponding match_r_paren will be called in parse_function (caller)
order = self._parse_order(this=expression)
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
def _parse_convert(self, strict):
this = self._parse_column()
if self._match(TokenType.USING):
@ -2511,8 +2542,8 @@ class Parser(metaclass=_Parser):
items = [parse_result] if parse_result is not None else []
while self._match(sep):
if parse_result and self._prev_comment is not None:
parse_result.comment = self._prev_comment
if parse_result and self._prev_comments:
parse_result.comments = self._prev_comments
parse_result = parse_method()
if parse_result is not None:
@ -2525,7 +2556,10 @@ class Parser(metaclass=_Parser):
while self._match_set(expressions):
this = self.expression(
expressions[self._prev.token_type], this=this, expression=parse_method()
expressions[self._prev.token_type],
this=this,
comments=self._prev_comments,
expression=parse_method(),
)
return this
@ -2566,6 +2600,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Transaction, this=this, modes=modes)
def _parse_commit_or_rollback(self):
chain = None
savepoint = None
is_rollback = self._prev.token_type == TokenType.ROLLBACK
@ -2575,9 +2610,13 @@ class Parser(metaclass=_Parser):
self._match_text_seq("SAVEPOINT")
savepoint = self._parse_id_var()
if self._match(TokenType.AND):
chain = not self._match_text_seq("NO")
self._match_text_seq("CHAIN")
if is_rollback:
return self.expression(exp.Rollback, savepoint=savepoint)
return self.expression(exp.Commit)
return self.expression(exp.Commit, chain=chain)
def _parse_show(self):
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
@ -2651,14 +2690,14 @@ class Parser(metaclass=_Parser):
def _match_l_paren(self, expression=None):
if not self._match(TokenType.L_PAREN):
self.raise_error("Expecting (")
if expression and self._prev_comment:
expression.comment = self._prev_comment
if expression and self._prev_comments:
expression.comments = self._prev_comments
def _match_r_paren(self, expression=None):
if not self._match(TokenType.R_PAREN):
self.raise_error("Expecting )")
if expression and self._prev_comment:
expression.comment = self._prev_comment
if expression and self._prev_comments:
expression.comments = self._prev_comments
def _match_texts(self, texts):
if self._curr and self._curr.text.upper() in texts:

View file

@ -130,18 +130,20 @@ class Step:
aggregations = []
sequence = itertools.count()
for e in expression.expressions:
aggregation = e.find(exp.AggFunc)
if aggregation:
projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
aggregations.append(e)
for operand in aggregation.unnest_operands():
def extract_agg_operands(expression):
for agg in expression.find_all(exp.AggFunc):
for operand in agg.unnest_operands():
if isinstance(operand, exp.Column):
continue
if operand not in operands:
operands[operand] = f"_a_{next(sequence)}"
operand.replace(exp.column(operands[operand], quoted=True))
for e in expression.expressions:
if e.find(exp.AggFunc):
projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
aggregations.append(e)
extract_agg_operands(e)
else:
projections.append(e)
@ -156,6 +158,13 @@ class Step:
aggregate = Aggregate()
aggregate.source = step.name
aggregate.name = step.name
having = expression.args.get("having")
if having:
extract_agg_operands(having)
aggregate.condition = having.this
aggregate.operands = tuple(
alias(operand, alias_) for operand, alias_ in operands.items()
)
@ -172,11 +181,6 @@ class Step:
aggregate.add_dependency(step)
step = aggregate
having = expression.args.get("having")
if having:
step.condition = having.this
order = expression.args.get("order")
if order:
@ -188,6 +192,17 @@ class Step:
step.projections = projections
if isinstance(expression, exp.Select) and expression.args.get("distinct"):
distinct = Aggregate()
distinct.source = step.name
distinct.name = step.name
distinct.group = {
e.alias_or_name: exp.column(col=e.alias_or_name, table=step.name)
for e in projections or expression.expressions
}
distinct.add_dependency(step)
step = distinct
limit = expression.args.get("limit")
if limit:
@ -231,6 +246,9 @@ class Step:
if self.condition:
lines.append(f"{nested}Condition: {self.condition.sql()}")
if self.limit is not math.inf:
lines.append(f"{nested}Limit: {self.limit}")
if self.dependencies:
lines.append(f"{nested}Dependencies:")
for dependency in self.dependencies:
@ -258,12 +276,7 @@ class Scan(Step):
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
) -> Step:
table = expression
alias_ = expression.alias
if not alias_:
raise UnsupportedError(
"Tables/Subqueries must be aliased. Run it through the optimizer"
)
alias_ = expression.alias_or_name
if isinstance(expression, exp.Subquery):
table = expression.this
@ -338,6 +351,9 @@ class Aggregate(Step):
lines.append(f"{indent}Group:")
for expression in self.group.values():
lines.append(f"{indent} - {expression.sql()}")
if self.condition:
lines.append(f"{indent}Having:")
lines.append(f"{indent} - {self.condition.sql()}")
if self.operands:
lines.append(f"{indent}Operands:")
for expression in self.operands:

View file

@ -81,6 +81,7 @@ class TokenType(AutoName):
BINARY = auto()
VARBINARY = auto()
JSON = auto()
JSONB = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto()
@ -91,6 +92,7 @@ class TokenType(AutoName):
NULLABLE = auto()
GEOMETRY = auto()
HLLSKETCH = auto()
HSTORE = auto()
SUPER = auto()
SERIAL = auto()
SMALLSERIAL = auto()
@ -113,6 +115,7 @@ class TokenType(AutoName):
APPLY = auto()
ARRAY = auto()
ASC = auto()
ASOF = auto()
AT_TIME_ZONE = auto()
AUTO_INCREMENT = auto()
BEGIN = auto()
@ -130,6 +133,7 @@ class TokenType(AutoName):
COMMAND = auto()
COMMENT = auto()
COMMIT = auto()
COMPOUND = auto()
CONSTRAINT = auto()
CREATE = auto()
CROSS = auto()
@ -271,6 +275,7 @@ class TokenType(AutoName):
UNBOUNDED = auto()
UNCACHE = auto()
UNION = auto()
UNLOGGED = auto()
UNNEST = auto()
UNPIVOT = auto()
UPDATE = auto()
@ -291,7 +296,7 @@ class TokenType(AutoName):
class Token:
__slots__ = ("token_type", "text", "line", "col", "comment")
__slots__ = ("token_type", "text", "line", "col", "comments")
@classmethod
def number(cls, number: int) -> Token:
@ -319,13 +324,13 @@ class Token:
text: str,
line: int = 1,
col: int = 1,
comment: t.Optional[str] = None,
comments: t.List[str] = [],
) -> None:
self.token_type = token_type
self.text = text
self.line = line
self.col = max(col - len(text), 1)
self.comment = comment
self.comments = comments
def __repr__(self) -> str:
attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
@ -452,6 +457,7 @@ class Tokenizer(metaclass=_Tokenizer):
"COLLATE": TokenType.COLLATE,
"COMMENT": TokenType.SCHEMA_COMMENT,
"COMMIT": TokenType.COMMIT,
"COMPOUND": TokenType.COMPOUND,
"CONSTRAINT": TokenType.CONSTRAINT,
"CREATE": TokenType.CREATE,
"CROSS": TokenType.CROSS,
@ -582,8 +588,9 @@ class Tokenizer(metaclass=_Tokenizer):
"TRAILING": TokenType.TRAILING,
"UNBOUNDED": TokenType.UNBOUNDED,
"UNION": TokenType.UNION,
"UNPIVOT": TokenType.UNPIVOT,
"UNLOGGED": TokenType.UNLOGGED,
"UNNEST": TokenType.UNNEST,
"UNPIVOT": TokenType.UNPIVOT,
"UPDATE": TokenType.UPDATE,
"USE": TokenType.USE,
"USING": TokenType.USING,
@ -686,12 +693,12 @@ class Tokenizer(metaclass=_Tokenizer):
"_current",
"_line",
"_col",
"_comment",
"_comments",
"_char",
"_end",
"_peek",
"_prev_token_line",
"_prev_token_comment",
"_prev_token_comments",
"_prev_token_type",
"_replace_backslash",
)
@ -708,13 +715,13 @@ class Tokenizer(metaclass=_Tokenizer):
self._current = 0
self._line = 1
self._col = 1
self._comment = None
self._comments: t.List[str] = []
self._char = None
self._end = None
self._peek = None
self._prev_token_line = -1
self._prev_token_comment = None
self._prev_token_comments: t.List[str] = []
self._prev_token_type = None
def tokenize(self, sql: str) -> t.List[Token]:
@ -767,7 +774,7 @@ class Tokenizer(metaclass=_Tokenizer):
def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
self._prev_token_line = self._line
self._prev_token_comment = self._comment
self._prev_token_comments = self._comments
self._prev_token_type = token_type # type: ignore
self.tokens.append(
Token(
@ -775,10 +782,10 @@ class Tokenizer(metaclass=_Tokenizer):
self._text if text is None else text,
self._line,
self._col,
self._comment,
self._comments,
)
)
self._comment = None
self._comments = []
if token_type in self.COMMANDS and (
len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
@ -857,22 +864,18 @@ class Tokenizer(metaclass=_Tokenizer):
while not self._end and self._chars(comment_end_size) != comment_end:
self._advance()
self._comment = self._text[comment_start_size : -comment_end_size + 1] # type: ignore
self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) # type: ignore
self._advance(comment_end_size - 1)
else:
while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: # type: ignore
self._advance()
self._comment = self._text[comment_start_size:] # type: ignore
# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. If both
# types of comment can be attached to a token, the trailing one is discarded in favour of the leading one.
self._comments.append(self._text[comment_start_size:]) # type: ignore
# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding.
# Multiple consecutive comments are preserved by appending them to the current comments list.
if comment_start_line == self._prev_token_line:
if self._prev_token_comment is None:
self.tokens[-1].comment = self._comment
self._prev_token_comment = self._comment
self._comment = None
self.tokens[-1].comments.extend(self._comments)
self._comments = []
return True

View file

@ -2,6 +2,8 @@ from __future__ import annotations
import typing as t
from sqlglot.helper import find_new_name
if t.TYPE_CHECKING:
from sqlglot.generator import Generator
@ -43,6 +45,43 @@ def unalias_group(expression: exp.Expression) -> exp.Expression:
return expression
def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
"""
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Args:
expression: the expression that will be transformed.
Returns:
The transformed expression.
"""
if (
isinstance(expression, exp.Select)
and expression.args.get("distinct")
and expression.args["distinct"].args.get("on")
and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
):
distinct_cols = [e.copy() for e in expression.args["distinct"].args["on"].expressions]
outer_selects = [e.copy() for e in expression.expressions]
nested = expression.copy()
nested.args["distinct"].pop()
row_number = find_new_name(expression.named_selects, "_row_number")
window = exp.Window(
this=exp.RowNumber(),
partition_by=distinct_cols,
)
order = nested.args.get("order")
if order:
window.set("order", order.copy())
order.pop()
window = exp.alias_(window, row_number)
nested.select(window, copy=False)
return exp.select(*outer_selects).from_(nested.subquery()).where(f'"{row_number}" = 1')
return expression
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
to_sql: t.Callable[[Generator, exp.Expression], str],
@ -81,3 +120,4 @@ def delegate(attr: str) -> t.Callable:
UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))}
ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))}