Merging upstream version 20.9.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
9421b254ec
commit
37a231f554
144 changed files with 78309 additions and 59609 deletions
|
@ -574,13 +574,13 @@ def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Col
|
|||
|
||||
def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
|
||||
return Column.invoke_expression_over_column(
|
||||
col, expression.DateAdd, expression=days, unit=expression.Var(this="day")
|
||||
col, expression.DateAdd, expression=days, unit=expression.Var(this="DAY")
|
||||
)
|
||||
|
||||
|
||||
def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
|
||||
return Column.invoke_expression_over_column(
|
||||
col, expression.DateSub, expression=days, unit=expression.Var(this="day")
|
||||
col, expression.DateSub, expression=days, unit=expression.Var(this="DAY")
|
||||
)
|
||||
|
||||
|
||||
|
@ -635,7 +635,7 @@ def next_day(col: ColumnOrName, dayOfWeek: str) -> Column:
|
|||
|
||||
|
||||
def last_day(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_anonymous_function(col, "LAST_DAY")
|
||||
return Column.invoke_expression_over_column(col, expression.LastDay)
|
||||
|
||||
|
||||
def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
|
||||
|
|
|
@ -16,20 +16,22 @@ from sqlglot.dialects.dialect import (
|
|||
format_time_lambda,
|
||||
if_sql,
|
||||
inline_array_sql,
|
||||
json_keyvalue_comma_sql,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
no_ilike_sql,
|
||||
parse_date_delta_with_interval,
|
||||
path_to_jsonpath,
|
||||
regexp_replace_sql,
|
||||
rename_func,
|
||||
timestrtotime_sql,
|
||||
ts_or_ds_add_cast,
|
||||
ts_or_ds_to_date_sql,
|
||||
)
|
||||
from sqlglot.helper import seq_get, split_num_words
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from typing_extensions import Literal
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
|
||||
|
@ -206,12 +208,17 @@ def _unix_to_time_sql(self: BigQuery.Generator, expression: exp.UnixToTime) -> s
|
|||
return f"TIMESTAMP_MILLIS({timestamp})"
|
||||
if scale == exp.UnixToTime.MICROS:
|
||||
return f"TIMESTAMP_MICROS({timestamp})"
|
||||
if scale == exp.UnixToTime.NANOS:
|
||||
# We need to cast to INT64 because that's what BQ expects
|
||||
return f"TIMESTAMP_MICROS(CAST({timestamp} / 1000 AS INT64))"
|
||||
|
||||
self.unsupported(f"Unsupported scale for timestamp: {scale}.")
|
||||
return ""
|
||||
return f"TIMESTAMP_SECONDS(CAST({timestamp} / POW(10, {scale}) AS INT64))"
|
||||
|
||||
|
||||
def _parse_time(args: t.List) -> exp.Func:
|
||||
if len(args) == 1:
|
||||
return exp.TsOrDsToTime(this=args[0])
|
||||
if len(args) == 3:
|
||||
return exp.TimeFromParts.from_arg_list(args)
|
||||
|
||||
return exp.Anonymous(this="TIME", expressions=args)
|
||||
|
||||
|
||||
class BigQuery(Dialect):
|
||||
|
@ -329,7 +336,13 @@ class BigQuery(Dialect):
|
|||
"DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd),
|
||||
"DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub),
|
||||
"DIV": binary_from_function(exp.IntDiv),
|
||||
"FORMAT_DATE": lambda args: exp.TimeToStr(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 1)), format=seq_get(args, 0)
|
||||
),
|
||||
"GENERATE_ARRAY": exp.GenerateSeries.from_arg_list,
|
||||
"JSON_EXTRACT_SCALAR": lambda args: exp.JSONExtractScalar(
|
||||
this=seq_get(args, 0), expression=seq_get(args, 1) or exp.Literal.string("$")
|
||||
),
|
||||
"MD5": exp.MD5Digest.from_arg_list,
|
||||
"TO_HEX": _parse_to_hex,
|
||||
"PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")(
|
||||
|
@ -351,6 +364,7 @@ class BigQuery(Dialect):
|
|||
this=seq_get(args, 0),
|
||||
expression=seq_get(args, 1) or exp.Literal.string(","),
|
||||
),
|
||||
"TIME": _parse_time,
|
||||
"TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd),
|
||||
"TIME_SUB": parse_date_delta_with_interval(exp.TimeSub),
|
||||
"TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd),
|
||||
|
@ -361,9 +375,7 @@ class BigQuery(Dialect):
|
|||
"TIMESTAMP_MILLIS": lambda args: exp.UnixToTime(
|
||||
this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS
|
||||
),
|
||||
"TIMESTAMP_SECONDS": lambda args: exp.UnixToTime(
|
||||
this=seq_get(args, 0), scale=exp.UnixToTime.SECONDS
|
||||
),
|
||||
"TIMESTAMP_SECONDS": lambda args: exp.UnixToTime(this=seq_get(args, 0)),
|
||||
"TO_JSON_STRING": exp.JSONFormat.from_arg_list,
|
||||
}
|
||||
|
||||
|
@ -460,7 +472,15 @@ class BigQuery(Dialect):
|
|||
|
||||
return table
|
||||
|
||||
def _parse_json_object(self) -> exp.JSONObject:
|
||||
@t.overload
|
||||
def _parse_json_object(self, agg: Literal[False]) -> exp.JSONObject:
|
||||
...
|
||||
|
||||
@t.overload
|
||||
def _parse_json_object(self, agg: Literal[True]) -> exp.JSONObjectAgg:
|
||||
...
|
||||
|
||||
def _parse_json_object(self, agg=False):
|
||||
json_object = super()._parse_json_object()
|
||||
array_kv_pair = seq_get(json_object.expressions, 0)
|
||||
|
||||
|
@ -513,6 +533,10 @@ class BigQuery(Dialect):
|
|||
UNNEST_WITH_ORDINALITY = False
|
||||
COLLATE_IS_FUNC = True
|
||||
LIMIT_ONLY_LITERALS = True
|
||||
SUPPORTS_TABLE_ALIAS_COLUMNS = False
|
||||
UNPIVOT_ALIASES_ARE_IDENTIFIERS = False
|
||||
JSON_KEY_VALUE_PAIR_SEP = ","
|
||||
NULL_ORDERING_SUPPORTED = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -525,6 +549,7 @@ class BigQuery(Dialect):
|
|||
exp.CollateProperty: lambda self, e: f"DEFAULT COLLATE {self.sql(e, 'this')}"
|
||||
if e.args.get("default")
|
||||
else f"COLLATE {self.sql(e, 'this')}",
|
||||
exp.CountIf: rename_func("COUNTIF"),
|
||||
exp.Create: _create_sql,
|
||||
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
|
||||
exp.DateAdd: date_add_interval_sql("DATE", "ADD"),
|
||||
|
@ -536,13 +561,13 @@ class BigQuery(Dialect):
|
|||
exp.DatetimeSub: date_add_interval_sql("DATETIME", "SUB"),
|
||||
exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")),
|
||||
exp.GenerateSeries: rename_func("GENERATE_ARRAY"),
|
||||
exp.GetPath: path_to_jsonpath(),
|
||||
exp.GroupConcat: rename_func("STRING_AGG"),
|
||||
exp.Hex: rename_func("TO_HEX"),
|
||||
exp.If: if_sql(false_value="NULL"),
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.IntDiv: rename_func("DIV"),
|
||||
exp.JSONFormat: rename_func("TO_JSON_STRING"),
|
||||
exp.JSONKeyValue: json_keyvalue_comma_sql,
|
||||
exp.Max: max_or_greatest,
|
||||
exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)),
|
||||
exp.MD5Digest: rename_func("MD5"),
|
||||
|
@ -578,16 +603,17 @@ class BigQuery(Dialect):
|
|||
"PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone")
|
||||
),
|
||||
exp.TimeAdd: date_add_interval_sql("TIME", "ADD"),
|
||||
exp.TimeFromParts: rename_func("TIME"),
|
||||
exp.TimeSub: date_add_interval_sql("TIME", "SUB"),
|
||||
exp.TimestampAdd: date_add_interval_sql("TIMESTAMP", "ADD"),
|
||||
exp.TimestampSub: date_add_interval_sql("TIMESTAMP", "SUB"),
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeToStr: lambda self, e: f"FORMAT_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
|
||||
exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression),
|
||||
exp.TsOrDsAdd: _ts_or_ds_add_sql,
|
||||
exp.TsOrDsDiff: _ts_or_ds_diff_sql,
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"),
|
||||
exp.TsOrDsToTime: rename_func("TIME"),
|
||||
exp.Unhex: rename_func("FROM_HEX"),
|
||||
exp.UnixDate: rename_func("UNIX_DATE"),
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.Values: _derived_table_values_to_unnest,
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
|
@ -724,6 +750,26 @@ class BigQuery(Dialect):
|
|||
"within",
|
||||
}
|
||||
|
||||
def timetostr_sql(self, expression: exp.TimeToStr) -> str:
|
||||
if isinstance(expression.this, exp.TsOrDsToDate):
|
||||
this: exp.Expression = expression.this
|
||||
else:
|
||||
this = expression
|
||||
|
||||
return f"FORMAT_DATE({self.format_time(expression)}, {self.sql(this, 'this')})"
|
||||
|
||||
def struct_sql(self, expression: exp.Struct) -> str:
|
||||
args = []
|
||||
for expr in expression.expressions:
|
||||
if isinstance(expr, self.KEY_VALUE_DEFINITIONS):
|
||||
arg = f"{self.sql(expr, 'expression')} AS {expr.this.name}"
|
||||
else:
|
||||
arg = self.sql(expr)
|
||||
|
||||
args.append(arg)
|
||||
|
||||
return self.func("STRUCT", *args)
|
||||
|
||||
def eq_sql(self, expression: exp.EQ) -> str:
|
||||
# Operands of = cannot be NULL in BigQuery
|
||||
if isinstance(expression.left, exp.Null) or isinstance(expression.right, exp.Null):
|
||||
|
@ -760,7 +806,20 @@ class BigQuery(Dialect):
|
|||
return inline_array_sql(self, expression)
|
||||
|
||||
def bracket_sql(self, expression: exp.Bracket) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
expressions = expression.expressions
|
||||
|
||||
if len(expressions) == 1:
|
||||
arg = expressions[0]
|
||||
if arg.type is None:
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
arg = annotate_types(arg)
|
||||
|
||||
if arg.type and arg.type.this in exp.DataType.TEXT_TYPES:
|
||||
# BQ doesn't support bracket syntax with string values
|
||||
return f"{this}.{arg.name}"
|
||||
|
||||
expressions_sql = ", ".join(self.sql(e) for e in expressions)
|
||||
offset = expression.args.get("offset")
|
||||
|
||||
|
@ -768,13 +827,13 @@ class BigQuery(Dialect):
|
|||
expressions_sql = f"OFFSET({expressions_sql})"
|
||||
elif offset == 1:
|
||||
expressions_sql = f"ORDINAL({expressions_sql})"
|
||||
else:
|
||||
elif offset is not None:
|
||||
self.unsupported(f"Unsupported array offset: {offset}")
|
||||
|
||||
if expression.args.get("safe"):
|
||||
expressions_sql = f"SAFE_{expressions_sql}"
|
||||
|
||||
return f"{self.sql(expression, 'this')}[{expressions_sql}]"
|
||||
return f"{this}[{expressions_sql}]"
|
||||
|
||||
def transaction_sql(self, *_) -> str:
|
||||
return "BEGIN TRANSACTION"
|
||||
|
|
|
@ -6,6 +6,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
|
|||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
arg_max_or_min_no_count,
|
||||
date_delta_sql,
|
||||
inline_array_sql,
|
||||
no_pivot_sql,
|
||||
rename_func,
|
||||
|
@ -22,16 +23,25 @@ def _lower_func(sql: str) -> str:
|
|||
return sql[:index].lower() + sql[index:]
|
||||
|
||||
|
||||
def _quantile_sql(self, e):
|
||||
def _quantile_sql(self: ClickHouse.Generator, e: exp.Quantile) -> str:
|
||||
quantile = e.args["quantile"]
|
||||
args = f"({self.sql(e, 'this')})"
|
||||
|
||||
if isinstance(quantile, exp.Array):
|
||||
func = self.func("quantiles", *quantile)
|
||||
else:
|
||||
func = self.func("quantile", quantile)
|
||||
|
||||
return func + args
|
||||
|
||||
|
||||
def _parse_count_if(args: t.List) -> exp.CountIf | exp.CombinedAggFunc:
|
||||
if len(args) == 1:
|
||||
return exp.CountIf(this=seq_get(args, 0))
|
||||
|
||||
return exp.CombinedAggFunc(this="countIf", expressions=args, parts=("count", "If"))
|
||||
|
||||
|
||||
class ClickHouse(Dialect):
|
||||
NORMALIZE_FUNCTIONS: bool | str = False
|
||||
NULL_ORDERING = "nulls_are_last"
|
||||
|
@ -53,6 +63,7 @@ class ClickHouse(Dialect):
|
|||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"ATTACH": TokenType.COMMAND,
|
||||
"DATE32": TokenType.DATE32,
|
||||
"DATETIME64": TokenType.DATETIME64,
|
||||
"DICTIONARY": TokenType.DICTIONARY,
|
||||
"ENUM": TokenType.ENUM,
|
||||
|
@ -75,6 +86,8 @@ class ClickHouse(Dialect):
|
|||
"UINT32": TokenType.UINT,
|
||||
"UINT64": TokenType.UBIGINT,
|
||||
"UINT8": TokenType.UTINYINT,
|
||||
"IPV4": TokenType.IPV4,
|
||||
"IPV6": TokenType.IPV6,
|
||||
}
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
|
@ -91,6 +104,8 @@ class ClickHouse(Dialect):
|
|||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"ANY": exp.AnyValue.from_arg_list,
|
||||
"ARRAYSUM": exp.ArraySum.from_arg_list,
|
||||
"COUNTIF": _parse_count_if,
|
||||
"DATE_ADD": lambda args: exp.DateAdd(
|
||||
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
),
|
||||
|
@ -110,6 +125,138 @@ class ClickHouse(Dialect):
|
|||
"XOR": lambda args: exp.Xor(expressions=args),
|
||||
}
|
||||
|
||||
AGG_FUNCTIONS = {
|
||||
"count",
|
||||
"min",
|
||||
"max",
|
||||
"sum",
|
||||
"avg",
|
||||
"any",
|
||||
"stddevPop",
|
||||
"stddevSamp",
|
||||
"varPop",
|
||||
"varSamp",
|
||||
"corr",
|
||||
"covarPop",
|
||||
"covarSamp",
|
||||
"entropy",
|
||||
"exponentialMovingAverage",
|
||||
"intervalLengthSum",
|
||||
"kolmogorovSmirnovTest",
|
||||
"mannWhitneyUTest",
|
||||
"median",
|
||||
"rankCorr",
|
||||
"sumKahan",
|
||||
"studentTTest",
|
||||
"welchTTest",
|
||||
"anyHeavy",
|
||||
"anyLast",
|
||||
"boundingRatio",
|
||||
"first_value",
|
||||
"last_value",
|
||||
"argMin",
|
||||
"argMax",
|
||||
"avgWeighted",
|
||||
"topK",
|
||||
"topKWeighted",
|
||||
"deltaSum",
|
||||
"deltaSumTimestamp",
|
||||
"groupArray",
|
||||
"groupArrayLast",
|
||||
"groupUniqArray",
|
||||
"groupArrayInsertAt",
|
||||
"groupArrayMovingAvg",
|
||||
"groupArrayMovingSum",
|
||||
"groupArraySample",
|
||||
"groupBitAnd",
|
||||
"groupBitOr",
|
||||
"groupBitXor",
|
||||
"groupBitmap",
|
||||
"groupBitmapAnd",
|
||||
"groupBitmapOr",
|
||||
"groupBitmapXor",
|
||||
"sumWithOverflow",
|
||||
"sumMap",
|
||||
"minMap",
|
||||
"maxMap",
|
||||
"skewSamp",
|
||||
"skewPop",
|
||||
"kurtSamp",
|
||||
"kurtPop",
|
||||
"uniq",
|
||||
"uniqExact",
|
||||
"uniqCombined",
|
||||
"uniqCombined64",
|
||||
"uniqHLL12",
|
||||
"uniqTheta",
|
||||
"quantile",
|
||||
"quantiles",
|
||||
"quantileExact",
|
||||
"quantilesExact",
|
||||
"quantileExactLow",
|
||||
"quantilesExactLow",
|
||||
"quantileExactHigh",
|
||||
"quantilesExactHigh",
|
||||
"quantileExactWeighted",
|
||||
"quantilesExactWeighted",
|
||||
"quantileTiming",
|
||||
"quantilesTiming",
|
||||
"quantileTimingWeighted",
|
||||
"quantilesTimingWeighted",
|
||||
"quantileDeterministic",
|
||||
"quantilesDeterministic",
|
||||
"quantileTDigest",
|
||||
"quantilesTDigest",
|
||||
"quantileTDigestWeighted",
|
||||
"quantilesTDigestWeighted",
|
||||
"quantileBFloat16",
|
||||
"quantilesBFloat16",
|
||||
"quantileBFloat16Weighted",
|
||||
"quantilesBFloat16Weighted",
|
||||
"simpleLinearRegression",
|
||||
"stochasticLinearRegression",
|
||||
"stochasticLogisticRegression",
|
||||
"categoricalInformationValue",
|
||||
"contingency",
|
||||
"cramersV",
|
||||
"cramersVBiasCorrected",
|
||||
"theilsU",
|
||||
"maxIntersections",
|
||||
"maxIntersectionsPosition",
|
||||
"meanZTest",
|
||||
"quantileInterpolatedWeighted",
|
||||
"quantilesInterpolatedWeighted",
|
||||
"quantileGK",
|
||||
"quantilesGK",
|
||||
"sparkBar",
|
||||
"sumCount",
|
||||
"largestTriangleThreeBuckets",
|
||||
}
|
||||
|
||||
AGG_FUNCTIONS_SUFFIXES = [
|
||||
"If",
|
||||
"Array",
|
||||
"ArrayIf",
|
||||
"Map",
|
||||
"SimpleState",
|
||||
"State",
|
||||
"Merge",
|
||||
"MergeState",
|
||||
"ForEach",
|
||||
"Distinct",
|
||||
"OrDefault",
|
||||
"OrNull",
|
||||
"Resample",
|
||||
"ArgMin",
|
||||
"ArgMax",
|
||||
]
|
||||
|
||||
AGG_FUNC_MAPPING = (
|
||||
lambda functions, suffixes: {
|
||||
f"{f}{sfx}": (f, sfx) for sfx in (suffixes + [""]) for f in functions
|
||||
}
|
||||
)(AGG_FUNCTIONS, AGG_FUNCTIONS_SUFFIXES)
|
||||
|
||||
FUNCTIONS_WITH_ALIASED_ARGS = {*parser.Parser.FUNCTIONS_WITH_ALIASED_ARGS, "TUPLE"}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
|
@ -272,9 +419,18 @@ class ClickHouse(Dialect):
|
|||
)
|
||||
|
||||
if isinstance(func, exp.Anonymous):
|
||||
parts = self.AGG_FUNC_MAPPING.get(func.this)
|
||||
params = self._parse_func_params(func)
|
||||
|
||||
if params:
|
||||
if parts and parts[1]:
|
||||
return self.expression(
|
||||
exp.CombinedParameterizedAgg,
|
||||
this=func.this,
|
||||
expressions=func.expressions,
|
||||
params=params,
|
||||
parts=parts,
|
||||
)
|
||||
return self.expression(
|
||||
exp.ParameterizedAgg,
|
||||
this=func.this,
|
||||
|
@ -282,6 +438,20 @@ class ClickHouse(Dialect):
|
|||
params=params,
|
||||
)
|
||||
|
||||
if parts:
|
||||
if parts[1]:
|
||||
return self.expression(
|
||||
exp.CombinedAggFunc,
|
||||
this=func.this,
|
||||
expressions=func.expressions,
|
||||
parts=parts,
|
||||
)
|
||||
return self.expression(
|
||||
exp.AnonymousAggFunc,
|
||||
this=func.this,
|
||||
expressions=func.expressions,
|
||||
)
|
||||
|
||||
return func
|
||||
|
||||
def _parse_func_params(
|
||||
|
@ -329,6 +499,9 @@ class ClickHouse(Dialect):
|
|||
STRUCT_DELIMITER = ("(", ")")
|
||||
NVL2_SUPPORTED = False
|
||||
TABLESAMPLE_REQUIRES_PARENS = False
|
||||
TABLESAMPLE_SIZE_IS_ROWS = False
|
||||
TABLESAMPLE_KEYWORDS = "SAMPLE"
|
||||
LAST_DAY_SUPPORTS_DATE_PART = False
|
||||
|
||||
STRING_TYPE_MAPPING = {
|
||||
exp.DataType.Type.CHAR: "String",
|
||||
|
@ -348,6 +521,7 @@ class ClickHouse(Dialect):
|
|||
**STRING_TYPE_MAPPING,
|
||||
exp.DataType.Type.ARRAY: "Array",
|
||||
exp.DataType.Type.BIGINT: "Int64",
|
||||
exp.DataType.Type.DATE32: "Date32",
|
||||
exp.DataType.Type.DATETIME64: "DateTime64",
|
||||
exp.DataType.Type.DOUBLE: "Float64",
|
||||
exp.DataType.Type.ENUM: "Enum",
|
||||
|
@ -372,24 +546,23 @@ class ClickHouse(Dialect):
|
|||
exp.DataType.Type.UINT256: "UInt256",
|
||||
exp.DataType.Type.USMALLINT: "UInt16",
|
||||
exp.DataType.Type.UTINYINT: "UInt8",
|
||||
exp.DataType.Type.IPV4: "IPv4",
|
||||
exp.DataType.Type.IPV6: "IPv6",
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_qualify]),
|
||||
exp.AnyValue: rename_func("any"),
|
||||
exp.ApproxDistinct: rename_func("uniq"),
|
||||
exp.ArraySum: rename_func("arraySum"),
|
||||
exp.ArgMax: arg_max_or_min_no_count("argMax"),
|
||||
exp.ArgMin: arg_max_or_min_no_count("argMin"),
|
||||
exp.Array: inline_array_sql,
|
||||
exp.CastToStrType: rename_func("CAST"),
|
||||
exp.CountIf: rename_func("countIf"),
|
||||
exp.CurrentDate: lambda self, e: self.func("CURRENT_DATE"),
|
||||
exp.DateAdd: lambda self, e: self.func(
|
||||
"DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
|
||||
),
|
||||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
|
||||
),
|
||||
exp.DateAdd: date_delta_sql("DATE_ADD"),
|
||||
exp.DateDiff: date_delta_sql("DATE_DIFF"),
|
||||
exp.Explode: rename_func("arrayJoin"),
|
||||
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
|
||||
exp.IsNan: rename_func("isNaN"),
|
||||
|
@ -400,6 +573,7 @@ class ClickHouse(Dialect):
|
|||
exp.Quantile: _quantile_sql,
|
||||
exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})",
|
||||
exp.Rand: rename_func("randCanonical"),
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_qualify]),
|
||||
exp.StartsWith: rename_func("startsWith"),
|
||||
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
|
||||
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||
|
@ -485,10 +659,19 @@ class ClickHouse(Dialect):
|
|||
else "",
|
||||
]
|
||||
|
||||
def parameterizedagg_sql(self, expression: exp.Anonymous) -> str:
|
||||
def parameterizedagg_sql(self, expression: exp.ParameterizedAgg) -> str:
|
||||
params = self.expressions(expression, key="params", flat=True)
|
||||
return self.func(expression.name, *expression.expressions) + f"({params})"
|
||||
|
||||
def anonymousaggfunc_sql(self, expression: exp.AnonymousAggFunc) -> str:
|
||||
return self.func(expression.name, *expression.expressions)
|
||||
|
||||
def combinedaggfunc_sql(self, expression: exp.CombinedAggFunc) -> str:
|
||||
return self.anonymousaggfunc_sql(expression)
|
||||
|
||||
def combinedparameterizedagg_sql(self, expression: exp.CombinedParameterizedAgg) -> str:
|
||||
return self.parameterizedagg_sql(expression)
|
||||
|
||||
def placeholder_sql(self, expression: exp.Placeholder) -> str:
|
||||
return f"{{{expression.name}: {self.sql(expression, 'kind')}}}"
|
||||
|
||||
|
|
|
@ -30,6 +30,8 @@ class Databricks(Spark):
|
|||
}
|
||||
|
||||
class Generator(Spark.Generator):
|
||||
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
|
||||
|
||||
TRANSFORMS = {
|
||||
**Spark.Generator.TRANSFORMS,
|
||||
exp.DateAdd: date_delta_sql("DATEADD"),
|
||||
|
|
|
@ -153,6 +153,9 @@ class Dialect(metaclass=_Dialect):
|
|||
ALIAS_POST_TABLESAMPLE = False
|
||||
"""Determines whether or not the table alias comes after tablesample."""
|
||||
|
||||
TABLESAMPLE_SIZE_IS_PERCENT = False
|
||||
"""Determines whether or not a size in the table sample clause represents percentage."""
|
||||
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
|
||||
"""Specifies the strategy according to which identifiers should be normalized."""
|
||||
|
||||
|
@ -220,6 +223,24 @@ class Dialect(metaclass=_Dialect):
|
|||
For example, such columns may be excluded from `SELECT *` queries.
|
||||
"""
|
||||
|
||||
PREFER_CTE_ALIAS_COLUMN = False
|
||||
"""
|
||||
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
|
||||
HAVING clause of the CTE. This flag will cause the CTE alias columns to override
|
||||
any projection aliases in the subquery.
|
||||
|
||||
For example,
|
||||
WITH y(c) AS (
|
||||
SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
|
||||
) SELECT c FROM y;
|
||||
|
||||
will be rewritten as
|
||||
|
||||
WITH y(c) AS (
|
||||
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
|
||||
) SELECT c FROM y;
|
||||
"""
|
||||
|
||||
# --- Autofilled ---
|
||||
|
||||
tokenizer_class = Tokenizer
|
||||
|
@ -287,7 +308,13 @@ class Dialect(metaclass=_Dialect):
|
|||
|
||||
result = cls.get(dialect_name.strip())
|
||||
if not result:
|
||||
raise ValueError(f"Unknown dialect '{dialect_name}'.")
|
||||
from difflib import get_close_matches
|
||||
|
||||
similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
|
||||
if similar:
|
||||
similar = f" Did you mean {similar}?"
|
||||
|
||||
raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
|
||||
|
||||
return result(**kwargs)
|
||||
|
||||
|
@ -506,7 +533,7 @@ def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
|
|||
def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
|
||||
n = self.sql(expression, "this")
|
||||
d = self.sql(expression, "expression")
|
||||
return f"IF({d} <> 0, {n} / {d}, NULL)"
|
||||
return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
|
||||
|
||||
|
||||
def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
|
||||
|
@ -695,7 +722,7 @@ def date_add_interval_sql(
|
|||
|
||||
def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
|
||||
return self.func(
|
||||
"DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
|
||||
"DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
|
||||
)
|
||||
|
||||
|
||||
|
@ -801,22 +828,6 @@ def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
|
|||
return self.func("STRPTIME", expression.this, self.format_time(expression))
|
||||
|
||||
|
||||
def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
|
||||
def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
|
||||
_dialect = Dialect.get_or_raise(dialect)
|
||||
time_format = self.format_time(expression)
|
||||
if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
|
||||
return self.sql(
|
||||
exp.cast(
|
||||
exp.StrToTime(this=expression.this, format=expression.args["format"]),
|
||||
"date",
|
||||
)
|
||||
)
|
||||
return self.sql(exp.cast(expression.this, "date"))
|
||||
|
||||
return _ts_or_ds_to_date_sql
|
||||
|
||||
|
||||
def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
|
||||
return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
|
||||
|
||||
|
@ -894,11 +905,6 @@ def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
|
|||
return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
|
||||
|
||||
|
||||
# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon
|
||||
def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str:
|
||||
return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
|
||||
|
||||
|
||||
def is_parse_json(expression: exp.Expression) -> bool:
|
||||
return isinstance(expression, exp.ParseJSON) or (
|
||||
isinstance(expression, exp.Cast) and expression.is_type("json")
|
||||
|
@ -946,7 +952,70 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE
|
|||
expression = ts_or_ds_add_cast(expression)
|
||||
|
||||
return self.func(
|
||||
name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this
|
||||
name,
|
||||
exp.var(expression.text("unit").upper() or "DAY"),
|
||||
expression.expression,
|
||||
expression.this,
|
||||
)
|
||||
|
||||
return _delta_sql
|
||||
|
||||
|
||||
def prepend_dollar_to_path(expression: exp.GetPath) -> exp.GetPath:
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
# Makes sure the path will be evaluated correctly at runtime to include the path root.
|
||||
# For example, `[0].foo` will become `$[0].foo`, and `foo` will become `$.foo`.
|
||||
path = expression.expression
|
||||
path = exp.func(
|
||||
"if",
|
||||
exp.func("startswith", path, "'['"),
|
||||
exp.func("concat", "'$'", path),
|
||||
exp.func("concat", "'$.'", path),
|
||||
)
|
||||
|
||||
expression.expression.replace(simplify(path))
|
||||
return expression
|
||||
|
||||
|
||||
def path_to_jsonpath(
|
||||
name: str = "JSON_EXTRACT",
|
||||
) -> t.Callable[[Generator, exp.GetPath], str]:
|
||||
def _transform(self: Generator, expression: exp.GetPath) -> str:
|
||||
return rename_func(name)(self, prepend_dollar_to_path(expression))
|
||||
|
||||
return _transform
|
||||
|
||||
|
||||
def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
|
||||
trunc_curr_date = exp.func("date_trunc", "month", expression.this)
|
||||
plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
|
||||
minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
|
||||
|
||||
return self.sql(exp.cast(minus_one_day, "date"))
|
||||
|
||||
|
||||
def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
|
||||
"""Remove table refs from columns in when statements."""
|
||||
alias = expression.this.args.get("alias")
|
||||
|
||||
normalize = (
|
||||
lambda identifier: self.dialect.normalize_identifier(identifier).name
|
||||
if identifier
|
||||
else None
|
||||
)
|
||||
|
||||
targets = {normalize(expression.this.this)}
|
||||
|
||||
if alias:
|
||||
targets.add(normalize(alias.this))
|
||||
|
||||
for when in expression.expressions:
|
||||
when.transform(
|
||||
lambda node: exp.column(node.this)
|
||||
if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
|
||||
else node,
|
||||
copy=False,
|
||||
)
|
||||
|
||||
return self.merge_sql(expression)
|
||||
|
|
|
@ -22,6 +22,7 @@ class Doris(MySQL):
|
|||
"COLLECT_SET": exp.ArrayUniqueAgg.from_arg_list,
|
||||
"DATE_TRUNC": parse_timestamp_trunc,
|
||||
"REGEXP": exp.RegexpLike.from_arg_list,
|
||||
"TO_DATE": exp.TsOrDsToDate.from_arg_list,
|
||||
}
|
||||
|
||||
class Generator(MySQL.Generator):
|
||||
|
@ -34,21 +35,26 @@ class Doris(MySQL):
|
|||
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
|
||||
}
|
||||
|
||||
LAST_DAY_SUPPORTS_DATE_PART = False
|
||||
|
||||
TIMESTAMP_FUNC_TYPES = set()
|
||||
|
||||
TRANSFORMS = {
|
||||
**MySQL.Generator.TRANSFORMS,
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.ArgMax: rename_func("MAX_BY"),
|
||||
exp.ArgMin: rename_func("MIN_BY"),
|
||||
exp.ArrayAgg: rename_func("COLLECT_LIST"),
|
||||
exp.ArrayUniqueAgg: rename_func("COLLECT_SET"),
|
||||
exp.CurrentTimestamp: lambda *_: "NOW()",
|
||||
exp.DateTrunc: lambda self, e: self.func(
|
||||
"DATE_TRUNC", e.this, "'" + e.text("unit") + "'"
|
||||
),
|
||||
exp.JSONExtractScalar: arrow_json_extract_sql,
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.Map: rename_func("ARRAY_MAP"),
|
||||
exp.RegexpLike: rename_func("REGEXP"),
|
||||
exp.RegexpSplit: rename_func("SPLIT_BY_STRING"),
|
||||
exp.ArrayUniqueAgg: rename_func("COLLECT_SET"),
|
||||
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.Split: rename_func("SPLIT_BY_STRING"),
|
||||
exp.TimeStrToDate: rename_func("TO_DATE"),
|
||||
|
@ -63,5 +69,4 @@ class Doris(MySQL):
|
|||
"FROM_UNIXTIME", e.this, time_format("doris")(self, e)
|
||||
),
|
||||
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
||||
exp.Map: rename_func("ARRAY_MAP"),
|
||||
}
|
||||
|
|
|
@ -12,7 +12,6 @@ from sqlglot.dialects.dialect import (
|
|||
rename_func,
|
||||
str_position_sql,
|
||||
timestrtotime_sql,
|
||||
ts_or_ds_to_date_sql,
|
||||
)
|
||||
|
||||
|
||||
|
@ -99,6 +98,7 @@ class Drill(Dialect):
|
|||
TABLE_HINTS = False
|
||||
QUERY_HINTS = False
|
||||
NVL2_SUPPORTED = False
|
||||
LAST_DAY_SUPPORTS_DATE_PART = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
|
@ -150,7 +150,6 @@ class Drill(Dialect):
|
|||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})",
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("drill"),
|
||||
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
|
||||
}
|
||||
|
||||
|
|
|
@ -22,15 +22,15 @@ from sqlglot.dialects.dialect import (
|
|||
no_safe_divide_sql,
|
||||
no_timestamp_sql,
|
||||
pivot_column_names,
|
||||
prepend_dollar_to_path,
|
||||
regexp_extract_sql,
|
||||
rename_func,
|
||||
str_position_sql,
|
||||
str_to_time_sql,
|
||||
timestamptrunc_sql,
|
||||
timestrtotime_sql,
|
||||
ts_or_ds_to_date_sql,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.helper import flatten, seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
|
@ -141,11 +141,25 @@ def _unix_to_time_sql(self: DuckDB.Generator, expression: exp.UnixToTime) -> str
|
|||
return f"EPOCH_MS({timestamp})"
|
||||
if scale == exp.UnixToTime.MICROS:
|
||||
return f"MAKE_TIMESTAMP({timestamp})"
|
||||
if scale == exp.UnixToTime.NANOS:
|
||||
return f"TO_TIMESTAMP({timestamp} / 1000000000)"
|
||||
|
||||
self.unsupported(f"Unsupported scale for timestamp: {scale}.")
|
||||
return ""
|
||||
return f"TO_TIMESTAMP({timestamp} / POW(10, {scale}))"
|
||||
|
||||
|
||||
def _rename_unless_within_group(
|
||||
a: str, b: str
|
||||
) -> t.Callable[[DuckDB.Generator, exp.Expression], str]:
|
||||
return (
|
||||
lambda self, expression: self.func(a, *flatten(expression.args.values()))
|
||||
if isinstance(expression.find_ancestor(exp.Select, exp.WithinGroup), exp.WithinGroup)
|
||||
else self.func(b, *flatten(expression.args.values()))
|
||||
)
|
||||
|
||||
|
||||
def _parse_struct_pack(args: t.List) -> exp.Struct:
|
||||
args_with_columns_as_identifiers = [
|
||||
exp.PropertyEQ(this=arg.this.this, expression=arg.expression) for arg in args
|
||||
]
|
||||
return exp.Struct.from_arg_list(args_with_columns_as_identifiers)
|
||||
|
||||
|
||||
class DuckDB(Dialect):
|
||||
|
@ -183,6 +197,11 @@ class DuckDB(Dialect):
|
|||
"TIMESTAMP_US": TokenType.TIMESTAMP,
|
||||
}
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
**tokens.Tokenizer.SINGLE_TOKENS,
|
||||
"$": TokenType.PARAMETER,
|
||||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
BITWISE = {
|
||||
**parser.Parser.BITWISE,
|
||||
|
@ -209,10 +228,12 @@ class DuckDB(Dialect):
|
|||
"EPOCH_MS": lambda args: exp.UnixToTime(
|
||||
this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS
|
||||
),
|
||||
"JSON": exp.ParseJSON.from_arg_list,
|
||||
"LIST_HAS": exp.ArrayContains.from_arg_list,
|
||||
"LIST_REVERSE_SORT": _sort_array_reverse,
|
||||
"LIST_SORT": exp.SortArray.from_arg_list,
|
||||
"LIST_VALUE": exp.Array.from_arg_list,
|
||||
"MAKE_TIME": exp.TimeFromParts.from_arg_list,
|
||||
"MAKE_TIMESTAMP": _parse_make_timestamp,
|
||||
"MEDIAN": lambda args: exp.PercentileCont(
|
||||
this=seq_get(args, 0), expression=exp.Literal.number(0.5)
|
||||
|
@ -234,7 +255,7 @@ class DuckDB(Dialect):
|
|||
"STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
|
||||
"STRING_TO_ARRAY": exp.Split.from_arg_list,
|
||||
"STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"),
|
||||
"STRUCT_PACK": exp.Struct.from_arg_list,
|
||||
"STRUCT_PACK": _parse_struct_pack,
|
||||
"STR_SPLIT": exp.Split.from_arg_list,
|
||||
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
|
||||
"TO_TIMESTAMP": exp.UnixToTime.from_arg_list,
|
||||
|
@ -250,6 +271,13 @@ class DuckDB(Dialect):
|
|||
TokenType.ANTI,
|
||||
}
|
||||
|
||||
PLACEHOLDER_PARSERS = {
|
||||
**parser.Parser.PLACEHOLDER_PARSERS,
|
||||
TokenType.PARAMETER: lambda self: self.expression(exp.Placeholder, this=self._prev.text)
|
||||
if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS)
|
||||
else None,
|
||||
}
|
||||
|
||||
def _parse_types(
|
||||
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
|
||||
) -> t.Optional[exp.Expression]:
|
||||
|
@ -268,7 +296,7 @@ class DuckDB(Dialect):
|
|||
|
||||
return this
|
||||
|
||||
def _parse_struct_types(self) -> t.Optional[exp.Expression]:
|
||||
def _parse_struct_types(self, type_required: bool = False) -> t.Optional[exp.Expression]:
|
||||
return self._parse_field_def()
|
||||
|
||||
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
|
||||
|
@ -285,6 +313,10 @@ class DuckDB(Dialect):
|
|||
RENAME_TABLE_WITH_DB = False
|
||||
NVL2_SUPPORTED = False
|
||||
SEMI_ANTI_JOIN_WITH_SIDE = False
|
||||
TABLESAMPLE_KEYWORDS = "USING SAMPLE"
|
||||
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
|
||||
LAST_DAY_SUPPORTS_DATE_PART = False
|
||||
JSON_KEY_VALUE_PAIR_SEP = ","
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -311,7 +343,7 @@ class DuckDB(Dialect):
|
|||
exp.DateFromParts: rename_func("MAKE_DATE"),
|
||||
exp.DateSub: _date_delta_sql,
|
||||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF", f"'{e.args.get('unit') or 'day'}'", e.expression, e.this
|
||||
"DATE_DIFF", f"'{e.args.get('unit') or 'DAY'}'", e.expression, e.this
|
||||
),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)",
|
||||
|
@ -322,11 +354,11 @@ class DuckDB(Dialect):
|
|||
exp.IntDiv: lambda self, e: self.binary(e, "//"),
|
||||
exp.IsInf: rename_func("ISINF"),
|
||||
exp.IsNan: rename_func("ISNAN"),
|
||||
exp.JSONBExtract: arrow_json_extract_sql,
|
||||
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.JSONFormat: _json_format_sql,
|
||||
exp.JSONBExtract: arrow_json_extract_sql,
|
||||
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.MonthsBetween: lambda self, e: self.func(
|
||||
|
@ -336,8 +368,8 @@ class DuckDB(Dialect):
|
|||
exp.cast(e.this, "timestamp", copy=True),
|
||||
),
|
||||
exp.ParseJSON: rename_func("JSON"),
|
||||
exp.PercentileCont: rename_func("QUANTILE_CONT"),
|
||||
exp.PercentileDisc: rename_func("QUANTILE_DISC"),
|
||||
exp.PercentileCont: _rename_unless_within_group("PERCENTILE_CONT", "QUANTILE_CONT"),
|
||||
exp.PercentileDisc: _rename_unless_within_group("PERCENTILE_DISC", "QUANTILE_DISC"),
|
||||
# DuckDB doesn't allow qualified columns inside of PIVOT expressions.
|
||||
# See: https://github.com/duckdb/duckdb/blob/671faf92411182f81dce42ac43de8bfb05d9909e/src/planner/binder/tableref/bind_pivot.cpp#L61-L62
|
||||
exp.Pivot: transforms.preprocess([transforms.unqualify_columns]),
|
||||
|
@ -362,7 +394,9 @@ class DuckDB(Dialect):
|
|||
exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))",
|
||||
exp.Struct: _struct_sql,
|
||||
exp.Timestamp: no_timestamp_sql,
|
||||
exp.TimestampFromParts: rename_func("MAKE_TIMESTAMP"),
|
||||
exp.TimestampDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this
|
||||
),
|
||||
exp.TimestampTrunc: timestamptrunc_sql,
|
||||
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
|
@ -373,11 +407,10 @@ class DuckDB(Dialect):
|
|||
exp.TsOrDsAdd: _ts_or_ds_add_sql,
|
||||
exp.TsOrDsDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF",
|
||||
f"'{e.args.get('unit') or 'day'}'",
|
||||
f"'{e.args.get('unit') or 'DAY'}'",
|
||||
exp.cast(e.expression, "TIMESTAMP"),
|
||||
exp.cast(e.this, "TIMESTAMP"),
|
||||
),
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("duckdb"),
|
||||
exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)",
|
||||
|
@ -410,6 +443,49 @@ class DuckDB(Dialect):
|
|||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def timefromparts_sql(self, expression: exp.TimeFromParts) -> str:
|
||||
nano = expression.args.get("nano")
|
||||
if nano is not None:
|
||||
expression.set(
|
||||
"sec", expression.args["sec"] + nano.pop() / exp.Literal.number(1000000000.0)
|
||||
)
|
||||
|
||||
return rename_func("MAKE_TIME")(self, expression)
|
||||
|
||||
def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str:
|
||||
sec = expression.args["sec"]
|
||||
|
||||
milli = expression.args.get("milli")
|
||||
if milli is not None:
|
||||
sec += milli.pop() / exp.Literal.number(1000.0)
|
||||
|
||||
nano = expression.args.get("nano")
|
||||
if nano is not None:
|
||||
sec += nano.pop() / exp.Literal.number(1000000000.0)
|
||||
|
||||
if milli or nano:
|
||||
expression.set("sec", sec)
|
||||
|
||||
return rename_func("MAKE_TIMESTAMP")(self, expression)
|
||||
|
||||
def tablesample_sql(
|
||||
self,
|
||||
expression: exp.TableSample,
|
||||
sep: str = " AS ",
|
||||
tablesample_keyword: t.Optional[str] = None,
|
||||
) -> str:
|
||||
if not isinstance(expression.parent, exp.Select):
|
||||
# This sample clause only applies to a single source, not the entire resulting relation
|
||||
tablesample_keyword = "TABLESAMPLE"
|
||||
|
||||
return super().tablesample_sql(
|
||||
expression, sep=sep, tablesample_keyword=tablesample_keyword
|
||||
)
|
||||
|
||||
def getpath_sql(self, expression: exp.GetPath) -> str:
|
||||
expression = prepend_dollar_to_path(expression)
|
||||
return f"{self.sql(expression, 'this')} -> {self.sql(expression, 'expression')}"
|
||||
|
||||
def interval_sql(self, expression: exp.Interval) -> str:
|
||||
multiplier: t.Optional[int] = None
|
||||
unit = expression.text("unit").lower()
|
||||
|
@ -420,11 +496,14 @@ class DuckDB(Dialect):
|
|||
multiplier = 90
|
||||
|
||||
if multiplier:
|
||||
return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('day')))})"
|
||||
return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('DAY')))})"
|
||||
|
||||
return super().interval_sql(expression)
|
||||
|
||||
def tablesample_sql(
|
||||
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep: str = " AS "
|
||||
) -> str:
|
||||
return super().tablesample_sql(expression, seed_prefix="REPEATABLE", sep=sep)
|
||||
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
|
||||
if isinstance(expression.parent, exp.UserDefinedFunction):
|
||||
return self.sql(expression, "this")
|
||||
return super().columndef_sql(expression, sep)
|
||||
|
||||
def placeholder_sql(self, expression: exp.Placeholder) -> str:
|
||||
return f"${expression.name}" if expression.name else "?"
|
||||
|
|
|
@ -418,13 +418,13 @@ class Hive(Dialect):
|
|||
class Generator(generator.Generator):
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
TABLESAMPLE_WITH_METHOD = False
|
||||
TABLESAMPLE_SIZE_IS_PERCENT = True
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
QUERY_HINTS = False
|
||||
INDEX_ON = "ON TABLE"
|
||||
EXTRACT_ALLOWS_QUOTES = False
|
||||
NVL2_SUPPORTED = False
|
||||
LAST_DAY_SUPPORTS_DATE_PART = False
|
||||
|
||||
EXPRESSIONS_WITHOUT_NESTED_CTES = {
|
||||
exp.Insert,
|
||||
|
@ -523,7 +523,6 @@ class Hive(Dialect):
|
|||
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
|
||||
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
|
||||
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
|
||||
exp.LastDateOfMonth: rename_func("LAST_DAY"),
|
||||
exp.National: lambda self, e: self.national_sql(e, prefix=""),
|
||||
exp.ClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})",
|
||||
exp.NonClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})",
|
||||
|
|
|
@ -11,7 +11,6 @@ from sqlglot.dialects.dialect import (
|
|||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
isnull_to_is_null,
|
||||
json_keyvalue_comma_sql,
|
||||
locate_to_strposition,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
|
@ -21,6 +20,7 @@ from sqlglot.dialects.dialect import (
|
|||
no_tablesample_sql,
|
||||
no_trycast_sql,
|
||||
parse_date_delta_with_interval,
|
||||
path_to_jsonpath,
|
||||
rename_func,
|
||||
strposition_to_locate_sql,
|
||||
)
|
||||
|
@ -37,21 +37,21 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[MySQL.Parser], ex
|
|||
|
||||
def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str:
|
||||
expr = self.sql(expression, "this")
|
||||
unit = expression.text("unit")
|
||||
unit = expression.text("unit").upper()
|
||||
|
||||
if unit == "day":
|
||||
if unit == "DAY":
|
||||
return f"DATE({expr})"
|
||||
|
||||
if unit == "week":
|
||||
if unit == "WEEK":
|
||||
concat = f"CONCAT(YEAR({expr}), ' ', WEEK({expr}, 1), ' 1')"
|
||||
date_format = "%Y %u %w"
|
||||
elif unit == "month":
|
||||
elif unit == "MONTH":
|
||||
concat = f"CONCAT(YEAR({expr}), ' ', MONTH({expr}), ' 1')"
|
||||
date_format = "%Y %c %e"
|
||||
elif unit == "quarter":
|
||||
elif unit == "QUARTER":
|
||||
concat = f"CONCAT(YEAR({expr}), ' ', QUARTER({expr}) * 3 - 2, ' 1')"
|
||||
date_format = "%Y %c %e"
|
||||
elif unit == "year":
|
||||
elif unit == "YEAR":
|
||||
concat = f"CONCAT(YEAR({expr}), ' 1 1')"
|
||||
date_format = "%Y %c %e"
|
||||
else:
|
||||
|
@ -292,9 +292,15 @@ class MySQL(Dialect):
|
|||
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
|
||||
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
|
||||
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
|
||||
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
|
||||
"ISNULL": isnull_to_is_null,
|
||||
"LOCATE": locate_to_strposition,
|
||||
"MAKETIME": exp.TimeFromParts.from_arg_list,
|
||||
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"MONTHNAME": lambda args: exp.TimeToStr(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
format=exp.Literal.string("%B"),
|
||||
|
@ -308,11 +314,6 @@ class MySQL(Dialect):
|
|||
)
|
||||
+ 1
|
||||
),
|
||||
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"WEEK": lambda args: exp.Week(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)), mode=seq_get(args, 1)
|
||||
),
|
||||
|
@ -441,6 +442,7 @@ class MySQL(Dialect):
|
|||
}
|
||||
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
STRING_ALIASES = True
|
||||
|
||||
def _parse_primary_key_part(self) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_id_var()
|
||||
|
@ -620,13 +622,15 @@ class MySQL(Dialect):
|
|||
|
||||
class Generator(generator.Generator):
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
NULL_ORDERING_SUPPORTED = False
|
||||
NULL_ORDERING_SUPPORTED = None
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = True
|
||||
DUPLICATE_KEY_UPDATE_WITH_SET = False
|
||||
QUERY_HINT_SEP = " "
|
||||
VALUES_AS_TABLE = False
|
||||
NVL2_SUPPORTED = False
|
||||
LAST_DAY_SUPPORTS_DATE_PART = False
|
||||
JSON_KEY_VALUE_PAIR_SEP = ","
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -642,15 +646,16 @@ class MySQL(Dialect):
|
|||
exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")),
|
||||
exp.DayOfWeek: _remove_ts_or_ds_to_date(rename_func("DAYOFWEEK")),
|
||||
exp.DayOfYear: _remove_ts_or_ds_to_date(rename_func("DAYOFYEAR")),
|
||||
exp.GetPath: path_to_jsonpath(),
|
||||
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.JSONKeyValue: json_keyvalue_comma_sql,
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.Month: _remove_ts_or_ds_to_date(),
|
||||
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
|
||||
exp.NullSafeNEQ: lambda self, e: f"NOT {self.binary(e, '<=>')}",
|
||||
exp.ParseJSON: lambda self, e: self.sql(e, "this"),
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.Select: transforms.preprocess(
|
||||
[
|
||||
|
@ -665,6 +670,7 @@ class MySQL(Dialect):
|
|||
exp.StrToTime: _str_to_date_sql,
|
||||
exp.Stuff: rename_func("INSERT"),
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
exp.TimeFromParts: rename_func("MAKETIME"),
|
||||
exp.TimestampAdd: date_add_interval_sql("DATE", "ADD"),
|
||||
exp.TimestampSub: date_add_interval_sql("DATE", "SUB"),
|
||||
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
|
|
|
@ -53,6 +53,7 @@ def to_char(args: t.List) -> exp.TimeToStr | exp.ToChar:
|
|||
class Oracle(Dialect):
|
||||
ALIAS_POST_TABLESAMPLE = True
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
TABLESAMPLE_SIZE_IS_PERCENT = True
|
||||
|
||||
# See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
|
||||
|
@ -81,6 +82,7 @@ class Oracle(Dialect):
|
|||
"WW": "%W", # Week of year (1-53)
|
||||
"YY": "%y", # 15
|
||||
"YYYY": "%Y", # 2015
|
||||
"FF6": "%f", # only 6 digits are supported in python formats
|
||||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
|
@ -91,6 +93,8 @@ class Oracle(Dialect):
|
|||
**parser.Parser.FUNCTIONS,
|
||||
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
|
||||
"TO_CHAR": to_char,
|
||||
"TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "oracle"),
|
||||
"TO_DATE": format_time_lambda(exp.StrToDate, "oracle"),
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
|
||||
|
@ -107,6 +111,11 @@ class Oracle(Dialect):
|
|||
"XMLTABLE": _parse_xml_table,
|
||||
}
|
||||
|
||||
QUERY_MODIFIER_PARSERS = {
|
||||
**parser.Parser.QUERY_MODIFIER_PARSERS,
|
||||
TokenType.ORDER_SIBLINGS_BY: lambda self: ("order", self._parse_order()),
|
||||
}
|
||||
|
||||
TYPE_LITERAL_PARSERS = {
|
||||
exp.DataType.Type.DATE: lambda self, this, _: self.expression(
|
||||
exp.DateStrToDate, this=this
|
||||
|
@ -153,8 +162,10 @@ class Oracle(Dialect):
|
|||
COLUMN_JOIN_MARKS_SUPPORTED = True
|
||||
DATA_TYPE_SPECIFIERS_ALLOWED = True
|
||||
ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = False
|
||||
|
||||
LIMIT_FETCH = "FETCH"
|
||||
TABLESAMPLE_KEYWORDS = "SAMPLE"
|
||||
LAST_DAY_SUPPORTS_DATE_PART = False
|
||||
SUPPORTS_SELECT_INTO = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
|
@ -186,6 +197,7 @@ class Oracle(Dialect):
|
|||
]
|
||||
),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.StrToDate: lambda self, e: f"TO_DATE({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
|
||||
exp.Substring: rename_func("SUBSTR"),
|
||||
exp.Table: lambda self, e: self.table_sql(e, sep=" "),
|
||||
|
@ -201,6 +213,10 @@ class Oracle(Dialect):
|
|||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def currenttimestamp_sql(self, expression: exp.CurrentTimestamp) -> str:
|
||||
this = expression.this
|
||||
return self.func("CURRENT_TIMESTAMP", this) if this else "CURRENT_TIMESTAMP"
|
||||
|
||||
def offset_sql(self, expression: exp.Offset) -> str:
|
||||
return f"{super().offset_sql(expression)} ROWS"
|
||||
|
||||
|
@ -233,8 +249,10 @@ class Oracle(Dialect):
|
|||
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
|
||||
"MINUS": TokenType.EXCEPT,
|
||||
"NVARCHAR2": TokenType.NVARCHAR,
|
||||
"ORDER SIBLINGS BY": TokenType.ORDER_SIBLINGS_BY,
|
||||
"SAMPLE": TokenType.TABLE_SAMPLE,
|
||||
"START": TokenType.BEGIN,
|
||||
"SYSDATE": TokenType.CURRENT_TIMESTAMP,
|
||||
"TOP": TokenType.TOP,
|
||||
"VARCHAR2": TokenType.VARCHAR,
|
||||
}
|
||||
|
|
|
@ -13,11 +13,12 @@ from sqlglot.dialects.dialect import (
|
|||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
max_or_greatest,
|
||||
merge_without_target_sql,
|
||||
min_or_least,
|
||||
no_last_day_sql,
|
||||
no_map_from_entries_sql,
|
||||
no_paren_current_date_sql,
|
||||
no_pivot_sql,
|
||||
no_tablesample_sql,
|
||||
no_trycast_sql,
|
||||
parse_timestamp_trunc,
|
||||
rename_func,
|
||||
|
@ -27,7 +28,6 @@ from sqlglot.dialects.dialect import (
|
|||
timestrtotime_sql,
|
||||
trim_sql,
|
||||
ts_or_ds_add_cast,
|
||||
ts_or_ds_to_date_sql,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.parser import binary_range_parser
|
||||
|
@ -188,36 +188,6 @@ def _to_timestamp(args: t.List) -> exp.Expression:
|
|||
return format_time_lambda(exp.StrToTime, "postgres")(args)
|
||||
|
||||
|
||||
def _merge_sql(self: Postgres.Generator, expression: exp.Merge) -> str:
|
||||
def _remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
|
||||
"""Remove table refs from columns in when statements."""
|
||||
if isinstance(expression, exp.Merge):
|
||||
alias = expression.this.args.get("alias")
|
||||
|
||||
normalize = (
|
||||
lambda identifier: self.dialect.normalize_identifier(identifier).name
|
||||
if identifier
|
||||
else None
|
||||
)
|
||||
|
||||
targets = {normalize(expression.this.this)}
|
||||
|
||||
if alias:
|
||||
targets.add(normalize(alias.this))
|
||||
|
||||
for when in expression.expressions:
|
||||
when.transform(
|
||||
lambda node: exp.column(node.this)
|
||||
if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
|
||||
else node,
|
||||
copy=False,
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
return transforms.preprocess([_remove_target_from_merge])(self, expression)
|
||||
|
||||
|
||||
class Postgres(Dialect):
|
||||
INDEX_OFFSET = 1
|
||||
TYPED_DIVISION = True
|
||||
|
@ -316,6 +286,8 @@ class Postgres(Dialect):
|
|||
**parser.Parser.FUNCTIONS,
|
||||
"DATE_TRUNC": parse_timestamp_trunc,
|
||||
"GENERATE_SERIES": _generate_series,
|
||||
"MAKE_TIME": exp.TimeFromParts.from_arg_list,
|
||||
"MAKE_TIMESTAMP": exp.TimestampFromParts.from_arg_list,
|
||||
"NOW": exp.CurrentTimestamp.from_arg_list,
|
||||
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
|
||||
"TO_TIMESTAMP": _to_timestamp,
|
||||
|
@ -387,12 +359,18 @@ class Postgres(Dialect):
|
|||
|
||||
class Generator(generator.Generator):
|
||||
SINGLE_STRING_INTERVAL = True
|
||||
RENAME_TABLE_WITH_DB = False
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
QUERY_HINTS = False
|
||||
NVL2_SUPPORTED = False
|
||||
PARAMETER_TOKEN = "$"
|
||||
TABLESAMPLE_SIZE_IS_ROWS = False
|
||||
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
|
||||
SUPPORTS_SELECT_INTO = True
|
||||
# https://www.postgresql.org/docs/current/sql-createtable.html
|
||||
SUPPORTS_UNLOGGED_TABLES = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
|
@ -430,12 +408,13 @@ class Postgres(Dialect):
|
|||
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
|
||||
exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"),
|
||||
exp.JSONBContains: lambda self, e: self.binary(e, "?"),
|
||||
exp.LastDay: no_last_day_sql,
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.Max: max_or_greatest,
|
||||
exp.MapFromEntries: no_map_from_entries_sql,
|
||||
exp.Min: min_or_least,
|
||||
exp.Merge: _merge_sql,
|
||||
exp.Merge: merge_without_target_sql,
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
exp.PercentileCont: transforms.preprocess(
|
||||
[transforms.add_within_group_for_percentiles]
|
||||
|
@ -458,16 +437,16 @@ class Postgres(Dialect):
|
|||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.StructExtract: struct_extract_sql,
|
||||
exp.Substring: _substring_sql,
|
||||
exp.TimeFromParts: rename_func("MAKE_TIME"),
|
||||
exp.TimestampFromParts: rename_func("MAKE_TIMESTAMP"),
|
||||
exp.TimestampTrunc: timestamptrunc_sql,
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.Trim: trim_sql,
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.TsOrDsAdd: _date_add_sql("+"),
|
||||
exp.TsOrDsDiff: _date_diff_sql,
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("postgres"),
|
||||
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
exp.Variance: rename_func("VAR_SAMP"),
|
||||
|
|
|
@ -18,6 +18,7 @@ from sqlglot.dialects.dialect import (
|
|||
no_pivot_sql,
|
||||
no_safe_divide_sql,
|
||||
no_timestamp_sql,
|
||||
path_to_jsonpath,
|
||||
regexp_extract_sql,
|
||||
rename_func,
|
||||
right_to_substring_sql,
|
||||
|
@ -99,14 +100,14 @@ def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate)
|
|||
|
||||
def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str:
|
||||
expression = ts_or_ds_add_cast(expression)
|
||||
unit = exp.Literal.string(expression.text("unit") or "day")
|
||||
unit = exp.Literal.string(expression.text("unit") or "DAY")
|
||||
return self.func("DATE_ADD", unit, expression.expression, expression.this)
|
||||
|
||||
|
||||
def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> str:
|
||||
this = exp.cast(expression.this, "TIMESTAMP")
|
||||
expr = exp.cast(expression.expression, "TIMESTAMP")
|
||||
unit = exp.Literal.string(expression.text("unit") or "day")
|
||||
unit = exp.Literal.string(expression.text("unit") or "DAY")
|
||||
return self.func("DATE_DIFF", unit, expr, this)
|
||||
|
||||
|
||||
|
@ -138,13 +139,6 @@ def _from_unixtime(args: t.List) -> exp.Expression:
|
|||
return exp.UnixToTime.from_arg_list(args)
|
||||
|
||||
|
||||
def _parse_element_at(args: t.List) -> exp.Bracket:
|
||||
this = seq_get(args, 0)
|
||||
index = seq_get(args, 1)
|
||||
assert isinstance(this, exp.Expression) and isinstance(index, exp.Expression)
|
||||
return exp.Bracket(this=this, expressions=[index], offset=1, safe=True)
|
||||
|
||||
|
||||
def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
|
||||
if isinstance(expression, exp.Table):
|
||||
if isinstance(expression.this, exp.GenerateSeries):
|
||||
|
@ -175,15 +169,8 @@ def _unix_to_time_sql(self: Presto.Generator, expression: exp.UnixToTime) -> str
|
|||
timestamp = self.sql(expression, "this")
|
||||
if scale in (None, exp.UnixToTime.SECONDS):
|
||||
return rename_func("FROM_UNIXTIME")(self, expression)
|
||||
if scale == exp.UnixToTime.MILLIS:
|
||||
return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000)"
|
||||
if scale == exp.UnixToTime.MICROS:
|
||||
return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000)"
|
||||
if scale == exp.UnixToTime.NANOS:
|
||||
return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000000)"
|
||||
|
||||
self.unsupported(f"Unsupported scale for timestamp: {scale}.")
|
||||
return ""
|
||||
return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / POW(10, {scale}))"
|
||||
|
||||
|
||||
def _to_int(expression: exp.Expression) -> exp.Expression:
|
||||
|
@ -215,6 +202,7 @@ class Presto(Dialect):
|
|||
STRICT_STRING_CONCAT = True
|
||||
SUPPORTS_SEMI_ANTI_JOIN = False
|
||||
TYPED_DIVISION = True
|
||||
TABLESAMPLE_SIZE_IS_PERCENT = True
|
||||
|
||||
# https://github.com/trinodb/trino/issues/17
|
||||
# https://github.com/trinodb/trino/issues/12289
|
||||
|
@ -258,7 +246,9 @@ class Presto(Dialect):
|
|||
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
|
||||
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
|
||||
"DATE_TRUNC": date_trunc_to_time,
|
||||
"ELEMENT_AT": _parse_element_at,
|
||||
"ELEMENT_AT": lambda args: exp.Bracket(
|
||||
this=seq_get(args, 0), expressions=[seq_get(args, 1)], offset=1, safe=True
|
||||
),
|
||||
"FROM_HEX": exp.Unhex.from_arg_list,
|
||||
"FROM_UNIXTIME": _from_unixtime,
|
||||
"FROM_UTF8": lambda args: exp.Decode(
|
||||
|
@ -344,20 +334,20 @@ class Presto(Dialect):
|
|||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.DateAdd: lambda self, e: self.func(
|
||||
"DATE_ADD",
|
||||
exp.Literal.string(e.text("unit") or "day"),
|
||||
exp.Literal.string(e.text("unit") or "DAY"),
|
||||
_to_int(
|
||||
e.expression,
|
||||
),
|
||||
e.this,
|
||||
),
|
||||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
|
||||
"DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.expression, e.this
|
||||
),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
|
||||
exp.DateSub: lambda self, e: self.func(
|
||||
"DATE_ADD",
|
||||
exp.Literal.string(e.text("unit") or "day"),
|
||||
exp.Literal.string(e.text("unit") or "DAY"),
|
||||
_to_int(e.expression * -1),
|
||||
e.this,
|
||||
),
|
||||
|
@ -366,6 +356,7 @@ class Presto(Dialect):
|
|||
exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"),
|
||||
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
|
||||
exp.First: _first_last_sql,
|
||||
exp.GetPath: path_to_jsonpath(),
|
||||
exp.Group: transforms.preprocess([transforms.unalias_group]),
|
||||
exp.GroupConcat: lambda self, e: self.func(
|
||||
"ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator")
|
||||
|
@ -376,6 +367,7 @@ class Presto(Dialect):
|
|||
exp.Initcap: _initcap_sql,
|
||||
exp.ParseJSON: rename_func("JSON_PARSE"),
|
||||
exp.Last: _first_last_sql,
|
||||
exp.LastDay: lambda self, e: self.func("LAST_DAY_OF_MONTH", e.this),
|
||||
exp.Lateral: _explode_to_unnest_sql,
|
||||
exp.Left: left_to_substring_sql,
|
||||
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
|
||||
|
@ -446,7 +438,7 @@ class Presto(Dialect):
|
|||
return super().bracket_sql(expression)
|
||||
|
||||
def struct_sql(self, expression: exp.Struct) -> str:
|
||||
if any(isinstance(arg, self.KEY_VALUE_DEFINITONS) for arg in expression.expressions):
|
||||
if any(isinstance(arg, self.KEY_VALUE_DEFINITIONS) for arg in expression.expressions):
|
||||
self.unsupported("Struct with key-value definitions is unsupported.")
|
||||
return self.function_fallback_sql(expression)
|
||||
|
||||
|
@ -454,8 +446,8 @@ class Presto(Dialect):
|
|||
|
||||
def interval_sql(self, expression: exp.Interval) -> str:
|
||||
unit = self.sql(expression, "unit")
|
||||
if expression.this and unit.lower().startswith("week"):
|
||||
return f"({expression.this.name} * INTERVAL '7' day)"
|
||||
if expression.this and unit.startswith("WEEK"):
|
||||
return f"({expression.this.name} * INTERVAL '7' DAY)"
|
||||
return super().interval_sql(expression)
|
||||
|
||||
def transaction_sql(self, expression: exp.Transaction) -> str:
|
||||
|
|
|
@ -9,8 +9,8 @@ from sqlglot.dialects.dialect import (
|
|||
concat_ws_to_dpipe_sql,
|
||||
date_delta_sql,
|
||||
generatedasidentitycolumnconstraint_sql,
|
||||
no_tablesample_sql,
|
||||
rename_func,
|
||||
ts_or_ds_to_date_sql,
|
||||
)
|
||||
from sqlglot.dialects.postgres import Postgres
|
||||
from sqlglot.helper import seq_get
|
||||
|
@ -123,6 +123,27 @@ class Redshift(Postgres):
|
|||
self._retreat(index)
|
||||
return None
|
||||
|
||||
def _parse_query_modifiers(
|
||||
self, this: t.Optional[exp.Expression]
|
||||
) -> t.Optional[exp.Expression]:
|
||||
this = super()._parse_query_modifiers(this)
|
||||
|
||||
if this:
|
||||
refs = set()
|
||||
|
||||
for i, join in enumerate(this.args.get("joins", [])):
|
||||
refs.add(
|
||||
(
|
||||
this.args["from"] if i == 0 else this.args["joins"][i - 1]
|
||||
).alias_or_name.lower()
|
||||
)
|
||||
table = join.this
|
||||
|
||||
if isinstance(table, exp.Table):
|
||||
if table.parts[0].name.lower() in refs:
|
||||
table.replace(table.to_column())
|
||||
return this
|
||||
|
||||
class Tokenizer(Postgres.Tokenizer):
|
||||
BIT_STRINGS = []
|
||||
HEX_STRINGS = []
|
||||
|
@ -144,11 +165,11 @@ class Redshift(Postgres):
|
|||
|
||||
class Generator(Postgres.Generator):
|
||||
LOCKING_READS_SUPPORTED = False
|
||||
RENAME_TABLE_WITH_DB = False
|
||||
QUERY_HINTS = False
|
||||
VALUES_AS_TABLE = False
|
||||
TZ_TO_WITH_TIME_ZONE = True
|
||||
NVL2_SUPPORTED = True
|
||||
LAST_DAY_SUPPORTS_DATE_PART = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**Postgres.Generator.TYPE_MAPPING,
|
||||
|
@ -184,9 +205,9 @@ class Redshift(Postgres):
|
|||
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
|
||||
),
|
||||
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
exp.TsOrDsAdd: date_delta_sql("DATEADD"),
|
||||
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("redshift"),
|
||||
}
|
||||
|
||||
# Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots
|
||||
|
@ -198,6 +219,9 @@ class Redshift(Postgres):
|
|||
# Redshift supports ANY_VALUE(..)
|
||||
TRANSFORMS.pop(exp.AnyValue)
|
||||
|
||||
# Redshift supports LAST_DAY(..)
|
||||
TRANSFORMS.pop(exp.LastDay)
|
||||
|
||||
RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"}
|
||||
|
||||
def with_properties(self, properties: exp.Properties) -> str:
|
||||
|
|
|
@ -19,7 +19,6 @@ from sqlglot.dialects.dialect import (
|
|||
rename_func,
|
||||
timestamptrunc_sql,
|
||||
timestrtotime_sql,
|
||||
ts_or_ds_to_date_sql,
|
||||
var_map_sql,
|
||||
)
|
||||
from sqlglot.expressions import Literal
|
||||
|
@ -40,21 +39,7 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime,
|
|||
if second_arg.is_string:
|
||||
# case: <string_expr> [ , <format> ]
|
||||
return format_time_lambda(exp.StrToTime, "snowflake")(args)
|
||||
|
||||
# case: <numeric_expr> [ , <scale> ]
|
||||
if second_arg.name not in ["0", "3", "9"]:
|
||||
raise ValueError(
|
||||
f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9"
|
||||
)
|
||||
|
||||
if second_arg.name == "0":
|
||||
timescale = exp.UnixToTime.SECONDS
|
||||
elif second_arg.name == "3":
|
||||
timescale = exp.UnixToTime.MILLIS
|
||||
elif second_arg.name == "9":
|
||||
timescale = exp.UnixToTime.NANOS
|
||||
|
||||
return exp.UnixToTime(this=first_arg, scale=timescale)
|
||||
return exp.UnixToTime(this=first_arg, scale=second_arg)
|
||||
|
||||
from sqlglot.optimizer.simplify import simplify_literals
|
||||
|
||||
|
@ -91,23 +76,9 @@ def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
|
|||
|
||||
|
||||
def _parse_datediff(args: t.List) -> exp.DateDiff:
|
||||
return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
|
||||
|
||||
|
||||
def _unix_to_time_sql(self: Snowflake.Generator, expression: exp.UnixToTime) -> str:
|
||||
scale = expression.args.get("scale")
|
||||
timestamp = self.sql(expression, "this")
|
||||
if scale in (None, exp.UnixToTime.SECONDS):
|
||||
return f"TO_TIMESTAMP({timestamp})"
|
||||
if scale == exp.UnixToTime.MILLIS:
|
||||
return f"TO_TIMESTAMP({timestamp}, 3)"
|
||||
if scale == exp.UnixToTime.MICROS:
|
||||
return f"TO_TIMESTAMP({timestamp} / 1000, 3)"
|
||||
if scale == exp.UnixToTime.NANOS:
|
||||
return f"TO_TIMESTAMP({timestamp}, 9)"
|
||||
|
||||
self.unsupported(f"Unsupported scale for timestamp: {scale}.")
|
||||
return ""
|
||||
return exp.DateDiff(
|
||||
this=seq_get(args, 2), expression=seq_get(args, 1), unit=_map_date_part(seq_get(args, 0))
|
||||
)
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
|
||||
|
@ -120,14 +91,15 @@ def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]:
|
|||
|
||||
self._match(TokenType.COMMA)
|
||||
expression = self._parse_bitwise()
|
||||
|
||||
this = _map_date_part(this)
|
||||
name = this.name.upper()
|
||||
|
||||
if name.startswith("EPOCH"):
|
||||
if name.startswith("EPOCH_MILLISECOND"):
|
||||
if name == "EPOCH_MILLISECOND":
|
||||
scale = 10**3
|
||||
elif name.startswith("EPOCH_MICROSECOND"):
|
||||
elif name == "EPOCH_MICROSECOND":
|
||||
scale = 10**6
|
||||
elif name.startswith("EPOCH_NANOSECOND"):
|
||||
elif name == "EPOCH_NANOSECOND":
|
||||
scale = 10**9
|
||||
else:
|
||||
scale = None
|
||||
|
@ -204,6 +176,159 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[Snowflake.Parser]
|
|||
return _parse
|
||||
|
||||
|
||||
DATE_PART_MAPPING = {
|
||||
"Y": "YEAR",
|
||||
"YY": "YEAR",
|
||||
"YYY": "YEAR",
|
||||
"YYYY": "YEAR",
|
||||
"YR": "YEAR",
|
||||
"YEARS": "YEAR",
|
||||
"YRS": "YEAR",
|
||||
"MM": "MONTH",
|
||||
"MON": "MONTH",
|
||||
"MONS": "MONTH",
|
||||
"MONTHS": "MONTH",
|
||||
"D": "DAY",
|
||||
"DD": "DAY",
|
||||
"DAYS": "DAY",
|
||||
"DAYOFMONTH": "DAY",
|
||||
"WEEKDAY": "DAYOFWEEK",
|
||||
"DOW": "DAYOFWEEK",
|
||||
"DW": "DAYOFWEEK",
|
||||
"WEEKDAY_ISO": "DAYOFWEEKISO",
|
||||
"DOW_ISO": "DAYOFWEEKISO",
|
||||
"DW_ISO": "DAYOFWEEKISO",
|
||||
"YEARDAY": "DAYOFYEAR",
|
||||
"DOY": "DAYOFYEAR",
|
||||
"DY": "DAYOFYEAR",
|
||||
"W": "WEEK",
|
||||
"WK": "WEEK",
|
||||
"WEEKOFYEAR": "WEEK",
|
||||
"WOY": "WEEK",
|
||||
"WY": "WEEK",
|
||||
"WEEK_ISO": "WEEKISO",
|
||||
"WEEKOFYEARISO": "WEEKISO",
|
||||
"WEEKOFYEAR_ISO": "WEEKISO",
|
||||
"Q": "QUARTER",
|
||||
"QTR": "QUARTER",
|
||||
"QTRS": "QUARTER",
|
||||
"QUARTERS": "QUARTER",
|
||||
"H": "HOUR",
|
||||
"HH": "HOUR",
|
||||
"HR": "HOUR",
|
||||
"HOURS": "HOUR",
|
||||
"HRS": "HOUR",
|
||||
"M": "MINUTE",
|
||||
"MI": "MINUTE",
|
||||
"MIN": "MINUTE",
|
||||
"MINUTES": "MINUTE",
|
||||
"MINS": "MINUTE",
|
||||
"S": "SECOND",
|
||||
"SEC": "SECOND",
|
||||
"SECONDS": "SECOND",
|
||||
"SECS": "SECOND",
|
||||
"MS": "MILLISECOND",
|
||||
"MSEC": "MILLISECOND",
|
||||
"MILLISECONDS": "MILLISECOND",
|
||||
"US": "MICROSECOND",
|
||||
"USEC": "MICROSECOND",
|
||||
"MICROSECONDS": "MICROSECOND",
|
||||
"NS": "NANOSECOND",
|
||||
"NSEC": "NANOSECOND",
|
||||
"NANOSEC": "NANOSECOND",
|
||||
"NSECOND": "NANOSECOND",
|
||||
"NSECONDS": "NANOSECOND",
|
||||
"NANOSECS": "NANOSECOND",
|
||||
"NSECONDS": "NANOSECOND",
|
||||
"EPOCH": "EPOCH_SECOND",
|
||||
"EPOCH_SECONDS": "EPOCH_SECOND",
|
||||
"EPOCH_MILLISECONDS": "EPOCH_MILLISECOND",
|
||||
"EPOCH_MICROSECONDS": "EPOCH_MICROSECOND",
|
||||
"EPOCH_NANOSECONDS": "EPOCH_NANOSECOND",
|
||||
"TZH": "TIMEZONE_HOUR",
|
||||
"TZM": "TIMEZONE_MINUTE",
|
||||
}
|
||||
|
||||
|
||||
@t.overload
|
||||
def _map_date_part(part: exp.Expression) -> exp.Var:
|
||||
pass
|
||||
|
||||
|
||||
@t.overload
|
||||
def _map_date_part(part: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
pass
|
||||
|
||||
|
||||
def _map_date_part(part):
|
||||
mapped = DATE_PART_MAPPING.get(part.name.upper()) if part else None
|
||||
return exp.var(mapped) if mapped else part
|
||||
|
||||
|
||||
def _date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
|
||||
trunc = date_trunc_to_time(args)
|
||||
trunc.set("unit", _map_date_part(trunc.args["unit"]))
|
||||
return trunc
|
||||
|
||||
|
||||
def _parse_colon_get_path(
|
||||
self: parser.Parser, this: t.Optional[exp.Expression]
|
||||
) -> t.Optional[exp.Expression]:
|
||||
while True:
|
||||
path = self._parse_bitwise()
|
||||
|
||||
# The cast :: operator has a lower precedence than the extraction operator :, so
|
||||
# we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH
|
||||
if isinstance(path, exp.Cast):
|
||||
target_type = path.to
|
||||
path = path.this
|
||||
else:
|
||||
target_type = None
|
||||
|
||||
if isinstance(path, exp.Expression):
|
||||
path = exp.Literal.string(path.sql(dialect="snowflake"))
|
||||
|
||||
# The extraction operator : is left-associative
|
||||
this = self.expression(exp.GetPath, this=this, expression=path)
|
||||
|
||||
if target_type:
|
||||
this = exp.cast(this, target_type)
|
||||
|
||||
if not self._match(TokenType.COLON):
|
||||
break
|
||||
|
||||
if self._match_set(self.RANGE_PARSERS):
|
||||
this = self.RANGE_PARSERS[self._prev.token_type](self, this) or this
|
||||
|
||||
return this
|
||||
|
||||
|
||||
def _parse_timestamp_from_parts(args: t.List) -> exp.Func:
|
||||
if len(args) == 2:
|
||||
# Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept,
|
||||
# so we parse this into Anonymous for now instead of introducing complexity
|
||||
return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args)
|
||||
|
||||
return exp.TimestampFromParts.from_arg_list(args)
|
||||
|
||||
|
||||
def _unqualify_unpivot_columns(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Snowflake doesn't allow columns referenced in UNPIVOT to be qualified,
|
||||
so we need to unqualify them.
|
||||
|
||||
Example:
|
||||
>>> from sqlglot import parse_one
|
||||
>>> expr = parse_one("SELECT * FROM m_sales UNPIVOT(sales FOR month IN (m_sales.jan, feb, mar, april))")
|
||||
>>> print(_unqualify_unpivot_columns(expr).sql(dialect="snowflake"))
|
||||
SELECT * FROM m_sales UNPIVOT(sales FOR month IN (jan, feb, mar, april))
|
||||
"""
|
||||
if isinstance(expression, exp.Pivot) and expression.unpivot:
|
||||
expression = transforms.unqualify_columns(expression)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
class Snowflake(Dialect):
|
||||
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
|
||||
|
@ -211,6 +336,8 @@ class Snowflake(Dialect):
|
|||
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
SUPPORTS_SEMI_ANTI_JOIN = False
|
||||
PREFER_CTE_ALIAS_COLUMN = True
|
||||
TABLESAMPLE_SIZE_IS_PERCENT = True
|
||||
|
||||
TIME_MAPPING = {
|
||||
"YYYY": "%Y",
|
||||
|
@ -276,14 +403,19 @@ class Snowflake(Dialect):
|
|||
"BIT_XOR": binary_from_function(exp.BitwiseXor),
|
||||
"BOOLXOR": binary_from_function(exp.Xor),
|
||||
"CONVERT_TIMEZONE": _parse_convert_timezone,
|
||||
"DATE_TRUNC": date_trunc_to_time,
|
||||
"DATE_TRUNC": _date_trunc_to_time,
|
||||
"DATEADD": lambda args: exp.DateAdd(
|
||||
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
unit=_map_date_part(seq_get(args, 0)),
|
||||
),
|
||||
"DATEDIFF": _parse_datediff,
|
||||
"DIV0": _div0_to_if,
|
||||
"FLATTEN": exp.Explode.from_arg_list,
|
||||
"IFF": exp.If.from_arg_list,
|
||||
"LAST_DAY": lambda args: exp.LastDay(
|
||||
this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1))
|
||||
),
|
||||
"LISTAGG": exp.GroupConcat.from_arg_list,
|
||||
"NULLIFZERO": _nullifzero_to_if,
|
||||
"OBJECT_CONSTRUCT": _parse_object_construct,
|
||||
|
@ -293,6 +425,8 @@ class Snowflake(Dialect):
|
|||
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
|
||||
"TIMEDIFF": _parse_datediff,
|
||||
"TIMESTAMPDIFF": _parse_datediff,
|
||||
"TIMESTAMPFROMPARTS": _parse_timestamp_from_parts,
|
||||
"TIMESTAMP_FROM_PARTS": _parse_timestamp_from_parts,
|
||||
"TO_TIMESTAMP": _parse_to_timestamp,
|
||||
"TO_VARCHAR": exp.ToChar.from_arg_list,
|
||||
"ZEROIFNULL": _zeroifnull_to_if,
|
||||
|
@ -301,22 +435,17 @@ class Snowflake(Dialect):
|
|||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
"DATE_PART": _parse_date_part,
|
||||
"OBJECT_CONSTRUCT_KEEP_NULL": lambda self: self._parse_json_object(),
|
||||
}
|
||||
FUNCTION_PARSERS.pop("TRIM")
|
||||
|
||||
COLUMN_OPERATORS = {
|
||||
**parser.Parser.COLUMN_OPERATORS,
|
||||
TokenType.COLON: lambda self, this, path: self.expression(
|
||||
exp.Bracket, this=this, expressions=[path]
|
||||
),
|
||||
}
|
||||
|
||||
TIMESTAMPS = parser.Parser.TIMESTAMPS - {TokenType.TIME}
|
||||
|
||||
RANGE_PARSERS = {
|
||||
**parser.Parser.RANGE_PARSERS,
|
||||
TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny),
|
||||
TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny),
|
||||
TokenType.COLON: _parse_colon_get_path,
|
||||
}
|
||||
|
||||
ALTER_PARSERS = {
|
||||
|
@ -344,6 +473,7 @@ class Snowflake(Dialect):
|
|||
SHOW_PARSERS = {
|
||||
"PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
|
||||
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
|
||||
"COLUMNS": _show_parser("COLUMNS"),
|
||||
}
|
||||
|
||||
STAGED_FILE_SINGLE_TOKENS = {
|
||||
|
@ -351,8 +481,18 @@ class Snowflake(Dialect):
|
|||
TokenType.MOD,
|
||||
TokenType.SLASH,
|
||||
}
|
||||
|
||||
FLATTEN_COLUMNS = ["SEQ", "KEY", "PATH", "INDEX", "VALUE", "THIS"]
|
||||
|
||||
def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]:
|
||||
if is_map:
|
||||
# Keys are strings in Snowflake's objects, see also:
|
||||
# - https://docs.snowflake.com/en/sql-reference/data-types-semistructured
|
||||
# - https://docs.snowflake.com/en/sql-reference/functions/object_construct
|
||||
return self._parse_slice(self._parse_string())
|
||||
|
||||
return self._parse_slice(self._parse_alias(self._parse_conjunction(), explicit=True))
|
||||
|
||||
def _parse_lateral(self) -> t.Optional[exp.Lateral]:
|
||||
lateral = super()._parse_lateral()
|
||||
if not lateral:
|
||||
|
@ -440,6 +580,8 @@ class Snowflake(Dialect):
|
|||
scope = None
|
||||
scope_kind = None
|
||||
|
||||
like = self._parse_string() if self._match(TokenType.LIKE) else None
|
||||
|
||||
if self._match(TokenType.IN):
|
||||
if self._match_text_seq("ACCOUNT"):
|
||||
scope_kind = "ACCOUNT"
|
||||
|
@ -451,7 +593,9 @@ class Snowflake(Dialect):
|
|||
scope_kind = "TABLE"
|
||||
scope = self._parse_table()
|
||||
|
||||
return self.expression(exp.Show, this=this, scope=scope, scope_kind=scope_kind)
|
||||
return self.expression(
|
||||
exp.Show, this=this, like=like, scope=scope, scope_kind=scope_kind
|
||||
)
|
||||
|
||||
def _parse_alter_table_swap(self) -> exp.SwapTable:
|
||||
self._match_text_seq("WITH")
|
||||
|
@ -489,8 +633,12 @@ class Snowflake(Dialect):
|
|||
"MINUS": TokenType.EXCEPT,
|
||||
"NCHAR VARYING": TokenType.VARCHAR,
|
||||
"PUT": TokenType.COMMAND,
|
||||
"REMOVE": TokenType.COMMAND,
|
||||
"RENAME": TokenType.REPLACE,
|
||||
"RM": TokenType.COMMAND,
|
||||
"SAMPLE": TokenType.TABLE_SAMPLE,
|
||||
"SQL_DOUBLE": TokenType.DOUBLE,
|
||||
"SQL_VARCHAR": TokenType.VARCHAR,
|
||||
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
|
||||
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
|
||||
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
|
||||
|
@ -518,6 +666,8 @@ class Snowflake(Dialect):
|
|||
SUPPORTS_TABLE_COPY = False
|
||||
COLLATE_IS_FUNC = True
|
||||
LIMIT_ONLY_LITERALS = True
|
||||
JSON_KEY_VALUE_PAIR_SEP = ","
|
||||
INSERT_OVERWRITE = " OVERWRITE INTO"
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -545,6 +695,8 @@ class Snowflake(Dialect):
|
|||
),
|
||||
exp.GroupConcat: rename_func("LISTAGG"),
|
||||
exp.If: if_sql(name="IFF", false_value="NULL"),
|
||||
exp.JSONExtract: lambda self, e: f"{self.sql(e, 'this')}[{self.sql(e, 'expression')}]",
|
||||
exp.JSONObject: lambda self, e: self.func("OBJECT_CONSTRUCT_KEEP_NULL", *e.expressions),
|
||||
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
|
||||
exp.LogicalOr: rename_func("BOOLOR_AGG"),
|
||||
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
|
@ -557,6 +709,7 @@ class Snowflake(Dialect):
|
|||
exp.PercentileDisc: transforms.preprocess(
|
||||
[transforms.add_within_group_for_percentiles]
|
||||
),
|
||||
exp.Pivot: transforms.preprocess([_unqualify_unpivot_columns]),
|
||||
exp.RegexpILike: _regexpilike_sql,
|
||||
exp.Rand: rename_func("RANDOM"),
|
||||
exp.Select: transforms.preprocess(
|
||||
|
@ -578,6 +731,9 @@ class Snowflake(Dialect):
|
|||
*(arg for expression in e.expressions for arg in expression.flatten()),
|
||||
),
|
||||
exp.Stuff: rename_func("INSERT"),
|
||||
exp.TimestampDiff: lambda self, e: self.func(
|
||||
"TIMESTAMPDIFF", e.unit, e.expression, e.this
|
||||
),
|
||||
exp.TimestampTrunc: timestamptrunc_sql,
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeToStr: lambda self, e: self.func(
|
||||
|
@ -589,8 +745,7 @@ class Snowflake(Dialect):
|
|||
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
|
||||
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
|
||||
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"),
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.UnixToTime: rename_func("TO_TIMESTAMP"),
|
||||
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
|
||||
exp.Xor: rename_func("BOOLXOR"),
|
||||
|
@ -612,6 +767,14 @@ class Snowflake(Dialect):
|
|||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str:
|
||||
milli = expression.args.get("milli")
|
||||
if milli is not None:
|
||||
milli_to_nano = milli.pop() * exp.Literal.number(1000000)
|
||||
expression.set("nano", milli_to_nano)
|
||||
|
||||
return rename_func("TIMESTAMP_FROM_PARTS")(self, expression)
|
||||
|
||||
def trycast_sql(self, expression: exp.TryCast) -> str:
|
||||
value = expression.this
|
||||
|
||||
|
@ -657,6 +820,9 @@ class Snowflake(Dialect):
|
|||
return f"{explode}{alias}"
|
||||
|
||||
def show_sql(self, expression: exp.Show) -> str:
|
||||
like = self.sql(expression, "like")
|
||||
like = f" LIKE {like}" if like else ""
|
||||
|
||||
scope = self.sql(expression, "scope")
|
||||
scope = f" {scope}" if scope else ""
|
||||
|
||||
|
@ -664,7 +830,7 @@ class Snowflake(Dialect):
|
|||
if scope_kind:
|
||||
scope_kind = f" IN {scope_kind}"
|
||||
|
||||
return f"SHOW {expression.name}{scope_kind}{scope}"
|
||||
return f"SHOW {expression.name}{like}{scope_kind}{scope}"
|
||||
|
||||
def regexpextract_sql(self, expression: exp.RegexpExtract) -> str:
|
||||
# Other dialects don't support all of the following parameters, so we need to
|
||||
|
|
|
@ -48,11 +48,8 @@ def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str
|
|||
return f"TIMESTAMP_MILLIS({timestamp})"
|
||||
if scale == exp.UnixToTime.MICROS:
|
||||
return f"TIMESTAMP_MICROS({timestamp})"
|
||||
if scale == exp.UnixToTime.NANOS:
|
||||
return f"TIMESTAMP_SECONDS({timestamp} / 1000000000)"
|
||||
|
||||
self.unsupported(f"Unsupported scale for timestamp: {scale}.")
|
||||
return ""
|
||||
return f"TIMESTAMP_SECONDS({timestamp} / POW(10, {scale}))"
|
||||
|
||||
|
||||
def _unalias_pivot(expression: exp.Expression) -> exp.Expression:
|
||||
|
@ -93,12 +90,7 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
|
|||
SELECT * FROM tbl PIVOT(SUM(tbl.sales) FOR quarter IN ('Q1', 'Q1'))
|
||||
"""
|
||||
if isinstance(expression, exp.Pivot):
|
||||
expression.args["field"].transform(
|
||||
lambda node: exp.column(node.output_name, quoted=node.this.quoted)
|
||||
if isinstance(node, exp.Column)
|
||||
else node,
|
||||
copy=False,
|
||||
)
|
||||
expression.set("field", transforms.unqualify_columns(expression.args["field"]))
|
||||
|
||||
return expression
|
||||
|
||||
|
@ -234,7 +226,7 @@ class Spark2(Hive):
|
|||
def struct_sql(self, expression: exp.Struct) -> str:
|
||||
args = []
|
||||
for arg in expression.expressions:
|
||||
if isinstance(arg, self.KEY_VALUE_DEFINITONS):
|
||||
if isinstance(arg, self.KEY_VALUE_DEFINITIONS):
|
||||
if isinstance(arg, exp.Bracket):
|
||||
args.append(exp.alias_(arg.this, arg.expressions[0].name))
|
||||
else:
|
||||
|
|
|
@ -78,6 +78,7 @@ class SQLite(Dialect):
|
|||
**parser.Parser.FUNCTIONS,
|
||||
"EDITDIST3": exp.Levenshtein.from_arg_list,
|
||||
}
|
||||
STRING_ALIASES = True
|
||||
|
||||
class Generator(generator.Generator):
|
||||
JOIN_HINTS = False
|
||||
|
|
|
@ -175,6 +175,8 @@ class Teradata(Dialect):
|
|||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
QUERY_HINTS = False
|
||||
TABLESAMPLE_KEYWORDS = "SAMPLE"
|
||||
LAST_DAY_SUPPORTS_DATE_PART = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
|
@ -214,7 +216,10 @@ class Teradata(Dialect):
|
|||
return self.cast_sql(expression, safe_prefix="TRY")
|
||||
|
||||
def tablesample_sql(
|
||||
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
|
||||
self,
|
||||
expression: exp.TableSample,
|
||||
sep: str = " AS ",
|
||||
tablesample_keyword: t.Optional[str] = None,
|
||||
) -> str:
|
||||
return f"{self.sql(expression, 'this')} SAMPLE {self.expressions(expression)}"
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import merge_without_target_sql
|
||||
from sqlglot.dialects.presto import Presto
|
||||
|
||||
|
||||
|
@ -11,6 +12,7 @@ class Trino(Presto):
|
|||
TRANSFORMS = {
|
||||
**Presto.Generator.TRANSFORMS,
|
||||
exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
exp.Merge: merge_without_target_sql,
|
||||
}
|
||||
|
||||
class Tokenizer(Presto.Tokenizer):
|
||||
|
|
|
@ -14,9 +14,10 @@ from sqlglot.dialects.dialect import (
|
|||
max_or_greatest,
|
||||
min_or_least,
|
||||
parse_date_delta,
|
||||
path_to_jsonpath,
|
||||
rename_func,
|
||||
timestrtotime_sql,
|
||||
ts_or_ds_to_date_sql,
|
||||
trim_sql,
|
||||
)
|
||||
from sqlglot.expressions import DataType
|
||||
from sqlglot.helper import seq_get
|
||||
|
@ -105,18 +106,17 @@ def _parse_format(args: t.List) -> exp.Expression:
|
|||
return exp.TimeToStr(this=this, format=fmt, culture=culture)
|
||||
|
||||
|
||||
def _parse_eomonth(args: t.List) -> exp.Expression:
|
||||
date = seq_get(args, 0)
|
||||
def _parse_eomonth(args: t.List) -> exp.LastDay:
|
||||
date = exp.TsOrDsToDate(this=seq_get(args, 0))
|
||||
month_lag = seq_get(args, 1)
|
||||
unit = DATE_DELTA_INTERVAL.get("month")
|
||||
|
||||
if month_lag is None:
|
||||
return exp.LastDateOfMonth(this=date)
|
||||
this: exp.Expression = date
|
||||
else:
|
||||
unit = DATE_DELTA_INTERVAL.get("month")
|
||||
this = exp.DateAdd(this=date, expression=month_lag, unit=unit and exp.var(unit))
|
||||
|
||||
# Remove month lag argument in parser as its compared with the number of arguments of the resulting class
|
||||
args.remove(month_lag)
|
||||
|
||||
return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit))
|
||||
return exp.LastDay(this=this)
|
||||
|
||||
|
||||
def _parse_hashbytes(args: t.List) -> exp.Expression:
|
||||
|
@ -137,26 +137,27 @@ def _parse_hashbytes(args: t.List) -> exp.Expression:
|
|||
return exp.func("HASHBYTES", *args)
|
||||
|
||||
|
||||
DATEPART_ONLY_FORMATS = {"dw", "hour", "quarter"}
|
||||
DATEPART_ONLY_FORMATS = {"DW", "HOUR", "QUARTER"}
|
||||
|
||||
|
||||
def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str:
|
||||
fmt = (
|
||||
expression.args["format"]
|
||||
if isinstance(expression, exp.NumberToStr)
|
||||
else exp.Literal.string(
|
||||
format_time(
|
||||
expression.text("format"),
|
||||
t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING),
|
||||
)
|
||||
)
|
||||
)
|
||||
fmt = expression.args["format"]
|
||||
|
||||
# There is no format for "quarter"
|
||||
if fmt.name.lower() in DATEPART_ONLY_FORMATS:
|
||||
return self.func("DATEPART", fmt.name, expression.this)
|
||||
if not isinstance(expression, exp.NumberToStr):
|
||||
if fmt.is_string:
|
||||
mapped_fmt = format_time(fmt.name, TSQL.INVERSE_TIME_MAPPING)
|
||||
|
||||
return self.func("FORMAT", expression.this, fmt, expression.args.get("culture"))
|
||||
name = (mapped_fmt or "").upper()
|
||||
if name in DATEPART_ONLY_FORMATS:
|
||||
return self.func("DATEPART", name, expression.this)
|
||||
|
||||
fmt_sql = self.sql(exp.Literal.string(mapped_fmt))
|
||||
else:
|
||||
fmt_sql = self.format_time(expression) or self.sql(fmt)
|
||||
else:
|
||||
fmt_sql = self.sql(fmt)
|
||||
|
||||
return self.func("FORMAT", expression.this, fmt_sql, expression.args.get("culture"))
|
||||
|
||||
|
||||
def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str:
|
||||
|
@ -239,6 +240,30 @@ def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
# https://learn.microsoft.com/en-us/sql/t-sql/functions/datetimefromparts-transact-sql?view=sql-server-ver16#syntax
|
||||
def _parse_datetimefromparts(args: t.List) -> exp.TimestampFromParts:
|
||||
return exp.TimestampFromParts(
|
||||
year=seq_get(args, 0),
|
||||
month=seq_get(args, 1),
|
||||
day=seq_get(args, 2),
|
||||
hour=seq_get(args, 3),
|
||||
min=seq_get(args, 4),
|
||||
sec=seq_get(args, 5),
|
||||
milli=seq_get(args, 6),
|
||||
)
|
||||
|
||||
|
||||
# https://learn.microsoft.com/en-us/sql/t-sql/functions/timefromparts-transact-sql?view=sql-server-ver16#syntax
|
||||
def _parse_timefromparts(args: t.List) -> exp.TimeFromParts:
|
||||
return exp.TimeFromParts(
|
||||
hour=seq_get(args, 0),
|
||||
min=seq_get(args, 1),
|
||||
sec=seq_get(args, 2),
|
||||
fractions=seq_get(args, 3),
|
||||
precision=seq_get(args, 4),
|
||||
)
|
||||
|
||||
|
||||
class TSQL(Dialect):
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
|
||||
TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"
|
||||
|
@ -352,7 +377,7 @@ class TSQL(Dialect):
|
|||
}
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
IDENTIFIERS = ['"', ("[", "]")]
|
||||
IDENTIFIERS = [("[", "]"), '"']
|
||||
QUOTES = ["'", '"']
|
||||
HEX_STRINGS = [("0x", ""), ("0X", "")]
|
||||
VAR_SINGLE_TOKENS = {"@", "$", "#"}
|
||||
|
@ -362,6 +387,7 @@ class TSQL(Dialect):
|
|||
"DATETIME2": TokenType.DATETIME,
|
||||
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
|
||||
"DECLARE": TokenType.COMMAND,
|
||||
"EXEC": TokenType.COMMAND,
|
||||
"IMAGE": TokenType.IMAGE,
|
||||
"MONEY": TokenType.MONEY,
|
||||
"NTEXT": TokenType.TEXT,
|
||||
|
@ -397,6 +423,7 @@ class TSQL(Dialect):
|
|||
"DATEDIFF": _parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
|
||||
"DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True),
|
||||
"DATEPART": _format_time_lambda(exp.TimeToStr),
|
||||
"DATETIMEFROMPARTS": _parse_datetimefromparts,
|
||||
"EOMONTH": _parse_eomonth,
|
||||
"FORMAT": _parse_format,
|
||||
"GETDATE": exp.CurrentTimestamp.from_arg_list,
|
||||
|
@ -411,6 +438,7 @@ class TSQL(Dialect):
|
|||
"SUSER_NAME": exp.CurrentUser.from_arg_list,
|
||||
"SUSER_SNAME": exp.CurrentUser.from_arg_list,
|
||||
"SYSTEM_USER": exp.CurrentUser.from_arg_list,
|
||||
"TIMEFROMPARTS": _parse_timefromparts,
|
||||
}
|
||||
|
||||
JOIN_HINTS = {
|
||||
|
@ -440,6 +468,7 @@ class TSQL(Dialect):
|
|||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
|
||||
STRING_ALIASES = True
|
||||
|
||||
def _parse_projections(self) -> t.List[exp.Expression]:
|
||||
"""
|
||||
|
@ -630,8 +659,10 @@ class TSQL(Dialect):
|
|||
COMPUTED_COLUMN_WITH_TYPE = False
|
||||
CTE_RECURSIVE_KEYWORD_REQUIRED = False
|
||||
ENSURE_BOOLS = True
|
||||
NULL_ORDERING_SUPPORTED = False
|
||||
NULL_ORDERING_SUPPORTED = None
|
||||
SUPPORTS_SINGLE_ARG_CONCAT = False
|
||||
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
|
||||
SUPPORTS_SELECT_INTO = True
|
||||
|
||||
EXPRESSIONS_WITHOUT_NESTED_CTES = {
|
||||
exp.Delete,
|
||||
|
@ -667,13 +698,16 @@ class TSQL(Dialect):
|
|||
exp.CurrentTimestamp: rename_func("GETDATE"),
|
||||
exp.Extract: rename_func("DATEPART"),
|
||||
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
|
||||
exp.GetPath: path_to_jsonpath("JSON_VALUE"),
|
||||
exp.GroupConcat: _string_agg_sql,
|
||||
exp.If: rename_func("IIF"),
|
||||
exp.LastDay: lambda self, e: self.func("EOMONTH", e.this),
|
||||
exp.Length: rename_func("LEN"),
|
||||
exp.Max: max_or_greatest,
|
||||
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
|
||||
exp.Min: min_or_least,
|
||||
exp.NumberToStr: _format_sql,
|
||||
exp.ParseJSON: lambda self, e: self.sql(e, "this"),
|
||||
exp.Select: transforms.preprocess(
|
||||
[
|
||||
transforms.eliminate_distinct_on,
|
||||
|
@ -689,9 +723,9 @@ class TSQL(Dialect):
|
|||
exp.TemporaryProperty: lambda self, e: "",
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeToStr: _format_sql,
|
||||
exp.Trim: trim_sql,
|
||||
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
|
||||
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("tsql"),
|
||||
}
|
||||
|
||||
TRANSFORMS.pop(exp.ReturnsProperty)
|
||||
|
@ -701,6 +735,46 @@ class TSQL(Dialect):
|
|||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def lateral_op(self, expression: exp.Lateral) -> str:
|
||||
cross_apply = expression.args.get("cross_apply")
|
||||
if cross_apply is True:
|
||||
return "CROSS APPLY"
|
||||
if cross_apply is False:
|
||||
return "OUTER APPLY"
|
||||
|
||||
# TODO: perhaps we can check if the parent is a Join and transpile it appropriately
|
||||
self.unsupported("LATERAL clause is not supported.")
|
||||
return "LATERAL"
|
||||
|
||||
def timefromparts_sql(self, expression: exp.TimeFromParts) -> str:
|
||||
nano = expression.args.get("nano")
|
||||
if nano is not None:
|
||||
nano.pop()
|
||||
self.unsupported("Specifying nanoseconds is not supported in TIMEFROMPARTS.")
|
||||
|
||||
if expression.args.get("fractions") is None:
|
||||
expression.set("fractions", exp.Literal.number(0))
|
||||
if expression.args.get("precision") is None:
|
||||
expression.set("precision", exp.Literal.number(0))
|
||||
|
||||
return rename_func("TIMEFROMPARTS")(self, expression)
|
||||
|
||||
def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str:
|
||||
zone = expression.args.get("zone")
|
||||
if zone is not None:
|
||||
zone.pop()
|
||||
self.unsupported("Time zone is not supported in DATETIMEFROMPARTS.")
|
||||
|
||||
nano = expression.args.get("nano")
|
||||
if nano is not None:
|
||||
nano.pop()
|
||||
self.unsupported("Specifying nanoseconds is not supported in DATETIMEFROMPARTS.")
|
||||
|
||||
if expression.args.get("milli") is None:
|
||||
expression.set("milli", exp.Literal.number(0))
|
||||
|
||||
return rename_func("DATETIMEFROMPARTS")(self, expression)
|
||||
|
||||
def set_operation(self, expression: exp.Union, op: str) -> str:
|
||||
limit = expression.args.get("limit")
|
||||
if limit:
|
||||
|
|
|
@ -132,11 +132,10 @@ def ordered(this, desc, nulls_first):
|
|||
|
||||
@null_if_any
|
||||
def interval(this, unit):
|
||||
unit = unit.lower()
|
||||
plural = unit + "s"
|
||||
plural = unit + "S"
|
||||
if plural in Generator.TIME_PART_SINGULARS:
|
||||
unit = plural
|
||||
return datetime.timedelta(**{unit: float(this)})
|
||||
return datetime.timedelta(**{unit.lower(): float(this)})
|
||||
|
||||
|
||||
@null_if_any("this", "expression")
|
||||
|
@ -176,6 +175,7 @@ ENV = {
|
|||
"DOT": null_if_any(lambda e, this: e[this]),
|
||||
"EQ": null_if_any(lambda this, e: this == e),
|
||||
"EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
|
||||
"GETPATH": null_if_any(lambda this, e: this.get(e)),
|
||||
"GT": null_if_any(lambda this, e: this > e),
|
||||
"GTE": null_if_any(lambda this, e: this >= e),
|
||||
"IF": lambda predicate, true, false: true if predicate else false,
|
||||
|
|
|
@ -16,6 +16,7 @@ import datetime
|
|||
import math
|
||||
import numbers
|
||||
import re
|
||||
import textwrap
|
||||
import typing as t
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
|
@ -35,6 +36,8 @@ from sqlglot.helper import (
|
|||
from sqlglot.tokens import Token
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from typing_extensions import Literal as Lit
|
||||
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
|
||||
|
||||
|
@ -242,6 +245,9 @@ class Expression(metaclass=_Expression):
|
|||
def is_type(self, *dtypes) -> bool:
|
||||
return self.type is not None and self.type.is_type(*dtypes)
|
||||
|
||||
def is_leaf(self) -> bool:
|
||||
return not any(isinstance(v, (Expression, list)) for v in self.args.values())
|
||||
|
||||
@property
|
||||
def meta(self) -> t.Dict[str, t.Any]:
|
||||
if self._meta is None:
|
||||
|
@ -497,7 +503,14 @@ class Expression(metaclass=_Expression):
|
|||
return self.sql()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self._to_s()
|
||||
return _to_s(self)
|
||||
|
||||
def to_s(self) -> str:
|
||||
"""
|
||||
Same as __repr__, but includes additional information which can be useful
|
||||
for debugging, like empty or missing args and the AST nodes' object IDs.
|
||||
"""
|
||||
return _to_s(self, verbose=True)
|
||||
|
||||
def sql(self, dialect: DialectType = None, **opts) -> str:
|
||||
"""
|
||||
|
@ -514,30 +527,6 @@ class Expression(metaclass=_Expression):
|
|||
|
||||
return Dialect.get_or_raise(dialect).generate(self, **opts)
|
||||
|
||||
def _to_s(self, hide_missing: bool = True, level: int = 0) -> str:
|
||||
indent = "" if not level else "\n"
|
||||
indent += "".join([" "] * level)
|
||||
left = f"({self.key.upper()} "
|
||||
|
||||
args: t.Dict[str, t.Any] = {
|
||||
k: ", ".join(
|
||||
v._to_s(hide_missing=hide_missing, level=level + 1)
|
||||
if hasattr(v, "_to_s")
|
||||
else str(v)
|
||||
for v in ensure_list(vs)
|
||||
if v is not None
|
||||
)
|
||||
for k, vs in self.args.items()
|
||||
}
|
||||
args["comments"] = self.comments
|
||||
args["type"] = self.type
|
||||
args = {k: v for k, v in args.items() if v or not hide_missing}
|
||||
|
||||
right = ", ".join(f"{k}: {v}" for k, v in args.items())
|
||||
right += ")"
|
||||
|
||||
return indent + left + right
|
||||
|
||||
def transform(self, fun, *args, copy=True, **kwargs):
|
||||
"""
|
||||
Recursively visits all tree nodes (excluding already transformed ones)
|
||||
|
@ -580,8 +569,9 @@ class Expression(metaclass=_Expression):
|
|||
For example::
|
||||
|
||||
>>> tree = Select().select("x").from_("tbl")
|
||||
>>> tree.find(Column).replace(Column(this="y"))
|
||||
(COLUMN this: y)
|
||||
>>> tree.find(Column).replace(column("y"))
|
||||
Column(
|
||||
this=Identifier(this=y, quoted=False))
|
||||
>>> tree.sql()
|
||||
'SELECT y FROM tbl'
|
||||
|
||||
|
@ -831,6 +821,9 @@ class Expression(metaclass=_Expression):
|
|||
div.args["safe"] = safe
|
||||
return div
|
||||
|
||||
def desc(self, nulls_first: bool = False) -> Ordered:
|
||||
return Ordered(this=self.copy(), desc=True, nulls_first=nulls_first)
|
||||
|
||||
def __lt__(self, other: t.Any) -> LT:
|
||||
return self._binop(LT, other)
|
||||
|
||||
|
@ -1109,7 +1102,7 @@ class Clone(Expression):
|
|||
|
||||
|
||||
class Describe(Expression):
|
||||
arg_types = {"this": True, "kind": False, "expressions": False}
|
||||
arg_types = {"this": True, "extended": False, "kind": False, "expressions": False}
|
||||
|
||||
|
||||
class Kill(Expression):
|
||||
|
@ -1124,6 +1117,10 @@ class Set(Expression):
|
|||
arg_types = {"expressions": False, "unset": False, "tag": False}
|
||||
|
||||
|
||||
class Heredoc(Expression):
|
||||
arg_types = {"this": True, "tag": False}
|
||||
|
||||
|
||||
class SetItem(Expression):
|
||||
arg_types = {
|
||||
"this": False,
|
||||
|
@ -1937,7 +1934,13 @@ class Join(Expression):
|
|||
|
||||
|
||||
class Lateral(UDTF):
|
||||
arg_types = {"this": True, "view": False, "outer": False, "alias": False}
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"view": False,
|
||||
"outer": False,
|
||||
"alias": False,
|
||||
"cross_apply": False, # True -> CROSS APPLY, False -> OUTER APPLY
|
||||
}
|
||||
|
||||
|
||||
class MatchRecognize(Expression):
|
||||
|
@ -1964,7 +1967,12 @@ class Offset(Expression):
|
|||
|
||||
|
||||
class Order(Expression):
|
||||
arg_types = {"this": False, "expressions": True, "interpolate": False}
|
||||
arg_types = {
|
||||
"this": False,
|
||||
"expressions": True,
|
||||
"interpolate": False,
|
||||
"siblings": False,
|
||||
}
|
||||
|
||||
|
||||
# https://clickhouse.com/docs/en/sql-reference/statements/select/order-by#order-by-expr-with-fill-modifier
|
||||
|
@ -2002,6 +2010,11 @@ class AutoIncrementProperty(Property):
|
|||
arg_types = {"this": True}
|
||||
|
||||
|
||||
# https://docs.aws.amazon.com/prescriptive-guidance/latest/materialized-views-redshift/refreshing-materialized-views.html
|
||||
class AutoRefreshProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class BlockCompressionProperty(Property):
|
||||
arg_types = {"autotemp": False, "always": False, "default": True, "manual": True, "never": True}
|
||||
|
||||
|
@ -2259,6 +2272,10 @@ class SortKeyProperty(Property):
|
|||
arg_types = {"this": True, "compound": False}
|
||||
|
||||
|
||||
class SqlReadWriteProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class SqlSecurityProperty(Property):
|
||||
arg_types = {"definer": True}
|
||||
|
||||
|
@ -2543,7 +2560,6 @@ class Table(Expression):
|
|||
"version": False,
|
||||
"format": False,
|
||||
"pattern": False,
|
||||
"index": False,
|
||||
"ordinality": False,
|
||||
"when": False,
|
||||
}
|
||||
|
@ -2585,6 +2601,14 @@ class Table(Expression):
|
|||
|
||||
return parts
|
||||
|
||||
def to_column(self, copy: bool = True) -> Alias | Column | Dot:
|
||||
parts = self.parts
|
||||
col = column(*reversed(parts[0:4]), fields=parts[4:], copy=copy) # type: ignore
|
||||
alias = self.args.get("alias")
|
||||
if alias:
|
||||
col = alias_(col, alias.this, copy=copy)
|
||||
return col
|
||||
|
||||
|
||||
class Union(Subqueryable):
|
||||
arg_types = {
|
||||
|
@ -2694,6 +2718,14 @@ class Unnest(UDTF):
|
|||
"offset": False,
|
||||
}
|
||||
|
||||
@property
|
||||
def selects(self) -> t.List[Expression]:
|
||||
columns = super().selects
|
||||
offset = self.args.get("offset")
|
||||
if offset:
|
||||
columns = columns + [to_identifier("offset") if offset is True else offset]
|
||||
return columns
|
||||
|
||||
|
||||
class Update(Expression):
|
||||
arg_types = {
|
||||
|
@ -3368,7 +3400,7 @@ class Select(Subqueryable):
|
|||
|
||||
return Create(
|
||||
this=table_expression,
|
||||
kind="table",
|
||||
kind="TABLE",
|
||||
expression=instance,
|
||||
properties=properties_expression,
|
||||
)
|
||||
|
@ -3488,7 +3520,6 @@ class TableSample(Expression):
|
|||
"rows": False,
|
||||
"size": False,
|
||||
"seed": False,
|
||||
"kind": False,
|
||||
}
|
||||
|
||||
|
||||
|
@ -3517,6 +3548,10 @@ class Pivot(Expression):
|
|||
"include_nulls": False,
|
||||
}
|
||||
|
||||
@property
|
||||
def unpivot(self) -> bool:
|
||||
return bool(self.args.get("unpivot"))
|
||||
|
||||
|
||||
class Window(Condition):
|
||||
arg_types = {
|
||||
|
@ -3604,6 +3639,7 @@ class DataType(Expression):
|
|||
BOOLEAN = auto()
|
||||
CHAR = auto()
|
||||
DATE = auto()
|
||||
DATE32 = auto()
|
||||
DATEMULTIRANGE = auto()
|
||||
DATERANGE = auto()
|
||||
DATETIME = auto()
|
||||
|
@ -3631,6 +3667,8 @@ class DataType(Expression):
|
|||
INTERVAL = auto()
|
||||
IPADDRESS = auto()
|
||||
IPPREFIX = auto()
|
||||
IPV4 = auto()
|
||||
IPV6 = auto()
|
||||
JSON = auto()
|
||||
JSONB = auto()
|
||||
LONGBLOB = auto()
|
||||
|
@ -3729,6 +3767,7 @@ class DataType(Expression):
|
|||
Type.TIMESTAMP_MS,
|
||||
Type.TIMESTAMP_NS,
|
||||
Type.DATE,
|
||||
Type.DATE32,
|
||||
Type.DATETIME,
|
||||
Type.DATETIME64,
|
||||
}
|
||||
|
@ -4100,6 +4139,12 @@ class Alias(Expression):
|
|||
return self.alias
|
||||
|
||||
|
||||
# BigQuery requires the UNPIVOT column list aliases to be either strings or ints, but
|
||||
# other dialects require identifiers. This enables us to transpile between them easily.
|
||||
class PivotAlias(Alias):
|
||||
pass
|
||||
|
||||
|
||||
class Aliases(Expression):
|
||||
arg_types = {"this": True, "expressions": True}
|
||||
|
||||
|
@ -4108,6 +4153,11 @@ class Aliases(Expression):
|
|||
return self.expressions
|
||||
|
||||
|
||||
# https://docs.aws.amazon.com/redshift/latest/dg/query-super.html
|
||||
class AtIndex(Expression):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
class AtTimeZone(Expression):
|
||||
arg_types = {"this": True, "zone": True}
|
||||
|
||||
|
@ -4154,16 +4204,16 @@ class TimeUnit(Expression):
|
|||
arg_types = {"unit": False}
|
||||
|
||||
UNABBREVIATED_UNIT_NAME = {
|
||||
"d": "day",
|
||||
"h": "hour",
|
||||
"m": "minute",
|
||||
"ms": "millisecond",
|
||||
"ns": "nanosecond",
|
||||
"q": "quarter",
|
||||
"s": "second",
|
||||
"us": "microsecond",
|
||||
"w": "week",
|
||||
"y": "year",
|
||||
"D": "DAY",
|
||||
"H": "HOUR",
|
||||
"M": "MINUTE",
|
||||
"MS": "MILLISECOND",
|
||||
"NS": "NANOSECOND",
|
||||
"Q": "QUARTER",
|
||||
"S": "SECOND",
|
||||
"US": "MICROSECOND",
|
||||
"W": "WEEK",
|
||||
"Y": "YEAR",
|
||||
}
|
||||
|
||||
VAR_LIKE = (Column, Literal, Var)
|
||||
|
@ -4171,9 +4221,11 @@ class TimeUnit(Expression):
|
|||
def __init__(self, **args):
|
||||
unit = args.get("unit")
|
||||
if isinstance(unit, self.VAR_LIKE):
|
||||
args["unit"] = Var(this=self.UNABBREVIATED_UNIT_NAME.get(unit.name) or unit.name)
|
||||
args["unit"] = Var(
|
||||
this=(self.UNABBREVIATED_UNIT_NAME.get(unit.name) or unit.name).upper()
|
||||
)
|
||||
elif isinstance(unit, Week):
|
||||
unit.set("this", Var(this=unit.this.name))
|
||||
unit.set("this", Var(this=unit.this.name.upper()))
|
||||
|
||||
super().__init__(**args)
|
||||
|
||||
|
@ -4301,6 +4353,20 @@ class Anonymous(Func):
|
|||
is_var_len_args = True
|
||||
|
||||
|
||||
class AnonymousAggFunc(AggFunc):
|
||||
arg_types = {"this": True, "expressions": False}
|
||||
is_var_len_args = True
|
||||
|
||||
|
||||
# https://clickhouse.com/docs/en/sql-reference/aggregate-functions/combinators
|
||||
class CombinedAggFunc(AnonymousAggFunc):
|
||||
arg_types = {"this": True, "expressions": False, "parts": True}
|
||||
|
||||
|
||||
class CombinedParameterizedAgg(ParameterizedAgg):
|
||||
arg_types = {"this": True, "expressions": True, "params": True, "parts": True}
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/hll
|
||||
# https://docs.aws.amazon.com/redshift/latest/dg/r_HLL_function.html
|
||||
class Hll(AggFunc):
|
||||
|
@ -4381,7 +4447,7 @@ class ArraySort(Func):
|
|||
|
||||
|
||||
class ArraySum(Func):
|
||||
pass
|
||||
arg_types = {"this": True, "expression": False}
|
||||
|
||||
|
||||
class ArrayUnionAgg(AggFunc):
|
||||
|
@ -4498,7 +4564,7 @@ class Count(AggFunc):
|
|||
|
||||
|
||||
class CountIf(AggFunc):
|
||||
pass
|
||||
_sql_names = ["COUNT_IF", "COUNTIF"]
|
||||
|
||||
|
||||
class CurrentDate(Func):
|
||||
|
@ -4537,6 +4603,17 @@ class DateDiff(Func, TimeUnit):
|
|||
class DateTrunc(Func):
|
||||
arg_types = {"unit": True, "this": True, "zone": False}
|
||||
|
||||
def __init__(self, **args):
|
||||
unit = args.get("unit")
|
||||
if isinstance(unit, TimeUnit.VAR_LIKE):
|
||||
args["unit"] = Literal.string(
|
||||
(TimeUnit.UNABBREVIATED_UNIT_NAME.get(unit.name) or unit.name).upper()
|
||||
)
|
||||
elif isinstance(unit, Week):
|
||||
unit.set("this", Literal.string(unit.this.name.upper()))
|
||||
|
||||
super().__init__(**args)
|
||||
|
||||
@property
|
||||
def unit(self) -> Expression:
|
||||
return self.args["unit"]
|
||||
|
@ -4582,8 +4659,9 @@ class MonthsBetween(Func):
|
|||
arg_types = {"this": True, "expression": True, "roundoff": False}
|
||||
|
||||
|
||||
class LastDateOfMonth(Func):
|
||||
pass
|
||||
class LastDay(Func, TimeUnit):
|
||||
_sql_names = ["LAST_DAY", "LAST_DAY_OF_MONTH"]
|
||||
arg_types = {"this": True, "unit": False}
|
||||
|
||||
|
||||
class Extract(Func):
|
||||
|
@ -4627,10 +4705,22 @@ class TimeTrunc(Func, TimeUnit):
|
|||
|
||||
|
||||
class DateFromParts(Func):
|
||||
_sql_names = ["DATEFROMPARTS"]
|
||||
_sql_names = ["DATE_FROM_PARTS", "DATEFROMPARTS"]
|
||||
arg_types = {"year": True, "month": True, "day": True}
|
||||
|
||||
|
||||
class TimeFromParts(Func):
|
||||
_sql_names = ["TIME_FROM_PARTS", "TIMEFROMPARTS"]
|
||||
arg_types = {
|
||||
"hour": True,
|
||||
"min": True,
|
||||
"sec": True,
|
||||
"nano": False,
|
||||
"fractions": False,
|
||||
"precision": False,
|
||||
}
|
||||
|
||||
|
||||
class DateStrToDate(Func):
|
||||
pass
|
||||
|
||||
|
@ -4754,6 +4844,16 @@ class JSONObject(Func):
|
|||
}
|
||||
|
||||
|
||||
class JSONObjectAgg(AggFunc):
|
||||
arg_types = {
|
||||
"expressions": False,
|
||||
"null_handling": False,
|
||||
"unique_keys": False,
|
||||
"return_type": False,
|
||||
"encoding": False,
|
||||
}
|
||||
|
||||
|
||||
# https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/JSON_ARRAY.html
|
||||
class JSONArray(Func):
|
||||
arg_types = {
|
||||
|
@ -4841,6 +4941,15 @@ class ParseJSON(Func):
|
|||
is_var_len_args = True
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/get_path
|
||||
class GetPath(Func):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
@property
|
||||
def output_name(self) -> str:
|
||||
return self.expression.output_name
|
||||
|
||||
|
||||
class Least(Func):
|
||||
arg_types = {"this": True, "expressions": False}
|
||||
is_var_len_args = True
|
||||
|
@ -5026,7 +5135,7 @@ class RegexpReplace(Func):
|
|||
arg_types = {
|
||||
"this": True,
|
||||
"expression": True,
|
||||
"replacement": True,
|
||||
"replacement": False,
|
||||
"position": False,
|
||||
"occurrence": False,
|
||||
"parameters": False,
|
||||
|
@ -5052,8 +5161,10 @@ class Repeat(Func):
|
|||
arg_types = {"this": True, "times": True}
|
||||
|
||||
|
||||
# https://learn.microsoft.com/en-us/sql/t-sql/functions/round-transact-sql?view=sql-server-ver16
|
||||
# tsql third argument function == trunctaion if not 0
|
||||
class Round(Func):
|
||||
arg_types = {"this": True, "decimals": False}
|
||||
arg_types = {"this": True, "decimals": False, "truncate": False}
|
||||
|
||||
|
||||
class RowNumber(Func):
|
||||
|
@ -5228,6 +5339,10 @@ class TsOrDsToDate(Func):
|
|||
arg_types = {"this": True, "format": False}
|
||||
|
||||
|
||||
class TsOrDsToTime(Func):
|
||||
pass
|
||||
|
||||
|
||||
class TsOrDiToDi(Func):
|
||||
pass
|
||||
|
||||
|
@ -5236,6 +5351,11 @@ class Unhex(Func):
|
|||
pass
|
||||
|
||||
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/date_functions#unix_date
|
||||
class UnixDate(Func):
|
||||
pass
|
||||
|
||||
|
||||
class UnixToStr(Func):
|
||||
arg_types = {"this": True, "format": False}
|
||||
|
||||
|
@ -5245,10 +5365,16 @@ class UnixToStr(Func):
|
|||
class UnixToTime(Func):
|
||||
arg_types = {"this": True, "scale": False, "zone": False, "hours": False, "minutes": False}
|
||||
|
||||
SECONDS = Literal.string("seconds")
|
||||
MILLIS = Literal.string("millis")
|
||||
MICROS = Literal.string("micros")
|
||||
NANOS = Literal.string("nanos")
|
||||
SECONDS = Literal.number(0)
|
||||
DECIS = Literal.number(1)
|
||||
CENTIS = Literal.number(2)
|
||||
MILLIS = Literal.number(3)
|
||||
DECIMILLIS = Literal.number(4)
|
||||
CENTIMILLIS = Literal.number(5)
|
||||
MICROS = Literal.number(6)
|
||||
DECIMICROS = Literal.number(7)
|
||||
CENTIMICROS = Literal.number(8)
|
||||
NANOS = Literal.number(9)
|
||||
|
||||
|
||||
class UnixToTimeStr(Func):
|
||||
|
@ -5256,8 +5382,7 @@ class UnixToTimeStr(Func):
|
|||
|
||||
|
||||
class TimestampFromParts(Func):
|
||||
"""Constructs a timestamp given its constituent parts."""
|
||||
|
||||
_sql_names = ["TIMESTAMP_FROM_PARTS", "TIMESTAMPFROMPARTS"]
|
||||
arg_types = {
|
||||
"year": True,
|
||||
"month": True,
|
||||
|
@ -5265,6 +5390,9 @@ class TimestampFromParts(Func):
|
|||
"hour": True,
|
||||
"min": True,
|
||||
"sec": True,
|
||||
"nano": False,
|
||||
"zone": False,
|
||||
"milli": False,
|
||||
}
|
||||
|
||||
|
||||
|
@ -5358,9 +5486,9 @@ def maybe_parse(
|
|||
|
||||
Example:
|
||||
>>> maybe_parse("1")
|
||||
(LITERAL this: 1, is_string: False)
|
||||
Literal(this=1, is_string=False)
|
||||
>>> maybe_parse(to_identifier("x"))
|
||||
(IDENTIFIER this: x, quoted: False)
|
||||
Identifier(this=x, quoted=False)
|
||||
|
||||
Args:
|
||||
sql_or_expression: the SQL code string or an expression
|
||||
|
@ -5407,6 +5535,39 @@ def maybe_copy(instance, copy=True):
|
|||
return instance.copy() if copy and instance else instance
|
||||
|
||||
|
||||
def _to_s(node: t.Any, verbose: bool = False, level: int = 0) -> str:
|
||||
"""Generate a textual representation of an Expression tree"""
|
||||
indent = "\n" + (" " * (level + 1))
|
||||
delim = f",{indent}"
|
||||
|
||||
if isinstance(node, Expression):
|
||||
args = {k: v for k, v in node.args.items() if (v is not None and v != []) or verbose}
|
||||
|
||||
if (node.type or verbose) and not isinstance(node, DataType):
|
||||
args["_type"] = node.type
|
||||
if node.comments or verbose:
|
||||
args["_comments"] = node.comments
|
||||
|
||||
if verbose:
|
||||
args["_id"] = id(node)
|
||||
|
||||
# Inline leaves for a more compact representation
|
||||
if node.is_leaf():
|
||||
indent = ""
|
||||
delim = ", "
|
||||
|
||||
items = delim.join([f"{k}={_to_s(v, verbose, level + 1)}" for k, v in args.items()])
|
||||
return f"{node.__class__.__name__}({indent}{items})"
|
||||
|
||||
if isinstance(node, list):
|
||||
items = delim.join(_to_s(i, verbose, level + 1) for i in node)
|
||||
items = f"{indent}{items}" if items else ""
|
||||
return f"[{items}]"
|
||||
|
||||
# Indent multiline strings to match the current level
|
||||
return indent.join(textwrap.dedent(str(node).strip("\n")).splitlines())
|
||||
|
||||
|
||||
def _is_wrong_expression(expression, into):
|
||||
return isinstance(expression, Expression) and not isinstance(expression, into)
|
||||
|
||||
|
@ -5816,7 +5977,7 @@ def delete(
|
|||
def insert(
|
||||
expression: ExpOrStr,
|
||||
into: ExpOrStr,
|
||||
columns: t.Optional[t.Sequence[ExpOrStr]] = None,
|
||||
columns: t.Optional[t.Sequence[str | Identifier]] = None,
|
||||
overwrite: t.Optional[bool] = None,
|
||||
returning: t.Optional[ExpOrStr] = None,
|
||||
dialect: DialectType = None,
|
||||
|
@ -5847,15 +6008,7 @@ def insert(
|
|||
this: Table | Schema = maybe_parse(into, into=Table, dialect=dialect, copy=copy, **opts)
|
||||
|
||||
if columns:
|
||||
this = _apply_list_builder(
|
||||
*columns,
|
||||
instance=Schema(this=this),
|
||||
arg="expressions",
|
||||
into=Identifier,
|
||||
copy=False,
|
||||
dialect=dialect,
|
||||
**opts,
|
||||
)
|
||||
this = Schema(this=this, expressions=[to_identifier(c, copy=copy) for c in columns])
|
||||
|
||||
insert = Insert(this=this, expression=expr, overwrite=overwrite)
|
||||
|
||||
|
@ -6073,7 +6226,7 @@ def to_interval(interval: str | Literal) -> Interval:
|
|||
|
||||
return Interval(
|
||||
this=Literal.string(interval_parts.group(1)),
|
||||
unit=Var(this=interval_parts.group(2)),
|
||||
unit=Var(this=interval_parts.group(2).upper()),
|
||||
)
|
||||
|
||||
|
||||
|
@ -6219,13 +6372,44 @@ def subquery(
|
|||
return Select().from_(expression, dialect=dialect, **opts)
|
||||
|
||||
|
||||
@t.overload
|
||||
def column(
|
||||
col: str | Identifier,
|
||||
table: t.Optional[str | Identifier] = None,
|
||||
db: t.Optional[str | Identifier] = None,
|
||||
catalog: t.Optional[str | Identifier] = None,
|
||||
*,
|
||||
fields: t.Collection[t.Union[str, Identifier]],
|
||||
quoted: t.Optional[bool] = None,
|
||||
copy: bool = True,
|
||||
) -> Dot:
|
||||
pass
|
||||
|
||||
|
||||
@t.overload
|
||||
def column(
|
||||
col: str | Identifier,
|
||||
table: t.Optional[str | Identifier] = None,
|
||||
db: t.Optional[str | Identifier] = None,
|
||||
catalog: t.Optional[str | Identifier] = None,
|
||||
*,
|
||||
fields: Lit[None] = None,
|
||||
quoted: t.Optional[bool] = None,
|
||||
copy: bool = True,
|
||||
) -> Column:
|
||||
pass
|
||||
|
||||
|
||||
def column(
|
||||
col,
|
||||
table=None,
|
||||
db=None,
|
||||
catalog=None,
|
||||
*,
|
||||
fields=None,
|
||||
quoted=None,
|
||||
copy=True,
|
||||
):
|
||||
"""
|
||||
Build a Column.
|
||||
|
||||
|
@ -6234,18 +6418,24 @@ def column(
|
|||
table: Table name.
|
||||
db: Database name.
|
||||
catalog: Catalog name.
|
||||
fields: Additional fields using dots.
|
||||
quoted: Whether to force quotes on the column's identifiers.
|
||||
copy: Whether or not to copy identifiers if passed in.
|
||||
|
||||
Returns:
|
||||
The new Column instance.
|
||||
"""
|
||||
return Column(
|
||||
this=to_identifier(col, quoted=quoted),
|
||||
table=to_identifier(table, quoted=quoted),
|
||||
db=to_identifier(db, quoted=quoted),
|
||||
catalog=to_identifier(catalog, quoted=quoted),
|
||||
this = Column(
|
||||
this=to_identifier(col, quoted=quoted, copy=copy),
|
||||
table=to_identifier(table, quoted=quoted, copy=copy),
|
||||
db=to_identifier(db, quoted=quoted, copy=copy),
|
||||
catalog=to_identifier(catalog, quoted=quoted, copy=copy),
|
||||
)
|
||||
|
||||
if fields:
|
||||
this = Dot.build((this, *(to_identifier(field, copy=copy) for field in fields)))
|
||||
return this
|
||||
|
||||
|
||||
def cast(expression: ExpOrStr, to: DATA_TYPE, **opts) -> Cast:
|
||||
"""Cast an expression to a data type.
|
||||
|
@ -6333,10 +6523,10 @@ def var(name: t.Optional[ExpOrStr]) -> Var:
|
|||
|
||||
Example:
|
||||
>>> repr(var('x'))
|
||||
'(VAR this: x)'
|
||||
'Var(this=x)'
|
||||
|
||||
>>> repr(var(column('x', table='y')))
|
||||
'(VAR this: x)'
|
||||
'Var(this=x)'
|
||||
|
||||
Args:
|
||||
name: The name of the var or an expression who's name will become the var.
|
||||
|
|
|
@ -68,6 +68,7 @@ class Generator:
|
|||
exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})",
|
||||
exp.ClusteredColumnConstraint: lambda self, e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})",
|
||||
exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
|
||||
exp.AutoRefreshProperty: lambda self, e: f"AUTO REFRESH {self.sql(e, 'this')}",
|
||||
exp.CopyGrantsProperty: lambda self, e: "COPY GRANTS",
|
||||
exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
|
||||
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
|
||||
|
@ -96,6 +97,7 @@ class Generator:
|
|||
exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}",
|
||||
exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
|
||||
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
|
||||
exp.SqlReadWriteProperty: lambda self, e: e.name,
|
||||
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
|
||||
exp.StabilityProperty: lambda self, e: e.name,
|
||||
exp.TemporaryProperty: lambda self, e: f"TEMPORARY",
|
||||
|
@ -110,7 +112,8 @@ class Generator:
|
|||
}
|
||||
|
||||
# Whether or not null ordering is supported in order by
|
||||
NULL_ORDERING_SUPPORTED = True
|
||||
# True: Full Support, None: No support, False: No support in window specifications
|
||||
NULL_ORDERING_SUPPORTED: t.Optional[bool] = True
|
||||
|
||||
# Whether or not locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported
|
||||
LOCKING_READS_SUPPORTED = False
|
||||
|
@ -133,12 +136,6 @@ class Generator:
|
|||
# Whether or not the plural form of date parts like day (i.e. "days") is supported in INTERVALs
|
||||
INTERVAL_ALLOWS_PLURAL_FORM = True
|
||||
|
||||
# Whether or not the TABLESAMPLE clause supports a method name, like BERNOULLI
|
||||
TABLESAMPLE_WITH_METHOD = True
|
||||
|
||||
# Whether or not to treat the number in TABLESAMPLE (50) as a percentage
|
||||
TABLESAMPLE_SIZE_IS_PERCENT = False
|
||||
|
||||
# Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH")
|
||||
LIMIT_FETCH = "ALL"
|
||||
|
||||
|
@ -219,6 +216,18 @@ class Generator:
|
|||
# Whether or not parentheses are required around the table sample's expression
|
||||
TABLESAMPLE_REQUIRES_PARENS = True
|
||||
|
||||
# Whether or not a table sample clause's size needs to be followed by the ROWS keyword
|
||||
TABLESAMPLE_SIZE_IS_ROWS = True
|
||||
|
||||
# The keyword(s) to use when generating a sample clause
|
||||
TABLESAMPLE_KEYWORDS = "TABLESAMPLE"
|
||||
|
||||
# Whether or not the TABLESAMPLE clause supports a method name, like BERNOULLI
|
||||
TABLESAMPLE_WITH_METHOD = True
|
||||
|
||||
# The keyword to use when specifying the seed of a sample clause
|
||||
TABLESAMPLE_SEED_KEYWORD = "SEED"
|
||||
|
||||
# Whether or not COLLATE is a function instead of a binary operator
|
||||
COLLATE_IS_FUNC = False
|
||||
|
||||
|
@ -234,6 +243,27 @@ class Generator:
|
|||
# Whether or not CONCAT requires >1 arguments
|
||||
SUPPORTS_SINGLE_ARG_CONCAT = True
|
||||
|
||||
# Whether or not LAST_DAY function supports a date part argument
|
||||
LAST_DAY_SUPPORTS_DATE_PART = True
|
||||
|
||||
# Whether or not named columns are allowed in table aliases
|
||||
SUPPORTS_TABLE_ALIAS_COLUMNS = True
|
||||
|
||||
# Whether or not UNPIVOT aliases are Identifiers (False means they're Literals)
|
||||
UNPIVOT_ALIASES_ARE_IDENTIFIERS = True
|
||||
|
||||
# What delimiter to use for separating JSON key/value pairs
|
||||
JSON_KEY_VALUE_PAIR_SEP = ":"
|
||||
|
||||
# INSERT OVERWRITE TABLE x override
|
||||
INSERT_OVERWRITE = " OVERWRITE TABLE"
|
||||
|
||||
# Whether or not the SELECT .. INTO syntax is used instead of CTAS
|
||||
SUPPORTS_SELECT_INTO = False
|
||||
|
||||
# Whether or not UNLOGGED tables can be created
|
||||
SUPPORTS_UNLOGGED_TABLES = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
exp.DataType.Type.NCHAR: "CHAR",
|
||||
exp.DataType.Type.NVARCHAR: "VARCHAR",
|
||||
|
@ -252,15 +282,15 @@ class Generator:
|
|||
}
|
||||
|
||||
TIME_PART_SINGULARS = {
|
||||
"microseconds": "microsecond",
|
||||
"seconds": "second",
|
||||
"minutes": "minute",
|
||||
"hours": "hour",
|
||||
"days": "day",
|
||||
"weeks": "week",
|
||||
"months": "month",
|
||||
"quarters": "quarter",
|
||||
"years": "year",
|
||||
"MICROSECONDS": "MICROSECOND",
|
||||
"SECONDS": "SECOND",
|
||||
"MINUTES": "MINUTE",
|
||||
"HOURS": "HOUR",
|
||||
"DAYS": "DAY",
|
||||
"WEEKS": "WEEK",
|
||||
"MONTHS": "MONTH",
|
||||
"QUARTERS": "QUARTER",
|
||||
"YEARS": "YEAR",
|
||||
}
|
||||
|
||||
TOKEN_MAPPING: t.Dict[TokenType, str] = {}
|
||||
|
@ -272,6 +302,7 @@ class Generator:
|
|||
PROPERTIES_LOCATION = {
|
||||
exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.AutoRefreshProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.BlockCompressionProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.ChecksumProperty: exp.Properties.Location.POST_NAME,
|
||||
|
@ -323,6 +354,7 @@ class Generator:
|
|||
exp.SettingsProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SetProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SqlReadWriteProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.TemporaryProperty: exp.Properties.Location.POST_CREATE,
|
||||
|
@ -370,7 +402,7 @@ class Generator:
|
|||
# Expressions that need to have all CTEs under them bubbled up to them
|
||||
EXPRESSIONS_WITHOUT_NESTED_CTES: t.Set[t.Type[exp.Expression]] = set()
|
||||
|
||||
KEY_VALUE_DEFINITONS = (exp.Bracket, exp.EQ, exp.PropertyEQ, exp.Slice)
|
||||
KEY_VALUE_DEFINITIONS = (exp.Bracket, exp.EQ, exp.PropertyEQ, exp.Slice)
|
||||
|
||||
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
|
||||
|
||||
|
@ -775,7 +807,7 @@ class Generator:
|
|||
return self.sql(expression, "this")
|
||||
|
||||
def create_sql(self, expression: exp.Create) -> str:
|
||||
kind = self.sql(expression, "kind").upper()
|
||||
kind = self.sql(expression, "kind")
|
||||
properties = expression.args.get("properties")
|
||||
properties_locs = self.locate_properties(properties) if properties else defaultdict()
|
||||
|
||||
|
@ -868,7 +900,12 @@ class Generator:
|
|||
return f"{shallow}{keyword} {this}"
|
||||
|
||||
def describe_sql(self, expression: exp.Describe) -> str:
|
||||
return f"DESCRIBE {self.sql(expression, 'this')}"
|
||||
extended = " EXTENDED" if expression.args.get("extended") else ""
|
||||
return f"DESCRIBE{extended} {self.sql(expression, 'this')}"
|
||||
|
||||
def heredoc_sql(self, expression: exp.Heredoc) -> str:
|
||||
tag = self.sql(expression, "tag")
|
||||
return f"${tag}${self.sql(expression, 'this')}${tag}$"
|
||||
|
||||
def prepend_ctes(self, expression: exp.Expression, sql: str) -> str:
|
||||
with_ = self.sql(expression, "with")
|
||||
|
@ -895,6 +932,10 @@ class Generator:
|
|||
columns = self.expressions(expression, key="columns", flat=True)
|
||||
columns = f"({columns})" if columns else ""
|
||||
|
||||
if columns and not self.SUPPORTS_TABLE_ALIAS_COLUMNS:
|
||||
columns = ""
|
||||
self.unsupported("Named columns are not supported in table alias.")
|
||||
|
||||
if not alias and not self.dialect.UNNEST_COLUMN_ONLY:
|
||||
alias = "_t"
|
||||
|
||||
|
@ -1027,7 +1068,7 @@ class Generator:
|
|||
|
||||
def fetch_sql(self, expression: exp.Fetch) -> str:
|
||||
direction = expression.args.get("direction")
|
||||
direction = f" {direction.upper()}" if direction else ""
|
||||
direction = f" {direction}" if direction else ""
|
||||
count = expression.args.get("count")
|
||||
count = f" {count}" if count else ""
|
||||
if expression.args.get("percent"):
|
||||
|
@ -1318,7 +1359,7 @@ class Generator:
|
|||
if isinstance(expression.this, exp.Directory):
|
||||
this = " OVERWRITE" if overwrite else " INTO"
|
||||
else:
|
||||
this = " OVERWRITE TABLE" if overwrite else " INTO"
|
||||
this = self.INSERT_OVERWRITE if overwrite else " INTO"
|
||||
|
||||
alternative = expression.args.get("alternative")
|
||||
alternative = f" OR {alternative}" if alternative else ""
|
||||
|
@ -1365,10 +1406,10 @@ class Generator:
|
|||
return f"KILL{kind}{this}"
|
||||
|
||||
def pseudotype_sql(self, expression: exp.PseudoType) -> str:
|
||||
return expression.name.upper()
|
||||
return expression.name
|
||||
|
||||
def objectidentifier_sql(self, expression: exp.ObjectIdentifier) -> str:
|
||||
return expression.name.upper()
|
||||
return expression.name
|
||||
|
||||
def onconflict_sql(self, expression: exp.OnConflict) -> str:
|
||||
conflict = "ON DUPLICATE KEY" if expression.args.get("duplicate") else "ON CONFLICT"
|
||||
|
@ -1445,9 +1486,6 @@ class Generator:
|
|||
pattern = f", PATTERN => {pattern}" if pattern else ""
|
||||
file_format = f" (FILE_FORMAT => {file_format}{pattern})"
|
||||
|
||||
index = self.sql(expression, "index")
|
||||
index = f" AT {index}" if index else ""
|
||||
|
||||
ordinality = expression.args.get("ordinality") or ""
|
||||
if ordinality:
|
||||
ordinality = f" WITH ORDINALITY{alias}"
|
||||
|
@ -1457,10 +1495,13 @@ class Generator:
|
|||
if when:
|
||||
table = f"{table} {when}"
|
||||
|
||||
return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}{ordinality}"
|
||||
return f"{table}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}{ordinality}"
|
||||
|
||||
def tablesample_sql(
|
||||
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
|
||||
self,
|
||||
expression: exp.TableSample,
|
||||
sep: str = " AS ",
|
||||
tablesample_keyword: t.Optional[str] = None,
|
||||
) -> str:
|
||||
if self.dialect.ALIAS_POST_TABLESAMPLE and expression.this and expression.this.alias:
|
||||
table = expression.this.copy()
|
||||
|
@ -1472,30 +1513,30 @@ class Generator:
|
|||
alias = ""
|
||||
|
||||
method = self.sql(expression, "method")
|
||||
method = f"{method.upper()} " if method and self.TABLESAMPLE_WITH_METHOD else ""
|
||||
method = f"{method} " if method and self.TABLESAMPLE_WITH_METHOD else ""
|
||||
numerator = self.sql(expression, "bucket_numerator")
|
||||
denominator = self.sql(expression, "bucket_denominator")
|
||||
field = self.sql(expression, "bucket_field")
|
||||
field = f" ON {field}" if field else ""
|
||||
bucket = f"BUCKET {numerator} OUT OF {denominator}{field}" if numerator else ""
|
||||
percent = self.sql(expression, "percent")
|
||||
percent = f"{percent} PERCENT" if percent else ""
|
||||
rows = self.sql(expression, "rows")
|
||||
rows = f"{rows} ROWS" if rows else ""
|
||||
seed = self.sql(expression, "seed")
|
||||
seed = f" {self.TABLESAMPLE_SEED_KEYWORD} ({seed})" if seed else ""
|
||||
|
||||
size = self.sql(expression, "size")
|
||||
if size and self.TABLESAMPLE_SIZE_IS_PERCENT:
|
||||
size = f"{size} PERCENT"
|
||||
if size and self.TABLESAMPLE_SIZE_IS_ROWS:
|
||||
size = f"{size} ROWS"
|
||||
|
||||
seed = self.sql(expression, "seed")
|
||||
seed = f" {seed_prefix} ({seed})" if seed else ""
|
||||
kind = expression.args.get("kind", "TABLESAMPLE")
|
||||
percent = self.sql(expression, "percent")
|
||||
if percent and not self.dialect.TABLESAMPLE_SIZE_IS_PERCENT:
|
||||
percent = f"{percent} PERCENT"
|
||||
|
||||
expr = f"{bucket}{percent}{rows}{size}"
|
||||
expr = f"{bucket}{percent}{size}"
|
||||
if self.TABLESAMPLE_REQUIRES_PARENS:
|
||||
expr = f"({expr})"
|
||||
|
||||
return f"{this} {kind} {method}{expr}{seed}{alias}"
|
||||
return (
|
||||
f"{this} {tablesample_keyword or self.TABLESAMPLE_KEYWORDS} {method}{expr}{seed}{alias}"
|
||||
)
|
||||
|
||||
def pivot_sql(self, expression: exp.Pivot) -> str:
|
||||
expressions = self.expressions(expression, flat=True)
|
||||
|
@ -1513,8 +1554,7 @@ class Generator:
|
|||
|
||||
alias = self.sql(expression, "alias")
|
||||
alias = f" AS {alias}" if alias else ""
|
||||
unpivot = expression.args.get("unpivot")
|
||||
direction = "UNPIVOT" if unpivot else "PIVOT"
|
||||
direction = "UNPIVOT" if expression.unpivot else "PIVOT"
|
||||
field = self.sql(expression, "field")
|
||||
include_nulls = expression.args.get("include_nulls")
|
||||
if include_nulls is not None:
|
||||
|
@ -1675,7 +1715,8 @@ class Generator:
|
|||
if not on_sql and using:
|
||||
on_sql = csv(*(self.sql(column) for column in using))
|
||||
|
||||
this_sql = self.sql(expression, "this")
|
||||
this = expression.this
|
||||
this_sql = self.sql(this)
|
||||
|
||||
if on_sql:
|
||||
on_sql = self.indent(on_sql, skip_first=True)
|
||||
|
@ -1685,6 +1726,9 @@ class Generator:
|
|||
else:
|
||||
on_sql = f"{space}ON {on_sql}"
|
||||
elif not op_sql:
|
||||
if isinstance(this, exp.Lateral) and this.args.get("cross_apply") is not None:
|
||||
return f" {this_sql}"
|
||||
|
||||
return f", {this_sql}"
|
||||
|
||||
op_sql = f"{op_sql} JOIN" if op_sql else "JOIN"
|
||||
|
@ -1695,6 +1739,19 @@ class Generator:
|
|||
args = f"({args})" if len(args.split(",")) > 1 else args
|
||||
return f"{args} {arrow_sep} {self.sql(expression, 'this')}"
|
||||
|
||||
def lateral_op(self, expression: exp.Lateral) -> str:
|
||||
cross_apply = expression.args.get("cross_apply")
|
||||
|
||||
# https://www.mssqltips.com/sqlservertip/1958/sql-server-cross-apply-and-outer-apply/
|
||||
if cross_apply is True:
|
||||
op = "INNER JOIN "
|
||||
elif cross_apply is False:
|
||||
op = "LEFT JOIN "
|
||||
else:
|
||||
op = ""
|
||||
|
||||
return f"{op}LATERAL"
|
||||
|
||||
def lateral_sql(self, expression: exp.Lateral) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
||||
|
@ -1708,7 +1765,7 @@ class Generator:
|
|||
|
||||
alias = self.sql(expression, "alias")
|
||||
alias = f" AS {alias}" if alias else ""
|
||||
return f"LATERAL {this}{alias}"
|
||||
return f"{self.lateral_op(expression)} {this}{alias}"
|
||||
|
||||
def limit_sql(self, expression: exp.Limit, top: bool = False) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -1805,7 +1862,8 @@ class Generator:
|
|||
def order_sql(self, expression: exp.Order, flat: bool = False) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
this = f"{this} " if this else this
|
||||
order = self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat) # type: ignore
|
||||
siblings = "SIBLINGS " if expression.args.get("siblings") else ""
|
||||
order = self.op_expressions(f"{this}ORDER {siblings}BY", expression, flat=this or flat) # type: ignore
|
||||
interpolated_values = [
|
||||
f"{self.sql(named_expression, 'alias')} AS {self.sql(named_expression, 'this')}"
|
||||
for named_expression in expression.args.get("interpolate") or []
|
||||
|
@ -1860,9 +1918,21 @@ class Generator:
|
|||
|
||||
# If the NULLS FIRST/LAST clause is unsupported, we add another sort key to simulate it
|
||||
if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED:
|
||||
null_sort_order = " DESC" if nulls_sort_change == " NULLS FIRST" else ""
|
||||
this = f"CASE WHEN {this} IS NULL THEN 1 ELSE 0 END{null_sort_order}, {this}"
|
||||
nulls_sort_change = ""
|
||||
window = expression.find_ancestor(exp.Window, exp.Select)
|
||||
if isinstance(window, exp.Window) and window.args.get("spec"):
|
||||
self.unsupported(
|
||||
f"'{nulls_sort_change.strip()}' translation not supported in window functions"
|
||||
)
|
||||
nulls_sort_change = ""
|
||||
elif self.NULL_ORDERING_SUPPORTED is None:
|
||||
if expression.this.is_int:
|
||||
self.unsupported(
|
||||
f"'{nulls_sort_change.strip()}' translation not supported with positional ordering"
|
||||
)
|
||||
else:
|
||||
null_sort_order = " DESC" if nulls_sort_change == " NULLS FIRST" else ""
|
||||
this = f"CASE WHEN {this} IS NULL THEN 1 ELSE 0 END{null_sort_order}, {this}"
|
||||
nulls_sort_change = ""
|
||||
|
||||
with_fill = self.sql(expression, "with_fill")
|
||||
with_fill = f" {with_fill}" if with_fill else ""
|
||||
|
@ -1961,10 +2031,14 @@ class Generator:
|
|||
return [locks, self.sql(expression, "sample")]
|
||||
|
||||
def select_sql(self, expression: exp.Select) -> str:
|
||||
into = expression.args.get("into")
|
||||
if not self.SUPPORTS_SELECT_INTO and into:
|
||||
into.pop()
|
||||
|
||||
hint = self.sql(expression, "hint")
|
||||
distinct = self.sql(expression, "distinct")
|
||||
distinct = f" {distinct}" if distinct else ""
|
||||
kind = self.sql(expression, "kind").upper()
|
||||
kind = self.sql(expression, "kind")
|
||||
limit = expression.args.get("limit")
|
||||
top = (
|
||||
self.limit_sql(limit, top=True)
|
||||
|
@ -2005,7 +2079,19 @@ class Generator:
|
|||
self.sql(expression, "into", comment=False),
|
||||
self.sql(expression, "from", comment=False),
|
||||
)
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
sql = self.prepend_ctes(expression, sql)
|
||||
|
||||
if not self.SUPPORTS_SELECT_INTO and into:
|
||||
if into.args.get("temporary"):
|
||||
table_kind = " TEMPORARY"
|
||||
elif self.SUPPORTS_UNLOGGED_TABLES and into.args.get("unlogged"):
|
||||
table_kind = " UNLOGGED"
|
||||
else:
|
||||
table_kind = ""
|
||||
sql = f"CREATE{table_kind} TABLE {self.sql(into.this)} AS {sql}"
|
||||
|
||||
return sql
|
||||
|
||||
def schema_sql(self, expression: exp.Schema) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -2266,29 +2352,35 @@ class Generator:
|
|||
return f"{self.func('MATCH', *expression.expressions)} AGAINST({self.sql(expression, 'this')}{modifier})"
|
||||
|
||||
def jsonkeyvalue_sql(self, expression: exp.JSONKeyValue) -> str:
|
||||
return f"{self.sql(expression, 'this')}: {self.sql(expression, 'expression')}"
|
||||
return f"{self.sql(expression, 'this')}{self.JSON_KEY_VALUE_PAIR_SEP} {self.sql(expression, 'expression')}"
|
||||
|
||||
def formatjson_sql(self, expression: exp.FormatJson) -> str:
|
||||
return f"{self.sql(expression, 'this')} FORMAT JSON"
|
||||
|
||||
def jsonobject_sql(self, expression: exp.JSONObject) -> str:
|
||||
def jsonobject_sql(self, expression: exp.JSONObject | exp.JSONObjectAgg) -> str:
|
||||
null_handling = expression.args.get("null_handling")
|
||||
null_handling = f" {null_handling}" if null_handling else ""
|
||||
|
||||
unique_keys = expression.args.get("unique_keys")
|
||||
if unique_keys is not None:
|
||||
unique_keys = f" {'WITH' if unique_keys else 'WITHOUT'} UNIQUE KEYS"
|
||||
else:
|
||||
unique_keys = ""
|
||||
|
||||
return_type = self.sql(expression, "return_type")
|
||||
return_type = f" RETURNING {return_type}" if return_type else ""
|
||||
encoding = self.sql(expression, "encoding")
|
||||
encoding = f" ENCODING {encoding}" if encoding else ""
|
||||
|
||||
return self.func(
|
||||
"JSON_OBJECT",
|
||||
"JSON_OBJECT" if isinstance(expression, exp.JSONObject) else "JSON_OBJECTAGG",
|
||||
*expression.expressions,
|
||||
suffix=f"{null_handling}{unique_keys}{return_type}{encoding})",
|
||||
)
|
||||
|
||||
def jsonobjectagg_sql(self, expression: exp.JSONObjectAgg) -> str:
|
||||
return self.jsonobject_sql(expression)
|
||||
|
||||
def jsonarray_sql(self, expression: exp.JSONArray) -> str:
|
||||
null_handling = expression.args.get("null_handling")
|
||||
null_handling = f" {null_handling}" if null_handling else ""
|
||||
|
@ -2385,7 +2477,7 @@ class Generator:
|
|||
def interval_sql(self, expression: exp.Interval) -> str:
|
||||
unit = self.sql(expression, "unit")
|
||||
if not self.INTERVAL_ALLOWS_PLURAL_FORM:
|
||||
unit = self.TIME_PART_SINGULARS.get(unit.lower(), unit)
|
||||
unit = self.TIME_PART_SINGULARS.get(unit, unit)
|
||||
unit = f" {unit}" if unit else ""
|
||||
|
||||
if self.SINGLE_STRING_INTERVAL:
|
||||
|
@ -2436,9 +2528,25 @@ class Generator:
|
|||
alias = f" AS {alias}" if alias else ""
|
||||
return f"{self.sql(expression, 'this')}{alias}"
|
||||
|
||||
def pivotalias_sql(self, expression: exp.PivotAlias) -> str:
|
||||
alias = expression.args["alias"]
|
||||
identifier_alias = isinstance(alias, exp.Identifier)
|
||||
|
||||
if identifier_alias and not self.UNPIVOT_ALIASES_ARE_IDENTIFIERS:
|
||||
alias.replace(exp.Literal.string(alias.output_name))
|
||||
elif not identifier_alias and self.UNPIVOT_ALIASES_ARE_IDENTIFIERS:
|
||||
alias.replace(exp.to_identifier(alias.output_name))
|
||||
|
||||
return self.alias_sql(expression)
|
||||
|
||||
def aliases_sql(self, expression: exp.Aliases) -> str:
|
||||
return f"{self.sql(expression, 'this')} AS ({self.expressions(expression, flat=True)})"
|
||||
|
||||
def atindex_sql(self, expression: exp.AtTimeZone) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
index = self.sql(expression, "expression")
|
||||
return f"{this} AT {index}"
|
||||
|
||||
def attimezone_sql(self, expression: exp.AtTimeZone) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
zone = self.sql(expression, "zone")
|
||||
|
@ -2500,7 +2608,7 @@ class Generator:
|
|||
return self.binary(expression, "COLLATE")
|
||||
|
||||
def command_sql(self, expression: exp.Command) -> str:
|
||||
return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}"
|
||||
return f"{self.sql(expression, 'this')} {expression.text('expression').strip()}"
|
||||
|
||||
def comment_sql(self, expression: exp.Comment) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -3102,6 +3210,47 @@ class Generator:
|
|||
cond_for_null = arg.is_(exp.null())
|
||||
return self.sql(exp.func("IF", cond_for_null, exp.null(), exp.Array(expressions=[arg])))
|
||||
|
||||
def tsordstotime_sql(self, expression: exp.TsOrDsToTime) -> str:
|
||||
this = expression.this
|
||||
if isinstance(this, exp.TsOrDsToTime) or this.is_type(exp.DataType.Type.TIME):
|
||||
return self.sql(this)
|
||||
|
||||
return self.sql(exp.cast(this, "time"))
|
||||
|
||||
def tsordstodate_sql(self, expression: exp.TsOrDsToDate) -> str:
|
||||
this = expression.this
|
||||
time_format = self.format_time(expression)
|
||||
|
||||
if time_format and time_format not in (self.dialect.TIME_FORMAT, self.dialect.DATE_FORMAT):
|
||||
return self.sql(
|
||||
exp.cast(exp.StrToTime(this=this, format=expression.args["format"]), "date")
|
||||
)
|
||||
|
||||
if isinstance(this, exp.TsOrDsToDate) or this.is_type(exp.DataType.Type.DATE):
|
||||
return self.sql(this)
|
||||
|
||||
return self.sql(exp.cast(this, "date"))
|
||||
|
||||
def unixdate_sql(self, expression: exp.UnixDate) -> str:
|
||||
return self.sql(
|
||||
exp.func(
|
||||
"DATEDIFF",
|
||||
expression.this,
|
||||
exp.cast(exp.Literal.string("1970-01-01"), "date"),
|
||||
"day",
|
||||
)
|
||||
)
|
||||
|
||||
def lastday_sql(self, expression: exp.LastDay) -> str:
|
||||
if self.LAST_DAY_SUPPORTS_DATE_PART:
|
||||
return self.function_fallback_sql(expression)
|
||||
|
||||
unit = expression.text("unit")
|
||||
if unit and unit != "MONTH":
|
||||
self.unsupported("Date parts are not supported in LAST_DAY.")
|
||||
|
||||
return self.func("LAST_DAY", expression.this)
|
||||
|
||||
def _simplify_unless_literal(self, expression: E) -> E:
|
||||
if not isinstance(expression, exp.Literal):
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
|
|
@ -129,13 +129,10 @@ def lineage(
|
|||
if isinstance(column, int)
|
||||
else next(
|
||||
(select for select in scope.expression.selects if select.alias_or_name == column),
|
||||
exp.Star() if scope.expression.is_star else None,
|
||||
exp.Star() if scope.expression.is_star else scope.expression,
|
||||
)
|
||||
)
|
||||
|
||||
if not select:
|
||||
raise ValueError(f"Could not find {column} in {scope.expression}")
|
||||
|
||||
if isinstance(scope.expression, exp.Union):
|
||||
upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
|
||||
|
||||
|
@ -194,6 +191,8 @@ def lineage(
|
|||
# if the select is a star add all scope sources as downstreams
|
||||
if select.is_star:
|
||||
for source in scope.sources.values():
|
||||
if isinstance(source, Scope):
|
||||
source = source.expression
|
||||
node.downstream.append(Node(name=select.sql(), source=source, expression=source))
|
||||
|
||||
# Find all columns that went into creating this one to list their lineage nodes.
|
||||
|
|
|
@ -195,6 +195,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.StrPosition,
|
||||
exp.TsOrDiToDi,
|
||||
},
|
||||
exp.DataType.Type.JSON: {
|
||||
exp.ParseJSON,
|
||||
},
|
||||
exp.DataType.Type.TIMESTAMP: {
|
||||
exp.CurrentTime,
|
||||
exp.CurrentTimestamp,
|
||||
|
@ -275,6 +278,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
|
||||
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
|
||||
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
|
||||
exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True),
|
||||
}
|
||||
|
||||
NESTED_TYPES = {
|
||||
|
@ -477,7 +481,12 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
@t.no_type_check
|
||||
def _annotate_by_args(
|
||||
self, expression: E, *args: str, promote: bool = False, array: bool = False
|
||||
self,
|
||||
expression: E,
|
||||
*args: str,
|
||||
promote: bool = False,
|
||||
array: bool = False,
|
||||
struct: bool = False,
|
||||
) -> E:
|
||||
self._annotate_args(expression)
|
||||
|
||||
|
@ -506,6 +515,19 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
),
|
||||
)
|
||||
|
||||
if struct:
|
||||
expressions = [
|
||||
expr.type
|
||||
if not expr.args.get("alias")
|
||||
else exp.ColumnDef(this=expr.args["alias"].copy(), kind=expr.type)
|
||||
for expr in expressions
|
||||
]
|
||||
|
||||
self._set_type(
|
||||
expression,
|
||||
exp.DataType(this=exp.DataType.Type.STRUCT, expressions=expressions, nested=True),
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
def _annotate_timeunit(
|
||||
|
|
|
@ -30,13 +30,18 @@ def pushdown_predicates(expression, dialect=None):
|
|||
where = select.args.get("where")
|
||||
if where:
|
||||
selected_sources = scope.selected_sources
|
||||
join_index = {
|
||||
join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or [])
|
||||
}
|
||||
|
||||
# a right join can only push down to itself and not the source FROM table
|
||||
for k, (node, source) in selected_sources.items():
|
||||
parent = node.find_ancestor(exp.Join, exp.From)
|
||||
if isinstance(parent, exp.Join) and parent.side == "RIGHT":
|
||||
selected_sources = {k: (node, source)}
|
||||
break
|
||||
pushdown(where.this, selected_sources, scope_ref_count, dialect)
|
||||
|
||||
pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index)
|
||||
|
||||
# joins should only pushdown into itself, not to other joins
|
||||
# so we limit the selected sources to only itself
|
||||
|
@ -53,7 +58,7 @@ def pushdown_predicates(expression, dialect=None):
|
|||
return expression
|
||||
|
||||
|
||||
def pushdown(condition, sources, scope_ref_count, dialect):
|
||||
def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
|
||||
if not condition:
|
||||
return
|
||||
|
||||
|
@ -67,21 +72,28 @@ def pushdown(condition, sources, scope_ref_count, dialect):
|
|||
)
|
||||
|
||||
if cnf_like:
|
||||
pushdown_cnf(predicates, sources, scope_ref_count)
|
||||
pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index)
|
||||
else:
|
||||
pushdown_dnf(predicates, sources, scope_ref_count)
|
||||
|
||||
|
||||
def pushdown_cnf(predicates, scope, scope_ref_count):
|
||||
def pushdown_cnf(predicates, scope, scope_ref_count, join_index=None):
|
||||
"""
|
||||
If the predicates are in CNF like form, we can simply replace each block in the parent.
|
||||
"""
|
||||
join_index = join_index or {}
|
||||
for predicate in predicates:
|
||||
for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
|
||||
if isinstance(node, exp.Join):
|
||||
predicate.replace(exp.true())
|
||||
node.on(predicate, copy=False)
|
||||
break
|
||||
name = node.alias_or_name
|
||||
predicate_tables = exp.column_table_names(predicate, name)
|
||||
|
||||
# Don't push the predicate if it references tables that appear in later joins
|
||||
this_index = join_index[name]
|
||||
if all(join_index.get(table, -1) < this_index for table in predicate_tables):
|
||||
predicate.replace(exp.true())
|
||||
node.on(predicate, copy=False)
|
||||
break
|
||||
if isinstance(node, exp.Select):
|
||||
predicate.replace(exp.true())
|
||||
inner_predicate = replace_aliases(node, predicate)
|
||||
|
@ -112,9 +124,7 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
|
|||
|
||||
conditions = {}
|
||||
|
||||
# for every pushdown table, find all related conditions in all predicates
|
||||
# combine them with ORS
|
||||
# (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z)
|
||||
# pushdown all predicates to their respective nodes
|
||||
for table in sorted(pushdown_tables):
|
||||
for predicate in predicates:
|
||||
nodes = nodes_for_predicate(predicate, scope, scope_ref_count)
|
||||
|
@ -122,23 +132,9 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
|
|||
if table not in nodes:
|
||||
continue
|
||||
|
||||
predicate_condition = None
|
||||
|
||||
for column in predicate.find_all(exp.Column):
|
||||
if column.table == table:
|
||||
condition = column.find_ancestor(exp.Condition)
|
||||
predicate_condition = (
|
||||
exp.and_(predicate_condition, condition)
|
||||
if predicate_condition
|
||||
else condition
|
||||
)
|
||||
|
||||
if predicate_condition:
|
||||
conditions[table] = (
|
||||
exp.or_(conditions[table], predicate_condition)
|
||||
if table in conditions
|
||||
else predicate_condition
|
||||
)
|
||||
conditions[table] = (
|
||||
exp.or_(conditions[table], predicate) if table in conditions else predicate
|
||||
)
|
||||
|
||||
for name, node in nodes.items():
|
||||
if name not in conditions:
|
||||
|
|
|
@ -43,9 +43,8 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
|
|||
parent_selections = referenced_columns.get(scope, {SELECT_ALL})
|
||||
alias_count = source_column_alias_count.get(scope, 0)
|
||||
|
||||
if scope.expression.args.get("distinct") or (scope.parent and scope.parent.pivots):
|
||||
# We can't remove columns SELECT DISTINCT nor UNION DISTINCT. The same holds if
|
||||
# we select from a pivoted source in the parent scope.
|
||||
# We can't remove columns SELECT DISTINCT nor UNION DISTINCT.
|
||||
if scope.expression.args.get("distinct"):
|
||||
parent_selections = {SELECT_ALL}
|
||||
|
||||
if isinstance(scope.expression, exp.Union):
|
||||
|
@ -78,7 +77,7 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
|
|||
# Push the selected columns down to the next scope
|
||||
for name, (node, source) in scope.selected_sources.items():
|
||||
if isinstance(source, Scope):
|
||||
columns = selects.get(name) or set()
|
||||
columns = {SELECT_ALL} if scope.pivots else selects.get(name) or set()
|
||||
referenced_columns[source].update(columns)
|
||||
|
||||
column_aliases = node.alias_column_names
|
||||
|
|
|
@ -3,10 +3,11 @@ from __future__ import annotations
|
|||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
from sqlglot.dialects.dialect import Dialect, DialectType
|
||||
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
|
||||
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
|
||||
from sqlglot.optimizer.qualify_columns import (
|
||||
pushdown_cte_alias_columns as pushdown_cte_alias_columns_func,
|
||||
qualify_columns as qualify_columns_func,
|
||||
quote_identifiers as quote_identifiers_func,
|
||||
validate_qualify_columns as validate_qualify_columns_func,
|
||||
|
@ -22,6 +23,7 @@ def qualify(
|
|||
catalog: t.Optional[str] = None,
|
||||
schema: t.Optional[dict | Schema] = None,
|
||||
expand_alias_refs: bool = True,
|
||||
expand_stars: bool = True,
|
||||
infer_schema: t.Optional[bool] = None,
|
||||
isolate_tables: bool = False,
|
||||
qualify_columns: bool = True,
|
||||
|
@ -47,6 +49,9 @@ def qualify(
|
|||
catalog: Default catalog name for tables.
|
||||
schema: Schema to infer column names and types.
|
||||
expand_alias_refs: Whether or not to expand references to aliases.
|
||||
expand_stars: Whether or not to expand star queries. This is a necessary step
|
||||
for most of the optimizer's rules to work; do not set to False unless you
|
||||
know what you're doing!
|
||||
infer_schema: Whether or not to infer the schema if missing.
|
||||
isolate_tables: Whether or not to isolate table selects.
|
||||
qualify_columns: Whether or not to qualify columns.
|
||||
|
@ -66,9 +71,16 @@ def qualify(
|
|||
if isolate_tables:
|
||||
expression = isolate_table_selects(expression, schema=schema)
|
||||
|
||||
if Dialect.get_or_raise(dialect).PREFER_CTE_ALIAS_COLUMN:
|
||||
expression = pushdown_cte_alias_columns_func(expression)
|
||||
|
||||
if qualify_columns:
|
||||
expression = qualify_columns_func(
|
||||
expression, schema, expand_alias_refs=expand_alias_refs, infer_schema=infer_schema
|
||||
expression,
|
||||
schema,
|
||||
expand_alias_refs=expand_alias_refs,
|
||||
expand_stars=expand_stars,
|
||||
infer_schema=infer_schema,
|
||||
)
|
||||
|
||||
if quote_identifiers:
|
||||
|
|
|
@ -17,6 +17,7 @@ def qualify_columns(
|
|||
expression: exp.Expression,
|
||||
schema: t.Dict | Schema,
|
||||
expand_alias_refs: bool = True,
|
||||
expand_stars: bool = True,
|
||||
infer_schema: t.Optional[bool] = None,
|
||||
) -> exp.Expression:
|
||||
"""
|
||||
|
@ -33,10 +34,16 @@ def qualify_columns(
|
|||
expression: Expression to qualify.
|
||||
schema: Database schema.
|
||||
expand_alias_refs: Whether or not to expand references to aliases.
|
||||
expand_stars: Whether or not to expand star queries. This is a necessary step
|
||||
for most of the optimizer's rules to work; do not set to False unless you
|
||||
know what you're doing!
|
||||
infer_schema: Whether or not to infer the schema if missing.
|
||||
|
||||
Returns:
|
||||
The qualified expression.
|
||||
|
||||
Notes:
|
||||
- Currently only handles a single PIVOT or UNPIVOT operator
|
||||
"""
|
||||
schema = ensure_schema(schema)
|
||||
infer_schema = schema.empty if infer_schema is None else infer_schema
|
||||
|
@ -57,7 +64,8 @@ def qualify_columns(
|
|||
_expand_alias_refs(scope, resolver)
|
||||
|
||||
if not isinstance(scope.expression, exp.UDTF):
|
||||
_expand_stars(scope, resolver, using_column_tables, pseudocolumns)
|
||||
if expand_stars:
|
||||
_expand_stars(scope, resolver, using_column_tables, pseudocolumns)
|
||||
qualify_outputs(scope)
|
||||
|
||||
_expand_group_by(scope)
|
||||
|
@ -68,21 +76,41 @@ def qualify_columns(
|
|||
|
||||
def validate_qualify_columns(expression: E) -> E:
|
||||
"""Raise an `OptimizeError` if any columns aren't qualified"""
|
||||
unqualified_columns = []
|
||||
all_unqualified_columns = []
|
||||
for scope in traverse_scope(expression):
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
unqualified_columns.extend(scope.unqualified_columns)
|
||||
unqualified_columns = scope.unqualified_columns
|
||||
|
||||
if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
|
||||
column = scope.external_columns[0]
|
||||
raise OptimizeError(
|
||||
f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
|
||||
)
|
||||
for_table = f" for table: '{column.table}'" if column.table else ""
|
||||
raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
|
||||
|
||||
if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
|
||||
# New columns produced by the UNPIVOT can't be qualified, but there may be columns
|
||||
# under the UNPIVOT's IN clause that can and should be qualified. We recompute
|
||||
# this list here to ensure those in the former category will be excluded.
|
||||
unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
|
||||
unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
|
||||
|
||||
all_unqualified_columns.extend(unqualified_columns)
|
||||
|
||||
if all_unqualified_columns:
|
||||
raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
|
||||
|
||||
if unqualified_columns:
|
||||
raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
|
||||
return expression
|
||||
|
||||
|
||||
def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
|
||||
name_column = []
|
||||
field = unpivot.args.get("field")
|
||||
if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
|
||||
name_column.append(field.this)
|
||||
|
||||
value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
|
||||
return itertools.chain(name_column, value_columns)
|
||||
|
||||
|
||||
def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
|
||||
"""
|
||||
Remove table column aliases.
|
||||
|
@ -216,6 +244,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
|
|||
replace_columns(expression.args.get("group"), literal_index=True)
|
||||
replace_columns(expression.args.get("having"), resolve_table=True)
|
||||
replace_columns(expression.args.get("qualify"), resolve_table=True)
|
||||
|
||||
scope.clear_cache()
|
||||
|
||||
|
||||
|
@ -353,18 +382,25 @@ def _expand_stars(
|
|||
replace_columns: t.Dict[int, t.Dict[str, str]] = {}
|
||||
coalesced_columns = set()
|
||||
|
||||
# TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
|
||||
pivot_columns = None
|
||||
pivot_output_columns = None
|
||||
pivot_exclude_columns = None
|
||||
|
||||
pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
|
||||
if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
|
||||
if pivot.unpivot:
|
||||
pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
|
||||
|
||||
has_pivoted_source = pivot and not pivot.args.get("unpivot")
|
||||
if pivot and has_pivoted_source:
|
||||
pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))
|
||||
field = pivot.args.get("field")
|
||||
if isinstance(field, exp.In):
|
||||
pivot_exclude_columns = {
|
||||
c.output_name for e in field.expressions for c in e.find_all(exp.Column)
|
||||
}
|
||||
else:
|
||||
pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
|
||||
|
||||
pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
|
||||
if not pivot_output_columns:
|
||||
pivot_output_columns = [col.alias_or_name for col in pivot.expressions]
|
||||
pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
|
||||
if not pivot_output_columns:
|
||||
pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
|
||||
|
||||
for expression in scope.expression.selects:
|
||||
if isinstance(expression, exp.Star):
|
||||
|
@ -384,47 +420,54 @@ def _expand_stars(
|
|||
raise OptimizeError(f"Unknown table: {table}")
|
||||
|
||||
columns = resolver.get_source_columns(table, only_visible=True)
|
||||
columns = columns or scope.outer_column_list
|
||||
|
||||
if pseudocolumns:
|
||||
columns = [name for name in columns if name.upper() not in pseudocolumns]
|
||||
|
||||
if columns and "*" not in columns:
|
||||
table_id = id(table)
|
||||
columns_to_exclude = except_columns.get(table_id) or set()
|
||||
if not columns or "*" in columns:
|
||||
return
|
||||
|
||||
if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
|
||||
implicit_columns = [col for col in columns if col not in pivot_columns]
|
||||
table_id = id(table)
|
||||
columns_to_exclude = except_columns.get(table_id) or set()
|
||||
|
||||
if pivot:
|
||||
if pivot_output_columns and pivot_exclude_columns:
|
||||
pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
|
||||
pivot_columns.extend(pivot_output_columns)
|
||||
else:
|
||||
pivot_columns = pivot.alias_column_names
|
||||
|
||||
if pivot_columns:
|
||||
new_selections.extend(
|
||||
exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
|
||||
for name in implicit_columns + pivot_output_columns
|
||||
for name in pivot_columns
|
||||
if name not in columns_to_exclude
|
||||
)
|
||||
continue
|
||||
|
||||
for name in columns:
|
||||
if name in using_column_tables and table in using_column_tables[name]:
|
||||
if name in coalesced_columns:
|
||||
continue
|
||||
for name in columns:
|
||||
if name in using_column_tables and table in using_column_tables[name]:
|
||||
if name in coalesced_columns:
|
||||
continue
|
||||
|
||||
coalesced_columns.add(name)
|
||||
tables = using_column_tables[name]
|
||||
coalesce = [exp.column(name, table=table) for table in tables]
|
||||
coalesced_columns.add(name)
|
||||
tables = using_column_tables[name]
|
||||
coalesce = [exp.column(name, table=table) for table in tables]
|
||||
|
||||
new_selections.append(
|
||||
alias(
|
||||
exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
|
||||
alias=name,
|
||||
copy=False,
|
||||
)
|
||||
new_selections.append(
|
||||
alias(
|
||||
exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
|
||||
alias=name,
|
||||
copy=False,
|
||||
)
|
||||
elif name not in columns_to_exclude:
|
||||
alias_ = replace_columns.get(table_id, {}).get(name, name)
|
||||
column = exp.column(name, table=table)
|
||||
new_selections.append(
|
||||
alias(column, alias_, copy=False) if alias_ != name else column
|
||||
)
|
||||
else:
|
||||
return
|
||||
)
|
||||
elif name not in columns_to_exclude:
|
||||
alias_ = replace_columns.get(table_id, {}).get(name, name)
|
||||
column = exp.column(name, table=table)
|
||||
new_selections.append(
|
||||
alias(column, alias_, copy=False) if alias_ != name else column
|
||||
)
|
||||
|
||||
# Ensures we don't overwrite the initial selections with an empty list
|
||||
if new_selections:
|
||||
|
@ -472,6 +515,9 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
|
|||
for i, (selection, aliased_column) in enumerate(
|
||||
itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
|
||||
):
|
||||
if selection is None:
|
||||
break
|
||||
|
||||
if isinstance(selection, exp.Subquery):
|
||||
if not selection.output_name:
|
||||
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
|
||||
|
@ -495,6 +541,38 @@ def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool
|
|||
)
|
||||
|
||||
|
||||
def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Pushes down the CTE alias columns into the projection,
|
||||
|
||||
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
|
||||
>>> pushdown_cte_alias_columns(expression).sql()
|
||||
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
|
||||
|
||||
Args:
|
||||
expression: Expression to pushdown.
|
||||
|
||||
Returns:
|
||||
The expression with the CTE aliases pushed down into the projection.
|
||||
"""
|
||||
for cte in expression.find_all(exp.CTE):
|
||||
if cte.alias_column_names:
|
||||
new_expressions = []
|
||||
for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
|
||||
if isinstance(projection, exp.Alias):
|
||||
projection.set("alias", _alias)
|
||||
else:
|
||||
projection = alias(projection, alias=_alias)
|
||||
new_expressions.append(projection)
|
||||
cte.this.set("expressions", new_expressions)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
class Resolver:
|
||||
"""
|
||||
Helper for resolving columns.
|
||||
|
|
|
@ -72,11 +72,15 @@ def qualify_tables(
|
|||
if not source.args.get("catalog") and source.args.get("db"):
|
||||
source.set("catalog", catalog)
|
||||
|
||||
pivots = pivots = source.args.get("pivots")
|
||||
if not source.alias:
|
||||
# Don't add the pivot's alias to the pivoted table, use the table's name instead
|
||||
if pivots and pivots[0].alias == name:
|
||||
name = source.name
|
||||
|
||||
# Mutates the source by attaching an alias to it
|
||||
alias(source, name or source.name or next_alias_name(), copy=False, table=True)
|
||||
|
||||
pivots = source.args.get("pivots")
|
||||
if pivots and not pivots[0].alias:
|
||||
pivots[0].set(
|
||||
"alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))
|
||||
|
|
|
@ -539,11 +539,23 @@ def _traverse_union(scope):
|
|||
|
||||
# The last scope to be yield should be the top most scope
|
||||
left = None
|
||||
for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)):
|
||||
for left in _traverse_scope(
|
||||
scope.branch(
|
||||
scope.expression.left,
|
||||
outer_column_list=scope.outer_column_list,
|
||||
scope_type=ScopeType.UNION,
|
||||
)
|
||||
):
|
||||
yield left
|
||||
|
||||
right = None
|
||||
for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)):
|
||||
for right in _traverse_scope(
|
||||
scope.branch(
|
||||
scope.expression.right,
|
||||
outer_column_list=scope.outer_column_list,
|
||||
scope_type=ScopeType.UNION,
|
||||
)
|
||||
):
|
||||
yield right
|
||||
|
||||
scope.union_scopes = [left, right]
|
||||
|
|
|
@ -100,6 +100,7 @@ def simplify(
|
|||
node = simplify_parens(node)
|
||||
node = simplify_datetrunc(node, dialect)
|
||||
node = sort_comparison(node)
|
||||
node = simplify_startswith(node)
|
||||
|
||||
if root:
|
||||
expression.replace(node)
|
||||
|
@ -776,6 +777,26 @@ def simplify_conditionals(expression):
|
|||
return expression
|
||||
|
||||
|
||||
def simplify_startswith(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Reduces a prefix check to either TRUE or FALSE if both the string and the
|
||||
prefix are statically known.
|
||||
|
||||
Example:
|
||||
>>> from sqlglot import parse_one
|
||||
>>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
|
||||
'TRUE'
|
||||
"""
|
||||
if (
|
||||
isinstance(expression, exp.StartsWith)
|
||||
and expression.this.is_string
|
||||
and expression.expression.is_string
|
||||
):
|
||||
return exp.convert(expression.name.startswith(expression.expression.name))
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
DateRange = t.Tuple[datetime.date, datetime.date]
|
||||
|
||||
|
||||
|
@ -1160,7 +1181,7 @@ def gen(expression: t.Any) -> str:
|
|||
GEN_MAP = {
|
||||
exp.Add: lambda e: _binary(e, "+"),
|
||||
exp.And: lambda e: _binary(e, "AND"),
|
||||
exp.Anonymous: lambda e: f"{e.this} {','.join(gen(e) for e in e.expressions)}",
|
||||
exp.Anonymous: lambda e: f"{e.this.upper()} {','.join(gen(e) for e in e.expressions)}",
|
||||
exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
|
||||
exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
|
||||
exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
|
||||
|
|
|
@ -12,6 +12,8 @@ from sqlglot.tokens import Token, Tokenizer, TokenType
|
|||
from sqlglot.trie import TrieResult, in_trie, new_trie
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from typing_extensions import Literal
|
||||
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import Dialect, DialectType
|
||||
|
||||
|
@ -193,6 +195,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.DATETIME,
|
||||
TokenType.DATETIME64,
|
||||
TokenType.DATE,
|
||||
TokenType.DATE32,
|
||||
TokenType.INT4RANGE,
|
||||
TokenType.INT4MULTIRANGE,
|
||||
TokenType.INT8RANGE,
|
||||
|
@ -232,6 +235,8 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.INET,
|
||||
TokenType.IPADDRESS,
|
||||
TokenType.IPPREFIX,
|
||||
TokenType.IPV4,
|
||||
TokenType.IPV6,
|
||||
TokenType.UNKNOWN,
|
||||
TokenType.NULL,
|
||||
*ENUM_TYPE_TOKENS,
|
||||
|
@ -669,6 +674,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
PROPERTY_PARSERS: t.Dict[str, t.Callable] = {
|
||||
"ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty),
|
||||
"AUTO": lambda self: self._parse_auto_property(),
|
||||
"AUTO_INCREMENT": lambda self: self._parse_property_assignment(exp.AutoIncrementProperty),
|
||||
"BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(),
|
||||
"CHARSET": lambda self, **kwargs: self._parse_character_set(**kwargs),
|
||||
|
@ -680,6 +686,7 @@ class Parser(metaclass=_Parser):
|
|||
exp.CollateProperty, **kwargs
|
||||
),
|
||||
"COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
|
||||
"CONTAINS": lambda self: self._parse_contains_property(),
|
||||
"COPY": lambda self: self._parse_copy_property(),
|
||||
"DATABLOCKSIZE": lambda self, **kwargs: self._parse_datablocksize(**kwargs),
|
||||
"DEFINER": lambda self: self._parse_definer(),
|
||||
|
@ -710,6 +717,7 @@ class Parser(metaclass=_Parser):
|
|||
"LOG": lambda self, **kwargs: self._parse_log(**kwargs),
|
||||
"MATERIALIZED": lambda self: self.expression(exp.MaterializedProperty),
|
||||
"MERGEBLOCKRATIO": lambda self, **kwargs: self._parse_mergeblockratio(**kwargs),
|
||||
"MODIFIES": lambda self: self._parse_modifies_property(),
|
||||
"MULTISET": lambda self: self.expression(exp.SetProperty, multi=True),
|
||||
"NO": lambda self: self._parse_no_property(),
|
||||
"ON": lambda self: self._parse_on_property(),
|
||||
|
@ -721,6 +729,7 @@ class Parser(metaclass=_Parser):
|
|||
"PARTITIONED_BY": lambda self: self._parse_partitioned_by(),
|
||||
"PRIMARY KEY": lambda self: self._parse_primary_key(in_props=True),
|
||||
"RANGE": lambda self: self._parse_dict_range(this="RANGE"),
|
||||
"READS": lambda self: self._parse_reads_property(),
|
||||
"REMOTE": lambda self: self._parse_remote_with_connection(),
|
||||
"RETURNS": lambda self: self._parse_returns(),
|
||||
"ROW": lambda self: self._parse_row(),
|
||||
|
@ -841,6 +850,7 @@ class Parser(metaclass=_Parser):
|
|||
"DECODE": lambda self: self._parse_decode(),
|
||||
"EXTRACT": lambda self: self._parse_extract(),
|
||||
"JSON_OBJECT": lambda self: self._parse_json_object(),
|
||||
"JSON_OBJECTAGG": lambda self: self._parse_json_object(agg=True),
|
||||
"JSON_TABLE": lambda self: self._parse_json_table(),
|
||||
"MATCH": lambda self: self._parse_match_against(),
|
||||
"OPENJSON": lambda self: self._parse_open_json(),
|
||||
|
@ -925,6 +935,8 @@ class Parser(metaclass=_Parser):
|
|||
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER}
|
||||
WINDOW_SIDES = {"FOLLOWING", "PRECEDING"}
|
||||
|
||||
JSON_KEY_VALUE_SEPARATOR_TOKENS = {TokenType.COLON, TokenType.COMMA, TokenType.IS}
|
||||
|
||||
FETCH_TOKENS = ID_VAR_TOKENS - {TokenType.ROW, TokenType.ROWS, TokenType.PERCENT}
|
||||
|
||||
ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY}
|
||||
|
@ -954,6 +966,9 @@ class Parser(metaclass=_Parser):
|
|||
# Whether the TRIM function expects the characters to trim as its first argument
|
||||
TRIM_PATTERN_FIRST = False
|
||||
|
||||
# Whether or not string aliases are supported `SELECT COUNT(*) 'count'`
|
||||
STRING_ALIASES = False
|
||||
|
||||
# Whether query modifiers such as LIMIT are attached to the UNION node (vs its right operand)
|
||||
MODIFIERS_ATTACHED_TO_UNION = True
|
||||
UNION_MODIFIERS = {"order", "limit", "offset"}
|
||||
|
@ -1193,7 +1208,9 @@ class Parser(metaclass=_Parser):
|
|||
self._advance(index - self._index)
|
||||
|
||||
def _parse_command(self) -> exp.Command:
|
||||
return self.expression(exp.Command, this=self._prev.text, expression=self._parse_string())
|
||||
return self.expression(
|
||||
exp.Command, this=self._prev.text.upper(), expression=self._parse_string()
|
||||
)
|
||||
|
||||
def _parse_comment(self, allow_exists: bool = True) -> exp.Expression:
|
||||
start = self._prev
|
||||
|
@ -1353,26 +1370,27 @@ class Parser(metaclass=_Parser):
|
|||
# exp.Properties.Location.POST_SCHEMA ("schema" here is the UDF's type signature)
|
||||
extend_props(self._parse_properties())
|
||||
|
||||
self._match(TokenType.ALIAS)
|
||||
expression = self._match(TokenType.ALIAS) and self._parse_heredoc()
|
||||
|
||||
if self._match(TokenType.COMMAND):
|
||||
expression = self._parse_as_command(self._prev)
|
||||
else:
|
||||
begin = self._match(TokenType.BEGIN)
|
||||
return_ = self._match_text_seq("RETURN")
|
||||
|
||||
if self._match(TokenType.STRING, advance=False):
|
||||
# Takes care of BigQuery's JavaScript UDF definitions that end in an OPTIONS property
|
||||
# # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_function_statement
|
||||
expression = self._parse_string()
|
||||
extend_props(self._parse_properties())
|
||||
if not expression:
|
||||
if self._match(TokenType.COMMAND):
|
||||
expression = self._parse_as_command(self._prev)
|
||||
else:
|
||||
expression = self._parse_statement()
|
||||
begin = self._match(TokenType.BEGIN)
|
||||
return_ = self._match_text_seq("RETURN")
|
||||
|
||||
end = self._match_text_seq("END")
|
||||
if self._match(TokenType.STRING, advance=False):
|
||||
# Takes care of BigQuery's JavaScript UDF definitions that end in an OPTIONS property
|
||||
# # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_function_statement
|
||||
expression = self._parse_string()
|
||||
extend_props(self._parse_properties())
|
||||
else:
|
||||
expression = self._parse_statement()
|
||||
|
||||
if return_:
|
||||
expression = self.expression(exp.Return, this=expression)
|
||||
end = self._match_text_seq("END")
|
||||
|
||||
if return_:
|
||||
expression = self.expression(exp.Return, this=expression)
|
||||
elif create_token.token_type == TokenType.INDEX:
|
||||
this = self._parse_index(index=self._parse_id_var())
|
||||
elif create_token.token_type in self.DB_CREATABLES:
|
||||
|
@ -1426,7 +1444,7 @@ class Parser(metaclass=_Parser):
|
|||
exp.Create,
|
||||
comments=comments,
|
||||
this=this,
|
||||
kind=create_token.text,
|
||||
kind=create_token.text.upper(),
|
||||
replace=replace,
|
||||
unique=unique,
|
||||
expression=expression,
|
||||
|
@ -1849,9 +1867,21 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return self.expression(exp.WithDataProperty, no=no, statistics=statistics)
|
||||
|
||||
def _parse_no_property(self) -> t.Optional[exp.NoPrimaryIndexProperty]:
|
||||
def _parse_contains_property(self) -> t.Optional[exp.SqlReadWriteProperty]:
|
||||
if self._match_text_seq("SQL"):
|
||||
return self.expression(exp.SqlReadWriteProperty, this="CONTAINS SQL")
|
||||
return None
|
||||
|
||||
def _parse_modifies_property(self) -> t.Optional[exp.SqlReadWriteProperty]:
|
||||
if self._match_text_seq("SQL", "DATA"):
|
||||
return self.expression(exp.SqlReadWriteProperty, this="MODIFIES SQL DATA")
|
||||
return None
|
||||
|
||||
def _parse_no_property(self) -> t.Optional[exp.Expression]:
|
||||
if self._match_text_seq("PRIMARY", "INDEX"):
|
||||
return exp.NoPrimaryIndexProperty()
|
||||
if self._match_text_seq("SQL"):
|
||||
return self.expression(exp.SqlReadWriteProperty, this="NO SQL")
|
||||
return None
|
||||
|
||||
def _parse_on_property(self) -> t.Optional[exp.Expression]:
|
||||
|
@ -1861,6 +1891,11 @@ class Parser(metaclass=_Parser):
|
|||
return exp.OnCommitProperty(delete=True)
|
||||
return self.expression(exp.OnProperty, this=self._parse_schema(self._parse_id_var()))
|
||||
|
||||
def _parse_reads_property(self) -> t.Optional[exp.SqlReadWriteProperty]:
|
||||
if self._match_text_seq("SQL", "DATA"):
|
||||
return self.expression(exp.SqlReadWriteProperty, this="READS SQL DATA")
|
||||
return None
|
||||
|
||||
def _parse_distkey(self) -> exp.DistKeyProperty:
|
||||
return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var))
|
||||
|
||||
|
@ -1920,10 +1955,13 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
def _parse_describe(self) -> exp.Describe:
|
||||
kind = self._match_set(self.CREATABLES) and self._prev.text
|
||||
extended = self._match_text_seq("EXTENDED")
|
||||
this = self._parse_table(schema=True)
|
||||
properties = self._parse_properties()
|
||||
expressions = properties.expressions if properties else None
|
||||
return self.expression(exp.Describe, this=this, kind=kind, expressions=expressions)
|
||||
return self.expression(
|
||||
exp.Describe, this=this, extended=extended, kind=kind, expressions=expressions
|
||||
)
|
||||
|
||||
def _parse_insert(self) -> exp.Insert:
|
||||
comments = ensure_list(self._prev_comments)
|
||||
|
@ -2164,13 +2202,13 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
def _parse_value(self) -> exp.Tuple:
|
||||
if self._match(TokenType.L_PAREN):
|
||||
expressions = self._parse_csv(self._parse_conjunction)
|
||||
expressions = self._parse_csv(self._parse_expression)
|
||||
self._match_r_paren()
|
||||
return self.expression(exp.Tuple, expressions=expressions)
|
||||
|
||||
# In presto we can have VALUES 1, 2 which results in 1 column & 2 rows.
|
||||
# https://prestodb.io/docs/current/sql/values.html
|
||||
return self.expression(exp.Tuple, expressions=[self._parse_conjunction()])
|
||||
return self.expression(exp.Tuple, expressions=[self._parse_expression()])
|
||||
|
||||
def _parse_projections(self) -> t.List[exp.Expression]:
|
||||
return self._parse_expressions()
|
||||
|
@ -2212,7 +2250,7 @@ class Parser(metaclass=_Parser):
|
|||
kind = (
|
||||
self._match(TokenType.ALIAS)
|
||||
and self._match_texts(("STRUCT", "VALUE"))
|
||||
and self._prev.text
|
||||
and self._prev.text.upper()
|
||||
)
|
||||
|
||||
if distinct:
|
||||
|
@ -2261,7 +2299,7 @@ class Parser(metaclass=_Parser):
|
|||
if table
|
||||
else self._parse_select(nested=True, parse_set_operation=False)
|
||||
)
|
||||
this = self._parse_set_operations(self._parse_query_modifiers(this))
|
||||
this = self._parse_query_modifiers(self._parse_set_operations(this))
|
||||
|
||||
self._match_r_paren()
|
||||
|
||||
|
@ -2304,7 +2342,7 @@ class Parser(metaclass=_Parser):
|
|||
)
|
||||
|
||||
def _parse_cte(self) -> exp.CTE:
|
||||
alias = self._parse_table_alias()
|
||||
alias = self._parse_table_alias(self.ID_VAR_TOKENS)
|
||||
if not alias or not alias.this:
|
||||
self.raise_error("Expected CTE to have alias")
|
||||
|
||||
|
@ -2490,13 +2528,14 @@ class Parser(metaclass=_Parser):
|
|||
)
|
||||
|
||||
def _parse_lateral(self) -> t.Optional[exp.Lateral]:
|
||||
outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY)
|
||||
cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY)
|
||||
if not cross_apply and self._match_pair(TokenType.OUTER, TokenType.APPLY):
|
||||
cross_apply = False
|
||||
|
||||
if outer_apply or cross_apply:
|
||||
if cross_apply is not None:
|
||||
this = self._parse_select(table=True)
|
||||
view = None
|
||||
outer = not cross_apply
|
||||
outer = None
|
||||
elif self._match(TokenType.LATERAL):
|
||||
this = self._parse_select(table=True)
|
||||
view = self._match(TokenType.VIEW)
|
||||
|
@ -2529,7 +2568,14 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
table_alias = self._parse_table_alias()
|
||||
|
||||
return self.expression(exp.Lateral, this=this, view=view, outer=outer, alias=table_alias)
|
||||
return self.expression(
|
||||
exp.Lateral,
|
||||
this=this,
|
||||
view=view,
|
||||
outer=outer,
|
||||
alias=table_alias,
|
||||
cross_apply=cross_apply,
|
||||
)
|
||||
|
||||
def _parse_join_parts(
|
||||
self,
|
||||
|
@ -2563,9 +2609,6 @@ class Parser(metaclass=_Parser):
|
|||
if not skip_join_token and not join and not outer_apply and not cross_apply:
|
||||
return None
|
||||
|
||||
if outer_apply:
|
||||
side = Token(TokenType.LEFT, "LEFT")
|
||||
|
||||
kwargs: t.Dict[str, t.Any] = {"this": self._parse_table(parse_bracket=parse_bracket)}
|
||||
|
||||
if method:
|
||||
|
@ -2755,8 +2798,10 @@ class Parser(metaclass=_Parser):
|
|||
if alias:
|
||||
this.set("alias", alias)
|
||||
|
||||
if self._match_text_seq("AT"):
|
||||
this.set("index", self._parse_id_var())
|
||||
if isinstance(this, exp.Table) and self._match_text_seq("AT"):
|
||||
return self.expression(
|
||||
exp.AtIndex, this=this.to_column(copy=False), expression=self._parse_id_var()
|
||||
)
|
||||
|
||||
this.set("hints", self._parse_table_hints())
|
||||
|
||||
|
@ -2865,15 +2910,10 @@ class Parser(metaclass=_Parser):
|
|||
bucket_denominator = None
|
||||
bucket_field = None
|
||||
percent = None
|
||||
rows = None
|
||||
size = None
|
||||
seed = None
|
||||
|
||||
kind = (
|
||||
self._prev.text if self._prev.token_type == TokenType.TABLE_SAMPLE else "USING SAMPLE"
|
||||
)
|
||||
method = self._parse_var(tokens=(TokenType.ROW,))
|
||||
|
||||
method = self._parse_var(tokens=(TokenType.ROW,), upper=True)
|
||||
matched_l_paren = self._match(TokenType.L_PAREN)
|
||||
|
||||
if self.TABLESAMPLE_CSV:
|
||||
|
@ -2895,16 +2935,16 @@ class Parser(metaclass=_Parser):
|
|||
bucket_field = self._parse_field()
|
||||
elif self._match_set((TokenType.PERCENT, TokenType.MOD)):
|
||||
percent = num
|
||||
elif self._match(TokenType.ROWS):
|
||||
rows = num
|
||||
elif num:
|
||||
elif self._match(TokenType.ROWS) or not self.dialect.TABLESAMPLE_SIZE_IS_PERCENT:
|
||||
size = num
|
||||
else:
|
||||
percent = num
|
||||
|
||||
if matched_l_paren:
|
||||
self._match_r_paren()
|
||||
|
||||
if self._match(TokenType.L_PAREN):
|
||||
method = self._parse_var()
|
||||
method = self._parse_var(upper=True)
|
||||
seed = self._match(TokenType.COMMA) and self._parse_number()
|
||||
self._match_r_paren()
|
||||
elif self._match_texts(("SEED", "REPEATABLE")):
|
||||
|
@ -2918,10 +2958,8 @@ class Parser(metaclass=_Parser):
|
|||
bucket_denominator=bucket_denominator,
|
||||
bucket_field=bucket_field,
|
||||
percent=percent,
|
||||
rows=rows,
|
||||
size=size,
|
||||
seed=seed,
|
||||
kind=kind,
|
||||
)
|
||||
|
||||
def _parse_pivots(self) -> t.Optional[t.List[exp.Pivot]]:
|
||||
|
@ -2946,6 +2984,27 @@ class Parser(metaclass=_Parser):
|
|||
exp.Pivot, this=this, expressions=expressions, using=using, group=group
|
||||
)
|
||||
|
||||
def _parse_pivot_in(self) -> exp.In:
|
||||
def _parse_aliased_expression() -> t.Optional[exp.Expression]:
|
||||
this = self._parse_conjunction()
|
||||
|
||||
self._match(TokenType.ALIAS)
|
||||
alias = self._parse_field()
|
||||
if alias:
|
||||
return self.expression(exp.PivotAlias, this=this, alias=alias)
|
||||
|
||||
return this
|
||||
|
||||
value = self._parse_column()
|
||||
|
||||
if not self._match_pair(TokenType.IN, TokenType.L_PAREN):
|
||||
self.raise_error("Expecting IN (")
|
||||
|
||||
aliased_expressions = self._parse_csv(_parse_aliased_expression)
|
||||
|
||||
self._match_r_paren()
|
||||
return self.expression(exp.In, this=value, expressions=aliased_expressions)
|
||||
|
||||
def _parse_pivot(self) -> t.Optional[exp.Pivot]:
|
||||
index = self._index
|
||||
include_nulls = None
|
||||
|
@ -2964,7 +3023,6 @@ class Parser(metaclass=_Parser):
|
|||
return None
|
||||
|
||||
expressions = []
|
||||
field = None
|
||||
|
||||
if not self._match(TokenType.L_PAREN):
|
||||
self._retreat(index)
|
||||
|
@ -2981,12 +3039,7 @@ class Parser(metaclass=_Parser):
|
|||
if not self._match(TokenType.FOR):
|
||||
self.raise_error("Expecting FOR")
|
||||
|
||||
value = self._parse_column()
|
||||
|
||||
if not self._match(TokenType.IN):
|
||||
self.raise_error("Expecting IN")
|
||||
|
||||
field = self._parse_in(value, alias=True)
|
||||
field = self._parse_pivot_in()
|
||||
|
||||
self._match_r_paren()
|
||||
|
||||
|
@ -3132,14 +3185,19 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_order(
|
||||
self, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False
|
||||
) -> t.Optional[exp.Expression]:
|
||||
siblings = None
|
||||
if not skip_order_token and not self._match(TokenType.ORDER_BY):
|
||||
return this
|
||||
if not self._match(TokenType.ORDER_SIBLINGS_BY):
|
||||
return this
|
||||
|
||||
siblings = True
|
||||
|
||||
return self.expression(
|
||||
exp.Order,
|
||||
this=this,
|
||||
expressions=self._parse_csv(self._parse_ordered),
|
||||
interpolate=self._parse_interpolate(),
|
||||
siblings=siblings,
|
||||
)
|
||||
|
||||
def _parse_sort(self, exp_class: t.Type[E], token: TokenType) -> t.Optional[E]:
|
||||
|
@ -3213,7 +3271,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
if self._match(TokenType.FETCH):
|
||||
direction = self._match_set((TokenType.FIRST, TokenType.NEXT))
|
||||
direction = self._prev.text if direction else "FIRST"
|
||||
direction = self._prev.text.upper() if direction else "FIRST"
|
||||
|
||||
count = self._parse_field(tokens=self.FETCH_TOKENS)
|
||||
percent = self._match(TokenType.PERCENT)
|
||||
|
@ -3398,10 +3456,10 @@ class Parser(metaclass=_Parser):
|
|||
return this
|
||||
return self.expression(exp.Escape, this=this, expression=self._parse_string())
|
||||
|
||||
def _parse_interval(self) -> t.Optional[exp.Interval]:
|
||||
def _parse_interval(self, match_interval: bool = True) -> t.Optional[exp.Interval]:
|
||||
index = self._index
|
||||
|
||||
if not self._match(TokenType.INTERVAL):
|
||||
if not self._match(TokenType.INTERVAL) and match_interval:
|
||||
return None
|
||||
|
||||
if self._match(TokenType.STRING, advance=False):
|
||||
|
@ -3409,11 +3467,19 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
this = self._parse_term()
|
||||
|
||||
if not this:
|
||||
if not this or (
|
||||
isinstance(this, exp.Column)
|
||||
and not this.table
|
||||
and not this.this.quoted
|
||||
and this.name.upper() == "IS"
|
||||
):
|
||||
self._retreat(index)
|
||||
return None
|
||||
|
||||
unit = self._parse_function() or self._parse_var(any_token=True)
|
||||
unit = self._parse_function() or (
|
||||
not self._match(TokenType.ALIAS, advance=False)
|
||||
and self._parse_var(any_token=True, upper=True)
|
||||
)
|
||||
|
||||
# Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse
|
||||
# each INTERVAL expression into this canonical form so it's easy to transpile
|
||||
|
@ -3429,7 +3495,7 @@ class Parser(metaclass=_Parser):
|
|||
self._retreat(self._index - 1)
|
||||
|
||||
this = exp.Literal.string(parts[0])
|
||||
unit = self.expression(exp.Var, this=parts[1])
|
||||
unit = self.expression(exp.Var, this=parts[1].upper())
|
||||
|
||||
return self.expression(exp.Interval, this=this, unit=unit)
|
||||
|
||||
|
@ -3489,6 +3555,12 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_type(self, parse_interval: bool = True) -> t.Optional[exp.Expression]:
|
||||
interval = parse_interval and self._parse_interval()
|
||||
if interval:
|
||||
# Convert INTERVAL 'val_1' unit_1 ... 'val_n' unit_n into a sum of intervals
|
||||
while self._match_set((TokenType.STRING, TokenType.NUMBER), advance=False):
|
||||
interval = self.expression( # type: ignore
|
||||
exp.Add, this=interval, expression=self._parse_interval(match_interval=False)
|
||||
)
|
||||
|
||||
return interval
|
||||
|
||||
index = self._index
|
||||
|
@ -3552,10 +3624,10 @@ class Parser(metaclass=_Parser):
|
|||
type_token = self._prev.token_type
|
||||
|
||||
if type_token == TokenType.PSEUDO_TYPE:
|
||||
return self.expression(exp.PseudoType, this=self._prev.text)
|
||||
return self.expression(exp.PseudoType, this=self._prev.text.upper())
|
||||
|
||||
if type_token == TokenType.OBJECT_IDENTIFIER:
|
||||
return self.expression(exp.ObjectIdentifier, this=self._prev.text)
|
||||
return self.expression(exp.ObjectIdentifier, this=self._prev.text.upper())
|
||||
|
||||
nested = type_token in self.NESTED_TYPE_TOKENS
|
||||
is_struct = type_token in self.STRUCT_TYPE_TOKENS
|
||||
|
@ -3587,7 +3659,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
if nested and self._match(TokenType.LT):
|
||||
if is_struct:
|
||||
expressions = self._parse_csv(self._parse_struct_types)
|
||||
expressions = self._parse_csv(lambda: self._parse_struct_types(type_required=True))
|
||||
else:
|
||||
expressions = self._parse_csv(
|
||||
lambda: self._parse_types(
|
||||
|
@ -3662,10 +3734,19 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return this
|
||||
|
||||
def _parse_struct_types(self) -> t.Optional[exp.Expression]:
|
||||
def _parse_struct_types(self, type_required: bool = False) -> t.Optional[exp.Expression]:
|
||||
index = self._index
|
||||
this = self._parse_type(parse_interval=False) or self._parse_id_var()
|
||||
self._match(TokenType.COLON)
|
||||
return self._parse_column_def(this)
|
||||
column_def = self._parse_column_def(this)
|
||||
|
||||
if type_required and (
|
||||
(isinstance(this, exp.Column) and this.this is column_def) or this is column_def
|
||||
):
|
||||
self._retreat(index)
|
||||
return self._parse_types()
|
||||
|
||||
return column_def
|
||||
|
||||
def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
if not self._match_text_seq("AT", "TIME", "ZONE"):
|
||||
|
@ -4025,6 +4106,12 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return exp.AutoIncrementColumnConstraint()
|
||||
|
||||
def _parse_auto_property(self) -> t.Optional[exp.AutoRefreshProperty]:
|
||||
if not self._match_text_seq("REFRESH"):
|
||||
self._retreat(self._index - 1)
|
||||
return None
|
||||
return self.expression(exp.AutoRefreshProperty, this=self._parse_var(upper=True))
|
||||
|
||||
def _parse_compress(self) -> exp.CompressColumnConstraint:
|
||||
if self._match(TokenType.L_PAREN, advance=False):
|
||||
return self.expression(
|
||||
|
@ -4230,8 +4317,10 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_primary_key_part(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_field()
|
||||
|
||||
def _parse_period_for_system_time(self) -> exp.PeriodForSystemTimeConstraint:
|
||||
self._match(TokenType.TIMESTAMP_SNAPSHOT)
|
||||
def _parse_period_for_system_time(self) -> t.Optional[exp.PeriodForSystemTimeConstraint]:
|
||||
if not self._match(TokenType.TIMESTAMP_SNAPSHOT):
|
||||
self._retreat(self._index - 1)
|
||||
return None
|
||||
|
||||
id_vars = self._parse_wrapped_id_vars()
|
||||
return self.expression(
|
||||
|
@ -4257,22 +4346,17 @@ class Parser(metaclass=_Parser):
|
|||
options = self._parse_key_constraint_options()
|
||||
return self.expression(exp.PrimaryKey, expressions=expressions, options=options)
|
||||
|
||||
def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]:
|
||||
return self._parse_slice(self._parse_alias(self._parse_conjunction(), explicit=True))
|
||||
|
||||
def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
if not self._match_set((TokenType.L_BRACKET, TokenType.L_BRACE)):
|
||||
return this
|
||||
|
||||
bracket_kind = self._prev.token_type
|
||||
|
||||
if self._match(TokenType.COLON):
|
||||
expressions: t.List[exp.Expression] = [
|
||||
self.expression(exp.Slice, expression=self._parse_conjunction())
|
||||
]
|
||||
else:
|
||||
expressions = self._parse_csv(
|
||||
lambda: self._parse_slice(
|
||||
self._parse_alias(self._parse_conjunction(), explicit=True)
|
||||
)
|
||||
)
|
||||
expressions = self._parse_csv(
|
||||
lambda: self._parse_bracket_key_value(is_map=bracket_kind == TokenType.L_BRACE)
|
||||
)
|
||||
|
||||
if not self._match(TokenType.R_BRACKET) and bracket_kind == TokenType.L_BRACKET:
|
||||
self.raise_error("Expected ]")
|
||||
|
@ -4313,7 +4397,10 @@ class Parser(metaclass=_Parser):
|
|||
default = self._parse_conjunction()
|
||||
|
||||
if not self._match(TokenType.END):
|
||||
self.raise_error("Expected END after CASE", self._prev)
|
||||
if isinstance(default, exp.Interval) and default.this.sql().upper() == "END":
|
||||
default = exp.column("interval")
|
||||
else:
|
||||
self.raise_error("Expected END after CASE", self._prev)
|
||||
|
||||
return self._parse_window(
|
||||
self.expression(exp.Case, comments=comments, this=expression, ifs=ifs, default=default)
|
||||
|
@ -4514,7 +4601,7 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_json_key_value(self) -> t.Optional[exp.JSONKeyValue]:
|
||||
self._match_text_seq("KEY")
|
||||
key = self._parse_column()
|
||||
self._match_set((TokenType.COLON, TokenType.COMMA))
|
||||
self._match_set(self.JSON_KEY_VALUE_SEPARATOR_TOKENS)
|
||||
self._match_text_seq("VALUE")
|
||||
value = self._parse_bitwise()
|
||||
|
||||
|
@ -4536,7 +4623,15 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return None
|
||||
|
||||
def _parse_json_object(self) -> exp.JSONObject:
|
||||
@t.overload
|
||||
def _parse_json_object(self, agg: Literal[False]) -> exp.JSONObject:
|
||||
...
|
||||
|
||||
@t.overload
|
||||
def _parse_json_object(self, agg: Literal[True]) -> exp.JSONObjectAgg:
|
||||
...
|
||||
|
||||
def _parse_json_object(self, agg=False):
|
||||
star = self._parse_star()
|
||||
expressions = (
|
||||
[star]
|
||||
|
@ -4559,7 +4654,7 @@ class Parser(metaclass=_Parser):
|
|||
encoding = self._match_text_seq("ENCODING") and self._parse_var()
|
||||
|
||||
return self.expression(
|
||||
exp.JSONObject,
|
||||
exp.JSONObjectAgg if agg else exp.JSONObject,
|
||||
expressions=expressions,
|
||||
null_handling=null_handling,
|
||||
unique_keys=unique_keys,
|
||||
|
@ -4873,10 +4968,17 @@ class Parser(metaclass=_Parser):
|
|||
self._match_r_paren(aliases)
|
||||
return aliases
|
||||
|
||||
alias = self._parse_id_var(any_token)
|
||||
alias = self._parse_id_var(any_token) or (
|
||||
self.STRING_ALIASES and self._parse_string_as_identifier()
|
||||
)
|
||||
|
||||
if alias:
|
||||
return self.expression(exp.Alias, comments=comments, this=this, alias=alias)
|
||||
this = self.expression(exp.Alias, comments=comments, this=this, alias=alias)
|
||||
|
||||
# Moves the comment next to the alias in `expr /* comment */ AS alias`
|
||||
if not this.comments and this.this.comments:
|
||||
this.comments = this.this.comments
|
||||
this.this.comments = None
|
||||
|
||||
return this
|
||||
|
||||
|
@ -4915,14 +5017,19 @@ class Parser(metaclass=_Parser):
|
|||
return self._parse_placeholder()
|
||||
|
||||
def _parse_var(
|
||||
self, any_token: bool = False, tokens: t.Optional[t.Collection[TokenType]] = None
|
||||
self,
|
||||
any_token: bool = False,
|
||||
tokens: t.Optional[t.Collection[TokenType]] = None,
|
||||
upper: bool = False,
|
||||
) -> t.Optional[exp.Expression]:
|
||||
if (
|
||||
(any_token and self._advance_any())
|
||||
or self._match(TokenType.VAR)
|
||||
or (self._match_set(tokens) if tokens else False)
|
||||
):
|
||||
return self.expression(exp.Var, this=self._prev.text)
|
||||
return self.expression(
|
||||
exp.Var, this=self._prev.text.upper() if upper else self._prev.text
|
||||
)
|
||||
return self._parse_placeholder()
|
||||
|
||||
def _advance_any(self, ignore_reserved: bool = False) -> t.Optional[Token]:
|
||||
|
@ -5418,6 +5525,42 @@ class Parser(metaclass=_Parser):
|
|||
condition=condition,
|
||||
)
|
||||
|
||||
def _parse_heredoc(self) -> t.Optional[exp.Heredoc]:
|
||||
if self._match(TokenType.HEREDOC_STRING):
|
||||
return self.expression(exp.Heredoc, this=self._prev.text)
|
||||
|
||||
if not self._match_text_seq("$"):
|
||||
return None
|
||||
|
||||
tags = ["$"]
|
||||
tag_text = None
|
||||
|
||||
if self._is_connected():
|
||||
self._advance()
|
||||
tags.append(self._prev.text.upper())
|
||||
else:
|
||||
self.raise_error("No closing $ found")
|
||||
|
||||
if tags[-1] != "$":
|
||||
if self._is_connected() and self._match_text_seq("$"):
|
||||
tag_text = tags[-1]
|
||||
tags.append("$")
|
||||
else:
|
||||
self.raise_error("No closing $ found")
|
||||
|
||||
heredoc_start = self._curr
|
||||
|
||||
while self._curr:
|
||||
if self._match_text_seq(*tags, advance=False):
|
||||
this = self._find_sql(heredoc_start, self._prev)
|
||||
self._advance(len(tags))
|
||||
return self.expression(exp.Heredoc, this=this, tag=tag_text)
|
||||
|
||||
self._advance()
|
||||
|
||||
self.raise_error(f"No closing {''.join(tags)} found")
|
||||
return None
|
||||
|
||||
def _find_parser(
|
||||
self, parsers: t.Dict[str, t.Callable], trie: t.Dict
|
||||
) -> t.Optional[t.Callable]:
|
||||
|
|
|
@ -215,12 +215,13 @@ class MappingSchema(AbstractMappingSchema, Schema):
|
|||
normalize: bool = True,
|
||||
) -> None:
|
||||
self.dialect = dialect
|
||||
self.visible = visible or {}
|
||||
self.visible = {} if visible is None else visible
|
||||
self.normalize = normalize
|
||||
self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
|
||||
self._depth = 0
|
||||
schema = {} if schema is None else schema
|
||||
|
||||
super().__init__(self._normalize(schema or {}))
|
||||
super().__init__(self._normalize(schema) if self.normalize else schema)
|
||||
|
||||
@classmethod
|
||||
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
|
||||
|
|
|
@ -147,6 +147,7 @@ class TokenType(AutoName):
|
|||
DATETIME = auto()
|
||||
DATETIME64 = auto()
|
||||
DATE = auto()
|
||||
DATE32 = auto()
|
||||
INT4RANGE = auto()
|
||||
INT4MULTIRANGE = auto()
|
||||
INT8RANGE = auto()
|
||||
|
@ -182,6 +183,8 @@ class TokenType(AutoName):
|
|||
INET = auto()
|
||||
IPADDRESS = auto()
|
||||
IPPREFIX = auto()
|
||||
IPV4 = auto()
|
||||
IPV6 = auto()
|
||||
ENUM = auto()
|
||||
ENUM8 = auto()
|
||||
ENUM16 = auto()
|
||||
|
@ -296,6 +299,7 @@ class TokenType(AutoName):
|
|||
ON = auto()
|
||||
OPERATOR = auto()
|
||||
ORDER_BY = auto()
|
||||
ORDER_SIBLINGS_BY = auto()
|
||||
ORDERED = auto()
|
||||
ORDINALITY = auto()
|
||||
OUTER = auto()
|
||||
|
|
|
@ -255,7 +255,7 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
|
|||
|
||||
if not arrays:
|
||||
if expression.args.get("from"):
|
||||
expression.join(series, copy=False)
|
||||
expression.join(series, copy=False, join_type="CROSS")
|
||||
else:
|
||||
expression.from_(series, copy=False)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue