1
0
Fork 0

Merging upstream version 18.17.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:09:41 +01:00
parent fdf9ca761f
commit 04c9be45a8
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
90 changed files with 46581 additions and 43319 deletions

View file

@ -58,6 +58,12 @@ parser.add_argument(
default="IMMEDIATE",
help="IGNORE, WARN, RAISE, IMMEDIATE (default)",
)
parser.add_argument(
"--version",
action="version",
version=sqlglot.__version__,
help="Display the SQLGlot version",
)
args = parser.parse_args()

View file

@ -84,11 +84,11 @@ def min(col: ColumnOrName) -> Column:
def max_by(col: ColumnOrName, ord: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "MAX_BY", ord)
return Column.invoke_expression_over_column(col, expression.ArgMax, expression=ord)
def min_by(col: ColumnOrName, ord: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "MIN_BY", ord)
return Column.invoke_expression_over_column(col, expression.ArgMin, expression=ord)
def count(col: ColumnOrName) -> Column:
@ -1113,7 +1113,7 @@ def reverse(col: ColumnOrName) -> Column:
def flatten(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "FLATTEN")
return Column.invoke_expression_over_column(col, expression.Flatten)
def map_keys(col: ColumnOrName) -> Column:

View file

@ -8,6 +8,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot._typing import E
from sqlglot.dialects.dialect import (
Dialect,
arg_max_or_min_no_count,
binary_from_function,
date_add_interval_sql,
datestrtodate_sql,
@ -434,8 +435,13 @@ class BigQuery(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArgMax: arg_max_or_min_no_count("MAX_BY"),
exp.ArgMin: arg_max_or_min_no_count("MIN_BY"),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
exp.CollateProperty: lambda self, e: f"DEFAULT COLLATE {self.sql(e, 'this')}"
if e.args.get("default")
else f"COLLATE {self.sql(e, 'this')}",
exp.Create: _create_sql,
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
exp.DateAdd: date_add_interval_sql("DATE", "ADD"),
@ -632,6 +638,13 @@ class BigQuery(Dialect):
"within",
}
def eq_sql(self, expression: exp.EQ) -> str:
# Operands of = cannot be NULL in BigQuery
if isinstance(expression.left, exp.Null) or isinstance(expression.right, exp.Null):
return "NULL"
return self.binary(expression, "=")
def attimezone_sql(self, expression: exp.AtTimeZone) -> str:
parent = expression.parent

View file

@ -5,6 +5,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
arg_max_or_min_no_count,
inline_array_sql,
no_pivot_sql,
rename_func,
@ -373,8 +374,11 @@ class ClickHouse(Dialect):
exp.Select: transforms.preprocess([transforms.eliminate_qualify]),
exp.AnyValue: rename_func("any"),
exp.ApproxDistinct: rename_func("uniq"),
exp.ArgMax: arg_max_or_min_no_count("argMax"),
exp.ArgMin: arg_max_or_min_no_count("argMin"),
exp.Array: inline_array_sql,
exp.CastToStrType: rename_func("CAST"),
exp.CurrentDate: lambda self, e: self.func("CURRENT_DATE"),
exp.DateAdd: lambda self, e: self.func(
"DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
),
@ -418,6 +422,33 @@ class ClickHouse(Dialect):
"NAMED COLLECTION",
}
def _any_to_has(
self,
expression: exp.EQ | exp.NEQ,
default: t.Callable[[t.Any], str],
prefix: str = "",
) -> str:
if isinstance(expression.left, exp.Any):
arr = expression.left
this = expression.right
elif isinstance(expression.right, exp.Any):
arr = expression.right
this = expression.left
else:
return default(expression)
return prefix + self.func("has", arr.this.unnest(), this)
def eq_sql(self, expression: exp.EQ) -> str:
return self._any_to_has(expression, super().eq_sql)
def neq_sql(self, expression: exp.NEQ) -> str:
return self._any_to_has(expression, super().neq_sql, "NOT ")
def regexpilike_sql(self, expression: exp.RegexpILike) -> str:
# Manually add a flag to make the search case-insensitive
regex = self.func("CONCAT", "'(?i)'", expression.expression)
return f"match({self.format_args(expression.this, regex)})"
def datatype_sql(self, expression: exp.DataType) -> str:
# String is the standard ClickHouse type, every other variant is just an alias.
# Additionally, any supplied length parameter will be ignored.

View file

@ -10,7 +10,7 @@ from sqlglot.errors import ParseError
from sqlglot.generator import Generator
from sqlglot.helper import flatten, seq_get
from sqlglot.parser import Parser
from sqlglot.time import format_time
from sqlglot.time import TIMEZONES, format_time
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import new_trie
@ -595,6 +595,19 @@ def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
)
def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
if not expression.expression:
return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP))
if expression.text("expression").lower() in TIMEZONES:
return self.sql(
exp.AtTimeZone(
this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
zone=expression.expression,
)
)
return self.function_fallback_sql(expression)
def locate_to_strposition(args: t.List) -> exp.Expression:
return exp.StrPosition(
this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
@ -691,9 +704,13 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
_dialect = Dialect.get_or_raise(dialect)
time_format = self.format_time(expression)
if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
return self.sql(exp.cast(str_to_time_sql(self, expression), "date"))
return self.sql(exp.cast(self.sql(expression, "this"), "date"))
return self.sql(
exp.cast(
exp.StrToTime(this=expression.this, format=expression.args["format"]),
"date",
)
)
return self.sql(exp.cast(expression.this, "date"))
return _ts_or_ds_to_date_sql
@ -725,7 +742,9 @@ def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
bad_args = list(
filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
)
if bad_args:
self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
@ -756,15 +775,6 @@ def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectTyp
return names
def simplify_literal(expression: E) -> E:
if not isinstance(expression.expression, exp.Literal):
from sqlglot.optimizer.simplify import simplify
simplify(expression.expression)
return expression
def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
@ -804,3 +814,21 @@ def move_insert_cte_sql(self: Generator, expression: exp.Insert) -> str:
expression = expression.copy()
expression.set("with", expression.expression.args["with"].pop())
return self.insert_sql(expression)
def generatedasidentitycolumnconstraint_sql(
self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
) -> str:
start = self.sql(expression, "start") or "1"
increment = self.sql(expression, "increment") or "1"
return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
if expression.args.get("count"):
self.unsupported(f"Only two arguments are supported in function {name}.")
return self.func(name, expression.this, expression.expression)
return _arg_max_or_min_sql

View file

@ -6,6 +6,7 @@ from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
approx_count_distinct_sql,
arg_max_or_min_no_count,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
binary_from_function,
@ -18,9 +19,9 @@ from sqlglot.dialects.dialect import (
no_comment_column_constraint_sql,
no_properties_sql,
no_safe_divide_sql,
no_timestamp_sql,
pivot_column_names,
regexp_extract_sql,
regexp_replace_sql,
rename_func,
str_position_sql,
str_to_time_sql,
@ -172,6 +173,12 @@ class DuckDB(Dialect):
this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2)
),
"REGEXP_MATCHES": exp.RegexpLike.from_arg_list,
"REGEXP_REPLACE": lambda args: exp.RegexpReplace(
this=seq_get(args, 0),
expression=seq_get(args, 1),
replacement=seq_get(args, 2),
modifiers=seq_get(args, 3),
),
"STRFTIME": format_time_lambda(exp.TimeToStr, "duckdb"),
"STRING_SPLIT": exp.Split.from_arg_list,
"STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
@ -243,6 +250,8 @@ class DuckDB(Dialect):
if e.expressions and e.expressions[0].find(exp.Select)
else inline_array_sql(self, e),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArgMax: arg_max_or_min_no_count("ARG_MAX"),
exp.ArgMin: arg_max_or_min_no_count("ARG_MIN"),
exp.ArraySort: _array_sort_sql,
exp.ArraySum: rename_func("LIST_SUM"),
exp.BitwiseXor: rename_func("XOR"),
@ -287,7 +296,13 @@ class DuckDB(Dialect):
exp.PercentileDisc: rename_func("QUANTILE_DISC"),
exp.Properties: no_properties_sql,
exp.RegexpExtract: regexp_extract_sql,
exp.RegexpReplace: regexp_replace_sql,
exp.RegexpReplace: lambda self, e: self.func(
"REGEXP_REPLACE",
e.this,
e.expression,
e.args.get("replacement"),
e.args.get("modifiers"),
),
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"),
exp.SafeDivide: no_safe_divide_sql,
@ -298,6 +313,7 @@ class DuckDB(Dialect):
exp.StrToTime: str_to_time_sql,
exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.Struct: _struct_sql,
exp.Timestamp: no_timestamp_sql,
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: timestrtotime_sql,

View file

@ -6,6 +6,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
approx_count_distinct_sql,
arg_max_or_min_no_count,
create_with_partitions_sql,
format_time_lambda,
if_sql,
@ -106,11 +107,16 @@ def _date_diff_sql(self: Hive.Generator, expression: exp.DateDiff) -> str:
sec_diff = f"UNIX_TIMESTAMP({left}) - UNIX_TIMESTAMP({right})"
return f"({sec_diff}){factor}" if factor else sec_diff
sql_func = "MONTHS_BETWEEN" if unit in DIFF_MONTH_SWITCH else "DATEDIFF"
months_between = unit in DIFF_MONTH_SWITCH
sql_func = "MONTHS_BETWEEN" if months_between else "DATEDIFF"
_, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1))
multiplier_sql = f" / {multiplier}" if multiplier > 1 else ""
diff_sql = f"{sql_func}({self.format_args(expression.this, expression.expression)})"
if months_between:
# MONTHS_BETWEEN returns a float, so we need to truncate the fractional part
diff_sql = f"CAST({diff_sql} AS INT)"
return f"{diff_sql}{multiplier_sql}"
@ -426,6 +432,8 @@ class Hive(Dialect):
exp.Property: _property_sql,
exp.AnyValue: rename_func("FIRST"),
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArgMax: arg_max_or_min_no_count("MAX_BY"),
exp.ArgMin: arg_max_or_min_no_count("MIN_BY"),
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayJoin: lambda self, e: self.func("CONCAT_WS", e.expression, e.this),
exp.ArraySize: rename_func("SIZE"),

View file

@ -21,7 +21,6 @@ from sqlglot.dialects.dialect import (
no_trycast_sql,
parse_date_delta_with_interval,
rename_func,
simplify_literal,
strposition_to_locate_sql,
)
from sqlglot.helper import seq_get
@ -689,6 +688,8 @@ class MySQL(Dialect):
LIMIT_FETCH = "LIMIT"
LIMIT_ONLY_LITERALS = True
# MySQL doesn't support many datatypes in cast.
# https://dev.mysql.com/doc/refman/8.0/en/cast-functions.html#function_cast
CAST_MAPPING = {
@ -712,16 +713,6 @@ class MySQL(Dialect):
result = f"{result} UNSIGNED"
return result
def limit_sql(self, expression: exp.Limit, top: bool = False) -> str:
# MySQL requires simple literal values for its LIMIT clause.
expression = simplify_literal(expression.copy())
return super().limit_sql(expression, top=top)
def offset_sql(self, expression: exp.Offset) -> str:
# MySQL requires simple literal values for its OFFSET clause.
expression = simplify_literal(expression.copy())
return super().offset_sql(expression)
def xor_sql(self, expression: exp.Xor) -> str:
if expression.expressions:
return self.expressions(expression, sep=" XOR ")

View file

@ -20,7 +20,6 @@ from sqlglot.dialects.dialect import (
no_trycast_sql,
parse_timestamp_trunc,
rename_func,
simplify_literal,
str_position_sql,
struct_extract_sql,
timestamptrunc_sql,
@ -49,7 +48,7 @@ def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, exp.DateAdd | ex
this = self.sql(expression, "this")
unit = expression.args.get("unit")
expression = simplify_literal(expression).expression
expression = self._simplify_unless_literal(expression.expression)
if not isinstance(expression, exp.Literal):
self.unsupported("Cannot add non literal")

View file

@ -15,6 +15,7 @@ from sqlglot.dialects.dialect import (
no_ilike_sql,
no_pivot_sql,
no_safe_divide_sql,
no_timestamp_sql,
regexp_extract_sql,
rename_func,
right_to_substring_sql,
@ -69,9 +70,10 @@ def _schema_sql(self: Presto.Generator, expression: exp.Schema) -> str:
if expression.parent:
for schema in expression.parent.find_all(exp.Schema):
if isinstance(schema.parent, exp.Property):
column_defs = schema.find_all(exp.ColumnDef)
if column_defs and isinstance(schema.parent, exp.Property):
expression = expression.copy()
expression.expressions.extend(schema.expressions)
expression.expressions.extend(column_defs)
return self.schema_sql(expression)
@ -252,6 +254,7 @@ class Presto(Dialect):
TZ_TO_WITH_TIME_ZONE = True
NVL2_SUPPORTED = False
STRUCT_DELIMITER = ("(", ")")
LIMIT_ONLY_LITERALS = True
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION,
@ -277,6 +280,8 @@ class Presto(Dialect):
exp.AnyValue: rename_func("ARBITRARY"),
exp.ApproxDistinct: _approx_distinct_sql,
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.ArgMax: rename_func("MAX_BY"),
exp.ArgMin: rename_func("MIN_BY"),
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayContains: rename_func("CONTAINS"),
@ -348,6 +353,7 @@ class Presto(Dialect):
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.StructExtract: struct_extract_sql,
exp.Table: transforms.preprocess([_unnest_sequence]),
exp.Timestamp: no_timestamp_sql,
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToDate: timestrtotime_sql,
exp.TimeStrToTime: timestrtotime_sql,
@ -367,7 +373,6 @@ class Presto(Dialect):
exp.WithinGroup: transforms.preprocess(
[transforms.remove_within_group_for_percentiles]
),
exp.Timestamp: transforms.preprocess([transforms.timestamp_to_cast]),
exp.Xor: bool_xor_sql,
}
@ -418,3 +423,15 @@ class Presto(Dialect):
self.sql(expression, "offset"),
self.sql(limit),
]
def create_sql(self, expression: exp.Create) -> str:
"""
Presto doesn't support CREATE VIEW with expressions (ex: `CREATE VIEW x (cola)` then `(cola)` is the expression),
so we need to remove them
"""
kind = expression.args["kind"]
schema = expression.this
if kind == "VIEW" and schema.expressions:
expression = expression.copy()
expression.this.set("expressions", None)
return super().create_sql(expression)

View file

@ -6,6 +6,7 @@ from sqlglot import exp, transforms
from sqlglot.dialects.dialect import (
concat_to_dpipe_sql,
concat_ws_to_dpipe_sql,
generatedasidentitycolumnconstraint_sql,
rename_func,
ts_or_ds_to_date_sql,
)
@ -171,8 +172,10 @@ class Redshift(Postgres):
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
exp.FromBase: rename_func("STRTOL"),
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
exp.JSONExtract: _json_sql,
exp.JSONExtractScalar: _json_sql,
exp.ParseJSON: rename_func("JSON_PARSE"),
exp.SafeConcat: concat_to_dpipe_sql,
exp.Select: transforms.preprocess(
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]

View file

@ -262,6 +262,7 @@ class Snowflake(Dialect):
),
"DATEDIFF": _parse_datediff,
"DIV0": _div0_to_if,
"FLATTEN": exp.Explode.from_arg_list,
"IFF": exp.If.from_arg_list,
"LISTAGG": exp.GroupConcat.from_arg_list,
"NULLIFZERO": _nullifzero_to_if,
@ -308,6 +309,7 @@ class Snowflake(Dialect):
expressions=self._parse_csv(self._parse_id_var),
unset=True,
),
"SWAP": lambda self: self._parse_alter_table_swap(),
}
STATEMENT_PARSERS = {
@ -325,6 +327,22 @@ class Snowflake(Dialect):
TokenType.MOD,
TokenType.SLASH,
}
FLATTEN_COLUMNS = ["SEQ", "KEY", "PATH", "INDEX", "VALUE", "THIS"]
def _parse_lateral(self) -> t.Optional[exp.Lateral]:
lateral = super()._parse_lateral()
if not lateral:
return lateral
if isinstance(lateral.this, exp.Explode):
table_alias = lateral.args.get("alias")
columns = [exp.to_identifier(col) for col in self.FLATTEN_COLUMNS]
if table_alias and not table_alias.args.get("columns"):
table_alias.set("columns", columns)
elif not table_alias:
exp.alias_(lateral, "_flattened", table=columns, copy=False)
return lateral
def _parse_table_parts(self, schema: bool = False) -> exp.Table:
# https://docs.snowflake.com/en/user-guide/querying-stage
@ -389,6 +407,10 @@ class Snowflake(Dialect):
return self.expression(exp.Show, this=this, scope=scope, scope_kind=scope_kind)
def _parse_alter_table_swap(self) -> exp.SwapTable:
self._match_text_seq("WITH")
return self.expression(exp.SwapTable, this=self._parse_table(schema=True))
class Tokenizer(tokens.Tokenizer):
STRING_ESCAPES = ["\\", "'"]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
@ -438,6 +460,8 @@ class Snowflake(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.ArgMax: rename_func("MAX_BY"),
exp.ArgMin: rename_func("MIN_BY"),
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
@ -451,7 +475,10 @@ class Snowflake(Dialect):
),
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.Explode: rename_func("FLATTEN"),
exp.Extract: rename_func("DATE_PART"),
exp.GenerateSeries: lambda self, e: self.func(
"ARRAY_GENERATE_RANGE", e.args["start"], e.args["end"] + 1, e.args.get("step")
@ -520,6 +547,12 @@ class Snowflake(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
def log_sql(self, expression: exp.Log) -> str:
if not expression.expression:
return self.func("LN", expression.this)
return super().log_sql(expression)
def unnest_sql(self, expression: exp.Unnest) -> str:
selects = ["value"]
unnest_alias = expression.args.get("alias")
@ -596,3 +629,7 @@ class Snowflake(Dialect):
increment = expression.args.get("increment")
increment = f" INCREMENT {increment}" if increment else ""
return f"AUTOINCREMENT{start}{increment}"
def swaptable_sql(self, expression: exp.SwapTable) -> str:
this = self.sql(expression, "this")
return f"SWAP WITH {this}"

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least
from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least, rename_func
from sqlglot.tokens import TokenType
@ -150,6 +150,7 @@ class Teradata(Dialect):
return self.expression(exp.RangeN, this=this, expressions=expressions, each=each)
class Generator(generator.Generator):
LIMIT_IS_TOP = True
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
@ -168,6 +169,8 @@ class Teradata(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.ArgMax: rename_func("MAX_BY"),
exp.ArgMin: rename_func("MIN_BY"),
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Select: transforms.preprocess(

View file

@ -8,6 +8,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
any_value_to_max_sql,
generatedasidentitycolumnconstraint_sql,
max_or_greatest,
min_or_least,
move_insert_cte_sql,
@ -603,6 +604,7 @@ class TSQL(Dialect):
exp.DataType.Type.DATETIME: "DATETIME2",
exp.DataType.Type.DOUBLE: "FLOAT",
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.TEXT: "VARCHAR(MAX)",
exp.DataType.Type.TIMESTAMP: "DATETIME2",
exp.DataType.Type.TIMESTAMPTZ: "DATETIMEOFFSET",
exp.DataType.Type.VARIANT: "SQL_VARIANT",
@ -617,6 +619,7 @@ class TSQL(Dialect):
exp.CurrentDate: rename_func("GETDATE"),
exp.CurrentTimestamp: rename_func("GETDATE"),
exp.Extract: rename_func("DATEPART"),
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
exp.GroupConcat: _string_agg_sql,
exp.If: rename_func("IIF"),
exp.Insert: move_insert_cte_sql,
@ -778,11 +781,3 @@ class TSQL(Dialect):
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True, sep=" ")
return f"CONSTRAINT {this} {expressions}"
# https://learn.microsoft.com/en-us/answers/questions/448821/create-table-in-sql-server
def generatedasidentitycolumnconstraint_sql(
self, expression: exp.GeneratedAsIdentityColumnConstraint
) -> str:
start = self.sql(expression, "start") or "1"
increment = self.sql(expression, "increment") or "1"
return f"IDENTITY({start}, {increment})"

View file

@ -23,7 +23,7 @@ from enum import auto
from functools import reduce
from sqlglot._typing import E
from sqlglot.errors import ParseError
from sqlglot.errors import ErrorLevel, ParseError
from sqlglot.helper import (
AutoName,
camel_to_snake_case,
@ -120,14 +120,14 @@ class Expression(metaclass=_Expression):
return hash((self.__class__, self.hashable_args))
@property
def this(self):
def this(self) -> t.Any:
"""
Retrieves the argument with key "this".
"""
return self.args.get("this")
@property
def expression(self):
def expression(self) -> t.Any:
"""
Retrieves the argument with key "expression".
"""
@ -1235,6 +1235,10 @@ class RenameTable(Expression):
pass
class SwapTable(Expression):
pass
class Comment(Expression):
arg_types = {"this": True, "kind": True, "expression": True, "exists": False}
@ -1979,7 +1983,7 @@ class ChecksumProperty(Property):
class CollateProperty(Property):
arg_types = {"this": True}
arg_types = {"this": True, "default": False}
class CopyGrantsProperty(Property):
@ -2607,11 +2611,11 @@ class Union(Subqueryable):
return self.this.unnest().selects
@property
def left(self):
def left(self) -> Expression:
return self.this
@property
def right(self):
def right(self) -> Expression:
return self.expression
@ -3700,7 +3704,9 @@ class DataType(Expression):
return DataType(this=DataType.Type.UNKNOWN, **kwargs)
try:
data_type_exp = parse_one(dtype, read=dialect, into=DataType)
data_type_exp = parse_one(
dtype, read=dialect, into=DataType, error_level=ErrorLevel.IGNORE
)
except ParseError:
if udt:
return DataType(this=DataType.Type.USERDEFINED, kind=dtype, **kwargs)
@ -3804,11 +3810,11 @@ class Binary(Condition):
arg_types = {"this": True, "expression": True}
@property
def left(self):
def left(self) -> Expression:
return self.this
@property
def right(self):
def right(self) -> Expression:
return self.expression
@ -4063,10 +4069,25 @@ class TimeUnit(Expression):
arg_types = {"unit": False}
UNABBREVIATED_UNIT_NAME = {
"d": "day",
"h": "hour",
"m": "minute",
"ms": "millisecond",
"ns": "nanosecond",
"q": "quarter",
"s": "second",
"us": "microsecond",
"w": "week",
"y": "year",
}
VAR_LIKE = (Column, Literal, Var)
def __init__(self, **args):
unit = args.get("unit")
if isinstance(unit, (Column, Literal)):
args["unit"] = Var(this=unit.name)
if isinstance(unit, self.VAR_LIKE):
args["unit"] = Var(this=self.UNABBREVIATED_UNIT_NAME.get(unit.name) or unit.name)
elif isinstance(unit, Week):
unit.set("this", Var(this=unit.this.name))
@ -4168,6 +4189,24 @@ class Abs(Func):
pass
class ArgMax(AggFunc):
arg_types = {"this": True, "expression": True, "count": False}
_sql_names = ["ARG_MAX", "ARGMAX", "MAX_BY"]
class ArgMin(AggFunc):
arg_types = {"this": True, "expression": True, "count": False}
_sql_names = ["ARG_MIN", "ARGMIN", "MIN_BY"]
class ApproxTopK(AggFunc):
arg_types = {"this": True, "expression": False, "counters": False}
class Flatten(Func):
pass
# https://spark.apache.org/docs/latest/api/sql/index.html#transform
class Transform(Func):
arg_types = {"this": True, "expression": True}
@ -4540,8 +4579,10 @@ class Exp(Func):
pass
# https://docs.snowflake.com/en/sql-reference/functions/flatten
class Explode(Func):
pass
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
class ExplodeOuter(Explode):
@ -4698,6 +4739,8 @@ class JSONArrayContains(Binary, Predicate, Func):
class ParseJSON(Func):
# BigQuery, Snowflake have PARSE_JSON, Presto has JSON_PARSE
_sql_names = ["PARSE_JSON", "JSON_PARSE"]
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
class Least(Func):
@ -4758,6 +4801,16 @@ class Lower(Func):
class Map(Func):
arg_types = {"keys": False, "values": False}
@property
def keys(self) -> t.List[Expression]:
keys = self.args.get("keys")
return keys.expressions if keys else []
@property
def values(self) -> t.List[Expression]:
values = self.args.get("values")
return values.expressions if values else []
class MapFromEntries(Func):
pass
@ -4870,6 +4923,7 @@ class RegexpReplace(Func):
"position": False,
"occurrence": False,
"parameters": False,
"modifiers": False,
}
@ -4877,7 +4931,7 @@ class RegexpLike(Binary, Func):
arg_types = {"this": True, "expression": True, "flag": False}
class RegexpILike(Func):
class RegexpILike(Binary, Func):
arg_types = {"this": True, "expression": True, "flag": False}

View file

@ -11,6 +11,9 @@ from sqlglot.helper import apply_index_offset, csv, seq_get
from sqlglot.time import format_time
from sqlglot.tokens import Tokenizer, TokenType
if t.TYPE_CHECKING:
from sqlglot._typing import E
logger = logging.getLogger("sqlglot")
@ -141,6 +144,9 @@ class Generator:
# Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH")
LIMIT_FETCH = "ALL"
# Whether or not limit and fetch allows expresions or just limits
LIMIT_ONLY_LITERALS = False
# Whether or not a table is allowed to be renamed with a db
RENAME_TABLE_WITH_DB = True
@ -341,6 +347,12 @@ class Generator:
exp.With,
)
# Expressions that should not have their comments generated in maybe_comment
EXCLUDE_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = (
exp.Binary,
exp.Union,
)
# Expressions that can remain unwrapped when appearing in the context of an INTERVAL
UNWRAPPED_INTERVAL_VALUES: t.Tuple[t.Type[exp.Expression], ...] = (
exp.Column,
@ -501,7 +513,7 @@ class Generator:
else None
)
if not comments or isinstance(expression, exp.Binary):
if not comments or isinstance(expression, self.EXCLUDE_COMMENTS):
return sql
comments_sql = " ".join(
@ -879,6 +891,10 @@ class Generator:
alias = self.sql(expression, "this")
columns = self.expressions(expression, key="columns", flat=True)
columns = f"({columns})" if columns else ""
if not alias and not self.UNNEST_COLUMN_ONLY:
alias = "_t"
return f"{alias}{columns}"
def bitstring_sql(self, expression: exp.BitString) -> str:
@ -1611,9 +1627,6 @@ class Generator:
def lateral_sql(self, expression: exp.Lateral) -> str:
this = self.sql(expression, "this")
if isinstance(expression.this, exp.Subquery):
return f"LATERAL {this}"
if expression.args.get("view"):
alias = expression.args["alias"]
columns = self.expressions(alias, key="columns", flat=True)
@ -1629,18 +1642,19 @@ class Generator:
def limit_sql(self, expression: exp.Limit, top: bool = False) -> str:
this = self.sql(expression, "this")
args = ", ".join(
sql
for sql in (
self.sql(expression, "offset"),
self.sql(expression, "expression"),
)
if sql
self.sql(self._simplify_unless_literal(e) if self.LIMIT_ONLY_LITERALS else e)
for e in (expression.args.get(k) for k in ("offset", "expression"))
if e
)
return f"{this}{self.seg('TOP' if top else 'LIMIT')} {args}"
def offset_sql(self, expression: exp.Offset) -> str:
this = self.sql(expression, "this")
return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}"
expression = expression.expression
expression = (
self._simplify_unless_literal(expression) if self.LIMIT_ONLY_LITERALS else expression
)
return f"{this}{self.seg('OFFSET')} {self.sql(expression)}"
def setitem_sql(self, expression: exp.SetItem) -> str:
kind = self.sql(expression, "kind")
@ -1895,12 +1909,13 @@ class Generator:
def schema_sql(self, expression: exp.Schema) -> str:
this = self.sql(expression, "this")
this = f"{this} " if this else ""
sql = self.schema_columns_sql(expression)
return f"{this}{sql}"
return f"{this} {sql}" if this and sql else this or sql
def schema_columns_sql(self, expression: exp.Schema) -> str:
return f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}"
if expression.expressions:
return f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}"
return ""
def star_sql(self, expression: exp.Star) -> str:
except_ = self.expressions(expression, key="except", flat=True)
@ -2708,8 +2723,8 @@ class Generator:
self.unsupported(f"Unsupported property {expression.__class__.__name__}")
return f"{property_name} {self.sql(expression, 'this')}"
def set_operation(self, expression: exp.Expression, op: str) -> str:
this = self.sql(expression, "this")
def set_operation(self, expression: exp.Union, op: str) -> str:
this = self.maybe_comment(self.sql(expression, "this"), comments=expression.comments)
op = self.seg(op)
return self.query_modifiers(
expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
@ -2912,6 +2927,14 @@ class Generator:
parameters = self.sql(expression, "params_struct")
return self.func("PREDICT", model, table, parameters or None)
def _simplify_unless_literal(self, expression: E) -> E:
if not isinstance(expression, exp.Literal):
from sqlglot.optimizer.simplify import simplify
expression = simplify(expression.copy())
return expression
def cached_generator(
cache: t.Optional[t.Dict[int, str]] = None

View file

@ -112,17 +112,34 @@ def lineage(
column
if isinstance(column, int)
else next(
i
for i, select in enumerate(scope.expression.selects)
if select.alias_or_name == column
(
i
for i, select in enumerate(scope.expression.selects)
if select.alias_or_name == column or select.is_star
),
-1, # mypy will not allow a None here, but a negative index should never be returned
)
)
if index == -1:
raise ValueError(f"Could not find {column} in {scope.expression}")
for s in scope.union_scopes:
to_node(index, scope=s, upstream=upstream)
return upstream
subquery = select.unalias()
if isinstance(subquery, exp.Subquery):
upstream = upstream or Node(name="SUBQUERY", source=scope.expression, expression=select)
scope = t.cast(Scope, build_scope(subquery.unnest()))
for select in subquery.named_selects:
to_node(select, scope=scope, upstream=upstream)
return upstream
if isinstance(scope.expression, exp.Select):
# For better ergonomics in our node labels, replace the full select with
# a version that has only the column we care about.
@ -142,8 +159,19 @@ def lineage(
if upstream:
upstream.downstream.append(node)
# if the select is a star add all scope sources as downstreams
if select.is_star:
for source in scope.sources.values():
node.downstream.append(Node(name=select.sql(), source=source, expression=source))
# Find all columns that went into creating this one to list their lineage nodes.
for c in set(select.find_all(exp.Column)):
source_columns = set(select.find_all(exp.Column))
# If the source is a UDTF find columns used in the UTDF to generate the table
if isinstance(source, exp.UDTF):
source_columns |= set(source.find_all(exp.Column))
for c in source_columns:
table = c.table
source = scope.sources.get(table)

View file

@ -6,7 +6,7 @@ import typing as t
from sqlglot import exp
from sqlglot._typing import E
from sqlglot.helper import ensure_list, subclasses
from sqlglot.helper import ensure_list, seq_get, subclasses
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema, ensure_schema
@ -271,6 +271,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Bracket: lambda self, e: self._annotate_bracket(e),
exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
@ -287,6 +288,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
@ -524,3 +526,24 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
self._set_type(expression, datatype)
return expression
def _annotate_bracket(self, expression: exp.Bracket) -> exp.Bracket:
self._annotate_args(expression)
bracket_arg = expression.expressions[0]
this = expression.this
if isinstance(bracket_arg, exp.Slice):
self._set_type(expression, this.type)
elif this.type.is_type(exp.DataType.Type.ARRAY):
contained_type = seq_get(this.type.expressions, 0) or exp.DataType.Type.UNKNOWN
self._set_type(expression, contained_type)
elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys:
index = this.keys.index(bracket_arg)
value = seq_get(this.values, index)
value_type = value.type if value else exp.DataType.Type.UNKNOWN
self._set_type(expression, value_type or exp.DataType.Type.UNKNOWN)
else:
self._set_type(expression, exp.DataType.Type.UNKNOWN)
return expression

View file

@ -69,7 +69,11 @@ def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression:
_replace_int_predicate(expression.left)
_replace_int_predicate(expression.right)
elif isinstance(expression, (exp.Where, exp.Having, exp.If)):
elif isinstance(expression, (exp.Where, exp.Having)) or (
# We can't replace num in CASE x WHEN num ..., because it's not the full predicate
isinstance(expression, exp.If)
and not (isinstance(expression.parent, exp.Case) and expression.parent.this)
):
_replace_int_predicate(expression.this)
return expression

View file

@ -70,6 +70,7 @@ def simplify(expression, constant_propagation=False):
node = uniq_sort(node, generate, root)
node = absorb_and_eliminate(node, root)
node = simplify_concat(node)
node = simplify_conditionals(node)
if constant_propagation:
node = propagate_constants(node, root)
@ -477,9 +478,11 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression:
return expression
if l.__class__ in INVERSE_DATE_OPS:
l = t.cast(exp.IntervalOp, l)
a = l.this
b = l.interval()
else:
l = t.cast(exp.Binary, l)
a, b = l.left, l.right
if not a_predicate(a) and b_predicate(b):
@ -695,6 +698,32 @@ def simplify_concat(expression):
return concat_type(expressions=new_args)
def simplify_conditionals(expression):
"""Simplifies expressions like IF, CASE if their condition is statically known."""
if isinstance(expression, exp.Case):
this = expression.this
for case in expression.args["ifs"]:
cond = case.this
if this:
# Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
cond = cond.replace(this.pop().eq(cond))
if always_true(cond):
return case.args["true"]
if always_false(cond):
case.pop()
if not expression.args["ifs"]:
return expression.args.get("default") or exp.null()
elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
if always_true(expression.this):
return expression.args["true"]
if always_false(expression.this):
return expression.args.get("false") or exp.null()
return expression
DateRange = t.Tuple[datetime.date, datetime.date]
@ -786,6 +815,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
else:
return expression
l = t.cast(exp.DateTrunc, l)
unit = l.unit.name.lower()
date = extract_date(r)
@ -798,6 +828,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
rs = expression.expressions
if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
l = t.cast(exp.DateTrunc, l)
unit = l.unit.name.lower()
ranges = []
@ -852,6 +883,10 @@ def always_true(expression):
)
def always_false(expression):
return is_false(expression) or is_null(expression)
def is_complement(a, b):
return isinstance(b, exp.Not) and b.this == a

View file

@ -313,6 +313,7 @@ class Parser(metaclass=_Parser):
TokenType.UNIQUE,
TokenType.UNPIVOT,
TokenType.UPDATE,
TokenType.USE,
TokenType.VOLATILE,
TokenType.WINDOW,
*CREATABLES,
@ -629,11 +630,14 @@ class Parser(metaclass=_Parser):
"ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty),
"AUTO_INCREMENT": lambda self: self._parse_property_assignment(exp.AutoIncrementProperty),
"BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(),
"CHARACTER SET": lambda self: self._parse_character_set(),
"CHARSET": lambda self, **kwargs: self._parse_character_set(**kwargs),
"CHARACTER SET": lambda self, **kwargs: self._parse_character_set(**kwargs),
"CHECKSUM": lambda self: self._parse_checksum(),
"CLUSTER BY": lambda self: self._parse_cluster(),
"CLUSTERED": lambda self: self._parse_clustered_by(),
"COLLATE": lambda self: self._parse_property_assignment(exp.CollateProperty),
"COLLATE": lambda self, **kwargs: self._parse_property_assignment(
exp.CollateProperty, **kwargs
),
"COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
"COPY": lambda self: self._parse_copy_property(),
"DATABLOCKSIZE": lambda self, **kwargs: self._parse_datablocksize(**kwargs),
@ -1443,8 +1447,8 @@ class Parser(metaclass=_Parser):
if self._match_texts(self.PROPERTY_PARSERS):
return self.PROPERTY_PARSERS[self._prev.text.upper()](self)
if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET):
return self._parse_character_set(default=True)
if self._match(TokenType.DEFAULT) and self._match_texts(self.PROPERTY_PARSERS):
return self.PROPERTY_PARSERS[self._prev.text.upper()](self, default=True)
if self._match_text_seq("COMPOUND", "SORTKEY"):
return self._parse_sortkey(compound=True)
@ -1480,10 +1484,10 @@ class Parser(metaclass=_Parser):
else self._parse_var_or_string() or self._parse_number() or self._parse_id_var(),
)
def _parse_property_assignment(self, exp_class: t.Type[E]) -> E:
def _parse_property_assignment(self, exp_class: t.Type[E], **kwargs: t.Any) -> E:
self._match(TokenType.EQ)
self._match(TokenType.ALIAS)
return self.expression(exp_class, this=self._parse_field())
return self.expression(exp_class, this=self._parse_field(), **kwargs)
def _parse_properties(self, before: t.Optional[bool] = None) -> t.Optional[exp.Properties]:
properties = []
@ -2426,9 +2430,9 @@ class Parser(metaclass=_Parser):
table_alias: t.Optional[exp.TableAlias] = self.expression(
exp.TableAlias, this=table, columns=columns
)
elif isinstance(this, exp.Subquery) and this.alias:
# Ensures parity between the Subquery's and the Lateral's "alias" args
table_alias = this.args["alias"].copy()
elif isinstance(this, (exp.Subquery, exp.Unnest)) and this.alias:
# We move the alias from the lateral's child node to the lateral itself
table_alias = this.args["alias"].pop()
else:
table_alias = self._parse_table_alias()
@ -2952,6 +2956,7 @@ class Parser(metaclass=_Parser):
cube = None
totals = None
index = self._index
with_ = self._match(TokenType.WITH)
if self._match(TokenType.ROLLUP):
rollup = with_ or self._parse_wrapped_csv(self._parse_column)
@ -2966,6 +2971,8 @@ class Parser(metaclass=_Parser):
elements["totals"] = True # type: ignore
if not (grouping_sets or rollup or cube or totals):
if with_:
self._retreat(index)
break
return self.expression(exp.Group, **elements) # type: ignore
@ -3157,6 +3164,7 @@ class Parser(metaclass=_Parser):
return self.expression(
expression,
comments=self._prev.comments,
this=this,
distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL),
by_name=self._match_text_seq("BY", "NAME"),
@ -3618,6 +3626,32 @@ class Parser(metaclass=_Parser):
functions: t.Optional[t.Dict[str, t.Callable]] = None,
anonymous: bool = False,
optional_parens: bool = True,
) -> t.Optional[exp.Expression]:
# This allows us to also parse {fn <function>} syntax (Snowflake, MySQL support this)
# See: https://community.snowflake.com/s/article/SQL-Escape-Sequences
fn_syntax = False
if (
self._match(TokenType.L_BRACE, advance=False)
and self._next
and self._next.text.upper() == "FN"
):
self._advance(2)
fn_syntax = True
func = self._parse_function_call(
functions=functions, anonymous=anonymous, optional_parens=optional_parens
)
if fn_syntax:
self._match(TokenType.R_BRACE)
return func
def _parse_function_call(
self,
functions: t.Optional[t.Dict[str, t.Callable]] = None,
anonymous: bool = False,
optional_parens: bool = True,
) -> t.Optional[exp.Expression]:
if not self._curr:
return None
@ -3856,6 +3890,10 @@ class Parser(metaclass=_Parser):
if not identity:
this.set("expression", self._parse_bitwise())
elif not this.args.get("start") and self._match(TokenType.NUMBER, advance=False):
args = self._parse_csv(self._parse_bitwise)
this.set("start", seq_get(args, 0))
this.set("increment", seq_get(args, 1))
self._match_r_paren()
@ -4039,6 +4077,11 @@ class Parser(metaclass=_Parser):
)
)
if not self._match(TokenType.R_BRACKET) and bracket_kind == TokenType.L_BRACKET:
self.raise_error("Expected ]")
elif not self._match(TokenType.R_BRACE) and bracket_kind == TokenType.L_BRACE:
self.raise_error("Expected }")
# https://duckdb.org/docs/sql/data_types/struct.html#creating-structs
if bracket_kind == TokenType.L_BRACE:
this = self.expression(exp.Struct, expressions=expressions)
@ -4048,11 +4091,6 @@ class Parser(metaclass=_Parser):
expressions = apply_index_offset(this, expressions, -self.INDEX_OFFSET)
this = self.expression(exp.Bracket, this=this, expressions=expressions)
if not self._match(TokenType.R_BRACKET) and bracket_kind == TokenType.L_BRACKET:
self.raise_error("Expected ]")
elif not self._match(TokenType.R_BRACE) and bracket_kind == TokenType.L_BRACE:
self.raise_error("Expected }")
self._add_comments(this)
return self._parse_bracket(this)

View file

@ -54,3 +54,606 @@ def format_time(
chunks.append(chars)
return "".join(mapping.get(chars, chars) for chars in chunks)
TIMEZONES = {
tz.lower()
for tz in (
"Africa/Abidjan",
"Africa/Accra",
"Africa/Addis_Ababa",
"Africa/Algiers",
"Africa/Asmara",
"Africa/Asmera",
"Africa/Bamako",
"Africa/Bangui",
"Africa/Banjul",
"Africa/Bissau",
"Africa/Blantyre",
"Africa/Brazzaville",
"Africa/Bujumbura",
"Africa/Cairo",
"Africa/Casablanca",
"Africa/Ceuta",
"Africa/Conakry",
"Africa/Dakar",
"Africa/Dar_es_Salaam",
"Africa/Djibouti",
"Africa/Douala",
"Africa/El_Aaiun",
"Africa/Freetown",
"Africa/Gaborone",
"Africa/Harare",
"Africa/Johannesburg",
"Africa/Juba",
"Africa/Kampala",
"Africa/Khartoum",
"Africa/Kigali",
"Africa/Kinshasa",
"Africa/Lagos",
"Africa/Libreville",
"Africa/Lome",
"Africa/Luanda",
"Africa/Lubumbashi",
"Africa/Lusaka",
"Africa/Malabo",
"Africa/Maputo",
"Africa/Maseru",
"Africa/Mbabane",
"Africa/Mogadishu",
"Africa/Monrovia",
"Africa/Nairobi",
"Africa/Ndjamena",
"Africa/Niamey",
"Africa/Nouakchott",
"Africa/Ouagadougou",
"Africa/Porto-Novo",
"Africa/Sao_Tome",
"Africa/Timbuktu",
"Africa/Tripoli",
"Africa/Tunis",
"Africa/Windhoek",
"America/Adak",
"America/Anchorage",
"America/Anguilla",
"America/Antigua",
"America/Araguaina",
"America/Argentina/Buenos_Aires",
"America/Argentina/Catamarca",
"America/Argentina/ComodRivadavia",
"America/Argentina/Cordoba",
"America/Argentina/Jujuy",
"America/Argentina/La_Rioja",
"America/Argentina/Mendoza",
"America/Argentina/Rio_Gallegos",
"America/Argentina/Salta",
"America/Argentina/San_Juan",
"America/Argentina/San_Luis",
"America/Argentina/Tucuman",
"America/Argentina/Ushuaia",
"America/Aruba",
"America/Asuncion",
"America/Atikokan",
"America/Atka",
"America/Bahia",
"America/Bahia_Banderas",
"America/Barbados",
"America/Belem",
"America/Belize",
"America/Blanc-Sablon",
"America/Boa_Vista",
"America/Bogota",
"America/Boise",
"America/Buenos_Aires",
"America/Cambridge_Bay",
"America/Campo_Grande",
"America/Cancun",
"America/Caracas",
"America/Catamarca",
"America/Cayenne",
"America/Cayman",
"America/Chicago",
"America/Chihuahua",
"America/Ciudad_Juarez",
"America/Coral_Harbour",
"America/Cordoba",
"America/Costa_Rica",
"America/Creston",
"America/Cuiaba",
"America/Curacao",
"America/Danmarkshavn",
"America/Dawson",
"America/Dawson_Creek",
"America/Denver",
"America/Detroit",
"America/Dominica",
"America/Edmonton",
"America/Eirunepe",
"America/El_Salvador",
"America/Ensenada",
"America/Fort_Nelson",
"America/Fort_Wayne",
"America/Fortaleza",
"America/Glace_Bay",
"America/Godthab",
"America/Goose_Bay",
"America/Grand_Turk",
"America/Grenada",
"America/Guadeloupe",
"America/Guatemala",
"America/Guayaquil",
"America/Guyana",
"America/Halifax",
"America/Havana",
"America/Hermosillo",
"America/Indiana/Indianapolis",
"America/Indiana/Knox",
"America/Indiana/Marengo",
"America/Indiana/Petersburg",
"America/Indiana/Tell_City",
"America/Indiana/Vevay",
"America/Indiana/Vincennes",
"America/Indiana/Winamac",
"America/Indianapolis",
"America/Inuvik",
"America/Iqaluit",
"America/Jamaica",
"America/Jujuy",
"America/Juneau",
"America/Kentucky/Louisville",
"America/Kentucky/Monticello",
"America/Knox_IN",
"America/Kralendijk",
"America/La_Paz",
"America/Lima",
"America/Los_Angeles",
"America/Louisville",
"America/Lower_Princes",
"America/Maceio",
"America/Managua",
"America/Manaus",
"America/Marigot",
"America/Martinique",
"America/Matamoros",
"America/Mazatlan",
"America/Mendoza",
"America/Menominee",
"America/Merida",
"America/Metlakatla",
"America/Mexico_City",
"America/Miquelon",
"America/Moncton",
"America/Monterrey",
"America/Montevideo",
"America/Montreal",
"America/Montserrat",
"America/Nassau",
"America/New_York",
"America/Nipigon",
"America/Nome",
"America/Noronha",
"America/North_Dakota/Beulah",
"America/North_Dakota/Center",
"America/North_Dakota/New_Salem",
"America/Nuuk",
"America/Ojinaga",
"America/Panama",
"America/Pangnirtung",
"America/Paramaribo",
"America/Phoenix",
"America/Port-au-Prince",
"America/Port_of_Spain",
"America/Porto_Acre",
"America/Porto_Velho",
"America/Puerto_Rico",
"America/Punta_Arenas",
"America/Rainy_River",
"America/Rankin_Inlet",
"America/Recife",
"America/Regina",
"America/Resolute",
"America/Rio_Branco",
"America/Rosario",
"America/Santa_Isabel",
"America/Santarem",
"America/Santiago",
"America/Santo_Domingo",
"America/Sao_Paulo",
"America/Scoresbysund",
"America/Shiprock",
"America/Sitka",
"America/St_Barthelemy",
"America/St_Johns",
"America/St_Kitts",
"America/St_Lucia",
"America/St_Thomas",
"America/St_Vincent",
"America/Swift_Current",
"America/Tegucigalpa",
"America/Thule",
"America/Thunder_Bay",
"America/Tijuana",
"America/Toronto",
"America/Tortola",
"America/Vancouver",
"America/Virgin",
"America/Whitehorse",
"America/Winnipeg",
"America/Yakutat",
"America/Yellowknife",
"Antarctica/Casey",
"Antarctica/Davis",
"Antarctica/DumontDUrville",
"Antarctica/Macquarie",
"Antarctica/Mawson",
"Antarctica/McMurdo",
"Antarctica/Palmer",
"Antarctica/Rothera",
"Antarctica/South_Pole",
"Antarctica/Syowa",
"Antarctica/Troll",
"Antarctica/Vostok",
"Arctic/Longyearbyen",
"Asia/Aden",
"Asia/Almaty",
"Asia/Amman",
"Asia/Anadyr",
"Asia/Aqtau",
"Asia/Aqtobe",
"Asia/Ashgabat",
"Asia/Ashkhabad",
"Asia/Atyrau",
"Asia/Baghdad",
"Asia/Bahrain",
"Asia/Baku",
"Asia/Bangkok",
"Asia/Barnaul",
"Asia/Beirut",
"Asia/Bishkek",
"Asia/Brunei",
"Asia/Calcutta",
"Asia/Chita",
"Asia/Choibalsan",
"Asia/Chongqing",
"Asia/Chungking",
"Asia/Colombo",
"Asia/Dacca",
"Asia/Damascus",
"Asia/Dhaka",
"Asia/Dili",
"Asia/Dubai",
"Asia/Dushanbe",
"Asia/Famagusta",
"Asia/Gaza",
"Asia/Harbin",
"Asia/Hebron",
"Asia/Ho_Chi_Minh",
"Asia/Hong_Kong",
"Asia/Hovd",
"Asia/Irkutsk",
"Asia/Istanbul",
"Asia/Jakarta",
"Asia/Jayapura",
"Asia/Jerusalem",
"Asia/Kabul",
"Asia/Kamchatka",
"Asia/Karachi",
"Asia/Kashgar",
"Asia/Kathmandu",
"Asia/Katmandu",
"Asia/Khandyga",
"Asia/Kolkata",
"Asia/Krasnoyarsk",
"Asia/Kuala_Lumpur",
"Asia/Kuching",
"Asia/Kuwait",
"Asia/Macao",
"Asia/Macau",
"Asia/Magadan",
"Asia/Makassar",
"Asia/Manila",
"Asia/Muscat",
"Asia/Nicosia",
"Asia/Novokuznetsk",
"Asia/Novosibirsk",
"Asia/Omsk",
"Asia/Oral",
"Asia/Phnom_Penh",
"Asia/Pontianak",
"Asia/Pyongyang",
"Asia/Qatar",
"Asia/Qostanay",
"Asia/Qyzylorda",
"Asia/Rangoon",
"Asia/Riyadh",
"Asia/Saigon",
"Asia/Sakhalin",
"Asia/Samarkand",
"Asia/Seoul",
"Asia/Shanghai",
"Asia/Singapore",
"Asia/Srednekolymsk",
"Asia/Taipei",
"Asia/Tashkent",
"Asia/Tbilisi",
"Asia/Tehran",
"Asia/Tel_Aviv",
"Asia/Thimbu",
"Asia/Thimphu",
"Asia/Tokyo",
"Asia/Tomsk",
"Asia/Ujung_Pandang",
"Asia/Ulaanbaatar",
"Asia/Ulan_Bator",
"Asia/Urumqi",
"Asia/Ust-Nera",
"Asia/Vientiane",
"Asia/Vladivostok",
"Asia/Yakutsk",
"Asia/Yangon",
"Asia/Yekaterinburg",
"Asia/Yerevan",
"Atlantic/Azores",
"Atlantic/Bermuda",
"Atlantic/Canary",
"Atlantic/Cape_Verde",
"Atlantic/Faeroe",
"Atlantic/Faroe",
"Atlantic/Jan_Mayen",
"Atlantic/Madeira",
"Atlantic/Reykjavik",
"Atlantic/South_Georgia",
"Atlantic/St_Helena",
"Atlantic/Stanley",
"Australia/ACT",
"Australia/Adelaide",
"Australia/Brisbane",
"Australia/Broken_Hill",
"Australia/Canberra",
"Australia/Currie",
"Australia/Darwin",
"Australia/Eucla",
"Australia/Hobart",
"Australia/LHI",
"Australia/Lindeman",
"Australia/Lord_Howe",
"Australia/Melbourne",
"Australia/NSW",
"Australia/North",
"Australia/Perth",
"Australia/Queensland",
"Australia/South",
"Australia/Sydney",
"Australia/Tasmania",
"Australia/Victoria",
"Australia/West",
"Australia/Yancowinna",
"Brazil/Acre",
"Brazil/DeNoronha",
"Brazil/East",
"Brazil/West",
"CET",
"CST6CDT",
"Canada/Atlantic",
"Canada/Central",
"Canada/Eastern",
"Canada/Mountain",
"Canada/Newfoundland",
"Canada/Pacific",
"Canada/Saskatchewan",
"Canada/Yukon",
"Chile/Continental",
"Chile/EasterIsland",
"Cuba",
"EET",
"EST",
"EST5EDT",
"Egypt",
"Eire",
"Etc/GMT",
"Etc/GMT+0",
"Etc/GMT+1",
"Etc/GMT+10",
"Etc/GMT+11",
"Etc/GMT+12",
"Etc/GMT+2",
"Etc/GMT+3",
"Etc/GMT+4",
"Etc/GMT+5",
"Etc/GMT+6",
"Etc/GMT+7",
"Etc/GMT+8",
"Etc/GMT+9",
"Etc/GMT-0",
"Etc/GMT-1",
"Etc/GMT-10",
"Etc/GMT-11",
"Etc/GMT-12",
"Etc/GMT-13",
"Etc/GMT-14",
"Etc/GMT-2",
"Etc/GMT-3",
"Etc/GMT-4",
"Etc/GMT-5",
"Etc/GMT-6",
"Etc/GMT-7",
"Etc/GMT-8",
"Etc/GMT-9",
"Etc/GMT0",
"Etc/Greenwich",
"Etc/UCT",
"Etc/UTC",
"Etc/Universal",
"Etc/Zulu",
"Europe/Amsterdam",
"Europe/Andorra",
"Europe/Astrakhan",
"Europe/Athens",
"Europe/Belfast",
"Europe/Belgrade",
"Europe/Berlin",
"Europe/Bratislava",
"Europe/Brussels",
"Europe/Bucharest",
"Europe/Budapest",
"Europe/Busingen",
"Europe/Chisinau",
"Europe/Copenhagen",
"Europe/Dublin",
"Europe/Gibraltar",
"Europe/Guernsey",
"Europe/Helsinki",
"Europe/Isle_of_Man",
"Europe/Istanbul",
"Europe/Jersey",
"Europe/Kaliningrad",
"Europe/Kiev",
"Europe/Kirov",
"Europe/Kyiv",
"Europe/Lisbon",
"Europe/Ljubljana",
"Europe/London",
"Europe/Luxembourg",
"Europe/Madrid",
"Europe/Malta",
"Europe/Mariehamn",
"Europe/Minsk",
"Europe/Monaco",
"Europe/Moscow",
"Europe/Nicosia",
"Europe/Oslo",
"Europe/Paris",
"Europe/Podgorica",
"Europe/Prague",
"Europe/Riga",
"Europe/Rome",
"Europe/Samara",
"Europe/San_Marino",
"Europe/Sarajevo",
"Europe/Saratov",
"Europe/Simferopol",
"Europe/Skopje",
"Europe/Sofia",
"Europe/Stockholm",
"Europe/Tallinn",
"Europe/Tirane",
"Europe/Tiraspol",
"Europe/Ulyanovsk",
"Europe/Uzhgorod",
"Europe/Vaduz",
"Europe/Vatican",
"Europe/Vienna",
"Europe/Vilnius",
"Europe/Volgograd",
"Europe/Warsaw",
"Europe/Zagreb",
"Europe/Zaporozhye",
"Europe/Zurich",
"GB",
"GB-Eire",
"GMT",
"GMT+0",
"GMT-0",
"GMT0",
"Greenwich",
"HST",
"Hongkong",
"Iceland",
"Indian/Antananarivo",
"Indian/Chagos",
"Indian/Christmas",
"Indian/Cocos",
"Indian/Comoro",
"Indian/Kerguelen",
"Indian/Mahe",
"Indian/Maldives",
"Indian/Mauritius",
"Indian/Mayotte",
"Indian/Reunion",
"Iran",
"Israel",
"Jamaica",
"Japan",
"Kwajalein",
"Libya",
"MET",
"MST",
"MST7MDT",
"Mexico/BajaNorte",
"Mexico/BajaSur",
"Mexico/General",
"NZ",
"NZ-CHAT",
"Navajo",
"PRC",
"PST8PDT",
"Pacific/Apia",
"Pacific/Auckland",
"Pacific/Bougainville",
"Pacific/Chatham",
"Pacific/Chuuk",
"Pacific/Easter",
"Pacific/Efate",
"Pacific/Enderbury",
"Pacific/Fakaofo",
"Pacific/Fiji",
"Pacific/Funafuti",
"Pacific/Galapagos",
"Pacific/Gambier",
"Pacific/Guadalcanal",
"Pacific/Guam",
"Pacific/Honolulu",
"Pacific/Johnston",
"Pacific/Kanton",
"Pacific/Kiritimati",
"Pacific/Kosrae",
"Pacific/Kwajalein",
"Pacific/Majuro",
"Pacific/Marquesas",
"Pacific/Midway",
"Pacific/Nauru",
"Pacific/Niue",
"Pacific/Norfolk",
"Pacific/Noumea",
"Pacific/Pago_Pago",
"Pacific/Palau",
"Pacific/Pitcairn",
"Pacific/Pohnpei",
"Pacific/Ponape",
"Pacific/Port_Moresby",
"Pacific/Rarotonga",
"Pacific/Saipan",
"Pacific/Samoa",
"Pacific/Tahiti",
"Pacific/Tarawa",
"Pacific/Tongatapu",
"Pacific/Truk",
"Pacific/Wake",
"Pacific/Wallis",
"Pacific/Yap",
"Poland",
"Portugal",
"ROC",
"ROK",
"Singapore",
"Turkey",
"UCT",
"US/Alaska",
"US/Aleutian",
"US/Arizona",
"US/Central",
"US/East-Indiana",
"US/Eastern",
"US/Hawaii",
"US/Indiana-Starke",
"US/Michigan",
"US/Mountain",
"US/Pacific",
"US/Samoa",
"UTC",
"Universal",
"W-SU",
"WET",
"Zulu",
)
}

View file

@ -1077,10 +1077,10 @@ class Tokenizer(metaclass=_Tokenizer):
literal = ""
while self._peek.strip() and self._peek not in self.SINGLE_TOKENS:
literal += self._peek.upper()
literal += self._peek
self._advance()
token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal, ""))
token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal.upper(), ""))
if token_type:
self._add(TokenType.NUMBER, number_text)

View file

@ -164,8 +164,9 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
"""Convert explode/posexplode into unnest (used in hive -> presto)."""
def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
"""Convert explode/posexplode into unnest (used in hive -> presto)."""
if isinstance(expression, exp.Select):
from sqlglot.optimizer.scope import Scope
@ -297,6 +298,7 @@ PERCENTILES = (exp.PercentileCont, exp.PercentileDisc)
def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
"""Transforms percentiles by adding a WITHIN GROUP clause to them."""
if (
isinstance(expression, PERCENTILES)
and not isinstance(expression.parent, exp.WithinGroup)
@ -311,6 +313,7 @@ def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expressi
def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
"""Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
if (
isinstance(expression, exp.WithinGroup)
and isinstance(expression.this, PERCENTILES)
@ -324,6 +327,7 @@ def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expre
def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
"""Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
if isinstance(expression, exp.With) and expression.recursive:
next_name = name_sequence("_c_")
@ -342,6 +346,7 @@ def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression
def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
"""Replace 'epoch' in casts by the equivalent date literal."""
if (
isinstance(expression, (exp.Cast, exp.TryCast))
and expression.name.lower() == "epoch"
@ -352,16 +357,8 @@ def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
return expression
def timestamp_to_cast(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Timestamp) and not expression.expression:
return exp.cast(
expression.this,
to=exp.DataType.Type.TIMESTAMP,
)
return expression
def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
"""Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
if isinstance(expression, exp.Select):
for join in expression.args.get("joins") or []:
on = join.args.get("on")