Merging upstream version 20.11.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
1bce3d0317
commit
e71ccc03da
141 changed files with 66644 additions and 54334 deletions
|
@ -87,13 +87,11 @@ def parse(
|
|||
|
||||
|
||||
@t.overload
|
||||
def parse_one(sql: str, *, into: t.Type[E], **opts) -> E:
|
||||
...
|
||||
def parse_one(sql: str, *, into: t.Type[E], **opts) -> E: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def parse_one(sql: str, **opts) -> Expression:
|
||||
...
|
||||
def parse_one(sql: str, **opts) -> Expression: ...
|
||||
|
||||
|
||||
def parse_one(
|
||||
|
|
|
@ -4,10 +4,13 @@ import typing as t
|
|||
|
||||
import sqlglot
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from typing_extensions import Literal as Lit # noqa
|
||||
|
||||
# A little hack for backwards compatibility with Python 3.7.
|
||||
# For example, we might want a TypeVar for objects that support comparison e.g. SupportsRichComparisonT from typeshed.
|
||||
# But Python 3.7 doesn't support Protocols, so we'd also need typing_extensions, which we don't want as a dependency.
|
||||
A = t.TypeVar("A", bound=t.Any)
|
||||
|
||||
B = t.TypeVar("B", bound="sqlglot.exp.Binary")
|
||||
E = t.TypeVar("E", bound="sqlglot.exp.Expression")
|
||||
T = t.TypeVar("T")
|
||||
|
|
|
@ -144,9 +144,11 @@ class Column:
|
|||
) -> Column:
|
||||
ensured_column = None if column is None else cls.ensure_col(column)
|
||||
ensure_expression_values = {
|
||||
k: [Column.ensure_col(x).expression for x in v]
|
||||
if is_iterable(v)
|
||||
else Column.ensure_col(v).expression
|
||||
k: (
|
||||
[Column.ensure_col(x).expression for x in v]
|
||||
if is_iterable(v)
|
||||
else Column.ensure_col(v).expression
|
||||
)
|
||||
for k, v in kwargs.items()
|
||||
if v is not None
|
||||
}
|
||||
|
|
|
@ -140,12 +140,10 @@ class DataFrame:
|
|||
return cte, name
|
||||
|
||||
@t.overload
|
||||
def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]:
|
||||
...
|
||||
def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: ...
|
||||
|
||||
@t.overload
|
||||
def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]:
|
||||
...
|
||||
def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: ...
|
||||
|
||||
def _ensure_list_of_columns(self, cols):
|
||||
return Column.ensure_cols(ensure_list(cols))
|
||||
|
@ -496,9 +494,11 @@ class DataFrame:
|
|||
join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs]
|
||||
# To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list
|
||||
select_column_names = [
|
||||
column.alias_or_name
|
||||
if not isinstance(column.expression.this, exp.Star)
|
||||
else column.sql()
|
||||
(
|
||||
column.alias_or_name
|
||||
if not isinstance(column.expression.this, exp.Star)
|
||||
else column.sql()
|
||||
)
|
||||
for column in self_columns + other_columns
|
||||
]
|
||||
select_column_names = [
|
||||
|
@ -552,9 +552,11 @@ class DataFrame:
|
|||
), "The length of items in ascending must equal the number of columns provided"
|
||||
col_and_ascending = list(zip(columns, ascending))
|
||||
order_by_columns = [
|
||||
exp.Ordered(this=col.expression, desc=not asc)
|
||||
if i not in pre_ordered_col_indexes
|
||||
else columns[i].column_expression
|
||||
(
|
||||
exp.Ordered(this=col.expression, desc=not asc)
|
||||
if i not in pre_ordered_col_indexes
|
||||
else columns[i].column_expression
|
||||
)
|
||||
for i, (col, asc) in enumerate(col_and_ascending)
|
||||
]
|
||||
return self.copy(expression=self.expression.order_by(*order_by_columns))
|
||||
|
|
|
@ -661,7 +661,7 @@ def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column:
|
|||
|
||||
def to_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column:
|
||||
tz_column = tz if isinstance(tz, Column) else lit(tz)
|
||||
return Column.invoke_anonymous_function(timestamp, "TO_UTC_TIMESTAMP", tz_column)
|
||||
return Column.invoke_expression_over_column(timestamp, expression.FromTimeZone, zone=tz_column)
|
||||
|
||||
|
||||
def timestamp_seconds(col: ColumnOrName) -> Column:
|
||||
|
|
|
@ -7,11 +7,11 @@ from sqlglot.dataframe.sql.column import Column
|
|||
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
|
||||
from sqlglot.helper import ensure_list
|
||||
|
||||
NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dataframe.sql.session import SparkSession
|
||||
|
||||
NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
|
||||
|
||||
|
||||
def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[NORMALIZE_INPUT]):
|
||||
expr = ensure_list(expr)
|
||||
|
|
|
@ -82,9 +82,11 @@ class SparkSession:
|
|||
]
|
||||
|
||||
sel_columns = [
|
||||
F.col(name).cast(data_type).alias(name).expression
|
||||
if data_type is not None
|
||||
else F.col(name).expression
|
||||
(
|
||||
F.col(name).cast(data_type).alias(name).expression
|
||||
if data_type is not None
|
||||
else F.col(name).expression
|
||||
)
|
||||
for name, data_type in column_mapping.items()
|
||||
]
|
||||
|
||||
|
|
|
@ -90,9 +90,11 @@ class WindowSpec:
|
|||
**kwargs,
|
||||
**{
|
||||
"start_side": "PRECEDING",
|
||||
"start": "UNBOUNDED"
|
||||
if start <= Window.unboundedPreceding
|
||||
else F.lit(start).expression,
|
||||
"start": (
|
||||
"UNBOUNDED"
|
||||
if start <= Window.unboundedPreceding
|
||||
else F.lit(start).expression
|
||||
),
|
||||
},
|
||||
}
|
||||
if end == Window.currentRow:
|
||||
|
@ -102,9 +104,9 @@ class WindowSpec:
|
|||
**kwargs,
|
||||
**{
|
||||
"end_side": "FOLLOWING",
|
||||
"end": "UNBOUNDED"
|
||||
if end >= Window.unboundedFollowing
|
||||
else F.lit(end).expression,
|
||||
"end": (
|
||||
"UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression
|
||||
),
|
||||
},
|
||||
}
|
||||
return kwargs
|
||||
|
|
|
@ -5,7 +5,6 @@ import re
|
|||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
NormalizationStrategy,
|
||||
|
@ -30,7 +29,7 @@ from sqlglot.helper import seq_get, split_num_words
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from typing_extensions import Literal
|
||||
from sqlglot._typing import E, Lit
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
|
@ -47,9 +46,11 @@ def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Va
|
|||
exp.alias_(value, column_name)
|
||||
for value, column_name in zip(
|
||||
t.expressions,
|
||||
alias.columns
|
||||
if alias and alias.columns
|
||||
else (f"_c{i}" for i in range(len(t.expressions))),
|
||||
(
|
||||
alias.columns
|
||||
if alias and alias.columns
|
||||
else (f"_c{i}" for i in range(len(t.expressions)))
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
@ -473,12 +474,10 @@ class BigQuery(Dialect):
|
|||
return table
|
||||
|
||||
@t.overload
|
||||
def _parse_json_object(self, agg: Literal[False]) -> exp.JSONObject:
|
||||
...
|
||||
def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ...
|
||||
|
||||
@t.overload
|
||||
def _parse_json_object(self, agg: Literal[True]) -> exp.JSONObjectAgg:
|
||||
...
|
||||
def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ...
|
||||
|
||||
def _parse_json_object(self, agg=False):
|
||||
json_object = super()._parse_json_object()
|
||||
|
@ -546,9 +545,11 @@ class BigQuery(Dialect):
|
|||
exp.ArrayContains: _array_contains_sql,
|
||||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||
exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
|
||||
exp.CollateProperty: lambda self, e: f"DEFAULT COLLATE {self.sql(e, 'this')}"
|
||||
if e.args.get("default")
|
||||
else f"COLLATE {self.sql(e, 'this')}",
|
||||
exp.CollateProperty: lambda self, e: (
|
||||
f"DEFAULT COLLATE {self.sql(e, 'this')}"
|
||||
if e.args.get("default")
|
||||
else f"COLLATE {self.sql(e, 'this')}"
|
||||
),
|
||||
exp.CountIf: rename_func("COUNTIF"),
|
||||
exp.Create: _create_sql,
|
||||
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
|
||||
|
@ -560,6 +561,9 @@ class BigQuery(Dialect):
|
|||
exp.DatetimeAdd: date_add_interval_sql("DATETIME", "ADD"),
|
||||
exp.DatetimeSub: date_add_interval_sql("DATETIME", "SUB"),
|
||||
exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")),
|
||||
exp.FromTimeZone: lambda self, e: self.func(
|
||||
"DATETIME", self.func("TIMESTAMP", e.this, e.args.get("zone")), "'UTC'"
|
||||
),
|
||||
exp.GenerateSeries: rename_func("GENERATE_ARRAY"),
|
||||
exp.GetPath: path_to_jsonpath(),
|
||||
exp.GroupConcat: rename_func("STRING_AGG"),
|
||||
|
@ -595,9 +599,9 @@ class BigQuery(Dialect):
|
|||
exp.SHA2: lambda self, e: self.func(
|
||||
f"SHA256" if e.text("length") == "256" else "SHA512", e.this
|
||||
),
|
||||
exp.StabilityProperty: lambda self, e: f"DETERMINISTIC"
|
||||
if e.name == "IMMUTABLE"
|
||||
else "NOT DETERMINISTIC",
|
||||
exp.StabilityProperty: lambda self, e: (
|
||||
f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC"
|
||||
),
|
||||
exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
|
||||
exp.StrToTime: lambda self, e: self.func(
|
||||
"PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone")
|
||||
|
|
|
@ -88,6 +88,8 @@ class ClickHouse(Dialect):
|
|||
"UINT8": TokenType.UTINYINT,
|
||||
"IPV4": TokenType.IPV4,
|
||||
"IPV6": TokenType.IPV6,
|
||||
"AGGREGATEFUNCTION": TokenType.AGGREGATEFUNCTION,
|
||||
"SIMPLEAGGREGATEFUNCTION": TokenType.SIMPLEAGGREGATEFUNCTION,
|
||||
}
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
|
@ -548,6 +550,8 @@ class ClickHouse(Dialect):
|
|||
exp.DataType.Type.UTINYINT: "UInt8",
|
||||
exp.DataType.Type.IPV4: "IPv4",
|
||||
exp.DataType.Type.IPV6: "IPv6",
|
||||
exp.DataType.Type.AGGREGATEFUNCTION: "AggregateFunction",
|
||||
exp.DataType.Type.SIMPLEAGGREGATEFUNCTION: "SimpleAggregateFunction",
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
|
@ -651,12 +655,16 @@ class ClickHouse(Dialect):
|
|||
|
||||
def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]:
|
||||
return super().after_limit_modifiers(expression) + [
|
||||
self.seg("SETTINGS ") + self.expressions(expression, key="settings", flat=True)
|
||||
if expression.args.get("settings")
|
||||
else "",
|
||||
self.seg("FORMAT ") + self.sql(expression, "format")
|
||||
if expression.args.get("format")
|
||||
else "",
|
||||
(
|
||||
self.seg("SETTINGS ") + self.expressions(expression, key="settings", flat=True)
|
||||
if expression.args.get("settings")
|
||||
else ""
|
||||
),
|
||||
(
|
||||
self.seg("FORMAT ") + self.sql(expression, "format")
|
||||
if expression.args.get("format")
|
||||
else ""
|
||||
),
|
||||
]
|
||||
|
||||
def parameterizedagg_sql(self, expression: exp.ParameterizedAgg) -> str:
|
||||
|
|
|
@ -5,7 +5,6 @@ from enum import Enum, auto
|
|||
from functools import reduce
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.errors import ParseError
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import AutoName, flatten, seq_get
|
||||
|
@ -14,11 +13,12 @@ from sqlglot.time import TIMEZONES, format_time
|
|||
from sqlglot.tokens import Token, Tokenizer, TokenType
|
||||
from sqlglot.trie import new_trie
|
||||
|
||||
B = t.TypeVar("B", bound=exp.Binary)
|
||||
|
||||
DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
|
||||
DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import B, E
|
||||
|
||||
|
||||
class Dialects(str, Enum):
|
||||
"""Dialects supported by SQLGLot."""
|
||||
|
@ -381,9 +381,11 @@ class Dialect(metaclass=_Dialect):
|
|||
):
|
||||
expression.set(
|
||||
"this",
|
||||
expression.this.upper()
|
||||
if self.normalization_strategy is NormalizationStrategy.UPPERCASE
|
||||
else expression.this.lower(),
|
||||
(
|
||||
expression.this.upper()
|
||||
if self.normalization_strategy is NormalizationStrategy.UPPERCASE
|
||||
else expression.this.lower()
|
||||
),
|
||||
)
|
||||
|
||||
return expression
|
||||
|
@ -877,9 +879,11 @@ def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectTyp
|
|||
Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
|
||||
"""
|
||||
agg_all_unquoted = agg.transform(
|
||||
lambda node: exp.Identifier(this=node.name, quoted=False)
|
||||
if isinstance(node, exp.Identifier)
|
||||
else node
|
||||
lambda node: (
|
||||
exp.Identifier(this=node.name, quoted=False)
|
||||
if isinstance(node, exp.Identifier)
|
||||
else node
|
||||
)
|
||||
)
|
||||
names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
|
||||
|
||||
|
@ -999,10 +1003,8 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
|
|||
"""Remove table refs from columns in when statements."""
|
||||
alias = expression.this.args.get("alias")
|
||||
|
||||
normalize = (
|
||||
lambda identifier: self.dialect.normalize_identifier(identifier).name
|
||||
if identifier
|
||||
else None
|
||||
normalize = lambda identifier: (
|
||||
self.dialect.normalize_identifier(identifier).name if identifier else None
|
||||
)
|
||||
|
||||
targets = {normalize(expression.this.this)}
|
||||
|
@ -1012,9 +1014,11 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
|
|||
|
||||
for when in expression.expressions:
|
||||
when.transform(
|
||||
lambda node: exp.column(node.this)
|
||||
if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
|
||||
else node,
|
||||
lambda node: (
|
||||
exp.column(node.this)
|
||||
if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
|
||||
else node
|
||||
),
|
||||
copy=False,
|
||||
)
|
||||
|
||||
|
|
|
@ -148,8 +148,8 @@ def _unix_to_time_sql(self: DuckDB.Generator, expression: exp.UnixToTime) -> str
|
|||
def _rename_unless_within_group(
|
||||
a: str, b: str
|
||||
) -> t.Callable[[DuckDB.Generator, exp.Expression], str]:
|
||||
return (
|
||||
lambda self, expression: self.func(a, *flatten(expression.args.values()))
|
||||
return lambda self, expression: (
|
||||
self.func(a, *flatten(expression.args.values()))
|
||||
if isinstance(expression.find_ancestor(exp.Select, exp.WithinGroup), exp.WithinGroup)
|
||||
else self.func(b, *flatten(expression.args.values()))
|
||||
)
|
||||
|
@ -273,9 +273,11 @@ class DuckDB(Dialect):
|
|||
|
||||
PLACEHOLDER_PARSERS = {
|
||||
**parser.Parser.PLACEHOLDER_PARSERS,
|
||||
TokenType.PARAMETER: lambda self: self.expression(exp.Placeholder, this=self._prev.text)
|
||||
if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS)
|
||||
else None,
|
||||
TokenType.PARAMETER: lambda self: (
|
||||
self.expression(exp.Placeholder, this=self._prev.text)
|
||||
if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS)
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
def _parse_types(
|
||||
|
@ -321,9 +323,11 @@ class DuckDB(Dialect):
|
|||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0])
|
||||
if e.expressions and e.expressions[0].find(exp.Select)
|
||||
else inline_array_sql(self, e),
|
||||
exp.Array: lambda self, e: (
|
||||
self.func("ARRAY", e.expressions[0])
|
||||
if e.expressions and e.expressions[0].find(exp.Select)
|
||||
else inline_array_sql(self, e)
|
||||
),
|
||||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||
exp.ArgMax: arg_max_or_min_no_count("ARG_MAX"),
|
||||
exp.ArgMin: arg_max_or_min_no_count("ARG_MIN"),
|
||||
|
|
|
@ -397,9 +397,11 @@ class Hive(Dialect):
|
|||
|
||||
if this and not schema:
|
||||
return this.transform(
|
||||
lambda node: node.replace(exp.DataType.build("text"))
|
||||
if isinstance(node, exp.DataType) and node.is_type("char", "varchar")
|
||||
else node,
|
||||
lambda node: (
|
||||
node.replace(exp.DataType.build("text"))
|
||||
if isinstance(node, exp.DataType) and node.is_type("char", "varchar")
|
||||
else node
|
||||
),
|
||||
copy=False,
|
||||
)
|
||||
|
||||
|
@ -409,9 +411,11 @@ class Hive(Dialect):
|
|||
self,
|
||||
) -> t.Tuple[t.List[exp.Expression], t.Optional[exp.Expression]]:
|
||||
return (
|
||||
self._parse_csv(self._parse_conjunction)
|
||||
if self._match_set({TokenType.PARTITION_BY, TokenType.DISTRIBUTE_BY})
|
||||
else [],
|
||||
(
|
||||
self._parse_csv(self._parse_conjunction)
|
||||
if self._match_set({TokenType.PARTITION_BY, TokenType.DISTRIBUTE_BY})
|
||||
else []
|
||||
),
|
||||
super()._parse_order(skip_order_token=self._match(TokenType.SORT_BY)),
|
||||
)
|
||||
|
||||
|
@ -483,9 +487,9 @@ class Hive(Dialect):
|
|||
exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)),
|
||||
exp.Min: min_or_least,
|
||||
exp.MonthsBetween: lambda self, e: self.func("MONTHS_BETWEEN", e.this, e.expression),
|
||||
exp.NotNullColumnConstraint: lambda self, e: ""
|
||||
if e.args.get("allow_null")
|
||||
else "NOT NULL",
|
||||
exp.NotNullColumnConstraint: lambda self, e: (
|
||||
"" if e.args.get("allow_null") else "NOT NULL"
|
||||
),
|
||||
exp.VarMap: var_map_sql,
|
||||
exp.Create: _create_sql,
|
||||
exp.Quantile: rename_func("PERCENTILE"),
|
||||
|
|
|
@ -166,6 +166,7 @@ class Oracle(Dialect):
|
|||
TABLESAMPLE_KEYWORDS = "SAMPLE"
|
||||
LAST_DAY_SUPPORTS_DATE_PART = False
|
||||
SUPPORTS_SELECT_INTO = True
|
||||
TZ_TO_WITH_TIME_ZONE = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
|
@ -179,6 +180,8 @@ class Oracle(Dialect):
|
|||
exp.DataType.Type.NVARCHAR: "NVARCHAR2",
|
||||
exp.DataType.Type.NCHAR: "NCHAR",
|
||||
exp.DataType.Type.TEXT: "CLOB",
|
||||
exp.DataType.Type.TIMETZ: "TIME",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
|
||||
exp.DataType.Type.BINARY: "BLOB",
|
||||
exp.DataType.Type.VARBINARY: "BLOB",
|
||||
}
|
||||
|
|
|
@ -282,6 +282,12 @@ class Postgres(Dialect):
|
|||
VAR_SINGLE_TOKENS = {"$"}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
PROPERTY_PARSERS = {
|
||||
**parser.Parser.PROPERTY_PARSERS,
|
||||
"SET": lambda self: self.expression(exp.SetConfigProperty, this=self._parse_set()),
|
||||
}
|
||||
PROPERTY_PARSERS.pop("INPUT", None)
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"DATE_TRUNC": parse_timestamp_trunc,
|
||||
|
@ -385,9 +391,11 @@ class Postgres(Dialect):
|
|||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.AnyValue: any_value_to_max_sql,
|
||||
exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
|
||||
if isinstance(seq_get(e.expressions, 0), exp.Select)
|
||||
else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]",
|
||||
exp.Array: lambda self, e: (
|
||||
f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
|
||||
if isinstance(seq_get(e.expressions, 0), exp.Select)
|
||||
else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]"
|
||||
),
|
||||
exp.ArrayConcat: rename_func("ARRAY_CAT"),
|
||||
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
|
||||
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
|
||||
|
@ -396,6 +404,7 @@ class Postgres(Dialect):
|
|||
exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]),
|
||||
exp.CurrentDate: no_paren_current_date_sql,
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.CurrentUser: lambda *_: "CURRENT_USER",
|
||||
exp.DateAdd: _date_add_sql("+"),
|
||||
exp.DateDiff: _date_diff_sql,
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
|
|
|
@ -356,6 +356,7 @@ class Presto(Dialect):
|
|||
exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"),
|
||||
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
|
||||
exp.First: _first_last_sql,
|
||||
exp.FromTimeZone: lambda self, e: f"WITH_TIMEZONE({self.sql(e, 'this')}, {self.sql(e, 'zone')}) AT TIME ZONE 'UTC'",
|
||||
exp.GetPath: path_to_jsonpath(),
|
||||
exp.Group: transforms.preprocess([transforms.unalias_group]),
|
||||
exp.GroupConcat: lambda self, e: self.func(
|
||||
|
|
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
|||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
NormalizationStrategy,
|
||||
|
@ -25,6 +24,9 @@ from sqlglot.expressions import Literal
|
|||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E
|
||||
|
||||
|
||||
def _check_int(s: str) -> bool:
|
||||
if s[0] in ("-", "+"):
|
||||
|
@ -297,10 +299,7 @@ def _parse_colon_get_path(
|
|||
if not self._match(TokenType.COLON):
|
||||
break
|
||||
|
||||
if self._match_set(self.RANGE_PARSERS):
|
||||
this = self.RANGE_PARSERS[self._prev.token_type](self, this) or this
|
||||
|
||||
return this
|
||||
return self._parse_range(this)
|
||||
|
||||
|
||||
def _parse_timestamp_from_parts(args: t.List) -> exp.Func:
|
||||
|
@ -376,7 +375,7 @@ class Snowflake(Dialect):
|
|||
and isinstance(expression.parent, exp.Table)
|
||||
and expression.name.lower() == "dual"
|
||||
):
|
||||
return t.cast(E, expression)
|
||||
return expression # type: ignore
|
||||
|
||||
return super().quote_identifier(expression, identify=identify)
|
||||
|
||||
|
@ -471,6 +470,10 @@ class Snowflake(Dialect):
|
|||
}
|
||||
|
||||
SHOW_PARSERS = {
|
||||
"SCHEMAS": _show_parser("SCHEMAS"),
|
||||
"TERSE SCHEMAS": _show_parser("SCHEMAS"),
|
||||
"OBJECTS": _show_parser("OBJECTS"),
|
||||
"TERSE OBJECTS": _show_parser("OBJECTS"),
|
||||
"PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
|
||||
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
|
||||
"COLUMNS": _show_parser("COLUMNS"),
|
||||
|
@ -580,21 +583,35 @@ class Snowflake(Dialect):
|
|||
scope = None
|
||||
scope_kind = None
|
||||
|
||||
# will identity SHOW TERSE SCHEMAS but not SHOW TERSE PRIMARY KEYS
|
||||
# which is syntactically valid but has no effect on the output
|
||||
terse = self._tokens[self._index - 2].text.upper() == "TERSE"
|
||||
|
||||
like = self._parse_string() if self._match(TokenType.LIKE) else None
|
||||
|
||||
if self._match(TokenType.IN):
|
||||
if self._match_text_seq("ACCOUNT"):
|
||||
scope_kind = "ACCOUNT"
|
||||
elif self._match_set(self.DB_CREATABLES):
|
||||
scope_kind = self._prev.text
|
||||
scope_kind = self._prev.text.upper()
|
||||
if self._curr:
|
||||
scope = self._parse_table()
|
||||
scope = self._parse_table_parts()
|
||||
elif self._curr:
|
||||
scope_kind = "TABLE"
|
||||
scope = self._parse_table()
|
||||
scope_kind = "SCHEMA" if this == "OBJECTS" else "TABLE"
|
||||
scope = self._parse_table_parts()
|
||||
|
||||
return self.expression(
|
||||
exp.Show, this=this, like=like, scope=scope, scope_kind=scope_kind
|
||||
exp.Show,
|
||||
**{
|
||||
"terse": terse,
|
||||
"this": this,
|
||||
"like": like,
|
||||
"scope": scope,
|
||||
"scope_kind": scope_kind,
|
||||
"starts_with": self._match_text_seq("STARTS", "WITH") and self._parse_string(),
|
||||
"limit": self._parse_limit(),
|
||||
"from": self._parse_string() if self._match(TokenType.FROM) else None,
|
||||
},
|
||||
)
|
||||
|
||||
def _parse_alter_table_swap(self) -> exp.SwapTable:
|
||||
|
@ -690,6 +707,9 @@ class Snowflake(Dialect):
|
|||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
exp.Explode: rename_func("FLATTEN"),
|
||||
exp.Extract: rename_func("DATE_PART"),
|
||||
exp.FromTimeZone: lambda self, e: self.func(
|
||||
"CONVERT_TIMEZONE", e.args.get("zone"), "'UTC'", e.this
|
||||
),
|
||||
exp.GenerateSeries: lambda self, e: self.func(
|
||||
"ARRAY_GENERATE_RANGE", e.args["start"], e.args["end"] + 1, e.args.get("step")
|
||||
),
|
||||
|
@ -820,6 +840,7 @@ class Snowflake(Dialect):
|
|||
return f"{explode}{alias}"
|
||||
|
||||
def show_sql(self, expression: exp.Show) -> str:
|
||||
terse = "TERSE " if expression.args.get("terse") else ""
|
||||
like = self.sql(expression, "like")
|
||||
like = f" LIKE {like}" if like else ""
|
||||
|
||||
|
@ -830,7 +851,19 @@ class Snowflake(Dialect):
|
|||
if scope_kind:
|
||||
scope_kind = f" IN {scope_kind}"
|
||||
|
||||
return f"SHOW {expression.name}{like}{scope_kind}{scope}"
|
||||
starts_with = self.sql(expression, "starts_with")
|
||||
if starts_with:
|
||||
starts_with = f" STARTS WITH {starts_with}"
|
||||
|
||||
limit = self.sql(expression, "limit")
|
||||
|
||||
from_ = self.sql(expression, "from")
|
||||
if from_:
|
||||
from_ = f" FROM {from_}"
|
||||
|
||||
return (
|
||||
f"SHOW {terse}{expression.name}{like}{scope_kind}{scope}{starts_with}{limit}{from_}"
|
||||
)
|
||||
|
||||
def regexpextract_sql(self, expression: exp.RegexpExtract) -> str:
|
||||
# Other dialects don't support all of the following parameters, so we need to
|
||||
|
@ -884,3 +917,6 @@ class Snowflake(Dialect):
|
|||
|
||||
def with_properties(self, properties: exp.Properties) -> str:
|
||||
return self.properties(properties, wrapped=False, prefix=self.seg(""), sep=" ")
|
||||
|
||||
def cluster_sql(self, expression: exp.Cluster) -> str:
|
||||
return f"CLUSTER BY ({self.expressions(expression, flat=True)})"
|
||||
|
|
|
@ -80,9 +80,9 @@ class Spark(Spark2):
|
|||
exp.TimestampAdd: lambda self, e: self.func(
|
||||
"DATEADD", e.args.get("unit") or "DAY", e.expression, e.this
|
||||
),
|
||||
exp.TryCast: lambda self, e: self.trycast_sql(e)
|
||||
if e.args.get("safe")
|
||||
else self.cast_sql(e),
|
||||
exp.TryCast: lambda self, e: (
|
||||
self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e)
|
||||
),
|
||||
}
|
||||
TRANSFORMS.pop(exp.AnyValue)
|
||||
TRANSFORMS.pop(exp.DateDiff)
|
||||
|
|
|
@ -129,10 +129,20 @@ class Spark2(Hive):
|
|||
"SHIFTRIGHT": binary_from_function(exp.BitwiseRightShift),
|
||||
"STRING": _parse_as_cast("string"),
|
||||
"TIMESTAMP": _parse_as_cast("timestamp"),
|
||||
"TO_TIMESTAMP": lambda args: _parse_as_cast("timestamp")(args)
|
||||
if len(args) == 1
|
||||
else format_time_lambda(exp.StrToTime, "spark")(args),
|
||||
"TO_TIMESTAMP": lambda args: (
|
||||
_parse_as_cast("timestamp")(args)
|
||||
if len(args) == 1
|
||||
else format_time_lambda(exp.StrToTime, "spark")(args)
|
||||
),
|
||||
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
|
||||
"TO_UTC_TIMESTAMP": lambda args: exp.FromTimeZone(
|
||||
this=exp.cast_unless(
|
||||
seq_get(args, 0) or exp.Var(this=""),
|
||||
exp.DataType.build("timestamp"),
|
||||
exp.DataType.build("timestamp"),
|
||||
),
|
||||
zone=seq_get(args, 1),
|
||||
),
|
||||
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
|
||||
"WEEKOFYEAR": lambda args: exp.WeekOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
}
|
||||
|
@ -188,6 +198,7 @@ class Spark2(Hive):
|
|||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
|
||||
exp.From: transforms.preprocess([_unalias_pivot]),
|
||||
exp.FromTimeZone: lambda self, e: f"TO_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.Map: _map_sql,
|
||||
|
@ -255,10 +266,12 @@ class Spark2(Hive):
|
|||
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
|
||||
return super().columndef_sql(
|
||||
expression,
|
||||
sep=": "
|
||||
if isinstance(expression.parent, exp.DataType)
|
||||
and expression.parent.is_type("struct")
|
||||
else sep,
|
||||
sep=(
|
||||
": "
|
||||
if isinstance(expression.parent, exp.DataType)
|
||||
and expression.parent.is_type("struct")
|
||||
else sep
|
||||
),
|
||||
)
|
||||
|
||||
class Tokenizer(Hive.Tokenizer):
|
||||
|
|
|
@ -38,3 +38,4 @@ class Tableau(Dialect):
|
|||
**parser.Parser.FUNCTIONS,
|
||||
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
|
||||
}
|
||||
NO_PAREN_IF_COMMANDS = False
|
||||
|
|
|
@ -76,9 +76,11 @@ def _format_time_lambda(
|
|||
format=exp.Literal.string(
|
||||
format_time(
|
||||
args[0].name.lower(),
|
||||
{**TSQL.TIME_MAPPING, **FULL_FORMAT_TIME_MAPPING}
|
||||
if full_format_mapping
|
||||
else TSQL.TIME_MAPPING,
|
||||
(
|
||||
{**TSQL.TIME_MAPPING, **FULL_FORMAT_TIME_MAPPING}
|
||||
if full_format_mapping
|
||||
else TSQL.TIME_MAPPING
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
@ -264,6 +266,15 @@ def _parse_timefromparts(args: t.List) -> exp.TimeFromParts:
|
|||
)
|
||||
|
||||
|
||||
def _parse_len(args: t.List) -> exp.Length:
|
||||
this = seq_get(args, 0)
|
||||
|
||||
if this and not this.is_string:
|
||||
this = exp.cast(this, exp.DataType.Type.TEXT)
|
||||
|
||||
return exp.Length(this=this)
|
||||
|
||||
|
||||
class TSQL(Dialect):
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
|
||||
TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"
|
||||
|
@ -431,7 +442,7 @@ class TSQL(Dialect):
|
|||
"IIF": exp.If.from_arg_list,
|
||||
"ISNULL": exp.Coalesce.from_arg_list,
|
||||
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
|
||||
"LEN": exp.Length.from_arg_list,
|
||||
"LEN": _parse_len,
|
||||
"REPLICATE": exp.Repeat.from_arg_list,
|
||||
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
|
||||
"SYSDATETIME": exp.CurrentTimestamp.from_arg_list,
|
||||
|
@ -469,6 +480,7 @@ class TSQL(Dialect):
|
|||
|
||||
ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
|
||||
STRING_ALIASES = True
|
||||
NO_PAREN_IF_COMMANDS = False
|
||||
|
||||
def _parse_projections(self) -> t.List[exp.Expression]:
|
||||
"""
|
||||
|
@ -478,9 +490,11 @@ class TSQL(Dialect):
|
|||
See: https://learn.microsoft.com/en-us/sql/t-sql/queries/select-clause-transact-sql?view=sql-server-ver16#syntax
|
||||
"""
|
||||
return [
|
||||
exp.alias_(projection.expression, projection.this.this, copy=False)
|
||||
if isinstance(projection, exp.EQ) and isinstance(projection.this, exp.Column)
|
||||
else projection
|
||||
(
|
||||
exp.alias_(projection.expression, projection.this.this, copy=False)
|
||||
if isinstance(projection, exp.EQ) and isinstance(projection.this, exp.Column)
|
||||
else projection
|
||||
)
|
||||
for projection in super()._parse_projections()
|
||||
]
|
||||
|
||||
|
@ -702,7 +716,6 @@ class TSQL(Dialect):
|
|||
exp.GroupConcat: _string_agg_sql,
|
||||
exp.If: rename_func("IIF"),
|
||||
exp.LastDay: lambda self, e: self.func("EOMONTH", e.this),
|
||||
exp.Length: rename_func("LEN"),
|
||||
exp.Max: max_or_greatest,
|
||||
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
|
||||
exp.Min: min_or_least,
|
||||
|
@ -922,3 +935,11 @@ class TSQL(Dialect):
|
|||
this = self.sql(expression, "this")
|
||||
expressions = self.expressions(expression, flat=True, sep=" ")
|
||||
return f"CONSTRAINT {this} {expressions}"
|
||||
|
||||
def length_sql(self, expression: exp.Length) -> str:
|
||||
this = expression.this
|
||||
if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.TEXT):
|
||||
this_sql = self.sql(this, "this")
|
||||
else:
|
||||
this_sql = self.sql(this)
|
||||
return self.func("LEN", this_sql)
|
||||
|
|
|
@ -392,9 +392,9 @@ def _lambda_sql(self, e: exp.Lambda) -> str:
|
|||
names = {e.name.lower() for e in e.expressions}
|
||||
|
||||
e = e.transform(
|
||||
lambda n: exp.var(n.name)
|
||||
if isinstance(n, exp.Identifier) and n.name.lower() in names
|
||||
else n
|
||||
lambda n: (
|
||||
exp.var(n.name) if isinstance(n, exp.Identifier) and n.name.lower() in names else n
|
||||
)
|
||||
)
|
||||
|
||||
return f"lambda {self.expressions(e, flat=True)}: {self.sql(e, 'this')}"
|
||||
|
@ -438,9 +438,9 @@ class Python(Dialect):
|
|||
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
|
||||
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {{{self.expressions(e, flat=True)}}}",
|
||||
exp.Interval: lambda self, e: f"INTERVAL({self.sql(e.this)}, '{self.sql(e.unit)}')",
|
||||
exp.Is: lambda self, e: self.binary(e, "==")
|
||||
if isinstance(e.this, exp.Literal)
|
||||
else self.binary(e, "is"),
|
||||
exp.Is: lambda self, e: (
|
||||
self.binary(e, "==") if isinstance(e.this, exp.Literal) else self.binary(e, "is")
|
||||
),
|
||||
exp.Lambda: _lambda_sql,
|
||||
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
|
||||
exp.Null: lambda *_: "None",
|
||||
|
|
|
@ -23,7 +23,6 @@ from copy import deepcopy
|
|||
from enum import auto
|
||||
from functools import reduce
|
||||
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.errors import ErrorLevel, ParseError
|
||||
from sqlglot.helper import (
|
||||
AutoName,
|
||||
|
@ -36,8 +35,7 @@ from sqlglot.helper import (
|
|||
from sqlglot.tokens import Token
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from typing_extensions import Literal as Lit
|
||||
|
||||
from sqlglot._typing import E, Lit
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
|
||||
|
||||
|
@ -389,7 +387,7 @@ class Expression(metaclass=_Expression):
|
|||
ancestor = self.parent
|
||||
while ancestor and not isinstance(ancestor, expression_types):
|
||||
ancestor = ancestor.parent
|
||||
return t.cast(E, ancestor)
|
||||
return ancestor # type: ignore
|
||||
|
||||
@property
|
||||
def parent_select(self) -> t.Optional[Select]:
|
||||
|
@ -555,12 +553,10 @@ class Expression(metaclass=_Expression):
|
|||
return new_node
|
||||
|
||||
@t.overload
|
||||
def replace(self, expression: E) -> E:
|
||||
...
|
||||
def replace(self, expression: E) -> E: ...
|
||||
|
||||
@t.overload
|
||||
def replace(self, expression: None) -> None:
|
||||
...
|
||||
def replace(self, expression: None) -> None: ...
|
||||
|
||||
def replace(self, expression):
|
||||
"""
|
||||
|
@ -781,13 +777,16 @@ class Expression(metaclass=_Expression):
|
|||
this=maybe_copy(self, copy),
|
||||
expressions=[convert(e, copy=copy) for e in expressions],
|
||||
query=maybe_parse(query, copy=copy, **opts) if query else None,
|
||||
unnest=Unnest(
|
||||
expressions=[
|
||||
maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts) for e in ensure_list(unnest)
|
||||
]
|
||||
)
|
||||
if unnest
|
||||
else None,
|
||||
unnest=(
|
||||
Unnest(
|
||||
expressions=[
|
||||
maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts)
|
||||
for e in ensure_list(unnest)
|
||||
]
|
||||
)
|
||||
if unnest
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
def between(self, low: t.Any, high: t.Any, copy: bool = True, **opts) -> Between:
|
||||
|
@ -926,7 +925,7 @@ class DerivedTable(Expression):
|
|||
class Unionable(Expression):
|
||||
def union(
|
||||
self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
|
||||
) -> Unionable:
|
||||
) -> Union:
|
||||
"""
|
||||
Builds a UNION expression.
|
||||
|
||||
|
@ -1134,9 +1133,12 @@ class SetItem(Expression):
|
|||
class Show(Expression):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"terse": False,
|
||||
"target": False,
|
||||
"offset": False,
|
||||
"starts_with": False,
|
||||
"limit": False,
|
||||
"from": False,
|
||||
"like": False,
|
||||
"where": False,
|
||||
"db": False,
|
||||
|
@ -1274,9 +1276,14 @@ class AlterColumn(Expression):
|
|||
"using": False,
|
||||
"default": False,
|
||||
"drop": False,
|
||||
"comment": False,
|
||||
}
|
||||
|
||||
|
||||
class RenameColumn(Expression):
|
||||
arg_types = {"this": True, "to": True, "exists": False}
|
||||
|
||||
|
||||
class RenameTable(Expression):
|
||||
pass
|
||||
|
||||
|
@ -1402,7 +1409,7 @@ class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
|
|||
|
||||
|
||||
class GeneratedAsRowColumnConstraint(ColumnConstraintKind):
|
||||
arg_types = {"start": True, "hidden": False}
|
||||
arg_types = {"start": False, "hidden": False}
|
||||
|
||||
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/create-table.html
|
||||
|
@ -1667,6 +1674,7 @@ class Index(Expression):
|
|||
"unique": False,
|
||||
"primary": False,
|
||||
"amp": False, # teradata
|
||||
"include": False,
|
||||
"partition_by": False, # teradata
|
||||
"where": False, # postgres partial indexes
|
||||
}
|
||||
|
@ -2016,7 +2024,13 @@ class AutoRefreshProperty(Property):
|
|||
|
||||
|
||||
class BlockCompressionProperty(Property):
|
||||
arg_types = {"autotemp": False, "always": False, "default": True, "manual": True, "never": True}
|
||||
arg_types = {
|
||||
"autotemp": False,
|
||||
"always": False,
|
||||
"default": False,
|
||||
"manual": False,
|
||||
"never": False,
|
||||
}
|
||||
|
||||
|
||||
class CharacterSetProperty(Property):
|
||||
|
@ -2089,6 +2103,10 @@ class FreespaceProperty(Property):
|
|||
arg_types = {"this": True, "percent": False}
|
||||
|
||||
|
||||
class InheritsProperty(Property):
|
||||
arg_types = {"expressions": True}
|
||||
|
||||
|
||||
class InputModelProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
@ -2099,11 +2117,11 @@ class OutputModelProperty(Property):
|
|||
|
||||
class IsolatedLoadingProperty(Property):
|
||||
arg_types = {
|
||||
"no": True,
|
||||
"concurrent": True,
|
||||
"for_all": True,
|
||||
"for_insert": True,
|
||||
"for_none": True,
|
||||
"no": False,
|
||||
"concurrent": False,
|
||||
"for_all": False,
|
||||
"for_insert": False,
|
||||
"for_none": False,
|
||||
}
|
||||
|
||||
|
||||
|
@ -2264,6 +2282,10 @@ class SetProperty(Property):
|
|||
arg_types = {"multi": True}
|
||||
|
||||
|
||||
class SetConfigProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class SettingsProperty(Property):
|
||||
arg_types = {"expressions": True}
|
||||
|
||||
|
@ -2407,13 +2429,16 @@ class Tuple(Expression):
|
|||
this=maybe_copy(self, copy),
|
||||
expressions=[convert(e, copy=copy) for e in expressions],
|
||||
query=maybe_parse(query, copy=copy, **opts) if query else None,
|
||||
unnest=Unnest(
|
||||
expressions=[
|
||||
maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts) for e in ensure_list(unnest)
|
||||
]
|
||||
)
|
||||
if unnest
|
||||
else None,
|
||||
unnest=(
|
||||
Unnest(
|
||||
expressions=[
|
||||
maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts)
|
||||
for e in ensure_list(unnest)
|
||||
]
|
||||
)
|
||||
if unnest
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
@ -3631,6 +3656,8 @@ class DataType(Expression):
|
|||
|
||||
class Type(AutoName):
|
||||
ARRAY = auto()
|
||||
AGGREGATEFUNCTION = auto()
|
||||
SIMPLEAGGREGATEFUNCTION = auto()
|
||||
BIGDECIMAL = auto()
|
||||
BIGINT = auto()
|
||||
BIGSERIAL = auto()
|
||||
|
@ -4162,6 +4189,10 @@ class AtTimeZone(Expression):
|
|||
arg_types = {"this": True, "zone": True}
|
||||
|
||||
|
||||
class FromTimeZone(Expression):
|
||||
arg_types = {"this": True, "zone": True}
|
||||
|
||||
|
||||
class Between(Predicate):
|
||||
arg_types = {"this": True, "low": True, "high": True}
|
||||
|
||||
|
@ -5456,8 +5487,7 @@ def maybe_parse(
|
|||
prefix: t.Optional[str] = None,
|
||||
copy: bool = False,
|
||||
**opts,
|
||||
) -> E:
|
||||
...
|
||||
) -> E: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
|
@ -5469,8 +5499,7 @@ def maybe_parse(
|
|||
prefix: t.Optional[str] = None,
|
||||
copy: bool = False,
|
||||
**opts,
|
||||
) -> E:
|
||||
...
|
||||
) -> E: ...
|
||||
|
||||
|
||||
def maybe_parse(
|
||||
|
@ -5522,13 +5551,11 @@ def maybe_parse(
|
|||
|
||||
|
||||
@t.overload
|
||||
def maybe_copy(instance: None, copy: bool = True) -> None:
|
||||
...
|
||||
def maybe_copy(instance: None, copy: bool = True) -> None: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def maybe_copy(instance: E, copy: bool = True) -> E:
|
||||
...
|
||||
def maybe_copy(instance: E, copy: bool = True) -> E: ...
|
||||
|
||||
|
||||
def maybe_copy(instance, copy=True):
|
||||
|
@ -6151,15 +6178,13 @@ SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
|
|||
|
||||
|
||||
@t.overload
|
||||
def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None:
|
||||
...
|
||||
def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def to_identifier(
|
||||
name: str | Identifier, quoted: t.Optional[bool] = None, copy: bool = True
|
||||
) -> Identifier:
|
||||
...
|
||||
) -> Identifier: ...
|
||||
|
||||
|
||||
def to_identifier(name, quoted=None, copy=True):
|
||||
|
@ -6231,13 +6256,11 @@ def to_interval(interval: str | Literal) -> Interval:
|
|||
|
||||
|
||||
@t.overload
|
||||
def to_table(sql_path: str | Table, **kwargs) -> Table:
|
||||
...
|
||||
def to_table(sql_path: str | Table, **kwargs) -> Table: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def to_table(sql_path: None, **kwargs) -> None:
|
||||
...
|
||||
def to_table(sql_path: None, **kwargs) -> None: ...
|
||||
|
||||
|
||||
def to_table(
|
||||
|
@ -6562,6 +6585,34 @@ def rename_table(old_name: str | Table, new_name: str | Table) -> AlterTable:
|
|||
)
|
||||
|
||||
|
||||
def rename_column(
|
||||
table_name: str | Table,
|
||||
old_column_name: str | Column,
|
||||
new_column_name: str | Column,
|
||||
exists: t.Optional[bool] = None,
|
||||
) -> AlterTable:
|
||||
"""Build ALTER TABLE... RENAME COLUMN... expression
|
||||
|
||||
Args:
|
||||
table_name: Name of the table
|
||||
old_column: The old name of the column
|
||||
new_column: The new name of the column
|
||||
exists: Whether or not to add the `IF EXISTS` clause
|
||||
|
||||
Returns:
|
||||
Alter table expression
|
||||
"""
|
||||
table = to_table(table_name)
|
||||
old_column = to_column(old_column_name)
|
||||
new_column = to_column(new_column_name)
|
||||
return AlterTable(
|
||||
this=table,
|
||||
actions=[
|
||||
RenameColumn(this=old_column, to=new_column, exists=exists),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def convert(value: t.Any, copy: bool = False) -> Expression:
|
||||
"""Convert a python value into an expression object.
|
||||
|
||||
|
@ -6581,7 +6632,7 @@ def convert(value: t.Any, copy: bool = False) -> Expression:
|
|||
if isinstance(value, bool):
|
||||
return Boolean(this=value)
|
||||
if value is None or (isinstance(value, float) and math.isnan(value)):
|
||||
return NULL
|
||||
return null()
|
||||
if isinstance(value, numbers.Number):
|
||||
return Literal.number(value)
|
||||
if isinstance(value, datetime.datetime):
|
||||
|
@ -6674,9 +6725,11 @@ def table_name(table: Table | str, dialect: DialectType = None, identify: bool =
|
|||
raise ValueError(f"Cannot parse {table}")
|
||||
|
||||
return ".".join(
|
||||
part.sql(dialect=dialect, identify=True, copy=False)
|
||||
if identify or not SAFE_IDENTIFIER_RE.match(part.name)
|
||||
else part.name
|
||||
(
|
||||
part.sql(dialect=dialect, identify=True, copy=False)
|
||||
if identify or not SAFE_IDENTIFIER_RE.match(part.name)
|
||||
else part.name
|
||||
)
|
||||
for part in table.parts
|
||||
)
|
||||
|
||||
|
@ -6942,9 +6995,3 @@ def null() -> Null:
|
|||
Returns a Null expression.
|
||||
"""
|
||||
return Null()
|
||||
|
||||
|
||||
# TODO: deprecate this
|
||||
TRUE = Boolean(this=True)
|
||||
FALSE = Boolean(this=False)
|
||||
NULL = Null()
|
||||
|
|
|
@ -77,6 +77,7 @@ class Generator:
|
|||
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
|
||||
exp.ExternalProperty: lambda self, e: "EXTERNAL",
|
||||
exp.HeapProperty: lambda self, e: "HEAP",
|
||||
exp.InheritsProperty: lambda self, e: f"INHERITS ({self.expressions(e, flat=True)})",
|
||||
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
|
||||
exp.InputModelProperty: lambda self, e: f"INPUT{self.sql(e, 'this')}",
|
||||
exp.IntervalSpan: lambda self, e: f"{self.sql(e, 'this')} TO {self.sql(e, 'expression')}",
|
||||
|
@ -96,6 +97,7 @@ class Generator:
|
|||
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
|
||||
exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}",
|
||||
exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
|
||||
exp.SetConfigProperty: lambda self, e: self.sql(e, "this"),
|
||||
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
|
||||
exp.SqlReadWriteProperty: lambda self, e: e.name,
|
||||
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
|
||||
|
@ -323,6 +325,7 @@ class Generator:
|
|||
exp.FileFormatProperty: exp.Properties.Location.POST_WITH,
|
||||
exp.FreespaceProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.HeapProperty: exp.Properties.Location.POST_WITH,
|
||||
exp.InheritsProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.InputModelProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.IsolatedLoadingProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.JournalProperty: exp.Properties.Location.POST_NAME,
|
||||
|
@ -353,6 +356,7 @@ class Generator:
|
|||
exp.Set: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SettingsProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SetProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.SetConfigProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SqlReadWriteProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
|
||||
|
@ -568,9 +572,11 @@ class Generator:
|
|||
|
||||
def wrap(self, expression: exp.Expression | str) -> str:
|
||||
this_sql = self.indent(
|
||||
self.sql(expression)
|
||||
if isinstance(expression, (exp.Select, exp.Union))
|
||||
else self.sql(expression, "this"),
|
||||
(
|
||||
self.sql(expression)
|
||||
if isinstance(expression, (exp.Select, exp.Union))
|
||||
else self.sql(expression, "this")
|
||||
),
|
||||
level=1,
|
||||
pad=0,
|
||||
)
|
||||
|
@ -605,9 +611,11 @@ class Generator:
|
|||
lines = sql.split("\n")
|
||||
|
||||
return "\n".join(
|
||||
line
|
||||
if (skip_first and i == 0) or (skip_last and i == len(lines) - 1)
|
||||
else f"{' ' * (level * self._indent + pad)}{line}"
|
||||
(
|
||||
line
|
||||
if (skip_first and i == 0) or (skip_last and i == len(lines) - 1)
|
||||
else f"{' ' * (level * self._indent + pad)}{line}"
|
||||
)
|
||||
for i, line in enumerate(lines)
|
||||
)
|
||||
|
||||
|
@ -775,7 +783,7 @@ class Generator:
|
|||
def generatedasrowcolumnconstraint_sql(
|
||||
self, expression: exp.GeneratedAsRowColumnConstraint
|
||||
) -> str:
|
||||
start = "START" if expression.args["start"] else "END"
|
||||
start = "START" if expression.args.get("start") else "END"
|
||||
hidden = " HIDDEN" if expression.args.get("hidden") else ""
|
||||
return f"GENERATED ALWAYS AS ROW {start}{hidden}"
|
||||
|
||||
|
@ -1111,7 +1119,10 @@ class Generator:
|
|||
partition_by = self.expressions(expression, key="partition_by", flat=True)
|
||||
partition_by = f" PARTITION BY {partition_by}" if partition_by else ""
|
||||
where = self.sql(expression, "where")
|
||||
return f"{unique}{primary}{amp}{index}{name}{table}{using}{columns}{partition_by}{where}"
|
||||
include = self.expressions(expression, key="include", flat=True)
|
||||
if include:
|
||||
include = f" INCLUDE ({include})"
|
||||
return f"{unique}{primary}{amp}{index}{name}{table}{using}{columns}{include}{partition_by}{where}"
|
||||
|
||||
def identifier_sql(self, expression: exp.Identifier) -> str:
|
||||
text = expression.name
|
||||
|
@ -2017,9 +2028,11 @@ class Generator:
|
|||
def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
|
||||
return [
|
||||
self.sql(expression, "qualify"),
|
||||
self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True)
|
||||
if expression.args.get("windows")
|
||||
else "",
|
||||
(
|
||||
self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True)
|
||||
if expression.args.get("windows")
|
||||
else ""
|
||||
),
|
||||
self.sql(expression, "distribute"),
|
||||
self.sql(expression, "sort"),
|
||||
self.sql(expression, "cluster"),
|
||||
|
@ -2552,6 +2565,11 @@ class Generator:
|
|||
zone = self.sql(expression, "zone")
|
||||
return f"{this} AT TIME ZONE {zone}"
|
||||
|
||||
def fromtimezone_sql(self, expression: exp.FromTimeZone) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
zone = self.sql(expression, "zone")
|
||||
return f"{this} AT TIME ZONE {zone} AT TIME ZONE 'UTC'"
|
||||
|
||||
def add_sql(self, expression: exp.Add) -> str:
|
||||
return self.binary(expression, "+")
|
||||
|
||||
|
@ -2669,6 +2687,10 @@ class Generator:
|
|||
if default:
|
||||
return f"ALTER COLUMN {this} SET DEFAULT {default}"
|
||||
|
||||
comment = self.sql(expression, "comment")
|
||||
if comment:
|
||||
return f"ALTER COLUMN {this} COMMENT {comment}"
|
||||
|
||||
if not expression.args.get("drop"):
|
||||
self.unsupported("Unsupported ALTER COLUMN syntax")
|
||||
|
||||
|
@ -2683,6 +2705,12 @@ class Generator:
|
|||
this = self.sql(expression, "this")
|
||||
return f"RENAME TO {this}"
|
||||
|
||||
def renamecolumn_sql(self, expression: exp.RenameColumn) -> str:
|
||||
exists = " IF EXISTS" if expression.args.get("exists") else ""
|
||||
old_column = self.sql(expression, "this")
|
||||
new_column = self.sql(expression, "to")
|
||||
return f"RENAME COLUMN{exists} {old_column} TO {new_column}"
|
||||
|
||||
def altertable_sql(self, expression: exp.AlterTable) -> str:
|
||||
actions = expression.args["actions"]
|
||||
|
||||
|
|
|
@ -53,13 +53,11 @@ def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
|
|||
|
||||
|
||||
@t.overload
|
||||
def ensure_list(value: t.Collection[T]) -> t.List[T]:
|
||||
...
|
||||
def ensure_list(value: t.Collection[T]) -> t.List[T]: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def ensure_list(value: T) -> t.List[T]:
|
||||
...
|
||||
def ensure_list(value: T) -> t.List[T]: ...
|
||||
|
||||
|
||||
def ensure_list(value):
|
||||
|
@ -81,13 +79,11 @@ def ensure_list(value):
|
|||
|
||||
|
||||
@t.overload
|
||||
def ensure_collection(value: t.Collection[T]) -> t.Collection[T]:
|
||||
...
|
||||
def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def ensure_collection(value: T) -> t.Collection[T]:
|
||||
...
|
||||
def ensure_collection(value: T) -> t.Collection[T]: ...
|
||||
|
||||
|
||||
def ensure_collection(value):
|
||||
|
|
215
sqlglot/jsonpath.py
Normal file
215
sqlglot/jsonpath.py
Normal file
|
@ -0,0 +1,215 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot.errors import ParseError
|
||||
from sqlglot.expressions import SAFE_IDENTIFIER_RE
|
||||
from sqlglot.tokens import Token, Tokenizer, TokenType
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import Lit
|
||||
|
||||
|
||||
class JSONPathTokenizer(Tokenizer):
|
||||
SINGLE_TOKENS = {
|
||||
"(": TokenType.L_PAREN,
|
||||
")": TokenType.R_PAREN,
|
||||
"[": TokenType.L_BRACKET,
|
||||
"]": TokenType.R_BRACKET,
|
||||
":": TokenType.COLON,
|
||||
",": TokenType.COMMA,
|
||||
"-": TokenType.DASH,
|
||||
".": TokenType.DOT,
|
||||
"?": TokenType.PLACEHOLDER,
|
||||
"@": TokenType.PARAMETER,
|
||||
"'": TokenType.QUOTE,
|
||||
'"': TokenType.QUOTE,
|
||||
"$": TokenType.DOLLAR,
|
||||
"*": TokenType.STAR,
|
||||
}
|
||||
|
||||
KEYWORDS = {
|
||||
"..": TokenType.DOT,
|
||||
}
|
||||
|
||||
IDENTIFIER_ESCAPES = ["\\"]
|
||||
STRING_ESCAPES = ["\\"]
|
||||
|
||||
|
||||
JSONPathNode = t.Dict[str, t.Any]
|
||||
|
||||
|
||||
def _node(kind: str, value: t.Any = None, **kwargs: t.Any) -> JSONPathNode:
|
||||
node = {"kind": kind, **kwargs}
|
||||
|
||||
if value is not None:
|
||||
node["value"] = value
|
||||
|
||||
return node
|
||||
|
||||
|
||||
def parse(path: str) -> t.List[JSONPathNode]:
|
||||
"""Takes in a JSONPath string and converts into a list of nodes."""
|
||||
tokens = JSONPathTokenizer().tokenize(path)
|
||||
size = len(tokens)
|
||||
|
||||
i = 0
|
||||
|
||||
def _curr() -> t.Optional[TokenType]:
|
||||
return tokens[i].token_type if i < size else None
|
||||
|
||||
def _prev() -> Token:
|
||||
return tokens[i - 1]
|
||||
|
||||
def _advance() -> Token:
|
||||
nonlocal i
|
||||
i += 1
|
||||
return _prev()
|
||||
|
||||
def _error(msg: str) -> str:
|
||||
return f"{msg} at index {i}: {path}"
|
||||
|
||||
@t.overload
|
||||
def _match(token_type: TokenType, raise_unmatched: Lit[True] = True) -> Token:
|
||||
pass
|
||||
|
||||
@t.overload
|
||||
def _match(token_type: TokenType, raise_unmatched: Lit[False] = False) -> t.Optional[Token]:
|
||||
pass
|
||||
|
||||
def _match(token_type, raise_unmatched=False):
|
||||
if _curr() == token_type:
|
||||
return _advance()
|
||||
if raise_unmatched:
|
||||
raise ParseError(_error(f"Expected {token_type}"))
|
||||
return None
|
||||
|
||||
def _parse_literal() -> t.Any:
|
||||
token = _match(TokenType.STRING) or _match(TokenType.IDENTIFIER)
|
||||
if token:
|
||||
return token.text
|
||||
if _match(TokenType.STAR):
|
||||
return _node("wildcard")
|
||||
if _match(TokenType.PLACEHOLDER) or _match(TokenType.L_PAREN):
|
||||
script = _prev().text == "("
|
||||
start = i
|
||||
|
||||
while True:
|
||||
if _match(TokenType.L_BRACKET):
|
||||
_parse_bracket() # nested call which we can throw away
|
||||
if _curr() in (TokenType.R_BRACKET, None):
|
||||
break
|
||||
_advance()
|
||||
return _node(
|
||||
"script" if script else "filter", path[tokens[start].start : tokens[i].end]
|
||||
)
|
||||
|
||||
number = "-" if _match(TokenType.DASH) else ""
|
||||
|
||||
token = _match(TokenType.NUMBER)
|
||||
if token:
|
||||
number += token.text
|
||||
|
||||
if number:
|
||||
return int(number)
|
||||
return False
|
||||
|
||||
def _parse_slice() -> t.Any:
|
||||
start = _parse_literal()
|
||||
end = _parse_literal() if _match(TokenType.COLON) else None
|
||||
step = _parse_literal() if _match(TokenType.COLON) else None
|
||||
|
||||
if end is None and step is None:
|
||||
return start
|
||||
return _node("slice", start=start, end=end, step=step)
|
||||
|
||||
def _parse_bracket() -> JSONPathNode:
|
||||
literal = _parse_slice()
|
||||
|
||||
if isinstance(literal, str) or literal is not False:
|
||||
indexes = [literal]
|
||||
while _match(TokenType.COMMA):
|
||||
literal = _parse_slice()
|
||||
|
||||
if literal:
|
||||
indexes.append(literal)
|
||||
|
||||
if len(indexes) == 1:
|
||||
if isinstance(literal, str):
|
||||
node = _node("key", indexes[0])
|
||||
elif isinstance(literal, dict) and literal["kind"] in ("script", "filter"):
|
||||
node = _node("selector", indexes[0])
|
||||
else:
|
||||
node = _node("subscript", indexes[0])
|
||||
else:
|
||||
node = _node("union", indexes)
|
||||
else:
|
||||
raise ParseError(_error("Cannot have empty segment"))
|
||||
|
||||
_match(TokenType.R_BRACKET, raise_unmatched=True)
|
||||
|
||||
return node
|
||||
|
||||
nodes = []
|
||||
|
||||
while _curr():
|
||||
if _match(TokenType.DOLLAR):
|
||||
nodes.append(_node("root"))
|
||||
elif _match(TokenType.DOT):
|
||||
recursive = _prev().text == ".."
|
||||
value = _match(TokenType.VAR) or _match(TokenType.STAR)
|
||||
nodes.append(
|
||||
_node("recursive" if recursive else "child", value=value.text if value else None)
|
||||
)
|
||||
elif _match(TokenType.L_BRACKET):
|
||||
nodes.append(_parse_bracket())
|
||||
elif _match(TokenType.VAR):
|
||||
nodes.append(_node("key", _prev().text))
|
||||
elif _match(TokenType.STAR):
|
||||
nodes.append(_node("wildcard"))
|
||||
elif _match(TokenType.PARAMETER):
|
||||
nodes.append(_node("current"))
|
||||
else:
|
||||
raise ParseError(_error(f"Unexpected {tokens[i].token_type}"))
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
MAPPING = {
|
||||
"child": lambda n: f".{n['value']}" if n.get("value") is not None else "",
|
||||
"filter": lambda n: f"?{n['value']}",
|
||||
"key": lambda n: (
|
||||
f".{n['value']}" if SAFE_IDENTIFIER_RE.match(n["value"]) else f'[{generate([n["value"]])}]'
|
||||
),
|
||||
"recursive": lambda n: f"..{n['value']}" if n.get("value") is not None else "..",
|
||||
"root": lambda _: "$",
|
||||
"script": lambda n: f"({n['value']}",
|
||||
"slice": lambda n: ":".join(
|
||||
"" if p is False else generate([p])
|
||||
for p in [n["start"], n["end"], n["step"]]
|
||||
if p is not None
|
||||
),
|
||||
"selector": lambda n: f"[{generate([n['value']])}]",
|
||||
"subscript": lambda n: f"[{generate([n['value']])}]",
|
||||
"union": lambda n: f"[{','.join(generate([p]) for p in n['value'])}]",
|
||||
"wildcard": lambda _: "*",
|
||||
}
|
||||
|
||||
|
||||
def generate(
|
||||
nodes: t.List[JSONPathNode],
|
||||
mapping: t.Optional[t.Dict[str, t.Callable[[JSONPathNode], str]]] = None,
|
||||
) -> str:
|
||||
mapping = MAPPING if mapping is None else mapping
|
||||
path = []
|
||||
|
||||
for node in nodes:
|
||||
if isinstance(node, dict):
|
||||
path.append(mapping[node["kind"]](node))
|
||||
elif isinstance(node, str):
|
||||
escaped = node.replace('"', '\\"')
|
||||
path.append(f'"{escaped}"')
|
||||
else:
|
||||
path.append(str(node))
|
||||
|
||||
return "".join(path)
|
|
@ -41,9 +41,9 @@ class Node:
|
|||
else:
|
||||
label = node.expression.sql(pretty=True, dialect=dialect)
|
||||
source = node.source.transform(
|
||||
lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
|
||||
if n is node.expression
|
||||
else n,
|
||||
lambda n: (
|
||||
exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
|
||||
),
|
||||
copy=False,
|
||||
).sql(pretty=True, dialect=dialect)
|
||||
title = f"<pre>{source}</pre>"
|
||||
|
|
|
@ -4,7 +4,6 @@ import functools
|
|||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.helper import (
|
||||
ensure_list,
|
||||
is_date_unit,
|
||||
|
@ -17,7 +16,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
|
|||
from sqlglot.schema import Schema, ensure_schema
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
B = t.TypeVar("B", bound=exp.Binary)
|
||||
from sqlglot._typing import B, E
|
||||
|
||||
BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type]
|
||||
BinaryCoercions = t.Dict[
|
||||
|
@ -479,6 +478,20 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
self._set_type(expression, target_type)
|
||||
return self._annotate_args(expression)
|
||||
|
||||
@t.no_type_check
|
||||
def _annotate_struct_value(
|
||||
self, expression: exp.Expression
|
||||
) -> t.Optional[exp.DataType] | exp.ColumnDef:
|
||||
alias = expression.args.get("alias")
|
||||
if alias:
|
||||
return exp.ColumnDef(this=alias.copy(), kind=expression.type)
|
||||
|
||||
# Case: key = value or key := value
|
||||
if expression.expression:
|
||||
return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)
|
||||
|
||||
return expression.type
|
||||
|
||||
@t.no_type_check
|
||||
def _annotate_by_args(
|
||||
self,
|
||||
|
@ -516,16 +529,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
)
|
||||
|
||||
if struct:
|
||||
expressions = [
|
||||
expr.type
|
||||
if not expr.args.get("alias")
|
||||
else exp.ColumnDef(this=expr.args["alias"].copy(), kind=expr.type)
|
||||
for expr in expressions
|
||||
]
|
||||
|
||||
self._set_type(
|
||||
expression,
|
||||
exp.DataType(this=exp.DataType.Type.STRUCT, expressions=expressions, nested=True),
|
||||
exp.DataType(
|
||||
this=exp.DataType.Type.STRUCT,
|
||||
expressions=[self._annotate_struct_value(expr) for expr in expressions],
|
||||
nested=True,
|
||||
),
|
||||
)
|
||||
|
||||
return expression
|
||||
|
|
|
@ -3,18 +3,18 @@ from __future__ import annotations
|
|||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import Dialect, DialectType
|
||||
|
||||
|
||||
@t.overload
|
||||
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
|
||||
...
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E
|
||||
|
||||
|
||||
@t.overload
|
||||
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier:
|
||||
...
|
||||
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: ...
|
||||
|
||||
|
||||
def normalize_identifiers(expression, dialect=None):
|
||||
|
|
|
@ -4,7 +4,6 @@ import itertools
|
|||
import typing as t
|
||||
|
||||
from sqlglot import alias, exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import Dialect, DialectType
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import seq_get
|
||||
|
@ -12,6 +11,9 @@ from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_
|
|||
from sqlglot.optimizer.simplify import simplify_parens
|
||||
from sqlglot.schema import Schema, ensure_schema
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E
|
||||
|
||||
|
||||
def qualify_columns(
|
||||
expression: exp.Expression,
|
||||
|
@ -210,7 +212,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
|
|||
if not node:
|
||||
return
|
||||
|
||||
for column, *_ in walk_in_scope(node):
|
||||
for column, *_ in walk_in_scope(node, prune=lambda node, *_: node.is_star):
|
||||
if not isinstance(column, exp.Column):
|
||||
continue
|
||||
|
||||
|
@ -525,6 +527,7 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
|
|||
selection = alias(
|
||||
selection,
|
||||
alias=selection.output_name or f"_col_{i}",
|
||||
copy=False,
|
||||
)
|
||||
if aliased_column:
|
||||
selection.set("alias", exp.to_identifier(aliased_column))
|
||||
|
|
|
@ -4,12 +4,14 @@ import itertools
|
|||
import typing as t
|
||||
|
||||
from sqlglot import alias, exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
from sqlglot.helper import csv_reader, name_sequence
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import Schema
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E
|
||||
|
||||
|
||||
def qualify_tables(
|
||||
expression: E,
|
||||
|
@ -46,6 +48,18 @@ def qualify_tables(
|
|||
db = exp.parse_identifier(db, dialect=dialect) if db else None
|
||||
catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
|
||||
|
||||
def _qualify(table: exp.Table) -> None:
|
||||
if isinstance(table.this, exp.Identifier):
|
||||
if not table.args.get("db"):
|
||||
table.set("db", db)
|
||||
if not table.args.get("catalog") and table.args.get("db"):
|
||||
table.set("catalog", catalog)
|
||||
|
||||
if not isinstance(expression, exp.Subqueryable):
|
||||
for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Unionable)):
|
||||
if isinstance(node, exp.Table):
|
||||
_qualify(node)
|
||||
|
||||
for scope in traverse_scope(expression):
|
||||
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
|
||||
if isinstance(derived_table, exp.Subquery):
|
||||
|
@ -66,11 +80,7 @@ def qualify_tables(
|
|||
|
||||
for name, source in scope.sources.items():
|
||||
if isinstance(source, exp.Table):
|
||||
if isinstance(source.this, exp.Identifier):
|
||||
if not source.args.get("db"):
|
||||
source.set("db", db)
|
||||
if not source.args.get("catalog") and source.args.get("db"):
|
||||
source.set("catalog", catalog)
|
||||
_qualify(source)
|
||||
|
||||
pivots = pivots = source.args.get("pivots")
|
||||
if not source.alias:
|
||||
|
@ -107,5 +117,14 @@ def qualify_tables(
|
|||
if isinstance(udtf, exp.Values) and not table_alias.columns:
|
||||
for i, e in enumerate(udtf.expressions[0].expressions):
|
||||
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
|
||||
else:
|
||||
for node, parent, _ in scope.walk():
|
||||
if (
|
||||
isinstance(node, exp.Table)
|
||||
and not node.alias
|
||||
and isinstance(parent, (exp.From, exp.Join))
|
||||
):
|
||||
# Mutates the table by attaching an alias to it
|
||||
alias(node, node.name, copy=False, table=True)
|
||||
|
||||
return expression
|
||||
|
|
|
@ -323,9 +323,14 @@ class Scope:
|
|||
sources in the current scope.
|
||||
"""
|
||||
if self._external_columns is None:
|
||||
self._external_columns = [
|
||||
c for c in self.columns if c.table not in self.selected_sources
|
||||
]
|
||||
if isinstance(self.expression, exp.Union):
|
||||
left, right = self.union_scopes
|
||||
self._external_columns = left.external_columns + right.external_columns
|
||||
else:
|
||||
self._external_columns = [
|
||||
c for c in self.columns if c.table not in self.selected_sources
|
||||
]
|
||||
|
||||
return self._external_columns
|
||||
|
||||
@property
|
||||
|
@ -477,11 +482,12 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
|
|||
|
||||
Args:
|
||||
expression (exp.Expression): expression to traverse
|
||||
|
||||
Returns:
|
||||
list[Scope]: scope instances
|
||||
"""
|
||||
if isinstance(expression, exp.Unionable) or (
|
||||
isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable)
|
||||
isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Unionable)
|
||||
):
|
||||
return list(_traverse_scope(Scope(expression)))
|
||||
|
||||
|
|
|
@ -1068,9 +1068,11 @@ def extract_interval(expression):
|
|||
def date_literal(date):
|
||||
return exp.cast(
|
||||
exp.Literal.string(date),
|
||||
exp.DataType.Type.DATETIME
|
||||
if isinstance(date, datetime.datetime)
|
||||
else exp.DataType.Type.DATE,
|
||||
(
|
||||
exp.DataType.Type.DATETIME
|
||||
if isinstance(date, datetime.datetime)
|
||||
else exp.DataType.Type.DATE
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -50,11 +50,12 @@ def unnest(select, parent_select, next_alias_name):
|
|||
):
|
||||
return
|
||||
|
||||
clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
|
||||
|
||||
# This subquery returns a scalar and can just be converted to a cross join
|
||||
if not isinstance(predicate, (exp.In, exp.Any)):
|
||||
column = exp.column(select.selects[0].alias_or_name, alias)
|
||||
|
||||
clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
|
||||
clause_parent_select = clause.parent_select if clause else None
|
||||
|
||||
if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or (
|
||||
|
@ -84,12 +85,18 @@ def unnest(select, parent_select, next_alias_name):
|
|||
column = _other_operand(predicate)
|
||||
value = select.selects[0]
|
||||
|
||||
on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
|
||||
_replace(predicate, f"NOT {on.right} IS NULL")
|
||||
join_key = exp.column(value.alias, alias)
|
||||
join_key_not_null = join_key.is_(exp.null()).not_()
|
||||
|
||||
if isinstance(clause, exp.Join):
|
||||
_replace(predicate, exp.true())
|
||||
parent_select.where(join_key_not_null, copy=False)
|
||||
else:
|
||||
_replace(predicate, join_key_not_null)
|
||||
|
||||
parent_select.join(
|
||||
select.group_by(value.this, copy=False),
|
||||
on=on,
|
||||
on=column.eq(join_key),
|
||||
join_type="LEFT",
|
||||
join_alias=alias,
|
||||
copy=False,
|
||||
|
|
|
@ -12,9 +12,7 @@ from sqlglot.tokens import Token, Tokenizer, TokenType
|
|||
from sqlglot.trie import TrieResult, in_trie, new_trie
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from typing_extensions import Literal
|
||||
|
||||
from sqlglot._typing import E
|
||||
from sqlglot._typing import E, Lit
|
||||
from sqlglot.dialects.dialect import Dialect, DialectType
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
@ -148,6 +146,11 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.ENUM16,
|
||||
}
|
||||
|
||||
AGGREGATE_TYPE_TOKENS = {
|
||||
TokenType.AGGREGATEFUNCTION,
|
||||
TokenType.SIMPLEAGGREGATEFUNCTION,
|
||||
}
|
||||
|
||||
TYPE_TOKENS = {
|
||||
TokenType.BIT,
|
||||
TokenType.BOOLEAN,
|
||||
|
@ -241,6 +244,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.NULL,
|
||||
*ENUM_TYPE_TOKENS,
|
||||
*NESTED_TYPE_TOKENS,
|
||||
*AGGREGATE_TYPE_TOKENS,
|
||||
}
|
||||
|
||||
SIGNED_TO_UNSIGNED_TYPE_TOKEN = {
|
||||
|
@ -653,9 +657,11 @@ class Parser(metaclass=_Parser):
|
|||
PLACEHOLDER_PARSERS = {
|
||||
TokenType.PLACEHOLDER: lambda self: self.expression(exp.Placeholder),
|
||||
TokenType.PARAMETER: lambda self: self._parse_parameter(),
|
||||
TokenType.COLON: lambda self: self.expression(exp.Placeholder, this=self._prev.text)
|
||||
if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS)
|
||||
else None,
|
||||
TokenType.COLON: lambda self: (
|
||||
self.expression(exp.Placeholder, this=self._prev.text)
|
||||
if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS)
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
RANGE_PARSERS = {
|
||||
|
@ -705,6 +711,9 @@ class Parser(metaclass=_Parser):
|
|||
"IMMUTABLE": lambda self: self.expression(
|
||||
exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
|
||||
),
|
||||
"INHERITS": lambda self: self.expression(
|
||||
exp.InheritsProperty, expressions=self._parse_wrapped_csv(self._parse_table)
|
||||
),
|
||||
"INPUT": lambda self: self.expression(exp.InputModelProperty, this=self._parse_schema()),
|
||||
"JOURNAL": lambda self, **kwargs: self._parse_journal(**kwargs),
|
||||
"LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty),
|
||||
|
@ -822,6 +831,7 @@ class Parser(metaclass=_Parser):
|
|||
ALTER_PARSERS = {
|
||||
"ADD": lambda self: self._parse_alter_table_add(),
|
||||
"ALTER": lambda self: self._parse_alter_table_alter(),
|
||||
"CLUSTER BY": lambda self: self._parse_cluster(wrapped=True),
|
||||
"DELETE": lambda self: self.expression(exp.Delete, where=self._parse_where()),
|
||||
"DROP": lambda self: self._parse_alter_table_drop(),
|
||||
"RENAME": lambda self: self._parse_alter_table_rename(),
|
||||
|
@ -973,6 +983,9 @@ class Parser(metaclass=_Parser):
|
|||
MODIFIERS_ATTACHED_TO_UNION = True
|
||||
UNION_MODIFIERS = {"order", "limit", "offset"}
|
||||
|
||||
# parses no parenthesis if statements as commands
|
||||
NO_PAREN_IF_COMMANDS = True
|
||||
|
||||
__slots__ = (
|
||||
"error_level",
|
||||
"error_message_context",
|
||||
|
@ -1207,7 +1220,20 @@ class Parser(metaclass=_Parser):
|
|||
if index != self._index:
|
||||
self._advance(index - self._index)
|
||||
|
||||
def _warn_unsupported(self) -> None:
|
||||
if len(self._tokens) <= 1:
|
||||
return
|
||||
|
||||
# We use _find_sql because self.sql may comprise multiple chunks, and we're only
|
||||
# interested in emitting a warning for the one being currently processed.
|
||||
sql = self._find_sql(self._tokens[0], self._tokens[-1])[: self.error_message_context]
|
||||
|
||||
logger.warning(
|
||||
f"'{sql}' contains unsupported syntax. Falling back to parsing as a 'Command'."
|
||||
)
|
||||
|
||||
def _parse_command(self) -> exp.Command:
|
||||
self._warn_unsupported()
|
||||
return self.expression(
|
||||
exp.Command, this=self._prev.text.upper(), expression=self._parse_string()
|
||||
)
|
||||
|
@ -1329,8 +1355,10 @@ class Parser(metaclass=_Parser):
|
|||
start = self._prev
|
||||
comments = self._prev_comments
|
||||
|
||||
replace = start.text.upper() == "REPLACE" or self._match_pair(
|
||||
TokenType.OR, TokenType.REPLACE
|
||||
replace = (
|
||||
start.token_type == TokenType.REPLACE
|
||||
or self._match_pair(TokenType.OR, TokenType.REPLACE)
|
||||
or self._match_pair(TokenType.OR, TokenType.ALTER)
|
||||
)
|
||||
unique = self._match(TokenType.UNIQUE)
|
||||
|
||||
|
@ -1440,6 +1468,9 @@ class Parser(metaclass=_Parser):
|
|||
exp.Clone, this=self._parse_table(schema=True), shallow=shallow, copy=copy
|
||||
)
|
||||
|
||||
if self._curr:
|
||||
return self._parse_as_command(start)
|
||||
|
||||
return self.expression(
|
||||
exp.Create,
|
||||
comments=comments,
|
||||
|
@ -1516,11 +1547,13 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return self.expression(
|
||||
exp.FileFormatProperty,
|
||||
this=self.expression(
|
||||
exp.InputOutputFormat, input_format=input_format, output_format=output_format
|
||||
)
|
||||
if input_format or output_format
|
||||
else self._parse_var_or_string() or self._parse_number() or self._parse_id_var(),
|
||||
this=(
|
||||
self.expression(
|
||||
exp.InputOutputFormat, input_format=input_format, output_format=output_format
|
||||
)
|
||||
if input_format or output_format
|
||||
else self._parse_var_or_string() or self._parse_number() or self._parse_id_var()
|
||||
),
|
||||
)
|
||||
|
||||
def _parse_property_assignment(self, exp_class: t.Type[E], **kwargs: t.Any) -> E:
|
||||
|
@ -1632,8 +1665,15 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return self.expression(exp.ChecksumProperty, on=on, default=self._match(TokenType.DEFAULT))
|
||||
|
||||
def _parse_cluster(self) -> exp.Cluster:
|
||||
return self.expression(exp.Cluster, expressions=self._parse_csv(self._parse_ordered))
|
||||
def _parse_cluster(self, wrapped: bool = False) -> exp.Cluster:
|
||||
return self.expression(
|
||||
exp.Cluster,
|
||||
expressions=(
|
||||
self._parse_wrapped_csv(self._parse_ordered)
|
||||
if wrapped
|
||||
else self._parse_csv(self._parse_ordered)
|
||||
),
|
||||
)
|
||||
|
||||
def _parse_clustered_by(self) -> exp.ClusteredByProperty:
|
||||
self._match_text_seq("BY")
|
||||
|
@ -2681,6 +2721,8 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
columns = None
|
||||
|
||||
include = self._parse_wrapped_id_vars() if self._match_text_seq("INCLUDE") else None
|
||||
|
||||
return self.expression(
|
||||
exp.Index,
|
||||
this=index,
|
||||
|
@ -2690,6 +2732,7 @@ class Parser(metaclass=_Parser):
|
|||
unique=unique,
|
||||
primary=primary,
|
||||
amp=amp,
|
||||
include=include,
|
||||
partition_by=self._parse_partition_by(),
|
||||
where=self._parse_where(),
|
||||
)
|
||||
|
@ -3380,8 +3423,8 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_comparison(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_tokens(self._parse_range, self.COMPARISON)
|
||||
|
||||
def _parse_range(self) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_bitwise()
|
||||
def _parse_range(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
|
||||
this = this or self._parse_bitwise()
|
||||
negate = self._match(TokenType.NOT)
|
||||
|
||||
if self._match_set(self.RANGE_PARSERS):
|
||||
|
@ -3535,14 +3578,21 @@ class Parser(metaclass=_Parser):
|
|||
return self._parse_tokens(self._parse_factor, self.TERM)
|
||||
|
||||
def _parse_factor(self) -> t.Optional[exp.Expression]:
|
||||
if self.EXPONENT:
|
||||
factor = self._parse_tokens(self._parse_exponent, self.FACTOR)
|
||||
else:
|
||||
factor = self._parse_tokens(self._parse_unary, self.FACTOR)
|
||||
if isinstance(factor, exp.Div):
|
||||
factor.args["typed"] = self.dialect.TYPED_DIVISION
|
||||
factor.args["safe"] = self.dialect.SAFE_DIVISION
|
||||
return factor
|
||||
parse_method = self._parse_exponent if self.EXPONENT else self._parse_unary
|
||||
this = parse_method()
|
||||
|
||||
while self._match_set(self.FACTOR):
|
||||
this = self.expression(
|
||||
self.FACTOR[self._prev.token_type],
|
||||
this=this,
|
||||
comments=self._prev_comments,
|
||||
expression=parse_method(),
|
||||
)
|
||||
if isinstance(this, exp.Div):
|
||||
this.args["typed"] = self.dialect.TYPED_DIVISION
|
||||
this.args["safe"] = self.dialect.SAFE_DIVISION
|
||||
|
||||
return this
|
||||
|
||||
def _parse_exponent(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_tokens(self._parse_unary, self.EXPONENT)
|
||||
|
@ -3617,6 +3667,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return exp.DataType.build(type_name, udt=True)
|
||||
else:
|
||||
self._retreat(self._index - 1)
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
@ -3631,6 +3682,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
nested = type_token in self.NESTED_TYPE_TOKENS
|
||||
is_struct = type_token in self.STRUCT_TYPE_TOKENS
|
||||
is_aggregate = type_token in self.AGGREGATE_TYPE_TOKENS
|
||||
expressions = None
|
||||
maybe_func = False
|
||||
|
||||
|
@ -3645,6 +3697,18 @@ class Parser(metaclass=_Parser):
|
|||
)
|
||||
elif type_token in self.ENUM_TYPE_TOKENS:
|
||||
expressions = self._parse_csv(self._parse_equality)
|
||||
elif is_aggregate:
|
||||
func_or_ident = self._parse_function(anonymous=True) or self._parse_id_var(
|
||||
any_token=False, tokens=(TokenType.VAR,)
|
||||
)
|
||||
if not func_or_ident or not self._match(TokenType.COMMA):
|
||||
return None
|
||||
expressions = self._parse_csv(
|
||||
lambda: self._parse_types(
|
||||
check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
|
||||
)
|
||||
)
|
||||
expressions.insert(0, func_or_ident)
|
||||
else:
|
||||
expressions = self._parse_csv(self._parse_type_size)
|
||||
|
||||
|
@ -4413,6 +4477,10 @@ class Parser(metaclass=_Parser):
|
|||
self._match_r_paren()
|
||||
else:
|
||||
index = self._index - 1
|
||||
|
||||
if self.NO_PAREN_IF_COMMANDS and index == 0:
|
||||
return self._parse_as_command(self._prev)
|
||||
|
||||
condition = self._parse_conjunction()
|
||||
|
||||
if not condition:
|
||||
|
@ -4624,12 +4692,10 @@ class Parser(metaclass=_Parser):
|
|||
return None
|
||||
|
||||
@t.overload
|
||||
def _parse_json_object(self, agg: Literal[False]) -> exp.JSONObject:
|
||||
...
|
||||
def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ...
|
||||
|
||||
@t.overload
|
||||
def _parse_json_object(self, agg: Literal[True]) -> exp.JSONObjectAgg:
|
||||
...
|
||||
def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ...
|
||||
|
||||
def _parse_json_object(self, agg=False):
|
||||
star = self._parse_star()
|
||||
|
@ -4974,11 +5040,12 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
if alias:
|
||||
this = self.expression(exp.Alias, comments=comments, this=this, alias=alias)
|
||||
column = this.this
|
||||
|
||||
# Moves the comment next to the alias in `expr /* comment */ AS alias`
|
||||
if not this.comments and this.this.comments:
|
||||
this.comments = this.this.comments
|
||||
this.this.comments = None
|
||||
if not this.comments and column and column.comments:
|
||||
this.comments = column.comments
|
||||
column.comments = None
|
||||
|
||||
return this
|
||||
|
||||
|
@ -5244,7 +5311,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
if self._match_text_seq("CHECK"):
|
||||
expression = self._parse_wrapped(self._parse_conjunction)
|
||||
enforced = self._match_text_seq("ENFORCED")
|
||||
enforced = self._match_text_seq("ENFORCED") or False
|
||||
|
||||
return self.expression(
|
||||
exp.AddConstraint, this=this, expression=expression, enforced=enforced
|
||||
|
@ -5278,6 +5345,8 @@ class Parser(metaclass=_Parser):
|
|||
return self.expression(exp.AlterColumn, this=column, drop=True)
|
||||
if self._match_pair(TokenType.SET, TokenType.DEFAULT):
|
||||
return self.expression(exp.AlterColumn, this=column, default=self._parse_conjunction())
|
||||
if self._match(TokenType.COMMENT):
|
||||
return self.expression(exp.AlterColumn, this=column, comment=self._parse_string())
|
||||
|
||||
self._match_text_seq("SET", "DATA")
|
||||
return self.expression(
|
||||
|
@ -5298,7 +5367,18 @@ class Parser(metaclass=_Parser):
|
|||
self._retreat(index)
|
||||
return self._parse_csv(self._parse_drop_column)
|
||||
|
||||
def _parse_alter_table_rename(self) -> exp.RenameTable:
|
||||
def _parse_alter_table_rename(self) -> t.Optional[exp.RenameTable | exp.RenameColumn]:
|
||||
if self._match(TokenType.COLUMN):
|
||||
exists = self._parse_exists()
|
||||
old_column = self._parse_column()
|
||||
to = self._match_text_seq("TO")
|
||||
new_column = self._parse_column()
|
||||
|
||||
if old_column is None or to is None or new_column is None:
|
||||
return None
|
||||
|
||||
return self.expression(exp.RenameColumn, this=old_column, to=new_column, exists=exists)
|
||||
|
||||
self._match_text_seq("TO")
|
||||
return self.expression(exp.RenameTable, this=self._parse_table(schema=True))
|
||||
|
||||
|
@ -5319,7 +5399,7 @@ class Parser(metaclass=_Parser):
|
|||
if parser:
|
||||
actions = ensure_list(parser(self))
|
||||
|
||||
if not self._curr:
|
||||
if not self._curr and actions:
|
||||
return self.expression(
|
||||
exp.AlterTable,
|
||||
this=this,
|
||||
|
@ -5467,6 +5547,7 @@ class Parser(metaclass=_Parser):
|
|||
self._advance()
|
||||
text = self._find_sql(start, self._prev)
|
||||
size = len(start.text)
|
||||
self._warn_unsupported()
|
||||
return exp.Command(this=text[:size], expression=text[size:])
|
||||
|
||||
def _parse_dict_property(self, this: str) -> exp.DictProperty:
|
||||
|
@ -5634,7 +5715,7 @@ class Parser(metaclass=_Parser):
|
|||
if advance:
|
||||
self._advance()
|
||||
return True
|
||||
return False
|
||||
return None
|
||||
|
||||
def _match_text_seq(self, *texts, advance=True):
|
||||
index = self._index
|
||||
|
@ -5643,7 +5724,7 @@ class Parser(metaclass=_Parser):
|
|||
self._advance()
|
||||
else:
|
||||
self._retreat(index)
|
||||
return False
|
||||
return None
|
||||
|
||||
if not advance:
|
||||
self._retreat(index)
|
||||
|
@ -5651,14 +5732,12 @@ class Parser(metaclass=_Parser):
|
|||
return True
|
||||
|
||||
@t.overload
|
||||
def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression:
|
||||
...
|
||||
def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression: ...
|
||||
|
||||
@t.overload
|
||||
def _replace_columns_with_dots(
|
||||
self, this: t.Optional[exp.Expression]
|
||||
) -> t.Optional[exp.Expression]:
|
||||
...
|
||||
) -> t.Optional[exp.Expression]: ...
|
||||
|
||||
def _replace_columns_with_dots(self, this):
|
||||
if isinstance(this, exp.Dot):
|
||||
|
|
|
@ -106,6 +106,19 @@ class Schema(abc.ABC):
|
|||
name = column if isinstance(column, str) else column.name
|
||||
return name in self.column_names(table, dialect=dialect, normalize=normalize)
|
||||
|
||||
@abc.abstractmethod
|
||||
def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
|
||||
"""
|
||||
Returns the schema of a given table.
|
||||
|
||||
Args:
|
||||
table: the target table.
|
||||
raise_on_missing: whether or not to raise in case the schema is not found.
|
||||
|
||||
Returns:
|
||||
The schema of the target table.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def supported_table_args(self) -> t.Tuple[str, ...]:
|
||||
|
@ -156,11 +169,9 @@ class AbstractMappingSchema:
|
|||
return [table.this.name]
|
||||
return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
|
||||
|
||||
def find(
|
||||
self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
|
||||
) -> t.Optional[t.Any]:
|
||||
def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
|
||||
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
|
||||
value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
|
||||
value, trie = in_trie(self.mapping_trie, parts)
|
||||
|
||||
if value == TrieResult.FAILED:
|
||||
return None
|
||||
|
|
|
@ -191,6 +191,8 @@ class TokenType(AutoName):
|
|||
FIXEDSTRING = auto()
|
||||
LOWCARDINALITY = auto()
|
||||
NESTED = auto()
|
||||
AGGREGATEFUNCTION = auto()
|
||||
SIMPLEAGGREGATEFUNCTION = auto()
|
||||
UNKNOWN = auto()
|
||||
|
||||
# keywords
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue