1
0
Fork 0

Merging upstream version 6.1.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 08:04:41 +01:00
parent 3c6d649c90
commit 08ecea3adf
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
61 changed files with 1844 additions and 1555 deletions

View file

@ -7,6 +7,7 @@ from sqlglot.dialects.mysql import MySQL
from sqlglot.dialects.oracle import Oracle
from sqlglot.dialects.postgres import Postgres
from sqlglot.dialects.presto import Presto
from sqlglot.dialects.redshift import Redshift
from sqlglot.dialects.snowflake import Snowflake
from sqlglot.dialects.spark import Spark
from sqlglot.dialects.sqlite import SQLite

View file

@ -44,6 +44,7 @@ class BigQuery(Dialect):
]
IDENTIFIERS = ["`"]
ESCAPE = "\\"
HEX_STRINGS = [("0x", ""), ("0X", "")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
@ -120,9 +121,5 @@ class BigQuery(Dialect):
def intersect_op(self, expression):
if not expression.args.get("distinct", False):
self.unsupported(
"INTERSECT without DISTINCT is not supported in BigQuery"
)
return (
f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
)
self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery")
return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"

View file

@ -20,6 +20,7 @@ class Dialects(str, Enum):
ORACLE = "oracle"
POSTGRES = "postgres"
PRESTO = "presto"
REDSHIFT = "redshift"
SNOWFLAKE = "snowflake"
SPARK = "spark"
SQLITE = "sqlite"
@ -53,12 +54,19 @@ class _Dialect(type):
klass.generator_class = getattr(klass, "Generator", Generator)
klass.tokenizer = klass.tokenizer_class()
klass.quote_start, klass.quote_end = list(klass.tokenizer_class.QUOTES.items())[
0
]
klass.identifier_start, klass.identifier_end = list(
klass.tokenizer_class.IDENTIFIERS.items()
)[0]
klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0]
if klass.tokenizer_class._BIT_STRINGS and exp.BitString not in klass.generator_class.TRANSFORMS:
bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.BitString
] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}"
if klass.tokenizer_class._HEX_STRINGS and exp.HexString not in klass.generator_class.TRANSFORMS:
hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.HexString
] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
return klass
@ -122,9 +130,7 @@ class Dialect(metaclass=_Dialect):
return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql)
def parse_into(self, expression_type, sql, **opts):
return self.parser(**opts).parse_into(
expression_type, self.tokenizer.tokenize(sql), sql
)
return self.parser(**opts).parse_into(expression_type, self.tokenizer.tokenize(sql), sql)
def generate(self, expression, **opts):
return self.generator(**opts).generate(expression)
@ -164,9 +170,7 @@ class Dialect(metaclass=_Dialect):
def rename_func(name):
return (
lambda self, expression: f"{name}({csv(*[self.sql(e) for e in expression.args.values()])})"
)
return lambda self, expression: f"{name}({csv(*[self.sql(e) for e in expression.args.values()])})"
def approx_count_distinct_sql(self, expression):
@ -260,8 +264,7 @@ def format_time_lambda(exp_class, dialect, default=None):
return exp_class(
this=list_get(args, 0),
format=Dialect[dialect].format_time(
list_get(args, 1)
or (Dialect[dialect].time_format if default is True else default)
list_get(args, 1) or (Dialect[dialect].time_format if default is True else default)
),
)

View file

@ -63,10 +63,7 @@ def _sort_array_reverse(args):
def _struct_pack_sql(self, expression):
args = [
self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e)
for e in expression.expressions
]
args = [self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) for e in expression.expressions]
return f"STRUCT_PACK({', '.join(args)})"

View file

@ -109,9 +109,7 @@ def _unnest_to_explode_sql(self, expression):
alias=exp.TableAlias(this=alias.this, columns=[column]),
)
)
for expression, column in zip(
unnest.expressions, alias.columns if alias else []
)
for expression, column in zip(unnest.expressions, alias.columns if alias else [])
)
return self.join_sql(expression)
@ -206,14 +204,11 @@ class Hive(Dialect):
substr=list_get(args, 0),
position=list_get(args, 2),
),
"LOG": (
lambda args: exp.Log.from_arg_list(args)
if len(args) > 1
else exp.Ln.from_arg_list(args)
),
"LOG": (lambda args: exp.Log.from_arg_list(args) if len(args) > 1 else exp.Ln.from_arg_list(args)),
"MAP": _parse_map,
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
"PERCENTILE": exp.Quantile.from_arg_list,
"PERCENTILE_APPROX": exp.ApproxQuantile.from_arg_list,
"COLLECT_SET": exp.SetAgg.from_arg_list,
"SIZE": exp.ArraySize.from_arg_list,
"SPLIT": exp.RegexpSplit.from_arg_list,
@ -262,6 +257,7 @@ class Hive(Dialect):
HiveMap: _map_sql,
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e.args['value'])}",
exp.Quantile: rename_func("PERCENTILE"),
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"),
exp.RegexpSplit: rename_func("SPLIT"),
exp.SafeDivide: no_safe_divide_sql,
@ -296,8 +292,7 @@ class Hive(Dialect):
def datatype_sql(self, expression):
if (
expression.this
in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR)
expression.this in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR)
and not expression.expressions
):
expression = exp.DataType.build("text")

View file

@ -49,6 +49,21 @@ def _str_to_date_sql(self, expression):
return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})"
def _trim_sql(self, expression):
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
remove_chars = self.sql(expression, "expression")
# Use TRIM/LTRIM/RTRIM syntax if the expression isn't mysql-specific
if not remove_chars:
return self.trim_sql(expression)
trim_type = f"{trim_type} " if trim_type else ""
remove_chars = f"{remove_chars} " if remove_chars else ""
from_part = "FROM " if trim_type or remove_chars else ""
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
def _date_add(expression_class):
def func(args):
interval = list_get(args, 1)
@ -88,9 +103,12 @@ class MySQL(Dialect):
QUOTES = ["'", '"']
COMMENTS = ["--", "#", ("/*", "*/")]
IDENTIFIERS = ["`"]
BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")]
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
"SEPARATOR": TokenType.SEPARATOR,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
"_BIG5": TokenType.INTRODUCER,
@ -145,6 +163,15 @@ class MySQL(Dialect):
"STR_TO_DATE": _str_to_date,
}
FUNCTION_PARSERS = {
**Parser.FUNCTION_PARSERS,
"GROUP_CONCAT": lambda self: self.expression(
exp.GroupConcat,
this=self._parse_lambda(),
separator=self._match(TokenType.SEPARATOR) and self._parse_field(),
),
}
class Generator(Generator):
NULL_ORDERING_SUPPORTED = False
@ -158,6 +185,8 @@ class MySQL(Dialect):
exp.DateAdd: _date_add_sql("ADD"),
exp.DateSub: _date_add_sql("SUB"),
exp.DateTrunc: _date_trunc_sql,
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_date_sql,
exp.Trim: _trim_sql,
}

View file

@ -51,6 +51,14 @@ class Oracle(Dialect):
sep="",
)
def alias_sql(self, expression):
if isinstance(expression.this, exp.Table):
to_sql = self.sql(expression, "alias")
# oracle does not allow "AS" between table and alias
to_sql = f" {to_sql}" if to_sql else ""
return f"{self.sql(expression, 'this')}{to_sql}"
return super().alias_sql(expression)
def offset_sql(self, expression):
return f"{super().offset_sql(expression)} ROWS"

View file

@ -11,6 +11,7 @@ from sqlglot.dialects.dialect import (
from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
from sqlglot.transforms import delegate, preprocess
def _date_add_sql(kind):
@ -32,11 +33,96 @@ def _date_add_sql(kind):
return func
def _lateral_sql(self, expression):
this = self.sql(expression, "this")
if isinstance(expression.this, exp.Subquery):
return f"LATERAL{self.sep()}{this}"
alias = expression.args["alias"]
table = alias.name
table = f" {table}" if table else table
columns = self.expressions(alias, key="columns", flat=True)
columns = f" AS {columns}" if columns else ""
return f"LATERAL{self.sep()}{this}{table}{columns}"
def _substring_sql(self, expression):
this = self.sql(expression, "this")
start = self.sql(expression, "start")
length = self.sql(expression, "length")
from_part = f" FROM {start}" if start else ""
for_part = f" FOR {length}" if length else ""
return f"SUBSTRING({this}{from_part}{for_part})"
def _trim_sql(self, expression):
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
remove_chars = self.sql(expression, "expression")
collation = self.sql(expression, "collation")
# Use TRIM/LTRIM/RTRIM syntax if the expression isn't postgres-specific
if not remove_chars and not collation:
return self.trim_sql(expression)
trim_type = f"{trim_type} " if trim_type else ""
remove_chars = f"{remove_chars} " if remove_chars else ""
from_part = "FROM " if trim_type or remove_chars else ""
collation = f" COLLATE {collation}" if collation else ""
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def _auto_increment_to_serial(expression):
auto = expression.find(exp.AutoIncrementColumnConstraint)
if auto:
expression = expression.copy()
expression.args["constraints"].remove(auto.parent)
kind = expression.args["kind"]
if kind.this == exp.DataType.Type.INT:
kind.replace(exp.DataType(this=exp.DataType.Type.SERIAL))
elif kind.this == exp.DataType.Type.SMALLINT:
kind.replace(exp.DataType(this=exp.DataType.Type.SMALLSERIAL))
elif kind.this == exp.DataType.Type.BIGINT:
kind.replace(exp.DataType(this=exp.DataType.Type.BIGSERIAL))
return expression
def _serial_to_generated(expression):
kind = expression.args["kind"]
if kind.this == exp.DataType.Type.SERIAL:
data_type = exp.DataType(this=exp.DataType.Type.INT)
elif kind.this == exp.DataType.Type.SMALLSERIAL:
data_type = exp.DataType(this=exp.DataType.Type.SMALLINT)
elif kind.this == exp.DataType.Type.BIGSERIAL:
data_type = exp.DataType(this=exp.DataType.Type.BIGINT)
else:
data_type = None
if data_type:
expression = expression.copy()
expression.args["kind"].replace(data_type)
constraints = expression.args["constraints"]
generated = exp.ColumnConstraint(kind=exp.GeneratedAsIdentityColumnConstraint(this=False))
notnull = exp.ColumnConstraint(kind=exp.NotNullColumnConstraint())
if notnull not in constraints:
constraints.insert(0, notnull)
if generated not in constraints:
constraints.insert(0, generated)
return expression
class Postgres(Dialect):
null_ordering = "nulls_are_large"
time_format = "'YYYY-MM-DD HH24:MI:SS'"
time_mapping = {
"AM": "%p", # AM or PM
"AM": "%p",
"PM": "%p",
"D": "%w", # 1-based day of week
"DD": "%d", # day of month
"DDD": "%j", # zero padded day of year
@ -65,14 +151,25 @@ class Postgres(Dialect):
}
class Tokenizer(Tokenizer):
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
"SERIAL": TokenType.AUTO_INCREMENT,
"ALWAYS": TokenType.ALWAYS,
"BY DEFAULT": TokenType.BY_DEFAULT,
"IDENTITY": TokenType.IDENTITY,
"FOR": TokenType.FOR,
"GENERATED": TokenType.GENERATED,
"DOUBLE PRECISION": TokenType.DOUBLE,
"BIGSERIAL": TokenType.BIGSERIAL,
"SERIAL": TokenType.SERIAL,
"SMALLSERIAL": TokenType.SMALLSERIAL,
"UUID": TokenType.UUID,
}
class Parser(Parser):
STRICT_CAST = False
FUNCTIONS = {
**Parser.FUNCTIONS,
"TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "postgres"),
@ -86,14 +183,18 @@ class Postgres(Dialect):
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
exp.DataType.Type.BINARY: "BYTEA",
}
TOKEN_MAPPING = {
TokenType.AUTO_INCREMENT: "SERIAL",
exp.DataType.Type.DATETIME: "TIMESTAMP",
}
TRANSFORMS = {
**Generator.TRANSFORMS,
exp.ColumnDef: preprocess(
[
_auto_increment_to_serial,
_serial_to_generated,
],
delegate("columndef_sql"),
),
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: lambda self, e: f"{self.sql(e, 'this')}#>{self.sql(e, 'path')}",
@ -102,8 +203,11 @@ class Postgres(Dialect):
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: _date_add_sql("+"),
exp.DateSub: _date_add_sql("-"),
exp.Lateral: _lateral_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Substring: _substring_sql,
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TableSample: no_tablesample_sql,
exp.Trim: _trim_sql,
exp.TryCast: no_trycast_sql,
}

View file

@ -96,9 +96,7 @@ def _ts_or_ds_to_date_sql(self, expression):
time_format = self.format_time(expression)
if time_format and time_format not in (Presto.time_format, Presto.date_format):
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
return (
f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)"
)
return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)"
def _ts_or_ds_add_sql(self, expression):
@ -141,6 +139,7 @@ class Presto(Dialect):
"FROM_UNIXTIME": exp.UnixToTime.from_arg_list,
"STRPOS": exp.StrPosition.from_arg_list,
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
}
class Generator(Generator):
@ -193,6 +192,7 @@ class Presto(Dialect):
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY = {self.sql(e.args['value'])}",
exp.Quantile: _quantile_sql,
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.SortArray: _no_sort_array,

View file

@ -0,0 +1,34 @@
from sqlglot import exp
from sqlglot.dialects.postgres import Postgres
from sqlglot.tokens import TokenType
class Redshift(Postgres):
time_format = "'YYYY-MM-DD HH:MI:SS'"
time_mapping = {
**Postgres.time_mapping,
"MON": "%b",
"HH": "%H",
}
class Tokenizer(Postgres.Tokenizer):
ESCAPE = "\\"
KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS,
"GEOMETRY": TokenType.GEOMETRY,
"GEOGRAPHY": TokenType.GEOGRAPHY,
"HLLSKETCH": TokenType.HLLSKETCH,
"SUPER": TokenType.SUPER,
"TIME": TokenType.TIMESTAMP,
"TIMETZ": TokenType.TIMESTAMPTZ,
"VARBYTE": TokenType.BINARY,
"SIMILAR TO": TokenType.SIMILAR_TO,
}
class Generator(Postgres.Generator):
TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING,
exp.DataType.Type.BINARY: "VARBYTE",
exp.DataType.Type.INT: "INTEGER",
}

View file

@ -23,9 +23,7 @@ def _snowflake_to_timestamp(args):
# case: <numeric_expr> [ , <scale> ]
if second_arg.name not in ["0", "3", "9"]:
raise ValueError(
f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9"
)
raise ValueError(f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9")
if second_arg.name == "0":
timescale = exp.UnixToTime.SECONDS

View file

@ -65,12 +65,11 @@ class Spark(Hive):
this=list_get(args, 0),
start=exp.Sub(
this=exp.Length(this=list_get(args, 0)),
expression=exp.Add(
this=list_get(args, 1), expression=exp.Literal.number(1)
),
expression=exp.Add(this=list_get(args, 1), expression=exp.Literal.number(1)),
),
length=list_get(args, 1),
),
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
}
class Generator(Hive.Generator):
@ -82,11 +81,7 @@ class Spark(Hive):
}
TRANSFORMS = {
**{
k: v
for k, v in Hive.Generator.TRANSFORMS.items()
if k not in {exp.ArraySort}
},
**{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort}},
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
@ -102,5 +97,5 @@ class Spark(Hive):
HiveMap: _map_sql,
}
def bitstring_sql(self, expression):
return f"X'{self.sql(expression, 'this')}'"
class Tokenizer(Hive.Tokenizer):
HEX_STRINGS = [("X'", "'")]

View file

@ -16,6 +16,7 @@ from sqlglot.tokens import Tokenizer, TokenType
class SQLite(Dialect):
class Tokenizer(Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"]
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")]
KEYWORDS = {
**Tokenizer.KEYWORDS,

View file

@ -8,3 +8,6 @@ class Trino(Presto):
**Presto.Generator.TRANSFORMS,
exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
}
class Tokenizer(Presto.Tokenizer):
HEX_STRINGS = [("X'", "'")]