1
0
Fork 0

Merging upstream version 20.11.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:19:58 +01:00
parent 1bce3d0317
commit e71ccc03da
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
141 changed files with 66644 additions and 54334 deletions

View file

@ -87,13 +87,11 @@ def parse(
@t.overload
def parse_one(sql: str, *, into: t.Type[E], **opts) -> E:
...
def parse_one(sql: str, *, into: t.Type[E], **opts) -> E: ...
@t.overload
def parse_one(sql: str, **opts) -> Expression:
...
def parse_one(sql: str, **opts) -> Expression: ...
def parse_one(

View file

@ -4,10 +4,13 @@ import typing as t
import sqlglot
if t.TYPE_CHECKING:
from typing_extensions import Literal as Lit # noqa
# A little hack for backwards compatibility with Python 3.7.
# For example, we might want a TypeVar for objects that support comparison e.g. SupportsRichComparisonT from typeshed.
# But Python 3.7 doesn't support Protocols, so we'd also need typing_extensions, which we don't want as a dependency.
A = t.TypeVar("A", bound=t.Any)
B = t.TypeVar("B", bound="sqlglot.exp.Binary")
E = t.TypeVar("E", bound="sqlglot.exp.Expression")
T = t.TypeVar("T")

View file

@ -144,9 +144,11 @@ class Column:
) -> Column:
ensured_column = None if column is None else cls.ensure_col(column)
ensure_expression_values = {
k: [Column.ensure_col(x).expression for x in v]
if is_iterable(v)
else Column.ensure_col(v).expression
k: (
[Column.ensure_col(x).expression for x in v]
if is_iterable(v)
else Column.ensure_col(v).expression
)
for k, v in kwargs.items()
if v is not None
}

View file

@ -140,12 +140,10 @@ class DataFrame:
return cte, name
@t.overload
def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]:
...
def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: ...
@t.overload
def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]:
...
def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: ...
def _ensure_list_of_columns(self, cols):
return Column.ensure_cols(ensure_list(cols))
@ -496,9 +494,11 @@ class DataFrame:
join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs]
# To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list
select_column_names = [
column.alias_or_name
if not isinstance(column.expression.this, exp.Star)
else column.sql()
(
column.alias_or_name
if not isinstance(column.expression.this, exp.Star)
else column.sql()
)
for column in self_columns + other_columns
]
select_column_names = [
@ -552,9 +552,11 @@ class DataFrame:
), "The length of items in ascending must equal the number of columns provided"
col_and_ascending = list(zip(columns, ascending))
order_by_columns = [
exp.Ordered(this=col.expression, desc=not asc)
if i not in pre_ordered_col_indexes
else columns[i].column_expression
(
exp.Ordered(this=col.expression, desc=not asc)
if i not in pre_ordered_col_indexes
else columns[i].column_expression
)
for i, (col, asc) in enumerate(col_and_ascending)
]
return self.copy(expression=self.expression.order_by(*order_by_columns))

View file

@ -661,7 +661,7 @@ def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column:
def to_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column:
tz_column = tz if isinstance(tz, Column) else lit(tz)
return Column.invoke_anonymous_function(timestamp, "TO_UTC_TIMESTAMP", tz_column)
return Column.invoke_expression_over_column(timestamp, expression.FromTimeZone, zone=tz_column)
def timestamp_seconds(col: ColumnOrName) -> Column:

View file

@ -7,11 +7,11 @@ from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
from sqlglot.helper import ensure_list
NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.session import SparkSession
NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[NORMALIZE_INPUT]):
expr = ensure_list(expr)

View file

@ -82,9 +82,11 @@ class SparkSession:
]
sel_columns = [
F.col(name).cast(data_type).alias(name).expression
if data_type is not None
else F.col(name).expression
(
F.col(name).cast(data_type).alias(name).expression
if data_type is not None
else F.col(name).expression
)
for name, data_type in column_mapping.items()
]

View file

@ -90,9 +90,11 @@ class WindowSpec:
**kwargs,
**{
"start_side": "PRECEDING",
"start": "UNBOUNDED"
if start <= Window.unboundedPreceding
else F.lit(start).expression,
"start": (
"UNBOUNDED"
if start <= Window.unboundedPreceding
else F.lit(start).expression
),
},
}
if end == Window.currentRow:
@ -102,9 +104,9 @@ class WindowSpec:
**kwargs,
**{
"end_side": "FOLLOWING",
"end": "UNBOUNDED"
if end >= Window.unboundedFollowing
else F.lit(end).expression,
"end": (
"UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression
),
},
}
return kwargs

View file

@ -5,7 +5,6 @@ import re
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot._typing import E
from sqlglot.dialects.dialect import (
Dialect,
NormalizationStrategy,
@ -30,7 +29,7 @@ from sqlglot.helper import seq_get, split_num_words
from sqlglot.tokens import TokenType
if t.TYPE_CHECKING:
from typing_extensions import Literal
from sqlglot._typing import E, Lit
logger = logging.getLogger("sqlglot")
@ -47,9 +46,11 @@ def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Va
exp.alias_(value, column_name)
for value, column_name in zip(
t.expressions,
alias.columns
if alias and alias.columns
else (f"_c{i}" for i in range(len(t.expressions))),
(
alias.columns
if alias and alias.columns
else (f"_c{i}" for i in range(len(t.expressions)))
),
)
]
)
@ -473,12 +474,10 @@ class BigQuery(Dialect):
return table
@t.overload
def _parse_json_object(self, agg: Literal[False]) -> exp.JSONObject:
...
def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ...
@t.overload
def _parse_json_object(self, agg: Literal[True]) -> exp.JSONObjectAgg:
...
def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ...
def _parse_json_object(self, agg=False):
json_object = super()._parse_json_object()
@ -546,9 +545,11 @@ class BigQuery(Dialect):
exp.ArrayContains: _array_contains_sql,
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
exp.CollateProperty: lambda self, e: f"DEFAULT COLLATE {self.sql(e, 'this')}"
if e.args.get("default")
else f"COLLATE {self.sql(e, 'this')}",
exp.CollateProperty: lambda self, e: (
f"DEFAULT COLLATE {self.sql(e, 'this')}"
if e.args.get("default")
else f"COLLATE {self.sql(e, 'this')}"
),
exp.CountIf: rename_func("COUNTIF"),
exp.Create: _create_sql,
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
@ -560,6 +561,9 @@ class BigQuery(Dialect):
exp.DatetimeAdd: date_add_interval_sql("DATETIME", "ADD"),
exp.DatetimeSub: date_add_interval_sql("DATETIME", "SUB"),
exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")),
exp.FromTimeZone: lambda self, e: self.func(
"DATETIME", self.func("TIMESTAMP", e.this, e.args.get("zone")), "'UTC'"
),
exp.GenerateSeries: rename_func("GENERATE_ARRAY"),
exp.GetPath: path_to_jsonpath(),
exp.GroupConcat: rename_func("STRING_AGG"),
@ -595,9 +599,9 @@ class BigQuery(Dialect):
exp.SHA2: lambda self, e: self.func(
f"SHA256" if e.text("length") == "256" else "SHA512", e.this
),
exp.StabilityProperty: lambda self, e: f"DETERMINISTIC"
if e.name == "IMMUTABLE"
else "NOT DETERMINISTIC",
exp.StabilityProperty: lambda self, e: (
f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC"
),
exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
exp.StrToTime: lambda self, e: self.func(
"PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone")

View file

@ -88,6 +88,8 @@ class ClickHouse(Dialect):
"UINT8": TokenType.UTINYINT,
"IPV4": TokenType.IPV4,
"IPV6": TokenType.IPV6,
"AGGREGATEFUNCTION": TokenType.AGGREGATEFUNCTION,
"SIMPLEAGGREGATEFUNCTION": TokenType.SIMPLEAGGREGATEFUNCTION,
}
SINGLE_TOKENS = {
@ -548,6 +550,8 @@ class ClickHouse(Dialect):
exp.DataType.Type.UTINYINT: "UInt8",
exp.DataType.Type.IPV4: "IPv4",
exp.DataType.Type.IPV6: "IPv6",
exp.DataType.Type.AGGREGATEFUNCTION: "AggregateFunction",
exp.DataType.Type.SIMPLEAGGREGATEFUNCTION: "SimpleAggregateFunction",
}
TRANSFORMS = {
@ -651,12 +655,16 @@ class ClickHouse(Dialect):
def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]:
return super().after_limit_modifiers(expression) + [
self.seg("SETTINGS ") + self.expressions(expression, key="settings", flat=True)
if expression.args.get("settings")
else "",
self.seg("FORMAT ") + self.sql(expression, "format")
if expression.args.get("format")
else "",
(
self.seg("SETTINGS ") + self.expressions(expression, key="settings", flat=True)
if expression.args.get("settings")
else ""
),
(
self.seg("FORMAT ") + self.sql(expression, "format")
if expression.args.get("format")
else ""
),
]
def parameterizedagg_sql(self, expression: exp.ParameterizedAgg) -> str:

View file

@ -5,7 +5,6 @@ from enum import Enum, auto
from functools import reduce
from sqlglot import exp
from sqlglot._typing import E
from sqlglot.errors import ParseError
from sqlglot.generator import Generator
from sqlglot.helper import AutoName, flatten, seq_get
@ -14,11 +13,12 @@ from sqlglot.time import TIMEZONES, format_time
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import new_trie
B = t.TypeVar("B", bound=exp.Binary)
DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
if t.TYPE_CHECKING:
from sqlglot._typing import B, E
class Dialects(str, Enum):
"""Dialects supported by SQLGLot."""
@ -381,9 +381,11 @@ class Dialect(metaclass=_Dialect):
):
expression.set(
"this",
expression.this.upper()
if self.normalization_strategy is NormalizationStrategy.UPPERCASE
else expression.this.lower(),
(
expression.this.upper()
if self.normalization_strategy is NormalizationStrategy.UPPERCASE
else expression.this.lower()
),
)
return expression
@ -877,9 +879,11 @@ def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectTyp
Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
"""
agg_all_unquoted = agg.transform(
lambda node: exp.Identifier(this=node.name, quoted=False)
if isinstance(node, exp.Identifier)
else node
lambda node: (
exp.Identifier(this=node.name, quoted=False)
if isinstance(node, exp.Identifier)
else node
)
)
names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
@ -999,10 +1003,8 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
"""Remove table refs from columns in when statements."""
alias = expression.this.args.get("alias")
normalize = (
lambda identifier: self.dialect.normalize_identifier(identifier).name
if identifier
else None
normalize = lambda identifier: (
self.dialect.normalize_identifier(identifier).name if identifier else None
)
targets = {normalize(expression.this.this)}
@ -1012,9 +1014,11 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
for when in expression.expressions:
when.transform(
lambda node: exp.column(node.this)
if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
else node,
lambda node: (
exp.column(node.this)
if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
else node
),
copy=False,
)

View file

@ -148,8 +148,8 @@ def _unix_to_time_sql(self: DuckDB.Generator, expression: exp.UnixToTime) -> str
def _rename_unless_within_group(
a: str, b: str
) -> t.Callable[[DuckDB.Generator, exp.Expression], str]:
return (
lambda self, expression: self.func(a, *flatten(expression.args.values()))
return lambda self, expression: (
self.func(a, *flatten(expression.args.values()))
if isinstance(expression.find_ancestor(exp.Select, exp.WithinGroup), exp.WithinGroup)
else self.func(b, *flatten(expression.args.values()))
)
@ -273,9 +273,11 @@ class DuckDB(Dialect):
PLACEHOLDER_PARSERS = {
**parser.Parser.PLACEHOLDER_PARSERS,
TokenType.PARAMETER: lambda self: self.expression(exp.Placeholder, this=self._prev.text)
if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS)
else None,
TokenType.PARAMETER: lambda self: (
self.expression(exp.Placeholder, this=self._prev.text)
if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS)
else None
),
}
def _parse_types(
@ -321,9 +323,11 @@ class DuckDB(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0])
if e.expressions and e.expressions[0].find(exp.Select)
else inline_array_sql(self, e),
exp.Array: lambda self, e: (
self.func("ARRAY", e.expressions[0])
if e.expressions and e.expressions[0].find(exp.Select)
else inline_array_sql(self, e)
),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArgMax: arg_max_or_min_no_count("ARG_MAX"),
exp.ArgMin: arg_max_or_min_no_count("ARG_MIN"),

View file

@ -397,9 +397,11 @@ class Hive(Dialect):
if this and not schema:
return this.transform(
lambda node: node.replace(exp.DataType.build("text"))
if isinstance(node, exp.DataType) and node.is_type("char", "varchar")
else node,
lambda node: (
node.replace(exp.DataType.build("text"))
if isinstance(node, exp.DataType) and node.is_type("char", "varchar")
else node
),
copy=False,
)
@ -409,9 +411,11 @@ class Hive(Dialect):
self,
) -> t.Tuple[t.List[exp.Expression], t.Optional[exp.Expression]]:
return (
self._parse_csv(self._parse_conjunction)
if self._match_set({TokenType.PARTITION_BY, TokenType.DISTRIBUTE_BY})
else [],
(
self._parse_csv(self._parse_conjunction)
if self._match_set({TokenType.PARTITION_BY, TokenType.DISTRIBUTE_BY})
else []
),
super()._parse_order(skip_order_token=self._match(TokenType.SORT_BY)),
)
@ -483,9 +487,9 @@ class Hive(Dialect):
exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)),
exp.Min: min_or_least,
exp.MonthsBetween: lambda self, e: self.func("MONTHS_BETWEEN", e.this, e.expression),
exp.NotNullColumnConstraint: lambda self, e: ""
if e.args.get("allow_null")
else "NOT NULL",
exp.NotNullColumnConstraint: lambda self, e: (
"" if e.args.get("allow_null") else "NOT NULL"
),
exp.VarMap: var_map_sql,
exp.Create: _create_sql,
exp.Quantile: rename_func("PERCENTILE"),

View file

@ -166,6 +166,7 @@ class Oracle(Dialect):
TABLESAMPLE_KEYWORDS = "SAMPLE"
LAST_DAY_SUPPORTS_DATE_PART = False
SUPPORTS_SELECT_INTO = True
TZ_TO_WITH_TIME_ZONE = True
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@ -179,6 +180,8 @@ class Oracle(Dialect):
exp.DataType.Type.NVARCHAR: "NVARCHAR2",
exp.DataType.Type.NCHAR: "NCHAR",
exp.DataType.Type.TEXT: "CLOB",
exp.DataType.Type.TIMETZ: "TIME",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.BINARY: "BLOB",
exp.DataType.Type.VARBINARY: "BLOB",
}

View file

@ -282,6 +282,12 @@ class Postgres(Dialect):
VAR_SINGLE_TOKENS = {"$"}
class Parser(parser.Parser):
PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS,
"SET": lambda self: self.expression(exp.SetConfigProperty, this=self._parse_set()),
}
PROPERTY_PARSERS.pop("INPUT", None)
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE_TRUNC": parse_timestamp_trunc,
@ -385,9 +391,11 @@ class Postgres(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.AnyValue: any_value_to_max_sql,
exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
if isinstance(seq_get(e.expressions, 0), exp.Select)
else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]",
exp.Array: lambda self, e: (
f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
if isinstance(seq_get(e.expressions, 0), exp.Select)
else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]"
),
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
@ -396,6 +404,7 @@ class Postgres(Dialect):
exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]),
exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.CurrentUser: lambda *_: "CURRENT_USER",
exp.DateAdd: _date_add_sql("+"),
exp.DateDiff: _date_diff_sql,
exp.DateStrToDate: datestrtodate_sql,

View file

@ -356,6 +356,7 @@ class Presto(Dialect):
exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"),
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.First: _first_last_sql,
exp.FromTimeZone: lambda self, e: f"WITH_TIMEZONE({self.sql(e, 'this')}, {self.sql(e, 'zone')}) AT TIME ZONE 'UTC'",
exp.GetPath: path_to_jsonpath(),
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.GroupConcat: lambda self, e: self.func(

View file

@ -3,7 +3,6 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot._typing import E
from sqlglot.dialects.dialect import (
Dialect,
NormalizationStrategy,
@ -25,6 +24,9 @@ from sqlglot.expressions import Literal
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
if t.TYPE_CHECKING:
from sqlglot._typing import E
def _check_int(s: str) -> bool:
if s[0] in ("-", "+"):
@ -297,10 +299,7 @@ def _parse_colon_get_path(
if not self._match(TokenType.COLON):
break
if self._match_set(self.RANGE_PARSERS):
this = self.RANGE_PARSERS[self._prev.token_type](self, this) or this
return this
return self._parse_range(this)
def _parse_timestamp_from_parts(args: t.List) -> exp.Func:
@ -376,7 +375,7 @@ class Snowflake(Dialect):
and isinstance(expression.parent, exp.Table)
and expression.name.lower() == "dual"
):
return t.cast(E, expression)
return expression # type: ignore
return super().quote_identifier(expression, identify=identify)
@ -471,6 +470,10 @@ class Snowflake(Dialect):
}
SHOW_PARSERS = {
"SCHEMAS": _show_parser("SCHEMAS"),
"TERSE SCHEMAS": _show_parser("SCHEMAS"),
"OBJECTS": _show_parser("OBJECTS"),
"TERSE OBJECTS": _show_parser("OBJECTS"),
"PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
"COLUMNS": _show_parser("COLUMNS"),
@ -580,21 +583,35 @@ class Snowflake(Dialect):
scope = None
scope_kind = None
# will identity SHOW TERSE SCHEMAS but not SHOW TERSE PRIMARY KEYS
# which is syntactically valid but has no effect on the output
terse = self._tokens[self._index - 2].text.upper() == "TERSE"
like = self._parse_string() if self._match(TokenType.LIKE) else None
if self._match(TokenType.IN):
if self._match_text_seq("ACCOUNT"):
scope_kind = "ACCOUNT"
elif self._match_set(self.DB_CREATABLES):
scope_kind = self._prev.text
scope_kind = self._prev.text.upper()
if self._curr:
scope = self._parse_table()
scope = self._parse_table_parts()
elif self._curr:
scope_kind = "TABLE"
scope = self._parse_table()
scope_kind = "SCHEMA" if this == "OBJECTS" else "TABLE"
scope = self._parse_table_parts()
return self.expression(
exp.Show, this=this, like=like, scope=scope, scope_kind=scope_kind
exp.Show,
**{
"terse": terse,
"this": this,
"like": like,
"scope": scope,
"scope_kind": scope_kind,
"starts_with": self._match_text_seq("STARTS", "WITH") and self._parse_string(),
"limit": self._parse_limit(),
"from": self._parse_string() if self._match(TokenType.FROM) else None,
},
)
def _parse_alter_table_swap(self) -> exp.SwapTable:
@ -690,6 +707,9 @@ class Snowflake(Dialect):
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.Explode: rename_func("FLATTEN"),
exp.Extract: rename_func("DATE_PART"),
exp.FromTimeZone: lambda self, e: self.func(
"CONVERT_TIMEZONE", e.args.get("zone"), "'UTC'", e.this
),
exp.GenerateSeries: lambda self, e: self.func(
"ARRAY_GENERATE_RANGE", e.args["start"], e.args["end"] + 1, e.args.get("step")
),
@ -820,6 +840,7 @@ class Snowflake(Dialect):
return f"{explode}{alias}"
def show_sql(self, expression: exp.Show) -> str:
terse = "TERSE " if expression.args.get("terse") else ""
like = self.sql(expression, "like")
like = f" LIKE {like}" if like else ""
@ -830,7 +851,19 @@ class Snowflake(Dialect):
if scope_kind:
scope_kind = f" IN {scope_kind}"
return f"SHOW {expression.name}{like}{scope_kind}{scope}"
starts_with = self.sql(expression, "starts_with")
if starts_with:
starts_with = f" STARTS WITH {starts_with}"
limit = self.sql(expression, "limit")
from_ = self.sql(expression, "from")
if from_:
from_ = f" FROM {from_}"
return (
f"SHOW {terse}{expression.name}{like}{scope_kind}{scope}{starts_with}{limit}{from_}"
)
def regexpextract_sql(self, expression: exp.RegexpExtract) -> str:
# Other dialects don't support all of the following parameters, so we need to
@ -884,3 +917,6 @@ class Snowflake(Dialect):
def with_properties(self, properties: exp.Properties) -> str:
return self.properties(properties, wrapped=False, prefix=self.seg(""), sep=" ")
def cluster_sql(self, expression: exp.Cluster) -> str:
return f"CLUSTER BY ({self.expressions(expression, flat=True)})"

View file

@ -80,9 +80,9 @@ class Spark(Spark2):
exp.TimestampAdd: lambda self, e: self.func(
"DATEADD", e.args.get("unit") or "DAY", e.expression, e.this
),
exp.TryCast: lambda self, e: self.trycast_sql(e)
if e.args.get("safe")
else self.cast_sql(e),
exp.TryCast: lambda self, e: (
self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e)
),
}
TRANSFORMS.pop(exp.AnyValue)
TRANSFORMS.pop(exp.DateDiff)

View file

@ -129,10 +129,20 @@ class Spark2(Hive):
"SHIFTRIGHT": binary_from_function(exp.BitwiseRightShift),
"STRING": _parse_as_cast("string"),
"TIMESTAMP": _parse_as_cast("timestamp"),
"TO_TIMESTAMP": lambda args: _parse_as_cast("timestamp")(args)
if len(args) == 1
else format_time_lambda(exp.StrToTime, "spark")(args),
"TO_TIMESTAMP": lambda args: (
_parse_as_cast("timestamp")(args)
if len(args) == 1
else format_time_lambda(exp.StrToTime, "spark")(args)
),
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
"TO_UTC_TIMESTAMP": lambda args: exp.FromTimeZone(
this=exp.cast_unless(
seq_get(args, 0) or exp.Var(this=""),
exp.DataType.build("timestamp"),
exp.DataType.build("timestamp"),
),
zone=seq_get(args, 1),
),
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
"WEEKOFYEAR": lambda args: exp.WeekOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
}
@ -188,6 +198,7 @@ class Spark2(Hive):
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
exp.From: transforms.preprocess([_unalias_pivot]),
exp.FromTimeZone: lambda self, e: f"TO_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Map: _map_sql,
@ -255,10 +266,12 @@ class Spark2(Hive):
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
return super().columndef_sql(
expression,
sep=": "
if isinstance(expression.parent, exp.DataType)
and expression.parent.is_type("struct")
else sep,
sep=(
": "
if isinstance(expression.parent, exp.DataType)
and expression.parent.is_type("struct")
else sep
),
)
class Tokenizer(Hive.Tokenizer):

View file

@ -38,3 +38,4 @@ class Tableau(Dialect):
**parser.Parser.FUNCTIONS,
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
}
NO_PAREN_IF_COMMANDS = False

View file

@ -76,9 +76,11 @@ def _format_time_lambda(
format=exp.Literal.string(
format_time(
args[0].name.lower(),
{**TSQL.TIME_MAPPING, **FULL_FORMAT_TIME_MAPPING}
if full_format_mapping
else TSQL.TIME_MAPPING,
(
{**TSQL.TIME_MAPPING, **FULL_FORMAT_TIME_MAPPING}
if full_format_mapping
else TSQL.TIME_MAPPING
),
)
),
)
@ -264,6 +266,15 @@ def _parse_timefromparts(args: t.List) -> exp.TimeFromParts:
)
def _parse_len(args: t.List) -> exp.Length:
this = seq_get(args, 0)
if this and not this.is_string:
this = exp.cast(this, exp.DataType.Type.TEXT)
return exp.Length(this=this)
class TSQL(Dialect):
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"
@ -431,7 +442,7 @@ class TSQL(Dialect):
"IIF": exp.If.from_arg_list,
"ISNULL": exp.Coalesce.from_arg_list,
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
"LEN": exp.Length.from_arg_list,
"LEN": _parse_len,
"REPLICATE": exp.Repeat.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"SYSDATETIME": exp.CurrentTimestamp.from_arg_list,
@ -469,6 +480,7 @@ class TSQL(Dialect):
ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
STRING_ALIASES = True
NO_PAREN_IF_COMMANDS = False
def _parse_projections(self) -> t.List[exp.Expression]:
"""
@ -478,9 +490,11 @@ class TSQL(Dialect):
See: https://learn.microsoft.com/en-us/sql/t-sql/queries/select-clause-transact-sql?view=sql-server-ver16#syntax
"""
return [
exp.alias_(projection.expression, projection.this.this, copy=False)
if isinstance(projection, exp.EQ) and isinstance(projection.this, exp.Column)
else projection
(
exp.alias_(projection.expression, projection.this.this, copy=False)
if isinstance(projection, exp.EQ) and isinstance(projection.this, exp.Column)
else projection
)
for projection in super()._parse_projections()
]
@ -702,7 +716,6 @@ class TSQL(Dialect):
exp.GroupConcat: _string_agg_sql,
exp.If: rename_func("IIF"),
exp.LastDay: lambda self, e: self.func("EOMONTH", e.this),
exp.Length: rename_func("LEN"),
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
exp.Min: min_or_least,
@ -922,3 +935,11 @@ class TSQL(Dialect):
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True, sep=" ")
return f"CONSTRAINT {this} {expressions}"
def length_sql(self, expression: exp.Length) -> str:
this = expression.this
if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.TEXT):
this_sql = self.sql(this, "this")
else:
this_sql = self.sql(this)
return self.func("LEN", this_sql)

View file

@ -392,9 +392,9 @@ def _lambda_sql(self, e: exp.Lambda) -> str:
names = {e.name.lower() for e in e.expressions}
e = e.transform(
lambda n: exp.var(n.name)
if isinstance(n, exp.Identifier) and n.name.lower() in names
else n
lambda n: (
exp.var(n.name) if isinstance(n, exp.Identifier) and n.name.lower() in names else n
)
)
return f"lambda {self.expressions(e, flat=True)}: {self.sql(e, 'this')}"
@ -438,9 +438,9 @@ class Python(Dialect):
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {{{self.expressions(e, flat=True)}}}",
exp.Interval: lambda self, e: f"INTERVAL({self.sql(e.this)}, '{self.sql(e.unit)}')",
exp.Is: lambda self, e: self.binary(e, "==")
if isinstance(e.this, exp.Literal)
else self.binary(e, "is"),
exp.Is: lambda self, e: (
self.binary(e, "==") if isinstance(e.this, exp.Literal) else self.binary(e, "is")
),
exp.Lambda: _lambda_sql,
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
exp.Null: lambda *_: "None",

View file

@ -23,7 +23,6 @@ from copy import deepcopy
from enum import auto
from functools import reduce
from sqlglot._typing import E
from sqlglot.errors import ErrorLevel, ParseError
from sqlglot.helper import (
AutoName,
@ -36,8 +35,7 @@ from sqlglot.helper import (
from sqlglot.tokens import Token
if t.TYPE_CHECKING:
from typing_extensions import Literal as Lit
from sqlglot._typing import E, Lit
from sqlglot.dialects.dialect import DialectType
@ -389,7 +387,7 @@ class Expression(metaclass=_Expression):
ancestor = self.parent
while ancestor and not isinstance(ancestor, expression_types):
ancestor = ancestor.parent
return t.cast(E, ancestor)
return ancestor # type: ignore
@property
def parent_select(self) -> t.Optional[Select]:
@ -555,12 +553,10 @@ class Expression(metaclass=_Expression):
return new_node
@t.overload
def replace(self, expression: E) -> E:
...
def replace(self, expression: E) -> E: ...
@t.overload
def replace(self, expression: None) -> None:
...
def replace(self, expression: None) -> None: ...
def replace(self, expression):
"""
@ -781,13 +777,16 @@ class Expression(metaclass=_Expression):
this=maybe_copy(self, copy),
expressions=[convert(e, copy=copy) for e in expressions],
query=maybe_parse(query, copy=copy, **opts) if query else None,
unnest=Unnest(
expressions=[
maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts) for e in ensure_list(unnest)
]
)
if unnest
else None,
unnest=(
Unnest(
expressions=[
maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts)
for e in ensure_list(unnest)
]
)
if unnest
else None
),
)
def between(self, low: t.Any, high: t.Any, copy: bool = True, **opts) -> Between:
@ -926,7 +925,7 @@ class DerivedTable(Expression):
class Unionable(Expression):
def union(
self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
) -> Unionable:
) -> Union:
"""
Builds a UNION expression.
@ -1134,9 +1133,12 @@ class SetItem(Expression):
class Show(Expression):
arg_types = {
"this": True,
"terse": False,
"target": False,
"offset": False,
"starts_with": False,
"limit": False,
"from": False,
"like": False,
"where": False,
"db": False,
@ -1274,9 +1276,14 @@ class AlterColumn(Expression):
"using": False,
"default": False,
"drop": False,
"comment": False,
}
class RenameColumn(Expression):
arg_types = {"this": True, "to": True, "exists": False}
class RenameTable(Expression):
pass
@ -1402,7 +1409,7 @@ class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
class GeneratedAsRowColumnConstraint(ColumnConstraintKind):
arg_types = {"start": True, "hidden": False}
arg_types = {"start": False, "hidden": False}
# https://dev.mysql.com/doc/refman/8.0/en/create-table.html
@ -1667,6 +1674,7 @@ class Index(Expression):
"unique": False,
"primary": False,
"amp": False, # teradata
"include": False,
"partition_by": False, # teradata
"where": False, # postgres partial indexes
}
@ -2016,7 +2024,13 @@ class AutoRefreshProperty(Property):
class BlockCompressionProperty(Property):
arg_types = {"autotemp": False, "always": False, "default": True, "manual": True, "never": True}
arg_types = {
"autotemp": False,
"always": False,
"default": False,
"manual": False,
"never": False,
}
class CharacterSetProperty(Property):
@ -2089,6 +2103,10 @@ class FreespaceProperty(Property):
arg_types = {"this": True, "percent": False}
class InheritsProperty(Property):
arg_types = {"expressions": True}
class InputModelProperty(Property):
arg_types = {"this": True}
@ -2099,11 +2117,11 @@ class OutputModelProperty(Property):
class IsolatedLoadingProperty(Property):
arg_types = {
"no": True,
"concurrent": True,
"for_all": True,
"for_insert": True,
"for_none": True,
"no": False,
"concurrent": False,
"for_all": False,
"for_insert": False,
"for_none": False,
}
@ -2264,6 +2282,10 @@ class SetProperty(Property):
arg_types = {"multi": True}
class SetConfigProperty(Property):
arg_types = {"this": True}
class SettingsProperty(Property):
arg_types = {"expressions": True}
@ -2407,13 +2429,16 @@ class Tuple(Expression):
this=maybe_copy(self, copy),
expressions=[convert(e, copy=copy) for e in expressions],
query=maybe_parse(query, copy=copy, **opts) if query else None,
unnest=Unnest(
expressions=[
maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts) for e in ensure_list(unnest)
]
)
if unnest
else None,
unnest=(
Unnest(
expressions=[
maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts)
for e in ensure_list(unnest)
]
)
if unnest
else None
),
)
@ -3631,6 +3656,8 @@ class DataType(Expression):
class Type(AutoName):
ARRAY = auto()
AGGREGATEFUNCTION = auto()
SIMPLEAGGREGATEFUNCTION = auto()
BIGDECIMAL = auto()
BIGINT = auto()
BIGSERIAL = auto()
@ -4162,6 +4189,10 @@ class AtTimeZone(Expression):
arg_types = {"this": True, "zone": True}
class FromTimeZone(Expression):
arg_types = {"this": True, "zone": True}
class Between(Predicate):
arg_types = {"this": True, "low": True, "high": True}
@ -5456,8 +5487,7 @@ def maybe_parse(
prefix: t.Optional[str] = None,
copy: bool = False,
**opts,
) -> E:
...
) -> E: ...
@t.overload
@ -5469,8 +5499,7 @@ def maybe_parse(
prefix: t.Optional[str] = None,
copy: bool = False,
**opts,
) -> E:
...
) -> E: ...
def maybe_parse(
@ -5522,13 +5551,11 @@ def maybe_parse(
@t.overload
def maybe_copy(instance: None, copy: bool = True) -> None:
...
def maybe_copy(instance: None, copy: bool = True) -> None: ...
@t.overload
def maybe_copy(instance: E, copy: bool = True) -> E:
...
def maybe_copy(instance: E, copy: bool = True) -> E: ...
def maybe_copy(instance, copy=True):
@ -6151,15 +6178,13 @@ SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
@t.overload
def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None:
...
def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None: ...
@t.overload
def to_identifier(
name: str | Identifier, quoted: t.Optional[bool] = None, copy: bool = True
) -> Identifier:
...
) -> Identifier: ...
def to_identifier(name, quoted=None, copy=True):
@ -6231,13 +6256,11 @@ def to_interval(interval: str | Literal) -> Interval:
@t.overload
def to_table(sql_path: str | Table, **kwargs) -> Table:
...
def to_table(sql_path: str | Table, **kwargs) -> Table: ...
@t.overload
def to_table(sql_path: None, **kwargs) -> None:
...
def to_table(sql_path: None, **kwargs) -> None: ...
def to_table(
@ -6562,6 +6585,34 @@ def rename_table(old_name: str | Table, new_name: str | Table) -> AlterTable:
)
def rename_column(
table_name: str | Table,
old_column_name: str | Column,
new_column_name: str | Column,
exists: t.Optional[bool] = None,
) -> AlterTable:
"""Build ALTER TABLE... RENAME COLUMN... expression
Args:
table_name: Name of the table
old_column: The old name of the column
new_column: The new name of the column
exists: Whether or not to add the `IF EXISTS` clause
Returns:
Alter table expression
"""
table = to_table(table_name)
old_column = to_column(old_column_name)
new_column = to_column(new_column_name)
return AlterTable(
this=table,
actions=[
RenameColumn(this=old_column, to=new_column, exists=exists),
],
)
def convert(value: t.Any, copy: bool = False) -> Expression:
"""Convert a python value into an expression object.
@ -6581,7 +6632,7 @@ def convert(value: t.Any, copy: bool = False) -> Expression:
if isinstance(value, bool):
return Boolean(this=value)
if value is None or (isinstance(value, float) and math.isnan(value)):
return NULL
return null()
if isinstance(value, numbers.Number):
return Literal.number(value)
if isinstance(value, datetime.datetime):
@ -6674,9 +6725,11 @@ def table_name(table: Table | str, dialect: DialectType = None, identify: bool =
raise ValueError(f"Cannot parse {table}")
return ".".join(
part.sql(dialect=dialect, identify=True, copy=False)
if identify or not SAFE_IDENTIFIER_RE.match(part.name)
else part.name
(
part.sql(dialect=dialect, identify=True, copy=False)
if identify or not SAFE_IDENTIFIER_RE.match(part.name)
else part.name
)
for part in table.parts
)
@ -6942,9 +6995,3 @@ def null() -> Null:
Returns a Null expression.
"""
return Null()
# TODO: deprecate this
TRUE = Boolean(this=True)
FALSE = Boolean(this=False)
NULL = Null()

View file

@ -77,6 +77,7 @@ class Generator:
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
exp.ExternalProperty: lambda self, e: "EXTERNAL",
exp.HeapProperty: lambda self, e: "HEAP",
exp.InheritsProperty: lambda self, e: f"INHERITS ({self.expressions(e, flat=True)})",
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
exp.InputModelProperty: lambda self, e: f"INPUT{self.sql(e, 'this')}",
exp.IntervalSpan: lambda self, e: f"{self.sql(e, 'this')} TO {self.sql(e, 'expression')}",
@ -96,6 +97,7 @@ class Generator:
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}",
exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
exp.SetConfigProperty: lambda self, e: self.sql(e, "this"),
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
exp.SqlReadWriteProperty: lambda self, e: e.name,
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
@ -323,6 +325,7 @@ class Generator:
exp.FileFormatProperty: exp.Properties.Location.POST_WITH,
exp.FreespaceProperty: exp.Properties.Location.POST_NAME,
exp.HeapProperty: exp.Properties.Location.POST_WITH,
exp.InheritsProperty: exp.Properties.Location.POST_SCHEMA,
exp.InputModelProperty: exp.Properties.Location.POST_SCHEMA,
exp.IsolatedLoadingProperty: exp.Properties.Location.POST_NAME,
exp.JournalProperty: exp.Properties.Location.POST_NAME,
@ -353,6 +356,7 @@ class Generator:
exp.Set: exp.Properties.Location.POST_SCHEMA,
exp.SettingsProperty: exp.Properties.Location.POST_SCHEMA,
exp.SetProperty: exp.Properties.Location.POST_CREATE,
exp.SetConfigProperty: exp.Properties.Location.POST_SCHEMA,
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA,
exp.SqlReadWriteProperty: exp.Properties.Location.POST_SCHEMA,
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
@ -568,9 +572,11 @@ class Generator:
def wrap(self, expression: exp.Expression | str) -> str:
this_sql = self.indent(
self.sql(expression)
if isinstance(expression, (exp.Select, exp.Union))
else self.sql(expression, "this"),
(
self.sql(expression)
if isinstance(expression, (exp.Select, exp.Union))
else self.sql(expression, "this")
),
level=1,
pad=0,
)
@ -605,9 +611,11 @@ class Generator:
lines = sql.split("\n")
return "\n".join(
line
if (skip_first and i == 0) or (skip_last and i == len(lines) - 1)
else f"{' ' * (level * self._indent + pad)}{line}"
(
line
if (skip_first and i == 0) or (skip_last and i == len(lines) - 1)
else f"{' ' * (level * self._indent + pad)}{line}"
)
for i, line in enumerate(lines)
)
@ -775,7 +783,7 @@ class Generator:
def generatedasrowcolumnconstraint_sql(
self, expression: exp.GeneratedAsRowColumnConstraint
) -> str:
start = "START" if expression.args["start"] else "END"
start = "START" if expression.args.get("start") else "END"
hidden = " HIDDEN" if expression.args.get("hidden") else ""
return f"GENERATED ALWAYS AS ROW {start}{hidden}"
@ -1111,7 +1119,10 @@ class Generator:
partition_by = self.expressions(expression, key="partition_by", flat=True)
partition_by = f" PARTITION BY {partition_by}" if partition_by else ""
where = self.sql(expression, "where")
return f"{unique}{primary}{amp}{index}{name}{table}{using}{columns}{partition_by}{where}"
include = self.expressions(expression, key="include", flat=True)
if include:
include = f" INCLUDE ({include})"
return f"{unique}{primary}{amp}{index}{name}{table}{using}{columns}{include}{partition_by}{where}"
def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
@ -2017,9 +2028,11 @@ class Generator:
def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
return [
self.sql(expression, "qualify"),
self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True)
if expression.args.get("windows")
else "",
(
self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True)
if expression.args.get("windows")
else ""
),
self.sql(expression, "distribute"),
self.sql(expression, "sort"),
self.sql(expression, "cluster"),
@ -2552,6 +2565,11 @@ class Generator:
zone = self.sql(expression, "zone")
return f"{this} AT TIME ZONE {zone}"
def fromtimezone_sql(self, expression: exp.FromTimeZone) -> str:
this = self.sql(expression, "this")
zone = self.sql(expression, "zone")
return f"{this} AT TIME ZONE {zone} AT TIME ZONE 'UTC'"
def add_sql(self, expression: exp.Add) -> str:
return self.binary(expression, "+")
@ -2669,6 +2687,10 @@ class Generator:
if default:
return f"ALTER COLUMN {this} SET DEFAULT {default}"
comment = self.sql(expression, "comment")
if comment:
return f"ALTER COLUMN {this} COMMENT {comment}"
if not expression.args.get("drop"):
self.unsupported("Unsupported ALTER COLUMN syntax")
@ -2683,6 +2705,12 @@ class Generator:
this = self.sql(expression, "this")
return f"RENAME TO {this}"
def renamecolumn_sql(self, expression: exp.RenameColumn) -> str:
exists = " IF EXISTS" if expression.args.get("exists") else ""
old_column = self.sql(expression, "this")
new_column = self.sql(expression, "to")
return f"RENAME COLUMN{exists} {old_column} TO {new_column}"
def altertable_sql(self, expression: exp.AlterTable) -> str:
actions = expression.args["actions"]

View file

@ -53,13 +53,11 @@ def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
@t.overload
def ensure_list(value: t.Collection[T]) -> t.List[T]:
...
def ensure_list(value: t.Collection[T]) -> t.List[T]: ...
@t.overload
def ensure_list(value: T) -> t.List[T]:
...
def ensure_list(value: T) -> t.List[T]: ...
def ensure_list(value):
@ -81,13 +79,11 @@ def ensure_list(value):
@t.overload
def ensure_collection(value: t.Collection[T]) -> t.Collection[T]:
...
def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: ...
@t.overload
def ensure_collection(value: T) -> t.Collection[T]:
...
def ensure_collection(value: T) -> t.Collection[T]: ...
def ensure_collection(value):

215
sqlglot/jsonpath.py Normal file
View file

@ -0,0 +1,215 @@
from __future__ import annotations
import typing as t
from sqlglot.errors import ParseError
from sqlglot.expressions import SAFE_IDENTIFIER_RE
from sqlglot.tokens import Token, Tokenizer, TokenType
if t.TYPE_CHECKING:
from sqlglot._typing import Lit
class JSONPathTokenizer(Tokenizer):
SINGLE_TOKENS = {
"(": TokenType.L_PAREN,
")": TokenType.R_PAREN,
"[": TokenType.L_BRACKET,
"]": TokenType.R_BRACKET,
":": TokenType.COLON,
",": TokenType.COMMA,
"-": TokenType.DASH,
".": TokenType.DOT,
"?": TokenType.PLACEHOLDER,
"@": TokenType.PARAMETER,
"'": TokenType.QUOTE,
'"': TokenType.QUOTE,
"$": TokenType.DOLLAR,
"*": TokenType.STAR,
}
KEYWORDS = {
"..": TokenType.DOT,
}
IDENTIFIER_ESCAPES = ["\\"]
STRING_ESCAPES = ["\\"]
JSONPathNode = t.Dict[str, t.Any]
def _node(kind: str, value: t.Any = None, **kwargs: t.Any) -> JSONPathNode:
node = {"kind": kind, **kwargs}
if value is not None:
node["value"] = value
return node
def parse(path: str) -> t.List[JSONPathNode]:
"""Takes in a JSONPath string and converts into a list of nodes."""
tokens = JSONPathTokenizer().tokenize(path)
size = len(tokens)
i = 0
def _curr() -> t.Optional[TokenType]:
return tokens[i].token_type if i < size else None
def _prev() -> Token:
return tokens[i - 1]
def _advance() -> Token:
nonlocal i
i += 1
return _prev()
def _error(msg: str) -> str:
return f"{msg} at index {i}: {path}"
@t.overload
def _match(token_type: TokenType, raise_unmatched: Lit[True] = True) -> Token:
pass
@t.overload
def _match(token_type: TokenType, raise_unmatched: Lit[False] = False) -> t.Optional[Token]:
pass
def _match(token_type, raise_unmatched=False):
if _curr() == token_type:
return _advance()
if raise_unmatched:
raise ParseError(_error(f"Expected {token_type}"))
return None
def _parse_literal() -> t.Any:
token = _match(TokenType.STRING) or _match(TokenType.IDENTIFIER)
if token:
return token.text
if _match(TokenType.STAR):
return _node("wildcard")
if _match(TokenType.PLACEHOLDER) or _match(TokenType.L_PAREN):
script = _prev().text == "("
start = i
while True:
if _match(TokenType.L_BRACKET):
_parse_bracket() # nested call which we can throw away
if _curr() in (TokenType.R_BRACKET, None):
break
_advance()
return _node(
"script" if script else "filter", path[tokens[start].start : tokens[i].end]
)
number = "-" if _match(TokenType.DASH) else ""
token = _match(TokenType.NUMBER)
if token:
number += token.text
if number:
return int(number)
return False
def _parse_slice() -> t.Any:
start = _parse_literal()
end = _parse_literal() if _match(TokenType.COLON) else None
step = _parse_literal() if _match(TokenType.COLON) else None
if end is None and step is None:
return start
return _node("slice", start=start, end=end, step=step)
def _parse_bracket() -> JSONPathNode:
literal = _parse_slice()
if isinstance(literal, str) or literal is not False:
indexes = [literal]
while _match(TokenType.COMMA):
literal = _parse_slice()
if literal:
indexes.append(literal)
if len(indexes) == 1:
if isinstance(literal, str):
node = _node("key", indexes[0])
elif isinstance(literal, dict) and literal["kind"] in ("script", "filter"):
node = _node("selector", indexes[0])
else:
node = _node("subscript", indexes[0])
else:
node = _node("union", indexes)
else:
raise ParseError(_error("Cannot have empty segment"))
_match(TokenType.R_BRACKET, raise_unmatched=True)
return node
nodes = []
while _curr():
if _match(TokenType.DOLLAR):
nodes.append(_node("root"))
elif _match(TokenType.DOT):
recursive = _prev().text == ".."
value = _match(TokenType.VAR) or _match(TokenType.STAR)
nodes.append(
_node("recursive" if recursive else "child", value=value.text if value else None)
)
elif _match(TokenType.L_BRACKET):
nodes.append(_parse_bracket())
elif _match(TokenType.VAR):
nodes.append(_node("key", _prev().text))
elif _match(TokenType.STAR):
nodes.append(_node("wildcard"))
elif _match(TokenType.PARAMETER):
nodes.append(_node("current"))
else:
raise ParseError(_error(f"Unexpected {tokens[i].token_type}"))
return nodes
MAPPING = {
"child": lambda n: f".{n['value']}" if n.get("value") is not None else "",
"filter": lambda n: f"?{n['value']}",
"key": lambda n: (
f".{n['value']}" if SAFE_IDENTIFIER_RE.match(n["value"]) else f'[{generate([n["value"]])}]'
),
"recursive": lambda n: f"..{n['value']}" if n.get("value") is not None else "..",
"root": lambda _: "$",
"script": lambda n: f"({n['value']}",
"slice": lambda n: ":".join(
"" if p is False else generate([p])
for p in [n["start"], n["end"], n["step"]]
if p is not None
),
"selector": lambda n: f"[{generate([n['value']])}]",
"subscript": lambda n: f"[{generate([n['value']])}]",
"union": lambda n: f"[{','.join(generate([p]) for p in n['value'])}]",
"wildcard": lambda _: "*",
}
def generate(
nodes: t.List[JSONPathNode],
mapping: t.Optional[t.Dict[str, t.Callable[[JSONPathNode], str]]] = None,
) -> str:
mapping = MAPPING if mapping is None else mapping
path = []
for node in nodes:
if isinstance(node, dict):
path.append(mapping[node["kind"]](node))
elif isinstance(node, str):
escaped = node.replace('"', '\\"')
path.append(f'"{escaped}"')
else:
path.append(str(node))
return "".join(path)

View file

@ -41,9 +41,9 @@ class Node:
else:
label = node.expression.sql(pretty=True, dialect=dialect)
source = node.source.transform(
lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
if n is node.expression
else n,
lambda n: (
exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
),
copy=False,
).sql(pretty=True, dialect=dialect)
title = f"<pre>{source}</pre>"

View file

@ -4,7 +4,6 @@ import functools
import typing as t
from sqlglot import exp
from sqlglot._typing import E
from sqlglot.helper import (
ensure_list,
is_date_unit,
@ -17,7 +16,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema, ensure_schema
if t.TYPE_CHECKING:
B = t.TypeVar("B", bound=exp.Binary)
from sqlglot._typing import B, E
BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type]
BinaryCoercions = t.Dict[
@ -479,6 +478,20 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
self._set_type(expression, target_type)
return self._annotate_args(expression)
@t.no_type_check
def _annotate_struct_value(
self, expression: exp.Expression
) -> t.Optional[exp.DataType] | exp.ColumnDef:
alias = expression.args.get("alias")
if alias:
return exp.ColumnDef(this=alias.copy(), kind=expression.type)
# Case: key = value or key := value
if expression.expression:
return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)
return expression.type
@t.no_type_check
def _annotate_by_args(
self,
@ -516,16 +529,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
)
if struct:
expressions = [
expr.type
if not expr.args.get("alias")
else exp.ColumnDef(this=expr.args["alias"].copy(), kind=expr.type)
for expr in expressions
]
self._set_type(
expression,
exp.DataType(this=exp.DataType.Type.STRUCT, expressions=expressions, nested=True),
exp.DataType(
this=exp.DataType.Type.STRUCT,
expressions=[self._annotate_struct_value(expr) for expr in expressions],
nested=True,
),
)
return expression

View file

@ -3,18 +3,18 @@ from __future__ import annotations
import typing as t
from sqlglot import exp
from sqlglot._typing import E
from sqlglot.dialects.dialect import Dialect, DialectType
@t.overload
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
...
if t.TYPE_CHECKING:
from sqlglot._typing import E
@t.overload
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier:
...
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: ...
@t.overload
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: ...
def normalize_identifiers(expression, dialect=None):

View file

@ -4,7 +4,6 @@ import itertools
import typing as t
from sqlglot import alias, exp
from sqlglot._typing import E
from sqlglot.dialects.dialect import Dialect, DialectType
from sqlglot.errors import OptimizeError
from sqlglot.helper import seq_get
@ -12,6 +11,9 @@ from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_
from sqlglot.optimizer.simplify import simplify_parens
from sqlglot.schema import Schema, ensure_schema
if t.TYPE_CHECKING:
from sqlglot._typing import E
def qualify_columns(
expression: exp.Expression,
@ -210,7 +212,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
if not node:
return
for column, *_ in walk_in_scope(node):
for column, *_ in walk_in_scope(node, prune=lambda node, *_: node.is_star):
if not isinstance(column, exp.Column):
continue
@ -525,6 +527,7 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
selection = alias(
selection,
alias=selection.output_name or f"_col_{i}",
copy=False,
)
if aliased_column:
selection.set("alias", exp.to_identifier(aliased_column))

View file

@ -4,12 +4,14 @@ import itertools
import typing as t
from sqlglot import alias, exp
from sqlglot._typing import E
from sqlglot.dialects.dialect import DialectType
from sqlglot.helper import csv_reader, name_sequence
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema
if t.TYPE_CHECKING:
from sqlglot._typing import E
def qualify_tables(
expression: E,
@ -46,6 +48,18 @@ def qualify_tables(
db = exp.parse_identifier(db, dialect=dialect) if db else None
catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
def _qualify(table: exp.Table) -> None:
if isinstance(table.this, exp.Identifier):
if not table.args.get("db"):
table.set("db", db)
if not table.args.get("catalog") and table.args.get("db"):
table.set("catalog", catalog)
if not isinstance(expression, exp.Subqueryable):
for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Unionable)):
if isinstance(node, exp.Table):
_qualify(node)
for scope in traverse_scope(expression):
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
if isinstance(derived_table, exp.Subquery):
@ -66,11 +80,7 @@ def qualify_tables(
for name, source in scope.sources.items():
if isinstance(source, exp.Table):
if isinstance(source.this, exp.Identifier):
if not source.args.get("db"):
source.set("db", db)
if not source.args.get("catalog") and source.args.get("db"):
source.set("catalog", catalog)
_qualify(source)
pivots = pivots = source.args.get("pivots")
if not source.alias:
@ -107,5 +117,14 @@ def qualify_tables(
if isinstance(udtf, exp.Values) and not table_alias.columns:
for i, e in enumerate(udtf.expressions[0].expressions):
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
else:
for node, parent, _ in scope.walk():
if (
isinstance(node, exp.Table)
and not node.alias
and isinstance(parent, (exp.From, exp.Join))
):
# Mutates the table by attaching an alias to it
alias(node, node.name, copy=False, table=True)
return expression

View file

@ -323,9 +323,14 @@ class Scope:
sources in the current scope.
"""
if self._external_columns is None:
self._external_columns = [
c for c in self.columns if c.table not in self.selected_sources
]
if isinstance(self.expression, exp.Union):
left, right = self.union_scopes
self._external_columns = left.external_columns + right.external_columns
else:
self._external_columns = [
c for c in self.columns if c.table not in self.selected_sources
]
return self._external_columns
@property
@ -477,11 +482,12 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
Args:
expression (exp.Expression): expression to traverse
Returns:
list[Scope]: scope instances
"""
if isinstance(expression, exp.Unionable) or (
isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable)
isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Unionable)
):
return list(_traverse_scope(Scope(expression)))

View file

@ -1068,9 +1068,11 @@ def extract_interval(expression):
def date_literal(date):
return exp.cast(
exp.Literal.string(date),
exp.DataType.Type.DATETIME
if isinstance(date, datetime.datetime)
else exp.DataType.Type.DATE,
(
exp.DataType.Type.DATETIME
if isinstance(date, datetime.datetime)
else exp.DataType.Type.DATE
),
)

View file

@ -50,11 +50,12 @@ def unnest(select, parent_select, next_alias_name):
):
return
clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
# This subquery returns a scalar and can just be converted to a cross join
if not isinstance(predicate, (exp.In, exp.Any)):
column = exp.column(select.selects[0].alias_or_name, alias)
clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
clause_parent_select = clause.parent_select if clause else None
if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or (
@ -84,12 +85,18 @@ def unnest(select, parent_select, next_alias_name):
column = _other_operand(predicate)
value = select.selects[0]
on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
_replace(predicate, f"NOT {on.right} IS NULL")
join_key = exp.column(value.alias, alias)
join_key_not_null = join_key.is_(exp.null()).not_()
if isinstance(clause, exp.Join):
_replace(predicate, exp.true())
parent_select.where(join_key_not_null, copy=False)
else:
_replace(predicate, join_key_not_null)
parent_select.join(
select.group_by(value.this, copy=False),
on=on,
on=column.eq(join_key),
join_type="LEFT",
join_alias=alias,
copy=False,

View file

@ -12,9 +12,7 @@ from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import TrieResult, in_trie, new_trie
if t.TYPE_CHECKING:
from typing_extensions import Literal
from sqlglot._typing import E
from sqlglot._typing import E, Lit
from sqlglot.dialects.dialect import Dialect, DialectType
logger = logging.getLogger("sqlglot")
@ -148,6 +146,11 @@ class Parser(metaclass=_Parser):
TokenType.ENUM16,
}
AGGREGATE_TYPE_TOKENS = {
TokenType.AGGREGATEFUNCTION,
TokenType.SIMPLEAGGREGATEFUNCTION,
}
TYPE_TOKENS = {
TokenType.BIT,
TokenType.BOOLEAN,
@ -241,6 +244,7 @@ class Parser(metaclass=_Parser):
TokenType.NULL,
*ENUM_TYPE_TOKENS,
*NESTED_TYPE_TOKENS,
*AGGREGATE_TYPE_TOKENS,
}
SIGNED_TO_UNSIGNED_TYPE_TOKEN = {
@ -653,9 +657,11 @@ class Parser(metaclass=_Parser):
PLACEHOLDER_PARSERS = {
TokenType.PLACEHOLDER: lambda self: self.expression(exp.Placeholder),
TokenType.PARAMETER: lambda self: self._parse_parameter(),
TokenType.COLON: lambda self: self.expression(exp.Placeholder, this=self._prev.text)
if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS)
else None,
TokenType.COLON: lambda self: (
self.expression(exp.Placeholder, this=self._prev.text)
if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS)
else None
),
}
RANGE_PARSERS = {
@ -705,6 +711,9 @@ class Parser(metaclass=_Parser):
"IMMUTABLE": lambda self: self.expression(
exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
),
"INHERITS": lambda self: self.expression(
exp.InheritsProperty, expressions=self._parse_wrapped_csv(self._parse_table)
),
"INPUT": lambda self: self.expression(exp.InputModelProperty, this=self._parse_schema()),
"JOURNAL": lambda self, **kwargs: self._parse_journal(**kwargs),
"LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty),
@ -822,6 +831,7 @@ class Parser(metaclass=_Parser):
ALTER_PARSERS = {
"ADD": lambda self: self._parse_alter_table_add(),
"ALTER": lambda self: self._parse_alter_table_alter(),
"CLUSTER BY": lambda self: self._parse_cluster(wrapped=True),
"DELETE": lambda self: self.expression(exp.Delete, where=self._parse_where()),
"DROP": lambda self: self._parse_alter_table_drop(),
"RENAME": lambda self: self._parse_alter_table_rename(),
@ -973,6 +983,9 @@ class Parser(metaclass=_Parser):
MODIFIERS_ATTACHED_TO_UNION = True
UNION_MODIFIERS = {"order", "limit", "offset"}
# parses no parenthesis if statements as commands
NO_PAREN_IF_COMMANDS = True
__slots__ = (
"error_level",
"error_message_context",
@ -1207,7 +1220,20 @@ class Parser(metaclass=_Parser):
if index != self._index:
self._advance(index - self._index)
def _warn_unsupported(self) -> None:
if len(self._tokens) <= 1:
return
# We use _find_sql because self.sql may comprise multiple chunks, and we're only
# interested in emitting a warning for the one being currently processed.
sql = self._find_sql(self._tokens[0], self._tokens[-1])[: self.error_message_context]
logger.warning(
f"'{sql}' contains unsupported syntax. Falling back to parsing as a 'Command'."
)
def _parse_command(self) -> exp.Command:
self._warn_unsupported()
return self.expression(
exp.Command, this=self._prev.text.upper(), expression=self._parse_string()
)
@ -1329,8 +1355,10 @@ class Parser(metaclass=_Parser):
start = self._prev
comments = self._prev_comments
replace = start.text.upper() == "REPLACE" or self._match_pair(
TokenType.OR, TokenType.REPLACE
replace = (
start.token_type == TokenType.REPLACE
or self._match_pair(TokenType.OR, TokenType.REPLACE)
or self._match_pair(TokenType.OR, TokenType.ALTER)
)
unique = self._match(TokenType.UNIQUE)
@ -1440,6 +1468,9 @@ class Parser(metaclass=_Parser):
exp.Clone, this=self._parse_table(schema=True), shallow=shallow, copy=copy
)
if self._curr:
return self._parse_as_command(start)
return self.expression(
exp.Create,
comments=comments,
@ -1516,11 +1547,13 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.FileFormatProperty,
this=self.expression(
exp.InputOutputFormat, input_format=input_format, output_format=output_format
)
if input_format or output_format
else self._parse_var_or_string() or self._parse_number() or self._parse_id_var(),
this=(
self.expression(
exp.InputOutputFormat, input_format=input_format, output_format=output_format
)
if input_format or output_format
else self._parse_var_or_string() or self._parse_number() or self._parse_id_var()
),
)
def _parse_property_assignment(self, exp_class: t.Type[E], **kwargs: t.Any) -> E:
@ -1632,8 +1665,15 @@ class Parser(metaclass=_Parser):
return self.expression(exp.ChecksumProperty, on=on, default=self._match(TokenType.DEFAULT))
def _parse_cluster(self) -> exp.Cluster:
return self.expression(exp.Cluster, expressions=self._parse_csv(self._parse_ordered))
def _parse_cluster(self, wrapped: bool = False) -> exp.Cluster:
return self.expression(
exp.Cluster,
expressions=(
self._parse_wrapped_csv(self._parse_ordered)
if wrapped
else self._parse_csv(self._parse_ordered)
),
)
def _parse_clustered_by(self) -> exp.ClusteredByProperty:
self._match_text_seq("BY")
@ -2681,6 +2721,8 @@ class Parser(metaclass=_Parser):
else:
columns = None
include = self._parse_wrapped_id_vars() if self._match_text_seq("INCLUDE") else None
return self.expression(
exp.Index,
this=index,
@ -2690,6 +2732,7 @@ class Parser(metaclass=_Parser):
unique=unique,
primary=primary,
amp=amp,
include=include,
partition_by=self._parse_partition_by(),
where=self._parse_where(),
)
@ -3380,8 +3423,8 @@ class Parser(metaclass=_Parser):
def _parse_comparison(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_range, self.COMPARISON)
def _parse_range(self) -> t.Optional[exp.Expression]:
this = self._parse_bitwise()
def _parse_range(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
this = this or self._parse_bitwise()
negate = self._match(TokenType.NOT)
if self._match_set(self.RANGE_PARSERS):
@ -3535,14 +3578,21 @@ class Parser(metaclass=_Parser):
return self._parse_tokens(self._parse_factor, self.TERM)
def _parse_factor(self) -> t.Optional[exp.Expression]:
if self.EXPONENT:
factor = self._parse_tokens(self._parse_exponent, self.FACTOR)
else:
factor = self._parse_tokens(self._parse_unary, self.FACTOR)
if isinstance(factor, exp.Div):
factor.args["typed"] = self.dialect.TYPED_DIVISION
factor.args["safe"] = self.dialect.SAFE_DIVISION
return factor
parse_method = self._parse_exponent if self.EXPONENT else self._parse_unary
this = parse_method()
while self._match_set(self.FACTOR):
this = self.expression(
self.FACTOR[self._prev.token_type],
this=this,
comments=self._prev_comments,
expression=parse_method(),
)
if isinstance(this, exp.Div):
this.args["typed"] = self.dialect.TYPED_DIVISION
this.args["safe"] = self.dialect.SAFE_DIVISION
return this
def _parse_exponent(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_unary, self.EXPONENT)
@ -3617,6 +3667,7 @@ class Parser(metaclass=_Parser):
return exp.DataType.build(type_name, udt=True)
else:
self._retreat(self._index - 1)
return None
else:
return None
@ -3631,6 +3682,7 @@ class Parser(metaclass=_Parser):
nested = type_token in self.NESTED_TYPE_TOKENS
is_struct = type_token in self.STRUCT_TYPE_TOKENS
is_aggregate = type_token in self.AGGREGATE_TYPE_TOKENS
expressions = None
maybe_func = False
@ -3645,6 +3697,18 @@ class Parser(metaclass=_Parser):
)
elif type_token in self.ENUM_TYPE_TOKENS:
expressions = self._parse_csv(self._parse_equality)
elif is_aggregate:
func_or_ident = self._parse_function(anonymous=True) or self._parse_id_var(
any_token=False, tokens=(TokenType.VAR,)
)
if not func_or_ident or not self._match(TokenType.COMMA):
return None
expressions = self._parse_csv(
lambda: self._parse_types(
check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
)
)
expressions.insert(0, func_or_ident)
else:
expressions = self._parse_csv(self._parse_type_size)
@ -4413,6 +4477,10 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
else:
index = self._index - 1
if self.NO_PAREN_IF_COMMANDS and index == 0:
return self._parse_as_command(self._prev)
condition = self._parse_conjunction()
if not condition:
@ -4624,12 +4692,10 @@ class Parser(metaclass=_Parser):
return None
@t.overload
def _parse_json_object(self, agg: Literal[False]) -> exp.JSONObject:
...
def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ...
@t.overload
def _parse_json_object(self, agg: Literal[True]) -> exp.JSONObjectAgg:
...
def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ...
def _parse_json_object(self, agg=False):
star = self._parse_star()
@ -4974,11 +5040,12 @@ class Parser(metaclass=_Parser):
if alias:
this = self.expression(exp.Alias, comments=comments, this=this, alias=alias)
column = this.this
# Moves the comment next to the alias in `expr /* comment */ AS alias`
if not this.comments and this.this.comments:
this.comments = this.this.comments
this.this.comments = None
if not this.comments and column and column.comments:
this.comments = column.comments
column.comments = None
return this
@ -5244,7 +5311,7 @@ class Parser(metaclass=_Parser):
if self._match_text_seq("CHECK"):
expression = self._parse_wrapped(self._parse_conjunction)
enforced = self._match_text_seq("ENFORCED")
enforced = self._match_text_seq("ENFORCED") or False
return self.expression(
exp.AddConstraint, this=this, expression=expression, enforced=enforced
@ -5278,6 +5345,8 @@ class Parser(metaclass=_Parser):
return self.expression(exp.AlterColumn, this=column, drop=True)
if self._match_pair(TokenType.SET, TokenType.DEFAULT):
return self.expression(exp.AlterColumn, this=column, default=self._parse_conjunction())
if self._match(TokenType.COMMENT):
return self.expression(exp.AlterColumn, this=column, comment=self._parse_string())
self._match_text_seq("SET", "DATA")
return self.expression(
@ -5298,7 +5367,18 @@ class Parser(metaclass=_Parser):
self._retreat(index)
return self._parse_csv(self._parse_drop_column)
def _parse_alter_table_rename(self) -> exp.RenameTable:
def _parse_alter_table_rename(self) -> t.Optional[exp.RenameTable | exp.RenameColumn]:
if self._match(TokenType.COLUMN):
exists = self._parse_exists()
old_column = self._parse_column()
to = self._match_text_seq("TO")
new_column = self._parse_column()
if old_column is None or to is None or new_column is None:
return None
return self.expression(exp.RenameColumn, this=old_column, to=new_column, exists=exists)
self._match_text_seq("TO")
return self.expression(exp.RenameTable, this=self._parse_table(schema=True))
@ -5319,7 +5399,7 @@ class Parser(metaclass=_Parser):
if parser:
actions = ensure_list(parser(self))
if not self._curr:
if not self._curr and actions:
return self.expression(
exp.AlterTable,
this=this,
@ -5467,6 +5547,7 @@ class Parser(metaclass=_Parser):
self._advance()
text = self._find_sql(start, self._prev)
size = len(start.text)
self._warn_unsupported()
return exp.Command(this=text[:size], expression=text[size:])
def _parse_dict_property(self, this: str) -> exp.DictProperty:
@ -5634,7 +5715,7 @@ class Parser(metaclass=_Parser):
if advance:
self._advance()
return True
return False
return None
def _match_text_seq(self, *texts, advance=True):
index = self._index
@ -5643,7 +5724,7 @@ class Parser(metaclass=_Parser):
self._advance()
else:
self._retreat(index)
return False
return None
if not advance:
self._retreat(index)
@ -5651,14 +5732,12 @@ class Parser(metaclass=_Parser):
return True
@t.overload
def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression:
...
def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression: ...
@t.overload
def _replace_columns_with_dots(
self, this: t.Optional[exp.Expression]
) -> t.Optional[exp.Expression]:
...
) -> t.Optional[exp.Expression]: ...
def _replace_columns_with_dots(self, this):
if isinstance(this, exp.Dot):

View file

@ -106,6 +106,19 @@ class Schema(abc.ABC):
name = column if isinstance(column, str) else column.name
return name in self.column_names(table, dialect=dialect, normalize=normalize)
@abc.abstractmethod
def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
"""
Returns the schema of a given table.
Args:
table: the target table.
raise_on_missing: whether or not to raise in case the schema is not found.
Returns:
The schema of the target table.
"""
@property
@abc.abstractmethod
def supported_table_args(self) -> t.Tuple[str, ...]:
@ -156,11 +169,9 @@ class AbstractMappingSchema:
return [table.this.name]
return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
def find(
self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
) -> t.Optional[t.Any]:
def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
value, trie = in_trie(self.mapping_trie, parts)
if value == TrieResult.FAILED:
return None

View file

@ -191,6 +191,8 @@ class TokenType(AutoName):
FIXEDSTRING = auto()
LOWCARDINALITY = auto()
NESTED = auto()
AGGREGATEFUNCTION = auto()
SIMPLEAGGREGATEFUNCTION = auto()
UNKNOWN = auto()
# keywords