1
0
Fork 0

Merging upstream version 17.7.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 20:46:55 +01:00
parent 87252470ef
commit 137902868c
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
93 changed files with 41580 additions and 39040 deletions

View file

@ -114,7 +114,7 @@ class Column:
return self.inverse_binary_op(exp.Or, other)
@classmethod
def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]) -> Column:
return cls(value)
@classmethod
@ -259,7 +259,7 @@ class Column:
new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
return Column(new_expression)
def cast(self, dataType: t.Union[str, DataType]):
def cast(self, dataType: t.Union[str, DataType]) -> Column:
"""
Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
Sqlglot doesn't currently replicate this class so it only accepts a string

View file

@ -600,8 +600,13 @@ def months_between(
date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None
) -> Column:
if roundOff is None:
return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2)
return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2, roundOff)
return Column.invoke_expression_over_column(
date1, expression.MonthsBetween, expression=date2
)
return Column.invoke_expression_over_column(
date1, expression.MonthsBetween, expression=date2, roundoff=roundOff
)
def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
@ -614,8 +619,9 @@ def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
if format is not None:
return Column.invoke_anonymous_function(col, "TO_TIMESTAMP", lit(format))
return Column.invoke_anonymous_function(col, "TO_TIMESTAMP")
return Column.invoke_expression_over_column(col, expression.StrToTime, format=lit(format))
return Column.ensure_col(col).cast("timestamp")
def trunc(col: ColumnOrName, format: str) -> Column:
@ -875,8 +881,16 @@ def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None)
)
def regexp_replace(str: ColumnOrName, pattern: str, replacement: str) -> Column:
return Column.invoke_anonymous_function(str, "REGEXP_REPLACE", lit(pattern), lit(replacement))
def regexp_replace(
str: ColumnOrName, pattern: str, replacement: str, position: t.Optional[int] = None
) -> Column:
return Column.invoke_expression_over_column(
str,
expression.RegexpReplace,
expression=lit(pattern),
replacement=lit(replacement),
position=position,
)
def initcap(col: ColumnOrName) -> Column:
@ -1186,7 +1200,9 @@ def transform(
f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
) -> Column:
f_expression = _get_lambda_from_func(f)
return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression))
return Column.invoke_expression_over_column(
col, expression.Transform, expression=Column(f_expression)
)
def exists(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column:

View file

@ -8,6 +8,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot._typing import E
from sqlglot.dialects.dialect import (
Dialect,
binary_from_function,
datestrtodate_sql,
format_time_lambda,
inline_array_sql,
@ -15,6 +16,7 @@ from sqlglot.dialects.dialect import (
min_or_least,
no_ilike_sql,
parse_date_delta_with_interval,
regexp_replace_sql,
rename_func,
timestrtotime_sql,
ts_or_ds_to_date_sql,
@ -39,7 +41,7 @@ def _date_add_sql(
def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str:
if not isinstance(expression.unnest().parent, exp.From):
if not expression.find_ancestor(exp.From, exp.Join):
return self.values_sql(expression)
alias = expression.args.get("alias")
@ -279,7 +281,7 @@ class BigQuery(Dialect):
),
"DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd),
"DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub),
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
"DIV": binary_from_function(exp.IntDiv),
"GENERATE_ARRAY": exp.GenerateSeries.from_arg_list,
"MD5": exp.MD5Digest.from_arg_list,
"TO_HEX": _parse_to_hex,
@ -415,6 +417,7 @@ class BigQuery(Dialect):
e.args.get("position"),
e.args.get("occurrence"),
),
exp.RegexpReplace: regexp_replace_sql,
exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
exp.ReturnsProperty: _returnsproperty_sql,
exp.Select: transforms.preprocess(

View file

@ -64,6 +64,7 @@ class ClickHouse(Dialect):
"MAP": parse_var_map,
"MATCH": exp.RegexpLike.from_arg_list,
"UNIQ": exp.ApproxDistinct.from_arg_list,
"XOR": lambda args: exp.Xor(expressions=args),
}
FUNCTIONS_WITH_ALIASED_ARGS = {*parser.Parser.FUNCTIONS_WITH_ALIASED_ARGS, "TUPLE"}
@ -95,6 +96,7 @@ class ClickHouse(Dialect):
TokenType.ASOF,
TokenType.ANTI,
TokenType.SEMI,
TokenType.ARRAY,
}
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {
@ -103,6 +105,7 @@ class ClickHouse(Dialect):
TokenType.ANTI,
TokenType.SETTINGS,
TokenType.FORMAT,
TokenType.ARRAY,
}
LOG_DEFAULTS_TO_LN = True
@ -160,8 +163,11 @@ class ClickHouse(Dialect):
schema: bool = False,
joins: bool = False,
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
parse_bracket: bool = False,
) -> t.Optional[exp.Expression]:
this = super()._parse_table(schema=schema, joins=joins, alias_tokens=alias_tokens)
this = super()._parse_table(
schema=schema, joins=joins, alias_tokens=alias_tokens, parse_bracket=parse_bracket
)
if self._match(TokenType.FINAL):
this = self.expression(exp.Final, this=this)
@ -204,8 +210,10 @@ class ClickHouse(Dialect):
self._match_set(self.JOIN_KINDS) and self._prev,
)
def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Join]:
join = super()._parse_join(skip_join_token)
def _parse_join(
self, skip_join_token: bool = False, parse_bracket: bool = False
) -> t.Optional[exp.Join]:
join = super()._parse_join(skip_join_token=skip_join_token, parse_bracket=True)
if join:
join.set("global", join.args.pop("method", None))
@ -318,6 +326,7 @@ class ClickHouse(Dialect):
exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})",
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.Xor: lambda self, e: self.func("xor", e.this, e.expression, *e.expressions),
}
PROPERTIES_LOCATION = {

View file

@ -12,6 +12,8 @@ from sqlglot.time import format_time
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import new_trie
B = t.TypeVar("B", bound=exp.Binary)
class Dialects(str, Enum):
DIALECT = ""
@ -630,6 +632,16 @@ def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
)
def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
if bad_args:
self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
return self.func(
"REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
)
def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
names = []
for agg in aggregations:
@ -650,3 +662,7 @@ def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectTyp
names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
return names
def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))

View file

@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
approx_count_distinct_sql,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
binary_from_function,
date_trunc_to_time,
datestrtodate_sql,
format_time_lambda,
@ -16,6 +17,7 @@ from sqlglot.dialects.dialect import (
no_safe_divide_sql,
pivot_column_names,
regexp_extract_sql,
regexp_replace_sql,
rename_func,
str_position_sql,
str_to_time_sql,
@ -103,7 +105,6 @@ class DuckDB(Dialect):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"~": TokenType.RLIKE,
":=": TokenType.EQ,
"//": TokenType.DIV,
"ATTACH": TokenType.COMMAND,
@ -128,6 +129,11 @@ class DuckDB(Dialect):
class Parser(parser.Parser):
CONCAT_NULL_OUTPUTS_STRING = True
BITWISE = {
**parser.Parser.BITWISE,
TokenType.TILDA: exp.RegexpLike,
}
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
@ -158,6 +164,7 @@ class DuckDB(Dialect):
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"TO_TIMESTAMP": exp.UnixToTime.from_arg_list,
"UNNEST": exp.Explode.from_arg_list,
"XOR": binary_from_function(exp.BitwiseXor),
}
TYPE_TOKENS = {
@ -190,6 +197,7 @@ class DuckDB(Dialect):
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArraySort: _array_sort_sql,
exp.ArraySum: rename_func("LIST_SUM"),
exp.BitwiseXor: lambda self, e: self.func("XOR", e.this, e.expression),
exp.CommentColumnConstraint: no_comment_column_constraint_sql,
exp.CurrentDate: lambda self, e: "CURRENT_DATE",
exp.CurrentTime: lambda self, e: "CURRENT_TIME",
@ -203,7 +211,7 @@ class DuckDB(Dialect):
exp.DateFromParts: rename_func("MAKE_DATE"),
exp.DateSub: _date_delta_sql,
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", f"'{e.args.get('unit', 'day')}'", e.expression, e.this
"DATE_DIFF", f"'{e.args.get('unit') or 'day'}'", e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)",
@ -217,8 +225,15 @@ class DuckDB(Dialect):
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.MonthsBetween: lambda self, e: self.func(
"DATEDIFF",
"'month'",
exp.cast(e.expression, "timestamp"),
exp.cast(e.this, "timestamp"),
),
exp.Properties: no_properties_sql,
exp.RegexpExtract: regexp_extract_sql,
exp.RegexpReplace: regexp_replace_sql,
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"),
exp.SafeDivide: no_safe_divide_sql,

View file

@ -18,6 +18,7 @@ from sqlglot.dialects.dialect import (
no_safe_divide_sql,
no_trycast_sql,
regexp_extract_sql,
regexp_replace_sql,
rename_func,
right_to_substring_sql,
strposition_to_locate_sql,
@ -211,6 +212,7 @@ class Hive(Dialect):
"ADD JAR": TokenType.COMMAND,
"ADD JARS": TokenType.COMMAND,
"MSCK REPAIR": TokenType.COMMAND,
"REFRESH": TokenType.COMMAND,
"WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES,
}
@ -270,6 +272,11 @@ class Hive(Dialect):
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
"TRANSFORM": lambda self: self._parse_transform(),
}
PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS,
"WITH SERDEPROPERTIES": lambda self: exp.SerdeProperties(
@ -277,6 +284,40 @@ class Hive(Dialect):
),
}
def _parse_transform(self) -> exp.Transform | exp.QueryTransform:
args = self._parse_csv(self._parse_lambda)
self._match_r_paren()
row_format_before = self._parse_row_format(match_row=True)
record_writer = None
if self._match_text_seq("RECORDWRITER"):
record_writer = self._parse_string()
if not self._match(TokenType.USING):
return exp.Transform.from_arg_list(args)
command_script = self._parse_string()
self._match(TokenType.ALIAS)
schema = self._parse_schema()
row_format_after = self._parse_row_format(match_row=True)
record_reader = None
if self._match_text_seq("RECORDREADER"):
record_reader = self._parse_string()
return self.expression(
exp.QueryTransform,
expressions=args,
command_script=command_script,
schema=schema,
row_format_before=row_format_before,
record_writer=record_writer,
row_format_after=row_format_after,
record_reader=record_reader,
)
def _parse_types(
self, check_func: bool = False, schema: bool = False
) -> t.Optional[exp.Expression]:
@ -363,11 +404,13 @@ class Hive(Dialect):
exp.Max: max_or_greatest,
exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)),
exp.Min: min_or_least,
exp.MonthsBetween: lambda self, e: self.func("MONTHS_BETWEEN", e.this, e.expression),
exp.VarMap: var_map_sql,
exp.Create: create_with_partitions_sql,
exp.Quantile: rename_func("PERCENTILE"),
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
exp.RegexpExtract: regexp_extract_sql,
exp.RegexpReplace: regexp_replace_sql,
exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"),
exp.RegexpSplit: rename_func("SPLIT"),
exp.Right: right_to_substring_sql,
@ -396,7 +439,6 @@ class Hive(Dialect):
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
exp.RowFormatSerdeProperty: lambda self, e: f"ROW FORMAT SERDE {self.sql(e, 'this')}",
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
exp.LastDateOfMonth: rename_func("LAST_DAY"),
@ -410,6 +452,11 @@ class Hive(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
def rowformatserdeproperty_sql(self, expression: exp.RowFormatSerdeProperty) -> str:
serde_props = self.sql(expression, "serde_properties")
serde_props = f" {serde_props}" if serde_props else ""
return f"ROW FORMAT SERDE {self.sql(expression, 'this')}{serde_props}"
def arrayagg_sql(self, expression: exp.ArrayAgg) -> str:
return self.func(
"COLLECT_LIST",

View file

@ -427,6 +427,7 @@ class MySQL(Dialect):
TABLE_HINTS = True
DUPLICATE_KEY_UPDATE_WITH_SET = False
QUERY_HINT_SEP = " "
VALUES_AS_TABLE = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@ -473,19 +474,32 @@ class MySQL(Dialect):
LIMIT_FETCH = "LIMIT"
# MySQL doesn't support many datatypes in cast.
# https://dev.mysql.com/doc/refman/8.0/en/cast-functions.html#function_cast
CAST_MAPPING = {
exp.DataType.Type.BIGINT: "SIGNED",
exp.DataType.Type.BOOLEAN: "SIGNED",
exp.DataType.Type.INT: "SIGNED",
exp.DataType.Type.TEXT: "CHAR",
exp.DataType.Type.UBIGINT: "UNSIGNED",
exp.DataType.Type.VARCHAR: "CHAR",
}
def xor_sql(self, expression: exp.Xor) -> str:
if expression.expressions:
return self.expressions(expression, sep=" XOR ")
return super().xor_sql(expression)
def jsonarraycontains_sql(self, expression: exp.JSONArrayContains) -> str:
return f"{self.sql(expression, 'this')} MEMBER OF({self.sql(expression, 'expression')})"
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
"""(U)BIGINT is not allowed in a CAST expression, so we use (UN)SIGNED instead."""
if expression.to.this == exp.DataType.Type.BIGINT:
to = "SIGNED"
elif expression.to.this == exp.DataType.Type.UBIGINT:
to = "UNSIGNED"
else:
return super().cast_sql(expression)
to = self.CAST_MAPPING.get(expression.to.this)
return f"CAST({self.sql(expression, 'this')} AS {to})"
if to:
expression = expression.copy()
expression.to.set("this", to)
return super().cast_sql(expression)
def show_sql(self, expression: exp.Show) -> str:
this = f" {expression.name}"

View file

@ -282,7 +282,6 @@ class Postgres(Dialect):
VAR_SINGLE_TOKENS = {"$"}
class Parser(parser.Parser):
STRICT_CAST = False
CONCAT_NULL_OUTPUTS_STRING = True
FUNCTIONS = {
@ -318,6 +317,11 @@ class Postgres(Dialect):
TokenType.LT_AT: binary_range_parser(exp.ArrayContained),
}
STATEMENT_PARSERS = {
**parser.Parser.STATEMENT_PARSERS,
TokenType.END: lambda self: self._parse_commit_or_rollback(),
}
def _parse_factor(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_exponent, self.FACTOR)

View file

@ -5,6 +5,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
binary_from_function,
date_trunc_to_time,
format_time_lambda,
if_sql,
@ -198,6 +199,10 @@ class Presto(Dialect):
**parser.Parser.FUNCTIONS,
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
"APPROX_PERCENTILE": _approx_percentile,
"BITWISE_AND": binary_from_function(exp.BitwiseAnd),
"BITWISE_NOT": lambda args: exp.BitwiseNot(this=seq_get(args, 0)),
"BITWISE_OR": binary_from_function(exp.BitwiseOr),
"BITWISE_XOR": binary_from_function(exp.BitwiseXor),
"CARDINALITY": exp.ArraySize.from_arg_list,
"CONTAINS": exp.ArrayContains.from_arg_list,
"DATE_ADD": lambda args: exp.DateAdd(

View file

@ -27,6 +27,11 @@ class Redshift(Postgres):
class Parser(Postgres.Parser):
FUNCTIONS = {
**Postgres.Parser.FUNCTIONS,
"ADD_MONTHS": lambda args: exp.DateAdd(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
expression=seq_get(args, 1),
unit=exp.var("month"),
),
"DATEADD": lambda args: exp.DateAdd(
this=exp.TsOrDsToDate(this=seq_get(args, 2)),
expression=seq_get(args, 1),
@ -37,7 +42,6 @@ class Redshift(Postgres):
expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
unit=seq_get(args, 0),
),
"NVL": exp.Coalesce.from_arg_list,
"STRTOL": exp.FromBase.from_arg_list,
}
@ -87,6 +91,7 @@ class Redshift(Postgres):
LOCKING_READS_SUPPORTED = False
RENAME_TABLE_WITH_DB = False
QUERY_HINTS = False
VALUES_AS_TABLE = False
TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING,
@ -129,40 +134,6 @@ class Redshift(Postgres):
RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"}
def values_sql(self, expression: exp.Values) -> str:
"""
Converts `VALUES...` expression into a series of unions.
Note: If you have a lot of unions then this will result in a large number of recursive statements to
evaluate the expression. You may need to increase `sys.setrecursionlimit` to run and it can also be
very slow.
"""
# The VALUES clause is still valid in an `INSERT INTO ..` statement, for example
if not expression.find_ancestor(exp.From, exp.Join):
return super().values_sql(expression)
column_names = expression.alias and expression.args["alias"].columns
selects = []
rows = [tuple_exp.expressions for tuple_exp in expression.expressions]
for i, row in enumerate(rows):
if i == 0 and column_names:
row = [
exp.alias_(value, column_name)
for value, column_name in zip(row, column_names)
]
selects.append(exp.Select(expressions=row))
subquery_expression: exp.Select | exp.Union = selects[0]
if len(selects) > 1:
for select in selects[1:]:
subquery_expression = exp.union(subquery_expression, select, distinct=False)
return self.subquery_sql(subquery_expression.subquery(expression.alias))
def with_properties(self, properties: exp.Properties) -> str:
"""Redshift doesn't have `WITH` as part of their with_properties so we remove it"""
return self.properties(properties, prefix=" ", suffix="")

View file

@ -30,7 +30,7 @@ def _check_int(s: str) -> bool:
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
def _snowflake_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]:
def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]:
if len(args) == 2:
first_arg, second_arg = args
if second_arg.is_string:
@ -137,7 +137,7 @@ def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]:
# https://docs.snowflake.com/en/sql-reference/functions/div0
def _div0_to_if(args: t.List) -> exp.Expression:
def _div0_to_if(args: t.List) -> exp.If:
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
true = exp.Literal.number(0)
false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1))
@ -145,13 +145,13 @@ def _div0_to_if(args: t.List) -> exp.Expression:
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
def _zeroifnull_to_if(args: t.List) -> exp.Expression:
def _zeroifnull_to_if(args: t.List) -> exp.If:
cond = exp.Is(this=seq_get(args, 0), expression=exp.Null())
return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0))
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
def _nullifzero_to_if(args: t.List) -> exp.Expression:
def _nullifzero_to_if(args: t.List) -> exp.If:
cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0))
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
@ -164,12 +164,21 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
return self.datatype_sql(expression)
def _parse_convert_timezone(args: t.List) -> exp.Expression:
def _parse_convert_timezone(args: t.List) -> t.Union[exp.Anonymous, exp.AtTimeZone]:
if len(args) == 3:
return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args)
return exp.AtTimeZone(this=seq_get(args, 1), zone=seq_get(args, 0))
def _parse_regexp_replace(args: t.List) -> exp.RegexpReplace:
regexp_replace = exp.RegexpReplace.from_arg_list(args)
if not regexp_replace.args.get("replacement"):
regexp_replace.set("replacement", exp.Literal.string(""))
return regexp_replace
class Snowflake(Dialect):
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
RESOLVES_IDENTIFIERS_AS_UPPERCASE = True
@ -223,13 +232,14 @@ class Snowflake(Dialect):
"IFF": exp.If.from_arg_list,
"NULLIFZERO": _nullifzero_to_if,
"OBJECT_CONSTRUCT": _parse_object_construct,
"REGEXP_REPLACE": _parse_regexp_replace,
"REGEXP_SUBSTR": exp.RegexpExtract.from_arg_list,
"RLIKE": exp.RegexpLike.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"TIMEDIFF": _parse_datediff,
"TIMESTAMPDIFF": _parse_datediff,
"TO_ARRAY": exp.Array.from_arg_list,
"TO_TIMESTAMP": _snowflake_to_timestamp,
"TO_TIMESTAMP": _parse_to_timestamp,
"TO_VARCHAR": exp.ToChar.from_arg_list,
"ZEROIFNULL": _zeroifnull_to_if,
}
@ -242,7 +252,6 @@ class Snowflake(Dialect):
FUNC_TOKENS = {
*parser.Parser.FUNC_TOKENS,
TokenType.RLIKE,
TokenType.TABLE,
}

View file

@ -2,9 +2,11 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, parser, transforms
from sqlglot import exp, transforms
from sqlglot.dialects.dialect import (
binary_from_function,
create_with_partitions_sql,
format_time_lambda,
pivot_column_names,
rename_func,
trim_sql,
@ -108,47 +110,36 @@ class Spark2(Hive):
class Parser(Hive.Parser):
FUNCTIONS = {
**Hive.Parser.FUNCTIONS,
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
"SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
this=seq_get(args, 0),
expression=seq_get(args, 1),
),
"SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
this=seq_get(args, 0),
expression=seq_get(args, 1),
),
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"IIF": exp.If.from_arg_list,
"AGGREGATE": exp.Reduce.from_arg_list,
"DAYOFWEEK": lambda args: exp.DayOfWeek(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"DAYOFMONTH": lambda args: exp.DayOfMonth(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"DAYOFYEAR": lambda args: exp.DayOfYear(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1),
unit=exp.var(seq_get(args, 0)),
),
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"BOOLEAN": _parse_as_cast("boolean"),
"DATE": _parse_as_cast("date"),
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1), unit=exp.var(seq_get(args, 0))
),
"DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DOUBLE": _parse_as_cast("double"),
"FLOAT": _parse_as_cast("float"),
"IIF": exp.If.from_arg_list,
"INT": _parse_as_cast("int"),
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
"RLIKE": exp.RegexpLike.from_arg_list,
"SHIFTLEFT": binary_from_function(exp.BitwiseLeftShift),
"SHIFTRIGHT": binary_from_function(exp.BitwiseRightShift),
"STRING": _parse_as_cast("string"),
"TIMESTAMP": _parse_as_cast("timestamp"),
"TO_TIMESTAMP": lambda args: _parse_as_cast("timestamp")(args)
if len(args) == 1
else format_time_lambda(exp.StrToTime, "spark")(args),
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
"WEEKOFYEAR": lambda args: exp.WeekOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
**Hive.Parser.FUNCTION_PARSERS,
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
@ -207,6 +198,13 @@ class Spark2(Hive):
exp.Map: _map_sql,
exp.Pivot: transforms.preprocess([_unqualify_pivot_columns]),
exp.Reduce: rename_func("AGGREGATE"),
exp.RegexpReplace: lambda self, e: self.func(
"REGEXP_REPLACE",
e.this,
e.expression,
e.args["replacement"],
e.args.get("position"),
),
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimestampTrunc: lambda self, e: self.func(
@ -224,6 +222,7 @@ class Spark2(Hive):
TRANSFORMS.pop(exp.ArraySort)
TRANSFORMS.pop(exp.ILike)
TRANSFORMS.pop(exp.Left)
TRANSFORMS.pop(exp.MonthsBetween)
TRANSFORMS.pop(exp.Right)
WRAP_DERIVED_VALUES = False

View file

@ -20,6 +20,8 @@ class StarRocks(MySQL):
}
class Generator(MySQL.Generator):
CAST_MAPPING = {}
TYPE_MAPPING = {
**MySQL.Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING",

View file

@ -138,7 +138,8 @@ def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.Tim
if isinstance(expression, exp.NumberToStr)
else exp.Literal.string(
format_time(
expression.text("format"), t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING)
expression.text("format"),
t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING),
)
)
)
@ -314,7 +315,9 @@ class TSQL(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"CHARINDEX": lambda args: exp.StrPosition(
this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
this=seq_get(args, 1),
substr=seq_get(args, 0),
position=seq_get(args, 2),
),
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
@ -365,6 +368,55 @@ class TSQL(Dialect):
CONCAT_NULL_OUTPUTS_STRING = True
def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback:
"""Applies to SQL Server and Azure SQL Database
COMMIT [ { TRAN | TRANSACTION }
[ transaction_name | @tran_name_variable ] ]
[ WITH ( DELAYED_DURABILITY = { OFF | ON } ) ]
ROLLBACK { TRAN | TRANSACTION }
[ transaction_name | @tran_name_variable
| savepoint_name | @savepoint_variable ]
"""
rollback = self._prev.token_type == TokenType.ROLLBACK
self._match_texts({"TRAN", "TRANSACTION"})
this = self._parse_id_var()
if rollback:
return self.expression(exp.Rollback, this=this)
durability = None
if self._match_pair(TokenType.WITH, TokenType.L_PAREN):
self._match_text_seq("DELAYED_DURABILITY")
self._match(TokenType.EQ)
if self._match_text_seq("OFF"):
durability = False
else:
self._match(TokenType.ON)
durability = True
self._match_r_paren()
return self.expression(exp.Commit, this=this, durability=durability)
def _parse_transaction(self) -> exp.Transaction | exp.Command:
"""Applies to SQL Server and Azure SQL Database
BEGIN { TRAN | TRANSACTION }
[ { transaction_name | @tran_name_variable }
[ WITH MARK [ 'description' ] ]
]
"""
if self._match_texts(("TRAN", "TRANSACTION")):
transaction = self.expression(exp.Transaction, this=self._parse_id_var())
if self._match_text_seq("WITH", "MARK"):
transaction.set("mark", self._parse_string())
return transaction
return self._parse_as_command(self._prev)
def _parse_system_time(self) -> t.Optional[exp.Expression]:
if not self._match_text_seq("FOR", "SYSTEM_TIME"):
return None
@ -496,7 +548,9 @@ class TSQL(Dialect):
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
exp.SHA2: lambda self, e: self.func(
"HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
"HASHBYTES",
exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"),
e.this,
),
exp.TimeToStr: _format_sql,
}
@ -539,3 +593,26 @@ class TSQL(Dialect):
into = self.sql(expression, "into")
into = self.seg(f"INTO {into}") if into else ""
return f"{self.seg('OUTPUT')} {self.expressions(expression, flat=True)}{into}"
def transaction_sql(self, expression: exp.Transaction) -> str:
this = self.sql(expression, "this")
this = f" {this}" if this else ""
mark = self.sql(expression, "mark")
mark = f" WITH MARK {mark}" if mark else ""
return f"BEGIN TRANSACTION{this}{mark}"
def commit_sql(self, expression: exp.Commit) -> str:
this = self.sql(expression, "this")
this = f" {this}" if this else ""
durability = expression.args.get("durability")
durability = (
f" WITH (DELAYED_DURABILITY = {'ON' if durability else 'OFF'})"
if durability is not None
else ""
)
return f"COMMIT TRANSACTION{this}{durability}"
def rollback_sql(self, expression: exp.Rollback) -> str:
this = self.sql(expression, "this")
this = f" {this}" if this else ""
return f"ROLLBACK TRANSACTION{this}"

View file

@ -759,12 +759,24 @@ class Condition(Expression):
)
def isin(
self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy: bool = True, **opts
self,
*expressions: t.Any,
query: t.Optional[ExpOrStr] = None,
unnest: t.Optional[ExpOrStr] | t.Collection[ExpOrStr] = None,
copy: bool = True,
**opts,
) -> In:
return In(
this=_maybe_copy(self, copy),
expressions=[convert(e, copy=copy) for e in expressions],
query=maybe_parse(query, copy=copy, **opts) if query else None,
unnest=Unnest(
expressions=[
maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts) for e in ensure_list(unnest)
]
)
if unnest
else None,
)
def between(self, low: t.Any, high: t.Any, copy: bool = True, **opts) -> Between:
@ -2019,7 +2031,20 @@ class RowFormatDelimitedProperty(Property):
class RowFormatSerdeProperty(Property):
arg_types = {"this": True}
arg_types = {"this": True, "serde_properties": False}
# https://spark.apache.org/docs/3.1.2/sql-ref-syntax-qry-select-transform.html
class QueryTransform(Expression):
arg_types = {
"expressions": True,
"command_script": True,
"schema": False,
"row_format_before": False,
"record_writer": False,
"row_format_after": False,
"record_reader": False,
}
class SchemaCommentProperty(Property):
@ -2149,12 +2174,24 @@ class Tuple(Expression):
arg_types = {"expressions": False}
def isin(
self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy: bool = True, **opts
self,
*expressions: t.Any,
query: t.Optional[ExpOrStr] = None,
unnest: t.Optional[ExpOrStr] | t.Collection[ExpOrStr] = None,
copy: bool = True,
**opts,
) -> In:
return In(
this=_maybe_copy(self, copy),
expressions=[convert(e, copy=copy) for e in expressions],
query=maybe_parse(query, copy=copy, **opts) if query else None,
unnest=Unnest(
expressions=[
maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts) for e in ensure_list(unnest)
]
)
if unnest
else None,
)
@ -3478,15 +3515,15 @@ class Command(Expression):
class Transaction(Expression):
arg_types = {"this": False, "modes": False}
arg_types = {"this": False, "modes": False, "mark": False}
class Commit(Expression):
arg_types = {"chain": False}
arg_types = {"chain": False, "this": False, "durability": False}
class Rollback(Expression):
arg_types = {"savepoint": False}
arg_types = {"savepoint": False, "this": False}
class AlterTable(Expression):
@ -3530,10 +3567,6 @@ class Or(Connector):
pass
class Xor(Connector):
pass
class BitwiseAnd(Binary):
pass
@ -3856,6 +3889,11 @@ class Abs(Func):
pass
# https://spark.apache.org/docs/latest/api/sql/index.html#transform
class Transform(Func):
arg_types = {"this": True, "expression": True}
class Anonymous(Func):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@ -4098,6 +4136,10 @@ class WeekOfYear(Func):
_sql_names = ["WEEK_OF_YEAR", "WEEKOFYEAR"]
class MonthsBetween(Func):
arg_types = {"this": True, "expression": True, "roundoff": False}
class LastDateOfMonth(Func):
pass
@ -4209,6 +4251,10 @@ class Hex(Func):
pass
class Xor(Connector, Func):
arg_types = {"this": False, "expression": False, "expressions": False}
class If(Func):
arg_types = {"this": True, "true": True, "false": False}
@ -4431,7 +4477,18 @@ class RegexpExtract(Func):
}
class RegexpLike(Func):
class RegexpReplace(Func):
arg_types = {
"this": True,
"expression": True,
"replacement": True,
"position": False,
"occurrence": False,
"parameters": False,
}
class RegexpLike(Binary, Func):
arg_types = {"this": True, "expression": True, "flag": False}

View file

@ -164,6 +164,11 @@ class Generator:
# https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax
SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE")
# Whether or not VALUES statements can be used as derived tables.
# MySQL 5 and Redshift do not allow this, so when False, it will convert
# SELECT * VALUES into SELECT UNION
VALUES_AS_TABLE = True
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@ -260,8 +265,9 @@ class Generator:
# Expressions whose comments are separated from them for better formatting
WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = (
exp.Select,
exp.Drop,
exp.From,
exp.Select,
exp.Where,
exp.With,
)
@ -818,7 +824,11 @@ class Generator:
def datatype_sql(self, expression: exp.DataType) -> str:
type_value = expression.this
type_sql = self.TYPE_MAPPING.get(type_value, type_value.value)
type_sql = (
self.TYPE_MAPPING.get(type_value, type_value.value)
if isinstance(type_value, exp.DataType.Type)
else type_value
)
nested = ""
interior = self.expressions(expression, flat=True)
values = ""
@ -1307,15 +1317,45 @@ class Generator:
return self.prepend_ctes(expression, sql)
def values_sql(self, expression: exp.Values) -> str:
args = self.expressions(expression)
alias = self.sql(expression, "alias")
values = f"VALUES{self.seg('')}{args}"
values = (
f"({values})"
if self.WRAP_DERIVED_VALUES and (alias or isinstance(expression.parent, exp.From))
else values
)
return f"{values} AS {alias}" if alias else values
# The VALUES clause is still valid in an `INSERT INTO ..` statement, for example
if self.VALUES_AS_TABLE or not expression.find_ancestor(exp.From, exp.Join):
args = self.expressions(expression)
alias = self.sql(expression, "alias")
values = f"VALUES{self.seg('')}{args}"
values = (
f"({values})"
if self.WRAP_DERIVED_VALUES and (alias or isinstance(expression.parent, exp.From))
else values
)
return f"{values} AS {alias}" if alias else values
# Converts `VALUES...` expression into a series of select unions.
# Note: If you have a lot of unions then this will result in a large number of recursive statements to
# evaluate the expression. You may need to increase `sys.setrecursionlimit` to run and it can also be
# very slow.
expression = expression.copy()
column_names = expression.alias and expression.args["alias"].columns
selects = []
for i, tup in enumerate(expression.expressions):
row = tup.expressions
if i == 0 and column_names:
row = [
exp.alias_(value, column_name) for value, column_name in zip(row, column_names)
]
selects.append(exp.Select(expressions=row))
subquery_expression: exp.Select | exp.Union = selects[0]
if len(selects) > 1:
for select in selects[1:]:
subquery_expression = exp.union(
subquery_expression, select, distinct=False, copy=False
)
return self.subquery_sql(subquery_expression.subquery(expression.alias, copy=False))
def var_sql(self, expression: exp.Var) -> str:
return self.sql(expression, "this")
@ -2043,7 +2083,7 @@ class Generator:
def and_sql(self, expression: exp.And) -> str:
return self.connector_sql(expression, "AND")
def xor_sql(self, expression: exp.And) -> str:
def xor_sql(self, expression: exp.Xor) -> str:
return self.connector_sql(expression, "XOR")
def connector_sql(self, expression: exp.Connector, op: str) -> str:
@ -2507,6 +2547,21 @@ class Generator:
return self.func("ANY_VALUE", this)
def querytransform_sql(self, expression: exp.QueryTransform) -> str:
transform = self.func("TRANSFORM", *expression.expressions)
row_format_before = self.sql(expression, "row_format_before")
row_format_before = f" {row_format_before}" if row_format_before else ""
record_writer = self.sql(expression, "record_writer")
record_writer = f" RECORDWRITER {record_writer}" if record_writer else ""
using = f" USING {self.sql(expression, 'command_script')}"
schema = self.sql(expression, "schema")
schema = f" AS {schema}" if schema else ""
row_format_after = self.sql(expression, "row_format_after")
row_format_after = f" {row_format_after}" if row_format_after else ""
record_reader = self.sql(expression, "record_reader")
record_reader = f" RECORDREADER {record_reader}" if record_reader else ""
return f"{transform}{row_format_before}{record_writer}{using}{schema}{row_format_after}{record_reader}"
def cached_generator(
cache: t.Optional[t.Dict[int, str]] = None

View file

@ -79,7 +79,7 @@ def lineage(
raise SqlglotError("Cannot build lineage, sql must be SELECT")
def to_node(
column_name: str,
column: str | int,
scope: Scope,
scope_name: t.Optional[str] = None,
upstream: t.Optional[Node] = None,
@ -90,26 +90,38 @@ def lineage(
for dt in scope.derived_tables
if dt.comments and dt.comments[0].startswith("source: ")
}
if isinstance(scope.expression, exp.Union):
for scope in scope.union_scopes:
node = to_node(
column_name,
scope=scope,
scope_name=scope_name,
upstream=upstream,
alias=aliases.get(scope_name),
)
return node
# Find the specific select clause that is the source of the column we want.
# This can either be a specific, named select or a generic `*` clause.
select = next(
(select for select in scope.expression.selects if select.alias_or_name == column_name),
exp.Star() if scope.expression.is_star else None,
select = (
scope.expression.selects[column]
if isinstance(column, int)
else next(
(select for select in scope.expression.selects if select.alias_or_name == column),
exp.Star() if scope.expression.is_star else None,
)
)
if not select:
raise ValueError(f"Could not find {column_name} in {scope.expression}")
raise ValueError(f"Could not find {column} in {scope.expression}")
if isinstance(scope.expression, exp.Union):
upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
index = (
column
if isinstance(column, int)
else next(
i
for i, select in enumerate(scope.expression.selects)
if select.alias_or_name == column
)
)
for s in scope.union_scopes:
to_node(index, scope=s, upstream=upstream)
return upstream
if isinstance(scope.expression, exp.Select):
# For better ergonomics in our node labels, replace the full select with
@ -122,7 +134,7 @@ def lineage(
# Create the node for this step in the lineage chain, and attach it to the previous one.
node = Node(
name=f"{scope_name}.{column_name}" if scope_name else column_name,
name=f"{scope_name}.{column}" if scope_name else str(column),
source=source,
expression=select,
alias=alias or "",

View file

@ -144,8 +144,9 @@ def _eliminate_derived_table(scope, existing_ctes, taken):
name, cte = _new_cte(scope, existing_ctes, taken)
table = exp.alias_(exp.table_(name), alias=parent.alias or name)
parent.replace(table)
table.set("joins", parent.args.get("joins"))
parent.replace(table)
return cte

View file

@ -176,6 +176,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
return (
isinstance(outer_scope.expression, exp.Select)
and not outer_scope.expression.is_star
and isinstance(inner_select, exp.Select)
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
and inner_select.args.get("from")
@ -242,6 +243,7 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
alias (str)
"""
new_subquery = inner_scope.expression.args["from"].this
new_subquery.set("joins", node_to_replace.args.get("joins"))
node_to_replace.replace(new_subquery)
for join_hint in outer_scope.join_hints:
tables = join_hint.find_all(exp.Table)

View file

@ -61,6 +61,9 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
if remove_unused_selections:
_remove_unused_selections(scope, parent_selections, schema)
if scope.expression.is_star:
continue
# Group columns by source name
selects = defaultdict(set)
for col in scope.columns:

View file

@ -29,12 +29,13 @@ def qualify_columns(
'SELECT tbl.col AS col FROM tbl'
Args:
expression: expression to qualify
schema: Database schema
expand_alias_refs: whether or not to expand references to aliases
infer_schema: whether or not to infer the schema if missing
expression: Expression to qualify.
schema: Database schema.
expand_alias_refs: Whether or not to expand references to aliases.
infer_schema: Whether or not to infer the schema if missing.
Returns:
sqlglot.Expression: qualified expression
The qualified expression.
"""
schema = ensure_schema(schema)
infer_schema = schema.empty if infer_schema is None else infer_schema
@ -410,7 +411,9 @@ def _expand_stars(
else:
return
scope.expression.set("expressions", new_selections)
# Ensures we don't overwrite the initial selections with an empty list
if new_selections:
scope.expression.set("expressions", new_selections)
def _add_except_columns(

View file

@ -124,8 +124,8 @@ class Scope:
self._ctes.append(node)
elif (
isinstance(node, exp.Subquery)
and isinstance(parent, (exp.From, exp.Join))
and _is_subquery_scope(node)
and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
and _is_derived_table(node)
):
self._derived_tables.append(node)
elif isinstance(node, exp.Subqueryable):
@ -610,13 +610,13 @@ def _traverse_ctes(scope):
scope.sources.update(sources)
def _is_subquery_scope(expression: exp.Subquery) -> bool:
def _is_derived_table(expression: exp.Subquery) -> bool:
"""
We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a new scope.
If an alias is present, it shadows all names under the Subquery, so that's an
exception to this rule.
We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table",
as it doesn't introduce a new scope. If an alias is present, it shadows all names
under the Subquery, so that's one exception to this rule.
"""
return bool(not isinstance(expression.unnest(), exp.Table) or expression.alias)
return bool(expression.alias or isinstance(expression.this, exp.Subqueryable))
def _traverse_tables(scope):
@ -654,7 +654,10 @@ def _traverse_tables(scope):
else:
sources[source_name] = expression
expressions.extend(join.this for join in expression.args.get("joins") or [])
# Make sure to not include the joins twice
if expression is not scope.expression:
expressions.extend(join.this for join in expression.args.get("joins") or [])
continue
if not isinstance(expression, exp.DerivedTable):
@ -664,10 +667,11 @@ def _traverse_tables(scope):
lateral_sources = sources
scope_type = ScopeType.UDTF
scopes = scope.udtf_scopes
elif _is_subquery_scope(expression):
elif _is_derived_table(expression):
lateral_sources = None
scope_type = ScopeType.DERIVED_TABLE
scopes = scope.derived_table_scopes
expressions.extend(join.this for join in expression.args.get("joins") or [])
else:
# Makes sure we check for possible sources in nested table constructs
expressions.append(expression.this)
@ -735,10 +739,16 @@ def walk_in_scope(expression, bfs=True):
isinstance(node, exp.CTE)
or (
isinstance(node, exp.Subquery)
and isinstance(parent, (exp.From, exp.Join))
and _is_subquery_scope(node)
and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
and _is_derived_table(node)
)
or isinstance(node, exp.UDTF)
or isinstance(node, exp.Subqueryable)
):
prune = True
if isinstance(node, (exp.Subquery, exp.UDTF)):
# The following args are not actually in the inner scope, so we should visit them
for key in ("joins", "laterals", "pivots"):
for arg in node.args.get(key) or []:
yield from walk_in_scope(arg, bfs=bfs)

View file

@ -327,6 +327,7 @@ class Parser(metaclass=_Parser):
TokenType.PRIMARY_KEY,
TokenType.RANGE,
TokenType.REPLACE,
TokenType.RLIKE,
TokenType.ROW,
TokenType.UNNEST,
TokenType.VAR,
@ -338,6 +339,7 @@ class Parser(metaclass=_Parser):
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
TokenType.WINDOW,
TokenType.XOR,
*TYPE_TOKENS,
*SUBQUERY_PREDICATES,
}
@ -505,7 +507,6 @@ class Parser(metaclass=_Parser):
TokenType.DESC: lambda self: self._parse_describe(),
TokenType.DESCRIBE: lambda self: self._parse_describe(),
TokenType.DROP: lambda self: self._parse_drop(),
TokenType.END: lambda self: self._parse_commit_or_rollback(),
TokenType.FROM: lambda self: exp.select("*").from_(
t.cast(exp.From, self._parse_from(skip_from_token=True))
),
@ -716,7 +717,7 @@ class Parser(metaclass=_Parser):
FUNCTIONS_WITH_ALIASED_ARGS = {"STRUCT"}
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
FUNCTION_PARSERS = {
"ANY_VALUE": lambda self: self._parse_any_value(),
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
"CONCAT": lambda self: self._parse_concat(),
@ -1144,6 +1145,7 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.Drop,
comments=start.comments,
exists=self._parse_exists(),
this=self._parse_table(schema=True),
kind=kind,
@ -1233,11 +1235,14 @@ class Parser(metaclass=_Parser):
expression = self._parse_ddl_select()
if create_token.token_type == TokenType.TABLE:
# exp.Properties.Location.POST_EXPRESSION
extend_props(self._parse_properties())
indexes = []
while True:
index = self._parse_index()
# exp.Properties.Location.POST_EXPRESSION and POST_INDEX
# exp.Properties.Location.POST_INDEX
extend_props(self._parse_properties())
if not index:
@ -1384,7 +1389,6 @@ class Parser(metaclass=_Parser):
def _parse_with_property(
self,
) -> t.Optional[exp.Expression] | t.List[t.Optional[exp.Expression]]:
self._match(TokenType.WITH)
if self._match(TokenType.L_PAREN, advance=False):
return self._parse_wrapped_csv(self._parse_property)
@ -1781,7 +1785,17 @@ class Parser(metaclass=_Parser):
return None
if self._match_text_seq("SERDE"):
return self.expression(exp.RowFormatSerdeProperty, this=self._parse_string())
this = self._parse_string()
serde_properties = None
if self._match(TokenType.SERDE_PROPERTIES):
serde_properties = self.expression(
exp.SerdeProperties, expressions=self._parse_wrapped_csv(self._parse_property)
)
return self.expression(
exp.RowFormatSerdeProperty, this=this, serde_properties=serde_properties
)
self._match_text_seq("DELIMITED")
@ -2251,7 +2265,9 @@ class Parser(metaclass=_Parser):
self._match_set(self.JOIN_KINDS) and self._prev,
)
def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Join]:
def _parse_join(
self, skip_join_token: bool = False, parse_bracket: bool = False
) -> t.Optional[exp.Join]:
if self._match(TokenType.COMMA):
return self.expression(exp.Join, this=self._parse_table())
@ -2275,7 +2291,7 @@ class Parser(metaclass=_Parser):
if outer_apply:
side = Token(TokenType.LEFT, "LEFT")
kwargs: t.Dict[str, t.Any] = {"this": self._parse_table()}
kwargs: t.Dict[str, t.Any] = {"this": self._parse_table(parse_bracket=parse_bracket)}
if method:
kwargs["method"] = method.text
@ -2411,6 +2427,7 @@ class Parser(metaclass=_Parser):
schema: bool = False,
joins: bool = False,
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
parse_bracket: bool = False,
) -> t.Optional[exp.Expression]:
lateral = self._parse_lateral()
if lateral:
@ -2430,7 +2447,9 @@ class Parser(metaclass=_Parser):
subquery.set("pivots", self._parse_pivots())
return subquery
this: exp.Expression = self._parse_table_parts(schema=schema)
bracket = parse_bracket and self._parse_bracket(None)
bracket = self.expression(exp.Table, this=bracket) if bracket else None
this: exp.Expression = bracket or self._parse_table_parts(schema=schema)
if schema:
return self._parse_schema(this=this)
@ -2758,8 +2777,15 @@ class Parser(metaclass=_Parser):
self, this: t.Optional[exp.Expression] = None, top: bool = False
) -> t.Optional[exp.Expression]:
if self._match(TokenType.TOP if top else TokenType.LIMIT):
limit_paren = self._match(TokenType.L_PAREN)
expression = self._parse_number() if top else self._parse_term()
comments = self._prev_comments
if top:
limit_paren = self._match(TokenType.L_PAREN)
expression = self._parse_number()
if limit_paren:
self._match_r_paren()
else:
expression = self._parse_term()
if self._match(TokenType.COMMA):
offset = expression
@ -2767,10 +2793,9 @@ class Parser(metaclass=_Parser):
else:
offset = None
limit_exp = self.expression(exp.Limit, this=this, expression=expression, offset=offset)
if limit_paren:
self._match_r_paren()
limit_exp = self.expression(
exp.Limit, this=this, expression=expression, offset=offset, comments=comments
)
return limit_exp
@ -2803,7 +2828,7 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.OFFSET):
return this
count = self._parse_number()
count = self._parse_term()
self._match_set((TokenType.ROW, TokenType.ROWS))
return self.expression(exp.Offset, this=this, expression=count)
@ -3320,7 +3345,7 @@ class Parser(metaclass=_Parser):
else:
this = self.expression(exp.Anonymous, this=this, expressions=args)
self._match_r_paren(this)
self._match(TokenType.R_PAREN, expression=this)
return self._parse_window(this)
def _parse_function_parameter(self) -> t.Optional[exp.Expression]:
@ -4076,7 +4101,10 @@ class Parser(metaclass=_Parser):
self, this: t.Optional[exp.Expression], alias: bool = False
) -> t.Optional[exp.Expression]:
if self._match_pair(TokenType.FILTER, TokenType.L_PAREN):
this = self.expression(exp.Filter, this=this, expression=self._parse_where())
self._match(TokenType.WHERE)
this = self.expression(
exp.Filter, this=this, expression=self._parse_where(skip_where_token=True)
)
self._match_r_paren()
# T-SQL allows the OVER (...) syntax after WITHIN GROUP.
@ -4351,7 +4379,7 @@ class Parser(metaclass=_Parser):
self._parse_set_operations(self._parse_select(nested=True, parse_subquery_alias=False))
)
def _parse_transaction(self) -> exp.Transaction:
def _parse_transaction(self) -> exp.Transaction | exp.Command:
this = None
if self._match_texts(self.TRANSACTION_KIND):
this = self._prev.text

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import typing as t
from enum import auto
from sqlglot.errors import TokenError
from sqlglot.helper import AutoName
from sqlglot.trie import TrieResult, in_trie, new_trie
@ -800,7 +801,7 @@ class Tokenizer(metaclass=_Tokenizer):
start = max(self._current - 50, 0)
end = min(self._current + 50, self.size - 1)
context = self.sql[start:end]
raise ValueError(f"Error tokenizing '{context}'") from e
raise TokenError(f"Error tokenizing '{context}'") from e
return self.tokens
@ -1097,7 +1098,7 @@ class Tokenizer(metaclass=_Tokenizer):
try:
int(text, base)
except:
raise RuntimeError(
raise TokenError(
f"Numeric string contains invalid characters from {self._line}:{self._start}"
)
else:
@ -1140,7 +1141,7 @@ class Tokenizer(metaclass=_Tokenizer):
if self._current + 1 < self.size:
self._advance(2)
else:
raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._current}")
raise TokenError(f"Missing {delimiter} from {self._line}:{self._current}")
else:
if self._chars(delim_size) == delimiter:
if delim_size > 1:
@ -1148,7 +1149,7 @@ class Tokenizer(metaclass=_Tokenizer):
break
if self._end:
raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}")
raise TokenError(f"Missing {delimiter} from {self._line}:{self._start}")
current = self._current - 1
self._advance(alnum=True)