Merging upstream version 17.7.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
87252470ef
commit
137902868c
93 changed files with 41580 additions and 39040 deletions
|
@ -114,7 +114,7 @@ class Column:
|
|||
return self.inverse_binary_op(exp.Or, other)
|
||||
|
||||
@classmethod
|
||||
def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
|
||||
def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]) -> Column:
|
||||
return cls(value)
|
||||
|
||||
@classmethod
|
||||
|
@ -259,7 +259,7 @@ class Column:
|
|||
new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
|
||||
return Column(new_expression)
|
||||
|
||||
def cast(self, dataType: t.Union[str, DataType]):
|
||||
def cast(self, dataType: t.Union[str, DataType]) -> Column:
|
||||
"""
|
||||
Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
|
||||
Sqlglot doesn't currently replicate this class so it only accepts a string
|
||||
|
|
|
@ -600,8 +600,13 @@ def months_between(
|
|||
date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None
|
||||
) -> Column:
|
||||
if roundOff is None:
|
||||
return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2)
|
||||
return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2, roundOff)
|
||||
return Column.invoke_expression_over_column(
|
||||
date1, expression.MonthsBetween, expression=date2
|
||||
)
|
||||
|
||||
return Column.invoke_expression_over_column(
|
||||
date1, expression.MonthsBetween, expression=date2, roundoff=roundOff
|
||||
)
|
||||
|
||||
|
||||
def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
|
||||
|
@ -614,8 +619,9 @@ def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
|
|||
|
||||
def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
|
||||
if format is not None:
|
||||
return Column.invoke_anonymous_function(col, "TO_TIMESTAMP", lit(format))
|
||||
return Column.invoke_anonymous_function(col, "TO_TIMESTAMP")
|
||||
return Column.invoke_expression_over_column(col, expression.StrToTime, format=lit(format))
|
||||
|
||||
return Column.ensure_col(col).cast("timestamp")
|
||||
|
||||
|
||||
def trunc(col: ColumnOrName, format: str) -> Column:
|
||||
|
@ -875,8 +881,16 @@ def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None)
|
|||
)
|
||||
|
||||
|
||||
def regexp_replace(str: ColumnOrName, pattern: str, replacement: str) -> Column:
|
||||
return Column.invoke_anonymous_function(str, "REGEXP_REPLACE", lit(pattern), lit(replacement))
|
||||
def regexp_replace(
|
||||
str: ColumnOrName, pattern: str, replacement: str, position: t.Optional[int] = None
|
||||
) -> Column:
|
||||
return Column.invoke_expression_over_column(
|
||||
str,
|
||||
expression.RegexpReplace,
|
||||
expression=lit(pattern),
|
||||
replacement=lit(replacement),
|
||||
position=position,
|
||||
)
|
||||
|
||||
|
||||
def initcap(col: ColumnOrName) -> Column:
|
||||
|
@ -1186,7 +1200,9 @@ def transform(
|
|||
f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
|
||||
) -> Column:
|
||||
f_expression = _get_lambda_from_func(f)
|
||||
return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression))
|
||||
return Column.invoke_expression_over_column(
|
||||
col, expression.Transform, expression=Column(f_expression)
|
||||
)
|
||||
|
||||
|
||||
def exists(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column:
|
||||
|
|
|
@ -8,6 +8,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
|
|||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
binary_from_function,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
inline_array_sql,
|
||||
|
@ -15,6 +16,7 @@ from sqlglot.dialects.dialect import (
|
|||
min_or_least,
|
||||
no_ilike_sql,
|
||||
parse_date_delta_with_interval,
|
||||
regexp_replace_sql,
|
||||
rename_func,
|
||||
timestrtotime_sql,
|
||||
ts_or_ds_to_date_sql,
|
||||
|
@ -39,7 +41,7 @@ def _date_add_sql(
|
|||
|
||||
|
||||
def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str:
|
||||
if not isinstance(expression.unnest().parent, exp.From):
|
||||
if not expression.find_ancestor(exp.From, exp.Join):
|
||||
return self.values_sql(expression)
|
||||
|
||||
alias = expression.args.get("alias")
|
||||
|
@ -279,7 +281,7 @@ class BigQuery(Dialect):
|
|||
),
|
||||
"DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd),
|
||||
"DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub),
|
||||
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
|
||||
"DIV": binary_from_function(exp.IntDiv),
|
||||
"GENERATE_ARRAY": exp.GenerateSeries.from_arg_list,
|
||||
"MD5": exp.MD5Digest.from_arg_list,
|
||||
"TO_HEX": _parse_to_hex,
|
||||
|
@ -415,6 +417,7 @@ class BigQuery(Dialect):
|
|||
e.args.get("position"),
|
||||
e.args.get("occurrence"),
|
||||
),
|
||||
exp.RegexpReplace: regexp_replace_sql,
|
||||
exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
|
||||
exp.ReturnsProperty: _returnsproperty_sql,
|
||||
exp.Select: transforms.preprocess(
|
||||
|
|
|
@ -64,6 +64,7 @@ class ClickHouse(Dialect):
|
|||
"MAP": parse_var_map,
|
||||
"MATCH": exp.RegexpLike.from_arg_list,
|
||||
"UNIQ": exp.ApproxDistinct.from_arg_list,
|
||||
"XOR": lambda args: exp.Xor(expressions=args),
|
||||
}
|
||||
|
||||
FUNCTIONS_WITH_ALIASED_ARGS = {*parser.Parser.FUNCTIONS_WITH_ALIASED_ARGS, "TUPLE"}
|
||||
|
@ -95,6 +96,7 @@ class ClickHouse(Dialect):
|
|||
TokenType.ASOF,
|
||||
TokenType.ANTI,
|
||||
TokenType.SEMI,
|
||||
TokenType.ARRAY,
|
||||
}
|
||||
|
||||
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {
|
||||
|
@ -103,6 +105,7 @@ class ClickHouse(Dialect):
|
|||
TokenType.ANTI,
|
||||
TokenType.SETTINGS,
|
||||
TokenType.FORMAT,
|
||||
TokenType.ARRAY,
|
||||
}
|
||||
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
@ -160,8 +163,11 @@ class ClickHouse(Dialect):
|
|||
schema: bool = False,
|
||||
joins: bool = False,
|
||||
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
|
||||
parse_bracket: bool = False,
|
||||
) -> t.Optional[exp.Expression]:
|
||||
this = super()._parse_table(schema=schema, joins=joins, alias_tokens=alias_tokens)
|
||||
this = super()._parse_table(
|
||||
schema=schema, joins=joins, alias_tokens=alias_tokens, parse_bracket=parse_bracket
|
||||
)
|
||||
|
||||
if self._match(TokenType.FINAL):
|
||||
this = self.expression(exp.Final, this=this)
|
||||
|
@ -204,8 +210,10 @@ class ClickHouse(Dialect):
|
|||
self._match_set(self.JOIN_KINDS) and self._prev,
|
||||
)
|
||||
|
||||
def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Join]:
|
||||
join = super()._parse_join(skip_join_token)
|
||||
def _parse_join(
|
||||
self, skip_join_token: bool = False, parse_bracket: bool = False
|
||||
) -> t.Optional[exp.Join]:
|
||||
join = super()._parse_join(skip_join_token=skip_join_token, parse_bracket=True)
|
||||
|
||||
if join:
|
||||
join.set("global", join.args.pop("method", None))
|
||||
|
@ -318,6 +326,7 @@ class ClickHouse(Dialect):
|
|||
exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})",
|
||||
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
|
||||
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||
exp.Xor: lambda self, e: self.func("xor", e.this, e.expression, *e.expressions),
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
|
|
|
@ -12,6 +12,8 @@ from sqlglot.time import format_time
|
|||
from sqlglot.tokens import Token, Tokenizer, TokenType
|
||||
from sqlglot.trie import new_trie
|
||||
|
||||
B = t.TypeVar("B", bound=exp.Binary)
|
||||
|
||||
|
||||
class Dialects(str, Enum):
|
||||
DIALECT = ""
|
||||
|
@ -630,6 +632,16 @@ def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
|
|||
)
|
||||
|
||||
|
||||
def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
|
||||
bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
|
||||
if bad_args:
|
||||
self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
|
||||
|
||||
return self.func(
|
||||
"REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
|
||||
)
|
||||
|
||||
|
||||
def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
|
||||
names = []
|
||||
for agg in aggregations:
|
||||
|
@ -650,3 +662,7 @@ def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectTyp
|
|||
names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
|
||||
|
||||
return names
|
||||
|
||||
|
||||
def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
|
||||
return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
|
||||
|
|
|
@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
|
|||
approx_count_distinct_sql,
|
||||
arrow_json_extract_scalar_sql,
|
||||
arrow_json_extract_sql,
|
||||
binary_from_function,
|
||||
date_trunc_to_time,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
|
@ -16,6 +17,7 @@ from sqlglot.dialects.dialect import (
|
|||
no_safe_divide_sql,
|
||||
pivot_column_names,
|
||||
regexp_extract_sql,
|
||||
regexp_replace_sql,
|
||||
rename_func,
|
||||
str_position_sql,
|
||||
str_to_time_sql,
|
||||
|
@ -103,7 +105,6 @@ class DuckDB(Dialect):
|
|||
class Tokenizer(tokens.Tokenizer):
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"~": TokenType.RLIKE,
|
||||
":=": TokenType.EQ,
|
||||
"//": TokenType.DIV,
|
||||
"ATTACH": TokenType.COMMAND,
|
||||
|
@ -128,6 +129,11 @@ class DuckDB(Dialect):
|
|||
class Parser(parser.Parser):
|
||||
CONCAT_NULL_OUTPUTS_STRING = True
|
||||
|
||||
BITWISE = {
|
||||
**parser.Parser.BITWISE,
|
||||
TokenType.TILDA: exp.RegexpLike,
|
||||
}
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
|
||||
|
@ -158,6 +164,7 @@ class DuckDB(Dialect):
|
|||
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
|
||||
"TO_TIMESTAMP": exp.UnixToTime.from_arg_list,
|
||||
"UNNEST": exp.Explode.from_arg_list,
|
||||
"XOR": binary_from_function(exp.BitwiseXor),
|
||||
}
|
||||
|
||||
TYPE_TOKENS = {
|
||||
|
@ -190,6 +197,7 @@ class DuckDB(Dialect):
|
|||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||
exp.ArraySort: _array_sort_sql,
|
||||
exp.ArraySum: rename_func("LIST_SUM"),
|
||||
exp.BitwiseXor: lambda self, e: self.func("XOR", e.this, e.expression),
|
||||
exp.CommentColumnConstraint: no_comment_column_constraint_sql,
|
||||
exp.CurrentDate: lambda self, e: "CURRENT_DATE",
|
||||
exp.CurrentTime: lambda self, e: "CURRENT_TIME",
|
||||
|
@ -203,7 +211,7 @@ class DuckDB(Dialect):
|
|||
exp.DateFromParts: rename_func("MAKE_DATE"),
|
||||
exp.DateSub: _date_delta_sql,
|
||||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF", f"'{e.args.get('unit', 'day')}'", e.expression, e.this
|
||||
"DATE_DIFF", f"'{e.args.get('unit') or 'day'}'", e.expression, e.this
|
||||
),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)",
|
||||
|
@ -217,8 +225,15 @@ class DuckDB(Dialect):
|
|||
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.MonthsBetween: lambda self, e: self.func(
|
||||
"DATEDIFF",
|
||||
"'month'",
|
||||
exp.cast(e.expression, "timestamp"),
|
||||
exp.cast(e.this, "timestamp"),
|
||||
),
|
||||
exp.Properties: no_properties_sql,
|
||||
exp.RegexpExtract: regexp_extract_sql,
|
||||
exp.RegexpReplace: regexp_replace_sql,
|
||||
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
|
||||
exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"),
|
||||
exp.SafeDivide: no_safe_divide_sql,
|
||||
|
|
|
@ -18,6 +18,7 @@ from sqlglot.dialects.dialect import (
|
|||
no_safe_divide_sql,
|
||||
no_trycast_sql,
|
||||
regexp_extract_sql,
|
||||
regexp_replace_sql,
|
||||
rename_func,
|
||||
right_to_substring_sql,
|
||||
strposition_to_locate_sql,
|
||||
|
@ -211,6 +212,7 @@ class Hive(Dialect):
|
|||
"ADD JAR": TokenType.COMMAND,
|
||||
"ADD JARS": TokenType.COMMAND,
|
||||
"MSCK REPAIR": TokenType.COMMAND,
|
||||
"REFRESH": TokenType.COMMAND,
|
||||
"WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES,
|
||||
}
|
||||
|
||||
|
@ -270,6 +272,11 @@ class Hive(Dialect):
|
|||
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
"TRANSFORM": lambda self: self._parse_transform(),
|
||||
}
|
||||
|
||||
PROPERTY_PARSERS = {
|
||||
**parser.Parser.PROPERTY_PARSERS,
|
||||
"WITH SERDEPROPERTIES": lambda self: exp.SerdeProperties(
|
||||
|
@ -277,6 +284,40 @@ class Hive(Dialect):
|
|||
),
|
||||
}
|
||||
|
||||
def _parse_transform(self) -> exp.Transform | exp.QueryTransform:
|
||||
args = self._parse_csv(self._parse_lambda)
|
||||
self._match_r_paren()
|
||||
|
||||
row_format_before = self._parse_row_format(match_row=True)
|
||||
|
||||
record_writer = None
|
||||
if self._match_text_seq("RECORDWRITER"):
|
||||
record_writer = self._parse_string()
|
||||
|
||||
if not self._match(TokenType.USING):
|
||||
return exp.Transform.from_arg_list(args)
|
||||
|
||||
command_script = self._parse_string()
|
||||
|
||||
self._match(TokenType.ALIAS)
|
||||
schema = self._parse_schema()
|
||||
|
||||
row_format_after = self._parse_row_format(match_row=True)
|
||||
record_reader = None
|
||||
if self._match_text_seq("RECORDREADER"):
|
||||
record_reader = self._parse_string()
|
||||
|
||||
return self.expression(
|
||||
exp.QueryTransform,
|
||||
expressions=args,
|
||||
command_script=command_script,
|
||||
schema=schema,
|
||||
row_format_before=row_format_before,
|
||||
record_writer=record_writer,
|
||||
row_format_after=row_format_after,
|
||||
record_reader=record_reader,
|
||||
)
|
||||
|
||||
def _parse_types(
|
||||
self, check_func: bool = False, schema: bool = False
|
||||
) -> t.Optional[exp.Expression]:
|
||||
|
@ -363,11 +404,13 @@ class Hive(Dialect):
|
|||
exp.Max: max_or_greatest,
|
||||
exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)),
|
||||
exp.Min: min_or_least,
|
||||
exp.MonthsBetween: lambda self, e: self.func("MONTHS_BETWEEN", e.this, e.expression),
|
||||
exp.VarMap: var_map_sql,
|
||||
exp.Create: create_with_partitions_sql,
|
||||
exp.Quantile: rename_func("PERCENTILE"),
|
||||
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
|
||||
exp.RegexpExtract: regexp_extract_sql,
|
||||
exp.RegexpReplace: regexp_replace_sql,
|
||||
exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"),
|
||||
exp.RegexpSplit: rename_func("SPLIT"),
|
||||
exp.Right: right_to_substring_sql,
|
||||
|
@ -396,7 +439,6 @@ class Hive(Dialect):
|
|||
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
||||
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
|
||||
exp.RowFormatSerdeProperty: lambda self, e: f"ROW FORMAT SERDE {self.sql(e, 'this')}",
|
||||
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
|
||||
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
|
||||
exp.LastDateOfMonth: rename_func("LAST_DAY"),
|
||||
|
@ -410,6 +452,11 @@ class Hive(Dialect):
|
|||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def rowformatserdeproperty_sql(self, expression: exp.RowFormatSerdeProperty) -> str:
|
||||
serde_props = self.sql(expression, "serde_properties")
|
||||
serde_props = f" {serde_props}" if serde_props else ""
|
||||
return f"ROW FORMAT SERDE {self.sql(expression, 'this')}{serde_props}"
|
||||
|
||||
def arrayagg_sql(self, expression: exp.ArrayAgg) -> str:
|
||||
return self.func(
|
||||
"COLLECT_LIST",
|
||||
|
|
|
@ -427,6 +427,7 @@ class MySQL(Dialect):
|
|||
TABLE_HINTS = True
|
||||
DUPLICATE_KEY_UPDATE_WITH_SET = False
|
||||
QUERY_HINT_SEP = " "
|
||||
VALUES_AS_TABLE = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -473,19 +474,32 @@ class MySQL(Dialect):
|
|||
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
|
||||
# MySQL doesn't support many datatypes in cast.
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/cast-functions.html#function_cast
|
||||
CAST_MAPPING = {
|
||||
exp.DataType.Type.BIGINT: "SIGNED",
|
||||
exp.DataType.Type.BOOLEAN: "SIGNED",
|
||||
exp.DataType.Type.INT: "SIGNED",
|
||||
exp.DataType.Type.TEXT: "CHAR",
|
||||
exp.DataType.Type.UBIGINT: "UNSIGNED",
|
||||
exp.DataType.Type.VARCHAR: "CHAR",
|
||||
}
|
||||
|
||||
def xor_sql(self, expression: exp.Xor) -> str:
|
||||
if expression.expressions:
|
||||
return self.expressions(expression, sep=" XOR ")
|
||||
return super().xor_sql(expression)
|
||||
|
||||
def jsonarraycontains_sql(self, expression: exp.JSONArrayContains) -> str:
|
||||
return f"{self.sql(expression, 'this')} MEMBER OF({self.sql(expression, 'expression')})"
|
||||
|
||||
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
|
||||
"""(U)BIGINT is not allowed in a CAST expression, so we use (UN)SIGNED instead."""
|
||||
if expression.to.this == exp.DataType.Type.BIGINT:
|
||||
to = "SIGNED"
|
||||
elif expression.to.this == exp.DataType.Type.UBIGINT:
|
||||
to = "UNSIGNED"
|
||||
else:
|
||||
return super().cast_sql(expression)
|
||||
to = self.CAST_MAPPING.get(expression.to.this)
|
||||
|
||||
return f"CAST({self.sql(expression, 'this')} AS {to})"
|
||||
if to:
|
||||
expression = expression.copy()
|
||||
expression.to.set("this", to)
|
||||
return super().cast_sql(expression)
|
||||
|
||||
def show_sql(self, expression: exp.Show) -> str:
|
||||
this = f" {expression.name}"
|
||||
|
|
|
@ -282,7 +282,6 @@ class Postgres(Dialect):
|
|||
VAR_SINGLE_TOKENS = {"$"}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
STRICT_CAST = False
|
||||
CONCAT_NULL_OUTPUTS_STRING = True
|
||||
|
||||
FUNCTIONS = {
|
||||
|
@ -318,6 +317,11 @@ class Postgres(Dialect):
|
|||
TokenType.LT_AT: binary_range_parser(exp.ArrayContained),
|
||||
}
|
||||
|
||||
STATEMENT_PARSERS = {
|
||||
**parser.Parser.STATEMENT_PARSERS,
|
||||
TokenType.END: lambda self: self._parse_commit_or_rollback(),
|
||||
}
|
||||
|
||||
def _parse_factor(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_tokens(self._parse_exponent, self.FACTOR)
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import typing as t
|
|||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
binary_from_function,
|
||||
date_trunc_to_time,
|
||||
format_time_lambda,
|
||||
if_sql,
|
||||
|
@ -198,6 +199,10 @@ class Presto(Dialect):
|
|||
**parser.Parser.FUNCTIONS,
|
||||
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
|
||||
"APPROX_PERCENTILE": _approx_percentile,
|
||||
"BITWISE_AND": binary_from_function(exp.BitwiseAnd),
|
||||
"BITWISE_NOT": lambda args: exp.BitwiseNot(this=seq_get(args, 0)),
|
||||
"BITWISE_OR": binary_from_function(exp.BitwiseOr),
|
||||
"BITWISE_XOR": binary_from_function(exp.BitwiseXor),
|
||||
"CARDINALITY": exp.ArraySize.from_arg_list,
|
||||
"CONTAINS": exp.ArrayContains.from_arg_list,
|
||||
"DATE_ADD": lambda args: exp.DateAdd(
|
||||
|
|
|
@ -27,6 +27,11 @@ class Redshift(Postgres):
|
|||
class Parser(Postgres.Parser):
|
||||
FUNCTIONS = {
|
||||
**Postgres.Parser.FUNCTIONS,
|
||||
"ADD_MONTHS": lambda args: exp.DateAdd(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
expression=seq_get(args, 1),
|
||||
unit=exp.var("month"),
|
||||
),
|
||||
"DATEADD": lambda args: exp.DateAdd(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 2)),
|
||||
expression=seq_get(args, 1),
|
||||
|
@ -37,7 +42,6 @@ class Redshift(Postgres):
|
|||
expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
|
||||
unit=seq_get(args, 0),
|
||||
),
|
||||
"NVL": exp.Coalesce.from_arg_list,
|
||||
"STRTOL": exp.FromBase.from_arg_list,
|
||||
}
|
||||
|
||||
|
@ -87,6 +91,7 @@ class Redshift(Postgres):
|
|||
LOCKING_READS_SUPPORTED = False
|
||||
RENAME_TABLE_WITH_DB = False
|
||||
QUERY_HINTS = False
|
||||
VALUES_AS_TABLE = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**Postgres.Generator.TYPE_MAPPING,
|
||||
|
@ -129,40 +134,6 @@ class Redshift(Postgres):
|
|||
|
||||
RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"}
|
||||
|
||||
def values_sql(self, expression: exp.Values) -> str:
|
||||
"""
|
||||
Converts `VALUES...` expression into a series of unions.
|
||||
|
||||
Note: If you have a lot of unions then this will result in a large number of recursive statements to
|
||||
evaluate the expression. You may need to increase `sys.setrecursionlimit` to run and it can also be
|
||||
very slow.
|
||||
"""
|
||||
|
||||
# The VALUES clause is still valid in an `INSERT INTO ..` statement, for example
|
||||
if not expression.find_ancestor(exp.From, exp.Join):
|
||||
return super().values_sql(expression)
|
||||
|
||||
column_names = expression.alias and expression.args["alias"].columns
|
||||
|
||||
selects = []
|
||||
rows = [tuple_exp.expressions for tuple_exp in expression.expressions]
|
||||
|
||||
for i, row in enumerate(rows):
|
||||
if i == 0 and column_names:
|
||||
row = [
|
||||
exp.alias_(value, column_name)
|
||||
for value, column_name in zip(row, column_names)
|
||||
]
|
||||
|
||||
selects.append(exp.Select(expressions=row))
|
||||
|
||||
subquery_expression: exp.Select | exp.Union = selects[0]
|
||||
if len(selects) > 1:
|
||||
for select in selects[1:]:
|
||||
subquery_expression = exp.union(subquery_expression, select, distinct=False)
|
||||
|
||||
return self.subquery_sql(subquery_expression.subquery(expression.alias))
|
||||
|
||||
def with_properties(self, properties: exp.Properties) -> str:
|
||||
"""Redshift doesn't have `WITH` as part of their with_properties so we remove it"""
|
||||
return self.properties(properties, prefix=" ", suffix="")
|
||||
|
|
|
@ -30,7 +30,7 @@ def _check_int(s: str) -> bool:
|
|||
|
||||
|
||||
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
|
||||
def _snowflake_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]:
|
||||
def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]:
|
||||
if len(args) == 2:
|
||||
first_arg, second_arg = args
|
||||
if second_arg.is_string:
|
||||
|
@ -137,7 +137,7 @@ def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]:
|
|||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/div0
|
||||
def _div0_to_if(args: t.List) -> exp.Expression:
|
||||
def _div0_to_if(args: t.List) -> exp.If:
|
||||
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
|
||||
true = exp.Literal.number(0)
|
||||
false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1))
|
||||
|
@ -145,13 +145,13 @@ def _div0_to_if(args: t.List) -> exp.Expression:
|
|||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
|
||||
def _zeroifnull_to_if(args: t.List) -> exp.Expression:
|
||||
def _zeroifnull_to_if(args: t.List) -> exp.If:
|
||||
cond = exp.Is(this=seq_get(args, 0), expression=exp.Null())
|
||||
return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0))
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
|
||||
def _nullifzero_to_if(args: t.List) -> exp.Expression:
|
||||
def _nullifzero_to_if(args: t.List) -> exp.If:
|
||||
cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0))
|
||||
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
|
||||
|
||||
|
@ -164,12 +164,21 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
|
|||
return self.datatype_sql(expression)
|
||||
|
||||
|
||||
def _parse_convert_timezone(args: t.List) -> exp.Expression:
|
||||
def _parse_convert_timezone(args: t.List) -> t.Union[exp.Anonymous, exp.AtTimeZone]:
|
||||
if len(args) == 3:
|
||||
return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args)
|
||||
return exp.AtTimeZone(this=seq_get(args, 1), zone=seq_get(args, 0))
|
||||
|
||||
|
||||
def _parse_regexp_replace(args: t.List) -> exp.RegexpReplace:
|
||||
regexp_replace = exp.RegexpReplace.from_arg_list(args)
|
||||
|
||||
if not regexp_replace.args.get("replacement"):
|
||||
regexp_replace.set("replacement", exp.Literal.string(""))
|
||||
|
||||
return regexp_replace
|
||||
|
||||
|
||||
class Snowflake(Dialect):
|
||||
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = True
|
||||
|
@ -223,13 +232,14 @@ class Snowflake(Dialect):
|
|||
"IFF": exp.If.from_arg_list,
|
||||
"NULLIFZERO": _nullifzero_to_if,
|
||||
"OBJECT_CONSTRUCT": _parse_object_construct,
|
||||
"REGEXP_REPLACE": _parse_regexp_replace,
|
||||
"REGEXP_SUBSTR": exp.RegexpExtract.from_arg_list,
|
||||
"RLIKE": exp.RegexpLike.from_arg_list,
|
||||
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
|
||||
"TIMEDIFF": _parse_datediff,
|
||||
"TIMESTAMPDIFF": _parse_datediff,
|
||||
"TO_ARRAY": exp.Array.from_arg_list,
|
||||
"TO_TIMESTAMP": _snowflake_to_timestamp,
|
||||
"TO_TIMESTAMP": _parse_to_timestamp,
|
||||
"TO_VARCHAR": exp.ToChar.from_arg_list,
|
||||
"ZEROIFNULL": _zeroifnull_to_if,
|
||||
}
|
||||
|
@ -242,7 +252,6 @@ class Snowflake(Dialect):
|
|||
|
||||
FUNC_TOKENS = {
|
||||
*parser.Parser.FUNC_TOKENS,
|
||||
TokenType.RLIKE,
|
||||
TokenType.TABLE,
|
||||
}
|
||||
|
||||
|
|
|
@ -2,9 +2,11 @@ from __future__ import annotations
|
|||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, parser, transforms
|
||||
from sqlglot import exp, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
binary_from_function,
|
||||
create_with_partitions_sql,
|
||||
format_time_lambda,
|
||||
pivot_column_names,
|
||||
rename_func,
|
||||
trim_sql,
|
||||
|
@ -108,47 +110,36 @@ class Spark2(Hive):
|
|||
class Parser(Hive.Parser):
|
||||
FUNCTIONS = {
|
||||
**Hive.Parser.FUNCTIONS,
|
||||
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
|
||||
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
|
||||
"SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
|
||||
this=seq_get(args, 0),
|
||||
expression=seq_get(args, 1),
|
||||
),
|
||||
"SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
|
||||
this=seq_get(args, 0),
|
||||
expression=seq_get(args, 1),
|
||||
),
|
||||
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
||||
"IIF": exp.If.from_arg_list,
|
||||
"AGGREGATE": exp.Reduce.from_arg_list,
|
||||
"DAYOFWEEK": lambda args: exp.DayOfWeek(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"DAYOFMONTH": lambda args: exp.DayOfMonth(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"DAYOFYEAR": lambda args: exp.DayOfYear(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
|
||||
this=seq_get(args, 1),
|
||||
unit=exp.var(seq_get(args, 0)),
|
||||
),
|
||||
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
|
||||
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
||||
"BOOLEAN": _parse_as_cast("boolean"),
|
||||
"DATE": _parse_as_cast("date"),
|
||||
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
|
||||
this=seq_get(args, 1), unit=exp.var(seq_get(args, 0))
|
||||
),
|
||||
"DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"DOUBLE": _parse_as_cast("double"),
|
||||
"FLOAT": _parse_as_cast("float"),
|
||||
"IIF": exp.If.from_arg_list,
|
||||
"INT": _parse_as_cast("int"),
|
||||
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
|
||||
"RLIKE": exp.RegexpLike.from_arg_list,
|
||||
"SHIFTLEFT": binary_from_function(exp.BitwiseLeftShift),
|
||||
"SHIFTRIGHT": binary_from_function(exp.BitwiseRightShift),
|
||||
"STRING": _parse_as_cast("string"),
|
||||
"TIMESTAMP": _parse_as_cast("timestamp"),
|
||||
"TO_TIMESTAMP": lambda args: _parse_as_cast("timestamp")(args)
|
||||
if len(args) == 1
|
||||
else format_time_lambda(exp.StrToTime, "spark")(args),
|
||||
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
|
||||
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
|
||||
"WEEKOFYEAR": lambda args: exp.WeekOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
**Hive.Parser.FUNCTION_PARSERS,
|
||||
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
|
||||
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
|
||||
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
|
||||
|
@ -207,6 +198,13 @@ class Spark2(Hive):
|
|||
exp.Map: _map_sql,
|
||||
exp.Pivot: transforms.preprocess([_unqualify_pivot_columns]),
|
||||
exp.Reduce: rename_func("AGGREGATE"),
|
||||
exp.RegexpReplace: lambda self, e: self.func(
|
||||
"REGEXP_REPLACE",
|
||||
e.this,
|
||||
e.expression,
|
||||
e.args["replacement"],
|
||||
e.args.get("position"),
|
||||
),
|
||||
exp.StrToDate: _str_to_date,
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimestampTrunc: lambda self, e: self.func(
|
||||
|
@ -224,6 +222,7 @@ class Spark2(Hive):
|
|||
TRANSFORMS.pop(exp.ArraySort)
|
||||
TRANSFORMS.pop(exp.ILike)
|
||||
TRANSFORMS.pop(exp.Left)
|
||||
TRANSFORMS.pop(exp.MonthsBetween)
|
||||
TRANSFORMS.pop(exp.Right)
|
||||
|
||||
WRAP_DERIVED_VALUES = False
|
||||
|
|
|
@ -20,6 +20,8 @@ class StarRocks(MySQL):
|
|||
}
|
||||
|
||||
class Generator(MySQL.Generator):
|
||||
CAST_MAPPING = {}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**MySQL.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.TEXT: "STRING",
|
||||
|
|
|
@ -138,7 +138,8 @@ def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.Tim
|
|||
if isinstance(expression, exp.NumberToStr)
|
||||
else exp.Literal.string(
|
||||
format_time(
|
||||
expression.text("format"), t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING)
|
||||
expression.text("format"),
|
||||
t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
@ -314,7 +315,9 @@ class TSQL(Dialect):
|
|||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"CHARINDEX": lambda args: exp.StrPosition(
|
||||
this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
|
||||
this=seq_get(args, 1),
|
||||
substr=seq_get(args, 0),
|
||||
position=seq_get(args, 2),
|
||||
),
|
||||
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
|
||||
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
|
||||
|
@ -365,6 +368,55 @@ class TSQL(Dialect):
|
|||
|
||||
CONCAT_NULL_OUTPUTS_STRING = True
|
||||
|
||||
def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback:
|
||||
"""Applies to SQL Server and Azure SQL Database
|
||||
COMMIT [ { TRAN | TRANSACTION }
|
||||
[ transaction_name | @tran_name_variable ] ]
|
||||
[ WITH ( DELAYED_DURABILITY = { OFF | ON } ) ]
|
||||
|
||||
ROLLBACK { TRAN | TRANSACTION }
|
||||
[ transaction_name | @tran_name_variable
|
||||
| savepoint_name | @savepoint_variable ]
|
||||
"""
|
||||
rollback = self._prev.token_type == TokenType.ROLLBACK
|
||||
|
||||
self._match_texts({"TRAN", "TRANSACTION"})
|
||||
this = self._parse_id_var()
|
||||
|
||||
if rollback:
|
||||
return self.expression(exp.Rollback, this=this)
|
||||
|
||||
durability = None
|
||||
if self._match_pair(TokenType.WITH, TokenType.L_PAREN):
|
||||
self._match_text_seq("DELAYED_DURABILITY")
|
||||
self._match(TokenType.EQ)
|
||||
|
||||
if self._match_text_seq("OFF"):
|
||||
durability = False
|
||||
else:
|
||||
self._match(TokenType.ON)
|
||||
durability = True
|
||||
|
||||
self._match_r_paren()
|
||||
|
||||
return self.expression(exp.Commit, this=this, durability=durability)
|
||||
|
||||
def _parse_transaction(self) -> exp.Transaction | exp.Command:
|
||||
"""Applies to SQL Server and Azure SQL Database
|
||||
BEGIN { TRAN | TRANSACTION }
|
||||
[ { transaction_name | @tran_name_variable }
|
||||
[ WITH MARK [ 'description' ] ]
|
||||
]
|
||||
"""
|
||||
if self._match_texts(("TRAN", "TRANSACTION")):
|
||||
transaction = self.expression(exp.Transaction, this=self._parse_id_var())
|
||||
if self._match_text_seq("WITH", "MARK"):
|
||||
transaction.set("mark", self._parse_string())
|
||||
|
||||
return transaction
|
||||
|
||||
return self._parse_as_command(self._prev)
|
||||
|
||||
def _parse_system_time(self) -> t.Optional[exp.Expression]:
|
||||
if not self._match_text_seq("FOR", "SYSTEM_TIME"):
|
||||
return None
|
||||
|
@ -496,7 +548,9 @@ class TSQL(Dialect):
|
|||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
|
||||
exp.SHA2: lambda self, e: self.func(
|
||||
"HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
|
||||
"HASHBYTES",
|
||||
exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"),
|
||||
e.this,
|
||||
),
|
||||
exp.TimeToStr: _format_sql,
|
||||
}
|
||||
|
@ -539,3 +593,26 @@ class TSQL(Dialect):
|
|||
into = self.sql(expression, "into")
|
||||
into = self.seg(f"INTO {into}") if into else ""
|
||||
return f"{self.seg('OUTPUT')} {self.expressions(expression, flat=True)}{into}"
|
||||
|
||||
def transaction_sql(self, expression: exp.Transaction) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
this = f" {this}" if this else ""
|
||||
mark = self.sql(expression, "mark")
|
||||
mark = f" WITH MARK {mark}" if mark else ""
|
||||
return f"BEGIN TRANSACTION{this}{mark}"
|
||||
|
||||
def commit_sql(self, expression: exp.Commit) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
this = f" {this}" if this else ""
|
||||
durability = expression.args.get("durability")
|
||||
durability = (
|
||||
f" WITH (DELAYED_DURABILITY = {'ON' if durability else 'OFF'})"
|
||||
if durability is not None
|
||||
else ""
|
||||
)
|
||||
return f"COMMIT TRANSACTION{this}{durability}"
|
||||
|
||||
def rollback_sql(self, expression: exp.Rollback) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
this = f" {this}" if this else ""
|
||||
return f"ROLLBACK TRANSACTION{this}"
|
||||
|
|
|
@ -759,12 +759,24 @@ class Condition(Expression):
|
|||
)
|
||||
|
||||
def isin(
|
||||
self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy: bool = True, **opts
|
||||
self,
|
||||
*expressions: t.Any,
|
||||
query: t.Optional[ExpOrStr] = None,
|
||||
unnest: t.Optional[ExpOrStr] | t.Collection[ExpOrStr] = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
) -> In:
|
||||
return In(
|
||||
this=_maybe_copy(self, copy),
|
||||
expressions=[convert(e, copy=copy) for e in expressions],
|
||||
query=maybe_parse(query, copy=copy, **opts) if query else None,
|
||||
unnest=Unnest(
|
||||
expressions=[
|
||||
maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts) for e in ensure_list(unnest)
|
||||
]
|
||||
)
|
||||
if unnest
|
||||
else None,
|
||||
)
|
||||
|
||||
def between(self, low: t.Any, high: t.Any, copy: bool = True, **opts) -> Between:
|
||||
|
@ -2019,7 +2031,20 @@ class RowFormatDelimitedProperty(Property):
|
|||
|
||||
|
||||
class RowFormatSerdeProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
arg_types = {"this": True, "serde_properties": False}
|
||||
|
||||
|
||||
# https://spark.apache.org/docs/3.1.2/sql-ref-syntax-qry-select-transform.html
|
||||
class QueryTransform(Expression):
|
||||
arg_types = {
|
||||
"expressions": True,
|
||||
"command_script": True,
|
||||
"schema": False,
|
||||
"row_format_before": False,
|
||||
"record_writer": False,
|
||||
"row_format_after": False,
|
||||
"record_reader": False,
|
||||
}
|
||||
|
||||
|
||||
class SchemaCommentProperty(Property):
|
||||
|
@ -2149,12 +2174,24 @@ class Tuple(Expression):
|
|||
arg_types = {"expressions": False}
|
||||
|
||||
def isin(
|
||||
self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy: bool = True, **opts
|
||||
self,
|
||||
*expressions: t.Any,
|
||||
query: t.Optional[ExpOrStr] = None,
|
||||
unnest: t.Optional[ExpOrStr] | t.Collection[ExpOrStr] = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
) -> In:
|
||||
return In(
|
||||
this=_maybe_copy(self, copy),
|
||||
expressions=[convert(e, copy=copy) for e in expressions],
|
||||
query=maybe_parse(query, copy=copy, **opts) if query else None,
|
||||
unnest=Unnest(
|
||||
expressions=[
|
||||
maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts) for e in ensure_list(unnest)
|
||||
]
|
||||
)
|
||||
if unnest
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
|
@ -3478,15 +3515,15 @@ class Command(Expression):
|
|||
|
||||
|
||||
class Transaction(Expression):
|
||||
arg_types = {"this": False, "modes": False}
|
||||
arg_types = {"this": False, "modes": False, "mark": False}
|
||||
|
||||
|
||||
class Commit(Expression):
|
||||
arg_types = {"chain": False}
|
||||
arg_types = {"chain": False, "this": False, "durability": False}
|
||||
|
||||
|
||||
class Rollback(Expression):
|
||||
arg_types = {"savepoint": False}
|
||||
arg_types = {"savepoint": False, "this": False}
|
||||
|
||||
|
||||
class AlterTable(Expression):
|
||||
|
@ -3530,10 +3567,6 @@ class Or(Connector):
|
|||
pass
|
||||
|
||||
|
||||
class Xor(Connector):
|
||||
pass
|
||||
|
||||
|
||||
class BitwiseAnd(Binary):
|
||||
pass
|
||||
|
||||
|
@ -3856,6 +3889,11 @@ class Abs(Func):
|
|||
pass
|
||||
|
||||
|
||||
# https://spark.apache.org/docs/latest/api/sql/index.html#transform
|
||||
class Transform(Func):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
class Anonymous(Func):
|
||||
arg_types = {"this": True, "expressions": False}
|
||||
is_var_len_args = True
|
||||
|
@ -4098,6 +4136,10 @@ class WeekOfYear(Func):
|
|||
_sql_names = ["WEEK_OF_YEAR", "WEEKOFYEAR"]
|
||||
|
||||
|
||||
class MonthsBetween(Func):
|
||||
arg_types = {"this": True, "expression": True, "roundoff": False}
|
||||
|
||||
|
||||
class LastDateOfMonth(Func):
|
||||
pass
|
||||
|
||||
|
@ -4209,6 +4251,10 @@ class Hex(Func):
|
|||
pass
|
||||
|
||||
|
||||
class Xor(Connector, Func):
|
||||
arg_types = {"this": False, "expression": False, "expressions": False}
|
||||
|
||||
|
||||
class If(Func):
|
||||
arg_types = {"this": True, "true": True, "false": False}
|
||||
|
||||
|
@ -4431,7 +4477,18 @@ class RegexpExtract(Func):
|
|||
}
|
||||
|
||||
|
||||
class RegexpLike(Func):
|
||||
class RegexpReplace(Func):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"expression": True,
|
||||
"replacement": True,
|
||||
"position": False,
|
||||
"occurrence": False,
|
||||
"parameters": False,
|
||||
}
|
||||
|
||||
|
||||
class RegexpLike(Binary, Func):
|
||||
arg_types = {"this": True, "expression": True, "flag": False}
|
||||
|
||||
|
||||
|
|
|
@ -164,6 +164,11 @@ class Generator:
|
|||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax
|
||||
SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE")
|
||||
|
||||
# Whether or not VALUES statements can be used as derived tables.
|
||||
# MySQL 5 and Redshift do not allow this, so when False, it will convert
|
||||
# SELECT * VALUES into SELECT UNION
|
||||
VALUES_AS_TABLE = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
exp.DataType.Type.NCHAR: "CHAR",
|
||||
exp.DataType.Type.NVARCHAR: "VARCHAR",
|
||||
|
@ -260,8 +265,9 @@ class Generator:
|
|||
|
||||
# Expressions whose comments are separated from them for better formatting
|
||||
WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = (
|
||||
exp.Select,
|
||||
exp.Drop,
|
||||
exp.From,
|
||||
exp.Select,
|
||||
exp.Where,
|
||||
exp.With,
|
||||
)
|
||||
|
@ -818,7 +824,11 @@ class Generator:
|
|||
|
||||
def datatype_sql(self, expression: exp.DataType) -> str:
|
||||
type_value = expression.this
|
||||
type_sql = self.TYPE_MAPPING.get(type_value, type_value.value)
|
||||
type_sql = (
|
||||
self.TYPE_MAPPING.get(type_value, type_value.value)
|
||||
if isinstance(type_value, exp.DataType.Type)
|
||||
else type_value
|
||||
)
|
||||
nested = ""
|
||||
interior = self.expressions(expression, flat=True)
|
||||
values = ""
|
||||
|
@ -1307,15 +1317,45 @@ class Generator:
|
|||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
def values_sql(self, expression: exp.Values) -> str:
|
||||
args = self.expressions(expression)
|
||||
alias = self.sql(expression, "alias")
|
||||
values = f"VALUES{self.seg('')}{args}"
|
||||
values = (
|
||||
f"({values})"
|
||||
if self.WRAP_DERIVED_VALUES and (alias or isinstance(expression.parent, exp.From))
|
||||
else values
|
||||
)
|
||||
return f"{values} AS {alias}" if alias else values
|
||||
# The VALUES clause is still valid in an `INSERT INTO ..` statement, for example
|
||||
if self.VALUES_AS_TABLE or not expression.find_ancestor(exp.From, exp.Join):
|
||||
args = self.expressions(expression)
|
||||
alias = self.sql(expression, "alias")
|
||||
values = f"VALUES{self.seg('')}{args}"
|
||||
values = (
|
||||
f"({values})"
|
||||
if self.WRAP_DERIVED_VALUES and (alias or isinstance(expression.parent, exp.From))
|
||||
else values
|
||||
)
|
||||
return f"{values} AS {alias}" if alias else values
|
||||
|
||||
# Converts `VALUES...` expression into a series of select unions.
|
||||
# Note: If you have a lot of unions then this will result in a large number of recursive statements to
|
||||
# evaluate the expression. You may need to increase `sys.setrecursionlimit` to run and it can also be
|
||||
# very slow.
|
||||
expression = expression.copy()
|
||||
column_names = expression.alias and expression.args["alias"].columns
|
||||
|
||||
selects = []
|
||||
|
||||
for i, tup in enumerate(expression.expressions):
|
||||
row = tup.expressions
|
||||
|
||||
if i == 0 and column_names:
|
||||
row = [
|
||||
exp.alias_(value, column_name) for value, column_name in zip(row, column_names)
|
||||
]
|
||||
|
||||
selects.append(exp.Select(expressions=row))
|
||||
|
||||
subquery_expression: exp.Select | exp.Union = selects[0]
|
||||
if len(selects) > 1:
|
||||
for select in selects[1:]:
|
||||
subquery_expression = exp.union(
|
||||
subquery_expression, select, distinct=False, copy=False
|
||||
)
|
||||
|
||||
return self.subquery_sql(subquery_expression.subquery(expression.alias, copy=False))
|
||||
|
||||
def var_sql(self, expression: exp.Var) -> str:
|
||||
return self.sql(expression, "this")
|
||||
|
@ -2043,7 +2083,7 @@ class Generator:
|
|||
def and_sql(self, expression: exp.And) -> str:
|
||||
return self.connector_sql(expression, "AND")
|
||||
|
||||
def xor_sql(self, expression: exp.And) -> str:
|
||||
def xor_sql(self, expression: exp.Xor) -> str:
|
||||
return self.connector_sql(expression, "XOR")
|
||||
|
||||
def connector_sql(self, expression: exp.Connector, op: str) -> str:
|
||||
|
@ -2507,6 +2547,21 @@ class Generator:
|
|||
|
||||
return self.func("ANY_VALUE", this)
|
||||
|
||||
def querytransform_sql(self, expression: exp.QueryTransform) -> str:
|
||||
transform = self.func("TRANSFORM", *expression.expressions)
|
||||
row_format_before = self.sql(expression, "row_format_before")
|
||||
row_format_before = f" {row_format_before}" if row_format_before else ""
|
||||
record_writer = self.sql(expression, "record_writer")
|
||||
record_writer = f" RECORDWRITER {record_writer}" if record_writer else ""
|
||||
using = f" USING {self.sql(expression, 'command_script')}"
|
||||
schema = self.sql(expression, "schema")
|
||||
schema = f" AS {schema}" if schema else ""
|
||||
row_format_after = self.sql(expression, "row_format_after")
|
||||
row_format_after = f" {row_format_after}" if row_format_after else ""
|
||||
record_reader = self.sql(expression, "record_reader")
|
||||
record_reader = f" RECORDREADER {record_reader}" if record_reader else ""
|
||||
return f"{transform}{row_format_before}{record_writer}{using}{schema}{row_format_after}{record_reader}"
|
||||
|
||||
|
||||
def cached_generator(
|
||||
cache: t.Optional[t.Dict[int, str]] = None
|
||||
|
|
|
@ -79,7 +79,7 @@ def lineage(
|
|||
raise SqlglotError("Cannot build lineage, sql must be SELECT")
|
||||
|
||||
def to_node(
|
||||
column_name: str,
|
||||
column: str | int,
|
||||
scope: Scope,
|
||||
scope_name: t.Optional[str] = None,
|
||||
upstream: t.Optional[Node] = None,
|
||||
|
@ -90,26 +90,38 @@ def lineage(
|
|||
for dt in scope.derived_tables
|
||||
if dt.comments and dt.comments[0].startswith("source: ")
|
||||
}
|
||||
if isinstance(scope.expression, exp.Union):
|
||||
for scope in scope.union_scopes:
|
||||
node = to_node(
|
||||
column_name,
|
||||
scope=scope,
|
||||
scope_name=scope_name,
|
||||
upstream=upstream,
|
||||
alias=aliases.get(scope_name),
|
||||
)
|
||||
return node
|
||||
|
||||
# Find the specific select clause that is the source of the column we want.
|
||||
# This can either be a specific, named select or a generic `*` clause.
|
||||
select = next(
|
||||
(select for select in scope.expression.selects if select.alias_or_name == column_name),
|
||||
exp.Star() if scope.expression.is_star else None,
|
||||
select = (
|
||||
scope.expression.selects[column]
|
||||
if isinstance(column, int)
|
||||
else next(
|
||||
(select for select in scope.expression.selects if select.alias_or_name == column),
|
||||
exp.Star() if scope.expression.is_star else None,
|
||||
)
|
||||
)
|
||||
|
||||
if not select:
|
||||
raise ValueError(f"Could not find {column_name} in {scope.expression}")
|
||||
raise ValueError(f"Could not find {column} in {scope.expression}")
|
||||
|
||||
if isinstance(scope.expression, exp.Union):
|
||||
upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
|
||||
|
||||
index = (
|
||||
column
|
||||
if isinstance(column, int)
|
||||
else next(
|
||||
i
|
||||
for i, select in enumerate(scope.expression.selects)
|
||||
if select.alias_or_name == column
|
||||
)
|
||||
)
|
||||
|
||||
for s in scope.union_scopes:
|
||||
to_node(index, scope=s, upstream=upstream)
|
||||
|
||||
return upstream
|
||||
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
# For better ergonomics in our node labels, replace the full select with
|
||||
|
@ -122,7 +134,7 @@ def lineage(
|
|||
|
||||
# Create the node for this step in the lineage chain, and attach it to the previous one.
|
||||
node = Node(
|
||||
name=f"{scope_name}.{column_name}" if scope_name else column_name,
|
||||
name=f"{scope_name}.{column}" if scope_name else str(column),
|
||||
source=source,
|
||||
expression=select,
|
||||
alias=alias or "",
|
||||
|
|
|
@ -144,8 +144,9 @@ def _eliminate_derived_table(scope, existing_ctes, taken):
|
|||
name, cte = _new_cte(scope, existing_ctes, taken)
|
||||
|
||||
table = exp.alias_(exp.table_(name), alias=parent.alias or name)
|
||||
parent.replace(table)
|
||||
table.set("joins", parent.args.get("joins"))
|
||||
|
||||
parent.replace(table)
|
||||
return cte
|
||||
|
||||
|
||||
|
|
|
@ -176,6 +176,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
|||
|
||||
return (
|
||||
isinstance(outer_scope.expression, exp.Select)
|
||||
and not outer_scope.expression.is_star
|
||||
and isinstance(inner_select, exp.Select)
|
||||
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
|
||||
and inner_select.args.get("from")
|
||||
|
@ -242,6 +243,7 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
|
|||
alias (str)
|
||||
"""
|
||||
new_subquery = inner_scope.expression.args["from"].this
|
||||
new_subquery.set("joins", node_to_replace.args.get("joins"))
|
||||
node_to_replace.replace(new_subquery)
|
||||
for join_hint in outer_scope.join_hints:
|
||||
tables = join_hint.find_all(exp.Table)
|
||||
|
|
|
@ -61,6 +61,9 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
|
|||
if remove_unused_selections:
|
||||
_remove_unused_selections(scope, parent_selections, schema)
|
||||
|
||||
if scope.expression.is_star:
|
||||
continue
|
||||
|
||||
# Group columns by source name
|
||||
selects = defaultdict(set)
|
||||
for col in scope.columns:
|
||||
|
|
|
@ -29,12 +29,13 @@ def qualify_columns(
|
|||
'SELECT tbl.col AS col FROM tbl'
|
||||
|
||||
Args:
|
||||
expression: expression to qualify
|
||||
schema: Database schema
|
||||
expand_alias_refs: whether or not to expand references to aliases
|
||||
infer_schema: whether or not to infer the schema if missing
|
||||
expression: Expression to qualify.
|
||||
schema: Database schema.
|
||||
expand_alias_refs: Whether or not to expand references to aliases.
|
||||
infer_schema: Whether or not to infer the schema if missing.
|
||||
|
||||
Returns:
|
||||
sqlglot.Expression: qualified expression
|
||||
The qualified expression.
|
||||
"""
|
||||
schema = ensure_schema(schema)
|
||||
infer_schema = schema.empty if infer_schema is None else infer_schema
|
||||
|
@ -410,7 +411,9 @@ def _expand_stars(
|
|||
else:
|
||||
return
|
||||
|
||||
scope.expression.set("expressions", new_selections)
|
||||
# Ensures we don't overwrite the initial selections with an empty list
|
||||
if new_selections:
|
||||
scope.expression.set("expressions", new_selections)
|
||||
|
||||
|
||||
def _add_except_columns(
|
||||
|
|
|
@ -124,8 +124,8 @@ class Scope:
|
|||
self._ctes.append(node)
|
||||
elif (
|
||||
isinstance(node, exp.Subquery)
|
||||
and isinstance(parent, (exp.From, exp.Join))
|
||||
and _is_subquery_scope(node)
|
||||
and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
|
||||
and _is_derived_table(node)
|
||||
):
|
||||
self._derived_tables.append(node)
|
||||
elif isinstance(node, exp.Subqueryable):
|
||||
|
@ -610,13 +610,13 @@ def _traverse_ctes(scope):
|
|||
scope.sources.update(sources)
|
||||
|
||||
|
||||
def _is_subquery_scope(expression: exp.Subquery) -> bool:
|
||||
def _is_derived_table(expression: exp.Subquery) -> bool:
|
||||
"""
|
||||
We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a new scope.
|
||||
If an alias is present, it shadows all names under the Subquery, so that's an
|
||||
exception to this rule.
|
||||
We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table",
|
||||
as it doesn't introduce a new scope. If an alias is present, it shadows all names
|
||||
under the Subquery, so that's one exception to this rule.
|
||||
"""
|
||||
return bool(not isinstance(expression.unnest(), exp.Table) or expression.alias)
|
||||
return bool(expression.alias or isinstance(expression.this, exp.Subqueryable))
|
||||
|
||||
|
||||
def _traverse_tables(scope):
|
||||
|
@ -654,7 +654,10 @@ def _traverse_tables(scope):
|
|||
else:
|
||||
sources[source_name] = expression
|
||||
|
||||
expressions.extend(join.this for join in expression.args.get("joins") or [])
|
||||
# Make sure to not include the joins twice
|
||||
if expression is not scope.expression:
|
||||
expressions.extend(join.this for join in expression.args.get("joins") or [])
|
||||
|
||||
continue
|
||||
|
||||
if not isinstance(expression, exp.DerivedTable):
|
||||
|
@ -664,10 +667,11 @@ def _traverse_tables(scope):
|
|||
lateral_sources = sources
|
||||
scope_type = ScopeType.UDTF
|
||||
scopes = scope.udtf_scopes
|
||||
elif _is_subquery_scope(expression):
|
||||
elif _is_derived_table(expression):
|
||||
lateral_sources = None
|
||||
scope_type = ScopeType.DERIVED_TABLE
|
||||
scopes = scope.derived_table_scopes
|
||||
expressions.extend(join.this for join in expression.args.get("joins") or [])
|
||||
else:
|
||||
# Makes sure we check for possible sources in nested table constructs
|
||||
expressions.append(expression.this)
|
||||
|
@ -735,10 +739,16 @@ def walk_in_scope(expression, bfs=True):
|
|||
isinstance(node, exp.CTE)
|
||||
or (
|
||||
isinstance(node, exp.Subquery)
|
||||
and isinstance(parent, (exp.From, exp.Join))
|
||||
and _is_subquery_scope(node)
|
||||
and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
|
||||
and _is_derived_table(node)
|
||||
)
|
||||
or isinstance(node, exp.UDTF)
|
||||
or isinstance(node, exp.Subqueryable)
|
||||
):
|
||||
prune = True
|
||||
|
||||
if isinstance(node, (exp.Subquery, exp.UDTF)):
|
||||
# The following args are not actually in the inner scope, so we should visit them
|
||||
for key in ("joins", "laterals", "pivots"):
|
||||
for arg in node.args.get(key) or []:
|
||||
yield from walk_in_scope(arg, bfs=bfs)
|
||||
|
|
|
@ -327,6 +327,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.PRIMARY_KEY,
|
||||
TokenType.RANGE,
|
||||
TokenType.REPLACE,
|
||||
TokenType.RLIKE,
|
||||
TokenType.ROW,
|
||||
TokenType.UNNEST,
|
||||
TokenType.VAR,
|
||||
|
@ -338,6 +339,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.TIMESTAMP,
|
||||
TokenType.TIMESTAMPTZ,
|
||||
TokenType.WINDOW,
|
||||
TokenType.XOR,
|
||||
*TYPE_TOKENS,
|
||||
*SUBQUERY_PREDICATES,
|
||||
}
|
||||
|
@ -505,7 +507,6 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.DESC: lambda self: self._parse_describe(),
|
||||
TokenType.DESCRIBE: lambda self: self._parse_describe(),
|
||||
TokenType.DROP: lambda self: self._parse_drop(),
|
||||
TokenType.END: lambda self: self._parse_commit_or_rollback(),
|
||||
TokenType.FROM: lambda self: exp.select("*").from_(
|
||||
t.cast(exp.From, self._parse_from(skip_from_token=True))
|
||||
),
|
||||
|
@ -716,7 +717,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
FUNCTIONS_WITH_ALIASED_ARGS = {"STRUCT"}
|
||||
|
||||
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
|
||||
FUNCTION_PARSERS = {
|
||||
"ANY_VALUE": lambda self: self._parse_any_value(),
|
||||
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
|
||||
"CONCAT": lambda self: self._parse_concat(),
|
||||
|
@ -1144,6 +1145,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return self.expression(
|
||||
exp.Drop,
|
||||
comments=start.comments,
|
||||
exists=self._parse_exists(),
|
||||
this=self._parse_table(schema=True),
|
||||
kind=kind,
|
||||
|
@ -1233,11 +1235,14 @@ class Parser(metaclass=_Parser):
|
|||
expression = self._parse_ddl_select()
|
||||
|
||||
if create_token.token_type == TokenType.TABLE:
|
||||
# exp.Properties.Location.POST_EXPRESSION
|
||||
extend_props(self._parse_properties())
|
||||
|
||||
indexes = []
|
||||
while True:
|
||||
index = self._parse_index()
|
||||
|
||||
# exp.Properties.Location.POST_EXPRESSION and POST_INDEX
|
||||
# exp.Properties.Location.POST_INDEX
|
||||
extend_props(self._parse_properties())
|
||||
|
||||
if not index:
|
||||
|
@ -1384,7 +1389,6 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_with_property(
|
||||
self,
|
||||
) -> t.Optional[exp.Expression] | t.List[t.Optional[exp.Expression]]:
|
||||
self._match(TokenType.WITH)
|
||||
if self._match(TokenType.L_PAREN, advance=False):
|
||||
return self._parse_wrapped_csv(self._parse_property)
|
||||
|
||||
|
@ -1781,7 +1785,17 @@ class Parser(metaclass=_Parser):
|
|||
return None
|
||||
|
||||
if self._match_text_seq("SERDE"):
|
||||
return self.expression(exp.RowFormatSerdeProperty, this=self._parse_string())
|
||||
this = self._parse_string()
|
||||
|
||||
serde_properties = None
|
||||
if self._match(TokenType.SERDE_PROPERTIES):
|
||||
serde_properties = self.expression(
|
||||
exp.SerdeProperties, expressions=self._parse_wrapped_csv(self._parse_property)
|
||||
)
|
||||
|
||||
return self.expression(
|
||||
exp.RowFormatSerdeProperty, this=this, serde_properties=serde_properties
|
||||
)
|
||||
|
||||
self._match_text_seq("DELIMITED")
|
||||
|
||||
|
@ -2251,7 +2265,9 @@ class Parser(metaclass=_Parser):
|
|||
self._match_set(self.JOIN_KINDS) and self._prev,
|
||||
)
|
||||
|
||||
def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Join]:
|
||||
def _parse_join(
|
||||
self, skip_join_token: bool = False, parse_bracket: bool = False
|
||||
) -> t.Optional[exp.Join]:
|
||||
if self._match(TokenType.COMMA):
|
||||
return self.expression(exp.Join, this=self._parse_table())
|
||||
|
||||
|
@ -2275,7 +2291,7 @@ class Parser(metaclass=_Parser):
|
|||
if outer_apply:
|
||||
side = Token(TokenType.LEFT, "LEFT")
|
||||
|
||||
kwargs: t.Dict[str, t.Any] = {"this": self._parse_table()}
|
||||
kwargs: t.Dict[str, t.Any] = {"this": self._parse_table(parse_bracket=parse_bracket)}
|
||||
|
||||
if method:
|
||||
kwargs["method"] = method.text
|
||||
|
@ -2411,6 +2427,7 @@ class Parser(metaclass=_Parser):
|
|||
schema: bool = False,
|
||||
joins: bool = False,
|
||||
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
|
||||
parse_bracket: bool = False,
|
||||
) -> t.Optional[exp.Expression]:
|
||||
lateral = self._parse_lateral()
|
||||
if lateral:
|
||||
|
@ -2430,7 +2447,9 @@ class Parser(metaclass=_Parser):
|
|||
subquery.set("pivots", self._parse_pivots())
|
||||
return subquery
|
||||
|
||||
this: exp.Expression = self._parse_table_parts(schema=schema)
|
||||
bracket = parse_bracket and self._parse_bracket(None)
|
||||
bracket = self.expression(exp.Table, this=bracket) if bracket else None
|
||||
this: exp.Expression = bracket or self._parse_table_parts(schema=schema)
|
||||
|
||||
if schema:
|
||||
return self._parse_schema(this=this)
|
||||
|
@ -2758,8 +2777,15 @@ class Parser(metaclass=_Parser):
|
|||
self, this: t.Optional[exp.Expression] = None, top: bool = False
|
||||
) -> t.Optional[exp.Expression]:
|
||||
if self._match(TokenType.TOP if top else TokenType.LIMIT):
|
||||
limit_paren = self._match(TokenType.L_PAREN)
|
||||
expression = self._parse_number() if top else self._parse_term()
|
||||
comments = self._prev_comments
|
||||
if top:
|
||||
limit_paren = self._match(TokenType.L_PAREN)
|
||||
expression = self._parse_number()
|
||||
|
||||
if limit_paren:
|
||||
self._match_r_paren()
|
||||
else:
|
||||
expression = self._parse_term()
|
||||
|
||||
if self._match(TokenType.COMMA):
|
||||
offset = expression
|
||||
|
@ -2767,10 +2793,9 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
offset = None
|
||||
|
||||
limit_exp = self.expression(exp.Limit, this=this, expression=expression, offset=offset)
|
||||
|
||||
if limit_paren:
|
||||
self._match_r_paren()
|
||||
limit_exp = self.expression(
|
||||
exp.Limit, this=this, expression=expression, offset=offset, comments=comments
|
||||
)
|
||||
|
||||
return limit_exp
|
||||
|
||||
|
@ -2803,7 +2828,7 @@ class Parser(metaclass=_Parser):
|
|||
if not self._match(TokenType.OFFSET):
|
||||
return this
|
||||
|
||||
count = self._parse_number()
|
||||
count = self._parse_term()
|
||||
self._match_set((TokenType.ROW, TokenType.ROWS))
|
||||
return self.expression(exp.Offset, this=this, expression=count)
|
||||
|
||||
|
@ -3320,7 +3345,7 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
this = self.expression(exp.Anonymous, this=this, expressions=args)
|
||||
|
||||
self._match_r_paren(this)
|
||||
self._match(TokenType.R_PAREN, expression=this)
|
||||
return self._parse_window(this)
|
||||
|
||||
def _parse_function_parameter(self) -> t.Optional[exp.Expression]:
|
||||
|
@ -4076,7 +4101,10 @@ class Parser(metaclass=_Parser):
|
|||
self, this: t.Optional[exp.Expression], alias: bool = False
|
||||
) -> t.Optional[exp.Expression]:
|
||||
if self._match_pair(TokenType.FILTER, TokenType.L_PAREN):
|
||||
this = self.expression(exp.Filter, this=this, expression=self._parse_where())
|
||||
self._match(TokenType.WHERE)
|
||||
this = self.expression(
|
||||
exp.Filter, this=this, expression=self._parse_where(skip_where_token=True)
|
||||
)
|
||||
self._match_r_paren()
|
||||
|
||||
# T-SQL allows the OVER (...) syntax after WITHIN GROUP.
|
||||
|
@ -4351,7 +4379,7 @@ class Parser(metaclass=_Parser):
|
|||
self._parse_set_operations(self._parse_select(nested=True, parse_subquery_alias=False))
|
||||
)
|
||||
|
||||
def _parse_transaction(self) -> exp.Transaction:
|
||||
def _parse_transaction(self) -> exp.Transaction | exp.Command:
|
||||
this = None
|
||||
if self._match_texts(self.TRANSACTION_KIND):
|
||||
this = self._prev.text
|
||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||
import typing as t
|
||||
from enum import auto
|
||||
|
||||
from sqlglot.errors import TokenError
|
||||
from sqlglot.helper import AutoName
|
||||
from sqlglot.trie import TrieResult, in_trie, new_trie
|
||||
|
||||
|
@ -800,7 +801,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
start = max(self._current - 50, 0)
|
||||
end = min(self._current + 50, self.size - 1)
|
||||
context = self.sql[start:end]
|
||||
raise ValueError(f"Error tokenizing '{context}'") from e
|
||||
raise TokenError(f"Error tokenizing '{context}'") from e
|
||||
|
||||
return self.tokens
|
||||
|
||||
|
@ -1097,7 +1098,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
try:
|
||||
int(text, base)
|
||||
except:
|
||||
raise RuntimeError(
|
||||
raise TokenError(
|
||||
f"Numeric string contains invalid characters from {self._line}:{self._start}"
|
||||
)
|
||||
else:
|
||||
|
@ -1140,7 +1141,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
if self._current + 1 < self.size:
|
||||
self._advance(2)
|
||||
else:
|
||||
raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._current}")
|
||||
raise TokenError(f"Missing {delimiter} from {self._line}:{self._current}")
|
||||
else:
|
||||
if self._chars(delim_size) == delimiter:
|
||||
if delim_size > 1:
|
||||
|
@ -1148,7 +1149,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
break
|
||||
|
||||
if self._end:
|
||||
raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}")
|
||||
raise TokenError(f"Missing {delimiter} from {self._line}:{self._start}")
|
||||
|
||||
current = self._current - 1
|
||||
self._advance(alnum=True)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue