Merging upstream version 11.7.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
0c053462ae
commit
8d96084fad
144 changed files with 44104 additions and 39367 deletions
|
@ -13,6 +13,7 @@ from sqlglot.dialects.dialect import (
|
|||
max_or_greatest,
|
||||
min_or_least,
|
||||
no_ilike_sql,
|
||||
parse_date_delta_with_interval,
|
||||
rename_func,
|
||||
timestrtotime_sql,
|
||||
ts_or_ds_to_date_sql,
|
||||
|
@ -23,18 +24,6 @@ from sqlglot.tokens import TokenType
|
|||
E = t.TypeVar("E", bound=exp.Expression)
|
||||
|
||||
|
||||
def _date_add(expression_class: t.Type[E]) -> t.Callable[[t.Sequence], E]:
|
||||
def func(args):
|
||||
interval = seq_get(args, 1)
|
||||
return expression_class(
|
||||
this=seq_get(args, 0),
|
||||
expression=interval.this,
|
||||
unit=interval.args.get("unit"),
|
||||
)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def _date_add_sql(
|
||||
data_type: str, kind: str
|
||||
) -> t.Callable[[generator.Generator, exp.Expression], str]:
|
||||
|
@ -142,6 +131,7 @@ class BigQuery(Dialect):
|
|||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"ANY TYPE": TokenType.VARIANT,
|
||||
"BEGIN": TokenType.COMMAND,
|
||||
"BEGIN TRANSACTION": TokenType.BEGIN,
|
||||
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
|
||||
|
@ -155,14 +145,19 @@ class BigQuery(Dialect):
|
|||
KEYWORDS.pop("DIV")
|
||||
|
||||
class Parser(parser.Parser):
|
||||
PREFIXED_PIVOT_COLUMNS = True
|
||||
|
||||
LOG_BASE_FIRST = False
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"DATE_TRUNC": lambda args: exp.DateTrunc(
|
||||
unit=exp.Literal.string(seq_get(args, 1).name), # type: ignore
|
||||
this=seq_get(args, 0),
|
||||
),
|
||||
"DATE_ADD": _date_add(exp.DateAdd),
|
||||
"DATETIME_ADD": _date_add(exp.DatetimeAdd),
|
||||
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
|
||||
"DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd),
|
||||
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
|
||||
"REGEXP_CONTAINS": exp.RegexpLike.from_arg_list,
|
||||
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
|
||||
|
@ -174,12 +169,12 @@ class BigQuery(Dialect):
|
|||
if re.compile(str(seq_get(args, 1))).groups == 1
|
||||
else None,
|
||||
),
|
||||
"TIME_ADD": _date_add(exp.TimeAdd),
|
||||
"TIMESTAMP_ADD": _date_add(exp.TimestampAdd),
|
||||
"DATE_SUB": _date_add(exp.DateSub),
|
||||
"DATETIME_SUB": _date_add(exp.DatetimeSub),
|
||||
"TIME_SUB": _date_add(exp.TimeSub),
|
||||
"TIMESTAMP_SUB": _date_add(exp.TimestampSub),
|
||||
"TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd),
|
||||
"TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd),
|
||||
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
|
||||
"DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub),
|
||||
"TIME_SUB": parse_date_delta_with_interval(exp.TimeSub),
|
||||
"TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub),
|
||||
"PARSE_TIMESTAMP": lambda args: exp.StrToTime(
|
||||
this=seq_get(args, 1), format=seq_get(args, 0)
|
||||
),
|
||||
|
@ -209,14 +204,17 @@ class BigQuery(Dialect):
|
|||
PROPERTY_PARSERS = {
|
||||
**parser.Parser.PROPERTY_PARSERS, # type: ignore
|
||||
"NOT DETERMINISTIC": lambda self: self.expression(
|
||||
exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
|
||||
exp.StabilityProperty, this=exp.Literal.string("VOLATILE")
|
||||
),
|
||||
}
|
||||
|
||||
LOG_BASE_FIRST = False
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
class Generator(generator.Generator):
|
||||
EXPLICIT_UNION = True
|
||||
INTERVAL_ALLOWS_PLURAL_FORM = False
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore
|
||||
|
@ -236,9 +234,7 @@ class BigQuery(Dialect):
|
|||
exp.IntDiv: rename_func("DIV"),
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.Select: transforms.preprocess(
|
||||
[_unqualify_unnest], transforms.delegate("select_sql")
|
||||
),
|
||||
exp.Select: transforms.preprocess([_unqualify_unnest]),
|
||||
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
|
||||
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
|
||||
exp.TimeSub: _date_add_sql("TIME", "SUB"),
|
||||
|
@ -253,7 +249,7 @@ class BigQuery(Dialect):
|
|||
exp.ReturnsProperty: _returnsproperty_sql,
|
||||
exp.Create: _create_sql,
|
||||
exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression),
|
||||
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
|
||||
exp.StabilityProperty: lambda self, e: f"DETERMINISTIC"
|
||||
if e.name == "IMMUTABLE"
|
||||
else "NOT DETERMINISTIC",
|
||||
exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
|
||||
|
@ -261,6 +257,7 @@ class BigQuery(Dialect):
|
|||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC",
|
||||
exp.DataType.Type.BIGINT: "INT64",
|
||||
exp.DataType.Type.BOOLEAN: "BOOL",
|
||||
exp.DataType.Type.CHAR: "STRING",
|
||||
|
@ -272,17 +269,19 @@ class BigQuery(Dialect):
|
|||
exp.DataType.Type.NVARCHAR: "STRING",
|
||||
exp.DataType.Type.SMALLINT: "INT64",
|
||||
exp.DataType.Type.TEXT: "STRING",
|
||||
exp.DataType.Type.TIMESTAMP: "DATETIME",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
|
||||
exp.DataType.Type.TINYINT: "INT64",
|
||||
exp.DataType.Type.VARCHAR: "STRING",
|
||||
exp.DataType.Type.VARIANT: "ANY TYPE",
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
EXPLICIT_UNION = True
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
|
||||
def array_sql(self, expression: exp.Array) -> str:
|
||||
first_arg = seq_get(expression.expressions, 0)
|
||||
if isinstance(first_arg, exp.Subqueryable):
|
||||
|
|
|
@ -144,6 +144,13 @@ class ClickHouse(Dialect):
|
|||
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
EXPLICIT_UNION = True
|
||||
|
||||
def _param_args_sql(
|
||||
|
|
|
@ -9,6 +9,8 @@ from sqlglot.tokens import TokenType
|
|||
|
||||
class Databricks(Spark):
|
||||
class Parser(Spark.Parser):
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
FUNCTIONS = {
|
||||
**Spark.Parser.FUNCTIONS,
|
||||
"DATEADD": parse_date_delta(exp.DateAdd),
|
||||
|
@ -16,13 +18,17 @@ class Databricks(Spark):
|
|||
"DATEDIFF": parse_date_delta(exp.DateDiff),
|
||||
}
|
||||
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
FACTOR = {
|
||||
**Spark.Parser.FACTOR,
|
||||
TokenType.COLON: exp.JSONExtract,
|
||||
}
|
||||
|
||||
class Generator(Spark.Generator):
|
||||
TRANSFORMS = {
|
||||
**Spark.Generator.TRANSFORMS, # type: ignore
|
||||
exp.DateAdd: generate_date_delta_with_unit_sql,
|
||||
exp.DateDiff: generate_date_delta_with_unit_sql,
|
||||
exp.JSONExtract: lambda self, e: self.binary(e, ":"),
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
}
|
||||
TRANSFORMS.pop(exp.Select) # Remove the ELIMINATE_QUALIFY transformation
|
||||
|
|
|
@ -293,6 +293,13 @@ def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
|
|||
return ""
|
||||
|
||||
|
||||
def no_comment_column_constraint_sql(
|
||||
self: Generator, expression: exp.CommentColumnConstraint
|
||||
) -> str:
|
||||
self.unsupported("CommentColumnConstraint unsupported")
|
||||
return ""
|
||||
|
||||
|
||||
def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
substr = self.sql(expression, "substr")
|
||||
|
@ -379,15 +386,35 @@ def parse_date_delta(
|
|||
) -> t.Callable[[t.Sequence], E]:
|
||||
def inner_func(args: t.Sequence) -> E:
|
||||
unit_based = len(args) == 3
|
||||
this = seq_get(args, 2) if unit_based else seq_get(args, 0)
|
||||
expression = seq_get(args, 1) if unit_based else seq_get(args, 1)
|
||||
unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY")
|
||||
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit # type: ignore
|
||||
return exp_class(this=this, expression=expression, unit=unit)
|
||||
this = args[2] if unit_based else seq_get(args, 0)
|
||||
unit = args[0] if unit_based else exp.Literal.string("DAY")
|
||||
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
|
||||
return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
|
||||
|
||||
return inner_func
|
||||
|
||||
|
||||
def parse_date_delta_with_interval(
|
||||
expression_class: t.Type[E],
|
||||
) -> t.Callable[[t.Sequence], t.Optional[E]]:
|
||||
def func(args: t.Sequence) -> t.Optional[E]:
|
||||
if len(args) < 2:
|
||||
return None
|
||||
|
||||
interval = args[1]
|
||||
expression = interval.this
|
||||
if expression and expression.is_string:
|
||||
expression = exp.Literal.number(expression.this)
|
||||
|
||||
return expression_class(
|
||||
this=args[0],
|
||||
expression=expression,
|
||||
unit=exp.Literal.string(interval.text("unit")),
|
||||
)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
|
||||
unit = seq_get(args, 0)
|
||||
this = seq_get(args, 1)
|
||||
|
|
|
@ -104,6 +104,9 @@ class Drill(Dialect):
|
|||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
class Generator(generator.Generator):
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.INT: "INTEGER",
|
||||
|
@ -120,6 +123,7 @@ class Drill(Dialect):
|
|||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
|
|
|
@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
|
|||
arrow_json_extract_sql,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
no_comment_column_constraint_sql,
|
||||
no_pivot_sql,
|
||||
no_properties_sql,
|
||||
no_safe_divide_sql,
|
||||
|
@ -23,7 +24,7 @@ from sqlglot.tokens import TokenType
|
|||
|
||||
|
||||
def _ts_or_ds_add(self, expression):
|
||||
this = expression.args.get("this")
|
||||
this = self.sql(expression, "this")
|
||||
unit = self.sql(expression, "unit").strip("'") or "DAY"
|
||||
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
|
||||
|
||||
|
@ -139,6 +140,8 @@ class DuckDB(Dialect):
|
|||
}
|
||||
|
||||
class Generator(generator.Generator):
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
||||
TRANSFORMS = {
|
||||
|
@ -150,6 +153,7 @@ class DuckDB(Dialect):
|
|||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||
exp.ArraySort: _array_sort_sql,
|
||||
exp.ArraySum: rename_func("LIST_SUM"),
|
||||
exp.CommentColumnConstraint: no_comment_column_constraint_sql,
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
|
@ -213,6 +217,11 @@ class DuckDB(Dialect):
|
|||
"except": "EXCLUDE",
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
|
||||
def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str:
|
||||
|
|
|
@ -45,16 +45,23 @@ TIME_DIFF_FACTOR = {
|
|||
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
|
||||
|
||||
|
||||
def _add_date_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
|
||||
def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
unit = expression.text("unit").upper()
|
||||
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
|
||||
modified_increment = (
|
||||
int(expression.text("expression")) * multiplier
|
||||
if expression.expression.is_number
|
||||
else expression.expression
|
||||
)
|
||||
modified_increment = exp.Literal.number(modified_increment)
|
||||
return self.func(func, expression.this, modified_increment.this)
|
||||
|
||||
if isinstance(expression, exp.DateSub):
|
||||
multiplier *= -1
|
||||
|
||||
if expression.expression.is_number:
|
||||
modified_increment = exp.Literal.number(int(expression.text("expression")) * multiplier)
|
||||
else:
|
||||
modified_increment = expression.expression
|
||||
if multiplier != 1:
|
||||
modified_increment = exp.Mul( # type: ignore
|
||||
this=modified_increment, expression=exp.Literal.number(multiplier)
|
||||
)
|
||||
|
||||
return self.func(func, expression.this, modified_increment)
|
||||
|
||||
|
||||
def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
|
||||
|
@ -127,24 +134,6 @@ def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str
|
|||
return f"TO_DATE({this})"
|
||||
|
||||
|
||||
def _unnest_to_explode_sql(self: generator.Generator, expression: exp.Join) -> str:
|
||||
unnest = expression.this
|
||||
if isinstance(unnest, exp.Unnest):
|
||||
alias = unnest.args.get("alias")
|
||||
udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
|
||||
return "".join(
|
||||
self.sql(
|
||||
exp.Lateral(
|
||||
this=udtf(this=expression),
|
||||
view=True,
|
||||
alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore
|
||||
)
|
||||
)
|
||||
for expression, column in zip(unnest.expressions, alias.columns if alias else [])
|
||||
)
|
||||
return self.join_sql(expression)
|
||||
|
||||
|
||||
def _index_sql(self: generator.Generator, expression: exp.Index) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
table = self.sql(expression, "table")
|
||||
|
@ -195,6 +184,7 @@ class Hive(Dialect):
|
|||
IDENTIFIERS = ["`"]
|
||||
STRING_ESCAPES = ["\\"]
|
||||
ENCODE = "utf-8"
|
||||
IDENTIFIER_CAN_START_WITH_DIGIT = True
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
|
@ -217,9 +207,8 @@ class Hive(Dialect):
|
|||
"BD": "DECIMAL",
|
||||
}
|
||||
|
||||
IDENTIFIER_CAN_START_WITH_DIGIT = True
|
||||
|
||||
class Parser(parser.Parser):
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
STRICT_CAST = False
|
||||
|
||||
FUNCTIONS = {
|
||||
|
@ -273,9 +262,13 @@ class Hive(Dialect):
|
|||
),
|
||||
}
|
||||
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
class Generator(generator.Generator):
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
TABLESAMPLE_WITH_METHOD = False
|
||||
TABLESAMPLE_SIZE_IS_PERCENT = True
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.TEXT: "STRING",
|
||||
|
@ -289,6 +282,9 @@ class Hive(Dialect):
|
|||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
**transforms.ELIMINATE_QUALIFY, # type: ignore
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_qualify, transforms.unnest_to_explode]
|
||||
),
|
||||
exp.Property: _property_sql,
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.ArrayConcat: rename_func("CONCAT"),
|
||||
|
@ -298,13 +294,13 @@ class Hive(Dialect):
|
|||
exp.DateAdd: _add_date_sql,
|
||||
exp.DateDiff: _date_diff_sql,
|
||||
exp.DateStrToDate: rename_func("TO_DATE"),
|
||||
exp.DateSub: _add_date_sql,
|
||||
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
|
||||
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
|
||||
exp.FileFormatProperty: lambda self, e: f"STORED AS {e.name.upper()}",
|
||||
exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
|
||||
exp.If: if_sql,
|
||||
exp.Index: _index_sql,
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.Join: _unnest_to_explode_sql,
|
||||
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
|
||||
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
|
||||
exp.JSONFormat: rename_func("TO_JSON"),
|
||||
|
@ -354,10 +350,9 @@ class Hive(Dialect):
|
|||
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
|
||||
def arrayagg_sql(self, expression: exp.ArrayAgg) -> str:
|
||||
return self.func(
|
||||
"COLLECT_LIST",
|
||||
|
@ -378,4 +373,5 @@ class Hive(Dialect):
|
|||
expression = exp.DataType.build("text")
|
||||
elif expression.this in exp.DataType.TEMPORAL_TYPES:
|
||||
expression = exp.DataType.build(expression.this)
|
||||
|
||||
return super().datatype_sql(expression)
|
||||
|
|
|
@ -4,6 +4,8 @@ from sqlglot import exp, generator, parser, tokens
|
|||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
arrow_json_extract_scalar_sql,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
locate_to_strposition,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
|
@ -11,6 +13,7 @@ from sqlglot.dialects.dialect import (
|
|||
no_paren_current_date_sql,
|
||||
no_tablesample_sql,
|
||||
no_trycast_sql,
|
||||
parse_date_delta_with_interval,
|
||||
rename_func,
|
||||
strposition_to_locate_sql,
|
||||
)
|
||||
|
@ -76,18 +79,6 @@ def _trim_sql(self, expression):
|
|||
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
|
||||
|
||||
|
||||
def _date_add(expression_class):
|
||||
def func(args):
|
||||
interval = seq_get(args, 1)
|
||||
return expression_class(
|
||||
this=seq_get(args, 0),
|
||||
expression=interval.this,
|
||||
unit=exp.Literal.string(interval.text("unit").lower()),
|
||||
)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def _date_add_sql(kind):
|
||||
def func(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -115,6 +106,7 @@ class MySQL(Dialect):
|
|||
"%k": "%-H",
|
||||
"%l": "%-I",
|
||||
"%T": "%H:%M:%S",
|
||||
"%W": "%a",
|
||||
}
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
|
@ -127,12 +119,13 @@ class MySQL(Dialect):
|
|||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"MEDIUMTEXT": TokenType.MEDIUMTEXT,
|
||||
"CHARSET": TokenType.CHARACTER_SET,
|
||||
"LONGBLOB": TokenType.LONGBLOB,
|
||||
"LONGTEXT": TokenType.LONGTEXT,
|
||||
"MEDIUMBLOB": TokenType.MEDIUMBLOB,
|
||||
"LONGBLOB": TokenType.LONGBLOB,
|
||||
"START": TokenType.BEGIN,
|
||||
"MEDIUMTEXT": TokenType.MEDIUMTEXT,
|
||||
"SEPARATOR": TokenType.SEPARATOR,
|
||||
"START": TokenType.BEGIN,
|
||||
"_ARMSCII8": TokenType.INTRODUCER,
|
||||
"_ASCII": TokenType.INTRODUCER,
|
||||
"_BIG5": TokenType.INTRODUCER,
|
||||
|
@ -186,14 +179,15 @@ class MySQL(Dialect):
|
|||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"DATE_ADD": _date_add(exp.DateAdd),
|
||||
"DATE_SUB": _date_add(exp.DateSub),
|
||||
"STR_TO_DATE": _str_to_date,
|
||||
"LOCATE": locate_to_strposition,
|
||||
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
|
||||
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
|
||||
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
|
||||
"INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
|
||||
"LEFT": lambda args: exp.Substring(
|
||||
this=seq_get(args, 0), start=exp.Literal.number(1), length=seq_get(args, 1)
|
||||
),
|
||||
"LOCATE": locate_to_strposition,
|
||||
"STR_TO_DATE": _str_to_date,
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
|
@ -388,32 +382,36 @@ class MySQL(Dialect):
|
|||
class Generator(generator.Generator):
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
NULL_ORDERING_SUPPORTED = False
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
exp.CurrentDate: no_paren_current_date_sql,
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression),
|
||||
exp.DateAdd: _date_add_sql("ADD"),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateSub: _date_add_sql("SUB"),
|
||||
exp.DateTrunc: _date_trunc_sql,
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.DateAdd: _date_add_sql("ADD"),
|
||||
exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})",
|
||||
exp.DateSub: _date_add_sql("SUB"),
|
||||
exp.DateTrunc: _date_trunc_sql,
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
|
||||
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
|
||||
exp.StrToDate: _str_to_date_sql,
|
||||
exp.StrToTime: _str_to_date_sql,
|
||||
exp.Trim: _trim_sql,
|
||||
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
|
||||
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
|
||||
exp.StrPosition: strposition_to_locate_sql,
|
||||
exp.StrToDate: _str_to_date_sql,
|
||||
exp.StrToTime: _str_to_date_sql,
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)),
|
||||
exp.Trim: _trim_sql,
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
|
||||
}
|
||||
|
||||
TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy()
|
||||
|
@ -425,6 +423,7 @@ class MySQL(Dialect):
|
|||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
|
|
|
@ -7,11 +7,6 @@ from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sq
|
|||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
|
||||
TokenType.COLUMN,
|
||||
TokenType.RETURNING,
|
||||
}
|
||||
|
||||
|
||||
def _parse_xml_table(self) -> exp.XMLTable:
|
||||
this = self._parse_string()
|
||||
|
@ -22,9 +17,7 @@ def _parse_xml_table(self) -> exp.XMLTable:
|
|||
if self._match_text_seq("PASSING"):
|
||||
# The BY VALUE keywords are optional and are provided for semantic clarity
|
||||
self._match_text_seq("BY", "VALUE")
|
||||
passing = self._parse_csv(
|
||||
lambda: self._parse_table(alias_tokens=PASSING_TABLE_ALIAS_TOKENS)
|
||||
)
|
||||
passing = self._parse_csv(self._parse_column)
|
||||
|
||||
by_ref = self._match_text_seq("RETURNING", "SEQUENCE", "BY", "REF")
|
||||
|
||||
|
@ -68,6 +61,8 @@ class Oracle(Dialect):
|
|||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP}
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
|
||||
|
@ -78,6 +73,12 @@ class Oracle(Dialect):
|
|||
"XMLTABLE": _parse_xml_table,
|
||||
}
|
||||
|
||||
TYPE_LITERAL_PARSERS = {
|
||||
exp.DataType.Type.DATE: lambda self, this, _: self.expression(
|
||||
exp.DateStrToDate, this=this
|
||||
)
|
||||
}
|
||||
|
||||
def _parse_column(self) -> t.Optional[exp.Expression]:
|
||||
column = super()._parse_column()
|
||||
if column:
|
||||
|
@ -100,6 +101,8 @@ class Oracle(Dialect):
|
|||
|
||||
class Generator(generator.Generator):
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
|
@ -119,6 +122,9 @@ class Oracle(Dialect):
|
|||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
exp.DateStrToDate: lambda self, e: self.func(
|
||||
"TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD")
|
||||
),
|
||||
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
|
@ -129,6 +135,12 @@ class Oracle(Dialect):
|
|||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.Trim: trim_sql,
|
||||
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
|
||||
exp.IfNull: rename_func("NVL"),
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
LIMIT_FETCH = "FETCH"
|
||||
|
@ -142,9 +154,9 @@ class Oracle(Dialect):
|
|||
|
||||
def xmltable_sql(self, expression: exp.XMLTable) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
passing = self.expressions(expression, "passing")
|
||||
passing = self.expressions(expression, key="passing")
|
||||
passing = f"{self.sep()}PASSING{self.seg(passing)}" if passing else ""
|
||||
columns = self.expressions(expression, "columns")
|
||||
columns = self.expressions(expression, key="columns")
|
||||
columns = f"{self.sep()}COLUMNS{self.seg(columns)}" if columns else ""
|
||||
by_ref = (
|
||||
f"{self.sep()}RETURNING SEQUENCE BY REF" if expression.args.get("by_ref") else ""
|
||||
|
|
|
@ -5,6 +5,7 @@ from sqlglot.dialects.dialect import (
|
|||
Dialect,
|
||||
arrow_json_extract_scalar_sql,
|
||||
arrow_json_extract_sql,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
|
@ -19,7 +20,7 @@ from sqlglot.dialects.dialect import (
|
|||
from sqlglot.helper import seq_get
|
||||
from sqlglot.parser import binary_range_parser
|
||||
from sqlglot.tokens import TokenType
|
||||
from sqlglot.transforms import delegate, preprocess
|
||||
from sqlglot.transforms import preprocess, remove_target_from_merge
|
||||
|
||||
DATE_DIFF_FACTOR = {
|
||||
"MICROSECOND": " * 1000000",
|
||||
|
@ -239,7 +240,6 @@ class Postgres(Dialect):
|
|||
"SERIAL": TokenType.SERIAL,
|
||||
"SMALLSERIAL": TokenType.SMALLSERIAL,
|
||||
"TEMP": TokenType.TEMPORARY,
|
||||
"UUID": TokenType.UUID,
|
||||
"CSTRING": TokenType.PSEUDO_TYPE,
|
||||
}
|
||||
|
||||
|
@ -248,18 +248,25 @@ class Postgres(Dialect):
|
|||
"$": TokenType.PARAMETER,
|
||||
}
|
||||
|
||||
VAR_SINGLE_TOKENS = {"$"}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
STRICT_CAST = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"NOW": exp.CurrentTimestamp.from_arg_list,
|
||||
"TO_TIMESTAMP": _to_timestamp,
|
||||
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
|
||||
"GENERATE_SERIES": _generate_series,
|
||||
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
|
||||
this=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
),
|
||||
"GENERATE_SERIES": _generate_series,
|
||||
"NOW": exp.CurrentTimestamp.from_arg_list,
|
||||
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
|
||||
"TO_TIMESTAMP": _to_timestamp,
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
"DATE_PART": lambda self: self._parse_date_part(),
|
||||
}
|
||||
|
||||
BITWISE = {
|
||||
|
@ -279,8 +286,21 @@ class Postgres(Dialect):
|
|||
TokenType.LT_AT: binary_range_parser(exp.ArrayContained),
|
||||
}
|
||||
|
||||
def _parse_date_part(self) -> exp.Expression:
|
||||
part = self._parse_type()
|
||||
self._match(TokenType.COMMA)
|
||||
value = self._parse_bitwise()
|
||||
|
||||
if part and part.is_string:
|
||||
part = exp.Var(this=part.name)
|
||||
|
||||
return self.expression(exp.Extract, this=part, expression=value)
|
||||
|
||||
class Generator(generator.Generator):
|
||||
INTERVAL_ALLOWS_PLURAL_FORM = False
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
PARAMETER_TOKEN = "$"
|
||||
|
||||
TYPE_MAPPING = {
|
||||
|
@ -301,7 +321,6 @@ class Postgres(Dialect):
|
|||
_auto_increment_to_serial,
|
||||
_serial_to_generated,
|
||||
],
|
||||
delegate("columndef_sql"),
|
||||
),
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||
|
@ -312,6 +331,7 @@ class Postgres(Dialect):
|
|||
exp.CurrentDate: no_paren_current_date_sql,
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.DateAdd: _date_add_sql("+"),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateSub: _date_add_sql("-"),
|
||||
exp.DateDiff: _date_diff_sql,
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
|
@ -321,6 +341,7 @@ class Postgres(Dialect):
|
|||
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
|
||||
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
|
||||
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
|
||||
exp.Merge: preprocess([remove_target_from_merge]),
|
||||
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
|
||||
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
|
||||
exp.StrPosition: str_position_sql,
|
||||
|
@ -344,4 +365,5 @@ class Postgres(Dialect):
|
|||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
|
@ -19,20 +21,20 @@ from sqlglot.helper import seq_get
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _approx_distinct_sql(self, expression):
|
||||
def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistinct) -> str:
|
||||
accuracy = expression.args.get("accuracy")
|
||||
accuracy = ", " + self.sql(accuracy) if accuracy else ""
|
||||
return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})"
|
||||
|
||||
|
||||
def _datatype_sql(self, expression):
|
||||
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
|
||||
sql = self.datatype_sql(expression)
|
||||
if expression.this == exp.DataType.Type.TIMESTAMPTZ:
|
||||
sql = f"{sql} WITH TIME ZONE"
|
||||
return sql
|
||||
|
||||
|
||||
def _explode_to_unnest_sql(self, expression):
|
||||
def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str:
|
||||
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
|
||||
return self.sql(
|
||||
exp.Join(
|
||||
|
@ -47,22 +49,22 @@ def _explode_to_unnest_sql(self, expression):
|
|||
return self.lateral_sql(expression)
|
||||
|
||||
|
||||
def _initcap_sql(self, expression):
|
||||
def _initcap_sql(self: generator.Generator, expression: exp.Initcap) -> str:
|
||||
regex = r"(\w)(\w*)"
|
||||
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
|
||||
|
||||
|
||||
def _decode_sql(self, expression):
|
||||
_ensure_utf8(expression.args.get("charset"))
|
||||
def _decode_sql(self: generator.Generator, expression: exp.Decode) -> str:
|
||||
_ensure_utf8(expression.args["charset"])
|
||||
return self.func("FROM_UTF8", expression.this, expression.args.get("replace"))
|
||||
|
||||
|
||||
def _encode_sql(self, expression):
|
||||
_ensure_utf8(expression.args.get("charset"))
|
||||
def _encode_sql(self: generator.Generator, expression: exp.Encode) -> str:
|
||||
_ensure_utf8(expression.args["charset"])
|
||||
return f"TO_UTF8({self.sql(expression, 'this')})"
|
||||
|
||||
|
||||
def _no_sort_array(self, expression):
|
||||
def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str:
|
||||
if expression.args.get("asc") == exp.false():
|
||||
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
|
||||
else:
|
||||
|
@ -70,49 +72,62 @@ def _no_sort_array(self, expression):
|
|||
return self.func("ARRAY_SORT", expression.this, comparator)
|
||||
|
||||
|
||||
def _schema_sql(self, expression):
|
||||
def _schema_sql(self: generator.Generator, expression: exp.Schema) -> str:
|
||||
if isinstance(expression.parent, exp.Property):
|
||||
columns = ", ".join(f"'{c.name}'" for c in expression.expressions)
|
||||
return f"ARRAY[{columns}]"
|
||||
|
||||
for schema in expression.parent.find_all(exp.Schema):
|
||||
if isinstance(schema.parent, exp.Property):
|
||||
expression = expression.copy()
|
||||
expression.expressions.extend(schema.expressions)
|
||||
if expression.parent:
|
||||
for schema in expression.parent.find_all(exp.Schema):
|
||||
if isinstance(schema.parent, exp.Property):
|
||||
expression = expression.copy()
|
||||
expression.expressions.extend(schema.expressions)
|
||||
|
||||
return self.schema_sql(expression)
|
||||
|
||||
|
||||
def _quantile_sql(self, expression):
|
||||
def _quantile_sql(self: generator.Generator, expression: exp.Quantile) -> str:
|
||||
self.unsupported("Presto does not support exact quantiles")
|
||||
return f"APPROX_PERCENTILE({self.sql(expression, 'this')}, {self.sql(expression, 'quantile')})"
|
||||
|
||||
|
||||
def _str_to_time_sql(self, expression):
|
||||
def _str_to_time_sql(
|
||||
self: generator.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate
|
||||
) -> str:
|
||||
return f"DATE_PARSE({self.sql(expression, 'this')}, {self.format_time(expression)})"
|
||||
|
||||
|
||||
def _ts_or_ds_to_date_sql(self, expression):
|
||||
def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
|
||||
time_format = self.format_time(expression)
|
||||
if time_format and time_format not in (Presto.time_format, Presto.date_format):
|
||||
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
|
||||
return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)"
|
||||
|
||||
|
||||
def _ts_or_ds_add_sql(self, expression):
|
||||
def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
|
||||
this = expression.this
|
||||
|
||||
if not isinstance(this, exp.CurrentDate):
|
||||
this = self.func(
|
||||
"DATE_PARSE",
|
||||
self.func(
|
||||
"SUBSTR",
|
||||
this if this.is_string else exp.cast(this, "VARCHAR"),
|
||||
exp.Literal.number(1),
|
||||
exp.Literal.number(10),
|
||||
),
|
||||
Presto.date_format,
|
||||
)
|
||||
|
||||
return self.func(
|
||||
"DATE_ADD",
|
||||
exp.Literal.string(expression.text("unit") or "day"),
|
||||
expression.expression,
|
||||
self.func(
|
||||
"DATE_PARSE",
|
||||
self.func("SUBSTR", expression.this, exp.Literal.number(1), exp.Literal.number(10)),
|
||||
Presto.date_format,
|
||||
),
|
||||
this,
|
||||
)
|
||||
|
||||
|
||||
def _sequence_sql(self, expression):
|
||||
def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str:
|
||||
start = expression.args["start"]
|
||||
end = expression.args["end"]
|
||||
step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series
|
||||
|
@ -135,12 +150,12 @@ def _sequence_sql(self, expression):
|
|||
return self.func("SEQUENCE", start, end, step)
|
||||
|
||||
|
||||
def _ensure_utf8(charset):
|
||||
def _ensure_utf8(charset: exp.Literal) -> None:
|
||||
if charset.name.lower() != "utf-8":
|
||||
raise UnsupportedError(f"Unsupported charset {charset}")
|
||||
|
||||
|
||||
def _approx_percentile(args):
|
||||
def _approx_percentile(args: t.Sequence) -> exp.Expression:
|
||||
if len(args) == 4:
|
||||
return exp.ApproxQuantile(
|
||||
this=seq_get(args, 0),
|
||||
|
@ -157,7 +172,7 @@ def _approx_percentile(args):
|
|||
return exp.ApproxQuantile.from_arg_list(args)
|
||||
|
||||
|
||||
def _from_unixtime(args):
|
||||
def _from_unixtime(args: t.Sequence) -> exp.Expression:
|
||||
if len(args) == 3:
|
||||
return exp.UnixToTime(
|
||||
this=seq_get(args, 0),
|
||||
|
@ -226,11 +241,15 @@ class Presto(Dialect):
|
|||
FUNCTION_PARSERS.pop("TRIM")
|
||||
|
||||
class Generator(generator.Generator):
|
||||
INTERVAL_ALLOWS_PLURAL_FORM = False
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.LocationProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
|
@ -246,7 +265,6 @@ class Presto(Dialect):
|
|||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
**transforms.ELIMINATE_QUALIFY, # type: ignore
|
||||
exp.ApproxDistinct: _approx_distinct_sql,
|
||||
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
|
||||
exp.ArrayConcat: rename_func("CONCAT"),
|
||||
|
@ -284,6 +302,9 @@ class Presto(Dialect):
|
|||
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
|
||||
exp.SafeDivide: no_safe_divide_sql,
|
||||
exp.Schema: _schema_sql,
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_qualify, transforms.explode_to_unnest]
|
||||
),
|
||||
exp.SortArray: _no_sort_array,
|
||||
exp.StrPosition: rename_func("STRPOS"),
|
||||
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
|
||||
|
@ -308,7 +329,13 @@ class Presto(Dialect):
|
|||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
}
|
||||
|
||||
def transaction_sql(self, expression):
|
||||
def interval_sql(self, expression: exp.Interval) -> str:
|
||||
unit = self.sql(expression, "unit")
|
||||
if expression.this and unit.lower().startswith("week"):
|
||||
return f"({expression.this.name} * INTERVAL '7' day)"
|
||||
return super().interval_sql(expression)
|
||||
|
||||
def transaction_sql(self, expression: exp.Transaction) -> str:
|
||||
modes = expression.args.get("modes")
|
||||
modes = f" {', '.join(modes)}" if modes else ""
|
||||
return f"START TRANSACTION{modes}"
|
||||
|
|
|
@ -8,6 +8,10 @@ from sqlglot.helper import seq_get
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _json_sql(self, e) -> str:
|
||||
return f'{self.sql(e, "this")}."{e.expression.name}"'
|
||||
|
||||
|
||||
class Redshift(Postgres):
|
||||
time_format = "'YYYY-MM-DD HH:MI:SS'"
|
||||
time_mapping = {
|
||||
|
@ -56,6 +60,7 @@ class Redshift(Postgres):
|
|||
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
||||
"HLLSKETCH": TokenType.HLLSKETCH,
|
||||
"SUPER": TokenType.SUPER,
|
||||
"SYSDATE": TokenType.CURRENT_TIMESTAMP,
|
||||
"TIME": TokenType.TIMESTAMP,
|
||||
"TIMETZ": TokenType.TIMESTAMPTZ,
|
||||
"TOP": TokenType.TOP,
|
||||
|
@ -63,7 +68,14 @@ class Redshift(Postgres):
|
|||
"VARBYTE": TokenType.VARBINARY,
|
||||
}
|
||||
|
||||
# Redshift allows # to appear as a table identifier prefix
|
||||
SINGLE_TOKENS = Postgres.Tokenizer.SINGLE_TOKENS.copy()
|
||||
SINGLE_TOKENS.pop("#")
|
||||
|
||||
class Generator(Postgres.Generator):
|
||||
LOCKING_READS_SUPPORTED = False
|
||||
SINGLE_STRING_INTERVAL = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**Postgres.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.BINARY: "VARBYTE",
|
||||
|
@ -79,6 +91,7 @@ class Redshift(Postgres):
|
|||
TRANSFORMS = {
|
||||
**Postgres.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
|
||||
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
|
||||
exp.DateAdd: lambda self, e: self.func(
|
||||
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
|
||||
),
|
||||
|
@ -87,12 +100,16 @@ class Redshift(Postgres):
|
|||
),
|
||||
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
|
||||
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
|
||||
exp.JSONExtract: _json_sql,
|
||||
exp.JSONExtractScalar: _json_sql,
|
||||
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
|
||||
}
|
||||
|
||||
# Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres)
|
||||
TRANSFORMS.pop(exp.Pow)
|
||||
|
||||
RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot"}
|
||||
|
||||
def values_sql(self, expression: exp.Values) -> str:
|
||||
"""
|
||||
Converts `VALUES...` expression into a series of unions.
|
||||
|
|
|
@ -23,14 +23,14 @@ from sqlglot.parser import binary_range_parser
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _check_int(s):
|
||||
def _check_int(s: str) -> bool:
|
||||
if s[0] in ("-", "+"):
|
||||
return s[1:].isdigit()
|
||||
return s.isdigit()
|
||||
|
||||
|
||||
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
|
||||
def _snowflake_to_timestamp(args):
|
||||
def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.UnixToTime]:
|
||||
if len(args) == 2:
|
||||
first_arg, second_arg = args
|
||||
if second_arg.is_string:
|
||||
|
@ -69,7 +69,7 @@ def _snowflake_to_timestamp(args):
|
|||
return exp.UnixToTime.from_arg_list(args)
|
||||
|
||||
|
||||
def _unix_to_time_sql(self, expression):
|
||||
def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str:
|
||||
scale = expression.args.get("scale")
|
||||
timestamp = self.sql(expression, "this")
|
||||
if scale in [None, exp.UnixToTime.SECONDS]:
|
||||
|
@ -84,8 +84,12 @@ def _unix_to_time_sql(self, expression):
|
|||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
|
||||
# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts
|
||||
def _parse_date_part(self):
|
||||
def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_var() or self._parse_type()
|
||||
|
||||
if not this:
|
||||
return None
|
||||
|
||||
self._match(TokenType.COMMA)
|
||||
expression = self._parse_bitwise()
|
||||
|
||||
|
@ -101,7 +105,7 @@ def _parse_date_part(self):
|
|||
scale = None
|
||||
|
||||
ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP"))
|
||||
to_unix = self.expression(exp.TimeToUnix, this=ts)
|
||||
to_unix: exp.Expression = self.expression(exp.TimeToUnix, this=ts)
|
||||
|
||||
if scale:
|
||||
to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale))
|
||||
|
@ -112,7 +116,7 @@ def _parse_date_part(self):
|
|||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/div0
|
||||
def _div0_to_if(args):
|
||||
def _div0_to_if(args: t.Sequence) -> exp.Expression:
|
||||
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
|
||||
true = exp.Literal.number(0)
|
||||
false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1))
|
||||
|
@ -120,18 +124,18 @@ def _div0_to_if(args):
|
|||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
|
||||
def _zeroifnull_to_if(args):
|
||||
def _zeroifnull_to_if(args: t.Sequence) -> exp.Expression:
|
||||
cond = exp.Is(this=seq_get(args, 0), expression=exp.Null())
|
||||
return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0))
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
|
||||
def _nullifzero_to_if(args):
|
||||
def _nullifzero_to_if(args: t.Sequence) -> exp.Expression:
|
||||
cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0))
|
||||
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
|
||||
|
||||
|
||||
def _datatype_sql(self, expression):
|
||||
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
|
||||
if expression.this == exp.DataType.Type.ARRAY:
|
||||
return "ARRAY"
|
||||
elif expression.this == exp.DataType.Type.MAP:
|
||||
|
@ -155,9 +159,8 @@ class Snowflake(Dialect):
|
|||
"MM": "%m",
|
||||
"mm": "%m",
|
||||
"DD": "%d",
|
||||
"dd": "%d",
|
||||
"d": "%-d",
|
||||
"DY": "%w",
|
||||
"dd": "%-d",
|
||||
"DY": "%a",
|
||||
"dy": "%w",
|
||||
"HH24": "%H",
|
||||
"hh24": "%H",
|
||||
|
@ -174,6 +177,8 @@ class Snowflake(Dialect):
|
|||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
QUOTED_PIVOT_COLUMNS = True
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
|
||||
|
@ -269,9 +274,14 @@ class Snowflake(Dialect):
|
|||
"$": TokenType.PARAMETER,
|
||||
}
|
||||
|
||||
VAR_SINGLE_TOKENS = {"$"}
|
||||
|
||||
class Generator(generator.Generator):
|
||||
PARAMETER_TOKEN = "$"
|
||||
MATCHED_BY_SOURCE = False
|
||||
SINGLE_STRING_INTERVAL = True
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
|
@ -287,26 +297,30 @@ class Snowflake(Dialect):
|
|||
),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.If: rename_func("IFF"),
|
||||
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
exp.LogicalOr: rename_func("BOOLOR_AGG"),
|
||||
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
|
||||
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
exp.LogicalOr: rename_func("BOOLOR_AGG"),
|
||||
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
|
||||
exp.StrPosition: lambda self, e: self.func(
|
||||
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
|
||||
),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimestampTrunc: timestamptrunc_sql,
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
|
||||
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
|
||||
exp.TimeToStr: lambda self, e: self.func(
|
||||
"TO_CHAR", exp.cast(e.this, "timestamp"), self.format_time(e)
|
||||
),
|
||||
exp.TimestampTrunc: timestamptrunc_sql,
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"),
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
|
@ -322,14 +336,15 @@ class Snowflake(Dialect):
|
|||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.SetProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def except_op(self, expression):
|
||||
def except_op(self, expression: exp.Except) -> str:
|
||||
if not expression.args.get("distinct", False):
|
||||
self.unsupported("EXCEPT with All is not supported in Snowflake")
|
||||
return super().except_op(expression)
|
||||
|
||||
def intersect_op(self, expression):
|
||||
def intersect_op(self, expression: exp.Intersect) -> str:
|
||||
if not expression.args.get("distinct", False):
|
||||
self.unsupported("INTERSECT with All is not supported in Snowflake")
|
||||
return super().intersect_op(expression)
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, parser
|
||||
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
|
||||
from sqlglot.dialects.hive import Hive
|
||||
from sqlglot.helper import seq_get
|
||||
|
||||
|
||||
def _create_sql(self, e):
|
||||
kind = e.args.get("kind")
|
||||
def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
|
||||
kind = e.args["kind"]
|
||||
properties = e.args.get("properties")
|
||||
|
||||
if kind.upper() == "TABLE" and any(
|
||||
|
@ -18,13 +20,13 @@ def _create_sql(self, e):
|
|||
return create_with_partitions_sql(self, e)
|
||||
|
||||
|
||||
def _map_sql(self, expression):
|
||||
def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
|
||||
keys = self.sql(expression.args["keys"])
|
||||
values = self.sql(expression.args["values"])
|
||||
return f"MAP_FROM_ARRAYS({keys}, {values})"
|
||||
|
||||
|
||||
def _str_to_date(self, expression):
|
||||
def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format == Hive.date_format:
|
||||
|
@ -32,7 +34,7 @@ def _str_to_date(self, expression):
|
|||
return f"TO_DATE({this}, {time_format})"
|
||||
|
||||
|
||||
def _unix_to_time(self, expression):
|
||||
def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
|
||||
scale = expression.args.get("scale")
|
||||
timestamp = self.sql(expression, "this")
|
||||
if scale is None:
|
||||
|
@ -75,7 +77,11 @@ class Spark(Hive):
|
|||
length=seq_get(args, 1),
|
||||
),
|
||||
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
||||
"BOOLEAN": lambda args: exp.Cast(
|
||||
this=seq_get(args, 0), to=exp.DataType.build("boolean")
|
||||
),
|
||||
"IIF": exp.If.from_arg_list,
|
||||
"INT": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("int")),
|
||||
"AGGREGATE": exp.Reduce.from_arg_list,
|
||||
"DAYOFWEEK": lambda args: exp.DayOfWeek(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
|
@ -89,11 +95,16 @@ class Spark(Hive):
|
|||
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
|
||||
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
|
||||
this=seq_get(args, 1),
|
||||
unit=exp.var(seq_get(args, 0)),
|
||||
),
|
||||
"STRING": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("string")),
|
||||
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
|
||||
"TIMESTAMP": lambda args: exp.Cast(
|
||||
this=seq_get(args, 0), to=exp.DataType.build("timestamp")
|
||||
),
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
|
@ -108,16 +119,43 @@ class Spark(Hive):
|
|||
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
|
||||
}
|
||||
|
||||
def _parse_add_column(self):
|
||||
def _parse_add_column(self) -> t.Optional[exp.Expression]:
|
||||
return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
|
||||
|
||||
def _parse_drop_column(self):
|
||||
def _parse_drop_column(self) -> t.Optional[exp.Expression]:
|
||||
return self._match_text_seq("DROP", "COLUMNS") and self.expression(
|
||||
exp.Drop,
|
||||
this=self._parse_schema(),
|
||||
kind="COLUMNS",
|
||||
)
|
||||
|
||||
def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
|
||||
# Spark doesn't add a suffix to the pivot columns when there's a single aggregation
|
||||
if len(pivot_columns) == 1:
|
||||
return [""]
|
||||
|
||||
names = []
|
||||
for agg in pivot_columns:
|
||||
if isinstance(agg, exp.Alias):
|
||||
names.append(agg.alias)
|
||||
else:
|
||||
"""
|
||||
This case corresponds to aggregations without aliases being used as suffixes
|
||||
(e.g. col_avg(foo)). We need to unquote identifiers because they're going to
|
||||
be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
|
||||
Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
|
||||
|
||||
Moreover, function names are lowercased in order to mimic Spark's naming scheme.
|
||||
"""
|
||||
agg_all_unquoted = agg.transform(
|
||||
lambda node: exp.Identifier(this=node.name, quoted=False)
|
||||
if isinstance(node, exp.Identifier)
|
||||
else node
|
||||
)
|
||||
names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
|
||||
|
||||
return names
|
||||
|
||||
class Generator(Hive.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**Hive.Generator.TYPE_MAPPING, # type: ignore
|
||||
|
@ -145,7 +183,7 @@ class Spark(Hive):
|
|||
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
|
||||
exp.StrToDate: _str_to_date,
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.UnixToTime: _unix_to_time,
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.Create: _create_sql,
|
||||
exp.Map: _map_sql,
|
||||
exp.Reduce: rename_func("AGGREGATE"),
|
||||
|
|
|
@ -16,7 +16,7 @@ from sqlglot.tokens import TokenType
|
|||
|
||||
def _date_add_sql(self, expression):
|
||||
modifier = expression.expression
|
||||
modifier = expression.name if modifier.is_string else self.sql(modifier)
|
||||
modifier = modifier.name if modifier.is_string else self.sql(modifier)
|
||||
unit = expression.args.get("unit")
|
||||
modifier = f"'{modifier} {unit.name}'" if unit else f"'{modifier}'"
|
||||
return self.func("DATE", expression.this, modifier)
|
||||
|
@ -38,6 +38,9 @@ class SQLite(Dialect):
|
|||
}
|
||||
|
||||
class Generator(generator.Generator):
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.BOOLEAN: "INTEGER",
|
||||
|
@ -82,6 +85,11 @@ class SQLite(Dialect):
|
|||
exp.TryCast: no_trycast_sql,
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
|
||||
def cast_sql(self, expression: exp.Cast) -> str:
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func
|
||||
from sqlglot.dialects.dialect import (
|
||||
approx_count_distinct_sql,
|
||||
arrow_json_extract_sql,
|
||||
rename_func,
|
||||
)
|
||||
from sqlglot.dialects.mysql import MySQL
|
||||
from sqlglot.helper import seq_get
|
||||
|
||||
|
@ -10,6 +14,7 @@ class StarRocks(MySQL):
|
|||
class Parser(MySQL.Parser): # type: ignore
|
||||
FUNCTIONS = {
|
||||
**MySQL.Parser.FUNCTIONS,
|
||||
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
|
||||
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
|
||||
this=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
),
|
||||
|
@ -25,6 +30,7 @@ class StarRocks(MySQL):
|
|||
|
||||
TRANSFORMS = {
|
||||
**MySQL.Generator.TRANSFORMS, # type: ignore
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_sql,
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.DateDiff: rename_func("DATEDIFF"),
|
||||
|
|
|
@ -21,6 +21,9 @@ def _count_sql(self, expression):
|
|||
|
||||
class Tableau(Dialect):
|
||||
class Generator(generator.Generator):
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
exp.If: _if_sql,
|
||||
|
@ -28,6 +31,11 @@ class Tableau(Dialect):
|
|||
exp.Count: _count_sql,
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
|
|
|
@ -1,7 +1,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
format_time_lambda,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
)
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
|
@ -115,7 +122,18 @@ class Teradata(Dialect):
|
|||
|
||||
return self.expression(exp.RangeN, this=this, expressions=expressions, each=each)
|
||||
|
||||
def _parse_cast(self, strict: bool) -> exp.Expression:
|
||||
cast = t.cast(exp.Cast, super()._parse_cast(strict))
|
||||
if cast.to.this == exp.DataType.Type.DATE and self._match(TokenType.FORMAT):
|
||||
return format_time_lambda(exp.TimeToStr, "teradata")(
|
||||
[cast.this, self._parse_string()]
|
||||
)
|
||||
return cast
|
||||
|
||||
class Generator(generator.Generator):
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.GEOMETRY: "ST_GEOMETRY",
|
||||
|
@ -130,6 +148,7 @@ class Teradata(Dialect):
|
|||
**generator.Generator.TRANSFORMS,
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.TimeToStr: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
}
|
||||
|
||||
|
|
|
@ -96,6 +96,23 @@ def _parse_eomonth(args):
|
|||
return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit))
|
||||
|
||||
|
||||
def _parse_hashbytes(args):
|
||||
kind, data = args
|
||||
kind = kind.name.upper() if kind.is_string else ""
|
||||
|
||||
if kind == "MD5":
|
||||
args.pop(0)
|
||||
return exp.MD5(this=data)
|
||||
if kind in ("SHA", "SHA1"):
|
||||
args.pop(0)
|
||||
return exp.SHA(this=data)
|
||||
if kind == "SHA2_256":
|
||||
return exp.SHA2(this=data, length=exp.Literal.number(256))
|
||||
if kind == "SHA2_512":
|
||||
return exp.SHA2(this=data, length=exp.Literal.number(512))
|
||||
return exp.func("HASHBYTES", *args)
|
||||
|
||||
|
||||
def generate_date_delta_with_unit_sql(self, e):
|
||||
func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF"
|
||||
return self.func(func, e.text("unit"), e.expression, e.this)
|
||||
|
@ -266,6 +283,7 @@ class TSQL(Dialect):
|
|||
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
|
||||
"VARCHAR(MAX)": TokenType.TEXT,
|
||||
"XML": TokenType.XML,
|
||||
"SYSTEM_USER": TokenType.CURRENT_USER,
|
||||
}
|
||||
|
||||
# TSQL allows @, # to appear as a variable/identifier prefix
|
||||
|
@ -287,6 +305,7 @@ class TSQL(Dialect):
|
|||
"EOMONTH": _parse_eomonth,
|
||||
"FORMAT": _parse_format,
|
||||
"GETDATE": exp.CurrentTimestamp.from_arg_list,
|
||||
"HASHBYTES": _parse_hashbytes,
|
||||
"IIF": exp.If.from_arg_list,
|
||||
"ISNULL": exp.Coalesce.from_arg_list,
|
||||
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
|
||||
|
@ -296,6 +315,14 @@ class TSQL(Dialect):
|
|||
"SYSDATETIME": exp.CurrentTimestamp.from_arg_list,
|
||||
"SUSER_NAME": exp.CurrentUser.from_arg_list,
|
||||
"SUSER_SNAME": exp.CurrentUser.from_arg_list,
|
||||
"SYSTEM_USER": exp.CurrentUser.from_arg_list,
|
||||
}
|
||||
|
||||
JOIN_HINTS = {
|
||||
"LOOP",
|
||||
"HASH",
|
||||
"MERGE",
|
||||
"REMOTE",
|
||||
}
|
||||
|
||||
VAR_LENGTH_DATATYPES = {
|
||||
|
@ -441,11 +468,21 @@ class TSQL(Dialect):
|
|||
exp.TimeToStr: _format_sql,
|
||||
exp.GroupConcat: _string_agg_sql,
|
||||
exp.Max: max_or_greatest,
|
||||
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
|
||||
exp.Min: min_or_least,
|
||||
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
|
||||
exp.SHA2: lambda self, e: self.func(
|
||||
"HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
|
||||
),
|
||||
}
|
||||
|
||||
TRANSFORMS.pop(exp.ReturnsProperty)
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
LIMIT_FETCH = "FETCH"
|
||||
|
||||
def offset_sql(self, expression: exp.Offset) -> str:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue