1
0
Fork 0

Merging upstream version 11.1.3.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:26:26 +01:00
parent 8c1c1864c5
commit fb546b57e5
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
95 changed files with 32569 additions and 30081 deletions

View file

@ -2,6 +2,7 @@
from __future__ import annotations
import re
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
@ -31,13 +32,6 @@ def _date_add(expression_class: t.Type[E]) -> t.Callable[[t.Sequence], E]:
return func
def _date_trunc(args: t.Sequence) -> exp.Expression:
unit = seq_get(args, 1)
if isinstance(unit, exp.Column):
unit = exp.Var(this=unit.name)
return exp.DateTrunc(this=seq_get(args, 0), expression=unit)
def _date_add_sql(
data_type: str, kind: str
) -> t.Callable[[generator.Generator, exp.Expression], str]:
@ -158,11 +152,23 @@ class BigQuery(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"DATE_TRUNC": _date_trunc,
"DATE_TRUNC": lambda args: exp.DateTrunc(
unit=exp.Literal.string(seq_get(args, 1).name), # type: ignore
this=seq_get(args, 0),
),
"DATE_ADD": _date_add(exp.DateAdd),
"DATETIME_ADD": _date_add(exp.DatetimeAdd),
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
"REGEXP_CONTAINS": exp.RegexpLike.from_arg_list,
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
this=seq_get(args, 0),
expression=seq_get(args, 1),
position=seq_get(args, 2),
occurrence=seq_get(args, 3),
group=exp.Literal.number(1)
if re.compile(str(seq_get(args, 1))).groups == 1
else None,
),
"TIME_ADD": _date_add(exp.TimeAdd),
"TIMESTAMP_ADD": _date_add(exp.TimestampAdd),
"DATE_SUB": _date_add(exp.DateSub),
@ -214,6 +220,7 @@ class BigQuery(Dialect):
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
exp.DateStrToDate: datestrtodate_sql,
exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")),
exp.GroupConcat: rename_func("STRING_AGG"),
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
@ -226,11 +233,12 @@ class BigQuery(Dialect):
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
exp.TimeStrToTime: timestrtotime_sql,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.VariancePop: rename_func("VAR_POP"),
exp.Values: _derived_table_values_to_unnest,
exp.ReturnsProperty: _returnsproperty_sql,
exp.Create: _create_sql,
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression),
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
if e.name == "IMMUTABLE"
else "NOT DETERMINISTIC",
@ -251,6 +259,10 @@ class BigQuery(Dialect):
exp.DataType.Type.VARCHAR: "STRING",
exp.DataType.Type.NVARCHAR: "STRING",
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
}
EXPLICIT_UNION = True

View file

@ -4,6 +4,7 @@ from sqlglot import exp
from sqlglot.dialects.dialect import parse_date_delta
from sqlglot.dialects.spark import Spark
from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql
from sqlglot.tokens import TokenType
class Databricks(Spark):
@ -21,3 +22,11 @@ class Databricks(Spark):
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
}
PARAMETER_TOKEN = "$"
class Tokenizer(Spark.Tokenizer):
SINGLE_TOKENS = {
**Spark.Tokenizer.SINGLE_TOKENS,
"$": TokenType.PARAMETER,
}

View file

@ -215,24 +215,19 @@ DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
def _rename(self, expression):
args = flatten(expression.args.values())
return f"{self.normalize_func(name)}({self.format_args(*args)})"
return _rename
return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
if expression.args.get("accuracy"):
self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
return f"APPROX_COUNT_DISTINCT({self.format_args(expression.this)})"
return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql(self: Generator, expression: exp.If) -> str:
expressions = self.format_args(
expression.this, expression.args.get("true"), expression.args.get("false")
return self.func(
"IF", expression.this, expression.args.get("true"), expression.args.get("false")
)
return f"IF({expressions})"
def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
@ -318,13 +313,13 @@ def var_map_sql(
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
self.unsupported("Cannot convert array columns into map.")
return f"{map_func_name}({self.format_args(keys, values)})"
return self.func(map_func_name, keys, values)
args = []
for key, value in zip(keys.expressions, values.expressions):
args.append(self.sql(key))
args.append(self.sql(value))
return f"{map_func_name}({self.format_args(*args)})"
return self.func(map_func_name, *args)
def format_time_lambda(
@ -400,10 +395,9 @@ def locate_to_strposition(args: t.Sequence) -> exp.Expression:
def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
args = self.format_args(
expression.args.get("substr"), expression.this, expression.args.get("position")
return self.func(
"LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
)
return f"LOCATE({args})"
def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:

View file

@ -39,23 +39,6 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
return func
def if_sql(self: generator.Generator, expression: exp.If) -> str:
"""
Drill requires backticks around certain SQL reserved words, IF being one of them, This function
adds the backticks around the keyword IF.
Args:
self: The Drill dialect
expression: The input IF expression
Returns: The expression with IF in backticks.
"""
expressions = self.format_args(
expression.this, expression.args.get("true"), expression.args.get("false")
)
return f"`IF`({expressions})"
def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
@ -134,7 +117,7 @@ class Drill(Dialect):
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
}
TRANSFORMS = {
@ -148,7 +131,7 @@ class Drill(Dialect):
exp.DateSub: _date_add_sql("SUB"),
exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})",
exp.If: if_sql,
exp.If: lambda self, e: f"`IF`({self.format_args(e.this, e.args.get('true'), e.args.get('false'))})",
exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",

View file

@ -73,11 +73,24 @@ def _datatype_sql(self, expression):
return self.datatype_sql(expression)
def _regexp_extract_sql(self, expression):
bad_args = list(filter(expression.args.get, ("position", "occurrence")))
if bad_args:
self.unsupported(f"REGEXP_EXTRACT does not support arg(s) {bad_args}")
return self.func(
"REGEXP_EXTRACT",
expression.args.get("this"),
expression.args.get("expression"),
expression.args.get("group"),
)
class DuckDB(Dialect):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
":=": TokenType.EQ,
"ATTACH": TokenType.COMMAND,
"CHARACTER VARYING": TokenType.VARCHAR,
}
@ -117,7 +130,7 @@ class DuckDB(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0])
if isinstance(seq_get(e.expressions, 0), exp.Select)
else rename_func("LIST_VALUE")(self, e),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
@ -125,7 +138,9 @@ class DuckDB(Dialect):
exp.ArraySum: rename_func("LIST_SUM"),
exp.DataType: _datatype_sql,
exp.DateAdd: _date_add,
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.format_args(e.args.get("unit") or "'day'", e.expression, e.this)})""",
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", e.args.get("unit") or exp.Literal.string("day"), e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)",
@ -137,6 +152,7 @@ class DuckDB(Dialect):
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Pivot: no_pivot_sql,
exp.Properties: no_properties_sql,
exp.RegexpExtract: _regexp_extract_sql,
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"),
exp.SafeDivide: no_safe_divide_sql,

View file

@ -43,7 +43,7 @@ def _add_date_sql(self, expression):
else expression.expression
)
modified_increment = exp.Literal.number(modified_increment)
return f"{func}({self.format_args(expression.this, modified_increment.this)})"
return self.func(func, expression.this, modified_increment.this)
def _date_diff_sql(self, expression):
@ -66,7 +66,7 @@ def _property_sql(self, expression):
def _str_to_unix(self, expression):
return f"UNIX_TIMESTAMP({self.format_args(expression.this, _time_format(self, expression))})"
return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression))
def _str_to_date(self, expression):
@ -312,7 +312,9 @@ class Hive(Dialect):
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.TsOrDsToDate: _to_date_sql,
exp.TryCast: no_trycast_sql,
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.format_args(e.this, _time_format(self, e))})",
exp.UnixToStr: lambda self, e: self.func(
"FROM_UNIXTIME", e.this, _time_format(self, e)
),
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
@ -324,9 +326,9 @@ class Hive(Dialect):
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA,
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA,
}
def with_properties(self, properties):

View file

@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
no_paren_current_date_sql,
no_tablesample_sql,
no_trycast_sql,
rename_func,
strposition_to_locate_sql,
)
from sqlglot.helper import seq_get
@ -22,9 +23,8 @@ def _show_parser(*args, **kwargs):
def _date_trunc_sql(self, expression):
unit = expression.name.lower()
expr = self.sql(expression.expression)
expr = self.sql(expression, "this")
unit = expression.text("unit")
if unit == "day":
return f"DATE({expr})"
@ -42,7 +42,7 @@ def _date_trunc_sql(self, expression):
concat = f"CONCAT(YEAR({expr}), ' 1 1')"
date_format = "%Y %c %e"
else:
self.unsupported("Unexpected interval unit: {unit}")
self.unsupported(f"Unexpected interval unit: {unit}")
return f"DATE({expr})"
return f"STR_TO_DATE({concat}, '{date_format}')"
@ -443,6 +443,10 @@ class MySQL(Dialect):
exp.DateAdd: _date_add_sql("ADD"),
exp.DateSub: _date_add_sql("SUB"),
exp.DateTrunc: _date_trunc_sql,
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_date_sql,

View file

@ -1,15 +1,49 @@
from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sql
from sqlglot.helper import csv
from sqlglot.tokens import TokenType
PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
TokenType.COLUMN,
TokenType.RETURNING,
}
def _limit_sql(self, expression):
return self.fetch_sql(exp.Fetch(direction="FIRST", count=expression.expression))
def _parse_xml_table(self) -> exp.XMLTable:
this = self._parse_string()
passing = None
columns = None
if self._match_text_seq("PASSING"):
# The BY VALUE keywords are optional and are provided for semantic clarity
self._match_text_seq("BY", "VALUE")
passing = self._parse_csv(
lambda: self._parse_table(alias_tokens=PASSING_TABLE_ALIAS_TOKENS)
)
by_ref = self._match_text_seq("RETURNING", "SEQUENCE", "BY", "REF")
if self._match_text_seq("COLUMNS"):
columns = self._parse_csv(lambda: self._parse_column_def(self._parse_field(any_token=True)))
return self.expression(
exp.XMLTable,
this=this,
passing=passing,
columns=columns,
by_ref=by_ref,
)
class Oracle(Dialect):
# https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212
# https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes
@ -43,6 +77,11 @@ class Oracle(Dialect):
"DECODE": exp.Matches.from_arg_list,
}
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
**parser.Parser.FUNCTION_PARSERS,
"XMLTABLE": _parse_xml_table,
}
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
@ -74,7 +113,7 @@ class Oracle(Dialect):
exp.Substring: rename_func("SUBSTR"),
}
def query_modifiers(self, expression, *sqls):
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
return csv(
*sqls,
*[self.sql(sql) for sql in expression.args.get("joins") or []],
@ -97,19 +136,32 @@ class Oracle(Dialect):
sep="",
)
def offset_sql(self, expression):
def offset_sql(self, expression: exp.Offset) -> str:
return f"{super().offset_sql(expression)} ROWS"
def table_sql(self, expression):
return super().table_sql(expression, sep=" ")
def table_sql(self, expression: exp.Table, sep: str = " ") -> str:
return super().table_sql(expression, sep=sep)
def xmltable_sql(self, expression: exp.XMLTable) -> str:
this = self.sql(expression, "this")
passing = self.expressions(expression, "passing")
passing = f"{self.sep()}PASSING{self.seg(passing)}" if passing else ""
columns = self.expressions(expression, "columns")
columns = f"{self.sep()}COLUMNS{self.seg(columns)}" if columns else ""
by_ref = (
f"{self.sep()}RETURNING SEQUENCE BY REF" if expression.args.get("by_ref") else ""
)
return f"XMLTABLE({self.sep('')}{self.indent(this + passing + by_ref + columns)}{self.seg(')', sep='')}"
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"COLUMNS": TokenType.COLUMN,
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"MINUS": TokenType.EXCEPT,
"NVARCHAR2": TokenType.NVARCHAR,
"RETURNING": TokenType.RETURNING,
"START": TokenType.BEGIN,
"TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR,
"NVARCHAR2": TokenType.NVARCHAR,
}

View file

@ -58,17 +58,17 @@ def _date_diff_sql(self, expression):
age = f"AGE({end}, {start})"
if unit == "WEEK":
extract = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7"
unit = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7"
elif unit == "MONTH":
extract = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})"
unit = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})"
elif unit == "QUARTER":
extract = f"EXTRACT(year FROM {age}) * 4 + EXTRACT(month FROM {age}) / 3"
unit = f"EXTRACT(year FROM {age}) * 4 + EXTRACT(month FROM {age}) / 3"
elif unit == "YEAR":
extract = f"EXTRACT(year FROM {age})"
unit = f"EXTRACT(year FROM {age})"
else:
self.unsupported(f"Unsupported DATEDIFF unit {unit}")
unit = age
return f"CAST({extract} AS BIGINT)"
return f"CAST({unit} AS BIGINT)"
def _substring_sql(self, expression):
@ -206,6 +206,8 @@ class Postgres(Dialect):
}
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
@ -236,7 +238,7 @@ class Postgres(Dialect):
"UUID": TokenType.UUID,
"CSTRING": TokenType.PSEUDO_TYPE,
}
QUOTES = ["'", "$$"]
SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
"$": TokenType.PARAMETER,
@ -265,6 +267,7 @@ class Postgres(Dialect):
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
PARAMETER_TOKEN = "$"
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore

View file

@ -52,7 +52,7 @@ def _initcap_sql(self, expression):
def _decode_sql(self, expression):
_ensure_utf8(expression.args.get("charset"))
return f"FROM_UTF8({self.format_args(expression.this, expression.args.get('replace'))})"
return self.func("FROM_UTF8", expression.this, expression.args.get("replace"))
def _encode_sql(self, expression):
@ -65,8 +65,7 @@ def _no_sort_array(self, expression):
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
else:
comparator = None
args = self.format_args(expression.this, comparator)
return f"ARRAY_SORT({args})"
return self.func("ARRAY_SORT", expression.this, comparator)
def _schema_sql(self, expression):
@ -125,7 +124,7 @@ def _sequence_sql(self, expression):
else:
start = exp.Cast(this=start, to=to)
return f"SEQUENCE({self.format_args(start, end, step)})"
return self.func("SEQUENCE", start, end, step)
def _ensure_utf8(charset):

View file

@ -5,6 +5,7 @@ import typing as t
from sqlglot import exp, transforms
from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.postgres import Postgres
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@ -19,6 +20,11 @@ class Redshift(Postgres):
class Parser(Postgres.Parser):
FUNCTIONS = {
**Postgres.Parser.FUNCTIONS, # type: ignore
"DATEDIFF": lambda args: exp.DateDiff(
this=seq_get(args, 2),
expression=seq_get(args, 1),
unit=seq_get(args, 0),
),
"DECODE": exp.Matches.from_arg_list,
"NVL": exp.Coalesce.from_arg_list,
}
@ -41,7 +47,6 @@ class Redshift(Postgres):
KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS, # type: ignore
"ENCODE": TokenType.ENCODE,
"GEOMETRY": TokenType.GEOMETRY,
"GEOGRAPHY": TokenType.GEOGRAPHY,
"HLLSKETCH": TokenType.HLLSKETCH,
@ -62,12 +67,15 @@ class Redshift(Postgres):
PROPERTIES_LOCATION = {
**Postgres.Generator.PROPERTIES_LOCATION, # type: ignore
exp.LikeProperty: exp.Properties.Location.POST_SCHEMA_WITH,
exp.LikeProperty: exp.Properties.Location.POST_WITH,
}
TRANSFORMS = {
**Postgres.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
exp.DateDiff: lambda self, e: self.func(
"DATEDIFF", e.args.get("unit") or "day", e.expression, e.this
),
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
exp.DistStyleProperty: lambda self, e: self.naked_property(e),

View file

@ -178,18 +178,25 @@ class Snowflake(Dialect):
),
}
RANGE_PARSERS = {
**parser.Parser.RANGE_PARSERS, # type: ignore
TokenType.LIKE_ANY: lambda self, this: self._parse_escape(
self.expression(exp.LikeAny, this=this, expression=self._parse_bitwise())
),
TokenType.ILIKE_ANY: lambda self, this: self._parse_escape(
self.expression(exp.ILikeAny, this=this, expression=self._parse_bitwise())
),
}
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]
STRING_ESCAPES = ["\\", "'"]
SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
"$": TokenType.PARAMETER,
}
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"EXCLUDE": TokenType.EXCEPT,
"ILIKE ANY": TokenType.ILIKE_ANY,
"LIKE ANY": TokenType.LIKE_ANY,
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"PUT": TokenType.COMMAND,
"RENAME": TokenType.REPLACE,
@ -201,8 +208,14 @@ class Snowflake(Dialect):
"SAMPLE": TokenType.TABLE_SAMPLE,
}
SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
"$": TokenType.PARAMETER,
}
class Generator(generator.Generator):
CREATE_TRANSIENT = True
PARAMETER_TOKEN = "$"
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
@ -214,14 +227,15 @@ class Snowflake(Dialect):
exp.If: rename_func("IFF"),
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Matches: rename_func("DECODE"),
exp.StrPosition: lambda self, e: f"{self.normalize_func('POSITION')}({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})",
exp.StrPosition: lambda self, e: self.func(
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
exp.UnixToTime: _unix_to_time_sql,
exp.DayOfWeek: rename_func("DAYOFWEEK"),
}
@ -236,6 +250,12 @@ class Snowflake(Dialect):
"replace": "RENAME",
}
def ilikeany_sql(self, expression: exp.ILikeAny) -> str:
return self.binary(expression, "ILIKE ANY")
def likeany_sql(self, expression: exp.LikeAny) -> str:
return self.binary(expression, "LIKE ANY")
def except_op(self, expression):
if not expression.args.get("distinct", False):
self.unsupported("EXCEPT with All is not supported in Snowflake")

View file

@ -86,6 +86,11 @@ class Spark(Hive):
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1),
unit=exp.var(seq_get(args, 0)),
),
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
}
FUNCTION_PARSERS = {
@ -133,7 +138,7 @@ class Spark(Hive):
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
exp.DateTrunc: rename_func("TRUNC"),
exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
@ -142,7 +147,9 @@ class Spark(Hive):
exp.Map: _map_sql,
exp.Reduce: rename_func("AGGREGATE"),
exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
exp.TimestampTrunc: lambda self, e: f"DATE_TRUNC({self.sql(e, 'unit')}, {self.sql(e, 'this')})",
exp.TimestampTrunc: lambda self, e: self.func(
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
),
exp.Trim: trim_sql,
exp.VariancePop: rename_func("VAR_POP"),
exp.DateFromParts: rename_func("MAKE_DATE"),
@ -157,16 +164,16 @@ class Spark(Hive):
TRANSFORMS.pop(exp.ILike)
WRAP_DERIVED_VALUES = False
CREATE_FUNCTION_AS = False
CREATE_FUNCTION_RETURN_AS = False
def cast_sql(self, expression: exp.Cast) -> str:
if isinstance(expression.this, exp.Cast) and expression.this.is_type(
exp.DataType.Type.JSON
):
schema = f"'{self.sql(expression, 'to')}'"
return f"FROM_JSON({self.format_args(self.sql(expression.this, 'this'), schema)})"
return self.func("FROM_JSON", expression.this.this, schema)
if expression.to.is_type(exp.DataType.Type.JSON):
return f"TO_JSON({self.sql(expression, 'this')})"
return self.func("TO_JSON", expression.this)
return super(Spark.Generator, self).cast_sql(expression)

View file

@ -39,7 +39,7 @@ def _date_add_sql(self, expression):
modifier = expression.name if modifier.is_string else self.sql(modifier)
unit = expression.args.get("unit")
modifier = f"'{modifier} {unit.name}'" if unit else f"'{modifier}'"
return f"{self.normalize_func('DATE')}({self.format_args(expression.this, modifier)})"
return self.func("DATE", expression.this, modifier)
class SQLite(Dialect):

View file

@ -1,11 +1,33 @@
from __future__ import annotations
from sqlglot import exp, generator, parser
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect
from sqlglot.tokens import TokenType
class Teradata(Dialect):
class Tokenizer(tokens.Tokenizer):
# https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Comparison-Operators-and-Functions/Comparison-Operators/ANSI-Compliance
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"BYTEINT": TokenType.SMALLINT,
"SEL": TokenType.SELECT,
"INS": TokenType.INSERT,
"MOD": TokenType.MOD,
"LT": TokenType.LT,
"LE": TokenType.LTE,
"GT": TokenType.GT,
"GE": TokenType.GTE,
"^=": TokenType.NEQ,
"NE": TokenType.NEQ,
"NOT=": TokenType.NEQ,
"ST_GEOMETRY": TokenType.GEOMETRY,
}
# teradata does not support % for modulus
SINGLE_TOKENS = {**tokens.Tokenizer.SINGLE_TOKENS}
SINGLE_TOKENS.pop("%")
class Parser(parser.Parser):
CHARSET_TRANSLATORS = {
"GRAPHIC_TO_KANJISJIS",
@ -42,6 +64,14 @@ class Teradata(Dialect):
"UNICODE_TO_UNICODE_NFKD",
}
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS}
FUNC_TOKENS.remove(TokenType.REPLACE)
STATEMENT_PARSERS = {
**parser.Parser.STATEMENT_PARSERS, # type: ignore
TokenType.REPLACE: lambda self: self._parse_create(),
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS, # type: ignore
"TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST),
@ -76,6 +106,11 @@ class Teradata(Dialect):
)
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.GEOMETRY: "ST_GEOMETRY",
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX,
@ -93,3 +128,11 @@ class Teradata(Dialect):
where_sql = self.sql(expression, "where")
sql = f"UPDATE {this}{from_sql} SET {set_sql}{where_sql}"
return self.prepend_ctes(expression, sql)
def mod_sql(self, expression: exp.Mod) -> str:
return self.binary(expression, "MOD")
def datatype_sql(self, expression: exp.DataType) -> str:
type_sql = super().datatype_sql(expression)
prefix_sql = expression.args.get("prefix")
return f"SYSUDTLIB.{type_sql}" if prefix_sql else type_sql

View file

@ -92,7 +92,7 @@ def _parse_eomonth(args):
def generate_date_delta_with_unit_sql(self, e):
func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF"
return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})"
return self.func(func, e.text("unit"), e.expression, e.this)
def _format_sql(self, e):
@ -101,7 +101,7 @@ def _format_sql(self, e):
if isinstance(e, exp.NumberToStr)
else exp.Literal.string(format_time(e.text("format"), TSQL.inverse_time_mapping))
)
return f"FORMAT({self.format_args(e.this, fmt)})"
return self.func("FORMAT", e.this, fmt)
def _string_agg_sql(self, e):
@ -408,7 +408,7 @@ class TSQL(Dialect):
):
return this
expressions = self._parse_csv(self._parse_udf_kwarg)
expressions = self._parse_csv(self._parse_function_parameter)
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
class Generator(generator.Generator):