1
0
Fork 0

Merging upstream version 20.9.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:19:14 +01:00
parent 9421b254ec
commit 37a231f554
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
144 changed files with 78309 additions and 59609 deletions

View file

@ -574,13 +574,13 @@ def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Col
def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
return Column.invoke_expression_over_column(
col, expression.DateAdd, expression=days, unit=expression.Var(this="day")
col, expression.DateAdd, expression=days, unit=expression.Var(this="DAY")
)
def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
return Column.invoke_expression_over_column(
col, expression.DateSub, expression=days, unit=expression.Var(this="day")
col, expression.DateSub, expression=days, unit=expression.Var(this="DAY")
)
@ -635,7 +635,7 @@ def next_day(col: ColumnOrName, dayOfWeek: str) -> Column:
def last_day(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "LAST_DAY")
return Column.invoke_expression_over_column(col, expression.LastDay)
def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column:

View file

@ -16,20 +16,22 @@ from sqlglot.dialects.dialect import (
format_time_lambda,
if_sql,
inline_array_sql,
json_keyvalue_comma_sql,
max_or_greatest,
min_or_least,
no_ilike_sql,
parse_date_delta_with_interval,
path_to_jsonpath,
regexp_replace_sql,
rename_func,
timestrtotime_sql,
ts_or_ds_add_cast,
ts_or_ds_to_date_sql,
)
from sqlglot.helper import seq_get, split_num_words
from sqlglot.tokens import TokenType
if t.TYPE_CHECKING:
from typing_extensions import Literal
logger = logging.getLogger("sqlglot")
@ -206,12 +208,17 @@ def _unix_to_time_sql(self: BigQuery.Generator, expression: exp.UnixToTime) -> s
return f"TIMESTAMP_MILLIS({timestamp})"
if scale == exp.UnixToTime.MICROS:
return f"TIMESTAMP_MICROS({timestamp})"
if scale == exp.UnixToTime.NANOS:
# We need to cast to INT64 because that's what BQ expects
return f"TIMESTAMP_MICROS(CAST({timestamp} / 1000 AS INT64))"
self.unsupported(f"Unsupported scale for timestamp: {scale}.")
return ""
return f"TIMESTAMP_SECONDS(CAST({timestamp} / POW(10, {scale}) AS INT64))"
def _parse_time(args: t.List) -> exp.Func:
if len(args) == 1:
return exp.TsOrDsToTime(this=args[0])
if len(args) == 3:
return exp.TimeFromParts.from_arg_list(args)
return exp.Anonymous(this="TIME", expressions=args)
class BigQuery(Dialect):
@ -329,7 +336,13 @@ class BigQuery(Dialect):
"DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd),
"DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub),
"DIV": binary_from_function(exp.IntDiv),
"FORMAT_DATE": lambda args: exp.TimeToStr(
this=exp.TsOrDsToDate(this=seq_get(args, 1)), format=seq_get(args, 0)
),
"GENERATE_ARRAY": exp.GenerateSeries.from_arg_list,
"JSON_EXTRACT_SCALAR": lambda args: exp.JSONExtractScalar(
this=seq_get(args, 0), expression=seq_get(args, 1) or exp.Literal.string("$")
),
"MD5": exp.MD5Digest.from_arg_list,
"TO_HEX": _parse_to_hex,
"PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")(
@ -351,6 +364,7 @@ class BigQuery(Dialect):
this=seq_get(args, 0),
expression=seq_get(args, 1) or exp.Literal.string(","),
),
"TIME": _parse_time,
"TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd),
"TIME_SUB": parse_date_delta_with_interval(exp.TimeSub),
"TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd),
@ -361,9 +375,7 @@ class BigQuery(Dialect):
"TIMESTAMP_MILLIS": lambda args: exp.UnixToTime(
this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS
),
"TIMESTAMP_SECONDS": lambda args: exp.UnixToTime(
this=seq_get(args, 0), scale=exp.UnixToTime.SECONDS
),
"TIMESTAMP_SECONDS": lambda args: exp.UnixToTime(this=seq_get(args, 0)),
"TO_JSON_STRING": exp.JSONFormat.from_arg_list,
}
@ -460,7 +472,15 @@ class BigQuery(Dialect):
return table
def _parse_json_object(self) -> exp.JSONObject:
@t.overload
def _parse_json_object(self, agg: Literal[False]) -> exp.JSONObject:
...
@t.overload
def _parse_json_object(self, agg: Literal[True]) -> exp.JSONObjectAgg:
...
def _parse_json_object(self, agg=False):
json_object = super()._parse_json_object()
array_kv_pair = seq_get(json_object.expressions, 0)
@ -513,6 +533,10 @@ class BigQuery(Dialect):
UNNEST_WITH_ORDINALITY = False
COLLATE_IS_FUNC = True
LIMIT_ONLY_LITERALS = True
SUPPORTS_TABLE_ALIAS_COLUMNS = False
UNPIVOT_ALIASES_ARE_IDENTIFIERS = False
JSON_KEY_VALUE_PAIR_SEP = ","
NULL_ORDERING_SUPPORTED = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@ -525,6 +549,7 @@ class BigQuery(Dialect):
exp.CollateProperty: lambda self, e: f"DEFAULT COLLATE {self.sql(e, 'this')}"
if e.args.get("default")
else f"COLLATE {self.sql(e, 'this')}",
exp.CountIf: rename_func("COUNTIF"),
exp.Create: _create_sql,
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
exp.DateAdd: date_add_interval_sql("DATE", "ADD"),
@ -536,13 +561,13 @@ class BigQuery(Dialect):
exp.DatetimeSub: date_add_interval_sql("DATETIME", "SUB"),
exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")),
exp.GenerateSeries: rename_func("GENERATE_ARRAY"),
exp.GetPath: path_to_jsonpath(),
exp.GroupConcat: rename_func("STRING_AGG"),
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql(false_value="NULL"),
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
exp.JSONFormat: rename_func("TO_JSON_STRING"),
exp.JSONKeyValue: json_keyvalue_comma_sql,
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)),
exp.MD5Digest: rename_func("MD5"),
@ -578,16 +603,17 @@ class BigQuery(Dialect):
"PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone")
),
exp.TimeAdd: date_add_interval_sql("TIME", "ADD"),
exp.TimeFromParts: rename_func("TIME"),
exp.TimeSub: date_add_interval_sql("TIME", "SUB"),
exp.TimestampAdd: date_add_interval_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: date_add_interval_sql("TIMESTAMP", "SUB"),
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: lambda self, e: f"FORMAT_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression),
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsDiff: _ts_or_ds_diff_sql,
exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"),
exp.TsOrDsToTime: rename_func("TIME"),
exp.Unhex: rename_func("FROM_HEX"),
exp.UnixDate: rename_func("UNIX_DATE"),
exp.UnixToTime: _unix_to_time_sql,
exp.Values: _derived_table_values_to_unnest,
exp.VariancePop: rename_func("VAR_POP"),
@ -724,6 +750,26 @@ class BigQuery(Dialect):
"within",
}
def timetostr_sql(self, expression: exp.TimeToStr) -> str:
if isinstance(expression.this, exp.TsOrDsToDate):
this: exp.Expression = expression.this
else:
this = expression
return f"FORMAT_DATE({self.format_time(expression)}, {self.sql(this, 'this')})"
def struct_sql(self, expression: exp.Struct) -> str:
args = []
for expr in expression.expressions:
if isinstance(expr, self.KEY_VALUE_DEFINITIONS):
arg = f"{self.sql(expr, 'expression')} AS {expr.this.name}"
else:
arg = self.sql(expr)
args.append(arg)
return self.func("STRUCT", *args)
def eq_sql(self, expression: exp.EQ) -> str:
# Operands of = cannot be NULL in BigQuery
if isinstance(expression.left, exp.Null) or isinstance(expression.right, exp.Null):
@ -760,7 +806,20 @@ class BigQuery(Dialect):
return inline_array_sql(self, expression)
def bracket_sql(self, expression: exp.Bracket) -> str:
this = self.sql(expression, "this")
expressions = expression.expressions
if len(expressions) == 1:
arg = expressions[0]
if arg.type is None:
from sqlglot.optimizer.annotate_types import annotate_types
arg = annotate_types(arg)
if arg.type and arg.type.this in exp.DataType.TEXT_TYPES:
# BQ doesn't support bracket syntax with string values
return f"{this}.{arg.name}"
expressions_sql = ", ".join(self.sql(e) for e in expressions)
offset = expression.args.get("offset")
@ -768,13 +827,13 @@ class BigQuery(Dialect):
expressions_sql = f"OFFSET({expressions_sql})"
elif offset == 1:
expressions_sql = f"ORDINAL({expressions_sql})"
else:
elif offset is not None:
self.unsupported(f"Unsupported array offset: {offset}")
if expression.args.get("safe"):
expressions_sql = f"SAFE_{expressions_sql}"
return f"{self.sql(expression, 'this')}[{expressions_sql}]"
return f"{this}[{expressions_sql}]"
def transaction_sql(self, *_) -> str:
return "BEGIN TRANSACTION"

View file

@ -6,6 +6,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
arg_max_or_min_no_count,
date_delta_sql,
inline_array_sql,
no_pivot_sql,
rename_func,
@ -22,16 +23,25 @@ def _lower_func(sql: str) -> str:
return sql[:index].lower() + sql[index:]
def _quantile_sql(self, e):
def _quantile_sql(self: ClickHouse.Generator, e: exp.Quantile) -> str:
quantile = e.args["quantile"]
args = f"({self.sql(e, 'this')})"
if isinstance(quantile, exp.Array):
func = self.func("quantiles", *quantile)
else:
func = self.func("quantile", quantile)
return func + args
def _parse_count_if(args: t.List) -> exp.CountIf | exp.CombinedAggFunc:
if len(args) == 1:
return exp.CountIf(this=seq_get(args, 0))
return exp.CombinedAggFunc(this="countIf", expressions=args, parts=("count", "If"))
class ClickHouse(Dialect):
NORMALIZE_FUNCTIONS: bool | str = False
NULL_ORDERING = "nulls_are_last"
@ -53,6 +63,7 @@ class ClickHouse(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ATTACH": TokenType.COMMAND,
"DATE32": TokenType.DATE32,
"DATETIME64": TokenType.DATETIME64,
"DICTIONARY": TokenType.DICTIONARY,
"ENUM": TokenType.ENUM,
@ -75,6 +86,8 @@ class ClickHouse(Dialect):
"UINT32": TokenType.UINT,
"UINT64": TokenType.UBIGINT,
"UINT8": TokenType.UTINYINT,
"IPV4": TokenType.IPV4,
"IPV6": TokenType.IPV6,
}
SINGLE_TOKENS = {
@ -91,6 +104,8 @@ class ClickHouse(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ANY": exp.AnyValue.from_arg_list,
"ARRAYSUM": exp.ArraySum.from_arg_list,
"COUNTIF": _parse_count_if,
"DATE_ADD": lambda args: exp.DateAdd(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
@ -110,6 +125,138 @@ class ClickHouse(Dialect):
"XOR": lambda args: exp.Xor(expressions=args),
}
AGG_FUNCTIONS = {
"count",
"min",
"max",
"sum",
"avg",
"any",
"stddevPop",
"stddevSamp",
"varPop",
"varSamp",
"corr",
"covarPop",
"covarSamp",
"entropy",
"exponentialMovingAverage",
"intervalLengthSum",
"kolmogorovSmirnovTest",
"mannWhitneyUTest",
"median",
"rankCorr",
"sumKahan",
"studentTTest",
"welchTTest",
"anyHeavy",
"anyLast",
"boundingRatio",
"first_value",
"last_value",
"argMin",
"argMax",
"avgWeighted",
"topK",
"topKWeighted",
"deltaSum",
"deltaSumTimestamp",
"groupArray",
"groupArrayLast",
"groupUniqArray",
"groupArrayInsertAt",
"groupArrayMovingAvg",
"groupArrayMovingSum",
"groupArraySample",
"groupBitAnd",
"groupBitOr",
"groupBitXor",
"groupBitmap",
"groupBitmapAnd",
"groupBitmapOr",
"groupBitmapXor",
"sumWithOverflow",
"sumMap",
"minMap",
"maxMap",
"skewSamp",
"skewPop",
"kurtSamp",
"kurtPop",
"uniq",
"uniqExact",
"uniqCombined",
"uniqCombined64",
"uniqHLL12",
"uniqTheta",
"quantile",
"quantiles",
"quantileExact",
"quantilesExact",
"quantileExactLow",
"quantilesExactLow",
"quantileExactHigh",
"quantilesExactHigh",
"quantileExactWeighted",
"quantilesExactWeighted",
"quantileTiming",
"quantilesTiming",
"quantileTimingWeighted",
"quantilesTimingWeighted",
"quantileDeterministic",
"quantilesDeterministic",
"quantileTDigest",
"quantilesTDigest",
"quantileTDigestWeighted",
"quantilesTDigestWeighted",
"quantileBFloat16",
"quantilesBFloat16",
"quantileBFloat16Weighted",
"quantilesBFloat16Weighted",
"simpleLinearRegression",
"stochasticLinearRegression",
"stochasticLogisticRegression",
"categoricalInformationValue",
"contingency",
"cramersV",
"cramersVBiasCorrected",
"theilsU",
"maxIntersections",
"maxIntersectionsPosition",
"meanZTest",
"quantileInterpolatedWeighted",
"quantilesInterpolatedWeighted",
"quantileGK",
"quantilesGK",
"sparkBar",
"sumCount",
"largestTriangleThreeBuckets",
}
AGG_FUNCTIONS_SUFFIXES = [
"If",
"Array",
"ArrayIf",
"Map",
"SimpleState",
"State",
"Merge",
"MergeState",
"ForEach",
"Distinct",
"OrDefault",
"OrNull",
"Resample",
"ArgMin",
"ArgMax",
]
AGG_FUNC_MAPPING = (
lambda functions, suffixes: {
f"{f}{sfx}": (f, sfx) for sfx in (suffixes + [""]) for f in functions
}
)(AGG_FUNCTIONS, AGG_FUNCTIONS_SUFFIXES)
FUNCTIONS_WITH_ALIASED_ARGS = {*parser.Parser.FUNCTIONS_WITH_ALIASED_ARGS, "TUPLE"}
FUNCTION_PARSERS = {
@ -272,9 +419,18 @@ class ClickHouse(Dialect):
)
if isinstance(func, exp.Anonymous):
parts = self.AGG_FUNC_MAPPING.get(func.this)
params = self._parse_func_params(func)
if params:
if parts and parts[1]:
return self.expression(
exp.CombinedParameterizedAgg,
this=func.this,
expressions=func.expressions,
params=params,
parts=parts,
)
return self.expression(
exp.ParameterizedAgg,
this=func.this,
@ -282,6 +438,20 @@ class ClickHouse(Dialect):
params=params,
)
if parts:
if parts[1]:
return self.expression(
exp.CombinedAggFunc,
this=func.this,
expressions=func.expressions,
parts=parts,
)
return self.expression(
exp.AnonymousAggFunc,
this=func.this,
expressions=func.expressions,
)
return func
def _parse_func_params(
@ -329,6 +499,9 @@ class ClickHouse(Dialect):
STRUCT_DELIMITER = ("(", ")")
NVL2_SUPPORTED = False
TABLESAMPLE_REQUIRES_PARENS = False
TABLESAMPLE_SIZE_IS_ROWS = False
TABLESAMPLE_KEYWORDS = "SAMPLE"
LAST_DAY_SUPPORTS_DATE_PART = False
STRING_TYPE_MAPPING = {
exp.DataType.Type.CHAR: "String",
@ -348,6 +521,7 @@ class ClickHouse(Dialect):
**STRING_TYPE_MAPPING,
exp.DataType.Type.ARRAY: "Array",
exp.DataType.Type.BIGINT: "Int64",
exp.DataType.Type.DATE32: "Date32",
exp.DataType.Type.DATETIME64: "DateTime64",
exp.DataType.Type.DOUBLE: "Float64",
exp.DataType.Type.ENUM: "Enum",
@ -372,24 +546,23 @@ class ClickHouse(Dialect):
exp.DataType.Type.UINT256: "UInt256",
exp.DataType.Type.USMALLINT: "UInt16",
exp.DataType.Type.UTINYINT: "UInt8",
exp.DataType.Type.IPV4: "IPv4",
exp.DataType.Type.IPV6: "IPv6",
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.Select: transforms.preprocess([transforms.eliminate_qualify]),
exp.AnyValue: rename_func("any"),
exp.ApproxDistinct: rename_func("uniq"),
exp.ArraySum: rename_func("arraySum"),
exp.ArgMax: arg_max_or_min_no_count("argMax"),
exp.ArgMin: arg_max_or_min_no_count("argMin"),
exp.Array: inline_array_sql,
exp.CastToStrType: rename_func("CAST"),
exp.CountIf: rename_func("countIf"),
exp.CurrentDate: lambda self, e: self.func("CURRENT_DATE"),
exp.DateAdd: lambda self, e: self.func(
"DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
),
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
),
exp.DateAdd: date_delta_sql("DATE_ADD"),
exp.DateDiff: date_delta_sql("DATE_DIFF"),
exp.Explode: rename_func("arrayJoin"),
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
exp.IsNan: rename_func("isNaN"),
@ -400,6 +573,7 @@ class ClickHouse(Dialect):
exp.Quantile: _quantile_sql,
exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})",
exp.Rand: rename_func("randCanonical"),
exp.Select: transforms.preprocess([transforms.eliminate_qualify]),
exp.StartsWith: rename_func("startsWith"),
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
@ -485,10 +659,19 @@ class ClickHouse(Dialect):
else "",
]
def parameterizedagg_sql(self, expression: exp.Anonymous) -> str:
def parameterizedagg_sql(self, expression: exp.ParameterizedAgg) -> str:
params = self.expressions(expression, key="params", flat=True)
return self.func(expression.name, *expression.expressions) + f"({params})"
def anonymousaggfunc_sql(self, expression: exp.AnonymousAggFunc) -> str:
return self.func(expression.name, *expression.expressions)
def combinedaggfunc_sql(self, expression: exp.CombinedAggFunc) -> str:
return self.anonymousaggfunc_sql(expression)
def combinedparameterizedagg_sql(self, expression: exp.CombinedParameterizedAgg) -> str:
return self.parameterizedagg_sql(expression)
def placeholder_sql(self, expression: exp.Placeholder) -> str:
return f"{{{expression.name}: {self.sql(expression, 'kind')}}}"

View file

@ -30,6 +30,8 @@ class Databricks(Spark):
}
class Generator(Spark.Generator):
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
TRANSFORMS = {
**Spark.Generator.TRANSFORMS,
exp.DateAdd: date_delta_sql("DATEADD"),

View file

@ -153,6 +153,9 @@ class Dialect(metaclass=_Dialect):
ALIAS_POST_TABLESAMPLE = False
"""Determines whether or not the table alias comes after tablesample."""
TABLESAMPLE_SIZE_IS_PERCENT = False
"""Determines whether or not a size in the table sample clause represents percentage."""
NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
"""Specifies the strategy according to which identifiers should be normalized."""
@ -220,6 +223,24 @@ class Dialect(metaclass=_Dialect):
For example, such columns may be excluded from `SELECT *` queries.
"""
PREFER_CTE_ALIAS_COLUMN = False
"""
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
HAVING clause of the CTE. This flag will cause the CTE alias columns to override
any projection aliases in the subquery.
For example,
WITH y(c) AS (
SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
"""
# --- Autofilled ---
tokenizer_class = Tokenizer
@ -287,7 +308,13 @@ class Dialect(metaclass=_Dialect):
result = cls.get(dialect_name.strip())
if not result:
raise ValueError(f"Unknown dialect '{dialect_name}'.")
from difflib import get_close_matches
similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
if similar:
similar = f" Did you mean {similar}?"
raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
return result(**kwargs)
@ -506,7 +533,7 @@ def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
n = self.sql(expression, "this")
d = self.sql(expression, "expression")
return f"IF({d} <> 0, {n} / {d}, NULL)"
return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
@ -695,7 +722,7 @@ def date_add_interval_sql(
def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
return self.func(
"DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
"DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
)
@ -801,22 +828,6 @@ def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
return self.func("STRPTIME", expression.this, self.format_time(expression))
def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
_dialect = Dialect.get_or_raise(dialect)
time_format = self.format_time(expression)
if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
return self.sql(
exp.cast(
exp.StrToTime(this=expression.this, format=expression.args["format"]),
"date",
)
)
return self.sql(exp.cast(expression.this, "date"))
return _ts_or_ds_to_date_sql
def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
@ -894,11 +905,6 @@ def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon
def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str:
return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
def is_parse_json(expression: exp.Expression) -> bool:
return isinstance(expression, exp.ParseJSON) or (
isinstance(expression, exp.Cast) and expression.is_type("json")
@ -946,7 +952,70 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE
expression = ts_or_ds_add_cast(expression)
return self.func(
name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this
name,
exp.var(expression.text("unit").upper() or "DAY"),
expression.expression,
expression.this,
)
return _delta_sql
def prepend_dollar_to_path(expression: exp.GetPath) -> exp.GetPath:
from sqlglot.optimizer.simplify import simplify
# Makes sure the path will be evaluated correctly at runtime to include the path root.
# For example, `[0].foo` will become `$[0].foo`, and `foo` will become `$.foo`.
path = expression.expression
path = exp.func(
"if",
exp.func("startswith", path, "'['"),
exp.func("concat", "'$'", path),
exp.func("concat", "'$.'", path),
)
expression.expression.replace(simplify(path))
return expression
def path_to_jsonpath(
name: str = "JSON_EXTRACT",
) -> t.Callable[[Generator, exp.GetPath], str]:
def _transform(self: Generator, expression: exp.GetPath) -> str:
return rename_func(name)(self, prepend_dollar_to_path(expression))
return _transform
def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
trunc_curr_date = exp.func("date_trunc", "month", expression.this)
plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
return self.sql(exp.cast(minus_one_day, "date"))
def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
"""Remove table refs from columns in when statements."""
alias = expression.this.args.get("alias")
normalize = (
lambda identifier: self.dialect.normalize_identifier(identifier).name
if identifier
else None
)
targets = {normalize(expression.this.this)}
if alias:
targets.add(normalize(alias.this))
for when in expression.expressions:
when.transform(
lambda node: exp.column(node.this)
if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
else node,
copy=False,
)
return self.merge_sql(expression)

View file

@ -22,6 +22,7 @@ class Doris(MySQL):
"COLLECT_SET": exp.ArrayUniqueAgg.from_arg_list,
"DATE_TRUNC": parse_timestamp_trunc,
"REGEXP": exp.RegexpLike.from_arg_list,
"TO_DATE": exp.TsOrDsToDate.from_arg_list,
}
class Generator(MySQL.Generator):
@ -34,21 +35,26 @@ class Doris(MySQL):
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
}
LAST_DAY_SUPPORTS_DATE_PART = False
TIMESTAMP_FUNC_TYPES = set()
TRANSFORMS = {
**MySQL.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArgMax: rename_func("MAX_BY"),
exp.ArgMin: rename_func("MIN_BY"),
exp.ArrayAgg: rename_func("COLLECT_LIST"),
exp.ArrayUniqueAgg: rename_func("COLLECT_SET"),
exp.CurrentTimestamp: lambda *_: "NOW()",
exp.DateTrunc: lambda self, e: self.func(
"DATE_TRUNC", e.this, "'" + e.text("unit") + "'"
),
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.Map: rename_func("ARRAY_MAP"),
exp.RegexpLike: rename_func("REGEXP"),
exp.RegexpSplit: rename_func("SPLIT_BY_STRING"),
exp.ArrayUniqueAgg: rename_func("COLLECT_SET"),
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Split: rename_func("SPLIT_BY_STRING"),
exp.TimeStrToDate: rename_func("TO_DATE"),
@ -63,5 +69,4 @@ class Doris(MySQL):
"FROM_UNIXTIME", e.this, time_format("doris")(self, e)
),
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.Map: rename_func("ARRAY_MAP"),
}

View file

@ -12,7 +12,6 @@ from sqlglot.dialects.dialect import (
rename_func,
str_position_sql,
timestrtotime_sql,
ts_or_ds_to_date_sql,
)
@ -99,6 +98,7 @@ class Drill(Dialect):
TABLE_HINTS = False
QUERY_HINTS = False
NVL2_SUPPORTED = False
LAST_DAY_SUPPORTS_DATE_PART = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@ -150,7 +150,6 @@ class Drill(Dialect):
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.TryCast: no_trycast_sql,
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})",
exp.TsOrDsToDate: ts_or_ds_to_date_sql("drill"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
}

View file

@ -22,15 +22,15 @@ from sqlglot.dialects.dialect import (
no_safe_divide_sql,
no_timestamp_sql,
pivot_column_names,
prepend_dollar_to_path,
regexp_extract_sql,
rename_func,
str_position_sql,
str_to_time_sql,
timestamptrunc_sql,
timestrtotime_sql,
ts_or_ds_to_date_sql,
)
from sqlglot.helper import seq_get
from sqlglot.helper import flatten, seq_get
from sqlglot.tokens import TokenType
@ -141,11 +141,25 @@ def _unix_to_time_sql(self: DuckDB.Generator, expression: exp.UnixToTime) -> str
return f"EPOCH_MS({timestamp})"
if scale == exp.UnixToTime.MICROS:
return f"MAKE_TIMESTAMP({timestamp})"
if scale == exp.UnixToTime.NANOS:
return f"TO_TIMESTAMP({timestamp} / 1000000000)"
self.unsupported(f"Unsupported scale for timestamp: {scale}.")
return ""
return f"TO_TIMESTAMP({timestamp} / POW(10, {scale}))"
def _rename_unless_within_group(
a: str, b: str
) -> t.Callable[[DuckDB.Generator, exp.Expression], str]:
return (
lambda self, expression: self.func(a, *flatten(expression.args.values()))
if isinstance(expression.find_ancestor(exp.Select, exp.WithinGroup), exp.WithinGroup)
else self.func(b, *flatten(expression.args.values()))
)
def _parse_struct_pack(args: t.List) -> exp.Struct:
args_with_columns_as_identifiers = [
exp.PropertyEQ(this=arg.this.this, expression=arg.expression) for arg in args
]
return exp.Struct.from_arg_list(args_with_columns_as_identifiers)
class DuckDB(Dialect):
@ -183,6 +197,11 @@ class DuckDB(Dialect):
"TIMESTAMP_US": TokenType.TIMESTAMP,
}
SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
"$": TokenType.PARAMETER,
}
class Parser(parser.Parser):
BITWISE = {
**parser.Parser.BITWISE,
@ -209,10 +228,12 @@ class DuckDB(Dialect):
"EPOCH_MS": lambda args: exp.UnixToTime(
this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS
),
"JSON": exp.ParseJSON.from_arg_list,
"LIST_HAS": exp.ArrayContains.from_arg_list,
"LIST_REVERSE_SORT": _sort_array_reverse,
"LIST_SORT": exp.SortArray.from_arg_list,
"LIST_VALUE": exp.Array.from_arg_list,
"MAKE_TIME": exp.TimeFromParts.from_arg_list,
"MAKE_TIMESTAMP": _parse_make_timestamp,
"MEDIAN": lambda args: exp.PercentileCont(
this=seq_get(args, 0), expression=exp.Literal.number(0.5)
@ -234,7 +255,7 @@ class DuckDB(Dialect):
"STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"STRING_TO_ARRAY": exp.Split.from_arg_list,
"STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"),
"STRUCT_PACK": exp.Struct.from_arg_list,
"STRUCT_PACK": _parse_struct_pack,
"STR_SPLIT": exp.Split.from_arg_list,
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"TO_TIMESTAMP": exp.UnixToTime.from_arg_list,
@ -250,6 +271,13 @@ class DuckDB(Dialect):
TokenType.ANTI,
}
PLACEHOLDER_PARSERS = {
**parser.Parser.PLACEHOLDER_PARSERS,
TokenType.PARAMETER: lambda self: self.expression(exp.Placeholder, this=self._prev.text)
if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS)
else None,
}
def _parse_types(
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
) -> t.Optional[exp.Expression]:
@ -268,7 +296,7 @@ class DuckDB(Dialect):
return this
def _parse_struct_types(self) -> t.Optional[exp.Expression]:
def _parse_struct_types(self, type_required: bool = False) -> t.Optional[exp.Expression]:
return self._parse_field_def()
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
@ -285,6 +313,10 @@ class DuckDB(Dialect):
RENAME_TABLE_WITH_DB = False
NVL2_SUPPORTED = False
SEMI_ANTI_JOIN_WITH_SIDE = False
TABLESAMPLE_KEYWORDS = "USING SAMPLE"
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
LAST_DAY_SUPPORTS_DATE_PART = False
JSON_KEY_VALUE_PAIR_SEP = ","
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@ -311,7 +343,7 @@ class DuckDB(Dialect):
exp.DateFromParts: rename_func("MAKE_DATE"),
exp.DateSub: _date_delta_sql,
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", f"'{e.args.get('unit') or 'day'}'", e.expression, e.this
"DATE_DIFF", f"'{e.args.get('unit') or 'DAY'}'", e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)",
@ -322,11 +354,11 @@ class DuckDB(Dialect):
exp.IntDiv: lambda self, e: self.binary(e, "//"),
exp.IsInf: rename_func("ISINF"),
exp.IsNan: rename_func("ISNAN"),
exp.JSONBExtract: arrow_json_extract_sql,
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONFormat: _json_format_sql,
exp.JSONBExtract: arrow_json_extract_sql,
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.MonthsBetween: lambda self, e: self.func(
@ -336,8 +368,8 @@ class DuckDB(Dialect):
exp.cast(e.this, "timestamp", copy=True),
),
exp.ParseJSON: rename_func("JSON"),
exp.PercentileCont: rename_func("QUANTILE_CONT"),
exp.PercentileDisc: rename_func("QUANTILE_DISC"),
exp.PercentileCont: _rename_unless_within_group("PERCENTILE_CONT", "QUANTILE_CONT"),
exp.PercentileDisc: _rename_unless_within_group("PERCENTILE_DISC", "QUANTILE_DISC"),
# DuckDB doesn't allow qualified columns inside of PIVOT expressions.
# See: https://github.com/duckdb/duckdb/blob/671faf92411182f81dce42ac43de8bfb05d9909e/src/planner/binder/tableref/bind_pivot.cpp#L61-L62
exp.Pivot: transforms.preprocess([transforms.unqualify_columns]),
@ -362,7 +394,9 @@ class DuckDB(Dialect):
exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.Struct: _struct_sql,
exp.Timestamp: no_timestamp_sql,
exp.TimestampFromParts: rename_func("MAKE_TIMESTAMP"),
exp.TimestampDiff: lambda self, e: self.func(
"DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this
),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: timestrtotime_sql,
@ -373,11 +407,10 @@ class DuckDB(Dialect):
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsDiff: lambda self, e: self.func(
"DATE_DIFF",
f"'{e.args.get('unit') or 'day'}'",
f"'{e.args.get('unit') or 'DAY'}'",
exp.cast(e.expression, "TIMESTAMP"),
exp.cast(e.this, "TIMESTAMP"),
),
exp.TsOrDsToDate: ts_or_ds_to_date_sql("duckdb"),
exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: _unix_to_time_sql,
exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)",
@ -410,6 +443,49 @@ class DuckDB(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
def timefromparts_sql(self, expression: exp.TimeFromParts) -> str:
nano = expression.args.get("nano")
if nano is not None:
expression.set(
"sec", expression.args["sec"] + nano.pop() / exp.Literal.number(1000000000.0)
)
return rename_func("MAKE_TIME")(self, expression)
def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str:
sec = expression.args["sec"]
milli = expression.args.get("milli")
if milli is not None:
sec += milli.pop() / exp.Literal.number(1000.0)
nano = expression.args.get("nano")
if nano is not None:
sec += nano.pop() / exp.Literal.number(1000000000.0)
if milli or nano:
expression.set("sec", sec)
return rename_func("MAKE_TIMESTAMP")(self, expression)
def tablesample_sql(
self,
expression: exp.TableSample,
sep: str = " AS ",
tablesample_keyword: t.Optional[str] = None,
) -> str:
if not isinstance(expression.parent, exp.Select):
# This sample clause only applies to a single source, not the entire resulting relation
tablesample_keyword = "TABLESAMPLE"
return super().tablesample_sql(
expression, sep=sep, tablesample_keyword=tablesample_keyword
)
def getpath_sql(self, expression: exp.GetPath) -> str:
expression = prepend_dollar_to_path(expression)
return f"{self.sql(expression, 'this')} -> {self.sql(expression, 'expression')}"
def interval_sql(self, expression: exp.Interval) -> str:
multiplier: t.Optional[int] = None
unit = expression.text("unit").lower()
@ -420,11 +496,14 @@ class DuckDB(Dialect):
multiplier = 90
if multiplier:
return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('day')))})"
return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('DAY')))})"
return super().interval_sql(expression)
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep: str = " AS "
) -> str:
return super().tablesample_sql(expression, seed_prefix="REPEATABLE", sep=sep)
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
if isinstance(expression.parent, exp.UserDefinedFunction):
return self.sql(expression, "this")
return super().columndef_sql(expression, sep)
def placeholder_sql(self, expression: exp.Placeholder) -> str:
return f"${expression.name}" if expression.name else "?"

View file

@ -418,13 +418,13 @@ class Hive(Dialect):
class Generator(generator.Generator):
LIMIT_FETCH = "LIMIT"
TABLESAMPLE_WITH_METHOD = False
TABLESAMPLE_SIZE_IS_PERCENT = True
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
INDEX_ON = "ON TABLE"
EXTRACT_ALLOWS_QUOTES = False
NVL2_SUPPORTED = False
LAST_DAY_SUPPORTS_DATE_PART = False
EXPRESSIONS_WITHOUT_NESTED_CTES = {
exp.Insert,
@ -523,7 +523,6 @@ class Hive(Dialect):
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
exp.LastDateOfMonth: rename_func("LAST_DAY"),
exp.National: lambda self, e: self.national_sql(e, prefix=""),
exp.ClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})",
exp.NonClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})",

View file

@ -11,7 +11,6 @@ from sqlglot.dialects.dialect import (
datestrtodate_sql,
format_time_lambda,
isnull_to_is_null,
json_keyvalue_comma_sql,
locate_to_strposition,
max_or_greatest,
min_or_least,
@ -21,6 +20,7 @@ from sqlglot.dialects.dialect import (
no_tablesample_sql,
no_trycast_sql,
parse_date_delta_with_interval,
path_to_jsonpath,
rename_func,
strposition_to_locate_sql,
)
@ -37,21 +37,21 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[MySQL.Parser], ex
def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str:
expr = self.sql(expression, "this")
unit = expression.text("unit")
unit = expression.text("unit").upper()
if unit == "day":
if unit == "DAY":
return f"DATE({expr})"
if unit == "week":
if unit == "WEEK":
concat = f"CONCAT(YEAR({expr}), ' ', WEEK({expr}, 1), ' 1')"
date_format = "%Y %u %w"
elif unit == "month":
elif unit == "MONTH":
concat = f"CONCAT(YEAR({expr}), ' ', MONTH({expr}), ' 1')"
date_format = "%Y %c %e"
elif unit == "quarter":
elif unit == "QUARTER":
concat = f"CONCAT(YEAR({expr}), ' ', QUARTER({expr}) * 3 - 2, ' 1')"
date_format = "%Y %c %e"
elif unit == "year":
elif unit == "YEAR":
concat = f"CONCAT(YEAR({expr}), ' 1 1')"
date_format = "%Y %c %e"
else:
@ -292,9 +292,15 @@ class MySQL(Dialect):
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
"ISNULL": isnull_to_is_null,
"LOCATE": locate_to_strposition,
"MAKETIME": exp.TimeFromParts.from_arg_list,
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"MONTHNAME": lambda args: exp.TimeToStr(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
format=exp.Literal.string("%B"),
@ -308,11 +314,6 @@ class MySQL(Dialect):
)
+ 1
),
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"WEEK": lambda args: exp.Week(
this=exp.TsOrDsToDate(this=seq_get(args, 0)), mode=seq_get(args, 1)
),
@ -441,6 +442,7 @@ class MySQL(Dialect):
}
LOG_DEFAULTS_TO_LN = True
STRING_ALIASES = True
def _parse_primary_key_part(self) -> t.Optional[exp.Expression]:
this = self._parse_id_var()
@ -620,13 +622,15 @@ class MySQL(Dialect):
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
NULL_ORDERING_SUPPORTED = False
NULL_ORDERING_SUPPORTED = None
JOIN_HINTS = False
TABLE_HINTS = True
DUPLICATE_KEY_UPDATE_WITH_SET = False
QUERY_HINT_SEP = " "
VALUES_AS_TABLE = False
NVL2_SUPPORTED = False
LAST_DAY_SUPPORTS_DATE_PART = False
JSON_KEY_VALUE_PAIR_SEP = ","
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@ -642,15 +646,16 @@ class MySQL(Dialect):
exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")),
exp.DayOfWeek: _remove_ts_or_ds_to_date(rename_func("DAYOFWEEK")),
exp.DayOfYear: _remove_ts_or_ds_to_date(rename_func("DAYOFYEAR")),
exp.GetPath: path_to_jsonpath(),
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
exp.ILike: no_ilike_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONKeyValue: json_keyvalue_comma_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Month: _remove_ts_or_ds_to_date(),
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: f"NOT {self.binary(e, '<=>')}",
exp.ParseJSON: lambda self, e: self.sql(e, "this"),
exp.Pivot: no_pivot_sql,
exp.Select: transforms.preprocess(
[
@ -665,6 +670,7 @@ class MySQL(Dialect):
exp.StrToTime: _str_to_date_sql,
exp.Stuff: rename_func("INSERT"),
exp.TableSample: no_tablesample_sql,
exp.TimeFromParts: rename_func("MAKETIME"),
exp.TimestampAdd: date_add_interval_sql("DATE", "ADD"),
exp.TimestampSub: date_add_interval_sql("DATE", "SUB"),
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),

View file

@ -53,6 +53,7 @@ def to_char(args: t.List) -> exp.TimeToStr | exp.ToChar:
class Oracle(Dialect):
ALIAS_POST_TABLESAMPLE = True
LOCKING_READS_SUPPORTED = True
TABLESAMPLE_SIZE_IS_PERCENT = True
# See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
@ -81,6 +82,7 @@ class Oracle(Dialect):
"WW": "%W", # Week of year (1-53)
"YY": "%y", # 15
"YYYY": "%Y", # 2015
"FF6": "%f", # only 6 digits are supported in python formats
}
class Parser(parser.Parser):
@ -91,6 +93,8 @@ class Oracle(Dialect):
**parser.Parser.FUNCTIONS,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"TO_CHAR": to_char,
"TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "oracle"),
"TO_DATE": format_time_lambda(exp.StrToDate, "oracle"),
}
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
@ -107,6 +111,11 @@ class Oracle(Dialect):
"XMLTABLE": _parse_xml_table,
}
QUERY_MODIFIER_PARSERS = {
**parser.Parser.QUERY_MODIFIER_PARSERS,
TokenType.ORDER_SIBLINGS_BY: lambda self: ("order", self._parse_order()),
}
TYPE_LITERAL_PARSERS = {
exp.DataType.Type.DATE: lambda self, this, _: self.expression(
exp.DateStrToDate, this=this
@ -153,8 +162,10 @@ class Oracle(Dialect):
COLUMN_JOIN_MARKS_SUPPORTED = True
DATA_TYPE_SPECIFIERS_ALLOWED = True
ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = False
LIMIT_FETCH = "FETCH"
TABLESAMPLE_KEYWORDS = "SAMPLE"
LAST_DAY_SUPPORTS_DATE_PART = False
SUPPORTS_SELECT_INTO = True
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@ -186,6 +197,7 @@ class Oracle(Dialect):
]
),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.StrToDate: lambda self, e: f"TO_DATE({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
exp.Substring: rename_func("SUBSTR"),
exp.Table: lambda self, e: self.table_sql(e, sep=" "),
@ -201,6 +213,10 @@ class Oracle(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
def currenttimestamp_sql(self, expression: exp.CurrentTimestamp) -> str:
this = expression.this
return self.func("CURRENT_TIMESTAMP", this) if this else "CURRENT_TIMESTAMP"
def offset_sql(self, expression: exp.Offset) -> str:
return f"{super().offset_sql(expression)} ROWS"
@ -233,8 +249,10 @@ class Oracle(Dialect):
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"MINUS": TokenType.EXCEPT,
"NVARCHAR2": TokenType.NVARCHAR,
"ORDER SIBLINGS BY": TokenType.ORDER_SIBLINGS_BY,
"SAMPLE": TokenType.TABLE_SAMPLE,
"START": TokenType.BEGIN,
"SYSDATE": TokenType.CURRENT_TIMESTAMP,
"TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR,
}

View file

@ -13,11 +13,12 @@ from sqlglot.dialects.dialect import (
datestrtodate_sql,
format_time_lambda,
max_or_greatest,
merge_without_target_sql,
min_or_least,
no_last_day_sql,
no_map_from_entries_sql,
no_paren_current_date_sql,
no_pivot_sql,
no_tablesample_sql,
no_trycast_sql,
parse_timestamp_trunc,
rename_func,
@ -27,7 +28,6 @@ from sqlglot.dialects.dialect import (
timestrtotime_sql,
trim_sql,
ts_or_ds_add_cast,
ts_or_ds_to_date_sql,
)
from sqlglot.helper import seq_get
from sqlglot.parser import binary_range_parser
@ -188,36 +188,6 @@ def _to_timestamp(args: t.List) -> exp.Expression:
return format_time_lambda(exp.StrToTime, "postgres")(args)
def _merge_sql(self: Postgres.Generator, expression: exp.Merge) -> str:
def _remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
"""Remove table refs from columns in when statements."""
if isinstance(expression, exp.Merge):
alias = expression.this.args.get("alias")
normalize = (
lambda identifier: self.dialect.normalize_identifier(identifier).name
if identifier
else None
)
targets = {normalize(expression.this.this)}
if alias:
targets.add(normalize(alias.this))
for when in expression.expressions:
when.transform(
lambda node: exp.column(node.this)
if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
else node,
copy=False,
)
return expression
return transforms.preprocess([_remove_target_from_merge])(self, expression)
class Postgres(Dialect):
INDEX_OFFSET = 1
TYPED_DIVISION = True
@ -316,6 +286,8 @@ class Postgres(Dialect):
**parser.Parser.FUNCTIONS,
"DATE_TRUNC": parse_timestamp_trunc,
"GENERATE_SERIES": _generate_series,
"MAKE_TIME": exp.TimeFromParts.from_arg_list,
"MAKE_TIMESTAMP": exp.TimestampFromParts.from_arg_list,
"NOW": exp.CurrentTimestamp.from_arg_list,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
"TO_TIMESTAMP": _to_timestamp,
@ -387,12 +359,18 @@ class Postgres(Dialect):
class Generator(generator.Generator):
SINGLE_STRING_INTERVAL = True
RENAME_TABLE_WITH_DB = False
LOCKING_READS_SUPPORTED = True
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
NVL2_SUPPORTED = False
PARAMETER_TOKEN = "$"
TABLESAMPLE_SIZE_IS_ROWS = False
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
SUPPORTS_SELECT_INTO = True
# https://www.postgresql.org/docs/current/sql-createtable.html
SUPPORTS_UNLOGGED_TABLES = True
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@ -430,12 +408,13 @@ class Postgres(Dialect):
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"),
exp.JSONBContains: lambda self, e: self.binary(e, "?"),
exp.LastDay: no_last_day_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.Max: max_or_greatest,
exp.MapFromEntries: no_map_from_entries_sql,
exp.Min: min_or_least,
exp.Merge: _merge_sql,
exp.Merge: merge_without_target_sql,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.PercentileCont: transforms.preprocess(
[transforms.add_within_group_for_percentiles]
@ -458,16 +437,16 @@ class Postgres(Dialect):
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.StructExtract: struct_extract_sql,
exp.Substring: _substring_sql,
exp.TimeFromParts: rename_func("MAKE_TIME"),
exp.TimestampFromParts: rename_func("MAKE_TIMESTAMP"),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TableSample: no_tablesample_sql,
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: trim_sql,
exp.TryCast: no_trycast_sql,
exp.TsOrDsAdd: _date_add_sql("+"),
exp.TsOrDsDiff: _date_diff_sql,
exp.TsOrDsToDate: ts_or_ds_to_date_sql("postgres"),
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
exp.VariancePop: rename_func("VAR_POP"),
exp.Variance: rename_func("VAR_SAMP"),

View file

@ -18,6 +18,7 @@ from sqlglot.dialects.dialect import (
no_pivot_sql,
no_safe_divide_sql,
no_timestamp_sql,
path_to_jsonpath,
regexp_extract_sql,
rename_func,
right_to_substring_sql,
@ -99,14 +100,14 @@ def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate)
def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str:
expression = ts_or_ds_add_cast(expression)
unit = exp.Literal.string(expression.text("unit") or "day")
unit = exp.Literal.string(expression.text("unit") or "DAY")
return self.func("DATE_ADD", unit, expression.expression, expression.this)
def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> str:
this = exp.cast(expression.this, "TIMESTAMP")
expr = exp.cast(expression.expression, "TIMESTAMP")
unit = exp.Literal.string(expression.text("unit") or "day")
unit = exp.Literal.string(expression.text("unit") or "DAY")
return self.func("DATE_DIFF", unit, expr, this)
@ -138,13 +139,6 @@ def _from_unixtime(args: t.List) -> exp.Expression:
return exp.UnixToTime.from_arg_list(args)
def _parse_element_at(args: t.List) -> exp.Bracket:
this = seq_get(args, 0)
index = seq_get(args, 1)
assert isinstance(this, exp.Expression) and isinstance(index, exp.Expression)
return exp.Bracket(this=this, expressions=[index], offset=1, safe=True)
def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Table):
if isinstance(expression.this, exp.GenerateSeries):
@ -175,15 +169,8 @@ def _unix_to_time_sql(self: Presto.Generator, expression: exp.UnixToTime) -> str
timestamp = self.sql(expression, "this")
if scale in (None, exp.UnixToTime.SECONDS):
return rename_func("FROM_UNIXTIME")(self, expression)
if scale == exp.UnixToTime.MILLIS:
return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000)"
if scale == exp.UnixToTime.MICROS:
return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000)"
if scale == exp.UnixToTime.NANOS:
return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000000)"
self.unsupported(f"Unsupported scale for timestamp: {scale}.")
return ""
return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / POW(10, {scale}))"
def _to_int(expression: exp.Expression) -> exp.Expression:
@ -215,6 +202,7 @@ class Presto(Dialect):
STRICT_STRING_CONCAT = True
SUPPORTS_SEMI_ANTI_JOIN = False
TYPED_DIVISION = True
TABLESAMPLE_SIZE_IS_PERCENT = True
# https://github.com/trinodb/trino/issues/17
# https://github.com/trinodb/trino/issues/12289
@ -258,7 +246,9 @@ class Presto(Dialect):
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
"DATE_TRUNC": date_trunc_to_time,
"ELEMENT_AT": _parse_element_at,
"ELEMENT_AT": lambda args: exp.Bracket(
this=seq_get(args, 0), expressions=[seq_get(args, 1)], offset=1, safe=True
),
"FROM_HEX": exp.Unhex.from_arg_list,
"FROM_UNIXTIME": _from_unixtime,
"FROM_UTF8": lambda args: exp.Decode(
@ -344,20 +334,20 @@ class Presto(Dialect):
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: lambda self, e: self.func(
"DATE_ADD",
exp.Literal.string(e.text("unit") or "day"),
exp.Literal.string(e.text("unit") or "DAY"),
_to_int(
e.expression,
),
e.this,
),
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
"DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
exp.DateSub: lambda self, e: self.func(
"DATE_ADD",
exp.Literal.string(e.text("unit") or "day"),
exp.Literal.string(e.text("unit") or "DAY"),
_to_int(e.expression * -1),
e.this,
),
@ -366,6 +356,7 @@ class Presto(Dialect):
exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"),
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.First: _first_last_sql,
exp.GetPath: path_to_jsonpath(),
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.GroupConcat: lambda self, e: self.func(
"ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator")
@ -376,6 +367,7 @@ class Presto(Dialect):
exp.Initcap: _initcap_sql,
exp.ParseJSON: rename_func("JSON_PARSE"),
exp.Last: _first_last_sql,
exp.LastDay: lambda self, e: self.func("LAST_DAY_OF_MONTH", e.this),
exp.Lateral: _explode_to_unnest_sql,
exp.Left: left_to_substring_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
@ -446,7 +438,7 @@ class Presto(Dialect):
return super().bracket_sql(expression)
def struct_sql(self, expression: exp.Struct) -> str:
if any(isinstance(arg, self.KEY_VALUE_DEFINITONS) for arg in expression.expressions):
if any(isinstance(arg, self.KEY_VALUE_DEFINITIONS) for arg in expression.expressions):
self.unsupported("Struct with key-value definitions is unsupported.")
return self.function_fallback_sql(expression)
@ -454,8 +446,8 @@ class Presto(Dialect):
def interval_sql(self, expression: exp.Interval) -> str:
unit = self.sql(expression, "unit")
if expression.this and unit.lower().startswith("week"):
return f"({expression.this.name} * INTERVAL '7' day)"
if expression.this and unit.startswith("WEEK"):
return f"({expression.this.name} * INTERVAL '7' DAY)"
return super().interval_sql(expression)
def transaction_sql(self, expression: exp.Transaction) -> str:

View file

@ -9,8 +9,8 @@ from sqlglot.dialects.dialect import (
concat_ws_to_dpipe_sql,
date_delta_sql,
generatedasidentitycolumnconstraint_sql,
no_tablesample_sql,
rename_func,
ts_or_ds_to_date_sql,
)
from sqlglot.dialects.postgres import Postgres
from sqlglot.helper import seq_get
@ -123,6 +123,27 @@ class Redshift(Postgres):
self._retreat(index)
return None
def _parse_query_modifiers(
self, this: t.Optional[exp.Expression]
) -> t.Optional[exp.Expression]:
this = super()._parse_query_modifiers(this)
if this:
refs = set()
for i, join in enumerate(this.args.get("joins", [])):
refs.add(
(
this.args["from"] if i == 0 else this.args["joins"][i - 1]
).alias_or_name.lower()
)
table = join.this
if isinstance(table, exp.Table):
if table.parts[0].name.lower() in refs:
table.replace(table.to_column())
return this
class Tokenizer(Postgres.Tokenizer):
BIT_STRINGS = []
HEX_STRINGS = []
@ -144,11 +165,11 @@ class Redshift(Postgres):
class Generator(Postgres.Generator):
LOCKING_READS_SUPPORTED = False
RENAME_TABLE_WITH_DB = False
QUERY_HINTS = False
VALUES_AS_TABLE = False
TZ_TO_WITH_TIME_ZONE = True
NVL2_SUPPORTED = True
LAST_DAY_SUPPORTS_DATE_PART = False
TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING,
@ -184,9 +205,9 @@ class Redshift(Postgres):
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
),
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
exp.TableSample: no_tablesample_sql,
exp.TsOrDsAdd: date_delta_sql("DATEADD"),
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
exp.TsOrDsToDate: ts_or_ds_to_date_sql("redshift"),
}
# Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots
@ -198,6 +219,9 @@ class Redshift(Postgres):
# Redshift supports ANY_VALUE(..)
TRANSFORMS.pop(exp.AnyValue)
# Redshift supports LAST_DAY(..)
TRANSFORMS.pop(exp.LastDay)
RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"}
def with_properties(self, properties: exp.Properties) -> str:

View file

@ -19,7 +19,6 @@ from sqlglot.dialects.dialect import (
rename_func,
timestamptrunc_sql,
timestrtotime_sql,
ts_or_ds_to_date_sql,
var_map_sql,
)
from sqlglot.expressions import Literal
@ -40,21 +39,7 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime,
if second_arg.is_string:
# case: <string_expr> [ , <format> ]
return format_time_lambda(exp.StrToTime, "snowflake")(args)
# case: <numeric_expr> [ , <scale> ]
if second_arg.name not in ["0", "3", "9"]:
raise ValueError(
f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9"
)
if second_arg.name == "0":
timescale = exp.UnixToTime.SECONDS
elif second_arg.name == "3":
timescale = exp.UnixToTime.MILLIS
elif second_arg.name == "9":
timescale = exp.UnixToTime.NANOS
return exp.UnixToTime(this=first_arg, scale=timescale)
return exp.UnixToTime(this=first_arg, scale=second_arg)
from sqlglot.optimizer.simplify import simplify_literals
@ -91,23 +76,9 @@ def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
def _parse_datediff(args: t.List) -> exp.DateDiff:
return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
def _unix_to_time_sql(self: Snowflake.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale in (None, exp.UnixToTime.SECONDS):
return f"TO_TIMESTAMP({timestamp})"
if scale == exp.UnixToTime.MILLIS:
return f"TO_TIMESTAMP({timestamp}, 3)"
if scale == exp.UnixToTime.MICROS:
return f"TO_TIMESTAMP({timestamp} / 1000, 3)"
if scale == exp.UnixToTime.NANOS:
return f"TO_TIMESTAMP({timestamp}, 9)"
self.unsupported(f"Unsupported scale for timestamp: {scale}.")
return ""
return exp.DateDiff(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=_map_date_part(seq_get(args, 0))
)
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
@ -120,14 +91,15 @@ def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]:
self._match(TokenType.COMMA)
expression = self._parse_bitwise()
this = _map_date_part(this)
name = this.name.upper()
if name.startswith("EPOCH"):
if name.startswith("EPOCH_MILLISECOND"):
if name == "EPOCH_MILLISECOND":
scale = 10**3
elif name.startswith("EPOCH_MICROSECOND"):
elif name == "EPOCH_MICROSECOND":
scale = 10**6
elif name.startswith("EPOCH_NANOSECOND"):
elif name == "EPOCH_NANOSECOND":
scale = 10**9
else:
scale = None
@ -204,6 +176,159 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[Snowflake.Parser]
return _parse
DATE_PART_MAPPING = {
"Y": "YEAR",
"YY": "YEAR",
"YYY": "YEAR",
"YYYY": "YEAR",
"YR": "YEAR",
"YEARS": "YEAR",
"YRS": "YEAR",
"MM": "MONTH",
"MON": "MONTH",
"MONS": "MONTH",
"MONTHS": "MONTH",
"D": "DAY",
"DD": "DAY",
"DAYS": "DAY",
"DAYOFMONTH": "DAY",
"WEEKDAY": "DAYOFWEEK",
"DOW": "DAYOFWEEK",
"DW": "DAYOFWEEK",
"WEEKDAY_ISO": "DAYOFWEEKISO",
"DOW_ISO": "DAYOFWEEKISO",
"DW_ISO": "DAYOFWEEKISO",
"YEARDAY": "DAYOFYEAR",
"DOY": "DAYOFYEAR",
"DY": "DAYOFYEAR",
"W": "WEEK",
"WK": "WEEK",
"WEEKOFYEAR": "WEEK",
"WOY": "WEEK",
"WY": "WEEK",
"WEEK_ISO": "WEEKISO",
"WEEKOFYEARISO": "WEEKISO",
"WEEKOFYEAR_ISO": "WEEKISO",
"Q": "QUARTER",
"QTR": "QUARTER",
"QTRS": "QUARTER",
"QUARTERS": "QUARTER",
"H": "HOUR",
"HH": "HOUR",
"HR": "HOUR",
"HOURS": "HOUR",
"HRS": "HOUR",
"M": "MINUTE",
"MI": "MINUTE",
"MIN": "MINUTE",
"MINUTES": "MINUTE",
"MINS": "MINUTE",
"S": "SECOND",
"SEC": "SECOND",
"SECONDS": "SECOND",
"SECS": "SECOND",
"MS": "MILLISECOND",
"MSEC": "MILLISECOND",
"MILLISECONDS": "MILLISECOND",
"US": "MICROSECOND",
"USEC": "MICROSECOND",
"MICROSECONDS": "MICROSECOND",
"NS": "NANOSECOND",
"NSEC": "NANOSECOND",
"NANOSEC": "NANOSECOND",
"NSECOND": "NANOSECOND",
"NSECONDS": "NANOSECOND",
"NANOSECS": "NANOSECOND",
"NSECONDS": "NANOSECOND",
"EPOCH": "EPOCH_SECOND",
"EPOCH_SECONDS": "EPOCH_SECOND",
"EPOCH_MILLISECONDS": "EPOCH_MILLISECOND",
"EPOCH_MICROSECONDS": "EPOCH_MICROSECOND",
"EPOCH_NANOSECONDS": "EPOCH_NANOSECOND",
"TZH": "TIMEZONE_HOUR",
"TZM": "TIMEZONE_MINUTE",
}
@t.overload
def _map_date_part(part: exp.Expression) -> exp.Var:
pass
@t.overload
def _map_date_part(part: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
pass
def _map_date_part(part):
mapped = DATE_PART_MAPPING.get(part.name.upper()) if part else None
return exp.var(mapped) if mapped else part
def _date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
trunc = date_trunc_to_time(args)
trunc.set("unit", _map_date_part(trunc.args["unit"]))
return trunc
def _parse_colon_get_path(
self: parser.Parser, this: t.Optional[exp.Expression]
) -> t.Optional[exp.Expression]:
while True:
path = self._parse_bitwise()
# The cast :: operator has a lower precedence than the extraction operator :, so
# we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH
if isinstance(path, exp.Cast):
target_type = path.to
path = path.this
else:
target_type = None
if isinstance(path, exp.Expression):
path = exp.Literal.string(path.sql(dialect="snowflake"))
# The extraction operator : is left-associative
this = self.expression(exp.GetPath, this=this, expression=path)
if target_type:
this = exp.cast(this, target_type)
if not self._match(TokenType.COLON):
break
if self._match_set(self.RANGE_PARSERS):
this = self.RANGE_PARSERS[self._prev.token_type](self, this) or this
return this
def _parse_timestamp_from_parts(args: t.List) -> exp.Func:
if len(args) == 2:
# Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept,
# so we parse this into Anonymous for now instead of introducing complexity
return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args)
return exp.TimestampFromParts.from_arg_list(args)
def _unqualify_unpivot_columns(expression: exp.Expression) -> exp.Expression:
"""
Snowflake doesn't allow columns referenced in UNPIVOT to be qualified,
so we need to unqualify them.
Example:
>>> from sqlglot import parse_one
>>> expr = parse_one("SELECT * FROM m_sales UNPIVOT(sales FOR month IN (m_sales.jan, feb, mar, april))")
>>> print(_unqualify_unpivot_columns(expr).sql(dialect="snowflake"))
SELECT * FROM m_sales UNPIVOT(sales FOR month IN (jan, feb, mar, april))
"""
if isinstance(expression, exp.Pivot) and expression.unpivot:
expression = transforms.unqualify_columns(expression)
return expression
class Snowflake(Dialect):
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
@ -211,6 +336,8 @@ class Snowflake(Dialect):
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
SUPPORTS_USER_DEFINED_TYPES = False
SUPPORTS_SEMI_ANTI_JOIN = False
PREFER_CTE_ALIAS_COLUMN = True
TABLESAMPLE_SIZE_IS_PERCENT = True
TIME_MAPPING = {
"YYYY": "%Y",
@ -276,14 +403,19 @@ class Snowflake(Dialect):
"BIT_XOR": binary_from_function(exp.BitwiseXor),
"BOOLXOR": binary_from_function(exp.Xor),
"CONVERT_TIMEZONE": _parse_convert_timezone,
"DATE_TRUNC": date_trunc_to_time,
"DATE_TRUNC": _date_trunc_to_time,
"DATEADD": lambda args: exp.DateAdd(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
this=seq_get(args, 2),
expression=seq_get(args, 1),
unit=_map_date_part(seq_get(args, 0)),
),
"DATEDIFF": _parse_datediff,
"DIV0": _div0_to_if,
"FLATTEN": exp.Explode.from_arg_list,
"IFF": exp.If.from_arg_list,
"LAST_DAY": lambda args: exp.LastDay(
this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1))
),
"LISTAGG": exp.GroupConcat.from_arg_list,
"NULLIFZERO": _nullifzero_to_if,
"OBJECT_CONSTRUCT": _parse_object_construct,
@ -293,6 +425,8 @@ class Snowflake(Dialect):
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"TIMEDIFF": _parse_datediff,
"TIMESTAMPDIFF": _parse_datediff,
"TIMESTAMPFROMPARTS": _parse_timestamp_from_parts,
"TIMESTAMP_FROM_PARTS": _parse_timestamp_from_parts,
"TO_TIMESTAMP": _parse_to_timestamp,
"TO_VARCHAR": exp.ToChar.from_arg_list,
"ZEROIFNULL": _zeroifnull_to_if,
@ -301,22 +435,17 @@ class Snowflake(Dialect):
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
"DATE_PART": _parse_date_part,
"OBJECT_CONSTRUCT_KEEP_NULL": lambda self: self._parse_json_object(),
}
FUNCTION_PARSERS.pop("TRIM")
COLUMN_OPERATORS = {
**parser.Parser.COLUMN_OPERATORS,
TokenType.COLON: lambda self, this, path: self.expression(
exp.Bracket, this=this, expressions=[path]
),
}
TIMESTAMPS = parser.Parser.TIMESTAMPS - {TokenType.TIME}
RANGE_PARSERS = {
**parser.Parser.RANGE_PARSERS,
TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny),
TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny),
TokenType.COLON: _parse_colon_get_path,
}
ALTER_PARSERS = {
@ -344,6 +473,7 @@ class Snowflake(Dialect):
SHOW_PARSERS = {
"PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
"COLUMNS": _show_parser("COLUMNS"),
}
STAGED_FILE_SINGLE_TOKENS = {
@ -351,8 +481,18 @@ class Snowflake(Dialect):
TokenType.MOD,
TokenType.SLASH,
}
FLATTEN_COLUMNS = ["SEQ", "KEY", "PATH", "INDEX", "VALUE", "THIS"]
def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]:
if is_map:
# Keys are strings in Snowflake's objects, see also:
# - https://docs.snowflake.com/en/sql-reference/data-types-semistructured
# - https://docs.snowflake.com/en/sql-reference/functions/object_construct
return self._parse_slice(self._parse_string())
return self._parse_slice(self._parse_alias(self._parse_conjunction(), explicit=True))
def _parse_lateral(self) -> t.Optional[exp.Lateral]:
lateral = super()._parse_lateral()
if not lateral:
@ -440,6 +580,8 @@ class Snowflake(Dialect):
scope = None
scope_kind = None
like = self._parse_string() if self._match(TokenType.LIKE) else None
if self._match(TokenType.IN):
if self._match_text_seq("ACCOUNT"):
scope_kind = "ACCOUNT"
@ -451,7 +593,9 @@ class Snowflake(Dialect):
scope_kind = "TABLE"
scope = self._parse_table()
return self.expression(exp.Show, this=this, scope=scope, scope_kind=scope_kind)
return self.expression(
exp.Show, this=this, like=like, scope=scope, scope_kind=scope_kind
)
def _parse_alter_table_swap(self) -> exp.SwapTable:
self._match_text_seq("WITH")
@ -489,8 +633,12 @@ class Snowflake(Dialect):
"MINUS": TokenType.EXCEPT,
"NCHAR VARYING": TokenType.VARCHAR,
"PUT": TokenType.COMMAND,
"REMOVE": TokenType.COMMAND,
"RENAME": TokenType.REPLACE,
"RM": TokenType.COMMAND,
"SAMPLE": TokenType.TABLE_SAMPLE,
"SQL_DOUBLE": TokenType.DOUBLE,
"SQL_VARCHAR": TokenType.VARCHAR,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
@ -518,6 +666,8 @@ class Snowflake(Dialect):
SUPPORTS_TABLE_COPY = False
COLLATE_IS_FUNC = True
LIMIT_ONLY_LITERALS = True
JSON_KEY_VALUE_PAIR_SEP = ","
INSERT_OVERWRITE = " OVERWRITE INTO"
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@ -545,6 +695,8 @@ class Snowflake(Dialect):
),
exp.GroupConcat: rename_func("LISTAGG"),
exp.If: if_sql(name="IFF", false_value="NULL"),
exp.JSONExtract: lambda self, e: f"{self.sql(e, 'this')}[{self.sql(e, 'expression')}]",
exp.JSONObject: lambda self, e: self.func("OBJECT_CONSTRUCT_KEEP_NULL", *e.expressions),
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
exp.LogicalOr: rename_func("BOOLOR_AGG"),
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
@ -557,6 +709,7 @@ class Snowflake(Dialect):
exp.PercentileDisc: transforms.preprocess(
[transforms.add_within_group_for_percentiles]
),
exp.Pivot: transforms.preprocess([_unqualify_unpivot_columns]),
exp.RegexpILike: _regexpilike_sql,
exp.Rand: rename_func("RANDOM"),
exp.Select: transforms.preprocess(
@ -578,6 +731,9 @@ class Snowflake(Dialect):
*(arg for expression in e.expressions for arg in expression.flatten()),
),
exp.Stuff: rename_func("INSERT"),
exp.TimestampDiff: lambda self, e: self.func(
"TIMESTAMPDIFF", e.unit, e.expression, e.this
),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: lambda self, e: self.func(
@ -589,8 +745,7 @@ class Snowflake(Dialect):
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"),
exp.UnixToTime: _unix_to_time_sql,
exp.UnixToTime: rename_func("TO_TIMESTAMP"),
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
exp.Xor: rename_func("BOOLXOR"),
@ -612,6 +767,14 @@ class Snowflake(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str:
milli = expression.args.get("milli")
if milli is not None:
milli_to_nano = milli.pop() * exp.Literal.number(1000000)
expression.set("nano", milli_to_nano)
return rename_func("TIMESTAMP_FROM_PARTS")(self, expression)
def trycast_sql(self, expression: exp.TryCast) -> str:
value = expression.this
@ -657,6 +820,9 @@ class Snowflake(Dialect):
return f"{explode}{alias}"
def show_sql(self, expression: exp.Show) -> str:
like = self.sql(expression, "like")
like = f" LIKE {like}" if like else ""
scope = self.sql(expression, "scope")
scope = f" {scope}" if scope else ""
@ -664,7 +830,7 @@ class Snowflake(Dialect):
if scope_kind:
scope_kind = f" IN {scope_kind}"
return f"SHOW {expression.name}{scope_kind}{scope}"
return f"SHOW {expression.name}{like}{scope_kind}{scope}"
def regexpextract_sql(self, expression: exp.RegexpExtract) -> str:
# Other dialects don't support all of the following parameters, so we need to

View file

@ -48,11 +48,8 @@ def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str
return f"TIMESTAMP_MILLIS({timestamp})"
if scale == exp.UnixToTime.MICROS:
return f"TIMESTAMP_MICROS({timestamp})"
if scale == exp.UnixToTime.NANOS:
return f"TIMESTAMP_SECONDS({timestamp} / 1000000000)"
self.unsupported(f"Unsupported scale for timestamp: {scale}.")
return ""
return f"TIMESTAMP_SECONDS({timestamp} / POW(10, {scale}))"
def _unalias_pivot(expression: exp.Expression) -> exp.Expression:
@ -93,12 +90,7 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
SELECT * FROM tbl PIVOT(SUM(tbl.sales) FOR quarter IN ('Q1', 'Q1'))
"""
if isinstance(expression, exp.Pivot):
expression.args["field"].transform(
lambda node: exp.column(node.output_name, quoted=node.this.quoted)
if isinstance(node, exp.Column)
else node,
copy=False,
)
expression.set("field", transforms.unqualify_columns(expression.args["field"]))
return expression
@ -234,7 +226,7 @@ class Spark2(Hive):
def struct_sql(self, expression: exp.Struct) -> str:
args = []
for arg in expression.expressions:
if isinstance(arg, self.KEY_VALUE_DEFINITONS):
if isinstance(arg, self.KEY_VALUE_DEFINITIONS):
if isinstance(arg, exp.Bracket):
args.append(exp.alias_(arg.this, arg.expressions[0].name))
else:

View file

@ -78,6 +78,7 @@ class SQLite(Dialect):
**parser.Parser.FUNCTIONS,
"EDITDIST3": exp.Levenshtein.from_arg_list,
}
STRING_ALIASES = True
class Generator(generator.Generator):
JOIN_HINTS = False

View file

@ -175,6 +175,8 @@ class Teradata(Dialect):
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
TABLESAMPLE_KEYWORDS = "SAMPLE"
LAST_DAY_SUPPORTS_DATE_PART = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@ -214,7 +216,10 @@ class Teradata(Dialect):
return self.cast_sql(expression, safe_prefix="TRY")
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
self,
expression: exp.TableSample,
sep: str = " AS ",
tablesample_keyword: t.Optional[str] = None,
) -> str:
return f"{self.sql(expression, 'this')} SAMPLE {self.expressions(expression)}"

View file

@ -1,6 +1,7 @@
from __future__ import annotations
from sqlglot import exp
from sqlglot.dialects.dialect import merge_without_target_sql
from sqlglot.dialects.presto import Presto
@ -11,6 +12,7 @@ class Trino(Presto):
TRANSFORMS = {
**Presto.Generator.TRANSFORMS,
exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.Merge: merge_without_target_sql,
}
class Tokenizer(Presto.Tokenizer):

View file

@ -14,9 +14,10 @@ from sqlglot.dialects.dialect import (
max_or_greatest,
min_or_least,
parse_date_delta,
path_to_jsonpath,
rename_func,
timestrtotime_sql,
ts_or_ds_to_date_sql,
trim_sql,
)
from sqlglot.expressions import DataType
from sqlglot.helper import seq_get
@ -105,18 +106,17 @@ def _parse_format(args: t.List) -> exp.Expression:
return exp.TimeToStr(this=this, format=fmt, culture=culture)
def _parse_eomonth(args: t.List) -> exp.Expression:
date = seq_get(args, 0)
def _parse_eomonth(args: t.List) -> exp.LastDay:
date = exp.TsOrDsToDate(this=seq_get(args, 0))
month_lag = seq_get(args, 1)
unit = DATE_DELTA_INTERVAL.get("month")
if month_lag is None:
return exp.LastDateOfMonth(this=date)
this: exp.Expression = date
else:
unit = DATE_DELTA_INTERVAL.get("month")
this = exp.DateAdd(this=date, expression=month_lag, unit=unit and exp.var(unit))
# Remove month lag argument in parser as its compared with the number of arguments of the resulting class
args.remove(month_lag)
return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit))
return exp.LastDay(this=this)
def _parse_hashbytes(args: t.List) -> exp.Expression:
@ -137,26 +137,27 @@ def _parse_hashbytes(args: t.List) -> exp.Expression:
return exp.func("HASHBYTES", *args)
DATEPART_ONLY_FORMATS = {"dw", "hour", "quarter"}
DATEPART_ONLY_FORMATS = {"DW", "HOUR", "QUARTER"}
def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str:
fmt = (
expression.args["format"]
if isinstance(expression, exp.NumberToStr)
else exp.Literal.string(
format_time(
expression.text("format"),
t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING),
)
)
)
fmt = expression.args["format"]
# There is no format for "quarter"
if fmt.name.lower() in DATEPART_ONLY_FORMATS:
return self.func("DATEPART", fmt.name, expression.this)
if not isinstance(expression, exp.NumberToStr):
if fmt.is_string:
mapped_fmt = format_time(fmt.name, TSQL.INVERSE_TIME_MAPPING)
return self.func("FORMAT", expression.this, fmt, expression.args.get("culture"))
name = (mapped_fmt or "").upper()
if name in DATEPART_ONLY_FORMATS:
return self.func("DATEPART", name, expression.this)
fmt_sql = self.sql(exp.Literal.string(mapped_fmt))
else:
fmt_sql = self.format_time(expression) or self.sql(fmt)
else:
fmt_sql = self.sql(fmt)
return self.func("FORMAT", expression.this, fmt_sql, expression.args.get("culture"))
def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str:
@ -239,6 +240,30 @@ def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression:
return expression
# https://learn.microsoft.com/en-us/sql/t-sql/functions/datetimefromparts-transact-sql?view=sql-server-ver16#syntax
def _parse_datetimefromparts(args: t.List) -> exp.TimestampFromParts:
return exp.TimestampFromParts(
year=seq_get(args, 0),
month=seq_get(args, 1),
day=seq_get(args, 2),
hour=seq_get(args, 3),
min=seq_get(args, 4),
sec=seq_get(args, 5),
milli=seq_get(args, 6),
)
# https://learn.microsoft.com/en-us/sql/t-sql/functions/timefromparts-transact-sql?view=sql-server-ver16#syntax
def _parse_timefromparts(args: t.List) -> exp.TimeFromParts:
return exp.TimeFromParts(
hour=seq_get(args, 0),
min=seq_get(args, 1),
sec=seq_get(args, 2),
fractions=seq_get(args, 3),
precision=seq_get(args, 4),
)
class TSQL(Dialect):
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"
@ -352,7 +377,7 @@ class TSQL(Dialect):
}
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]")]
IDENTIFIERS = [("[", "]"), '"']
QUOTES = ["'", '"']
HEX_STRINGS = [("0x", ""), ("0X", "")]
VAR_SINGLE_TOKENS = {"@", "$", "#"}
@ -362,6 +387,7 @@ class TSQL(Dialect):
"DATETIME2": TokenType.DATETIME,
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
"DECLARE": TokenType.COMMAND,
"EXEC": TokenType.COMMAND,
"IMAGE": TokenType.IMAGE,
"MONEY": TokenType.MONEY,
"NTEXT": TokenType.TEXT,
@ -397,6 +423,7 @@ class TSQL(Dialect):
"DATEDIFF": _parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
"DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True),
"DATEPART": _format_time_lambda(exp.TimeToStr),
"DATETIMEFROMPARTS": _parse_datetimefromparts,
"EOMONTH": _parse_eomonth,
"FORMAT": _parse_format,
"GETDATE": exp.CurrentTimestamp.from_arg_list,
@ -411,6 +438,7 @@ class TSQL(Dialect):
"SUSER_NAME": exp.CurrentUser.from_arg_list,
"SUSER_SNAME": exp.CurrentUser.from_arg_list,
"SYSTEM_USER": exp.CurrentUser.from_arg_list,
"TIMEFROMPARTS": _parse_timefromparts,
}
JOIN_HINTS = {
@ -440,6 +468,7 @@ class TSQL(Dialect):
LOG_DEFAULTS_TO_LN = True
ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
STRING_ALIASES = True
def _parse_projections(self) -> t.List[exp.Expression]:
"""
@ -630,8 +659,10 @@ class TSQL(Dialect):
COMPUTED_COLUMN_WITH_TYPE = False
CTE_RECURSIVE_KEYWORD_REQUIRED = False
ENSURE_BOOLS = True
NULL_ORDERING_SUPPORTED = False
NULL_ORDERING_SUPPORTED = None
SUPPORTS_SINGLE_ARG_CONCAT = False
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
SUPPORTS_SELECT_INTO = True
EXPRESSIONS_WITHOUT_NESTED_CTES = {
exp.Delete,
@ -667,13 +698,16 @@ class TSQL(Dialect):
exp.CurrentTimestamp: rename_func("GETDATE"),
exp.Extract: rename_func("DATEPART"),
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
exp.GetPath: path_to_jsonpath("JSON_VALUE"),
exp.GroupConcat: _string_agg_sql,
exp.If: rename_func("IIF"),
exp.LastDay: lambda self, e: self.func("EOMONTH", e.this),
exp.Length: rename_func("LEN"),
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
exp.Min: min_or_least,
exp.NumberToStr: _format_sql,
exp.ParseJSON: lambda self, e: self.sql(e, "this"),
exp.Select: transforms.preprocess(
[
transforms.eliminate_distinct_on,
@ -689,9 +723,9 @@ class TSQL(Dialect):
exp.TemporaryProperty: lambda self, e: "",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: _format_sql,
exp.Trim: trim_sql,
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
exp.TsOrDsToDate: ts_or_ds_to_date_sql("tsql"),
}
TRANSFORMS.pop(exp.ReturnsProperty)
@ -701,6 +735,46 @@ class TSQL(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
def lateral_op(self, expression: exp.Lateral) -> str:
cross_apply = expression.args.get("cross_apply")
if cross_apply is True:
return "CROSS APPLY"
if cross_apply is False:
return "OUTER APPLY"
# TODO: perhaps we can check if the parent is a Join and transpile it appropriately
self.unsupported("LATERAL clause is not supported.")
return "LATERAL"
def timefromparts_sql(self, expression: exp.TimeFromParts) -> str:
nano = expression.args.get("nano")
if nano is not None:
nano.pop()
self.unsupported("Specifying nanoseconds is not supported in TIMEFROMPARTS.")
if expression.args.get("fractions") is None:
expression.set("fractions", exp.Literal.number(0))
if expression.args.get("precision") is None:
expression.set("precision", exp.Literal.number(0))
return rename_func("TIMEFROMPARTS")(self, expression)
def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str:
zone = expression.args.get("zone")
if zone is not None:
zone.pop()
self.unsupported("Time zone is not supported in DATETIMEFROMPARTS.")
nano = expression.args.get("nano")
if nano is not None:
nano.pop()
self.unsupported("Specifying nanoseconds is not supported in DATETIMEFROMPARTS.")
if expression.args.get("milli") is None:
expression.set("milli", exp.Literal.number(0))
return rename_func("DATETIMEFROMPARTS")(self, expression)
def set_operation(self, expression: exp.Union, op: str) -> str:
limit = expression.args.get("limit")
if limit:

View file

@ -132,11 +132,10 @@ def ordered(this, desc, nulls_first):
@null_if_any
def interval(this, unit):
unit = unit.lower()
plural = unit + "s"
plural = unit + "S"
if plural in Generator.TIME_PART_SINGULARS:
unit = plural
return datetime.timedelta(**{unit: float(this)})
return datetime.timedelta(**{unit.lower(): float(this)})
@null_if_any("this", "expression")
@ -176,6 +175,7 @@ ENV = {
"DOT": null_if_any(lambda e, this: e[this]),
"EQ": null_if_any(lambda this, e: this == e),
"EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
"GETPATH": null_if_any(lambda this, e: this.get(e)),
"GT": null_if_any(lambda this, e: this > e),
"GTE": null_if_any(lambda this, e: this >= e),
"IF": lambda predicate, true, false: true if predicate else false,

View file

@ -16,6 +16,7 @@ import datetime
import math
import numbers
import re
import textwrap
import typing as t
from collections import deque
from copy import deepcopy
@ -35,6 +36,8 @@ from sqlglot.helper import (
from sqlglot.tokens import Token
if t.TYPE_CHECKING:
from typing_extensions import Literal as Lit
from sqlglot.dialects.dialect import DialectType
@ -242,6 +245,9 @@ class Expression(metaclass=_Expression):
def is_type(self, *dtypes) -> bool:
return self.type is not None and self.type.is_type(*dtypes)
def is_leaf(self) -> bool:
return not any(isinstance(v, (Expression, list)) for v in self.args.values())
@property
def meta(self) -> t.Dict[str, t.Any]:
if self._meta is None:
@ -497,7 +503,14 @@ class Expression(metaclass=_Expression):
return self.sql()
def __repr__(self) -> str:
return self._to_s()
return _to_s(self)
def to_s(self) -> str:
"""
Same as __repr__, but includes additional information which can be useful
for debugging, like empty or missing args and the AST nodes' object IDs.
"""
return _to_s(self, verbose=True)
def sql(self, dialect: DialectType = None, **opts) -> str:
"""
@ -514,30 +527,6 @@ class Expression(metaclass=_Expression):
return Dialect.get_or_raise(dialect).generate(self, **opts)
def _to_s(self, hide_missing: bool = True, level: int = 0) -> str:
indent = "" if not level else "\n"
indent += "".join([" "] * level)
left = f"({self.key.upper()} "
args: t.Dict[str, t.Any] = {
k: ", ".join(
v._to_s(hide_missing=hide_missing, level=level + 1)
if hasattr(v, "_to_s")
else str(v)
for v in ensure_list(vs)
if v is not None
)
for k, vs in self.args.items()
}
args["comments"] = self.comments
args["type"] = self.type
args = {k: v for k, v in args.items() if v or not hide_missing}
right = ", ".join(f"{k}: {v}" for k, v in args.items())
right += ")"
return indent + left + right
def transform(self, fun, *args, copy=True, **kwargs):
"""
Recursively visits all tree nodes (excluding already transformed ones)
@ -580,8 +569,9 @@ class Expression(metaclass=_Expression):
For example::
>>> tree = Select().select("x").from_("tbl")
>>> tree.find(Column).replace(Column(this="y"))
(COLUMN this: y)
>>> tree.find(Column).replace(column("y"))
Column(
this=Identifier(this=y, quoted=False))
>>> tree.sql()
'SELECT y FROM tbl'
@ -831,6 +821,9 @@ class Expression(metaclass=_Expression):
div.args["safe"] = safe
return div
def desc(self, nulls_first: bool = False) -> Ordered:
return Ordered(this=self.copy(), desc=True, nulls_first=nulls_first)
def __lt__(self, other: t.Any) -> LT:
return self._binop(LT, other)
@ -1109,7 +1102,7 @@ class Clone(Expression):
class Describe(Expression):
arg_types = {"this": True, "kind": False, "expressions": False}
arg_types = {"this": True, "extended": False, "kind": False, "expressions": False}
class Kill(Expression):
@ -1124,6 +1117,10 @@ class Set(Expression):
arg_types = {"expressions": False, "unset": False, "tag": False}
class Heredoc(Expression):
arg_types = {"this": True, "tag": False}
class SetItem(Expression):
arg_types = {
"this": False,
@ -1937,7 +1934,13 @@ class Join(Expression):
class Lateral(UDTF):
arg_types = {"this": True, "view": False, "outer": False, "alias": False}
arg_types = {
"this": True,
"view": False,
"outer": False,
"alias": False,
"cross_apply": False, # True -> CROSS APPLY, False -> OUTER APPLY
}
class MatchRecognize(Expression):
@ -1964,7 +1967,12 @@ class Offset(Expression):
class Order(Expression):
arg_types = {"this": False, "expressions": True, "interpolate": False}
arg_types = {
"this": False,
"expressions": True,
"interpolate": False,
"siblings": False,
}
# https://clickhouse.com/docs/en/sql-reference/statements/select/order-by#order-by-expr-with-fill-modifier
@ -2002,6 +2010,11 @@ class AutoIncrementProperty(Property):
arg_types = {"this": True}
# https://docs.aws.amazon.com/prescriptive-guidance/latest/materialized-views-redshift/refreshing-materialized-views.html
class AutoRefreshProperty(Property):
arg_types = {"this": True}
class BlockCompressionProperty(Property):
arg_types = {"autotemp": False, "always": False, "default": True, "manual": True, "never": True}
@ -2259,6 +2272,10 @@ class SortKeyProperty(Property):
arg_types = {"this": True, "compound": False}
class SqlReadWriteProperty(Property):
arg_types = {"this": True}
class SqlSecurityProperty(Property):
arg_types = {"definer": True}
@ -2543,7 +2560,6 @@ class Table(Expression):
"version": False,
"format": False,
"pattern": False,
"index": False,
"ordinality": False,
"when": False,
}
@ -2585,6 +2601,14 @@ class Table(Expression):
return parts
def to_column(self, copy: bool = True) -> Alias | Column | Dot:
parts = self.parts
col = column(*reversed(parts[0:4]), fields=parts[4:], copy=copy) # type: ignore
alias = self.args.get("alias")
if alias:
col = alias_(col, alias.this, copy=copy)
return col
class Union(Subqueryable):
arg_types = {
@ -2694,6 +2718,14 @@ class Unnest(UDTF):
"offset": False,
}
@property
def selects(self) -> t.List[Expression]:
columns = super().selects
offset = self.args.get("offset")
if offset:
columns = columns + [to_identifier("offset") if offset is True else offset]
return columns
class Update(Expression):
arg_types = {
@ -3368,7 +3400,7 @@ class Select(Subqueryable):
return Create(
this=table_expression,
kind="table",
kind="TABLE",
expression=instance,
properties=properties_expression,
)
@ -3488,7 +3520,6 @@ class TableSample(Expression):
"rows": False,
"size": False,
"seed": False,
"kind": False,
}
@ -3517,6 +3548,10 @@ class Pivot(Expression):
"include_nulls": False,
}
@property
def unpivot(self) -> bool:
return bool(self.args.get("unpivot"))
class Window(Condition):
arg_types = {
@ -3604,6 +3639,7 @@ class DataType(Expression):
BOOLEAN = auto()
CHAR = auto()
DATE = auto()
DATE32 = auto()
DATEMULTIRANGE = auto()
DATERANGE = auto()
DATETIME = auto()
@ -3631,6 +3667,8 @@ class DataType(Expression):
INTERVAL = auto()
IPADDRESS = auto()
IPPREFIX = auto()
IPV4 = auto()
IPV6 = auto()
JSON = auto()
JSONB = auto()
LONGBLOB = auto()
@ -3729,6 +3767,7 @@ class DataType(Expression):
Type.TIMESTAMP_MS,
Type.TIMESTAMP_NS,
Type.DATE,
Type.DATE32,
Type.DATETIME,
Type.DATETIME64,
}
@ -4100,6 +4139,12 @@ class Alias(Expression):
return self.alias
# BigQuery requires the UNPIVOT column list aliases to be either strings or ints, but
# other dialects require identifiers. This enables us to transpile between them easily.
class PivotAlias(Alias):
pass
class Aliases(Expression):
arg_types = {"this": True, "expressions": True}
@ -4108,6 +4153,11 @@ class Aliases(Expression):
return self.expressions
# https://docs.aws.amazon.com/redshift/latest/dg/query-super.html
class AtIndex(Expression):
arg_types = {"this": True, "expression": True}
class AtTimeZone(Expression):
arg_types = {"this": True, "zone": True}
@ -4154,16 +4204,16 @@ class TimeUnit(Expression):
arg_types = {"unit": False}
UNABBREVIATED_UNIT_NAME = {
"d": "day",
"h": "hour",
"m": "minute",
"ms": "millisecond",
"ns": "nanosecond",
"q": "quarter",
"s": "second",
"us": "microsecond",
"w": "week",
"y": "year",
"D": "DAY",
"H": "HOUR",
"M": "MINUTE",
"MS": "MILLISECOND",
"NS": "NANOSECOND",
"Q": "QUARTER",
"S": "SECOND",
"US": "MICROSECOND",
"W": "WEEK",
"Y": "YEAR",
}
VAR_LIKE = (Column, Literal, Var)
@ -4171,9 +4221,11 @@ class TimeUnit(Expression):
def __init__(self, **args):
unit = args.get("unit")
if isinstance(unit, self.VAR_LIKE):
args["unit"] = Var(this=self.UNABBREVIATED_UNIT_NAME.get(unit.name) or unit.name)
args["unit"] = Var(
this=(self.UNABBREVIATED_UNIT_NAME.get(unit.name) or unit.name).upper()
)
elif isinstance(unit, Week):
unit.set("this", Var(this=unit.this.name))
unit.set("this", Var(this=unit.this.name.upper()))
super().__init__(**args)
@ -4301,6 +4353,20 @@ class Anonymous(Func):
is_var_len_args = True
class AnonymousAggFunc(AggFunc):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
# https://clickhouse.com/docs/en/sql-reference/aggregate-functions/combinators
class CombinedAggFunc(AnonymousAggFunc):
arg_types = {"this": True, "expressions": False, "parts": True}
class CombinedParameterizedAgg(ParameterizedAgg):
arg_types = {"this": True, "expressions": True, "params": True, "parts": True}
# https://docs.snowflake.com/en/sql-reference/functions/hll
# https://docs.aws.amazon.com/redshift/latest/dg/r_HLL_function.html
class Hll(AggFunc):
@ -4381,7 +4447,7 @@ class ArraySort(Func):
class ArraySum(Func):
pass
arg_types = {"this": True, "expression": False}
class ArrayUnionAgg(AggFunc):
@ -4498,7 +4564,7 @@ class Count(AggFunc):
class CountIf(AggFunc):
pass
_sql_names = ["COUNT_IF", "COUNTIF"]
class CurrentDate(Func):
@ -4537,6 +4603,17 @@ class DateDiff(Func, TimeUnit):
class DateTrunc(Func):
arg_types = {"unit": True, "this": True, "zone": False}
def __init__(self, **args):
unit = args.get("unit")
if isinstance(unit, TimeUnit.VAR_LIKE):
args["unit"] = Literal.string(
(TimeUnit.UNABBREVIATED_UNIT_NAME.get(unit.name) or unit.name).upper()
)
elif isinstance(unit, Week):
unit.set("this", Literal.string(unit.this.name.upper()))
super().__init__(**args)
@property
def unit(self) -> Expression:
return self.args["unit"]
@ -4582,8 +4659,9 @@ class MonthsBetween(Func):
arg_types = {"this": True, "expression": True, "roundoff": False}
class LastDateOfMonth(Func):
pass
class LastDay(Func, TimeUnit):
_sql_names = ["LAST_DAY", "LAST_DAY_OF_MONTH"]
arg_types = {"this": True, "unit": False}
class Extract(Func):
@ -4627,10 +4705,22 @@ class TimeTrunc(Func, TimeUnit):
class DateFromParts(Func):
_sql_names = ["DATEFROMPARTS"]
_sql_names = ["DATE_FROM_PARTS", "DATEFROMPARTS"]
arg_types = {"year": True, "month": True, "day": True}
class TimeFromParts(Func):
_sql_names = ["TIME_FROM_PARTS", "TIMEFROMPARTS"]
arg_types = {
"hour": True,
"min": True,
"sec": True,
"nano": False,
"fractions": False,
"precision": False,
}
class DateStrToDate(Func):
pass
@ -4754,6 +4844,16 @@ class JSONObject(Func):
}
class JSONObjectAgg(AggFunc):
arg_types = {
"expressions": False,
"null_handling": False,
"unique_keys": False,
"return_type": False,
"encoding": False,
}
# https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/JSON_ARRAY.html
class JSONArray(Func):
arg_types = {
@ -4841,6 +4941,15 @@ class ParseJSON(Func):
is_var_len_args = True
# https://docs.snowflake.com/en/sql-reference/functions/get_path
class GetPath(Func):
arg_types = {"this": True, "expression": True}
@property
def output_name(self) -> str:
return self.expression.output_name
class Least(Func):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@ -5026,7 +5135,7 @@ class RegexpReplace(Func):
arg_types = {
"this": True,
"expression": True,
"replacement": True,
"replacement": False,
"position": False,
"occurrence": False,
"parameters": False,
@ -5052,8 +5161,10 @@ class Repeat(Func):
arg_types = {"this": True, "times": True}
# https://learn.microsoft.com/en-us/sql/t-sql/functions/round-transact-sql?view=sql-server-ver16
# tsql third argument function == trunctaion if not 0
class Round(Func):
arg_types = {"this": True, "decimals": False}
arg_types = {"this": True, "decimals": False, "truncate": False}
class RowNumber(Func):
@ -5228,6 +5339,10 @@ class TsOrDsToDate(Func):
arg_types = {"this": True, "format": False}
class TsOrDsToTime(Func):
pass
class TsOrDiToDi(Func):
pass
@ -5236,6 +5351,11 @@ class Unhex(Func):
pass
# https://cloud.google.com/bigquery/docs/reference/standard-sql/date_functions#unix_date
class UnixDate(Func):
pass
class UnixToStr(Func):
arg_types = {"this": True, "format": False}
@ -5245,10 +5365,16 @@ class UnixToStr(Func):
class UnixToTime(Func):
arg_types = {"this": True, "scale": False, "zone": False, "hours": False, "minutes": False}
SECONDS = Literal.string("seconds")
MILLIS = Literal.string("millis")
MICROS = Literal.string("micros")
NANOS = Literal.string("nanos")
SECONDS = Literal.number(0)
DECIS = Literal.number(1)
CENTIS = Literal.number(2)
MILLIS = Literal.number(3)
DECIMILLIS = Literal.number(4)
CENTIMILLIS = Literal.number(5)
MICROS = Literal.number(6)
DECIMICROS = Literal.number(7)
CENTIMICROS = Literal.number(8)
NANOS = Literal.number(9)
class UnixToTimeStr(Func):
@ -5256,8 +5382,7 @@ class UnixToTimeStr(Func):
class TimestampFromParts(Func):
"""Constructs a timestamp given its constituent parts."""
_sql_names = ["TIMESTAMP_FROM_PARTS", "TIMESTAMPFROMPARTS"]
arg_types = {
"year": True,
"month": True,
@ -5265,6 +5390,9 @@ class TimestampFromParts(Func):
"hour": True,
"min": True,
"sec": True,
"nano": False,
"zone": False,
"milli": False,
}
@ -5358,9 +5486,9 @@ def maybe_parse(
Example:
>>> maybe_parse("1")
(LITERAL this: 1, is_string: False)
Literal(this=1, is_string=False)
>>> maybe_parse(to_identifier("x"))
(IDENTIFIER this: x, quoted: False)
Identifier(this=x, quoted=False)
Args:
sql_or_expression: the SQL code string or an expression
@ -5407,6 +5535,39 @@ def maybe_copy(instance, copy=True):
return instance.copy() if copy and instance else instance
def _to_s(node: t.Any, verbose: bool = False, level: int = 0) -> str:
"""Generate a textual representation of an Expression tree"""
indent = "\n" + (" " * (level + 1))
delim = f",{indent}"
if isinstance(node, Expression):
args = {k: v for k, v in node.args.items() if (v is not None and v != []) or verbose}
if (node.type or verbose) and not isinstance(node, DataType):
args["_type"] = node.type
if node.comments or verbose:
args["_comments"] = node.comments
if verbose:
args["_id"] = id(node)
# Inline leaves for a more compact representation
if node.is_leaf():
indent = ""
delim = ", "
items = delim.join([f"{k}={_to_s(v, verbose, level + 1)}" for k, v in args.items()])
return f"{node.__class__.__name__}({indent}{items})"
if isinstance(node, list):
items = delim.join(_to_s(i, verbose, level + 1) for i in node)
items = f"{indent}{items}" if items else ""
return f"[{items}]"
# Indent multiline strings to match the current level
return indent.join(textwrap.dedent(str(node).strip("\n")).splitlines())
def _is_wrong_expression(expression, into):
return isinstance(expression, Expression) and not isinstance(expression, into)
@ -5816,7 +5977,7 @@ def delete(
def insert(
expression: ExpOrStr,
into: ExpOrStr,
columns: t.Optional[t.Sequence[ExpOrStr]] = None,
columns: t.Optional[t.Sequence[str | Identifier]] = None,
overwrite: t.Optional[bool] = None,
returning: t.Optional[ExpOrStr] = None,
dialect: DialectType = None,
@ -5847,15 +6008,7 @@ def insert(
this: Table | Schema = maybe_parse(into, into=Table, dialect=dialect, copy=copy, **opts)
if columns:
this = _apply_list_builder(
*columns,
instance=Schema(this=this),
arg="expressions",
into=Identifier,
copy=False,
dialect=dialect,
**opts,
)
this = Schema(this=this, expressions=[to_identifier(c, copy=copy) for c in columns])
insert = Insert(this=this, expression=expr, overwrite=overwrite)
@ -6073,7 +6226,7 @@ def to_interval(interval: str | Literal) -> Interval:
return Interval(
this=Literal.string(interval_parts.group(1)),
unit=Var(this=interval_parts.group(2)),
unit=Var(this=interval_parts.group(2).upper()),
)
@ -6219,13 +6372,44 @@ def subquery(
return Select().from_(expression, dialect=dialect, **opts)
@t.overload
def column(
col: str | Identifier,
table: t.Optional[str | Identifier] = None,
db: t.Optional[str | Identifier] = None,
catalog: t.Optional[str | Identifier] = None,
*,
fields: t.Collection[t.Union[str, Identifier]],
quoted: t.Optional[bool] = None,
copy: bool = True,
) -> Dot:
pass
@t.overload
def column(
col: str | Identifier,
table: t.Optional[str | Identifier] = None,
db: t.Optional[str | Identifier] = None,
catalog: t.Optional[str | Identifier] = None,
*,
fields: Lit[None] = None,
quoted: t.Optional[bool] = None,
copy: bool = True,
) -> Column:
pass
def column(
col,
table=None,
db=None,
catalog=None,
*,
fields=None,
quoted=None,
copy=True,
):
"""
Build a Column.
@ -6234,18 +6418,24 @@ def column(
table: Table name.
db: Database name.
catalog: Catalog name.
fields: Additional fields using dots.
quoted: Whether to force quotes on the column's identifiers.
copy: Whether or not to copy identifiers if passed in.
Returns:
The new Column instance.
"""
return Column(
this=to_identifier(col, quoted=quoted),
table=to_identifier(table, quoted=quoted),
db=to_identifier(db, quoted=quoted),
catalog=to_identifier(catalog, quoted=quoted),
this = Column(
this=to_identifier(col, quoted=quoted, copy=copy),
table=to_identifier(table, quoted=quoted, copy=copy),
db=to_identifier(db, quoted=quoted, copy=copy),
catalog=to_identifier(catalog, quoted=quoted, copy=copy),
)
if fields:
this = Dot.build((this, *(to_identifier(field, copy=copy) for field in fields)))
return this
def cast(expression: ExpOrStr, to: DATA_TYPE, **opts) -> Cast:
"""Cast an expression to a data type.
@ -6333,10 +6523,10 @@ def var(name: t.Optional[ExpOrStr]) -> Var:
Example:
>>> repr(var('x'))
'(VAR this: x)'
'Var(this=x)'
>>> repr(var(column('x', table='y')))
'(VAR this: x)'
'Var(this=x)'
Args:
name: The name of the var or an expression who's name will become the var.

View file

@ -68,6 +68,7 @@ class Generator:
exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})",
exp.ClusteredColumnConstraint: lambda self, e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})",
exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
exp.AutoRefreshProperty: lambda self, e: f"AUTO REFRESH {self.sql(e, 'this')}",
exp.CopyGrantsProperty: lambda self, e: "COPY GRANTS",
exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
@ -96,6 +97,7 @@ class Generator:
exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}",
exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
exp.SqlReadWriteProperty: lambda self, e: e.name,
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
exp.StabilityProperty: lambda self, e: e.name,
exp.TemporaryProperty: lambda self, e: f"TEMPORARY",
@ -110,7 +112,8 @@ class Generator:
}
# Whether or not null ordering is supported in order by
NULL_ORDERING_SUPPORTED = True
# True: Full Support, None: No support, False: No support in window specifications
NULL_ORDERING_SUPPORTED: t.Optional[bool] = True
# Whether or not locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported
LOCKING_READS_SUPPORTED = False
@ -133,12 +136,6 @@ class Generator:
# Whether or not the plural form of date parts like day (i.e. "days") is supported in INTERVALs
INTERVAL_ALLOWS_PLURAL_FORM = True
# Whether or not the TABLESAMPLE clause supports a method name, like BERNOULLI
TABLESAMPLE_WITH_METHOD = True
# Whether or not to treat the number in TABLESAMPLE (50) as a percentage
TABLESAMPLE_SIZE_IS_PERCENT = False
# Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH")
LIMIT_FETCH = "ALL"
@ -219,6 +216,18 @@ class Generator:
# Whether or not parentheses are required around the table sample's expression
TABLESAMPLE_REQUIRES_PARENS = True
# Whether or not a table sample clause's size needs to be followed by the ROWS keyword
TABLESAMPLE_SIZE_IS_ROWS = True
# The keyword(s) to use when generating a sample clause
TABLESAMPLE_KEYWORDS = "TABLESAMPLE"
# Whether or not the TABLESAMPLE clause supports a method name, like BERNOULLI
TABLESAMPLE_WITH_METHOD = True
# The keyword to use when specifying the seed of a sample clause
TABLESAMPLE_SEED_KEYWORD = "SEED"
# Whether or not COLLATE is a function instead of a binary operator
COLLATE_IS_FUNC = False
@ -234,6 +243,27 @@ class Generator:
# Whether or not CONCAT requires >1 arguments
SUPPORTS_SINGLE_ARG_CONCAT = True
# Whether or not LAST_DAY function supports a date part argument
LAST_DAY_SUPPORTS_DATE_PART = True
# Whether or not named columns are allowed in table aliases
SUPPORTS_TABLE_ALIAS_COLUMNS = True
# Whether or not UNPIVOT aliases are Identifiers (False means they're Literals)
UNPIVOT_ALIASES_ARE_IDENTIFIERS = True
# What delimiter to use for separating JSON key/value pairs
JSON_KEY_VALUE_PAIR_SEP = ":"
# INSERT OVERWRITE TABLE x override
INSERT_OVERWRITE = " OVERWRITE TABLE"
# Whether or not the SELECT .. INTO syntax is used instead of CTAS
SUPPORTS_SELECT_INTO = False
# Whether or not UNLOGGED tables can be created
SUPPORTS_UNLOGGED_TABLES = False
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@ -252,15 +282,15 @@ class Generator:
}
TIME_PART_SINGULARS = {
"microseconds": "microsecond",
"seconds": "second",
"minutes": "minute",
"hours": "hour",
"days": "day",
"weeks": "week",
"months": "month",
"quarters": "quarter",
"years": "year",
"MICROSECONDS": "MICROSECOND",
"SECONDS": "SECOND",
"MINUTES": "MINUTE",
"HOURS": "HOUR",
"DAYS": "DAY",
"WEEKS": "WEEK",
"MONTHS": "MONTH",
"QUARTERS": "QUARTER",
"YEARS": "YEAR",
}
TOKEN_MAPPING: t.Dict[TokenType, str] = {}
@ -272,6 +302,7 @@ class Generator:
PROPERTIES_LOCATION = {
exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE,
exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA,
exp.AutoRefreshProperty: exp.Properties.Location.POST_SCHEMA,
exp.BlockCompressionProperty: exp.Properties.Location.POST_NAME,
exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA,
exp.ChecksumProperty: exp.Properties.Location.POST_NAME,
@ -323,6 +354,7 @@ class Generator:
exp.SettingsProperty: exp.Properties.Location.POST_SCHEMA,
exp.SetProperty: exp.Properties.Location.POST_CREATE,
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA,
exp.SqlReadWriteProperty: exp.Properties.Location.POST_SCHEMA,
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA,
exp.TemporaryProperty: exp.Properties.Location.POST_CREATE,
@ -370,7 +402,7 @@ class Generator:
# Expressions that need to have all CTEs under them bubbled up to them
EXPRESSIONS_WITHOUT_NESTED_CTES: t.Set[t.Type[exp.Expression]] = set()
KEY_VALUE_DEFINITONS = (exp.Bracket, exp.EQ, exp.PropertyEQ, exp.Slice)
KEY_VALUE_DEFINITIONS = (exp.Bracket, exp.EQ, exp.PropertyEQ, exp.Slice)
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
@ -775,7 +807,7 @@ class Generator:
return self.sql(expression, "this")
def create_sql(self, expression: exp.Create) -> str:
kind = self.sql(expression, "kind").upper()
kind = self.sql(expression, "kind")
properties = expression.args.get("properties")
properties_locs = self.locate_properties(properties) if properties else defaultdict()
@ -868,7 +900,12 @@ class Generator:
return f"{shallow}{keyword} {this}"
def describe_sql(self, expression: exp.Describe) -> str:
return f"DESCRIBE {self.sql(expression, 'this')}"
extended = " EXTENDED" if expression.args.get("extended") else ""
return f"DESCRIBE{extended} {self.sql(expression, 'this')}"
def heredoc_sql(self, expression: exp.Heredoc) -> str:
tag = self.sql(expression, "tag")
return f"${tag}${self.sql(expression, 'this')}${tag}$"
def prepend_ctes(self, expression: exp.Expression, sql: str) -> str:
with_ = self.sql(expression, "with")
@ -895,6 +932,10 @@ class Generator:
columns = self.expressions(expression, key="columns", flat=True)
columns = f"({columns})" if columns else ""
if columns and not self.SUPPORTS_TABLE_ALIAS_COLUMNS:
columns = ""
self.unsupported("Named columns are not supported in table alias.")
if not alias and not self.dialect.UNNEST_COLUMN_ONLY:
alias = "_t"
@ -1027,7 +1068,7 @@ class Generator:
def fetch_sql(self, expression: exp.Fetch) -> str:
direction = expression.args.get("direction")
direction = f" {direction.upper()}" if direction else ""
direction = f" {direction}" if direction else ""
count = expression.args.get("count")
count = f" {count}" if count else ""
if expression.args.get("percent"):
@ -1318,7 +1359,7 @@ class Generator:
if isinstance(expression.this, exp.Directory):
this = " OVERWRITE" if overwrite else " INTO"
else:
this = " OVERWRITE TABLE" if overwrite else " INTO"
this = self.INSERT_OVERWRITE if overwrite else " INTO"
alternative = expression.args.get("alternative")
alternative = f" OR {alternative}" if alternative else ""
@ -1365,10 +1406,10 @@ class Generator:
return f"KILL{kind}{this}"
def pseudotype_sql(self, expression: exp.PseudoType) -> str:
return expression.name.upper()
return expression.name
def objectidentifier_sql(self, expression: exp.ObjectIdentifier) -> str:
return expression.name.upper()
return expression.name
def onconflict_sql(self, expression: exp.OnConflict) -> str:
conflict = "ON DUPLICATE KEY" if expression.args.get("duplicate") else "ON CONFLICT"
@ -1445,9 +1486,6 @@ class Generator:
pattern = f", PATTERN => {pattern}" if pattern else ""
file_format = f" (FILE_FORMAT => {file_format}{pattern})"
index = self.sql(expression, "index")
index = f" AT {index}" if index else ""
ordinality = expression.args.get("ordinality") or ""
if ordinality:
ordinality = f" WITH ORDINALITY{alias}"
@ -1457,10 +1495,13 @@ class Generator:
if when:
table = f"{table} {when}"
return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}{ordinality}"
return f"{table}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}{ordinality}"
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
self,
expression: exp.TableSample,
sep: str = " AS ",
tablesample_keyword: t.Optional[str] = None,
) -> str:
if self.dialect.ALIAS_POST_TABLESAMPLE and expression.this and expression.this.alias:
table = expression.this.copy()
@ -1472,30 +1513,30 @@ class Generator:
alias = ""
method = self.sql(expression, "method")
method = f"{method.upper()} " if method and self.TABLESAMPLE_WITH_METHOD else ""
method = f"{method} " if method and self.TABLESAMPLE_WITH_METHOD else ""
numerator = self.sql(expression, "bucket_numerator")
denominator = self.sql(expression, "bucket_denominator")
field = self.sql(expression, "bucket_field")
field = f" ON {field}" if field else ""
bucket = f"BUCKET {numerator} OUT OF {denominator}{field}" if numerator else ""
percent = self.sql(expression, "percent")
percent = f"{percent} PERCENT" if percent else ""
rows = self.sql(expression, "rows")
rows = f"{rows} ROWS" if rows else ""
seed = self.sql(expression, "seed")
seed = f" {self.TABLESAMPLE_SEED_KEYWORD} ({seed})" if seed else ""
size = self.sql(expression, "size")
if size and self.TABLESAMPLE_SIZE_IS_PERCENT:
size = f"{size} PERCENT"
if size and self.TABLESAMPLE_SIZE_IS_ROWS:
size = f"{size} ROWS"
seed = self.sql(expression, "seed")
seed = f" {seed_prefix} ({seed})" if seed else ""
kind = expression.args.get("kind", "TABLESAMPLE")
percent = self.sql(expression, "percent")
if percent and not self.dialect.TABLESAMPLE_SIZE_IS_PERCENT:
percent = f"{percent} PERCENT"
expr = f"{bucket}{percent}{rows}{size}"
expr = f"{bucket}{percent}{size}"
if self.TABLESAMPLE_REQUIRES_PARENS:
expr = f"({expr})"
return f"{this} {kind} {method}{expr}{seed}{alias}"
return (
f"{this} {tablesample_keyword or self.TABLESAMPLE_KEYWORDS} {method}{expr}{seed}{alias}"
)
def pivot_sql(self, expression: exp.Pivot) -> str:
expressions = self.expressions(expression, flat=True)
@ -1513,8 +1554,7 @@ class Generator:
alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else ""
unpivot = expression.args.get("unpivot")
direction = "UNPIVOT" if unpivot else "PIVOT"
direction = "UNPIVOT" if expression.unpivot else "PIVOT"
field = self.sql(expression, "field")
include_nulls = expression.args.get("include_nulls")
if include_nulls is not None:
@ -1675,7 +1715,8 @@ class Generator:
if not on_sql and using:
on_sql = csv(*(self.sql(column) for column in using))
this_sql = self.sql(expression, "this")
this = expression.this
this_sql = self.sql(this)
if on_sql:
on_sql = self.indent(on_sql, skip_first=True)
@ -1685,6 +1726,9 @@ class Generator:
else:
on_sql = f"{space}ON {on_sql}"
elif not op_sql:
if isinstance(this, exp.Lateral) and this.args.get("cross_apply") is not None:
return f" {this_sql}"
return f", {this_sql}"
op_sql = f"{op_sql} JOIN" if op_sql else "JOIN"
@ -1695,6 +1739,19 @@ class Generator:
args = f"({args})" if len(args.split(",")) > 1 else args
return f"{args} {arrow_sep} {self.sql(expression, 'this')}"
def lateral_op(self, expression: exp.Lateral) -> str:
cross_apply = expression.args.get("cross_apply")
# https://www.mssqltips.com/sqlservertip/1958/sql-server-cross-apply-and-outer-apply/
if cross_apply is True:
op = "INNER JOIN "
elif cross_apply is False:
op = "LEFT JOIN "
else:
op = ""
return f"{op}LATERAL"
def lateral_sql(self, expression: exp.Lateral) -> str:
this = self.sql(expression, "this")
@ -1708,7 +1765,7 @@ class Generator:
alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else ""
return f"LATERAL {this}{alias}"
return f"{self.lateral_op(expression)} {this}{alias}"
def limit_sql(self, expression: exp.Limit, top: bool = False) -> str:
this = self.sql(expression, "this")
@ -1805,7 +1862,8 @@ class Generator:
def order_sql(self, expression: exp.Order, flat: bool = False) -> str:
this = self.sql(expression, "this")
this = f"{this} " if this else this
order = self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat) # type: ignore
siblings = "SIBLINGS " if expression.args.get("siblings") else ""
order = self.op_expressions(f"{this}ORDER {siblings}BY", expression, flat=this or flat) # type: ignore
interpolated_values = [
f"{self.sql(named_expression, 'alias')} AS {self.sql(named_expression, 'this')}"
for named_expression in expression.args.get("interpolate") or []
@ -1860,9 +1918,21 @@ class Generator:
# If the NULLS FIRST/LAST clause is unsupported, we add another sort key to simulate it
if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED:
null_sort_order = " DESC" if nulls_sort_change == " NULLS FIRST" else ""
this = f"CASE WHEN {this} IS NULL THEN 1 ELSE 0 END{null_sort_order}, {this}"
nulls_sort_change = ""
window = expression.find_ancestor(exp.Window, exp.Select)
if isinstance(window, exp.Window) and window.args.get("spec"):
self.unsupported(
f"'{nulls_sort_change.strip()}' translation not supported in window functions"
)
nulls_sort_change = ""
elif self.NULL_ORDERING_SUPPORTED is None:
if expression.this.is_int:
self.unsupported(
f"'{nulls_sort_change.strip()}' translation not supported with positional ordering"
)
else:
null_sort_order = " DESC" if nulls_sort_change == " NULLS FIRST" else ""
this = f"CASE WHEN {this} IS NULL THEN 1 ELSE 0 END{null_sort_order}, {this}"
nulls_sort_change = ""
with_fill = self.sql(expression, "with_fill")
with_fill = f" {with_fill}" if with_fill else ""
@ -1961,10 +2031,14 @@ class Generator:
return [locks, self.sql(expression, "sample")]
def select_sql(self, expression: exp.Select) -> str:
into = expression.args.get("into")
if not self.SUPPORTS_SELECT_INTO and into:
into.pop()
hint = self.sql(expression, "hint")
distinct = self.sql(expression, "distinct")
distinct = f" {distinct}" if distinct else ""
kind = self.sql(expression, "kind").upper()
kind = self.sql(expression, "kind")
limit = expression.args.get("limit")
top = (
self.limit_sql(limit, top=True)
@ -2005,7 +2079,19 @@ class Generator:
self.sql(expression, "into", comment=False),
self.sql(expression, "from", comment=False),
)
return self.prepend_ctes(expression, sql)
sql = self.prepend_ctes(expression, sql)
if not self.SUPPORTS_SELECT_INTO and into:
if into.args.get("temporary"):
table_kind = " TEMPORARY"
elif self.SUPPORTS_UNLOGGED_TABLES and into.args.get("unlogged"):
table_kind = " UNLOGGED"
else:
table_kind = ""
sql = f"CREATE{table_kind} TABLE {self.sql(into.this)} AS {sql}"
return sql
def schema_sql(self, expression: exp.Schema) -> str:
this = self.sql(expression, "this")
@ -2266,29 +2352,35 @@ class Generator:
return f"{self.func('MATCH', *expression.expressions)} AGAINST({self.sql(expression, 'this')}{modifier})"
def jsonkeyvalue_sql(self, expression: exp.JSONKeyValue) -> str:
return f"{self.sql(expression, 'this')}: {self.sql(expression, 'expression')}"
return f"{self.sql(expression, 'this')}{self.JSON_KEY_VALUE_PAIR_SEP} {self.sql(expression, 'expression')}"
def formatjson_sql(self, expression: exp.FormatJson) -> str:
return f"{self.sql(expression, 'this')} FORMAT JSON"
def jsonobject_sql(self, expression: exp.JSONObject) -> str:
def jsonobject_sql(self, expression: exp.JSONObject | exp.JSONObjectAgg) -> str:
null_handling = expression.args.get("null_handling")
null_handling = f" {null_handling}" if null_handling else ""
unique_keys = expression.args.get("unique_keys")
if unique_keys is not None:
unique_keys = f" {'WITH' if unique_keys else 'WITHOUT'} UNIQUE KEYS"
else:
unique_keys = ""
return_type = self.sql(expression, "return_type")
return_type = f" RETURNING {return_type}" if return_type else ""
encoding = self.sql(expression, "encoding")
encoding = f" ENCODING {encoding}" if encoding else ""
return self.func(
"JSON_OBJECT",
"JSON_OBJECT" if isinstance(expression, exp.JSONObject) else "JSON_OBJECTAGG",
*expression.expressions,
suffix=f"{null_handling}{unique_keys}{return_type}{encoding})",
)
def jsonobjectagg_sql(self, expression: exp.JSONObjectAgg) -> str:
return self.jsonobject_sql(expression)
def jsonarray_sql(self, expression: exp.JSONArray) -> str:
null_handling = expression.args.get("null_handling")
null_handling = f" {null_handling}" if null_handling else ""
@ -2385,7 +2477,7 @@ class Generator:
def interval_sql(self, expression: exp.Interval) -> str:
unit = self.sql(expression, "unit")
if not self.INTERVAL_ALLOWS_PLURAL_FORM:
unit = self.TIME_PART_SINGULARS.get(unit.lower(), unit)
unit = self.TIME_PART_SINGULARS.get(unit, unit)
unit = f" {unit}" if unit else ""
if self.SINGLE_STRING_INTERVAL:
@ -2436,9 +2528,25 @@ class Generator:
alias = f" AS {alias}" if alias else ""
return f"{self.sql(expression, 'this')}{alias}"
def pivotalias_sql(self, expression: exp.PivotAlias) -> str:
alias = expression.args["alias"]
identifier_alias = isinstance(alias, exp.Identifier)
if identifier_alias and not self.UNPIVOT_ALIASES_ARE_IDENTIFIERS:
alias.replace(exp.Literal.string(alias.output_name))
elif not identifier_alias and self.UNPIVOT_ALIASES_ARE_IDENTIFIERS:
alias.replace(exp.to_identifier(alias.output_name))
return self.alias_sql(expression)
def aliases_sql(self, expression: exp.Aliases) -> str:
return f"{self.sql(expression, 'this')} AS ({self.expressions(expression, flat=True)})"
def atindex_sql(self, expression: exp.AtTimeZone) -> str:
this = self.sql(expression, "this")
index = self.sql(expression, "expression")
return f"{this} AT {index}"
def attimezone_sql(self, expression: exp.AtTimeZone) -> str:
this = self.sql(expression, "this")
zone = self.sql(expression, "zone")
@ -2500,7 +2608,7 @@ class Generator:
return self.binary(expression, "COLLATE")
def command_sql(self, expression: exp.Command) -> str:
return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}"
return f"{self.sql(expression, 'this')} {expression.text('expression').strip()}"
def comment_sql(self, expression: exp.Comment) -> str:
this = self.sql(expression, "this")
@ -3102,6 +3210,47 @@ class Generator:
cond_for_null = arg.is_(exp.null())
return self.sql(exp.func("IF", cond_for_null, exp.null(), exp.Array(expressions=[arg])))
def tsordstotime_sql(self, expression: exp.TsOrDsToTime) -> str:
this = expression.this
if isinstance(this, exp.TsOrDsToTime) or this.is_type(exp.DataType.Type.TIME):
return self.sql(this)
return self.sql(exp.cast(this, "time"))
def tsordstodate_sql(self, expression: exp.TsOrDsToDate) -> str:
this = expression.this
time_format = self.format_time(expression)
if time_format and time_format not in (self.dialect.TIME_FORMAT, self.dialect.DATE_FORMAT):
return self.sql(
exp.cast(exp.StrToTime(this=this, format=expression.args["format"]), "date")
)
if isinstance(this, exp.TsOrDsToDate) or this.is_type(exp.DataType.Type.DATE):
return self.sql(this)
return self.sql(exp.cast(this, "date"))
def unixdate_sql(self, expression: exp.UnixDate) -> str:
return self.sql(
exp.func(
"DATEDIFF",
expression.this,
exp.cast(exp.Literal.string("1970-01-01"), "date"),
"day",
)
)
def lastday_sql(self, expression: exp.LastDay) -> str:
if self.LAST_DAY_SUPPORTS_DATE_PART:
return self.function_fallback_sql(expression)
unit = expression.text("unit")
if unit and unit != "MONTH":
self.unsupported("Date parts are not supported in LAST_DAY.")
return self.func("LAST_DAY", expression.this)
def _simplify_unless_literal(self, expression: E) -> E:
if not isinstance(expression, exp.Literal):
from sqlglot.optimizer.simplify import simplify

View file

@ -129,13 +129,10 @@ def lineage(
if isinstance(column, int)
else next(
(select for select in scope.expression.selects if select.alias_or_name == column),
exp.Star() if scope.expression.is_star else None,
exp.Star() if scope.expression.is_star else scope.expression,
)
)
if not select:
raise ValueError(f"Could not find {column} in {scope.expression}")
if isinstance(scope.expression, exp.Union):
upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
@ -194,6 +191,8 @@ def lineage(
# if the select is a star add all scope sources as downstreams
if select.is_star:
for source in scope.sources.values():
if isinstance(source, Scope):
source = source.expression
node.downstream.append(Node(name=select.sql(), source=source, expression=source))
# Find all columns that went into creating this one to list their lineage nodes.

View file

@ -195,6 +195,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.StrPosition,
exp.TsOrDiToDi,
},
exp.DataType.Type.JSON: {
exp.ParseJSON,
},
exp.DataType.Type.TIMESTAMP: {
exp.CurrentTime,
exp.CurrentTimestamp,
@ -275,6 +278,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True),
}
NESTED_TYPES = {
@ -477,7 +481,12 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
@t.no_type_check
def _annotate_by_args(
self, expression: E, *args: str, promote: bool = False, array: bool = False
self,
expression: E,
*args: str,
promote: bool = False,
array: bool = False,
struct: bool = False,
) -> E:
self._annotate_args(expression)
@ -506,6 +515,19 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
),
)
if struct:
expressions = [
expr.type
if not expr.args.get("alias")
else exp.ColumnDef(this=expr.args["alias"].copy(), kind=expr.type)
for expr in expressions
]
self._set_type(
expression,
exp.DataType(this=exp.DataType.Type.STRUCT, expressions=expressions, nested=True),
)
return expression
def _annotate_timeunit(

View file

@ -30,13 +30,18 @@ def pushdown_predicates(expression, dialect=None):
where = select.args.get("where")
if where:
selected_sources = scope.selected_sources
join_index = {
join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or [])
}
# a right join can only push down to itself and not the source FROM table
for k, (node, source) in selected_sources.items():
parent = node.find_ancestor(exp.Join, exp.From)
if isinstance(parent, exp.Join) and parent.side == "RIGHT":
selected_sources = {k: (node, source)}
break
pushdown(where.this, selected_sources, scope_ref_count, dialect)
pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index)
# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
@ -53,7 +58,7 @@ def pushdown_predicates(expression, dialect=None):
return expression
def pushdown(condition, sources, scope_ref_count, dialect):
def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
if not condition:
return
@ -67,21 +72,28 @@ def pushdown(condition, sources, scope_ref_count, dialect):
)
if cnf_like:
pushdown_cnf(predicates, sources, scope_ref_count)
pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index)
else:
pushdown_dnf(predicates, sources, scope_ref_count)
def pushdown_cnf(predicates, scope, scope_ref_count):
def pushdown_cnf(predicates, scope, scope_ref_count, join_index=None):
"""
If the predicates are in CNF like form, we can simply replace each block in the parent.
"""
join_index = join_index or {}
for predicate in predicates:
for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
if isinstance(node, exp.Join):
predicate.replace(exp.true())
node.on(predicate, copy=False)
break
name = node.alias_or_name
predicate_tables = exp.column_table_names(predicate, name)
# Don't push the predicate if it references tables that appear in later joins
this_index = join_index[name]
if all(join_index.get(table, -1) < this_index for table in predicate_tables):
predicate.replace(exp.true())
node.on(predicate, copy=False)
break
if isinstance(node, exp.Select):
predicate.replace(exp.true())
inner_predicate = replace_aliases(node, predicate)
@ -112,9 +124,7 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
conditions = {}
# for every pushdown table, find all related conditions in all predicates
# combine them with ORS
# (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z)
# pushdown all predicates to their respective nodes
for table in sorted(pushdown_tables):
for predicate in predicates:
nodes = nodes_for_predicate(predicate, scope, scope_ref_count)
@ -122,23 +132,9 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
if table not in nodes:
continue
predicate_condition = None
for column in predicate.find_all(exp.Column):
if column.table == table:
condition = column.find_ancestor(exp.Condition)
predicate_condition = (
exp.and_(predicate_condition, condition)
if predicate_condition
else condition
)
if predicate_condition:
conditions[table] = (
exp.or_(conditions[table], predicate_condition)
if table in conditions
else predicate_condition
)
conditions[table] = (
exp.or_(conditions[table], predicate) if table in conditions else predicate
)
for name, node in nodes.items():
if name not in conditions:

View file

@ -43,9 +43,8 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
parent_selections = referenced_columns.get(scope, {SELECT_ALL})
alias_count = source_column_alias_count.get(scope, 0)
if scope.expression.args.get("distinct") or (scope.parent and scope.parent.pivots):
# We can't remove columns SELECT DISTINCT nor UNION DISTINCT. The same holds if
# we select from a pivoted source in the parent scope.
# We can't remove columns SELECT DISTINCT nor UNION DISTINCT.
if scope.expression.args.get("distinct"):
parent_selections = {SELECT_ALL}
if isinstance(scope.expression, exp.Union):
@ -78,7 +77,7 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
# Push the selected columns down to the next scope
for name, (node, source) in scope.selected_sources.items():
if isinstance(source, Scope):
columns = selects.get(name) or set()
columns = {SELECT_ALL} if scope.pivots else selects.get(name) or set()
referenced_columns[source].update(columns)
column_aliases = node.alias_column_names

View file

@ -3,10 +3,11 @@ from __future__ import annotations
import typing as t
from sqlglot import exp
from sqlglot.dialects.dialect import DialectType
from sqlglot.dialects.dialect import Dialect, DialectType
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
from sqlglot.optimizer.qualify_columns import (
pushdown_cte_alias_columns as pushdown_cte_alias_columns_func,
qualify_columns as qualify_columns_func,
quote_identifiers as quote_identifiers_func,
validate_qualify_columns as validate_qualify_columns_func,
@ -22,6 +23,7 @@ def qualify(
catalog: t.Optional[str] = None,
schema: t.Optional[dict | Schema] = None,
expand_alias_refs: bool = True,
expand_stars: bool = True,
infer_schema: t.Optional[bool] = None,
isolate_tables: bool = False,
qualify_columns: bool = True,
@ -47,6 +49,9 @@ def qualify(
catalog: Default catalog name for tables.
schema: Schema to infer column names and types.
expand_alias_refs: Whether or not to expand references to aliases.
expand_stars: Whether or not to expand star queries. This is a necessary step
for most of the optimizer's rules to work; do not set to False unless you
know what you're doing!
infer_schema: Whether or not to infer the schema if missing.
isolate_tables: Whether or not to isolate table selects.
qualify_columns: Whether or not to qualify columns.
@ -66,9 +71,16 @@ def qualify(
if isolate_tables:
expression = isolate_table_selects(expression, schema=schema)
if Dialect.get_or_raise(dialect).PREFER_CTE_ALIAS_COLUMN:
expression = pushdown_cte_alias_columns_func(expression)
if qualify_columns:
expression = qualify_columns_func(
expression, schema, expand_alias_refs=expand_alias_refs, infer_schema=infer_schema
expression,
schema,
expand_alias_refs=expand_alias_refs,
expand_stars=expand_stars,
infer_schema=infer_schema,
)
if quote_identifiers:

View file

@ -17,6 +17,7 @@ def qualify_columns(
expression: exp.Expression,
schema: t.Dict | Schema,
expand_alias_refs: bool = True,
expand_stars: bool = True,
infer_schema: t.Optional[bool] = None,
) -> exp.Expression:
"""
@ -33,10 +34,16 @@ def qualify_columns(
expression: Expression to qualify.
schema: Database schema.
expand_alias_refs: Whether or not to expand references to aliases.
expand_stars: Whether or not to expand star queries. This is a necessary step
for most of the optimizer's rules to work; do not set to False unless you
know what you're doing!
infer_schema: Whether or not to infer the schema if missing.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
"""
schema = ensure_schema(schema)
infer_schema = schema.empty if infer_schema is None else infer_schema
@ -57,7 +64,8 @@ def qualify_columns(
_expand_alias_refs(scope, resolver)
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver, using_column_tables, pseudocolumns)
if expand_stars:
_expand_stars(scope, resolver, using_column_tables, pseudocolumns)
qualify_outputs(scope)
_expand_group_by(scope)
@ -68,21 +76,41 @@ def qualify_columns(
def validate_qualify_columns(expression: E) -> E:
"""Raise an `OptimizeError` if any columns aren't qualified"""
unqualified_columns = []
all_unqualified_columns = []
for scope in traverse_scope(expression):
if isinstance(scope.expression, exp.Select):
unqualified_columns.extend(scope.unqualified_columns)
unqualified_columns = scope.unqualified_columns
if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
column = scope.external_columns[0]
raise OptimizeError(
f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
)
for_table = f" for table: '{column.table}'" if column.table else ""
raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
# New columns produced by the UNPIVOT can't be qualified, but there may be columns
# under the UNPIVOT's IN clause that can and should be qualified. We recompute
# this list here to ensure those in the former category will be excluded.
unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
all_unqualified_columns.extend(unqualified_columns)
if all_unqualified_columns:
raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
if unqualified_columns:
raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
return expression
def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
name_column = []
field = unpivot.args.get("field")
if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
name_column.append(field.this)
value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
return itertools.chain(name_column, value_columns)
def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
"""
Remove table column aliases.
@ -216,6 +244,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
replace_columns(expression.args.get("group"), literal_index=True)
replace_columns(expression.args.get("having"), resolve_table=True)
replace_columns(expression.args.get("qualify"), resolve_table=True)
scope.clear_cache()
@ -353,18 +382,25 @@ def _expand_stars(
replace_columns: t.Dict[int, t.Dict[str, str]] = {}
coalesced_columns = set()
# TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
pivot_columns = None
pivot_output_columns = None
pivot_exclude_columns = None
pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
if pivot.unpivot:
pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
has_pivoted_source = pivot and not pivot.args.get("unpivot")
if pivot and has_pivoted_source:
pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))
field = pivot.args.get("field")
if isinstance(field, exp.In):
pivot_exclude_columns = {
c.output_name for e in field.expressions for c in e.find_all(exp.Column)
}
else:
pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
if not pivot_output_columns:
pivot_output_columns = [col.alias_or_name for col in pivot.expressions]
pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
if not pivot_output_columns:
pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
for expression in scope.expression.selects:
if isinstance(expression, exp.Star):
@ -384,47 +420,54 @@ def _expand_stars(
raise OptimizeError(f"Unknown table: {table}")
columns = resolver.get_source_columns(table, only_visible=True)
columns = columns or scope.outer_column_list
if pseudocolumns:
columns = [name for name in columns if name.upper() not in pseudocolumns]
if columns and "*" not in columns:
table_id = id(table)
columns_to_exclude = except_columns.get(table_id) or set()
if not columns or "*" in columns:
return
if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
implicit_columns = [col for col in columns if col not in pivot_columns]
table_id = id(table)
columns_to_exclude = except_columns.get(table_id) or set()
if pivot:
if pivot_output_columns and pivot_exclude_columns:
pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
pivot_columns.extend(pivot_output_columns)
else:
pivot_columns = pivot.alias_column_names
if pivot_columns:
new_selections.extend(
exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
for name in implicit_columns + pivot_output_columns
for name in pivot_columns
if name not in columns_to_exclude
)
continue
for name in columns:
if name in using_column_tables and table in using_column_tables[name]:
if name in coalesced_columns:
continue
for name in columns:
if name in using_column_tables and table in using_column_tables[name]:
if name in coalesced_columns:
continue
coalesced_columns.add(name)
tables = using_column_tables[name]
coalesce = [exp.column(name, table=table) for table in tables]
coalesced_columns.add(name)
tables = using_column_tables[name]
coalesce = [exp.column(name, table=table) for table in tables]
new_selections.append(
alias(
exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
alias=name,
copy=False,
)
new_selections.append(
alias(
exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
alias=name,
copy=False,
)
elif name not in columns_to_exclude:
alias_ = replace_columns.get(table_id, {}).get(name, name)
column = exp.column(name, table=table)
new_selections.append(
alias(column, alias_, copy=False) if alias_ != name else column
)
else:
return
)
elif name not in columns_to_exclude:
alias_ = replace_columns.get(table_id, {}).get(name, name)
column = exp.column(name, table=table)
new_selections.append(
alias(column, alias_, copy=False) if alias_ != name else column
)
# Ensures we don't overwrite the initial selections with an empty list
if new_selections:
@ -472,6 +515,9 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
for i, (selection, aliased_column) in enumerate(
itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
):
if selection is None:
break
if isinstance(selection, exp.Subquery):
if not selection.output_name:
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
@ -495,6 +541,38 @@ def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool
)
def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
"""
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Args:
expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
"""
for cte in expression.find_all(exp.CTE):
if cte.alias_column_names:
new_expressions = []
for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
if isinstance(projection, exp.Alias):
projection.set("alias", _alias)
else:
projection = alias(projection, alias=_alias)
new_expressions.append(projection)
cte.this.set("expressions", new_expressions)
return expression
class Resolver:
"""
Helper for resolving columns.

View file

@ -72,11 +72,15 @@ def qualify_tables(
if not source.args.get("catalog") and source.args.get("db"):
source.set("catalog", catalog)
pivots = pivots = source.args.get("pivots")
if not source.alias:
# Don't add the pivot's alias to the pivoted table, use the table's name instead
if pivots and pivots[0].alias == name:
name = source.name
# Mutates the source by attaching an alias to it
alias(source, name or source.name or next_alias_name(), copy=False, table=True)
pivots = source.args.get("pivots")
if pivots and not pivots[0].alias:
pivots[0].set(
"alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))

View file

@ -539,11 +539,23 @@ def _traverse_union(scope):
# The last scope to be yield should be the top most scope
left = None
for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)):
for left in _traverse_scope(
scope.branch(
scope.expression.left,
outer_column_list=scope.outer_column_list,
scope_type=ScopeType.UNION,
)
):
yield left
right = None
for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)):
for right in _traverse_scope(
scope.branch(
scope.expression.right,
outer_column_list=scope.outer_column_list,
scope_type=ScopeType.UNION,
)
):
yield right
scope.union_scopes = [left, right]

View file

@ -100,6 +100,7 @@ def simplify(
node = simplify_parens(node)
node = simplify_datetrunc(node, dialect)
node = sort_comparison(node)
node = simplify_startswith(node)
if root:
expression.replace(node)
@ -776,6 +777,26 @@ def simplify_conditionals(expression):
return expression
def simplify_startswith(expression: exp.Expression) -> exp.Expression:
"""
Reduces a prefix check to either TRUE or FALSE if both the string and the
prefix are statically known.
Example:
>>> from sqlglot import parse_one
>>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
'TRUE'
"""
if (
isinstance(expression, exp.StartsWith)
and expression.this.is_string
and expression.expression.is_string
):
return exp.convert(expression.name.startswith(expression.expression.name))
return expression
DateRange = t.Tuple[datetime.date, datetime.date]
@ -1160,7 +1181,7 @@ def gen(expression: t.Any) -> str:
GEN_MAP = {
exp.Add: lambda e: _binary(e, "+"),
exp.And: lambda e: _binary(e, "AND"),
exp.Anonymous: lambda e: f"{e.this} {','.join(gen(e) for e in e.expressions)}",
exp.Anonymous: lambda e: f"{e.this.upper()} {','.join(gen(e) for e in e.expressions)}",
exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",

View file

@ -12,6 +12,8 @@ from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import TrieResult, in_trie, new_trie
if t.TYPE_CHECKING:
from typing_extensions import Literal
from sqlglot._typing import E
from sqlglot.dialects.dialect import Dialect, DialectType
@ -193,6 +195,7 @@ class Parser(metaclass=_Parser):
TokenType.DATETIME,
TokenType.DATETIME64,
TokenType.DATE,
TokenType.DATE32,
TokenType.INT4RANGE,
TokenType.INT4MULTIRANGE,
TokenType.INT8RANGE,
@ -232,6 +235,8 @@ class Parser(metaclass=_Parser):
TokenType.INET,
TokenType.IPADDRESS,
TokenType.IPPREFIX,
TokenType.IPV4,
TokenType.IPV6,
TokenType.UNKNOWN,
TokenType.NULL,
*ENUM_TYPE_TOKENS,
@ -669,6 +674,7 @@ class Parser(metaclass=_Parser):
PROPERTY_PARSERS: t.Dict[str, t.Callable] = {
"ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty),
"AUTO": lambda self: self._parse_auto_property(),
"AUTO_INCREMENT": lambda self: self._parse_property_assignment(exp.AutoIncrementProperty),
"BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(),
"CHARSET": lambda self, **kwargs: self._parse_character_set(**kwargs),
@ -680,6 +686,7 @@ class Parser(metaclass=_Parser):
exp.CollateProperty, **kwargs
),
"COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
"CONTAINS": lambda self: self._parse_contains_property(),
"COPY": lambda self: self._parse_copy_property(),
"DATABLOCKSIZE": lambda self, **kwargs: self._parse_datablocksize(**kwargs),
"DEFINER": lambda self: self._parse_definer(),
@ -710,6 +717,7 @@ class Parser(metaclass=_Parser):
"LOG": lambda self, **kwargs: self._parse_log(**kwargs),
"MATERIALIZED": lambda self: self.expression(exp.MaterializedProperty),
"MERGEBLOCKRATIO": lambda self, **kwargs: self._parse_mergeblockratio(**kwargs),
"MODIFIES": lambda self: self._parse_modifies_property(),
"MULTISET": lambda self: self.expression(exp.SetProperty, multi=True),
"NO": lambda self: self._parse_no_property(),
"ON": lambda self: self._parse_on_property(),
@ -721,6 +729,7 @@ class Parser(metaclass=_Parser):
"PARTITIONED_BY": lambda self: self._parse_partitioned_by(),
"PRIMARY KEY": lambda self: self._parse_primary_key(in_props=True),
"RANGE": lambda self: self._parse_dict_range(this="RANGE"),
"READS": lambda self: self._parse_reads_property(),
"REMOTE": lambda self: self._parse_remote_with_connection(),
"RETURNS": lambda self: self._parse_returns(),
"ROW": lambda self: self._parse_row(),
@ -841,6 +850,7 @@ class Parser(metaclass=_Parser):
"DECODE": lambda self: self._parse_decode(),
"EXTRACT": lambda self: self._parse_extract(),
"JSON_OBJECT": lambda self: self._parse_json_object(),
"JSON_OBJECTAGG": lambda self: self._parse_json_object(agg=True),
"JSON_TABLE": lambda self: self._parse_json_table(),
"MATCH": lambda self: self._parse_match_against(),
"OPENJSON": lambda self: self._parse_open_json(),
@ -925,6 +935,8 @@ class Parser(metaclass=_Parser):
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER}
WINDOW_SIDES = {"FOLLOWING", "PRECEDING"}
JSON_KEY_VALUE_SEPARATOR_TOKENS = {TokenType.COLON, TokenType.COMMA, TokenType.IS}
FETCH_TOKENS = ID_VAR_TOKENS - {TokenType.ROW, TokenType.ROWS, TokenType.PERCENT}
ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY}
@ -954,6 +966,9 @@ class Parser(metaclass=_Parser):
# Whether the TRIM function expects the characters to trim as its first argument
TRIM_PATTERN_FIRST = False
# Whether or not string aliases are supported `SELECT COUNT(*) 'count'`
STRING_ALIASES = False
# Whether query modifiers such as LIMIT are attached to the UNION node (vs its right operand)
MODIFIERS_ATTACHED_TO_UNION = True
UNION_MODIFIERS = {"order", "limit", "offset"}
@ -1193,7 +1208,9 @@ class Parser(metaclass=_Parser):
self._advance(index - self._index)
def _parse_command(self) -> exp.Command:
return self.expression(exp.Command, this=self._prev.text, expression=self._parse_string())
return self.expression(
exp.Command, this=self._prev.text.upper(), expression=self._parse_string()
)
def _parse_comment(self, allow_exists: bool = True) -> exp.Expression:
start = self._prev
@ -1353,26 +1370,27 @@ class Parser(metaclass=_Parser):
# exp.Properties.Location.POST_SCHEMA ("schema" here is the UDF's type signature)
extend_props(self._parse_properties())
self._match(TokenType.ALIAS)
expression = self._match(TokenType.ALIAS) and self._parse_heredoc()
if self._match(TokenType.COMMAND):
expression = self._parse_as_command(self._prev)
else:
begin = self._match(TokenType.BEGIN)
return_ = self._match_text_seq("RETURN")
if self._match(TokenType.STRING, advance=False):
# Takes care of BigQuery's JavaScript UDF definitions that end in an OPTIONS property
# # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_function_statement
expression = self._parse_string()
extend_props(self._parse_properties())
if not expression:
if self._match(TokenType.COMMAND):
expression = self._parse_as_command(self._prev)
else:
expression = self._parse_statement()
begin = self._match(TokenType.BEGIN)
return_ = self._match_text_seq("RETURN")
end = self._match_text_seq("END")
if self._match(TokenType.STRING, advance=False):
# Takes care of BigQuery's JavaScript UDF definitions that end in an OPTIONS property
# # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_function_statement
expression = self._parse_string()
extend_props(self._parse_properties())
else:
expression = self._parse_statement()
if return_:
expression = self.expression(exp.Return, this=expression)
end = self._match_text_seq("END")
if return_:
expression = self.expression(exp.Return, this=expression)
elif create_token.token_type == TokenType.INDEX:
this = self._parse_index(index=self._parse_id_var())
elif create_token.token_type in self.DB_CREATABLES:
@ -1426,7 +1444,7 @@ class Parser(metaclass=_Parser):
exp.Create,
comments=comments,
this=this,
kind=create_token.text,
kind=create_token.text.upper(),
replace=replace,
unique=unique,
expression=expression,
@ -1849,9 +1867,21 @@ class Parser(metaclass=_Parser):
return self.expression(exp.WithDataProperty, no=no, statistics=statistics)
def _parse_no_property(self) -> t.Optional[exp.NoPrimaryIndexProperty]:
def _parse_contains_property(self) -> t.Optional[exp.SqlReadWriteProperty]:
if self._match_text_seq("SQL"):
return self.expression(exp.SqlReadWriteProperty, this="CONTAINS SQL")
return None
def _parse_modifies_property(self) -> t.Optional[exp.SqlReadWriteProperty]:
if self._match_text_seq("SQL", "DATA"):
return self.expression(exp.SqlReadWriteProperty, this="MODIFIES SQL DATA")
return None
def _parse_no_property(self) -> t.Optional[exp.Expression]:
if self._match_text_seq("PRIMARY", "INDEX"):
return exp.NoPrimaryIndexProperty()
if self._match_text_seq("SQL"):
return self.expression(exp.SqlReadWriteProperty, this="NO SQL")
return None
def _parse_on_property(self) -> t.Optional[exp.Expression]:
@ -1861,6 +1891,11 @@ class Parser(metaclass=_Parser):
return exp.OnCommitProperty(delete=True)
return self.expression(exp.OnProperty, this=self._parse_schema(self._parse_id_var()))
def _parse_reads_property(self) -> t.Optional[exp.SqlReadWriteProperty]:
if self._match_text_seq("SQL", "DATA"):
return self.expression(exp.SqlReadWriteProperty, this="READS SQL DATA")
return None
def _parse_distkey(self) -> exp.DistKeyProperty:
return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var))
@ -1920,10 +1955,13 @@ class Parser(metaclass=_Parser):
def _parse_describe(self) -> exp.Describe:
kind = self._match_set(self.CREATABLES) and self._prev.text
extended = self._match_text_seq("EXTENDED")
this = self._parse_table(schema=True)
properties = self._parse_properties()
expressions = properties.expressions if properties else None
return self.expression(exp.Describe, this=this, kind=kind, expressions=expressions)
return self.expression(
exp.Describe, this=this, extended=extended, kind=kind, expressions=expressions
)
def _parse_insert(self) -> exp.Insert:
comments = ensure_list(self._prev_comments)
@ -2164,13 +2202,13 @@ class Parser(metaclass=_Parser):
def _parse_value(self) -> exp.Tuple:
if self._match(TokenType.L_PAREN):
expressions = self._parse_csv(self._parse_conjunction)
expressions = self._parse_csv(self._parse_expression)
self._match_r_paren()
return self.expression(exp.Tuple, expressions=expressions)
# In presto we can have VALUES 1, 2 which results in 1 column & 2 rows.
# https://prestodb.io/docs/current/sql/values.html
return self.expression(exp.Tuple, expressions=[self._parse_conjunction()])
return self.expression(exp.Tuple, expressions=[self._parse_expression()])
def _parse_projections(self) -> t.List[exp.Expression]:
return self._parse_expressions()
@ -2212,7 +2250,7 @@ class Parser(metaclass=_Parser):
kind = (
self._match(TokenType.ALIAS)
and self._match_texts(("STRUCT", "VALUE"))
and self._prev.text
and self._prev.text.upper()
)
if distinct:
@ -2261,7 +2299,7 @@ class Parser(metaclass=_Parser):
if table
else self._parse_select(nested=True, parse_set_operation=False)
)
this = self._parse_set_operations(self._parse_query_modifiers(this))
this = self._parse_query_modifiers(self._parse_set_operations(this))
self._match_r_paren()
@ -2304,7 +2342,7 @@ class Parser(metaclass=_Parser):
)
def _parse_cte(self) -> exp.CTE:
alias = self._parse_table_alias()
alias = self._parse_table_alias(self.ID_VAR_TOKENS)
if not alias or not alias.this:
self.raise_error("Expected CTE to have alias")
@ -2490,13 +2528,14 @@ class Parser(metaclass=_Parser):
)
def _parse_lateral(self) -> t.Optional[exp.Lateral]:
outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY)
cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY)
if not cross_apply and self._match_pair(TokenType.OUTER, TokenType.APPLY):
cross_apply = False
if outer_apply or cross_apply:
if cross_apply is not None:
this = self._parse_select(table=True)
view = None
outer = not cross_apply
outer = None
elif self._match(TokenType.LATERAL):
this = self._parse_select(table=True)
view = self._match(TokenType.VIEW)
@ -2529,7 +2568,14 @@ class Parser(metaclass=_Parser):
else:
table_alias = self._parse_table_alias()
return self.expression(exp.Lateral, this=this, view=view, outer=outer, alias=table_alias)
return self.expression(
exp.Lateral,
this=this,
view=view,
outer=outer,
alias=table_alias,
cross_apply=cross_apply,
)
def _parse_join_parts(
self,
@ -2563,9 +2609,6 @@ class Parser(metaclass=_Parser):
if not skip_join_token and not join and not outer_apply and not cross_apply:
return None
if outer_apply:
side = Token(TokenType.LEFT, "LEFT")
kwargs: t.Dict[str, t.Any] = {"this": self._parse_table(parse_bracket=parse_bracket)}
if method:
@ -2755,8 +2798,10 @@ class Parser(metaclass=_Parser):
if alias:
this.set("alias", alias)
if self._match_text_seq("AT"):
this.set("index", self._parse_id_var())
if isinstance(this, exp.Table) and self._match_text_seq("AT"):
return self.expression(
exp.AtIndex, this=this.to_column(copy=False), expression=self._parse_id_var()
)
this.set("hints", self._parse_table_hints())
@ -2865,15 +2910,10 @@ class Parser(metaclass=_Parser):
bucket_denominator = None
bucket_field = None
percent = None
rows = None
size = None
seed = None
kind = (
self._prev.text if self._prev.token_type == TokenType.TABLE_SAMPLE else "USING SAMPLE"
)
method = self._parse_var(tokens=(TokenType.ROW,))
method = self._parse_var(tokens=(TokenType.ROW,), upper=True)
matched_l_paren = self._match(TokenType.L_PAREN)
if self.TABLESAMPLE_CSV:
@ -2895,16 +2935,16 @@ class Parser(metaclass=_Parser):
bucket_field = self._parse_field()
elif self._match_set((TokenType.PERCENT, TokenType.MOD)):
percent = num
elif self._match(TokenType.ROWS):
rows = num
elif num:
elif self._match(TokenType.ROWS) or not self.dialect.TABLESAMPLE_SIZE_IS_PERCENT:
size = num
else:
percent = num
if matched_l_paren:
self._match_r_paren()
if self._match(TokenType.L_PAREN):
method = self._parse_var()
method = self._parse_var(upper=True)
seed = self._match(TokenType.COMMA) and self._parse_number()
self._match_r_paren()
elif self._match_texts(("SEED", "REPEATABLE")):
@ -2918,10 +2958,8 @@ class Parser(metaclass=_Parser):
bucket_denominator=bucket_denominator,
bucket_field=bucket_field,
percent=percent,
rows=rows,
size=size,
seed=seed,
kind=kind,
)
def _parse_pivots(self) -> t.Optional[t.List[exp.Pivot]]:
@ -2946,6 +2984,27 @@ class Parser(metaclass=_Parser):
exp.Pivot, this=this, expressions=expressions, using=using, group=group
)
def _parse_pivot_in(self) -> exp.In:
def _parse_aliased_expression() -> t.Optional[exp.Expression]:
this = self._parse_conjunction()
self._match(TokenType.ALIAS)
alias = self._parse_field()
if alias:
return self.expression(exp.PivotAlias, this=this, alias=alias)
return this
value = self._parse_column()
if not self._match_pair(TokenType.IN, TokenType.L_PAREN):
self.raise_error("Expecting IN (")
aliased_expressions = self._parse_csv(_parse_aliased_expression)
self._match_r_paren()
return self.expression(exp.In, this=value, expressions=aliased_expressions)
def _parse_pivot(self) -> t.Optional[exp.Pivot]:
index = self._index
include_nulls = None
@ -2964,7 +3023,6 @@ class Parser(metaclass=_Parser):
return None
expressions = []
field = None
if not self._match(TokenType.L_PAREN):
self._retreat(index)
@ -2981,12 +3039,7 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.FOR):
self.raise_error("Expecting FOR")
value = self._parse_column()
if not self._match(TokenType.IN):
self.raise_error("Expecting IN")
field = self._parse_in(value, alias=True)
field = self._parse_pivot_in()
self._match_r_paren()
@ -3132,14 +3185,19 @@ class Parser(metaclass=_Parser):
def _parse_order(
self, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False
) -> t.Optional[exp.Expression]:
siblings = None
if not skip_order_token and not self._match(TokenType.ORDER_BY):
return this
if not self._match(TokenType.ORDER_SIBLINGS_BY):
return this
siblings = True
return self.expression(
exp.Order,
this=this,
expressions=self._parse_csv(self._parse_ordered),
interpolate=self._parse_interpolate(),
siblings=siblings,
)
def _parse_sort(self, exp_class: t.Type[E], token: TokenType) -> t.Optional[E]:
@ -3213,7 +3271,7 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.FETCH):
direction = self._match_set((TokenType.FIRST, TokenType.NEXT))
direction = self._prev.text if direction else "FIRST"
direction = self._prev.text.upper() if direction else "FIRST"
count = self._parse_field(tokens=self.FETCH_TOKENS)
percent = self._match(TokenType.PERCENT)
@ -3398,10 +3456,10 @@ class Parser(metaclass=_Parser):
return this
return self.expression(exp.Escape, this=this, expression=self._parse_string())
def _parse_interval(self) -> t.Optional[exp.Interval]:
def _parse_interval(self, match_interval: bool = True) -> t.Optional[exp.Interval]:
index = self._index
if not self._match(TokenType.INTERVAL):
if not self._match(TokenType.INTERVAL) and match_interval:
return None
if self._match(TokenType.STRING, advance=False):
@ -3409,11 +3467,19 @@ class Parser(metaclass=_Parser):
else:
this = self._parse_term()
if not this:
if not this or (
isinstance(this, exp.Column)
and not this.table
and not this.this.quoted
and this.name.upper() == "IS"
):
self._retreat(index)
return None
unit = self._parse_function() or self._parse_var(any_token=True)
unit = self._parse_function() or (
not self._match(TokenType.ALIAS, advance=False)
and self._parse_var(any_token=True, upper=True)
)
# Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse
# each INTERVAL expression into this canonical form so it's easy to transpile
@ -3429,7 +3495,7 @@ class Parser(metaclass=_Parser):
self._retreat(self._index - 1)
this = exp.Literal.string(parts[0])
unit = self.expression(exp.Var, this=parts[1])
unit = self.expression(exp.Var, this=parts[1].upper())
return self.expression(exp.Interval, this=this, unit=unit)
@ -3489,6 +3555,12 @@ class Parser(metaclass=_Parser):
def _parse_type(self, parse_interval: bool = True) -> t.Optional[exp.Expression]:
interval = parse_interval and self._parse_interval()
if interval:
# Convert INTERVAL 'val_1' unit_1 ... 'val_n' unit_n into a sum of intervals
while self._match_set((TokenType.STRING, TokenType.NUMBER), advance=False):
interval = self.expression( # type: ignore
exp.Add, this=interval, expression=self._parse_interval(match_interval=False)
)
return interval
index = self._index
@ -3552,10 +3624,10 @@ class Parser(metaclass=_Parser):
type_token = self._prev.token_type
if type_token == TokenType.PSEUDO_TYPE:
return self.expression(exp.PseudoType, this=self._prev.text)
return self.expression(exp.PseudoType, this=self._prev.text.upper())
if type_token == TokenType.OBJECT_IDENTIFIER:
return self.expression(exp.ObjectIdentifier, this=self._prev.text)
return self.expression(exp.ObjectIdentifier, this=self._prev.text.upper())
nested = type_token in self.NESTED_TYPE_TOKENS
is_struct = type_token in self.STRUCT_TYPE_TOKENS
@ -3587,7 +3659,7 @@ class Parser(metaclass=_Parser):
if nested and self._match(TokenType.LT):
if is_struct:
expressions = self._parse_csv(self._parse_struct_types)
expressions = self._parse_csv(lambda: self._parse_struct_types(type_required=True))
else:
expressions = self._parse_csv(
lambda: self._parse_types(
@ -3662,10 +3734,19 @@ class Parser(metaclass=_Parser):
return this
def _parse_struct_types(self) -> t.Optional[exp.Expression]:
def _parse_struct_types(self, type_required: bool = False) -> t.Optional[exp.Expression]:
index = self._index
this = self._parse_type(parse_interval=False) or self._parse_id_var()
self._match(TokenType.COLON)
return self._parse_column_def(this)
column_def = self._parse_column_def(this)
if type_required and (
(isinstance(this, exp.Column) and this.this is column_def) or this is column_def
):
self._retreat(index)
return self._parse_types()
return column_def
def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match_text_seq("AT", "TIME", "ZONE"):
@ -4025,6 +4106,12 @@ class Parser(metaclass=_Parser):
return exp.AutoIncrementColumnConstraint()
def _parse_auto_property(self) -> t.Optional[exp.AutoRefreshProperty]:
if not self._match_text_seq("REFRESH"):
self._retreat(self._index - 1)
return None
return self.expression(exp.AutoRefreshProperty, this=self._parse_var(upper=True))
def _parse_compress(self) -> exp.CompressColumnConstraint:
if self._match(TokenType.L_PAREN, advance=False):
return self.expression(
@ -4230,8 +4317,10 @@ class Parser(metaclass=_Parser):
def _parse_primary_key_part(self) -> t.Optional[exp.Expression]:
return self._parse_field()
def _parse_period_for_system_time(self) -> exp.PeriodForSystemTimeConstraint:
self._match(TokenType.TIMESTAMP_SNAPSHOT)
def _parse_period_for_system_time(self) -> t.Optional[exp.PeriodForSystemTimeConstraint]:
if not self._match(TokenType.TIMESTAMP_SNAPSHOT):
self._retreat(self._index - 1)
return None
id_vars = self._parse_wrapped_id_vars()
return self.expression(
@ -4257,22 +4346,17 @@ class Parser(metaclass=_Parser):
options = self._parse_key_constraint_options()
return self.expression(exp.PrimaryKey, expressions=expressions, options=options)
def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]:
return self._parse_slice(self._parse_alias(self._parse_conjunction(), explicit=True))
def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match_set((TokenType.L_BRACKET, TokenType.L_BRACE)):
return this
bracket_kind = self._prev.token_type
if self._match(TokenType.COLON):
expressions: t.List[exp.Expression] = [
self.expression(exp.Slice, expression=self._parse_conjunction())
]
else:
expressions = self._parse_csv(
lambda: self._parse_slice(
self._parse_alias(self._parse_conjunction(), explicit=True)
)
)
expressions = self._parse_csv(
lambda: self._parse_bracket_key_value(is_map=bracket_kind == TokenType.L_BRACE)
)
if not self._match(TokenType.R_BRACKET) and bracket_kind == TokenType.L_BRACKET:
self.raise_error("Expected ]")
@ -4313,7 +4397,10 @@ class Parser(metaclass=_Parser):
default = self._parse_conjunction()
if not self._match(TokenType.END):
self.raise_error("Expected END after CASE", self._prev)
if isinstance(default, exp.Interval) and default.this.sql().upper() == "END":
default = exp.column("interval")
else:
self.raise_error("Expected END after CASE", self._prev)
return self._parse_window(
self.expression(exp.Case, comments=comments, this=expression, ifs=ifs, default=default)
@ -4514,7 +4601,7 @@ class Parser(metaclass=_Parser):
def _parse_json_key_value(self) -> t.Optional[exp.JSONKeyValue]:
self._match_text_seq("KEY")
key = self._parse_column()
self._match_set((TokenType.COLON, TokenType.COMMA))
self._match_set(self.JSON_KEY_VALUE_SEPARATOR_TOKENS)
self._match_text_seq("VALUE")
value = self._parse_bitwise()
@ -4536,7 +4623,15 @@ class Parser(metaclass=_Parser):
return None
def _parse_json_object(self) -> exp.JSONObject:
@t.overload
def _parse_json_object(self, agg: Literal[False]) -> exp.JSONObject:
...
@t.overload
def _parse_json_object(self, agg: Literal[True]) -> exp.JSONObjectAgg:
...
def _parse_json_object(self, agg=False):
star = self._parse_star()
expressions = (
[star]
@ -4559,7 +4654,7 @@ class Parser(metaclass=_Parser):
encoding = self._match_text_seq("ENCODING") and self._parse_var()
return self.expression(
exp.JSONObject,
exp.JSONObjectAgg if agg else exp.JSONObject,
expressions=expressions,
null_handling=null_handling,
unique_keys=unique_keys,
@ -4873,10 +4968,17 @@ class Parser(metaclass=_Parser):
self._match_r_paren(aliases)
return aliases
alias = self._parse_id_var(any_token)
alias = self._parse_id_var(any_token) or (
self.STRING_ALIASES and self._parse_string_as_identifier()
)
if alias:
return self.expression(exp.Alias, comments=comments, this=this, alias=alias)
this = self.expression(exp.Alias, comments=comments, this=this, alias=alias)
# Moves the comment next to the alias in `expr /* comment */ AS alias`
if not this.comments and this.this.comments:
this.comments = this.this.comments
this.this.comments = None
return this
@ -4915,14 +5017,19 @@ class Parser(metaclass=_Parser):
return self._parse_placeholder()
def _parse_var(
self, any_token: bool = False, tokens: t.Optional[t.Collection[TokenType]] = None
self,
any_token: bool = False,
tokens: t.Optional[t.Collection[TokenType]] = None,
upper: bool = False,
) -> t.Optional[exp.Expression]:
if (
(any_token and self._advance_any())
or self._match(TokenType.VAR)
or (self._match_set(tokens) if tokens else False)
):
return self.expression(exp.Var, this=self._prev.text)
return self.expression(
exp.Var, this=self._prev.text.upper() if upper else self._prev.text
)
return self._parse_placeholder()
def _advance_any(self, ignore_reserved: bool = False) -> t.Optional[Token]:
@ -5418,6 +5525,42 @@ class Parser(metaclass=_Parser):
condition=condition,
)
def _parse_heredoc(self) -> t.Optional[exp.Heredoc]:
if self._match(TokenType.HEREDOC_STRING):
return self.expression(exp.Heredoc, this=self._prev.text)
if not self._match_text_seq("$"):
return None
tags = ["$"]
tag_text = None
if self._is_connected():
self._advance()
tags.append(self._prev.text.upper())
else:
self.raise_error("No closing $ found")
if tags[-1] != "$":
if self._is_connected() and self._match_text_seq("$"):
tag_text = tags[-1]
tags.append("$")
else:
self.raise_error("No closing $ found")
heredoc_start = self._curr
while self._curr:
if self._match_text_seq(*tags, advance=False):
this = self._find_sql(heredoc_start, self._prev)
self._advance(len(tags))
return self.expression(exp.Heredoc, this=this, tag=tag_text)
self._advance()
self.raise_error(f"No closing {''.join(tags)} found")
return None
def _find_parser(
self, parsers: t.Dict[str, t.Callable], trie: t.Dict
) -> t.Optional[t.Callable]:

View file

@ -215,12 +215,13 @@ class MappingSchema(AbstractMappingSchema, Schema):
normalize: bool = True,
) -> None:
self.dialect = dialect
self.visible = visible or {}
self.visible = {} if visible is None else visible
self.normalize = normalize
self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
self._depth = 0
schema = {} if schema is None else schema
super().__init__(self._normalize(schema or {}))
super().__init__(self._normalize(schema) if self.normalize else schema)
@classmethod
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:

View file

@ -147,6 +147,7 @@ class TokenType(AutoName):
DATETIME = auto()
DATETIME64 = auto()
DATE = auto()
DATE32 = auto()
INT4RANGE = auto()
INT4MULTIRANGE = auto()
INT8RANGE = auto()
@ -182,6 +183,8 @@ class TokenType(AutoName):
INET = auto()
IPADDRESS = auto()
IPPREFIX = auto()
IPV4 = auto()
IPV6 = auto()
ENUM = auto()
ENUM8 = auto()
ENUM16 = auto()
@ -296,6 +299,7 @@ class TokenType(AutoName):
ON = auto()
OPERATOR = auto()
ORDER_BY = auto()
ORDER_SIBLINGS_BY = auto()
ORDERED = auto()
ORDINALITY = auto()
OUTER = auto()

View file

@ -255,7 +255,7 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
if not arrays:
if expression.args.get("from"):
expression.join(series, copy=False)
expression.join(series, copy=False, join_type="CROSS")
else:
expression.from_(series, copy=False)