Merging upstream version 10.1.3.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
582b160275
commit
a5128ea109
57 changed files with 1542 additions and 529 deletions
|
@ -30,7 +30,7 @@ from sqlglot.parser import Parser
|
|||
from sqlglot.schema import MappingSchema
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
|
||||
__version__ = "10.0.8"
|
||||
__version__ = "10.1.3"
|
||||
|
||||
pretty = False
|
||||
|
||||
|
|
|
@ -56,12 +56,12 @@ def _derived_table_values_to_unnest(self, expression):
|
|||
|
||||
|
||||
def _returnsproperty_sql(self, expression):
|
||||
value = expression.args.get("value")
|
||||
if isinstance(value, exp.Schema):
|
||||
value = f"{value.this} <{self.expressions(value)}>"
|
||||
this = expression.this
|
||||
if isinstance(this, exp.Schema):
|
||||
this = f"{this.this} <{self.expressions(this)}>"
|
||||
else:
|
||||
value = self.sql(value)
|
||||
return f"RETURNS {value}"
|
||||
this = self.sql(this)
|
||||
return f"RETURNS {this}"
|
||||
|
||||
|
||||
def _create_sql(self, expression):
|
||||
|
@ -142,6 +142,11 @@ class BigQuery(Dialect):
|
|||
),
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
}
|
||||
FUNCTION_PARSERS.pop("TRIM")
|
||||
|
||||
NO_PAREN_FUNCTIONS = {
|
||||
**parser.Parser.NO_PAREN_FUNCTIONS,
|
||||
TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
|
||||
|
@ -174,6 +179,7 @@ class BigQuery(Dialect):
|
|||
exp.Values: _derived_table_values_to_unnest,
|
||||
exp.ReturnsProperty: _returnsproperty_sql,
|
||||
exp.Create: _create_sql,
|
||||
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
|
||||
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
|
||||
if e.name == "IMMUTABLE"
|
||||
else "NOT DETERMINISTIC",
|
||||
|
@ -200,9 +206,7 @@ class BigQuery(Dialect):
|
|||
exp.VolatilityProperty,
|
||||
}
|
||||
|
||||
WITH_PROPERTIES = {
|
||||
exp.AnonymousProperty,
|
||||
}
|
||||
WITH_PROPERTIES = {exp.Property}
|
||||
|
||||
EXPLICIT_UNION = True
|
||||
|
||||
|
|
|
@ -21,14 +21,15 @@ class ClickHouse(Dialect):
|
|||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"FINAL": TokenType.FINAL,
|
||||
"ASOF": TokenType.ASOF,
|
||||
"DATETIME64": TokenType.DATETIME,
|
||||
"INT8": TokenType.TINYINT,
|
||||
"FINAL": TokenType.FINAL,
|
||||
"FLOAT32": TokenType.FLOAT,
|
||||
"FLOAT64": TokenType.DOUBLE,
|
||||
"INT16": TokenType.SMALLINT,
|
||||
"INT32": TokenType.INT,
|
||||
"INT64": TokenType.BIGINT,
|
||||
"FLOAT32": TokenType.FLOAT,
|
||||
"FLOAT64": TokenType.DOUBLE,
|
||||
"INT8": TokenType.TINYINT,
|
||||
"TUPLE": TokenType.STRUCT,
|
||||
}
|
||||
|
||||
|
@ -38,6 +39,10 @@ class ClickHouse(Dialect):
|
|||
"MAP": parse_var_map,
|
||||
}
|
||||
|
||||
JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF}
|
||||
|
||||
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY}
|
||||
|
||||
def _parse_table(self, schema=False):
|
||||
this = super()._parse_table(schema)
|
||||
|
||||
|
|
|
@ -289,19 +289,19 @@ def struct_extract_sql(self, expression):
|
|||
return f"{this}.{struct_key}"
|
||||
|
||||
|
||||
def var_map_sql(self, expression):
|
||||
def var_map_sql(self, expression, map_func_name="MAP"):
|
||||
keys = expression.args["keys"]
|
||||
values = expression.args["values"]
|
||||
|
||||
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
|
||||
self.unsupported("Cannot convert array columns into map.")
|
||||
return f"MAP({self.format_args(keys, values)})"
|
||||
return f"{map_func_name}({self.format_args(keys, values)})"
|
||||
|
||||
args = []
|
||||
for key, value in zip(keys.expressions, values.expressions):
|
||||
args.append(self.sql(key))
|
||||
args.append(self.sql(value))
|
||||
return f"MAP({self.format_args(*args)})"
|
||||
return f"{map_func_name}({self.format_args(*args)})"
|
||||
|
||||
|
||||
def format_time_lambda(exp_class, dialect, default=None):
|
||||
|
@ -336,18 +336,13 @@ def create_with_partitions_sql(self, expression):
|
|||
if has_schema and is_partitionable:
|
||||
expression = expression.copy()
|
||||
prop = expression.find(exp.PartitionedByProperty)
|
||||
value = prop and prop.args.get("value")
|
||||
if prop and not isinstance(value, exp.Schema):
|
||||
this = prop and prop.this
|
||||
if prop and not isinstance(this, exp.Schema):
|
||||
schema = expression.this
|
||||
columns = {v.name.upper() for v in value.expressions}
|
||||
columns = {v.name.upper() for v in this.expressions}
|
||||
partitions = [col for col in schema.expressions if col.name.upper() in columns]
|
||||
schema.set(
|
||||
"expressions",
|
||||
[e for e in schema.expressions if e not in partitions],
|
||||
)
|
||||
prop.replace(
|
||||
exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions))
|
||||
)
|
||||
schema.set("expressions", [e for e in schema.expressions if e not in partitions])
|
||||
prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
|
||||
expression.set("this", schema)
|
||||
|
||||
return self.create_sql(expression)
|
||||
|
|
|
@ -153,7 +153,7 @@ class Drill(Dialect):
|
|||
exp.If: if_sql,
|
||||
exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
|
||||
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
|
||||
exp.StrPosition: str_position_sql,
|
||||
|
|
|
@ -61,9 +61,7 @@ def _array_sort(self, expression):
|
|||
|
||||
|
||||
def _property_sql(self, expression):
|
||||
key = expression.name
|
||||
value = self.sql(expression, "value")
|
||||
return f"'{key}'={value}"
|
||||
return f"'{expression.name}'={self.sql(expression, 'value')}"
|
||||
|
||||
|
||||
def _str_to_unix(self, expression):
|
||||
|
@ -250,7 +248,7 @@ class Hive(Dialect):
|
|||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
exp.AnonymousProperty: _property_sql,
|
||||
exp.Property: _property_sql,
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.ArrayAgg: rename_func("COLLECT_LIST"),
|
||||
exp.ArrayConcat: rename_func("CONCAT"),
|
||||
|
@ -262,7 +260,7 @@ class Hive(Dialect):
|
|||
exp.DateStrToDate: rename_func("TO_DATE"),
|
||||
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
|
||||
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
|
||||
exp.FileFormatProperty: lambda self, e: f"STORED AS {e.text('value').upper()}",
|
||||
exp.FileFormatProperty: lambda self, e: f"STORED AS {e.name.upper()}",
|
||||
exp.If: if_sql,
|
||||
exp.Index: _index_sql,
|
||||
exp.ILike: no_ilike_sql,
|
||||
|
@ -285,7 +283,7 @@ class Hive(Dialect):
|
|||
exp.StrToTime: _str_to_time,
|
||||
exp.StrToUnix: _str_to_unix,
|
||||
exp.StructExtract: struct_extract_sql,
|
||||
exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'value')}",
|
||||
exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}",
|
||||
exp.TimeStrToDate: rename_func("TO_DATE"),
|
||||
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
|
||||
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
|
@ -298,11 +296,11 @@ class Hive(Dialect):
|
|||
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.format_args(e.this, _time_format(self, e))})",
|
||||
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
||||
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'value')}",
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
|
||||
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
|
||||
}
|
||||
|
||||
WITH_PROPERTIES = {exp.AnonymousProperty}
|
||||
WITH_PROPERTIES = {exp.Property}
|
||||
|
||||
ROOT_PROPERTIES = {
|
||||
exp.PartitionedByProperty,
|
||||
|
|
|
@ -453,6 +453,7 @@ class MySQL(Dialect):
|
|||
exp.CharacterSetProperty,
|
||||
exp.CollateProperty,
|
||||
exp.SchemaCommentProperty,
|
||||
exp.LikeProperty,
|
||||
}
|
||||
|
||||
WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, tokens, transforms
|
||||
from sqlglot.dialects.dialect import Dialect, no_ilike_sql
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func
|
||||
from sqlglot.helper import csv
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
@ -37,6 +37,12 @@ class Oracle(Dialect):
|
|||
"YYYY": "%Y", # 2015
|
||||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"DECODE": exp.Matches.from_arg_list,
|
||||
}
|
||||
|
||||
class Generator(generator.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
|
@ -58,6 +64,7 @@ class Oracle(Dialect):
|
|||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.Limit: _limit_sql,
|
||||
exp.Matches: rename_func("DECODE"),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
|
||||
|
|
|
@ -74,6 +74,27 @@ def _trim_sql(self, expression):
|
|||
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
|
||||
|
||||
|
||||
def _string_agg_sql(self, expression):
|
||||
expression = expression.copy()
|
||||
separator = expression.args.get("separator") or exp.Literal.string(",")
|
||||
|
||||
order = ""
|
||||
this = expression.this
|
||||
if isinstance(this, exp.Order):
|
||||
if this.this:
|
||||
this = this.this
|
||||
this.pop()
|
||||
order = self.sql(expression.this) # Order has a leading space
|
||||
|
||||
return f"STRING_AGG({self.format_args(this, separator)}{order})"
|
||||
|
||||
|
||||
def _datatype_sql(self, expression):
|
||||
if expression.this == exp.DataType.Type.ARRAY:
|
||||
return f"{self.expressions(expression, flat=True)}[]"
|
||||
return self.datatype_sql(expression)
|
||||
|
||||
|
||||
def _auto_increment_to_serial(expression):
|
||||
auto = expression.find(exp.AutoIncrementColumnConstraint)
|
||||
|
||||
|
@ -191,25 +212,27 @@ class Postgres(Dialect):
|
|||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"ALWAYS": TokenType.ALWAYS,
|
||||
"BY DEFAULT": TokenType.BY_DEFAULT,
|
||||
"IDENTITY": TokenType.IDENTITY,
|
||||
"GENERATED": TokenType.GENERATED,
|
||||
"DOUBLE PRECISION": TokenType.DOUBLE,
|
||||
"BIGSERIAL": TokenType.BIGSERIAL,
|
||||
"SERIAL": TokenType.SERIAL,
|
||||
"SMALLSERIAL": TokenType.SMALLSERIAL,
|
||||
"UUID": TokenType.UUID,
|
||||
"TEMP": TokenType.TEMPORARY,
|
||||
"BEGIN TRANSACTION": TokenType.BEGIN,
|
||||
"BEGIN": TokenType.COMMAND,
|
||||
"BEGIN TRANSACTION": TokenType.BEGIN,
|
||||
"BIGSERIAL": TokenType.BIGSERIAL,
|
||||
"BY DEFAULT": TokenType.BY_DEFAULT,
|
||||
"COMMENT ON": TokenType.COMMAND,
|
||||
"DECLARE": TokenType.COMMAND,
|
||||
"DO": TokenType.COMMAND,
|
||||
"DOUBLE PRECISION": TokenType.DOUBLE,
|
||||
"GENERATED": TokenType.GENERATED,
|
||||
"GRANT": TokenType.COMMAND,
|
||||
"HSTORE": TokenType.HSTORE,
|
||||
"IDENTITY": TokenType.IDENTITY,
|
||||
"JSONB": TokenType.JSONB,
|
||||
"REFRESH": TokenType.COMMAND,
|
||||
"REINDEX": TokenType.COMMAND,
|
||||
"RESET": TokenType.COMMAND,
|
||||
"REVOKE": TokenType.COMMAND,
|
||||
"GRANT": TokenType.COMMAND,
|
||||
"SERIAL": TokenType.SERIAL,
|
||||
"SMALLSERIAL": TokenType.SMALLSERIAL,
|
||||
"TEMP": TokenType.TEMPORARY,
|
||||
"UUID": TokenType.UUID,
|
||||
**{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
|
||||
**{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
|
||||
}
|
||||
|
@ -265,4 +288,7 @@ class Postgres(Dialect):
|
|||
exp.Trim: _trim_sql,
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.GroupConcat: _string_agg_sql,
|
||||
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
|
||||
}
|
||||
|
|
|
@ -171,16 +171,7 @@ class Presto(Dialect):
|
|||
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
||||
ROOT_PROPERTIES = {
|
||||
exp.SchemaCommentProperty,
|
||||
}
|
||||
|
||||
WITH_PROPERTIES = {
|
||||
exp.PartitionedByProperty,
|
||||
exp.FileFormatProperty,
|
||||
exp.AnonymousProperty,
|
||||
exp.TableFormatProperty,
|
||||
}
|
||||
ROOT_PROPERTIES = {exp.SchemaCommentProperty}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
|
@ -231,7 +222,8 @@ class Presto(Dialect):
|
|||
exp.StrToTime: _str_to_time_sql,
|
||||
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
|
||||
exp.StructExtract: struct_extract_sql,
|
||||
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT = '{e.text('value').upper()}'",
|
||||
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'",
|
||||
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
|
||||
exp.TimeStrToDate: _date_parse_sql,
|
||||
exp.TimeStrToTime: _date_parse_sql,
|
||||
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))",
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot import exp, transforms
|
||||
from sqlglot.dialects.postgres import Postgres
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
@ -18,12 +18,14 @@ class Redshift(Postgres):
|
|||
|
||||
KEYWORDS = {
|
||||
**Postgres.Tokenizer.KEYWORDS, # type: ignore
|
||||
"COPY": TokenType.COMMAND,
|
||||
"GEOMETRY": TokenType.GEOMETRY,
|
||||
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
||||
"HLLSKETCH": TokenType.HLLSKETCH,
|
||||
"SUPER": TokenType.SUPER,
|
||||
"TIME": TokenType.TIMESTAMP,
|
||||
"TIMETZ": TokenType.TIMESTAMPTZ,
|
||||
"UNLOAD": TokenType.COMMAND,
|
||||
"VARBYTE": TokenType.VARBINARY,
|
||||
"SIMILAR TO": TokenType.SIMILAR_TO,
|
||||
}
|
||||
|
@ -35,3 +37,17 @@ class Redshift(Postgres):
|
|||
exp.DataType.Type.VARBINARY: "VARBYTE",
|
||||
exp.DataType.Type.INT: "INTEGER",
|
||||
}
|
||||
|
||||
ROOT_PROPERTIES = {
|
||||
exp.DistKeyProperty,
|
||||
exp.SortKeyProperty,
|
||||
exp.DistStyleProperty,
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**Postgres.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
|
||||
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
|
||||
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
|
||||
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import (
|
|||
format_time_lambda,
|
||||
inline_array_sql,
|
||||
rename_func,
|
||||
var_map_sql,
|
||||
)
|
||||
from sqlglot.expressions import Literal
|
||||
from sqlglot.helper import seq_get
|
||||
|
@ -100,6 +101,14 @@ def _parse_date_part(self):
|
|||
return self.expression(exp.Extract, this=this, expression=expression)
|
||||
|
||||
|
||||
def _datatype_sql(self, expression):
|
||||
if expression.this == exp.DataType.Type.ARRAY:
|
||||
return "ARRAY"
|
||||
elif expression.this == exp.DataType.Type.MAP:
|
||||
return "OBJECT"
|
||||
return self.datatype_sql(expression)
|
||||
|
||||
|
||||
class Snowflake(Dialect):
|
||||
null_ordering = "nulls_are_large"
|
||||
time_format = "'yyyy-mm-dd hh24:mi:ss'"
|
||||
|
@ -142,6 +151,8 @@ class Snowflake(Dialect):
|
|||
"TO_TIMESTAMP": _snowflake_to_timestamp,
|
||||
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
|
||||
"RLIKE": exp.RegexpLike.from_arg_list,
|
||||
"DECODE": exp.Matches.from_arg_list,
|
||||
"OBJECT_CONSTRUCT": parser.parse_var_map,
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
|
@ -195,16 +206,20 @@ class Snowflake(Dialect):
|
|||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.ArrayConcat: rename_func("ARRAY_CAT"),
|
||||
exp.If: rename_func("IFF"),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
|
||||
exp.Array: inline_array_sql,
|
||||
exp.StrPosition: rename_func("POSITION"),
|
||||
exp.ArrayConcat: rename_func("ARRAY_CAT"),
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.If: rename_func("IFF"),
|
||||
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
exp.Matches: rename_func("DECODE"),
|
||||
exp.StrPosition: rename_func("POSITION"),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
|
||||
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
|
|
|
@ -98,7 +98,7 @@ class Spark(Hive):
|
|||
TRANSFORMS = {
|
||||
**Hive.Generator.TRANSFORMS, # type: ignore
|
||||
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
|
||||
exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}",
|
||||
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
|
||||
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
|
||||
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
|
||||
|
|
|
@ -13,6 +13,23 @@ from sqlglot.dialects.dialect import (
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
# https://www.sqlite.org/lang_aggfunc.html#group_concat
|
||||
def _group_concat_sql(self, expression):
|
||||
this = expression.this
|
||||
distinct = expression.find(exp.Distinct)
|
||||
if distinct:
|
||||
this = distinct.expressions[0]
|
||||
distinct = "DISTINCT "
|
||||
|
||||
if isinstance(expression.this, exp.Order):
|
||||
self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.")
|
||||
if expression.this.this and not distinct:
|
||||
this = expression.this.this
|
||||
|
||||
separator = expression.args.get("separator")
|
||||
return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})"
|
||||
|
||||
|
||||
class SQLite(Dialect):
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
IDENTIFIERS = ['"', ("[", "]"), "`"]
|
||||
|
@ -62,6 +79,7 @@ class SQLite(Dialect):
|
|||
exp.Levenshtein: rename_func("EDITDIST3"),
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.GroupConcat: _group_concat_sql,
|
||||
}
|
||||
|
||||
def transaction_sql(self, expression):
|
||||
|
|
|
@ -17,6 +17,7 @@ FULL_FORMAT_TIME_MAPPING = {
|
|||
"mm": "%B",
|
||||
"m": "%B",
|
||||
}
|
||||
|
||||
DATE_DELTA_INTERVAL = {
|
||||
"year": "year",
|
||||
"yyyy": "year",
|
||||
|
@ -37,11 +38,12 @@ DATE_DELTA_INTERVAL = {
|
|||
|
||||
|
||||
DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{1,2})")
|
||||
|
||||
# N = Numeric, C=Currency
|
||||
TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}
|
||||
|
||||
|
||||
def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
|
||||
def _format_time_lambda(exp_class, full_format_mapping=None, default=None):
|
||||
def _format_time(args):
|
||||
return exp_class(
|
||||
this=seq_get(args, 1),
|
||||
|
@ -58,7 +60,7 @@ def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
|
|||
return _format_time
|
||||
|
||||
|
||||
def parse_format(args):
|
||||
def _parse_format(args):
|
||||
fmt = seq_get(args, 1)
|
||||
number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this)
|
||||
if number_fmt:
|
||||
|
@ -78,7 +80,7 @@ def generate_date_delta_with_unit_sql(self, e):
|
|||
return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})"
|
||||
|
||||
|
||||
def generate_format_sql(self, e):
|
||||
def _format_sql(self, e):
|
||||
fmt = (
|
||||
e.args["format"]
|
||||
if isinstance(e, exp.NumberToStr)
|
||||
|
@ -87,6 +89,28 @@ def generate_format_sql(self, e):
|
|||
return f"FORMAT({self.format_args(e.this, fmt)})"
|
||||
|
||||
|
||||
def _string_agg_sql(self, e):
|
||||
e = e.copy()
|
||||
|
||||
this = e.this
|
||||
distinct = e.find(exp.Distinct)
|
||||
if distinct:
|
||||
# exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression
|
||||
self.unsupported("T-SQL STRING_AGG doesn't support DISTINCT.")
|
||||
this = distinct.expressions[0]
|
||||
distinct.pop()
|
||||
|
||||
order = ""
|
||||
if isinstance(e.this, exp.Order):
|
||||
if e.this.this:
|
||||
this = e.this.this
|
||||
e.this.this.pop()
|
||||
order = f" WITHIN GROUP ({self.sql(e.this)[1:]})" # Order has a leading space
|
||||
|
||||
separator = e.args.get("separator") or exp.Literal.string(",")
|
||||
return f"STRING_AGG({self.format_args(this, separator)}){order}"
|
||||
|
||||
|
||||
class TSQL(Dialect):
|
||||
null_ordering = "nulls_are_small"
|
||||
time_format = "'yyyy-mm-dd hh:mm:ss'"
|
||||
|
@ -228,14 +252,14 @@ class TSQL(Dialect):
|
|||
"ISNULL": exp.Coalesce.from_arg_list,
|
||||
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
|
||||
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
|
||||
"DATENAME": tsql_format_time_lambda(exp.TimeToStr, full_format_mapping=True),
|
||||
"DATEPART": tsql_format_time_lambda(exp.TimeToStr),
|
||||
"DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True),
|
||||
"DATEPART": _format_time_lambda(exp.TimeToStr),
|
||||
"GETDATE": exp.CurrentDate.from_arg_list,
|
||||
"IIF": exp.If.from_arg_list,
|
||||
"LEN": exp.Length.from_arg_list,
|
||||
"REPLICATE": exp.Repeat.from_arg_list,
|
||||
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
|
||||
"FORMAT": parse_format,
|
||||
"FORMAT": _parse_format,
|
||||
}
|
||||
|
||||
VAR_LENGTH_DATATYPES = {
|
||||
|
@ -298,6 +322,7 @@ class TSQL(Dialect):
|
|||
exp.DateDiff: generate_date_delta_with_unit_sql,
|
||||
exp.CurrentDate: rename_func("GETDATE"),
|
||||
exp.If: rename_func("IIF"),
|
||||
exp.NumberToStr: generate_format_sql,
|
||||
exp.TimeToStr: generate_format_sql,
|
||||
exp.NumberToStr: _format_sql,
|
||||
exp.TimeToStr: _format_sql,
|
||||
exp.GroupConcat: _string_agg_sql,
|
||||
}
|
||||
|
|
|
@ -22,7 +22,40 @@ class UnsupportedError(SqlglotError):
|
|||
|
||||
|
||||
class ParseError(SqlglotError):
|
||||
pass
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
errors: t.Optional[t.List[t.Dict[str, t.Any]]] = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.errors = errors or []
|
||||
|
||||
@classmethod
|
||||
def new(
|
||||
cls,
|
||||
message: str,
|
||||
description: t.Optional[str] = None,
|
||||
line: t.Optional[int] = None,
|
||||
col: t.Optional[int] = None,
|
||||
start_context: t.Optional[str] = None,
|
||||
highlight: t.Optional[str] = None,
|
||||
end_context: t.Optional[str] = None,
|
||||
into_expression: t.Optional[str] = None,
|
||||
) -> ParseError:
|
||||
return cls(
|
||||
message,
|
||||
[
|
||||
{
|
||||
"description": description,
|
||||
"line": line,
|
||||
"col": col,
|
||||
"start_context": start_context,
|
||||
"highlight": highlight,
|
||||
"end_context": end_context,
|
||||
"into_expression": into_expression,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TokenError(SqlglotError):
|
||||
|
@ -41,9 +74,13 @@ class ExecuteError(SqlglotError):
|
|||
pass
|
||||
|
||||
|
||||
def concat_errors(errors: t.Sequence[t.Any], maximum: int) -> str:
|
||||
def concat_messages(errors: t.Sequence[t.Any], maximum: int) -> str:
|
||||
msg = [str(e) for e in errors[:maximum]]
|
||||
remaining = len(errors) - maximum
|
||||
if remaining > 0:
|
||||
msg.append(f"... and {remaining} more")
|
||||
return "\n\n".join(msg)
|
||||
|
||||
|
||||
def merge_errors(errors: t.Sequence[ParseError]) -> t.List[t.Dict[str, t.Any]]:
|
||||
return [e_dict for error in errors for e_dict in error.errors]
|
||||
|
|
|
@ -122,7 +122,6 @@ def interval(this, unit):
|
|||
|
||||
|
||||
ENV = {
|
||||
"__builtins__": {},
|
||||
"exp": exp,
|
||||
# aggs
|
||||
"SUM": filter_nulls(sum),
|
||||
|
|
|
@ -115,6 +115,9 @@ class PythonExecutor:
|
|||
sink = self.table(context.columns)
|
||||
|
||||
for reader in table_iter:
|
||||
if len(sink) >= step.limit:
|
||||
break
|
||||
|
||||
if condition and not context.eval(condition):
|
||||
continue
|
||||
|
||||
|
@ -123,9 +126,6 @@ class PythonExecutor:
|
|||
else:
|
||||
sink.append(reader.row)
|
||||
|
||||
if len(sink) >= step.limit:
|
||||
break
|
||||
|
||||
return self.context({step.name: sink})
|
||||
|
||||
def static(self):
|
||||
|
@ -288,21 +288,32 @@ class PythonExecutor:
|
|||
end = 1
|
||||
length = len(context.table)
|
||||
table = self.table(list(step.group) + step.aggregations)
|
||||
condition = self.generate(step.condition)
|
||||
|
||||
for i in range(length):
|
||||
context.set_index(i)
|
||||
key = context.eval_tuple(group_by)
|
||||
group = key if group is None else group
|
||||
end += 1
|
||||
if key != group:
|
||||
context.set_range(start, end - 2)
|
||||
table.append(group + context.eval_tuple(aggregations))
|
||||
group = key
|
||||
start = end - 2
|
||||
if i == length - 1:
|
||||
context.set_range(start, end - 1)
|
||||
def add_row():
|
||||
if not condition or context.eval(condition):
|
||||
table.append(group + context.eval_tuple(aggregations))
|
||||
|
||||
if length:
|
||||
for i in range(length):
|
||||
context.set_index(i)
|
||||
key = context.eval_tuple(group_by)
|
||||
group = key if group is None else group
|
||||
end += 1
|
||||
if key != group:
|
||||
context.set_range(start, end - 2)
|
||||
add_row()
|
||||
group = key
|
||||
start = end - 2
|
||||
if len(table.rows) >= step.limit:
|
||||
break
|
||||
if i == length - 1:
|
||||
context.set_range(start, end - 1)
|
||||
add_row()
|
||||
elif step.limit > 0:
|
||||
context.set_range(0, 0)
|
||||
table.append(context.eval_tuple(group_by) + context.eval_tuple(aggregations))
|
||||
|
||||
context = self.context({step.name: table, **{name: table for name in context.tables}})
|
||||
|
||||
if step.projections:
|
||||
|
@ -311,11 +322,9 @@ class PythonExecutor:
|
|||
|
||||
def sort(self, step, context):
|
||||
projections = self.generate_tuple(step.projections)
|
||||
|
||||
projection_columns = [p.alias_or_name for p in step.projections]
|
||||
all_columns = list(context.columns) + projection_columns
|
||||
sink = self.table(all_columns)
|
||||
|
||||
for reader, ctx in context:
|
||||
sink.append(reader.row + ctx.eval_tuple(projections))
|
||||
|
||||
|
@ -401,8 +410,9 @@ class Python(Dialect):
|
|||
exp.Boolean: lambda self, e: "True" if e.this else "False",
|
||||
exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})",
|
||||
exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
|
||||
exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})",
|
||||
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
|
||||
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}",
|
||||
exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})",
|
||||
exp.Is: lambda self, e: self.binary(e, "is"),
|
||||
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
|
||||
exp.Null: lambda *_: "None",
|
||||
|
|
|
@ -43,14 +43,14 @@ class Expression(metaclass=_Expression):
|
|||
|
||||
key = "Expression"
|
||||
arg_types = {"this": True}
|
||||
__slots__ = ("args", "parent", "arg_key", "type", "comment")
|
||||
__slots__ = ("args", "parent", "arg_key", "type", "comments")
|
||||
|
||||
def __init__(self, **args):
|
||||
self.args = args
|
||||
self.parent = None
|
||||
self.arg_key = None
|
||||
self.type = None
|
||||
self.comment = None
|
||||
self.comments = None
|
||||
|
||||
for arg_key, value in self.args.items():
|
||||
self._set_parent(arg_key, value)
|
||||
|
@ -88,19 +88,6 @@ class Expression(metaclass=_Expression):
|
|||
return field.this
|
||||
return ""
|
||||
|
||||
def find_comment(self, key: str) -> str:
|
||||
"""
|
||||
Finds the comment that is attached to a specified child node.
|
||||
|
||||
Args:
|
||||
key: the key of the target child node (e.g. "this", "expression", etc).
|
||||
|
||||
Returns:
|
||||
The comment attached to the child node, or the empty string, if it doesn't exist.
|
||||
"""
|
||||
field = self.args.get(key)
|
||||
return field.comment if isinstance(field, Expression) else ""
|
||||
|
||||
@property
|
||||
def is_string(self):
|
||||
return isinstance(self, Literal) and self.args["is_string"]
|
||||
|
@ -137,7 +124,7 @@ class Expression(metaclass=_Expression):
|
|||
|
||||
def __deepcopy__(self, memo):
|
||||
copy = self.__class__(**deepcopy(self.args))
|
||||
copy.comment = self.comment
|
||||
copy.comments = self.comments
|
||||
copy.type = self.type
|
||||
return copy
|
||||
|
||||
|
@ -369,7 +356,7 @@ class Expression(metaclass=_Expression):
|
|||
)
|
||||
for k, vs in self.args.items()
|
||||
}
|
||||
args["comment"] = self.comment
|
||||
args["comments"] = self.comments
|
||||
args["type"] = self.type
|
||||
args = {k: v for k, v in args.items() if v or not hide_missing}
|
||||
|
||||
|
@ -767,7 +754,7 @@ class NotNullColumnConstraint(ColumnConstraintKind):
|
|||
|
||||
|
||||
class PrimaryKeyColumnConstraint(ColumnConstraintKind):
|
||||
pass
|
||||
arg_types = {"desc": False}
|
||||
|
||||
|
||||
class UniqueColumnConstraint(ColumnConstraintKind):
|
||||
|
@ -819,6 +806,12 @@ class Unique(Expression):
|
|||
arg_types = {"expressions": True}
|
||||
|
||||
|
||||
# https://www.postgresql.org/docs/9.1/sql-selectinto.html
|
||||
# https://docs.aws.amazon.com/redshift/latest/dg/r_SELECT_INTO.html#r_SELECT_INTO-examples
|
||||
class Into(Expression):
|
||||
arg_types = {"this": True, "temporary": False, "unlogged": False}
|
||||
|
||||
|
||||
class From(Expression):
|
||||
arg_types = {"expressions": True}
|
||||
|
||||
|
@ -1065,67 +1058,67 @@ class Property(Expression):
|
|||
|
||||
|
||||
class TableFormatProperty(Property):
|
||||
pass
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class PartitionedByProperty(Property):
|
||||
pass
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class FileFormatProperty(Property):
|
||||
pass
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class DistKeyProperty(Property):
|
||||
pass
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class SortKeyProperty(Property):
|
||||
pass
|
||||
arg_types = {"this": True, "compound": False}
|
||||
|
||||
|
||||
class DistStyleProperty(Property):
|
||||
pass
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class LikeProperty(Property):
|
||||
arg_types = {"this": True, "expressions": False}
|
||||
|
||||
|
||||
class LocationProperty(Property):
|
||||
pass
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class EngineProperty(Property):
|
||||
pass
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class AutoIncrementProperty(Property):
|
||||
pass
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class CharacterSetProperty(Property):
|
||||
arg_types = {"this": True, "value": True, "default": True}
|
||||
arg_types = {"this": True, "default": True}
|
||||
|
||||
|
||||
class CollateProperty(Property):
|
||||
pass
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class SchemaCommentProperty(Property):
|
||||
pass
|
||||
|
||||
|
||||
class AnonymousProperty(Property):
|
||||
pass
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class ReturnsProperty(Property):
|
||||
arg_types = {"this": True, "value": True, "is_table": False}
|
||||
arg_types = {"this": True, "is_table": False}
|
||||
|
||||
|
||||
class LanguageProperty(Property):
|
||||
pass
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class ExecuteAsProperty(Property):
|
||||
pass
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class VolatilityProperty(Property):
|
||||
|
@ -1135,27 +1128,36 @@ class VolatilityProperty(Property):
|
|||
class Properties(Expression):
|
||||
arg_types = {"expressions": True}
|
||||
|
||||
PROPERTY_KEY_MAPPING = {
|
||||
NAME_TO_PROPERTY = {
|
||||
"AUTO_INCREMENT": AutoIncrementProperty,
|
||||
"CHARACTER_SET": CharacterSetProperty,
|
||||
"CHARACTER SET": CharacterSetProperty,
|
||||
"COLLATE": CollateProperty,
|
||||
"COMMENT": SchemaCommentProperty,
|
||||
"ENGINE": EngineProperty,
|
||||
"FORMAT": FileFormatProperty,
|
||||
"LOCATION": LocationProperty,
|
||||
"PARTITIONED_BY": PartitionedByProperty,
|
||||
"TABLE_FORMAT": TableFormatProperty,
|
||||
"DISTKEY": DistKeyProperty,
|
||||
"DISTSTYLE": DistStyleProperty,
|
||||
"ENGINE": EngineProperty,
|
||||
"EXECUTE AS": ExecuteAsProperty,
|
||||
"FORMAT": FileFormatProperty,
|
||||
"LANGUAGE": LanguageProperty,
|
||||
"LOCATION": LocationProperty,
|
||||
"PARTITIONED_BY": PartitionedByProperty,
|
||||
"RETURNS": ReturnsProperty,
|
||||
"SORTKEY": SortKeyProperty,
|
||||
"TABLE_FORMAT": TableFormatProperty,
|
||||
}
|
||||
|
||||
PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, properties_dict) -> Properties:
|
||||
expressions = []
|
||||
for key, value in properties_dict.items():
|
||||
property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty)
|
||||
expressions.append(property_cls(this=Literal.string(key), value=convert(value)))
|
||||
property_cls = cls.NAME_TO_PROPERTY.get(key.upper())
|
||||
if property_cls:
|
||||
expressions.append(property_cls(this=convert(value)))
|
||||
else:
|
||||
expressions.append(Property(this=Literal.string(key), value=convert(value)))
|
||||
|
||||
return cls(expressions=expressions)
|
||||
|
||||
|
||||
|
@ -1383,6 +1385,7 @@ class Select(Subqueryable):
|
|||
"expressions": False,
|
||||
"hint": False,
|
||||
"distinct": False,
|
||||
"into": False,
|
||||
"from": False,
|
||||
**QUERY_MODIFIERS,
|
||||
}
|
||||
|
@ -2015,6 +2018,7 @@ class DataType(Expression):
|
|||
DECIMAL = auto()
|
||||
BOOLEAN = auto()
|
||||
JSON = auto()
|
||||
JSONB = auto()
|
||||
INTERVAL = auto()
|
||||
TIMESTAMP = auto()
|
||||
TIMESTAMPTZ = auto()
|
||||
|
@ -2029,6 +2033,7 @@ class DataType(Expression):
|
|||
STRUCT = auto()
|
||||
NULLABLE = auto()
|
||||
HLLSKETCH = auto()
|
||||
HSTORE = auto()
|
||||
SUPER = auto()
|
||||
SERIAL = auto()
|
||||
SMALLSERIAL = auto()
|
||||
|
@ -2109,7 +2114,7 @@ class Transaction(Command):
|
|||
|
||||
|
||||
class Commit(Command):
|
||||
arg_types = {} # type: ignore
|
||||
arg_types = {"chain": False}
|
||||
|
||||
|
||||
class Rollback(Command):
|
||||
|
@ -2442,7 +2447,7 @@ class ArrayFilter(Func):
|
|||
|
||||
|
||||
class ArraySize(Func):
|
||||
pass
|
||||
arg_types = {"this": True, "expression": False}
|
||||
|
||||
|
||||
class ArraySort(Func):
|
||||
|
@ -2726,6 +2731,16 @@ class VarMap(Func):
|
|||
is_var_len_args = True
|
||||
|
||||
|
||||
class Matches(Func):
|
||||
"""Oracle/Snowflake decode.
|
||||
https://docs.oracle.com/cd/B19306_01/server.102/b14200/functions040.htm
|
||||
Pattern matching MATCHES(value, search1, result1, ...searchN, resultN, else)
|
||||
"""
|
||||
|
||||
arg_types = {"this": True, "expressions": True}
|
||||
is_var_len_args = True
|
||||
|
||||
|
||||
class Max(AggFunc):
|
||||
pass
|
||||
|
||||
|
@ -2785,6 +2800,10 @@ class Round(Func):
|
|||
arg_types = {"this": True, "decimals": False}
|
||||
|
||||
|
||||
class RowNumber(Func):
|
||||
arg_types: t.Dict[str, t.Any] = {}
|
||||
|
||||
|
||||
class SafeDivide(Func):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
|
|
@ -1,19 +1,16 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors
|
||||
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
|
||||
from sqlglot.helper import apply_index_offset, csv
|
||||
from sqlglot.time import format_time
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
NEWLINE_RE = re.compile("\r\n?|\n")
|
||||
|
||||
|
||||
class Generator:
|
||||
"""
|
||||
|
@ -58,11 +55,11 @@ class Generator:
|
|||
"""
|
||||
|
||||
TRANSFORMS = {
|
||||
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}",
|
||||
exp.DateAdd: lambda self, e: f"DATE_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
|
||||
exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})",
|
||||
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
|
||||
exp.VarMap: lambda self, e: f"MAP({self.format_args(e.args['keys'], e.args['values'])})",
|
||||
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'this')}",
|
||||
exp.LanguageProperty: lambda self, e: self.naked_property(e),
|
||||
exp.LocationProperty: lambda self, e: self.naked_property(e),
|
||||
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
|
||||
|
@ -97,16 +94,17 @@ class Generator:
|
|||
exp.DistStyleProperty,
|
||||
exp.DistKeyProperty,
|
||||
exp.SortKeyProperty,
|
||||
exp.LikeProperty,
|
||||
}
|
||||
|
||||
WITH_PROPERTIES = {
|
||||
exp.AnonymousProperty,
|
||||
exp.Property,
|
||||
exp.FileFormatProperty,
|
||||
exp.PartitionedByProperty,
|
||||
exp.TableFormatProperty,
|
||||
}
|
||||
|
||||
WITH_SEPARATED_COMMENTS = (exp.Select,)
|
||||
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
|
||||
|
||||
__slots__ = (
|
||||
"time_mapping",
|
||||
|
@ -211,7 +209,7 @@ class Generator:
|
|||
for msg in self.unsupported_messages:
|
||||
logger.warning(msg)
|
||||
elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
|
||||
raise UnsupportedError(concat_errors(self.unsupported_messages, self.max_unsupported))
|
||||
raise UnsupportedError(concat_messages(self.unsupported_messages, self.max_unsupported))
|
||||
|
||||
return sql
|
||||
|
||||
|
@ -226,25 +224,24 @@ class Generator:
|
|||
def seg(self, sql, sep=" "):
|
||||
return f"{self.sep(sep)}{sql}"
|
||||
|
||||
def maybe_comment(self, sql, expression, single_line=False):
|
||||
comment = expression.comment if self._comments else None
|
||||
|
||||
if not comment:
|
||||
return sql
|
||||
|
||||
def pad_comment(self, comment):
|
||||
comment = " " + comment if comment[0].strip() else comment
|
||||
comment = comment + " " if comment[-1].strip() else comment
|
||||
return comment
|
||||
|
||||
def maybe_comment(self, sql, expression):
|
||||
comments = expression.comments if self._comments else None
|
||||
|
||||
if not comments:
|
||||
return sql
|
||||
|
||||
sep = "\n" if self.pretty else " "
|
||||
comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments)
|
||||
|
||||
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
|
||||
return f"/*{comment}*/{self.sep()}{sql}"
|
||||
return f"{comments}{self.sep()}{sql}"
|
||||
|
||||
if not self.pretty:
|
||||
return f"{sql} /*{comment}*/"
|
||||
|
||||
if not NEWLINE_RE.search(comment):
|
||||
return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/"
|
||||
|
||||
return f"/*{comment}*/\n{sql}" if sql else f" /*{comment}*/"
|
||||
return f"{sql} {comments}"
|
||||
|
||||
def wrap(self, expression):
|
||||
this_sql = self.indent(
|
||||
|
@ -387,8 +384,11 @@ class Generator:
|
|||
def notnullcolumnconstraint_sql(self, _):
|
||||
return "NOT NULL"
|
||||
|
||||
def primarykeycolumnconstraint_sql(self, _):
|
||||
return "PRIMARY KEY"
|
||||
def primarykeycolumnconstraint_sql(self, expression):
|
||||
desc = expression.args.get("desc")
|
||||
if desc is not None:
|
||||
return f"PRIMARY KEY{' DESC' if desc else ' ASC'}"
|
||||
return f"PRIMARY KEY"
|
||||
|
||||
def uniquecolumnconstraint_sql(self, _):
|
||||
return "UNIQUE"
|
||||
|
@ -546,36 +546,33 @@ class Generator:
|
|||
|
||||
def root_properties(self, properties):
|
||||
if properties.expressions:
|
||||
return self.sep() + self.expressions(
|
||||
properties,
|
||||
indent=False,
|
||||
sep=" ",
|
||||
)
|
||||
return self.sep() + self.expressions(properties, indent=False, sep=" ")
|
||||
return ""
|
||||
|
||||
def properties(self, properties, prefix="", sep=", "):
|
||||
if properties.expressions:
|
||||
expressions = self.expressions(
|
||||
properties,
|
||||
sep=sep,
|
||||
indent=False,
|
||||
)
|
||||
expressions = self.expressions(properties, sep=sep, indent=False)
|
||||
return f"{self.seg(prefix)}{' ' if prefix else ''}{self.wrap(expressions)}"
|
||||
return ""
|
||||
|
||||
def with_properties(self, properties):
|
||||
return self.properties(
|
||||
properties,
|
||||
prefix="WITH",
|
||||
)
|
||||
return self.properties(properties, prefix="WITH")
|
||||
|
||||
def property_sql(self, expression):
|
||||
if isinstance(expression.this, exp.Literal):
|
||||
key = expression.this.this
|
||||
else:
|
||||
key = expression.name
|
||||
value = self.sql(expression, "value")
|
||||
return f"{key}={value}"
|
||||
property_cls = expression.__class__
|
||||
if property_cls == exp.Property:
|
||||
return f"{expression.name}={self.sql(expression, 'value')}"
|
||||
|
||||
property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls)
|
||||
if not property_name:
|
||||
self.unsupported(f"Unsupported property {property_name}")
|
||||
|
||||
return f"{property_name}={self.sql(expression, 'this')}"
|
||||
|
||||
def likeproperty_sql(self, expression):
|
||||
options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions)
|
||||
options = f" {options}" if options else ""
|
||||
return f"LIKE {self.sql(expression, 'this')}{options}"
|
||||
|
||||
def insert_sql(self, expression):
|
||||
overwrite = expression.args.get("overwrite")
|
||||
|
@ -700,6 +697,11 @@ class Generator:
|
|||
def var_sql(self, expression):
|
||||
return self.sql(expression, "this")
|
||||
|
||||
def into_sql(self, expression):
|
||||
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
|
||||
unlogged = " UNLOGGED" if expression.args.get("unlogged") else ""
|
||||
return f"{self.seg('INTO')}{temporary or unlogged} {self.sql(expression, 'this')}"
|
||||
|
||||
def from_sql(self, expression):
|
||||
expressions = self.expressions(expression, flat=True)
|
||||
return f"{self.seg('FROM')} {expressions}"
|
||||
|
@ -883,6 +885,7 @@ class Generator:
|
|||
sql = self.query_modifiers(
|
||||
expression,
|
||||
f"SELECT{hint}{distinct}{expressions}",
|
||||
self.sql(expression, "into", comment=False),
|
||||
self.sql(expression, "from", comment=False),
|
||||
)
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
@ -1061,6 +1064,11 @@ class Generator:
|
|||
else:
|
||||
return f"TRIM({target})"
|
||||
|
||||
def concat_sql(self, expression):
|
||||
if len(expression.expressions) == 1:
|
||||
return self.sql(expression.expressions[0])
|
||||
return self.function_fallback_sql(expression)
|
||||
|
||||
def check_sql(self, expression):
|
||||
this = self.sql(expression, key="this")
|
||||
return f"CHECK ({this})"
|
||||
|
@ -1125,7 +1133,10 @@ class Generator:
|
|||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
def neg_sql(self, expression):
|
||||
return f"-{self.sql(expression, 'this')}"
|
||||
# This makes sure we don't convert "- - 5" to "--5", which is a comment
|
||||
this_sql = self.sql(expression, "this")
|
||||
sep = " " if this_sql[0] == "-" else ""
|
||||
return f"-{sep}{this_sql}"
|
||||
|
||||
def not_sql(self, expression):
|
||||
return f"NOT {self.sql(expression, 'this')}"
|
||||
|
@ -1191,8 +1202,12 @@ class Generator:
|
|||
def transaction_sql(self, *_):
|
||||
return "BEGIN"
|
||||
|
||||
def commit_sql(self, *_):
|
||||
return "COMMIT"
|
||||
def commit_sql(self, expression):
|
||||
chain = expression.args.get("chain")
|
||||
if chain is not None:
|
||||
chain = " AND CHAIN" if chain else " AND NO CHAIN"
|
||||
|
||||
return f"COMMIT{chain or ''}"
|
||||
|
||||
def rollback_sql(self, expression):
|
||||
savepoint = expression.args.get("savepoint")
|
||||
|
@ -1334,15 +1349,15 @@ class Generator:
|
|||
result_sqls = []
|
||||
for i, e in enumerate(expressions):
|
||||
sql = self.sql(e, comment=False)
|
||||
comment = self.maybe_comment("", e, single_line=True)
|
||||
comments = self.maybe_comment("", e)
|
||||
|
||||
if self.pretty:
|
||||
if self._leading_comma:
|
||||
result_sqls.append(f"{sep if i > 0 else pad}{sql}{comment}")
|
||||
result_sqls.append(f"{sep if i > 0 else pad}{sql}{comments}")
|
||||
else:
|
||||
result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comment}")
|
||||
result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}")
|
||||
else:
|
||||
result_sqls.append(f"{sql}{comment}{sep if i + 1 < num_sqls else ''}")
|
||||
result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}")
|
||||
|
||||
result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
|
||||
return self.indent(result_sqls, skip_first=False) if indent else result_sqls
|
||||
|
@ -1354,7 +1369,10 @@ class Generator:
|
|||
return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}"
|
||||
|
||||
def naked_property(self, expression):
|
||||
return f"{expression.name} {self.sql(expression, 'value')}"
|
||||
property_name = exp.Properties.PROPERTY_TO_NAME.get(expression.__class__)
|
||||
if not property_name:
|
||||
self.unsupported(f"Unsupported property {expression.__class__.__name__}")
|
||||
return f"{property_name} {self.sql(expression, 'this')}"
|
||||
|
||||
def set_operation(self, expression, op):
|
||||
this = self.sql(expression, "this")
|
||||
|
|
|
@ -68,6 +68,9 @@ def eliminate_subqueries(expression):
|
|||
for cte_scope in root.cte_scopes:
|
||||
# Append all the new CTEs from this existing CTE
|
||||
for scope in cte_scope.traverse():
|
||||
if scope is cte_scope:
|
||||
# Don't try to eliminate this CTE itself
|
||||
continue
|
||||
new_cte = _eliminate(scope, existing_ctes, taken)
|
||||
if new_cte:
|
||||
new_ctes.append(new_cte)
|
||||
|
@ -97,6 +100,9 @@ def _eliminate(scope, existing_ctes, taken):
|
|||
if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF):
|
||||
return _eliminate_derived_table(scope, existing_ctes, taken)
|
||||
|
||||
if scope.is_cte:
|
||||
return _eliminate_cte(scope, existing_ctes, taken)
|
||||
|
||||
|
||||
def _eliminate_union(scope, existing_ctes, taken):
|
||||
duplicate_cte_alias = existing_ctes.get(scope.expression)
|
||||
|
@ -127,26 +133,61 @@ def _eliminate_union(scope, existing_ctes, taken):
|
|||
|
||||
|
||||
def _eliminate_derived_table(scope, existing_ctes, taken):
|
||||
parent = scope.expression.parent
|
||||
name, cte = _new_cte(scope, existing_ctes, taken)
|
||||
|
||||
table = exp.alias_(exp.table_(name), alias=parent.alias or name)
|
||||
parent.replace(table)
|
||||
|
||||
return cte
|
||||
|
||||
|
||||
def _eliminate_cte(scope, existing_ctes, taken):
|
||||
parent = scope.expression.parent
|
||||
name, cte = _new_cte(scope, existing_ctes, taken)
|
||||
|
||||
with_ = parent.parent
|
||||
parent.pop()
|
||||
if not with_.expressions:
|
||||
with_.pop()
|
||||
|
||||
# Rename references to this CTE
|
||||
for child_scope in scope.parent.traverse():
|
||||
for table, source in child_scope.selected_sources.values():
|
||||
if source is scope:
|
||||
new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name)
|
||||
table.replace(new_table)
|
||||
|
||||
return cte
|
||||
|
||||
|
||||
def _new_cte(scope, existing_ctes, taken):
|
||||
"""
|
||||
Returns:
|
||||
tuple of (name, cte)
|
||||
where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance.
|
||||
If this CTE duplicates an existing CTE, `cte` will be None.
|
||||
"""
|
||||
duplicate_cte_alias = existing_ctes.get(scope.expression)
|
||||
parent = scope.expression.parent
|
||||
name = alias = parent.alias
|
||||
name = parent.alias
|
||||
|
||||
if not alias:
|
||||
name = alias = find_new_name(taken=taken, base="cte")
|
||||
if not name:
|
||||
name = find_new_name(taken=taken, base="cte")
|
||||
|
||||
if duplicate_cte_alias:
|
||||
name = duplicate_cte_alias
|
||||
elif taken.get(alias):
|
||||
name = find_new_name(taken=taken, base=alias)
|
||||
elif taken.get(name):
|
||||
name = find_new_name(taken=taken, base=name)
|
||||
|
||||
taken[name] = scope
|
||||
|
||||
table = exp.alias_(exp.table_(name), alias=alias)
|
||||
parent.replace(table)
|
||||
|
||||
if not duplicate_cte_alias:
|
||||
existing_ctes[scope.expression] = name
|
||||
return exp.CTE(
|
||||
cte = exp.CTE(
|
||||
this=scope.expression,
|
||||
alias=exp.TableAlias(this=exp.to_identifier(name)),
|
||||
)
|
||||
else:
|
||||
cte = None
|
||||
return name, cte
|
||||
|
|
92
sqlglot/optimizer/lower_identities.py
Normal file
92
sqlglot/optimizer/lower_identities.py
Normal file
|
@ -0,0 +1,92 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.helper import ensure_collection
|
||||
|
||||
|
||||
def lower_identities(expression):
|
||||
"""
|
||||
Convert all unquoted identifiers to lower case.
|
||||
|
||||
Assuming the schema is all lower case, this essentially makes identifiers case-insensitive.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
|
||||
>>> lower_identities(expression).sql()
|
||||
'SELECT bar.a AS A FROM "Foo".bar'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to quote
|
||||
Returns:
|
||||
sqlglot.Expression: quoted expression
|
||||
"""
|
||||
# We need to leave the output aliases unchanged, so the selects need special handling
|
||||
_lower_selects(expression)
|
||||
|
||||
# These clauses can reference output aliases and also need special handling
|
||||
_lower_order(expression)
|
||||
_lower_having(expression)
|
||||
|
||||
# We've already handled these args, so don't traverse into them
|
||||
traversed = {"expressions", "order", "having"}
|
||||
|
||||
if isinstance(expression, exp.Subquery):
|
||||
# Root subquery, e.g. (SELECT A AS A FROM X) LIMIT 1
|
||||
lower_identities(expression.this)
|
||||
traversed |= {"this"}
|
||||
|
||||
if isinstance(expression, exp.Union):
|
||||
# Union, e.g. SELECT A AS A FROM X UNION SELECT A AS A FROM X
|
||||
lower_identities(expression.left)
|
||||
lower_identities(expression.right)
|
||||
traversed |= {"this", "expression"}
|
||||
|
||||
for k, v in expression.args.items():
|
||||
if k in traversed:
|
||||
continue
|
||||
|
||||
for child in ensure_collection(v):
|
||||
if isinstance(child, exp.Expression):
|
||||
child.transform(_lower, copy=False)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def _lower_selects(expression):
|
||||
for e in expression.expressions:
|
||||
# Leave output aliases as-is
|
||||
e.unalias().transform(_lower, copy=False)
|
||||
|
||||
|
||||
def _lower_order(expression):
|
||||
order = expression.args.get("order")
|
||||
|
||||
if not order:
|
||||
return
|
||||
|
||||
output_aliases = {e.alias for e in expression.expressions if isinstance(e, exp.Alias)}
|
||||
|
||||
for ordered in order.expressions:
|
||||
# Don't lower references to output aliases
|
||||
if not (
|
||||
isinstance(ordered.this, exp.Column)
|
||||
and not ordered.this.table
|
||||
and ordered.this.name in output_aliases
|
||||
):
|
||||
ordered.transform(_lower, copy=False)
|
||||
|
||||
|
||||
def _lower_having(expression):
|
||||
having = expression.args.get("having")
|
||||
|
||||
if not having:
|
||||
return
|
||||
|
||||
# Don't lower references to output aliases
|
||||
for agg in having.find_all(exp.AggFunc):
|
||||
agg.transform(_lower, copy=False)
|
||||
|
||||
|
||||
def _lower(node):
|
||||
if isinstance(node, exp.Identifier) and not node.quoted:
|
||||
node.set("this", node.this.lower())
|
||||
return node
|
|
@ -6,6 +6,7 @@ from sqlglot.optimizer.eliminate_joins import eliminate_joins
|
|||
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
|
||||
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
|
||||
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
|
||||
from sqlglot.optimizer.lower_identities import lower_identities
|
||||
from sqlglot.optimizer.merge_subqueries import merge_subqueries
|
||||
from sqlglot.optimizer.normalize import normalize
|
||||
from sqlglot.optimizer.optimize_joins import optimize_joins
|
||||
|
@ -17,6 +18,7 @@ from sqlglot.optimizer.quote_identities import quote_identities
|
|||
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
|
||||
|
||||
RULES = (
|
||||
lower_identities,
|
||||
qualify_tables,
|
||||
isolate_table_selects,
|
||||
qualify_columns,
|
||||
|
|
|
@ -1,16 +1,15 @@
|
|||
import itertools
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
from sqlglot.optimizer.scope import ScopeType, traverse_scope
|
||||
|
||||
|
||||
def unnest_subqueries(expression):
|
||||
"""
|
||||
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
|
||||
|
||||
Convert the subquery into a group by so it is not a many to many left join.
|
||||
Unnesting can only occur if the subquery does not have LIMIT or OFFSET.
|
||||
Unnesting non correlated subqueries only happens on IN statements or = ANY statements.
|
||||
Convert scalar subqueries into cross joins.
|
||||
Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
|
@ -29,21 +28,43 @@ def unnest_subqueries(expression):
|
|||
for scope in traverse_scope(expression):
|
||||
select = scope.expression
|
||||
parent = select.parent_select
|
||||
if not parent:
|
||||
continue
|
||||
if scope.external_columns:
|
||||
decorrelate(select, parent, scope.external_columns, sequence)
|
||||
else:
|
||||
elif scope.scope_type == ScopeType.SUBQUERY:
|
||||
unnest(select, parent, sequence)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def unnest(select, parent_select, sequence):
|
||||
predicate = select.find_ancestor(exp.In, exp.Any)
|
||||
if len(select.selects) > 1:
|
||||
return
|
||||
|
||||
predicate = select.find_ancestor(exp.Condition)
|
||||
alias = _alias(sequence)
|
||||
|
||||
if not predicate or parent_select is not predicate.parent_select:
|
||||
return
|
||||
|
||||
if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset):
|
||||
# this subquery returns a scalar and can just be converted to a cross join
|
||||
if not isinstance(predicate, (exp.In, exp.Any)):
|
||||
having = predicate.find_ancestor(exp.Having)
|
||||
column = exp.column(select.selects[0].alias_or_name, alias)
|
||||
if having and having.parent_select is parent_select:
|
||||
column = exp.Max(this=column)
|
||||
_replace(select.parent, column)
|
||||
|
||||
parent_select.join(
|
||||
select,
|
||||
join_type="CROSS",
|
||||
join_alias=alias,
|
||||
copy=False,
|
||||
)
|
||||
return
|
||||
|
||||
if select.find(exp.Limit, exp.Offset):
|
||||
return
|
||||
|
||||
if isinstance(predicate, exp.Any):
|
||||
|
@ -54,7 +75,6 @@ def unnest(select, parent_select, sequence):
|
|||
|
||||
column = _other_operand(predicate)
|
||||
value = select.selects[0]
|
||||
alias = _alias(sequence)
|
||||
|
||||
on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
|
||||
_replace(predicate, f"NOT {on.right} IS NULL")
|
||||
|
|
|
@ -4,7 +4,7 @@ import logging
|
|||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import ErrorLevel, ParseError, concat_errors
|
||||
from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors
|
||||
from sqlglot.helper import apply_index_offset, ensure_collection, seq_get
|
||||
from sqlglot.tokens import Token, Tokenizer, TokenType
|
||||
from sqlglot.trie import in_trie, new_trie
|
||||
|
@ -104,6 +104,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.BINARY,
|
||||
TokenType.VARBINARY,
|
||||
TokenType.JSON,
|
||||
TokenType.JSONB,
|
||||
TokenType.INTERVAL,
|
||||
TokenType.TIMESTAMP,
|
||||
TokenType.TIMESTAMPTZ,
|
||||
|
@ -115,6 +116,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.GEOGRAPHY,
|
||||
TokenType.GEOMETRY,
|
||||
TokenType.HLLSKETCH,
|
||||
TokenType.HSTORE,
|
||||
TokenType.SUPER,
|
||||
TokenType.SERIAL,
|
||||
TokenType.SMALLSERIAL,
|
||||
|
@ -153,6 +155,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.COLLATE,
|
||||
TokenType.COMMAND,
|
||||
TokenType.COMMIT,
|
||||
TokenType.COMPOUND,
|
||||
TokenType.CONSTRAINT,
|
||||
TokenType.CURRENT_TIME,
|
||||
TokenType.DEFAULT,
|
||||
|
@ -194,6 +197,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.RANGE,
|
||||
TokenType.REFERENCES,
|
||||
TokenType.RETURNS,
|
||||
TokenType.ROW,
|
||||
TokenType.ROWS,
|
||||
TokenType.SCHEMA,
|
||||
TokenType.SCHEMA_COMMENT,
|
||||
|
@ -213,6 +217,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.TRUE,
|
||||
TokenType.UNBOUNDED,
|
||||
TokenType.UNIQUE,
|
||||
TokenType.UNLOGGED,
|
||||
TokenType.UNPIVOT,
|
||||
TokenType.PROPERTIES,
|
||||
TokenType.PROCEDURE,
|
||||
|
@ -400,9 +405,17 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()),
|
||||
TokenType.BEGIN: lambda self: self._parse_transaction(),
|
||||
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
|
||||
TokenType.END: lambda self: self._parse_commit_or_rollback(),
|
||||
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
|
||||
}
|
||||
|
||||
UNARY_PARSERS = {
|
||||
TokenType.PLUS: lambda self: self._parse_unary(), # Unary + is handled as a no-op
|
||||
TokenType.NOT: lambda self: self.expression(exp.Not, this=self._parse_equality()),
|
||||
TokenType.TILDA: lambda self: self.expression(exp.BitwiseNot, this=self._parse_unary()),
|
||||
TokenType.DASH: lambda self: self.expression(exp.Neg, this=self._parse_unary()),
|
||||
}
|
||||
|
||||
PRIMARY_PARSERS = {
|
||||
TokenType.STRING: lambda self, token: self.expression(
|
||||
exp.Literal, this=token.text, is_string=True
|
||||
|
@ -446,19 +459,20 @@ class Parser(metaclass=_Parser):
|
|||
}
|
||||
|
||||
PROPERTY_PARSERS = {
|
||||
TokenType.AUTO_INCREMENT: lambda self: self._parse_auto_increment(),
|
||||
TokenType.CHARACTER_SET: lambda self: self._parse_character_set(),
|
||||
TokenType.LOCATION: lambda self: self.expression(
|
||||
exp.LocationProperty,
|
||||
this=exp.Literal.string("LOCATION"),
|
||||
value=self._parse_string(),
|
||||
TokenType.AUTO_INCREMENT: lambda self: self._parse_property_assignment(
|
||||
exp.AutoIncrementProperty
|
||||
),
|
||||
TokenType.CHARACTER_SET: lambda self: self._parse_character_set(),
|
||||
TokenType.LOCATION: lambda self: self._parse_property_assignment(exp.LocationProperty),
|
||||
TokenType.PARTITIONED_BY: lambda self: self._parse_partitioned_by(),
|
||||
TokenType.SCHEMA_COMMENT: lambda self: self._parse_schema_comment(),
|
||||
TokenType.STORED: lambda self: self._parse_stored(),
|
||||
TokenType.SCHEMA_COMMENT: lambda self: self._parse_property_assignment(
|
||||
exp.SchemaCommentProperty
|
||||
),
|
||||
TokenType.STORED: lambda self: self._parse_property_assignment(exp.FileFormatProperty),
|
||||
TokenType.DISTKEY: lambda self: self._parse_distkey(),
|
||||
TokenType.DISTSTYLE: lambda self: self._parse_diststyle(),
|
||||
TokenType.DISTSTYLE: lambda self: self._parse_property_assignment(exp.DistStyleProperty),
|
||||
TokenType.SORTKEY: lambda self: self._parse_sortkey(),
|
||||
TokenType.LIKE: lambda self: self._parse_create_like(),
|
||||
TokenType.RETURNS: lambda self: self._parse_returns(),
|
||||
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
|
||||
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
|
||||
|
@ -468,7 +482,7 @@ class Parser(metaclass=_Parser):
|
|||
),
|
||||
TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
|
||||
TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty),
|
||||
TokenType.EXECUTE: lambda self: self._parse_execute_as(),
|
||||
TokenType.EXECUTE: lambda self: self._parse_property_assignment(exp.ExecuteAsProperty),
|
||||
TokenType.DETERMINISTIC: lambda self: self.expression(
|
||||
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
|
||||
),
|
||||
|
@ -489,6 +503,7 @@ class Parser(metaclass=_Parser):
|
|||
),
|
||||
TokenType.FOREIGN_KEY: lambda self: self._parse_foreign_key(),
|
||||
TokenType.UNIQUE: lambda self: self._parse_unique(),
|
||||
TokenType.LIKE: lambda self: self._parse_create_like(),
|
||||
}
|
||||
|
||||
NO_PAREN_FUNCTION_PARSERS = {
|
||||
|
@ -505,6 +520,7 @@ class Parser(metaclass=_Parser):
|
|||
"TRIM": lambda self: self._parse_trim(),
|
||||
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
|
||||
"TRY_CAST": lambda self: self._parse_cast(False),
|
||||
"STRING_AGG": lambda self: self._parse_string_agg(),
|
||||
}
|
||||
|
||||
QUERY_MODIFIER_PARSERS = {
|
||||
|
@ -556,7 +572,7 @@ class Parser(metaclass=_Parser):
|
|||
"_curr",
|
||||
"_next",
|
||||
"_prev",
|
||||
"_prev_comment",
|
||||
"_prev_comments",
|
||||
"_show_trie",
|
||||
"_set_trie",
|
||||
)
|
||||
|
@ -589,7 +605,7 @@ class Parser(metaclass=_Parser):
|
|||
self._curr = None
|
||||
self._next = None
|
||||
self._prev = None
|
||||
self._prev_comment = None
|
||||
self._prev_comments = None
|
||||
|
||||
def parse(self, raw_tokens, sql=None):
|
||||
"""
|
||||
|
@ -608,6 +624,7 @@ class Parser(metaclass=_Parser):
|
|||
)
|
||||
|
||||
def parse_into(self, expression_types, raw_tokens, sql=None):
|
||||
errors = []
|
||||
for expression_type in ensure_collection(expression_types):
|
||||
parser = self.EXPRESSION_PARSERS.get(expression_type)
|
||||
if not parser:
|
||||
|
@ -615,8 +632,12 @@ class Parser(metaclass=_Parser):
|
|||
try:
|
||||
return self._parse(parser, raw_tokens, sql)
|
||||
except ParseError as e:
|
||||
error = e
|
||||
raise ParseError(f"Failed to parse into {expression_types}") from error
|
||||
e.errors[0]["into_expression"] = expression_type
|
||||
errors.append(e)
|
||||
raise ParseError(
|
||||
f"Failed to parse into {expression_types}",
|
||||
errors=merge_errors(errors),
|
||||
) from errors[-1]
|
||||
|
||||
def _parse(self, parse_method, raw_tokens, sql=None):
|
||||
self.reset()
|
||||
|
@ -650,7 +671,10 @@ class Parser(metaclass=_Parser):
|
|||
for error in self.errors:
|
||||
logger.error(str(error))
|
||||
elif self.error_level == ErrorLevel.RAISE and self.errors:
|
||||
raise ParseError(concat_errors(self.errors, self.max_errors))
|
||||
raise ParseError(
|
||||
concat_messages(self.errors, self.max_errors),
|
||||
errors=merge_errors(self.errors),
|
||||
)
|
||||
|
||||
def raise_error(self, message, token=None):
|
||||
token = token or self._curr or self._prev or Token.string("")
|
||||
|
@ -659,19 +683,27 @@ class Parser(metaclass=_Parser):
|
|||
start_context = self.sql[max(start - self.error_message_context, 0) : start]
|
||||
highlight = self.sql[start:end]
|
||||
end_context = self.sql[end : end + self.error_message_context]
|
||||
error = ParseError(
|
||||
error = ParseError.new(
|
||||
f"{message}. Line {token.line}, Col: {token.col}.\n"
|
||||
f" {start_context}\033[4m{highlight}\033[0m{end_context}"
|
||||
f" {start_context}\033[4m{highlight}\033[0m{end_context}",
|
||||
description=message,
|
||||
line=token.line,
|
||||
col=token.col,
|
||||
start_context=start_context,
|
||||
highlight=highlight,
|
||||
end_context=end_context,
|
||||
)
|
||||
if self.error_level == ErrorLevel.IMMEDIATE:
|
||||
raise error
|
||||
self.errors.append(error)
|
||||
|
||||
def expression(self, exp_class, **kwargs):
|
||||
def expression(self, exp_class, comments=None, **kwargs):
|
||||
instance = exp_class(**kwargs)
|
||||
if self._prev_comment:
|
||||
instance.comment = self._prev_comment
|
||||
self._prev_comment = None
|
||||
if self._prev_comments:
|
||||
instance.comments = self._prev_comments
|
||||
self._prev_comments = None
|
||||
if comments:
|
||||
instance.comments = comments
|
||||
self.validate_expression(instance)
|
||||
return instance
|
||||
|
||||
|
@ -714,10 +746,10 @@ class Parser(metaclass=_Parser):
|
|||
self._next = seq_get(self._tokens, self._index + 1)
|
||||
if self._index > 0:
|
||||
self._prev = self._tokens[self._index - 1]
|
||||
self._prev_comment = self._prev.comment
|
||||
self._prev_comments = self._prev.comments
|
||||
else:
|
||||
self._prev = None
|
||||
self._prev_comment = None
|
||||
self._prev_comments = None
|
||||
|
||||
def _retreat(self, index):
|
||||
self._advance(index - self._index)
|
||||
|
@ -768,7 +800,7 @@ class Parser(metaclass=_Parser):
|
|||
)
|
||||
|
||||
def _parse_create(self):
|
||||
replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE)
|
||||
replace = self._match_pair(TokenType.OR, TokenType.REPLACE)
|
||||
temporary = self._match(TokenType.TEMPORARY)
|
||||
transient = self._match(TokenType.TRANSIENT)
|
||||
unique = self._match(TokenType.UNIQUE)
|
||||
|
@ -822,97 +854,57 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_property(self):
|
||||
if self._match_set(self.PROPERTY_PARSERS):
|
||||
return self.PROPERTY_PARSERS[self._prev.token_type](self)
|
||||
|
||||
if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET):
|
||||
return self._parse_character_set(True)
|
||||
|
||||
if self._match_pair(TokenType.VAR, TokenType.EQ, advance=False):
|
||||
key = self._parse_var().this
|
||||
self._match(TokenType.EQ)
|
||||
if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY):
|
||||
return self._parse_sortkey(compound=True)
|
||||
|
||||
return self.expression(
|
||||
exp.AnonymousProperty,
|
||||
this=exp.Literal.string(key),
|
||||
value=self._parse_column(),
|
||||
)
|
||||
if self._match_pair(TokenType.VAR, TokenType.EQ, advance=False):
|
||||
key = self._parse_var()
|
||||
self._match(TokenType.EQ)
|
||||
return self.expression(exp.Property, this=key, value=self._parse_column())
|
||||
|
||||
return None
|
||||
|
||||
def _parse_property_assignment(self, exp_class):
|
||||
prop = self._prev.text
|
||||
self._match(TokenType.EQ)
|
||||
return self.expression(exp_class, this=prop, value=self._parse_var_or_string())
|
||||
self._match(TokenType.ALIAS)
|
||||
return self.expression(exp_class, this=self._parse_var_or_string() or self._parse_number())
|
||||
|
||||
def _parse_partitioned_by(self):
|
||||
self._match(TokenType.EQ)
|
||||
return self.expression(
|
||||
exp.PartitionedByProperty,
|
||||
this=exp.Literal.string("PARTITIONED_BY"),
|
||||
value=self._parse_schema() or self._parse_bracket(self._parse_field()),
|
||||
)
|
||||
|
||||
def _parse_stored(self):
|
||||
self._match(TokenType.ALIAS)
|
||||
self._match(TokenType.EQ)
|
||||
return self.expression(
|
||||
exp.FileFormatProperty,
|
||||
this=exp.Literal.string("FORMAT"),
|
||||
value=exp.Literal.string(self._parse_var_or_string().name),
|
||||
this=self._parse_schema() or self._parse_bracket(self._parse_field()),
|
||||
)
|
||||
|
||||
def _parse_distkey(self):
|
||||
self._match_l_paren()
|
||||
this = exp.Literal.string("DISTKEY")
|
||||
value = exp.Literal.string(self._parse_var().name)
|
||||
self._match_r_paren()
|
||||
return self.expression(
|
||||
exp.DistKeyProperty,
|
||||
this=this,
|
||||
value=value,
|
||||
)
|
||||
return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_var))
|
||||
|
||||
def _parse_sortkey(self):
|
||||
self._match_l_paren()
|
||||
this = exp.Literal.string("SORTKEY")
|
||||
value = exp.Literal.string(self._parse_var().name)
|
||||
self._match_r_paren()
|
||||
return self.expression(
|
||||
exp.SortKeyProperty,
|
||||
this=this,
|
||||
value=value,
|
||||
)
|
||||
def _parse_create_like(self):
|
||||
table = self._parse_table(schema=True)
|
||||
options = []
|
||||
while self._match_texts(("INCLUDING", "EXCLUDING")):
|
||||
options.append(
|
||||
self.expression(
|
||||
exp.Property,
|
||||
this=self._prev.text.upper(),
|
||||
value=exp.Var(this=self._parse_id_var().this.upper()),
|
||||
)
|
||||
)
|
||||
return self.expression(exp.LikeProperty, this=table, expressions=options)
|
||||
|
||||
def _parse_diststyle(self):
|
||||
this = exp.Literal.string("DISTSTYLE")
|
||||
value = exp.Literal.string(self._parse_var().name)
|
||||
def _parse_sortkey(self, compound=False):
|
||||
return self.expression(
|
||||
exp.DistStyleProperty,
|
||||
this=this,
|
||||
value=value,
|
||||
)
|
||||
|
||||
def _parse_auto_increment(self):
|
||||
self._match(TokenType.EQ)
|
||||
return self.expression(
|
||||
exp.AutoIncrementProperty,
|
||||
this=exp.Literal.string("AUTO_INCREMENT"),
|
||||
value=self._parse_number(),
|
||||
)
|
||||
|
||||
def _parse_schema_comment(self):
|
||||
self._match(TokenType.EQ)
|
||||
return self.expression(
|
||||
exp.SchemaCommentProperty,
|
||||
this=exp.Literal.string("COMMENT"),
|
||||
value=self._parse_string(),
|
||||
exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_var), compound=compound
|
||||
)
|
||||
|
||||
def _parse_character_set(self, default=False):
|
||||
self._match(TokenType.EQ)
|
||||
return self.expression(
|
||||
exp.CharacterSetProperty,
|
||||
this=exp.Literal.string("CHARACTER_SET"),
|
||||
value=self._parse_var_or_string(),
|
||||
default=default,
|
||||
exp.CharacterSetProperty, this=self._parse_var_or_string(), default=default
|
||||
)
|
||||
|
||||
def _parse_returns(self):
|
||||
|
@ -931,20 +923,7 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
value = self._parse_types()
|
||||
|
||||
return self.expression(
|
||||
exp.ReturnsProperty,
|
||||
this=exp.Literal.string("RETURNS"),
|
||||
value=value,
|
||||
is_table=is_table,
|
||||
)
|
||||
|
||||
def _parse_execute_as(self):
|
||||
self._match(TokenType.ALIAS)
|
||||
return self.expression(
|
||||
exp.ExecuteAsProperty,
|
||||
this=exp.Literal.string("EXECUTE AS"),
|
||||
value=self._parse_var(),
|
||||
)
|
||||
return self.expression(exp.ReturnsProperty, this=value, is_table=is_table)
|
||||
|
||||
def _parse_properties(self):
|
||||
properties = []
|
||||
|
@ -956,7 +935,7 @@ class Parser(metaclass=_Parser):
|
|||
properties.extend(
|
||||
self._parse_wrapped_csv(
|
||||
lambda: self.expression(
|
||||
exp.AnonymousProperty,
|
||||
exp.Property,
|
||||
this=self._parse_string(),
|
||||
value=self._match(TokenType.EQ) and self._parse_string(),
|
||||
)
|
||||
|
@ -1076,7 +1055,12 @@ class Parser(metaclass=_Parser):
|
|||
options = []
|
||||
|
||||
if self._match(TokenType.OPTIONS):
|
||||
options = self._parse_wrapped_csv(self._parse_string, sep=TokenType.EQ)
|
||||
self._match_l_paren()
|
||||
k = self._parse_string()
|
||||
self._match(TokenType.EQ)
|
||||
v = self._parse_string()
|
||||
options = [k, v]
|
||||
self._match_r_paren()
|
||||
|
||||
self._match(TokenType.ALIAS)
|
||||
return self.expression(
|
||||
|
@ -1116,7 +1100,7 @@ class Parser(metaclass=_Parser):
|
|||
self.raise_error(f"{this.key} does not support CTE")
|
||||
this = cte
|
||||
elif self._match(TokenType.SELECT):
|
||||
comment = self._prev_comment
|
||||
comments = self._prev_comments
|
||||
|
||||
hint = self._parse_hint()
|
||||
all_ = self._match(TokenType.ALL)
|
||||
|
@ -1141,10 +1125,16 @@ class Parser(metaclass=_Parser):
|
|||
expressions=expressions,
|
||||
limit=limit,
|
||||
)
|
||||
this.comment = comment
|
||||
this.comments = comments
|
||||
|
||||
into = self._parse_into()
|
||||
if into:
|
||||
this.set("into", into)
|
||||
|
||||
from_ = self._parse_from()
|
||||
if from_:
|
||||
this.set("from", from_)
|
||||
|
||||
self._parse_query_modifiers(this)
|
||||
elif (table or nested) and self._match(TokenType.L_PAREN):
|
||||
this = self._parse_table() if table else self._parse_select(nested=True)
|
||||
|
@ -1248,11 +1238,24 @@ class Parser(metaclass=_Parser):
|
|||
return self.expression(exp.Hint, expressions=hints)
|
||||
return None
|
||||
|
||||
def _parse_into(self):
|
||||
if not self._match(TokenType.INTO):
|
||||
return None
|
||||
|
||||
temp = self._match(TokenType.TEMPORARY)
|
||||
unlogged = self._match(TokenType.UNLOGGED)
|
||||
self._match(TokenType.TABLE)
|
||||
|
||||
return self.expression(
|
||||
exp.Into, this=self._parse_table(schema=True), temporary=temp, unlogged=unlogged
|
||||
)
|
||||
|
||||
def _parse_from(self):
|
||||
if not self._match(TokenType.FROM):
|
||||
return None
|
||||
|
||||
return self.expression(exp.From, expressions=self._parse_csv(self._parse_table))
|
||||
return self.expression(
|
||||
exp.From, comments=self._prev_comments, expressions=self._parse_csv(self._parse_table)
|
||||
)
|
||||
|
||||
def _parse_lateral(self):
|
||||
outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY)
|
||||
|
@ -1515,7 +1518,9 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_where(self, skip_where_token=False):
|
||||
if not skip_where_token and not self._match(TokenType.WHERE):
|
||||
return None
|
||||
return self.expression(exp.Where, this=self._parse_conjunction())
|
||||
return self.expression(
|
||||
exp.Where, comments=self._prev_comments, this=self._parse_conjunction()
|
||||
)
|
||||
|
||||
def _parse_group(self, skip_group_by_token=False):
|
||||
if not skip_group_by_token and not self._match(TokenType.GROUP_BY):
|
||||
|
@ -1737,12 +1742,8 @@ class Parser(metaclass=_Parser):
|
|||
return self._parse_tokens(self._parse_unary, self.FACTOR)
|
||||
|
||||
def _parse_unary(self):
|
||||
if self._match(TokenType.NOT):
|
||||
return self.expression(exp.Not, this=self._parse_equality())
|
||||
if self._match(TokenType.TILDA):
|
||||
return self.expression(exp.BitwiseNot, this=self._parse_unary())
|
||||
if self._match(TokenType.DASH):
|
||||
return self.expression(exp.Neg, this=self._parse_unary())
|
||||
if self._match_set(self.UNARY_PARSERS):
|
||||
return self.UNARY_PARSERS[self._prev.token_type](self)
|
||||
return self._parse_at_time_zone(self._parse_type())
|
||||
|
||||
def _parse_type(self):
|
||||
|
@ -1775,17 +1776,6 @@ class Parser(metaclass=_Parser):
|
|||
expressions = None
|
||||
maybe_func = False
|
||||
|
||||
if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
|
||||
return exp.DataType(
|
||||
this=exp.DataType.Type.ARRAY,
|
||||
expressions=[exp.DataType.build(type_token.value)],
|
||||
nested=True,
|
||||
)
|
||||
|
||||
if self._match(TokenType.L_BRACKET):
|
||||
self._retreat(index)
|
||||
return None
|
||||
|
||||
if self._match(TokenType.L_PAREN):
|
||||
if is_struct:
|
||||
expressions = self._parse_csv(self._parse_struct_kwargs)
|
||||
|
@ -1801,6 +1791,17 @@ class Parser(metaclass=_Parser):
|
|||
self._match_r_paren()
|
||||
maybe_func = True
|
||||
|
||||
if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
|
||||
return exp.DataType(
|
||||
this=exp.DataType.Type.ARRAY,
|
||||
expressions=[exp.DataType.build(type_token.value, expressions=expressions)],
|
||||
nested=True,
|
||||
)
|
||||
|
||||
if self._match(TokenType.L_BRACKET):
|
||||
self._retreat(index)
|
||||
return None
|
||||
|
||||
if nested and self._match(TokenType.LT):
|
||||
if is_struct:
|
||||
expressions = self._parse_csv(self._parse_struct_kwargs)
|
||||
|
@ -1904,7 +1905,7 @@ class Parser(metaclass=_Parser):
|
|||
return exp.Literal.number(f"0.{self._prev.text}")
|
||||
|
||||
if self._match(TokenType.L_PAREN):
|
||||
comment = self._prev_comment
|
||||
comments = self._prev_comments
|
||||
query = self._parse_select()
|
||||
|
||||
if query:
|
||||
|
@ -1924,8 +1925,8 @@ class Parser(metaclass=_Parser):
|
|||
this = self.expression(exp.Tuple, expressions=expressions)
|
||||
else:
|
||||
this = self.expression(exp.Paren, this=this)
|
||||
if comment:
|
||||
this.comment = comment
|
||||
if comments:
|
||||
this.comments = comments
|
||||
return this
|
||||
|
||||
return None
|
||||
|
@ -2098,7 +2099,10 @@ class Parser(metaclass=_Parser):
|
|||
elif self._match(TokenType.SCHEMA_COMMENT):
|
||||
kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string())
|
||||
elif self._match(TokenType.PRIMARY_KEY):
|
||||
kind = exp.PrimaryKeyColumnConstraint()
|
||||
desc = None
|
||||
if self._match(TokenType.ASC) or self._match(TokenType.DESC):
|
||||
desc = self._prev.token_type == TokenType.DESC
|
||||
kind = exp.PrimaryKeyColumnConstraint(desc=desc)
|
||||
elif self._match(TokenType.UNIQUE):
|
||||
kind = exp.UniqueColumnConstraint()
|
||||
elif self._match(TokenType.GENERATED):
|
||||
|
@ -2189,7 +2193,7 @@ class Parser(metaclass=_Parser):
|
|||
if not self._match(TokenType.R_BRACKET):
|
||||
self.raise_error("Expected ]")
|
||||
|
||||
this.comment = self._prev_comment
|
||||
this.comments = self._prev_comments
|
||||
return self._parse_bracket(this)
|
||||
|
||||
def _parse_case(self):
|
||||
|
@ -2256,6 +2260,33 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
|
||||
|
||||
def _parse_string_agg(self):
|
||||
if self._match(TokenType.DISTINCT):
|
||||
args = self._parse_csv(self._parse_conjunction)
|
||||
expression = self.expression(exp.Distinct, expressions=[seq_get(args, 0)])
|
||||
else:
|
||||
args = self._parse_csv(self._parse_conjunction)
|
||||
expression = seq_get(args, 0)
|
||||
|
||||
index = self._index
|
||||
if not self._match(TokenType.R_PAREN):
|
||||
# postgres: STRING_AGG([DISTINCT] expression, separator [ORDER BY expression1 {ASC | DESC} [, ...]])
|
||||
order = self._parse_order(this=expression)
|
||||
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
|
||||
|
||||
# Checks if we can parse an order clause: WITHIN GROUP (ORDER BY <order_by_expression_list> [ASC | DESC]).
|
||||
# This is done "manually", instead of letting _parse_window parse it into an exp.WithinGroup node, so that
|
||||
# the STRING_AGG call is parsed like in MySQL / SQLite and can thus be transpiled more easily to them.
|
||||
if not self._match(TokenType.WITHIN_GROUP):
|
||||
self._retreat(index)
|
||||
this = exp.GroupConcat.from_arg_list(args)
|
||||
self.validate_expression(this, args)
|
||||
return this
|
||||
|
||||
self._match_l_paren() # The corresponding match_r_paren will be called in parse_function (caller)
|
||||
order = self._parse_order(this=expression)
|
||||
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
|
||||
|
||||
def _parse_convert(self, strict):
|
||||
this = self._parse_column()
|
||||
if self._match(TokenType.USING):
|
||||
|
@ -2511,8 +2542,8 @@ class Parser(metaclass=_Parser):
|
|||
items = [parse_result] if parse_result is not None else []
|
||||
|
||||
while self._match(sep):
|
||||
if parse_result and self._prev_comment is not None:
|
||||
parse_result.comment = self._prev_comment
|
||||
if parse_result and self._prev_comments:
|
||||
parse_result.comments = self._prev_comments
|
||||
|
||||
parse_result = parse_method()
|
||||
if parse_result is not None:
|
||||
|
@ -2525,7 +2556,10 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
while self._match_set(expressions):
|
||||
this = self.expression(
|
||||
expressions[self._prev.token_type], this=this, expression=parse_method()
|
||||
expressions[self._prev.token_type],
|
||||
this=this,
|
||||
comments=self._prev_comments,
|
||||
expression=parse_method(),
|
||||
)
|
||||
|
||||
return this
|
||||
|
@ -2566,6 +2600,7 @@ class Parser(metaclass=_Parser):
|
|||
return self.expression(exp.Transaction, this=this, modes=modes)
|
||||
|
||||
def _parse_commit_or_rollback(self):
|
||||
chain = None
|
||||
savepoint = None
|
||||
is_rollback = self._prev.token_type == TokenType.ROLLBACK
|
||||
|
||||
|
@ -2575,9 +2610,13 @@ class Parser(metaclass=_Parser):
|
|||
self._match_text_seq("SAVEPOINT")
|
||||
savepoint = self._parse_id_var()
|
||||
|
||||
if self._match(TokenType.AND):
|
||||
chain = not self._match_text_seq("NO")
|
||||
self._match_text_seq("CHAIN")
|
||||
|
||||
if is_rollback:
|
||||
return self.expression(exp.Rollback, savepoint=savepoint)
|
||||
return self.expression(exp.Commit)
|
||||
return self.expression(exp.Commit, chain=chain)
|
||||
|
||||
def _parse_show(self):
|
||||
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
|
||||
|
@ -2651,14 +2690,14 @@ class Parser(metaclass=_Parser):
|
|||
def _match_l_paren(self, expression=None):
|
||||
if not self._match(TokenType.L_PAREN):
|
||||
self.raise_error("Expecting (")
|
||||
if expression and self._prev_comment:
|
||||
expression.comment = self._prev_comment
|
||||
if expression and self._prev_comments:
|
||||
expression.comments = self._prev_comments
|
||||
|
||||
def _match_r_paren(self, expression=None):
|
||||
if not self._match(TokenType.R_PAREN):
|
||||
self.raise_error("Expecting )")
|
||||
if expression and self._prev_comment:
|
||||
expression.comment = self._prev_comment
|
||||
if expression and self._prev_comments:
|
||||
expression.comments = self._prev_comments
|
||||
|
||||
def _match_texts(self, texts):
|
||||
if self._curr and self._curr.text.upper() in texts:
|
||||
|
|
|
@ -130,18 +130,20 @@ class Step:
|
|||
aggregations = []
|
||||
sequence = itertools.count()
|
||||
|
||||
for e in expression.expressions:
|
||||
aggregation = e.find(exp.AggFunc)
|
||||
|
||||
if aggregation:
|
||||
projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
|
||||
aggregations.append(e)
|
||||
for operand in aggregation.unnest_operands():
|
||||
def extract_agg_operands(expression):
|
||||
for agg in expression.find_all(exp.AggFunc):
|
||||
for operand in agg.unnest_operands():
|
||||
if isinstance(operand, exp.Column):
|
||||
continue
|
||||
if operand not in operands:
|
||||
operands[operand] = f"_a_{next(sequence)}"
|
||||
operand.replace(exp.column(operands[operand], quoted=True))
|
||||
|
||||
for e in expression.expressions:
|
||||
if e.find(exp.AggFunc):
|
||||
projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
|
||||
aggregations.append(e)
|
||||
extract_agg_operands(e)
|
||||
else:
|
||||
projections.append(e)
|
||||
|
||||
|
@ -156,6 +158,13 @@ class Step:
|
|||
aggregate = Aggregate()
|
||||
aggregate.source = step.name
|
||||
aggregate.name = step.name
|
||||
|
||||
having = expression.args.get("having")
|
||||
|
||||
if having:
|
||||
extract_agg_operands(having)
|
||||
aggregate.condition = having.this
|
||||
|
||||
aggregate.operands = tuple(
|
||||
alias(operand, alias_) for operand, alias_ in operands.items()
|
||||
)
|
||||
|
@ -172,11 +181,6 @@ class Step:
|
|||
aggregate.add_dependency(step)
|
||||
step = aggregate
|
||||
|
||||
having = expression.args.get("having")
|
||||
|
||||
if having:
|
||||
step.condition = having.this
|
||||
|
||||
order = expression.args.get("order")
|
||||
|
||||
if order:
|
||||
|
@ -188,6 +192,17 @@ class Step:
|
|||
|
||||
step.projections = projections
|
||||
|
||||
if isinstance(expression, exp.Select) and expression.args.get("distinct"):
|
||||
distinct = Aggregate()
|
||||
distinct.source = step.name
|
||||
distinct.name = step.name
|
||||
distinct.group = {
|
||||
e.alias_or_name: exp.column(col=e.alias_or_name, table=step.name)
|
||||
for e in projections or expression.expressions
|
||||
}
|
||||
distinct.add_dependency(step)
|
||||
step = distinct
|
||||
|
||||
limit = expression.args.get("limit")
|
||||
|
||||
if limit:
|
||||
|
@ -231,6 +246,9 @@ class Step:
|
|||
if self.condition:
|
||||
lines.append(f"{nested}Condition: {self.condition.sql()}")
|
||||
|
||||
if self.limit is not math.inf:
|
||||
lines.append(f"{nested}Limit: {self.limit}")
|
||||
|
||||
if self.dependencies:
|
||||
lines.append(f"{nested}Dependencies:")
|
||||
for dependency in self.dependencies:
|
||||
|
@ -258,12 +276,7 @@ class Scan(Step):
|
|||
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
|
||||
) -> Step:
|
||||
table = expression
|
||||
alias_ = expression.alias
|
||||
|
||||
if not alias_:
|
||||
raise UnsupportedError(
|
||||
"Tables/Subqueries must be aliased. Run it through the optimizer"
|
||||
)
|
||||
alias_ = expression.alias_or_name
|
||||
|
||||
if isinstance(expression, exp.Subquery):
|
||||
table = expression.this
|
||||
|
@ -338,6 +351,9 @@ class Aggregate(Step):
|
|||
lines.append(f"{indent}Group:")
|
||||
for expression in self.group.values():
|
||||
lines.append(f"{indent} - {expression.sql()}")
|
||||
if self.condition:
|
||||
lines.append(f"{indent}Having:")
|
||||
lines.append(f"{indent} - {self.condition.sql()}")
|
||||
if self.operands:
|
||||
lines.append(f"{indent}Operands:")
|
||||
for expression in self.operands:
|
||||
|
|
|
@ -81,6 +81,7 @@ class TokenType(AutoName):
|
|||
BINARY = auto()
|
||||
VARBINARY = auto()
|
||||
JSON = auto()
|
||||
JSONB = auto()
|
||||
TIMESTAMP = auto()
|
||||
TIMESTAMPTZ = auto()
|
||||
TIMESTAMPLTZ = auto()
|
||||
|
@ -91,6 +92,7 @@ class TokenType(AutoName):
|
|||
NULLABLE = auto()
|
||||
GEOMETRY = auto()
|
||||
HLLSKETCH = auto()
|
||||
HSTORE = auto()
|
||||
SUPER = auto()
|
||||
SERIAL = auto()
|
||||
SMALLSERIAL = auto()
|
||||
|
@ -113,6 +115,7 @@ class TokenType(AutoName):
|
|||
APPLY = auto()
|
||||
ARRAY = auto()
|
||||
ASC = auto()
|
||||
ASOF = auto()
|
||||
AT_TIME_ZONE = auto()
|
||||
AUTO_INCREMENT = auto()
|
||||
BEGIN = auto()
|
||||
|
@ -130,6 +133,7 @@ class TokenType(AutoName):
|
|||
COMMAND = auto()
|
||||
COMMENT = auto()
|
||||
COMMIT = auto()
|
||||
COMPOUND = auto()
|
||||
CONSTRAINT = auto()
|
||||
CREATE = auto()
|
||||
CROSS = auto()
|
||||
|
@ -271,6 +275,7 @@ class TokenType(AutoName):
|
|||
UNBOUNDED = auto()
|
||||
UNCACHE = auto()
|
||||
UNION = auto()
|
||||
UNLOGGED = auto()
|
||||
UNNEST = auto()
|
||||
UNPIVOT = auto()
|
||||
UPDATE = auto()
|
||||
|
@ -291,7 +296,7 @@ class TokenType(AutoName):
|
|||
|
||||
|
||||
class Token:
|
||||
__slots__ = ("token_type", "text", "line", "col", "comment")
|
||||
__slots__ = ("token_type", "text", "line", "col", "comments")
|
||||
|
||||
@classmethod
|
||||
def number(cls, number: int) -> Token:
|
||||
|
@ -319,13 +324,13 @@ class Token:
|
|||
text: str,
|
||||
line: int = 1,
|
||||
col: int = 1,
|
||||
comment: t.Optional[str] = None,
|
||||
comments: t.List[str] = [],
|
||||
) -> None:
|
||||
self.token_type = token_type
|
||||
self.text = text
|
||||
self.line = line
|
||||
self.col = max(col - len(text), 1)
|
||||
self.comment = comment
|
||||
self.comments = comments
|
||||
|
||||
def __repr__(self) -> str:
|
||||
attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
|
||||
|
@ -452,6 +457,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"COLLATE": TokenType.COLLATE,
|
||||
"COMMENT": TokenType.SCHEMA_COMMENT,
|
||||
"COMMIT": TokenType.COMMIT,
|
||||
"COMPOUND": TokenType.COMPOUND,
|
||||
"CONSTRAINT": TokenType.CONSTRAINT,
|
||||
"CREATE": TokenType.CREATE,
|
||||
"CROSS": TokenType.CROSS,
|
||||
|
@ -582,8 +588,9 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"TRAILING": TokenType.TRAILING,
|
||||
"UNBOUNDED": TokenType.UNBOUNDED,
|
||||
"UNION": TokenType.UNION,
|
||||
"UNPIVOT": TokenType.UNPIVOT,
|
||||
"UNLOGGED": TokenType.UNLOGGED,
|
||||
"UNNEST": TokenType.UNNEST,
|
||||
"UNPIVOT": TokenType.UNPIVOT,
|
||||
"UPDATE": TokenType.UPDATE,
|
||||
"USE": TokenType.USE,
|
||||
"USING": TokenType.USING,
|
||||
|
@ -686,12 +693,12 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"_current",
|
||||
"_line",
|
||||
"_col",
|
||||
"_comment",
|
||||
"_comments",
|
||||
"_char",
|
||||
"_end",
|
||||
"_peek",
|
||||
"_prev_token_line",
|
||||
"_prev_token_comment",
|
||||
"_prev_token_comments",
|
||||
"_prev_token_type",
|
||||
"_replace_backslash",
|
||||
)
|
||||
|
@ -708,13 +715,13 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._current = 0
|
||||
self._line = 1
|
||||
self._col = 1
|
||||
self._comment = None
|
||||
self._comments: t.List[str] = []
|
||||
|
||||
self._char = None
|
||||
self._end = None
|
||||
self._peek = None
|
||||
self._prev_token_line = -1
|
||||
self._prev_token_comment = None
|
||||
self._prev_token_comments: t.List[str] = []
|
||||
self._prev_token_type = None
|
||||
|
||||
def tokenize(self, sql: str) -> t.List[Token]:
|
||||
|
@ -767,7 +774,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
|
||||
self._prev_token_line = self._line
|
||||
self._prev_token_comment = self._comment
|
||||
self._prev_token_comments = self._comments
|
||||
self._prev_token_type = token_type # type: ignore
|
||||
self.tokens.append(
|
||||
Token(
|
||||
|
@ -775,10 +782,10 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._text if text is None else text,
|
||||
self._line,
|
||||
self._col,
|
||||
self._comment,
|
||||
self._comments,
|
||||
)
|
||||
)
|
||||
self._comment = None
|
||||
self._comments = []
|
||||
|
||||
if token_type in self.COMMANDS and (
|
||||
len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
|
||||
|
@ -857,22 +864,18 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
while not self._end and self._chars(comment_end_size) != comment_end:
|
||||
self._advance()
|
||||
|
||||
self._comment = self._text[comment_start_size : -comment_end_size + 1] # type: ignore
|
||||
self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) # type: ignore
|
||||
self._advance(comment_end_size - 1)
|
||||
else:
|
||||
while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: # type: ignore
|
||||
self._advance()
|
||||
self._comment = self._text[comment_start_size:] # type: ignore
|
||||
|
||||
# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. If both
|
||||
# types of comment can be attached to a token, the trailing one is discarded in favour of the leading one.
|
||||
self._comments.append(self._text[comment_start_size:]) # type: ignore
|
||||
|
||||
# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding.
|
||||
# Multiple consecutive comments are preserved by appending them to the current comments list.
|
||||
if comment_start_line == self._prev_token_line:
|
||||
if self._prev_token_comment is None:
|
||||
self.tokens[-1].comment = self._comment
|
||||
self._prev_token_comment = self._comment
|
||||
|
||||
self._comment = None
|
||||
self.tokens[-1].comments.extend(self._comments)
|
||||
self._comments = []
|
||||
|
||||
return True
|
||||
|
||||
|
|
|
@ -2,6 +2,8 @@ from __future__ import annotations
|
|||
|
||||
import typing as t
|
||||
|
||||
from sqlglot.helper import find_new_name
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.generator import Generator
|
||||
|
||||
|
@ -43,6 +45,43 @@ def unalias_group(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Convert SELECT DISTINCT ON statements to a subquery with a window function.
|
||||
|
||||
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
|
||||
|
||||
Args:
|
||||
expression: the expression that will be transformed.
|
||||
|
||||
Returns:
|
||||
The transformed expression.
|
||||
"""
|
||||
if (
|
||||
isinstance(expression, exp.Select)
|
||||
and expression.args.get("distinct")
|
||||
and expression.args["distinct"].args.get("on")
|
||||
and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
|
||||
):
|
||||
distinct_cols = [e.copy() for e in expression.args["distinct"].args["on"].expressions]
|
||||
outer_selects = [e.copy() for e in expression.expressions]
|
||||
nested = expression.copy()
|
||||
nested.args["distinct"].pop()
|
||||
row_number = find_new_name(expression.named_selects, "_row_number")
|
||||
window = exp.Window(
|
||||
this=exp.RowNumber(),
|
||||
partition_by=distinct_cols,
|
||||
)
|
||||
order = nested.args.get("order")
|
||||
if order:
|
||||
window.set("order", order.copy())
|
||||
order.pop()
|
||||
window = exp.alias_(window, row_number)
|
||||
nested.select(window, copy=False)
|
||||
return exp.select(*outer_selects).from_(nested.subquery()).where(f'"{row_number}" = 1')
|
||||
return expression
|
||||
|
||||
|
||||
def preprocess(
|
||||
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
|
||||
to_sql: t.Callable[[Generator, exp.Expression], str],
|
||||
|
@ -81,3 +120,4 @@ def delegate(attr: str) -> t.Callable:
|
|||
|
||||
|
||||
UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))}
|
||||
ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue