148 lines
6.1 KiB
Python
148 lines
6.1 KiB
Python
from __future__ import annotations
|
|
|
|
import typing as t
|
|
|
|
from sqlglot import exp, transforms
|
|
from sqlglot.dialects.postgres import Postgres
|
|
from sqlglot.helper import seq_get
|
|
from sqlglot.tokens import TokenType
|
|
|
|
|
|
class Redshift(Postgres):
|
|
time_format = "'YYYY-MM-DD HH:MI:SS'"
|
|
time_mapping = {
|
|
**Postgres.time_mapping, # type: ignore
|
|
"MON": "%b",
|
|
"HH": "%H",
|
|
}
|
|
|
|
class Parser(Postgres.Parser):
|
|
FUNCTIONS = {
|
|
**Postgres.Parser.FUNCTIONS, # type: ignore
|
|
"DATEADD": lambda args: exp.DateAdd(
|
|
this=seq_get(args, 2),
|
|
expression=seq_get(args, 1),
|
|
unit=seq_get(args, 0),
|
|
),
|
|
"DATEDIFF": lambda args: exp.DateDiff(
|
|
this=seq_get(args, 2),
|
|
expression=seq_get(args, 1),
|
|
unit=seq_get(args, 0),
|
|
),
|
|
"NVL": exp.Coalesce.from_arg_list,
|
|
}
|
|
|
|
CONVERT_TYPE_FIRST = True
|
|
|
|
def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]:
|
|
this = super()._parse_types(check_func=check_func)
|
|
|
|
if (
|
|
isinstance(this, exp.DataType)
|
|
and this.this == exp.DataType.Type.VARCHAR
|
|
and this.expressions
|
|
and this.expressions[0] == exp.column("MAX")
|
|
):
|
|
this.set("expressions", [exp.Var(this="MAX")])
|
|
|
|
return this
|
|
|
|
class Tokenizer(Postgres.Tokenizer):
|
|
STRING_ESCAPES = ["\\"]
|
|
|
|
KEYWORDS = {
|
|
**Postgres.Tokenizer.KEYWORDS, # type: ignore
|
|
"GEOMETRY": TokenType.GEOMETRY,
|
|
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
|
"HLLSKETCH": TokenType.HLLSKETCH,
|
|
"SUPER": TokenType.SUPER,
|
|
"TIME": TokenType.TIMESTAMP,
|
|
"TIMETZ": TokenType.TIMESTAMPTZ,
|
|
"TOP": TokenType.TOP,
|
|
"UNLOAD": TokenType.COMMAND,
|
|
"VARBYTE": TokenType.VARBINARY,
|
|
}
|
|
|
|
class Generator(Postgres.Generator):
|
|
TYPE_MAPPING = {
|
|
**Postgres.Generator.TYPE_MAPPING, # type: ignore
|
|
exp.DataType.Type.BINARY: "VARBYTE",
|
|
exp.DataType.Type.VARBINARY: "VARBYTE",
|
|
exp.DataType.Type.INT: "INTEGER",
|
|
}
|
|
|
|
PROPERTIES_LOCATION = {
|
|
**Postgres.Generator.PROPERTIES_LOCATION, # type: ignore
|
|
exp.LikeProperty: exp.Properties.Location.POST_WITH,
|
|
}
|
|
|
|
TRANSFORMS = {
|
|
**Postgres.Generator.TRANSFORMS, # type: ignore
|
|
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
|
|
exp.DateAdd: lambda self, e: self.func(
|
|
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
|
|
),
|
|
exp.DateDiff: lambda self, e: self.func(
|
|
"DATEDIFF", exp.var(e.text("unit") or "day"), e.expression, e.this
|
|
),
|
|
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
|
|
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
|
|
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
|
|
}
|
|
|
|
# Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres)
|
|
TRANSFORMS.pop(exp.Pow)
|
|
|
|
def values_sql(self, expression: exp.Values) -> str:
|
|
"""
|
|
Converts `VALUES...` expression into a series of unions.
|
|
|
|
Note: If you have a lot of unions then this will result in a large number of recursive statements to
|
|
evaluate the expression. You may need to increase `sys.setrecursionlimit` to run and it can also be
|
|
very slow.
|
|
"""
|
|
if not isinstance(expression.unnest().parent, exp.From):
|
|
return super().values_sql(expression)
|
|
rows = [tuple_exp.expressions for tuple_exp in expression.expressions]
|
|
selects = []
|
|
for i, row in enumerate(rows):
|
|
if i == 0 and expression.alias:
|
|
row = [
|
|
exp.alias_(value, column_name)
|
|
for value, column_name in zip(row, expression.args["alias"].args["columns"])
|
|
]
|
|
selects.append(exp.Select(expressions=row))
|
|
subquery_expression = selects[0]
|
|
if len(selects) > 1:
|
|
for select in selects[1:]:
|
|
subquery_expression = exp.union(subquery_expression, select, distinct=False)
|
|
return self.subquery_sql(subquery_expression.subquery(expression.alias))
|
|
|
|
def with_properties(self, properties: exp.Properties) -> str:
|
|
"""Redshift doesn't have `WITH` as part of their with_properties so we remove it"""
|
|
return self.properties(properties, prefix=" ", suffix="")
|
|
|
|
def renametable_sql(self, expression: exp.RenameTable) -> str:
|
|
"""Redshift only supports defining the table name itself (not the db) when renaming tables"""
|
|
expression = expression.copy()
|
|
target_table = expression.this
|
|
for arg in target_table.args:
|
|
if arg != "this":
|
|
target_table.set(arg, None)
|
|
this = self.sql(expression, "this")
|
|
return f"RENAME TO {this}"
|
|
|
|
def datatype_sql(self, expression: exp.DataType) -> str:
|
|
"""
|
|
Redshift converts the `TEXT` data type to `VARCHAR(255)` by default when people more generally mean
|
|
VARCHAR of max length which is `VARCHAR(max)` in Redshift. Therefore if we get a `TEXT` data type
|
|
without precision we convert it to `VARCHAR(max)` and if it does have precision then we just convert
|
|
`TEXT` to `VARCHAR`.
|
|
"""
|
|
if expression.this == exp.DataType.Type.TEXT:
|
|
expression = expression.copy()
|
|
expression.set("this", exp.DataType.Type.VARCHAR)
|
|
precision = expression.args.get("expressions")
|
|
if not precision:
|
|
expression.append("expressions", exp.Var(this="MAX"))
|
|
return super().datatype_sql(expression)
|