1
0
Fork 0

Merging upstream version 11.7.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:52:09 +01:00
parent 0c053462ae
commit 8d96084fad
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
144 changed files with 44104 additions and 39367 deletions

View file

@ -13,6 +13,7 @@ from sqlglot.dialects.dialect import (
max_or_greatest,
min_or_least,
no_ilike_sql,
parse_date_delta_with_interval,
rename_func,
timestrtotime_sql,
ts_or_ds_to_date_sql,
@ -23,18 +24,6 @@ from sqlglot.tokens import TokenType
E = t.TypeVar("E", bound=exp.Expression)
def _date_add(expression_class: t.Type[E]) -> t.Callable[[t.Sequence], E]:
def func(args):
interval = seq_get(args, 1)
return expression_class(
this=seq_get(args, 0),
expression=interval.this,
unit=interval.args.get("unit"),
)
return func
def _date_add_sql(
data_type: str, kind: str
) -> t.Callable[[generator.Generator, exp.Expression], str]:
@ -142,6 +131,7 @@ class BigQuery(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ANY TYPE": TokenType.VARIANT,
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
@ -155,14 +145,19 @@ class BigQuery(Dialect):
KEYWORDS.pop("DIV")
class Parser(parser.Parser):
PREFIXED_PIVOT_COLUMNS = True
LOG_BASE_FIRST = False
LOG_DEFAULTS_TO_LN = True
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"DATE_TRUNC": lambda args: exp.DateTrunc(
unit=exp.Literal.string(seq_get(args, 1).name), # type: ignore
this=seq_get(args, 0),
),
"DATE_ADD": _date_add(exp.DateAdd),
"DATETIME_ADD": _date_add(exp.DatetimeAdd),
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
"DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd),
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
"REGEXP_CONTAINS": exp.RegexpLike.from_arg_list,
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
@ -174,12 +169,12 @@ class BigQuery(Dialect):
if re.compile(str(seq_get(args, 1))).groups == 1
else None,
),
"TIME_ADD": _date_add(exp.TimeAdd),
"TIMESTAMP_ADD": _date_add(exp.TimestampAdd),
"DATE_SUB": _date_add(exp.DateSub),
"DATETIME_SUB": _date_add(exp.DatetimeSub),
"TIME_SUB": _date_add(exp.TimeSub),
"TIMESTAMP_SUB": _date_add(exp.TimestampSub),
"TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd),
"TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd),
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
"DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub),
"TIME_SUB": parse_date_delta_with_interval(exp.TimeSub),
"TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub),
"PARSE_TIMESTAMP": lambda args: exp.StrToTime(
this=seq_get(args, 1), format=seq_get(args, 0)
),
@ -209,14 +204,17 @@ class BigQuery(Dialect):
PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS, # type: ignore
"NOT DETERMINISTIC": lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
exp.StabilityProperty, this=exp.Literal.string("VOLATILE")
),
}
LOG_BASE_FIRST = False
LOG_DEFAULTS_TO_LN = True
class Generator(generator.Generator):
EXPLICIT_UNION = True
INTERVAL_ALLOWS_PLURAL_FORM = False
JOIN_HINTS = False
TABLE_HINTS = False
LIMIT_FETCH = "LIMIT"
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore
@ -236,9 +234,7 @@ class BigQuery(Dialect):
exp.IntDiv: rename_func("DIV"),
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Select: transforms.preprocess(
[_unqualify_unnest], transforms.delegate("select_sql")
),
exp.Select: transforms.preprocess([_unqualify_unnest]),
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
exp.TimeSub: _date_add_sql("TIME", "SUB"),
@ -253,7 +249,7 @@ class BigQuery(Dialect):
exp.ReturnsProperty: _returnsproperty_sql,
exp.Create: _create_sql,
exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression),
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
exp.StabilityProperty: lambda self, e: f"DETERMINISTIC"
if e.name == "IMMUTABLE"
else "NOT DETERMINISTIC",
exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
@ -261,6 +257,7 @@ class BigQuery(Dialect):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC",
exp.DataType.Type.BIGINT: "INT64",
exp.DataType.Type.BOOLEAN: "BOOL",
exp.DataType.Type.CHAR: "STRING",
@ -272,17 +269,19 @@ class BigQuery(Dialect):
exp.DataType.Type.NVARCHAR: "STRING",
exp.DataType.Type.SMALLINT: "INT64",
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.TIMESTAMP: "DATETIME",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.VARCHAR: "STRING",
exp.DataType.Type.VARIANT: "ANY TYPE",
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
EXPLICIT_UNION = True
LIMIT_FETCH = "LIMIT"
def array_sql(self, expression: exp.Array) -> str:
first_arg = seq_get(expression.expressions, 0)
if isinstance(first_arg, exp.Subqueryable):

View file

@ -144,6 +144,13 @@ class ClickHouse(Dialect):
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
JOIN_HINTS = False
TABLE_HINTS = False
EXPLICIT_UNION = True
def _param_args_sql(

View file

@ -9,6 +9,8 @@ from sqlglot.tokens import TokenType
class Databricks(Spark):
class Parser(Spark.Parser):
LOG_DEFAULTS_TO_LN = True
FUNCTIONS = {
**Spark.Parser.FUNCTIONS,
"DATEADD": parse_date_delta(exp.DateAdd),
@ -16,13 +18,17 @@ class Databricks(Spark):
"DATEDIFF": parse_date_delta(exp.DateDiff),
}
LOG_DEFAULTS_TO_LN = True
FACTOR = {
**Spark.Parser.FACTOR,
TokenType.COLON: exp.JSONExtract,
}
class Generator(Spark.Generator):
TRANSFORMS = {
**Spark.Generator.TRANSFORMS, # type: ignore
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.JSONExtract: lambda self, e: self.binary(e, ":"),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}
TRANSFORMS.pop(exp.Select) # Remove the ELIMINATE_QUALIFY transformation

View file

@ -293,6 +293,13 @@ def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
return ""
def no_comment_column_constraint_sql(
self: Generator, expression: exp.CommentColumnConstraint
) -> str:
self.unsupported("CommentColumnConstraint unsupported")
return ""
def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
this = self.sql(expression, "this")
substr = self.sql(expression, "substr")
@ -379,15 +386,35 @@ def parse_date_delta(
) -> t.Callable[[t.Sequence], E]:
def inner_func(args: t.Sequence) -> E:
unit_based = len(args) == 3
this = seq_get(args, 2) if unit_based else seq_get(args, 0)
expression = seq_get(args, 1) if unit_based else seq_get(args, 1)
unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY")
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit # type: ignore
return exp_class(this=this, expression=expression, unit=unit)
this = args[2] if unit_based else seq_get(args, 0)
unit = args[0] if unit_based else exp.Literal.string("DAY")
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
return inner_func
def parse_date_delta_with_interval(
expression_class: t.Type[E],
) -> t.Callable[[t.Sequence], t.Optional[E]]:
def func(args: t.Sequence) -> t.Optional[E]:
if len(args) < 2:
return None
interval = args[1]
expression = interval.this
if expression and expression.is_string:
expression = exp.Literal.number(expression.this)
return expression_class(
this=args[0],
expression=expression,
unit=exp.Literal.string(interval.text("unit")),
)
return func
def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
unit = seq_get(args, 0)
this = seq_get(args, 1)

View file

@ -104,6 +104,9 @@ class Drill(Dialect):
LOG_DEFAULTS_TO_LN = True
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.INT: "INTEGER",
@ -120,6 +123,7 @@ class Drill(Dialect):
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
TRANSFORMS = {

View file

@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
arrow_json_extract_sql,
datestrtodate_sql,
format_time_lambda,
no_comment_column_constraint_sql,
no_pivot_sql,
no_properties_sql,
no_safe_divide_sql,
@ -23,7 +24,7 @@ from sqlglot.tokens import TokenType
def _ts_or_ds_add(self, expression):
this = expression.args.get("this")
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
@ -139,6 +140,8 @@ class DuckDB(Dialect):
}
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
STRUCT_DELIMITER = ("(", ")")
TRANSFORMS = {
@ -150,6 +153,7 @@ class DuckDB(Dialect):
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArraySort: _array_sort_sql,
exp.ArraySum: rename_func("LIST_SUM"),
exp.CommentColumnConstraint: no_comment_column_constraint_sql,
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
@ -213,6 +217,11 @@ class DuckDB(Dialect):
"except": "EXCLUDE",
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
LIMIT_FETCH = "LIMIT"
def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str:

View file

@ -45,16 +45,23 @@ TIME_DIFF_FACTOR = {
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
def _add_date_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
unit = expression.text("unit").upper()
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
modified_increment = (
int(expression.text("expression")) * multiplier
if expression.expression.is_number
else expression.expression
)
modified_increment = exp.Literal.number(modified_increment)
return self.func(func, expression.this, modified_increment.this)
if isinstance(expression, exp.DateSub):
multiplier *= -1
if expression.expression.is_number:
modified_increment = exp.Literal.number(int(expression.text("expression")) * multiplier)
else:
modified_increment = expression.expression
if multiplier != 1:
modified_increment = exp.Mul( # type: ignore
this=modified_increment, expression=exp.Literal.number(multiplier)
)
return self.func(func, expression.this, modified_increment)
def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
@ -127,24 +134,6 @@ def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str
return f"TO_DATE({this})"
def _unnest_to_explode_sql(self: generator.Generator, expression: exp.Join) -> str:
unnest = expression.this
if isinstance(unnest, exp.Unnest):
alias = unnest.args.get("alias")
udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
return "".join(
self.sql(
exp.Lateral(
this=udtf(this=expression),
view=True,
alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore
)
)
for expression, column in zip(unnest.expressions, alias.columns if alias else [])
)
return self.join_sql(expression)
def _index_sql(self: generator.Generator, expression: exp.Index) -> str:
this = self.sql(expression, "this")
table = self.sql(expression, "table")
@ -195,6 +184,7 @@ class Hive(Dialect):
IDENTIFIERS = ["`"]
STRING_ESCAPES = ["\\"]
ENCODE = "utf-8"
IDENTIFIER_CAN_START_WITH_DIGIT = True
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@ -217,9 +207,8 @@ class Hive(Dialect):
"BD": "DECIMAL",
}
IDENTIFIER_CAN_START_WITH_DIGIT = True
class Parser(parser.Parser):
LOG_DEFAULTS_TO_LN = True
STRICT_CAST = False
FUNCTIONS = {
@ -273,9 +262,13 @@ class Hive(Dialect):
),
}
LOG_DEFAULTS_TO_LN = True
class Generator(generator.Generator):
LIMIT_FETCH = "LIMIT"
TABLESAMPLE_WITH_METHOD = False
TABLESAMPLE_SIZE_IS_PERCENT = True
JOIN_HINTS = False
TABLE_HINTS = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TEXT: "STRING",
@ -289,6 +282,9 @@ class Hive(Dialect):
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
**transforms.ELIMINATE_QUALIFY, # type: ignore
exp.Select: transforms.preprocess(
[transforms.eliminate_qualify, transforms.unnest_to_explode]
),
exp.Property: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayConcat: rename_func("CONCAT"),
@ -298,13 +294,13 @@ class Hive(Dialect):
exp.DateAdd: _add_date_sql,
exp.DateDiff: _date_diff_sql,
exp.DateStrToDate: rename_func("TO_DATE"),
exp.DateSub: _add_date_sql,
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
exp.FileFormatProperty: lambda self, e: f"STORED AS {e.name.upper()}",
exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
exp.If: if_sql,
exp.Index: _index_sql,
exp.ILike: no_ilike_sql,
exp.Join: _unnest_to_explode_sql,
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
exp.JSONFormat: rename_func("TO_JSON"),
@ -354,10 +350,9 @@ class Hive(Dialect):
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA,
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
LIMIT_FETCH = "LIMIT"
def arrayagg_sql(self, expression: exp.ArrayAgg) -> str:
return self.func(
"COLLECT_LIST",
@ -378,4 +373,5 @@ class Hive(Dialect):
expression = exp.DataType.build("text")
elif expression.this in exp.DataType.TEMPORAL_TYPES:
expression = exp.DataType.build(expression.this)
return super().datatype_sql(expression)

View file

@ -4,6 +4,8 @@ from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
datestrtodate_sql,
format_time_lambda,
locate_to_strposition,
max_or_greatest,
min_or_least,
@ -11,6 +13,7 @@ from sqlglot.dialects.dialect import (
no_paren_current_date_sql,
no_tablesample_sql,
no_trycast_sql,
parse_date_delta_with_interval,
rename_func,
strposition_to_locate_sql,
)
@ -76,18 +79,6 @@ def _trim_sql(self, expression):
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
def _date_add(expression_class):
def func(args):
interval = seq_get(args, 1)
return expression_class(
this=seq_get(args, 0),
expression=interval.this,
unit=exp.Literal.string(interval.text("unit").lower()),
)
return func
def _date_add_sql(kind):
def func(self, expression):
this = self.sql(expression, "this")
@ -115,6 +106,7 @@ class MySQL(Dialect):
"%k": "%-H",
"%l": "%-I",
"%T": "%H:%M:%S",
"%W": "%a",
}
class Tokenizer(tokens.Tokenizer):
@ -127,12 +119,13 @@ class MySQL(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"MEDIUMTEXT": TokenType.MEDIUMTEXT,
"CHARSET": TokenType.CHARACTER_SET,
"LONGBLOB": TokenType.LONGBLOB,
"LONGTEXT": TokenType.LONGTEXT,
"MEDIUMBLOB": TokenType.MEDIUMBLOB,
"LONGBLOB": TokenType.LONGBLOB,
"START": TokenType.BEGIN,
"MEDIUMTEXT": TokenType.MEDIUMTEXT,
"SEPARATOR": TokenType.SEPARATOR,
"START": TokenType.BEGIN,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
"_BIG5": TokenType.INTRODUCER,
@ -186,14 +179,15 @@ class MySQL(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"DATE_ADD": _date_add(exp.DateAdd),
"DATE_SUB": _date_add(exp.DateSub),
"STR_TO_DATE": _str_to_date,
"LOCATE": locate_to_strposition,
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
"INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
"LEFT": lambda args: exp.Substring(
this=seq_get(args, 0), start=exp.Literal.number(1), length=seq_get(args, 1)
),
"LOCATE": locate_to_strposition,
"STR_TO_DATE": _str_to_date,
}
FUNCTION_PARSERS = {
@ -388,32 +382,36 @@ class MySQL(Dialect):
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
NULL_ORDERING_SUPPORTED = False
JOIN_HINTS = False
TABLE_HINTS = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression),
exp.DateAdd: _date_add_sql("ADD"),
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _date_add_sql("SUB"),
exp.DateTrunc: _date_trunc_sql,
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
exp.ILike: no_ilike_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.TableSample: no_tablesample_sql,
exp.TryCast: no_trycast_sql,
exp.DateAdd: _date_add_sql("ADD"),
exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})",
exp.DateSub: _date_add_sql("SUB"),
exp.DateTrunc: _date_trunc_sql,
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_date_sql,
exp.Trim: _trim_sql,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_date_sql,
exp.TableSample: no_tablesample_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)),
exp.Trim: _trim_sql,
exp.TryCast: no_trycast_sql,
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
}
TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy()
@ -425,6 +423,7 @@ class MySQL(Dialect):
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
LIMIT_FETCH = "LIMIT"

View file

@ -7,11 +7,6 @@ from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sq
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
TokenType.COLUMN,
TokenType.RETURNING,
}
def _parse_xml_table(self) -> exp.XMLTable:
this = self._parse_string()
@ -22,9 +17,7 @@ def _parse_xml_table(self) -> exp.XMLTable:
if self._match_text_seq("PASSING"):
# The BY VALUE keywords are optional and are provided for semantic clarity
self._match_text_seq("BY", "VALUE")
passing = self._parse_csv(
lambda: self._parse_table(alias_tokens=PASSING_TABLE_ALIAS_TOKENS)
)
passing = self._parse_csv(self._parse_column)
by_ref = self._match_text_seq("RETURNING", "SEQUENCE", "BY", "REF")
@ -68,6 +61,8 @@ class Oracle(Dialect):
}
class Parser(parser.Parser):
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP}
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
@ -78,6 +73,12 @@ class Oracle(Dialect):
"XMLTABLE": _parse_xml_table,
}
TYPE_LITERAL_PARSERS = {
exp.DataType.Type.DATE: lambda self, this, _: self.expression(
exp.DateStrToDate, this=this
)
}
def _parse_column(self) -> t.Optional[exp.Expression]:
column = super()._parse_column()
if column:
@ -100,6 +101,8 @@ class Oracle(Dialect):
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
JOIN_HINTS = False
TABLE_HINTS = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
@ -119,6 +122,9 @@ class Oracle(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
exp.DateStrToDate: lambda self, e: self.func(
"TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD")
),
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.ILike: no_ilike_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
@ -129,6 +135,12 @@ class Oracle(Dialect):
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: trim_sql,
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
exp.IfNull: rename_func("NVL"),
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
LIMIT_FETCH = "FETCH"
@ -142,9 +154,9 @@ class Oracle(Dialect):
def xmltable_sql(self, expression: exp.XMLTable) -> str:
this = self.sql(expression, "this")
passing = self.expressions(expression, "passing")
passing = self.expressions(expression, key="passing")
passing = f"{self.sep()}PASSING{self.seg(passing)}" if passing else ""
columns = self.expressions(expression, "columns")
columns = self.expressions(expression, key="columns")
columns = f"{self.sep()}COLUMNS{self.seg(columns)}" if columns else ""
by_ref = (
f"{self.sep()}RETURNING SEQUENCE BY REF" if expression.args.get("by_ref") else ""

View file

@ -5,6 +5,7 @@ from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
datestrtodate_sql,
format_time_lambda,
max_or_greatest,
min_or_least,
@ -19,7 +20,7 @@ from sqlglot.dialects.dialect import (
from sqlglot.helper import seq_get
from sqlglot.parser import binary_range_parser
from sqlglot.tokens import TokenType
from sqlglot.transforms import delegate, preprocess
from sqlglot.transforms import preprocess, remove_target_from_merge
DATE_DIFF_FACTOR = {
"MICROSECOND": " * 1000000",
@ -239,7 +240,6 @@ class Postgres(Dialect):
"SERIAL": TokenType.SERIAL,
"SMALLSERIAL": TokenType.SMALLSERIAL,
"TEMP": TokenType.TEMPORARY,
"UUID": TokenType.UUID,
"CSTRING": TokenType.PSEUDO_TYPE,
}
@ -248,18 +248,25 @@ class Postgres(Dialect):
"$": TokenType.PARAMETER,
}
VAR_SINGLE_TOKENS = {"$"}
class Parser(parser.Parser):
STRICT_CAST = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"NOW": exp.CurrentTimestamp.from_arg_list,
"TO_TIMESTAMP": _to_timestamp,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
"GENERATE_SERIES": _generate_series,
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1), unit=seq_get(args, 0)
),
"GENERATE_SERIES": _generate_series,
"NOW": exp.CurrentTimestamp.from_arg_list,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
"TO_TIMESTAMP": _to_timestamp,
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
"DATE_PART": lambda self: self._parse_date_part(),
}
BITWISE = {
@ -279,8 +286,21 @@ class Postgres(Dialect):
TokenType.LT_AT: binary_range_parser(exp.ArrayContained),
}
def _parse_date_part(self) -> exp.Expression:
part = self._parse_type()
self._match(TokenType.COMMA)
value = self._parse_bitwise()
if part and part.is_string:
part = exp.Var(this=part.name)
return self.expression(exp.Extract, this=part, expression=value)
class Generator(generator.Generator):
INTERVAL_ALLOWS_PLURAL_FORM = False
LOCKING_READS_SUPPORTED = True
JOIN_HINTS = False
TABLE_HINTS = False
PARAMETER_TOKEN = "$"
TYPE_MAPPING = {
@ -301,7 +321,6 @@ class Postgres(Dialect):
_auto_increment_to_serial,
_serial_to_generated,
],
delegate("columndef_sql"),
),
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
@ -312,6 +331,7 @@ class Postgres(Dialect):
exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: _date_add_sql("+"),
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _date_add_sql("-"),
exp.DateDiff: _date_diff_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
@ -321,6 +341,7 @@ class Postgres(Dialect):
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
exp.Merge: preprocess([remove_target_from_merge]),
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
exp.StrPosition: str_position_sql,
@ -344,4 +365,5 @@ class Postgres(Dialect):
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}

View file

@ -1,5 +1,7 @@
from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
@ -19,20 +21,20 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
def _approx_distinct_sql(self, expression):
def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistinct) -> str:
accuracy = expression.args.get("accuracy")
accuracy = ", " + self.sql(accuracy) if accuracy else ""
return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})"
def _datatype_sql(self, expression):
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
sql = self.datatype_sql(expression)
if expression.this == exp.DataType.Type.TIMESTAMPTZ:
sql = f"{sql} WITH TIME ZONE"
return sql
def _explode_to_unnest_sql(self, expression):
def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str:
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
return self.sql(
exp.Join(
@ -47,22 +49,22 @@ def _explode_to_unnest_sql(self, expression):
return self.lateral_sql(expression)
def _initcap_sql(self, expression):
def _initcap_sql(self: generator.Generator, expression: exp.Initcap) -> str:
regex = r"(\w)(\w*)"
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
def _decode_sql(self, expression):
_ensure_utf8(expression.args.get("charset"))
def _decode_sql(self: generator.Generator, expression: exp.Decode) -> str:
_ensure_utf8(expression.args["charset"])
return self.func("FROM_UTF8", expression.this, expression.args.get("replace"))
def _encode_sql(self, expression):
_ensure_utf8(expression.args.get("charset"))
def _encode_sql(self: generator.Generator, expression: exp.Encode) -> str:
_ensure_utf8(expression.args["charset"])
return f"TO_UTF8({self.sql(expression, 'this')})"
def _no_sort_array(self, expression):
def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str:
if expression.args.get("asc") == exp.false():
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
else:
@ -70,49 +72,62 @@ def _no_sort_array(self, expression):
return self.func("ARRAY_SORT", expression.this, comparator)
def _schema_sql(self, expression):
def _schema_sql(self: generator.Generator, expression: exp.Schema) -> str:
if isinstance(expression.parent, exp.Property):
columns = ", ".join(f"'{c.name}'" for c in expression.expressions)
return f"ARRAY[{columns}]"
for schema in expression.parent.find_all(exp.Schema):
if isinstance(schema.parent, exp.Property):
expression = expression.copy()
expression.expressions.extend(schema.expressions)
if expression.parent:
for schema in expression.parent.find_all(exp.Schema):
if isinstance(schema.parent, exp.Property):
expression = expression.copy()
expression.expressions.extend(schema.expressions)
return self.schema_sql(expression)
def _quantile_sql(self, expression):
def _quantile_sql(self: generator.Generator, expression: exp.Quantile) -> str:
self.unsupported("Presto does not support exact quantiles")
return f"APPROX_PERCENTILE({self.sql(expression, 'this')}, {self.sql(expression, 'quantile')})"
def _str_to_time_sql(self, expression):
def _str_to_time_sql(
self: generator.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate
) -> str:
return f"DATE_PARSE({self.sql(expression, 'this')}, {self.format_time(expression)})"
def _ts_or_ds_to_date_sql(self, expression):
def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
time_format = self.format_time(expression)
if time_format and time_format not in (Presto.time_format, Presto.date_format):
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)"
def _ts_or_ds_add_sql(self, expression):
def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
this = expression.this
if not isinstance(this, exp.CurrentDate):
this = self.func(
"DATE_PARSE",
self.func(
"SUBSTR",
this if this.is_string else exp.cast(this, "VARCHAR"),
exp.Literal.number(1),
exp.Literal.number(10),
),
Presto.date_format,
)
return self.func(
"DATE_ADD",
exp.Literal.string(expression.text("unit") or "day"),
expression.expression,
self.func(
"DATE_PARSE",
self.func("SUBSTR", expression.this, exp.Literal.number(1), exp.Literal.number(10)),
Presto.date_format,
),
this,
)
def _sequence_sql(self, expression):
def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str:
start = expression.args["start"]
end = expression.args["end"]
step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series
@ -135,12 +150,12 @@ def _sequence_sql(self, expression):
return self.func("SEQUENCE", start, end, step)
def _ensure_utf8(charset):
def _ensure_utf8(charset: exp.Literal) -> None:
if charset.name.lower() != "utf-8":
raise UnsupportedError(f"Unsupported charset {charset}")
def _approx_percentile(args):
def _approx_percentile(args: t.Sequence) -> exp.Expression:
if len(args) == 4:
return exp.ApproxQuantile(
this=seq_get(args, 0),
@ -157,7 +172,7 @@ def _approx_percentile(args):
return exp.ApproxQuantile.from_arg_list(args)
def _from_unixtime(args):
def _from_unixtime(args: t.Sequence) -> exp.Expression:
if len(args) == 3:
return exp.UnixToTime(
this=seq_get(args, 0),
@ -226,11 +241,15 @@ class Presto(Dialect):
FUNCTION_PARSERS.pop("TRIM")
class Generator(generator.Generator):
INTERVAL_ALLOWS_PLURAL_FORM = False
JOIN_HINTS = False
TABLE_HINTS = False
STRUCT_DELIMITER = ("(", ")")
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.LocationProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
TYPE_MAPPING = {
@ -246,7 +265,6 @@ class Presto(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
**transforms.ELIMINATE_QUALIFY, # type: ignore
exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"),
@ -284,6 +302,9 @@ class Presto(Dialect):
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.Select: transforms.preprocess(
[transforms.eliminate_qualify, transforms.explode_to_unnest]
),
exp.SortArray: _no_sort_array,
exp.StrPosition: rename_func("STRPOS"),
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
@ -308,7 +329,13 @@ class Presto(Dialect):
exp.VariancePop: rename_func("VAR_POP"),
}
def transaction_sql(self, expression):
def interval_sql(self, expression: exp.Interval) -> str:
unit = self.sql(expression, "unit")
if expression.this and unit.lower().startswith("week"):
return f"({expression.this.name} * INTERVAL '7' day)"
return super().interval_sql(expression)
def transaction_sql(self, expression: exp.Transaction) -> str:
modes = expression.args.get("modes")
modes = f" {', '.join(modes)}" if modes else ""
return f"START TRANSACTION{modes}"

View file

@ -8,6 +8,10 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
def _json_sql(self, e) -> str:
return f'{self.sql(e, "this")}."{e.expression.name}"'
class Redshift(Postgres):
time_format = "'YYYY-MM-DD HH:MI:SS'"
time_mapping = {
@ -56,6 +60,7 @@ class Redshift(Postgres):
"GEOGRAPHY": TokenType.GEOGRAPHY,
"HLLSKETCH": TokenType.HLLSKETCH,
"SUPER": TokenType.SUPER,
"SYSDATE": TokenType.CURRENT_TIMESTAMP,
"TIME": TokenType.TIMESTAMP,
"TIMETZ": TokenType.TIMESTAMPTZ,
"TOP": TokenType.TOP,
@ -63,7 +68,14 @@ class Redshift(Postgres):
"VARBYTE": TokenType.VARBINARY,
}
# Redshift allows # to appear as a table identifier prefix
SINGLE_TOKENS = Postgres.Tokenizer.SINGLE_TOKENS.copy()
SINGLE_TOKENS.pop("#")
class Generator(Postgres.Generator):
LOCKING_READS_SUPPORTED = False
SINGLE_STRING_INTERVAL = True
TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BINARY: "VARBYTE",
@ -79,6 +91,7 @@ class Redshift(Postgres):
TRANSFORMS = {
**Postgres.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
exp.DateAdd: lambda self, e: self.func(
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
),
@ -87,12 +100,16 @@ class Redshift(Postgres):
),
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
exp.JSONExtract: _json_sql,
exp.JSONExtractScalar: _json_sql,
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
}
# Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres)
TRANSFORMS.pop(exp.Pow)
RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot"}
def values_sql(self, expression: exp.Values) -> str:
"""
Converts `VALUES...` expression into a series of unions.

View file

@ -23,14 +23,14 @@ from sqlglot.parser import binary_range_parser
from sqlglot.tokens import TokenType
def _check_int(s):
def _check_int(s: str) -> bool:
if s[0] in ("-", "+"):
return s[1:].isdigit()
return s.isdigit()
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
def _snowflake_to_timestamp(args):
def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.UnixToTime]:
if len(args) == 2:
first_arg, second_arg = args
if second_arg.is_string:
@ -69,7 +69,7 @@ def _snowflake_to_timestamp(args):
return exp.UnixToTime.from_arg_list(args)
def _unix_to_time_sql(self, expression):
def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale in [None, exp.UnixToTime.SECONDS]:
@ -84,8 +84,12 @@ def _unix_to_time_sql(self, expression):
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts
def _parse_date_part(self):
def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]:
this = self._parse_var() or self._parse_type()
if not this:
return None
self._match(TokenType.COMMA)
expression = self._parse_bitwise()
@ -101,7 +105,7 @@ def _parse_date_part(self):
scale = None
ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP"))
to_unix = self.expression(exp.TimeToUnix, this=ts)
to_unix: exp.Expression = self.expression(exp.TimeToUnix, this=ts)
if scale:
to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale))
@ -112,7 +116,7 @@ def _parse_date_part(self):
# https://docs.snowflake.com/en/sql-reference/functions/div0
def _div0_to_if(args):
def _div0_to_if(args: t.Sequence) -> exp.Expression:
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
true = exp.Literal.number(0)
false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1))
@ -120,18 +124,18 @@ def _div0_to_if(args):
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
def _zeroifnull_to_if(args):
def _zeroifnull_to_if(args: t.Sequence) -> exp.Expression:
cond = exp.Is(this=seq_get(args, 0), expression=exp.Null())
return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0))
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
def _nullifzero_to_if(args):
def _nullifzero_to_if(args: t.Sequence) -> exp.Expression:
cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0))
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
def _datatype_sql(self, expression):
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
if expression.this == exp.DataType.Type.ARRAY:
return "ARRAY"
elif expression.this == exp.DataType.Type.MAP:
@ -155,9 +159,8 @@ class Snowflake(Dialect):
"MM": "%m",
"mm": "%m",
"DD": "%d",
"dd": "%d",
"d": "%-d",
"DY": "%w",
"dd": "%-d",
"DY": "%a",
"dy": "%w",
"HH24": "%H",
"hh24": "%H",
@ -174,6 +177,8 @@ class Snowflake(Dialect):
}
class Parser(parser.Parser):
QUOTED_PIVOT_COLUMNS = True
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
@ -269,9 +274,14 @@ class Snowflake(Dialect):
"$": TokenType.PARAMETER,
}
VAR_SINGLE_TOKENS = {"$"}
class Generator(generator.Generator):
PARAMETER_TOKEN = "$"
MATCHED_BY_SOURCE = False
SINGLE_STRING_INTERVAL = True
JOIN_HINTS = False
TABLE_HINTS = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
@ -287,26 +297,30 @@ class Snowflake(Dialect):
),
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.If: rename_func("IFF"),
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.LogicalOr: rename_func("BOOLOR_AGG"),
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.LogicalOr: rename_func("BOOLOR_AGG"),
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
exp.StrPosition: lambda self, e: self.func(
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
exp.TimeToStr: lambda self, e: self.func(
"TO_CHAR", exp.cast(e.this, "timestamp"), self.format_time(e)
),
exp.TimestampTrunc: timestamptrunc_sql,
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"),
exp.UnixToTime: _unix_to_time_sql,
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
}
TYPE_MAPPING = {
@ -322,14 +336,15 @@ class Snowflake(Dialect):
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.SetProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
def except_op(self, expression):
def except_op(self, expression: exp.Except) -> str:
if not expression.args.get("distinct", False):
self.unsupported("EXCEPT with All is not supported in Snowflake")
return super().except_op(expression)
def intersect_op(self, expression):
def intersect_op(self, expression: exp.Intersect) -> str:
if not expression.args.get("distinct", False):
self.unsupported("INTERSECT with All is not supported in Snowflake")
return super().intersect_op(expression)

View file

@ -1,13 +1,15 @@
from __future__ import annotations
import typing as t
from sqlglot import exp, parser
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
from sqlglot.dialects.hive import Hive
from sqlglot.helper import seq_get
def _create_sql(self, e):
kind = e.args.get("kind")
def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
kind = e.args["kind"]
properties = e.args.get("properties")
if kind.upper() == "TABLE" and any(
@ -18,13 +20,13 @@ def _create_sql(self, e):
return create_with_partitions_sql(self, e)
def _map_sql(self, expression):
def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
keys = self.sql(expression.args["keys"])
values = self.sql(expression.args["values"])
return f"MAP_FROM_ARRAYS({keys}, {values})"
def _str_to_date(self, expression):
def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Hive.date_format:
@ -32,7 +34,7 @@ def _str_to_date(self, expression):
return f"TO_DATE({this}, {time_format})"
def _unix_to_time(self, expression):
def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale is None:
@ -75,7 +77,11 @@ class Spark(Hive):
length=seq_get(args, 1),
),
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"BOOLEAN": lambda args: exp.Cast(
this=seq_get(args, 0), to=exp.DataType.build("boolean")
),
"IIF": exp.If.from_arg_list,
"INT": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("int")),
"AGGREGATE": exp.Reduce.from_arg_list,
"DAYOFWEEK": lambda args: exp.DayOfWeek(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
@ -89,11 +95,16 @@ class Spark(Hive):
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1),
unit=exp.var(seq_get(args, 0)),
),
"STRING": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("string")),
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
"TIMESTAMP": lambda args: exp.Cast(
this=seq_get(args, 0), to=exp.DataType.build("timestamp")
),
}
FUNCTION_PARSERS = {
@ -108,16 +119,43 @@ class Spark(Hive):
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
}
def _parse_add_column(self):
def _parse_add_column(self) -> t.Optional[exp.Expression]:
return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
def _parse_drop_column(self):
def _parse_drop_column(self) -> t.Optional[exp.Expression]:
return self._match_text_seq("DROP", "COLUMNS") and self.expression(
exp.Drop,
this=self._parse_schema(),
kind="COLUMNS",
)
def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
# Spark doesn't add a suffix to the pivot columns when there's a single aggregation
if len(pivot_columns) == 1:
return [""]
names = []
for agg in pivot_columns:
if isinstance(agg, exp.Alias):
names.append(agg.alias)
else:
"""
This case corresponds to aggregations without aliases being used as suffixes
(e.g. col_avg(foo)). We need to unquote identifiers because they're going to
be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
Moreover, function names are lowercased in order to mimic Spark's naming scheme.
"""
agg_all_unquoted = agg.transform(
lambda node: exp.Identifier(this=node.name, quoted=False)
if isinstance(node, exp.Identifier)
else node
)
names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
return names
class Generator(Hive.Generator):
TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING, # type: ignore
@ -145,7 +183,7 @@ class Spark(Hive):
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: _unix_to_time,
exp.UnixToTime: _unix_to_time_sql,
exp.Create: _create_sql,
exp.Map: _map_sql,
exp.Reduce: rename_func("AGGREGATE"),

View file

@ -16,7 +16,7 @@ from sqlglot.tokens import TokenType
def _date_add_sql(self, expression):
modifier = expression.expression
modifier = expression.name if modifier.is_string else self.sql(modifier)
modifier = modifier.name if modifier.is_string else self.sql(modifier)
unit = expression.args.get("unit")
modifier = f"'{modifier} {unit.name}'" if unit else f"'{modifier}'"
return self.func("DATE", expression.this, modifier)
@ -38,6 +38,9 @@ class SQLite(Dialect):
}
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BOOLEAN: "INTEGER",
@ -82,6 +85,11 @@ class SQLite(Dialect):
exp.TryCast: no_trycast_sql,
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
LIMIT_FETCH = "LIMIT"
def cast_sql(self, expression: exp.Cast) -> str:

View file

@ -1,7 +1,11 @@
from __future__ import annotations
from sqlglot import exp
from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func
from sqlglot.dialects.dialect import (
approx_count_distinct_sql,
arrow_json_extract_sql,
rename_func,
)
from sqlglot.dialects.mysql import MySQL
from sqlglot.helper import seq_get
@ -10,6 +14,7 @@ class StarRocks(MySQL):
class Parser(MySQL.Parser): # type: ignore
FUNCTIONS = {
**MySQL.Parser.FUNCTIONS,
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1), unit=seq_get(args, 0)
),
@ -25,6 +30,7 @@ class StarRocks(MySQL):
TRANSFORMS = {
**MySQL.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: approx_count_distinct_sql,
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.DateDiff: rename_func("DATEDIFF"),

View file

@ -21,6 +21,9 @@ def _count_sql(self, expression):
class Tableau(Dialect):
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
exp.If: _if_sql,
@ -28,6 +31,11 @@ class Tableau(Dialect):
exp.Count: _count_sql,
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore

View file

@ -1,7 +1,14 @@
from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least
from sqlglot.dialects.dialect import (
Dialect,
format_time_lambda,
max_or_greatest,
min_or_least,
)
from sqlglot.tokens import TokenType
@ -115,7 +122,18 @@ class Teradata(Dialect):
return self.expression(exp.RangeN, this=this, expressions=expressions, each=each)
def _parse_cast(self, strict: bool) -> exp.Expression:
cast = t.cast(exp.Cast, super()._parse_cast(strict))
if cast.to.this == exp.DataType.Type.DATE and self._match(TokenType.FORMAT):
return format_time_lambda(exp.TimeToStr, "teradata")(
[cast.this, self._parse_string()]
)
return cast
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.GEOMETRY: "ST_GEOMETRY",
@ -130,6 +148,7 @@ class Teradata(Dialect):
**generator.Generator.TRANSFORMS,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.TimeToStr: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}

View file

@ -96,6 +96,23 @@ def _parse_eomonth(args):
return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit))
def _parse_hashbytes(args):
kind, data = args
kind = kind.name.upper() if kind.is_string else ""
if kind == "MD5":
args.pop(0)
return exp.MD5(this=data)
if kind in ("SHA", "SHA1"):
args.pop(0)
return exp.SHA(this=data)
if kind == "SHA2_256":
return exp.SHA2(this=data, length=exp.Literal.number(256))
if kind == "SHA2_512":
return exp.SHA2(this=data, length=exp.Literal.number(512))
return exp.func("HASHBYTES", *args)
def generate_date_delta_with_unit_sql(self, e):
func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF"
return self.func(func, e.text("unit"), e.expression, e.this)
@ -266,6 +283,7 @@ class TSQL(Dialect):
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
"VARCHAR(MAX)": TokenType.TEXT,
"XML": TokenType.XML,
"SYSTEM_USER": TokenType.CURRENT_USER,
}
# TSQL allows @, # to appear as a variable/identifier prefix
@ -287,6 +305,7 @@ class TSQL(Dialect):
"EOMONTH": _parse_eomonth,
"FORMAT": _parse_format,
"GETDATE": exp.CurrentTimestamp.from_arg_list,
"HASHBYTES": _parse_hashbytes,
"IIF": exp.If.from_arg_list,
"ISNULL": exp.Coalesce.from_arg_list,
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
@ -296,6 +315,14 @@ class TSQL(Dialect):
"SYSDATETIME": exp.CurrentTimestamp.from_arg_list,
"SUSER_NAME": exp.CurrentUser.from_arg_list,
"SUSER_SNAME": exp.CurrentUser.from_arg_list,
"SYSTEM_USER": exp.CurrentUser.from_arg_list,
}
JOIN_HINTS = {
"LOOP",
"HASH",
"MERGE",
"REMOTE",
}
VAR_LENGTH_DATATYPES = {
@ -441,11 +468,21 @@ class TSQL(Dialect):
exp.TimeToStr: _format_sql,
exp.GroupConcat: _string_agg_sql,
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
exp.Min: min_or_least,
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
exp.SHA2: lambda self, e: self.func(
"HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
),
}
TRANSFORMS.pop(exp.ReturnsProperty)
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
LIMIT_FETCH = "FETCH"
def offset_sql(self, expression: exp.Offset) -> str: