Merging upstream version 11.1.3.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
8c1c1864c5
commit
fb546b57e5
95 changed files with 32569 additions and 30081 deletions
|
@ -2,6 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
|
@ -31,13 +32,6 @@ def _date_add(expression_class: t.Type[E]) -> t.Callable[[t.Sequence], E]:
|
|||
return func
|
||||
|
||||
|
||||
def _date_trunc(args: t.Sequence) -> exp.Expression:
|
||||
unit = seq_get(args, 1)
|
||||
if isinstance(unit, exp.Column):
|
||||
unit = exp.Var(this=unit.name)
|
||||
return exp.DateTrunc(this=seq_get(args, 0), expression=unit)
|
||||
|
||||
|
||||
def _date_add_sql(
|
||||
data_type: str, kind: str
|
||||
) -> t.Callable[[generator.Generator, exp.Expression], str]:
|
||||
|
@ -158,11 +152,23 @@ class BigQuery(Dialect):
|
|||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"DATE_TRUNC": _date_trunc,
|
||||
"DATE_TRUNC": lambda args: exp.DateTrunc(
|
||||
unit=exp.Literal.string(seq_get(args, 1).name), # type: ignore
|
||||
this=seq_get(args, 0),
|
||||
),
|
||||
"DATE_ADD": _date_add(exp.DateAdd),
|
||||
"DATETIME_ADD": _date_add(exp.DatetimeAdd),
|
||||
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
|
||||
"REGEXP_CONTAINS": exp.RegexpLike.from_arg_list,
|
||||
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
|
||||
this=seq_get(args, 0),
|
||||
expression=seq_get(args, 1),
|
||||
position=seq_get(args, 2),
|
||||
occurrence=seq_get(args, 3),
|
||||
group=exp.Literal.number(1)
|
||||
if re.compile(str(seq_get(args, 1))).groups == 1
|
||||
else None,
|
||||
),
|
||||
"TIME_ADD": _date_add(exp.TimeAdd),
|
||||
"TIMESTAMP_ADD": _date_add(exp.TimestampAdd),
|
||||
"DATE_SUB": _date_add(exp.DateSub),
|
||||
|
@ -214,6 +220,7 @@ class BigQuery(Dialect):
|
|||
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
|
||||
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")),
|
||||
exp.GroupConcat: rename_func("STRING_AGG"),
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.IntDiv: rename_func("DIV"),
|
||||
|
@ -226,11 +233,12 @@ class BigQuery(Dialect):
|
|||
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
|
||||
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
exp.Values: _derived_table_values_to_unnest,
|
||||
exp.ReturnsProperty: _returnsproperty_sql,
|
||||
exp.Create: _create_sql,
|
||||
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
|
||||
exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression),
|
||||
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
|
||||
if e.name == "IMMUTABLE"
|
||||
else "NOT DETERMINISTIC",
|
||||
|
@ -251,6 +259,10 @@ class BigQuery(Dialect):
|
|||
exp.DataType.Type.VARCHAR: "STRING",
|
||||
exp.DataType.Type.NVARCHAR: "STRING",
|
||||
}
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
}
|
||||
|
||||
EXPLICIT_UNION = True
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ from sqlglot import exp
|
|||
from sqlglot.dialects.dialect import parse_date_delta
|
||||
from sqlglot.dialects.spark import Spark
|
||||
from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
class Databricks(Spark):
|
||||
|
@ -21,3 +22,11 @@ class Databricks(Spark):
|
|||
exp.DateAdd: generate_date_delta_with_unit_sql,
|
||||
exp.DateDiff: generate_date_delta_with_unit_sql,
|
||||
}
|
||||
|
||||
PARAMETER_TOKEN = "$"
|
||||
|
||||
class Tokenizer(Spark.Tokenizer):
|
||||
SINGLE_TOKENS = {
|
||||
**Spark.Tokenizer.SINGLE_TOKENS,
|
||||
"$": TokenType.PARAMETER,
|
||||
}
|
||||
|
|
|
@ -215,24 +215,19 @@ DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
|
|||
|
||||
|
||||
def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
|
||||
def _rename(self, expression):
|
||||
args = flatten(expression.args.values())
|
||||
return f"{self.normalize_func(name)}({self.format_args(*args)})"
|
||||
|
||||
return _rename
|
||||
return lambda self, expression: self.func(name, *flatten(expression.args.values()))
|
||||
|
||||
|
||||
def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
|
||||
if expression.args.get("accuracy"):
|
||||
self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
|
||||
return f"APPROX_COUNT_DISTINCT({self.format_args(expression.this)})"
|
||||
return self.func("APPROX_COUNT_DISTINCT", expression.this)
|
||||
|
||||
|
||||
def if_sql(self: Generator, expression: exp.If) -> str:
|
||||
expressions = self.format_args(
|
||||
expression.this, expression.args.get("true"), expression.args.get("false")
|
||||
return self.func(
|
||||
"IF", expression.this, expression.args.get("true"), expression.args.get("false")
|
||||
)
|
||||
return f"IF({expressions})"
|
||||
|
||||
|
||||
def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
|
||||
|
@ -318,13 +313,13 @@ def var_map_sql(
|
|||
|
||||
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
|
||||
self.unsupported("Cannot convert array columns into map.")
|
||||
return f"{map_func_name}({self.format_args(keys, values)})"
|
||||
return self.func(map_func_name, keys, values)
|
||||
|
||||
args = []
|
||||
for key, value in zip(keys.expressions, values.expressions):
|
||||
args.append(self.sql(key))
|
||||
args.append(self.sql(value))
|
||||
return f"{map_func_name}({self.format_args(*args)})"
|
||||
return self.func(map_func_name, *args)
|
||||
|
||||
|
||||
def format_time_lambda(
|
||||
|
@ -400,10 +395,9 @@ def locate_to_strposition(args: t.Sequence) -> exp.Expression:
|
|||
|
||||
|
||||
def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
|
||||
args = self.format_args(
|
||||
expression.args.get("substr"), expression.this, expression.args.get("position")
|
||||
return self.func(
|
||||
"LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
|
||||
)
|
||||
return f"LOCATE({args})"
|
||||
|
||||
|
||||
def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
|
||||
|
|
|
@ -39,23 +39,6 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
|
|||
return func
|
||||
|
||||
|
||||
def if_sql(self: generator.Generator, expression: exp.If) -> str:
|
||||
"""
|
||||
Drill requires backticks around certain SQL reserved words, IF being one of them, This function
|
||||
adds the backticks around the keyword IF.
|
||||
Args:
|
||||
self: The Drill dialect
|
||||
expression: The input IF expression
|
||||
|
||||
Returns: The expression with IF in backticks.
|
||||
|
||||
"""
|
||||
expressions = self.format_args(
|
||||
expression.this, expression.args.get("true"), expression.args.get("false")
|
||||
)
|
||||
return f"`IF`({expressions})"
|
||||
|
||||
|
||||
def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
|
@ -134,7 +117,7 @@ class Drill(Dialect):
|
|||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
|
@ -148,7 +131,7 @@ class Drill(Dialect):
|
|||
exp.DateSub: _date_add_sql("SUB"),
|
||||
exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)",
|
||||
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})",
|
||||
exp.If: if_sql,
|
||||
exp.If: lambda self, e: f"`IF`({self.format_args(e.this, e.args.get('true'), e.args.get('false'))})",
|
||||
exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
|
||||
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
|
|
|
@ -73,11 +73,24 @@ def _datatype_sql(self, expression):
|
|||
return self.datatype_sql(expression)
|
||||
|
||||
|
||||
def _regexp_extract_sql(self, expression):
|
||||
bad_args = list(filter(expression.args.get, ("position", "occurrence")))
|
||||
if bad_args:
|
||||
self.unsupported(f"REGEXP_EXTRACT does not support arg(s) {bad_args}")
|
||||
return self.func(
|
||||
"REGEXP_EXTRACT",
|
||||
expression.args.get("this"),
|
||||
expression.args.get("expression"),
|
||||
expression.args.get("group"),
|
||||
)
|
||||
|
||||
|
||||
class DuckDB(Dialect):
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
":=": TokenType.EQ,
|
||||
"ATTACH": TokenType.COMMAND,
|
||||
"CHARACTER VARYING": TokenType.VARCHAR,
|
||||
}
|
||||
|
||||
|
@ -117,7 +130,7 @@ class DuckDB(Dialect):
|
|||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
|
||||
exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0])
|
||||
if isinstance(seq_get(e.expressions, 0), exp.Select)
|
||||
else rename_func("LIST_VALUE")(self, e),
|
||||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||
|
@ -125,7 +138,9 @@ class DuckDB(Dialect):
|
|||
exp.ArraySum: rename_func("LIST_SUM"),
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.DateAdd: _date_add,
|
||||
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.format_args(e.args.get("unit") or "'day'", e.expression, e.this)})""",
|
||||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF", e.args.get("unit") or exp.Literal.string("day"), e.expression, e.this
|
||||
),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
|
||||
exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)",
|
||||
|
@ -137,6 +152,7 @@ class DuckDB(Dialect):
|
|||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.Properties: no_properties_sql,
|
||||
exp.RegexpExtract: _regexp_extract_sql,
|
||||
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
|
||||
exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"),
|
||||
exp.SafeDivide: no_safe_divide_sql,
|
||||
|
|
|
@ -43,7 +43,7 @@ def _add_date_sql(self, expression):
|
|||
else expression.expression
|
||||
)
|
||||
modified_increment = exp.Literal.number(modified_increment)
|
||||
return f"{func}({self.format_args(expression.this, modified_increment.this)})"
|
||||
return self.func(func, expression.this, modified_increment.this)
|
||||
|
||||
|
||||
def _date_diff_sql(self, expression):
|
||||
|
@ -66,7 +66,7 @@ def _property_sql(self, expression):
|
|||
|
||||
|
||||
def _str_to_unix(self, expression):
|
||||
return f"UNIX_TIMESTAMP({self.format_args(expression.this, _time_format(self, expression))})"
|
||||
return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression))
|
||||
|
||||
|
||||
def _str_to_date(self, expression):
|
||||
|
@ -312,7 +312,9 @@ class Hive(Dialect):
|
|||
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.TsOrDsToDate: _to_date_sql,
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.format_args(e.this, _time_format(self, e))})",
|
||||
exp.UnixToStr: lambda self, e: self.func(
|
||||
"FROM_UNIXTIME", e.this, _time_format(self, e)
|
||||
),
|
||||
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
||||
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
|
||||
|
@ -324,9 +326,9 @@ class Hive(Dialect):
|
|||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
}
|
||||
|
||||
def with_properties(self, properties):
|
||||
|
|
|
@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
|
|||
no_paren_current_date_sql,
|
||||
no_tablesample_sql,
|
||||
no_trycast_sql,
|
||||
rename_func,
|
||||
strposition_to_locate_sql,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
|
@ -22,9 +23,8 @@ def _show_parser(*args, **kwargs):
|
|||
|
||||
|
||||
def _date_trunc_sql(self, expression):
|
||||
unit = expression.name.lower()
|
||||
|
||||
expr = self.sql(expression.expression)
|
||||
expr = self.sql(expression, "this")
|
||||
unit = expression.text("unit")
|
||||
|
||||
if unit == "day":
|
||||
return f"DATE({expr})"
|
||||
|
@ -42,7 +42,7 @@ def _date_trunc_sql(self, expression):
|
|||
concat = f"CONCAT(YEAR({expr}), ' 1 1')"
|
||||
date_format = "%Y %c %e"
|
||||
else:
|
||||
self.unsupported("Unexpected interval unit: {unit}")
|
||||
self.unsupported(f"Unexpected interval unit: {unit}")
|
||||
return f"DATE({expr})"
|
||||
|
||||
return f"STR_TO_DATE({concat}, '{date_format}')"
|
||||
|
@ -443,6 +443,10 @@ class MySQL(Dialect):
|
|||
exp.DateAdd: _date_add_sql("ADD"),
|
||||
exp.DateSub: _date_add_sql("SUB"),
|
||||
exp.DateTrunc: _date_trunc_sql,
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
|
||||
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
|
||||
exp.StrToDate: _str_to_date_sql,
|
||||
exp.StrToTime: _str_to_date_sql,
|
||||
|
|
|
@ -1,15 +1,49 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sql
|
||||
from sqlglot.helper import csv
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
|
||||
TokenType.COLUMN,
|
||||
TokenType.RETURNING,
|
||||
}
|
||||
|
||||
|
||||
def _limit_sql(self, expression):
|
||||
return self.fetch_sql(exp.Fetch(direction="FIRST", count=expression.expression))
|
||||
|
||||
|
||||
def _parse_xml_table(self) -> exp.XMLTable:
|
||||
this = self._parse_string()
|
||||
|
||||
passing = None
|
||||
columns = None
|
||||
|
||||
if self._match_text_seq("PASSING"):
|
||||
# The BY VALUE keywords are optional and are provided for semantic clarity
|
||||
self._match_text_seq("BY", "VALUE")
|
||||
passing = self._parse_csv(
|
||||
lambda: self._parse_table(alias_tokens=PASSING_TABLE_ALIAS_TOKENS)
|
||||
)
|
||||
|
||||
by_ref = self._match_text_seq("RETURNING", "SEQUENCE", "BY", "REF")
|
||||
|
||||
if self._match_text_seq("COLUMNS"):
|
||||
columns = self._parse_csv(lambda: self._parse_column_def(self._parse_field(any_token=True)))
|
||||
|
||||
return self.expression(
|
||||
exp.XMLTable,
|
||||
this=this,
|
||||
passing=passing,
|
||||
columns=columns,
|
||||
by_ref=by_ref,
|
||||
)
|
||||
|
||||
|
||||
class Oracle(Dialect):
|
||||
# https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212
|
||||
# https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes
|
||||
|
@ -43,6 +77,11 @@ class Oracle(Dialect):
|
|||
"DECODE": exp.Matches.from_arg_list,
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
"XMLTABLE": _parse_xml_table,
|
||||
}
|
||||
|
||||
class Generator(generator.Generator):
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
|
||||
|
@ -74,7 +113,7 @@ class Oracle(Dialect):
|
|||
exp.Substring: rename_func("SUBSTR"),
|
||||
}
|
||||
|
||||
def query_modifiers(self, expression, *sqls):
|
||||
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
|
||||
return csv(
|
||||
*sqls,
|
||||
*[self.sql(sql) for sql in expression.args.get("joins") or []],
|
||||
|
@ -97,19 +136,32 @@ class Oracle(Dialect):
|
|||
sep="",
|
||||
)
|
||||
|
||||
def offset_sql(self, expression):
|
||||
def offset_sql(self, expression: exp.Offset) -> str:
|
||||
return f"{super().offset_sql(expression)} ROWS"
|
||||
|
||||
def table_sql(self, expression):
|
||||
return super().table_sql(expression, sep=" ")
|
||||
def table_sql(self, expression: exp.Table, sep: str = " ") -> str:
|
||||
return super().table_sql(expression, sep=sep)
|
||||
|
||||
def xmltable_sql(self, expression: exp.XMLTable) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
passing = self.expressions(expression, "passing")
|
||||
passing = f"{self.sep()}PASSING{self.seg(passing)}" if passing else ""
|
||||
columns = self.expressions(expression, "columns")
|
||||
columns = f"{self.sep()}COLUMNS{self.seg(columns)}" if columns else ""
|
||||
by_ref = (
|
||||
f"{self.sep()}RETURNING SEQUENCE BY REF" if expression.args.get("by_ref") else ""
|
||||
)
|
||||
return f"XMLTABLE({self.sep('')}{self.indent(this + passing + by_ref + columns)}{self.seg(')', sep='')}"
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"COLUMNS": TokenType.COLUMN,
|
||||
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
|
||||
"MINUS": TokenType.EXCEPT,
|
||||
"NVARCHAR2": TokenType.NVARCHAR,
|
||||
"RETURNING": TokenType.RETURNING,
|
||||
"START": TokenType.BEGIN,
|
||||
"TOP": TokenType.TOP,
|
||||
"VARCHAR2": TokenType.VARCHAR,
|
||||
"NVARCHAR2": TokenType.NVARCHAR,
|
||||
}
|
||||
|
|
|
@ -58,17 +58,17 @@ def _date_diff_sql(self, expression):
|
|||
age = f"AGE({end}, {start})"
|
||||
|
||||
if unit == "WEEK":
|
||||
extract = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7"
|
||||
unit = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7"
|
||||
elif unit == "MONTH":
|
||||
extract = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})"
|
||||
unit = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})"
|
||||
elif unit == "QUARTER":
|
||||
extract = f"EXTRACT(year FROM {age}) * 4 + EXTRACT(month FROM {age}) / 3"
|
||||
unit = f"EXTRACT(year FROM {age}) * 4 + EXTRACT(month FROM {age}) / 3"
|
||||
elif unit == "YEAR":
|
||||
extract = f"EXTRACT(year FROM {age})"
|
||||
unit = f"EXTRACT(year FROM {age})"
|
||||
else:
|
||||
self.unsupported(f"Unsupported DATEDIFF unit {unit}")
|
||||
unit = age
|
||||
|
||||
return f"CAST({extract} AS BIGINT)"
|
||||
return f"CAST({unit} AS BIGINT)"
|
||||
|
||||
|
||||
def _substring_sql(self, expression):
|
||||
|
@ -206,6 +206,8 @@ class Postgres(Dialect):
|
|||
}
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
QUOTES = ["'", "$$"]
|
||||
|
||||
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
|
||||
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
|
||||
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
|
||||
|
@ -236,7 +238,7 @@ class Postgres(Dialect):
|
|||
"UUID": TokenType.UUID,
|
||||
"CSTRING": TokenType.PSEUDO_TYPE,
|
||||
}
|
||||
QUOTES = ["'", "$$"]
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
**tokens.Tokenizer.SINGLE_TOKENS,
|
||||
"$": TokenType.PARAMETER,
|
||||
|
@ -265,6 +267,7 @@ class Postgres(Dialect):
|
|||
|
||||
class Generator(generator.Generator):
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
PARAMETER_TOKEN = "$"
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
|
|
|
@ -52,7 +52,7 @@ def _initcap_sql(self, expression):
|
|||
|
||||
def _decode_sql(self, expression):
|
||||
_ensure_utf8(expression.args.get("charset"))
|
||||
return f"FROM_UTF8({self.format_args(expression.this, expression.args.get('replace'))})"
|
||||
return self.func("FROM_UTF8", expression.this, expression.args.get("replace"))
|
||||
|
||||
|
||||
def _encode_sql(self, expression):
|
||||
|
@ -65,8 +65,7 @@ def _no_sort_array(self, expression):
|
|||
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
|
||||
else:
|
||||
comparator = None
|
||||
args = self.format_args(expression.this, comparator)
|
||||
return f"ARRAY_SORT({args})"
|
||||
return self.func("ARRAY_SORT", expression.this, comparator)
|
||||
|
||||
|
||||
def _schema_sql(self, expression):
|
||||
|
@ -125,7 +124,7 @@ def _sequence_sql(self, expression):
|
|||
else:
|
||||
start = exp.Cast(this=start, to=to)
|
||||
|
||||
return f"SEQUENCE({self.format_args(start, end, step)})"
|
||||
return self.func("SEQUENCE", start, end, step)
|
||||
|
||||
|
||||
def _ensure_utf8(charset):
|
||||
|
|
|
@ -5,6 +5,7 @@ import typing as t
|
|||
from sqlglot import exp, transforms
|
||||
from sqlglot.dialects.dialect import rename_func
|
||||
from sqlglot.dialects.postgres import Postgres
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
|
@ -19,6 +20,11 @@ class Redshift(Postgres):
|
|||
class Parser(Postgres.Parser):
|
||||
FUNCTIONS = {
|
||||
**Postgres.Parser.FUNCTIONS, # type: ignore
|
||||
"DATEDIFF": lambda args: exp.DateDiff(
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
unit=seq_get(args, 0),
|
||||
),
|
||||
"DECODE": exp.Matches.from_arg_list,
|
||||
"NVL": exp.Coalesce.from_arg_list,
|
||||
}
|
||||
|
@ -41,7 +47,6 @@ class Redshift(Postgres):
|
|||
|
||||
KEYWORDS = {
|
||||
**Postgres.Tokenizer.KEYWORDS, # type: ignore
|
||||
"ENCODE": TokenType.ENCODE,
|
||||
"GEOMETRY": TokenType.GEOMETRY,
|
||||
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
||||
"HLLSKETCH": TokenType.HLLSKETCH,
|
||||
|
@ -62,12 +67,15 @@ class Redshift(Postgres):
|
|||
|
||||
PROPERTIES_LOCATION = {
|
||||
**Postgres.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.LikeProperty: exp.Properties.Location.POST_SCHEMA_WITH,
|
||||
exp.LikeProperty: exp.Properties.Location.POST_WITH,
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**Postgres.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
|
||||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATEDIFF", e.args.get("unit") or "day", e.expression, e.this
|
||||
),
|
||||
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
|
||||
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
|
||||
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
|
||||
|
|
|
@ -178,18 +178,25 @@ class Snowflake(Dialect):
|
|||
),
|
||||
}
|
||||
|
||||
RANGE_PARSERS = {
|
||||
**parser.Parser.RANGE_PARSERS, # type: ignore
|
||||
TokenType.LIKE_ANY: lambda self, this: self._parse_escape(
|
||||
self.expression(exp.LikeAny, this=this, expression=self._parse_bitwise())
|
||||
),
|
||||
TokenType.ILIKE_ANY: lambda self, this: self._parse_escape(
|
||||
self.expression(exp.ILikeAny, this=this, expression=self._parse_bitwise())
|
||||
),
|
||||
}
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
QUOTES = ["'", "$$"]
|
||||
STRING_ESCAPES = ["\\", "'"]
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
**tokens.Tokenizer.SINGLE_TOKENS,
|
||||
"$": TokenType.PARAMETER,
|
||||
}
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"EXCLUDE": TokenType.EXCEPT,
|
||||
"ILIKE ANY": TokenType.ILIKE_ANY,
|
||||
"LIKE ANY": TokenType.LIKE_ANY,
|
||||
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
|
||||
"PUT": TokenType.COMMAND,
|
||||
"RENAME": TokenType.REPLACE,
|
||||
|
@ -201,8 +208,14 @@ class Snowflake(Dialect):
|
|||
"SAMPLE": TokenType.TABLE_SAMPLE,
|
||||
}
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
**tokens.Tokenizer.SINGLE_TOKENS,
|
||||
"$": TokenType.PARAMETER,
|
||||
}
|
||||
|
||||
class Generator(generator.Generator):
|
||||
CREATE_TRANSIENT = True
|
||||
PARAMETER_TOKEN = "$"
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
|
@ -214,14 +227,15 @@ class Snowflake(Dialect):
|
|||
exp.If: rename_func("IFF"),
|
||||
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
exp.Matches: rename_func("DECODE"),
|
||||
exp.StrPosition: lambda self, e: f"{self.normalize_func('POSITION')}({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})",
|
||||
exp.StrPosition: lambda self, e: self.func(
|
||||
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
|
||||
),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
|
||||
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
|
||||
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
}
|
||||
|
@ -236,6 +250,12 @@ class Snowflake(Dialect):
|
|||
"replace": "RENAME",
|
||||
}
|
||||
|
||||
def ilikeany_sql(self, expression: exp.ILikeAny) -> str:
|
||||
return self.binary(expression, "ILIKE ANY")
|
||||
|
||||
def likeany_sql(self, expression: exp.LikeAny) -> str:
|
||||
return self.binary(expression, "LIKE ANY")
|
||||
|
||||
def except_op(self, expression):
|
||||
if not expression.args.get("distinct", False):
|
||||
self.unsupported("EXCEPT with All is not supported in Snowflake")
|
||||
|
|
|
@ -86,6 +86,11 @@ class Spark(Hive):
|
|||
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
|
||||
this=seq_get(args, 1),
|
||||
unit=exp.var(seq_get(args, 0)),
|
||||
),
|
||||
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
|
@ -133,7 +138,7 @@ class Spark(Hive):
|
|||
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
|
||||
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
|
||||
exp.DateTrunc: rename_func("TRUNC"),
|
||||
exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
|
||||
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
|
||||
exp.StrToDate: _str_to_date,
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
|
@ -142,7 +147,9 @@ class Spark(Hive):
|
|||
exp.Map: _map_sql,
|
||||
exp.Reduce: rename_func("AGGREGATE"),
|
||||
exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
|
||||
exp.TimestampTrunc: lambda self, e: f"DATE_TRUNC({self.sql(e, 'unit')}, {self.sql(e, 'this')})",
|
||||
exp.TimestampTrunc: lambda self, e: self.func(
|
||||
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
|
||||
),
|
||||
exp.Trim: trim_sql,
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
exp.DateFromParts: rename_func("MAKE_DATE"),
|
||||
|
@ -157,16 +164,16 @@ class Spark(Hive):
|
|||
TRANSFORMS.pop(exp.ILike)
|
||||
|
||||
WRAP_DERIVED_VALUES = False
|
||||
CREATE_FUNCTION_AS = False
|
||||
CREATE_FUNCTION_RETURN_AS = False
|
||||
|
||||
def cast_sql(self, expression: exp.Cast) -> str:
|
||||
if isinstance(expression.this, exp.Cast) and expression.this.is_type(
|
||||
exp.DataType.Type.JSON
|
||||
):
|
||||
schema = f"'{self.sql(expression, 'to')}'"
|
||||
return f"FROM_JSON({self.format_args(self.sql(expression.this, 'this'), schema)})"
|
||||
return self.func("FROM_JSON", expression.this.this, schema)
|
||||
if expression.to.is_type(exp.DataType.Type.JSON):
|
||||
return f"TO_JSON({self.sql(expression, 'this')})"
|
||||
return self.func("TO_JSON", expression.this)
|
||||
|
||||
return super(Spark.Generator, self).cast_sql(expression)
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ def _date_add_sql(self, expression):
|
|||
modifier = expression.name if modifier.is_string else self.sql(modifier)
|
||||
unit = expression.args.get("unit")
|
||||
modifier = f"'{modifier} {unit.name}'" if unit else f"'{modifier}'"
|
||||
return f"{self.normalize_func('DATE')}({self.format_args(expression.this, modifier)})"
|
||||
return self.func("DATE", expression.this, modifier)
|
||||
|
||||
|
||||
class SQLite(Dialect):
|
||||
|
|
|
@ -1,11 +1,33 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, parser
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
class Teradata(Dialect):
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
# https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Comparison-Operators-and-Functions/Comparison-Operators/ANSI-Compliance
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"BYTEINT": TokenType.SMALLINT,
|
||||
"SEL": TokenType.SELECT,
|
||||
"INS": TokenType.INSERT,
|
||||
"MOD": TokenType.MOD,
|
||||
"LT": TokenType.LT,
|
||||
"LE": TokenType.LTE,
|
||||
"GT": TokenType.GT,
|
||||
"GE": TokenType.GTE,
|
||||
"^=": TokenType.NEQ,
|
||||
"NE": TokenType.NEQ,
|
||||
"NOT=": TokenType.NEQ,
|
||||
"ST_GEOMETRY": TokenType.GEOMETRY,
|
||||
}
|
||||
|
||||
# teradata does not support % for modulus
|
||||
SINGLE_TOKENS = {**tokens.Tokenizer.SINGLE_TOKENS}
|
||||
SINGLE_TOKENS.pop("%")
|
||||
|
||||
class Parser(parser.Parser):
|
||||
CHARSET_TRANSLATORS = {
|
||||
"GRAPHIC_TO_KANJISJIS",
|
||||
|
@ -42,6 +64,14 @@ class Teradata(Dialect):
|
|||
"UNICODE_TO_UNICODE_NFKD",
|
||||
}
|
||||
|
||||
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS}
|
||||
FUNC_TOKENS.remove(TokenType.REPLACE)
|
||||
|
||||
STATEMENT_PARSERS = {
|
||||
**parser.Parser.STATEMENT_PARSERS, # type: ignore
|
||||
TokenType.REPLACE: lambda self: self._parse_create(),
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS, # type: ignore
|
||||
"TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST),
|
||||
|
@ -76,6 +106,11 @@ class Teradata(Dialect):
|
|||
)
|
||||
|
||||
class Generator(generator.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.GEOMETRY: "ST_GEOMETRY",
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX,
|
||||
|
@ -93,3 +128,11 @@ class Teradata(Dialect):
|
|||
where_sql = self.sql(expression, "where")
|
||||
sql = f"UPDATE {this}{from_sql} SET {set_sql}{where_sql}"
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
def mod_sql(self, expression: exp.Mod) -> str:
|
||||
return self.binary(expression, "MOD")
|
||||
|
||||
def datatype_sql(self, expression: exp.DataType) -> str:
|
||||
type_sql = super().datatype_sql(expression)
|
||||
prefix_sql = expression.args.get("prefix")
|
||||
return f"SYSUDTLIB.{type_sql}" if prefix_sql else type_sql
|
||||
|
|
|
@ -92,7 +92,7 @@ def _parse_eomonth(args):
|
|||
|
||||
def generate_date_delta_with_unit_sql(self, e):
|
||||
func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF"
|
||||
return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})"
|
||||
return self.func(func, e.text("unit"), e.expression, e.this)
|
||||
|
||||
|
||||
def _format_sql(self, e):
|
||||
|
@ -101,7 +101,7 @@ def _format_sql(self, e):
|
|||
if isinstance(e, exp.NumberToStr)
|
||||
else exp.Literal.string(format_time(e.text("format"), TSQL.inverse_time_mapping))
|
||||
)
|
||||
return f"FORMAT({self.format_args(e.this, fmt)})"
|
||||
return self.func("FORMAT", e.this, fmt)
|
||||
|
||||
|
||||
def _string_agg_sql(self, e):
|
||||
|
@ -408,7 +408,7 @@ class TSQL(Dialect):
|
|||
):
|
||||
return this
|
||||
|
||||
expressions = self._parse_csv(self._parse_udf_kwarg)
|
||||
expressions = self._parse_csv(self._parse_function_parameter)
|
||||
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
|
||||
|
||||
class Generator(generator.Generator):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue