Merging upstream version 20.1.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
d4fe7bdb16
commit
90988d8258
127 changed files with 73384 additions and 73067 deletions
|
@ -22,6 +22,7 @@ from sqlglot.expressions import (
|
|||
Expression as Expression,
|
||||
alias_ as alias,
|
||||
and_ as and_,
|
||||
case as case,
|
||||
cast as cast,
|
||||
column as column,
|
||||
condition as condition,
|
||||
|
@ -82,8 +83,7 @@ def parse(
|
|||
Returns:
|
||||
The resulting syntax tree collection.
|
||||
"""
|
||||
dialect = Dialect.get_or_raise(read or dialect)()
|
||||
return dialect.parse(sql, **opts)
|
||||
return Dialect.get_or_raise(read or dialect).parse(sql, **opts)
|
||||
|
||||
|
||||
@t.overload
|
||||
|
@ -117,7 +117,7 @@ def parse_one(
|
|||
The syntax tree for the first parsed statement.
|
||||
"""
|
||||
|
||||
dialect = Dialect.get_or_raise(read or dialect)()
|
||||
dialect = Dialect.get_or_raise(read or dialect)
|
||||
|
||||
if into:
|
||||
result = dialect.parse_into(into, sql, **opts)
|
||||
|
@ -157,7 +157,8 @@ def transpile(
|
|||
The list of transpiled SQL statements.
|
||||
"""
|
||||
write = (read if write is None else write) if identity else write
|
||||
write = Dialect.get_or_raise(write)
|
||||
return [
|
||||
Dialect.get_or_raise(write)().generate(expression, copy=False, **opts) if expression else ""
|
||||
write.generate(expression, copy=False, **opts) if expression else ""
|
||||
for expression in parse(sql, read, error_level=error_level)
|
||||
]
|
||||
|
|
|
@ -81,7 +81,7 @@ if args.parse:
|
|||
)
|
||||
]
|
||||
elif args.tokenize:
|
||||
objs = sqlglot.Dialect.get_or_raise(args.read)().tokenize(sql)
|
||||
objs = sqlglot.Dialect.get_or_raise(args.read).tokenize(sql)
|
||||
else:
|
||||
objs = sqlglot.transpile(
|
||||
sql,
|
||||
|
|
|
@ -297,27 +297,26 @@ class DataFrame:
|
|||
select_expressions.append(expression_select_pair) # type: ignore
|
||||
return select_expressions
|
||||
|
||||
def sql(
|
||||
self, dialect: t.Optional[DialectType] = None, optimize: bool = True, **kwargs
|
||||
) -> t.List[str]:
|
||||
def sql(self, dialect: DialectType = None, optimize: bool = True, **kwargs) -> t.List[str]:
|
||||
from sqlglot.dataframe.sql.session import SparkSession
|
||||
|
||||
if dialect and Dialect.get_or_raise(dialect)() != SparkSession().dialect:
|
||||
logger.warning(
|
||||
f"The recommended way of defining a dialect is by doing `SparkSession.builder.config('sqlframe.dialect', '{dialect}').getOrCreate()`. It is no longer needed then when calling `sql`. If you run into issues try updating your query to use this pattern."
|
||||
)
|
||||
dialect = Dialect.get_or_raise(dialect or SparkSession().dialect)
|
||||
|
||||
df = self._resolve_pending_hints()
|
||||
select_expressions = df._get_select_expressions()
|
||||
output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
|
||||
replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
|
||||
|
||||
for expression_type, select_expression in select_expressions:
|
||||
select_expression = select_expression.transform(replace_id_value, replacement_mapping)
|
||||
if optimize:
|
||||
quote_identifiers(select_expression)
|
||||
quote_identifiers(select_expression, dialect=dialect)
|
||||
select_expression = t.cast(
|
||||
exp.Select, optimize_func(select_expression, dialect=SparkSession().dialect)
|
||||
exp.Select, optimize_func(select_expression, dialect=dialect)
|
||||
)
|
||||
|
||||
select_expression = df._replace_cte_names_with_hashes(select_expression)
|
||||
|
||||
expression: t.Union[exp.Select, exp.Cache, exp.Drop]
|
||||
if expression_type == exp.Cache:
|
||||
cache_table_name = df._create_hash_from_expression(select_expression)
|
||||
|
@ -330,13 +329,12 @@ class DataFrame:
|
|||
sqlglot.schema.add_table(
|
||||
cache_table_name,
|
||||
{
|
||||
expression.alias_or_name: expression.type.sql(
|
||||
dialect=SparkSession().dialect
|
||||
)
|
||||
expression.alias_or_name: expression.type.sql(dialect=dialect)
|
||||
for expression in select_expression.expressions
|
||||
},
|
||||
dialect=SparkSession().dialect,
|
||||
dialect=dialect,
|
||||
)
|
||||
|
||||
cache_storage_level = select_expression.args["cache_storage_level"]
|
||||
options = [
|
||||
exp.Literal.string("storageLevel"),
|
||||
|
@ -345,6 +343,7 @@ class DataFrame:
|
|||
expression = exp.Cache(
|
||||
this=cache_table, expression=select_expression, lazy=True, options=options
|
||||
)
|
||||
|
||||
# We will drop the "view" if it exists before running the cache table
|
||||
output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
|
||||
elif expression_type == exp.Create:
|
||||
|
@ -355,18 +354,17 @@ class DataFrame:
|
|||
select_without_ctes = select_expression.copy()
|
||||
select_without_ctes.set("with", None)
|
||||
expression.set("expression", select_without_ctes)
|
||||
|
||||
if select_expression.ctes:
|
||||
expression.set("with", exp.With(expressions=select_expression.ctes))
|
||||
elif expression_type == exp.Select:
|
||||
expression = select_expression
|
||||
else:
|
||||
raise ValueError(f"Invalid expression type: {expression_type}")
|
||||
|
||||
output_expressions.append(expression)
|
||||
|
||||
return [
|
||||
expression.sql(**{"dialect": SparkSession().dialect, **kwargs})
|
||||
for expression in output_expressions
|
||||
]
|
||||
return [expression.sql(dialect=dialect, **kwargs) for expression in output_expressions]
|
||||
|
||||
def copy(self, **kwargs) -> DataFrame:
|
||||
return DataFrame(**object_to_dict(self, **kwargs))
|
||||
|
@ -542,12 +540,7 @@ class DataFrame:
|
|||
"""
|
||||
columns = self._ensure_and_normalize_cols(cols)
|
||||
pre_ordered_col_indexes = [
|
||||
x
|
||||
for x in [
|
||||
i if isinstance(col.expression, exp.Ordered) else None
|
||||
for i, col in enumerate(columns)
|
||||
]
|
||||
if x is not None
|
||||
i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered)
|
||||
]
|
||||
if ascending is None:
|
||||
ascending = [True] * len(columns)
|
||||
|
|
|
@ -306,7 +306,7 @@ def collect_list(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def collect_set(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, expression.SetAgg)
|
||||
return Column.invoke_expression_over_column(col, expression.ArrayUniqueAgg)
|
||||
|
||||
|
||||
def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
|
||||
|
|
|
@ -28,7 +28,7 @@ class SparkSession:
|
|||
self.known_sequence_ids = set()
|
||||
self.name_to_sequence_id_mapping = defaultdict(list)
|
||||
self.incrementing_id = 1
|
||||
self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT)()
|
||||
self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT)
|
||||
|
||||
def __new__(cls, *args, **kwargs) -> SparkSession:
|
||||
if cls._instance is None:
|
||||
|
@ -182,7 +182,7 @@ class SparkSession:
|
|||
|
||||
def getOrCreate(self) -> SparkSession:
|
||||
spark = SparkSession()
|
||||
spark.dialect = Dialect.get_or_raise(self.dialect)()
|
||||
spark.dialect = Dialect.get_or_raise(self.dialect)
|
||||
return spark
|
||||
|
||||
@classproperty
|
||||
|
|
|
@ -8,6 +8,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
|
|||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
NormalizationStrategy,
|
||||
arg_max_or_min_no_count,
|
||||
binary_from_function,
|
||||
date_add_interval_sql,
|
||||
|
@ -23,6 +24,7 @@ from sqlglot.dialects.dialect import (
|
|||
regexp_replace_sql,
|
||||
rename_func,
|
||||
timestrtotime_sql,
|
||||
ts_or_ds_add_cast,
|
||||
ts_or_ds_to_date_sql,
|
||||
)
|
||||
from sqlglot.helper import seq_get, split_num_words
|
||||
|
@ -174,6 +176,44 @@ def _parse_to_hex(args: t.List) -> exp.Hex | exp.MD5:
|
|||
return exp.MD5(this=arg.this) if isinstance(arg, exp.MD5Digest) else exp.Hex(this=arg)
|
||||
|
||||
|
||||
def _array_contains_sql(self: BigQuery.Generator, expression: exp.ArrayContains) -> str:
|
||||
return self.sql(
|
||||
exp.Exists(
|
||||
this=exp.select("1")
|
||||
.from_(exp.Unnest(expressions=[expression.left]).as_("_unnest", table=["_col"]))
|
||||
.where(exp.column("_col").eq(expression.right))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _ts_or_ds_add_sql(self: BigQuery.Generator, expression: exp.TsOrDsAdd) -> str:
|
||||
return date_add_interval_sql("DATE", "ADD")(self, ts_or_ds_add_cast(expression))
|
||||
|
||||
|
||||
def _ts_or_ds_diff_sql(self: BigQuery.Generator, expression: exp.TsOrDsDiff) -> str:
|
||||
expression.this.replace(exp.cast(expression.this, "TIMESTAMP", copy=True))
|
||||
expression.expression.replace(exp.cast(expression.expression, "TIMESTAMP", copy=True))
|
||||
unit = expression.args.get("unit") or "DAY"
|
||||
return self.func("DATE_DIFF", expression.this, expression.expression, unit)
|
||||
|
||||
|
||||
def _unix_to_time_sql(self: BigQuery.Generator, expression: exp.UnixToTime) -> str:
|
||||
scale = expression.args.get("scale")
|
||||
timestamp = self.sql(expression, "this")
|
||||
if scale in (None, exp.UnixToTime.SECONDS):
|
||||
return f"TIMESTAMP_SECONDS({timestamp})"
|
||||
if scale == exp.UnixToTime.MILLIS:
|
||||
return f"TIMESTAMP_MILLIS({timestamp})"
|
||||
if scale == exp.UnixToTime.MICROS:
|
||||
return f"TIMESTAMP_MICROS({timestamp})"
|
||||
if scale == exp.UnixToTime.NANOS:
|
||||
# We need to cast to INT64 because that's what BQ expects
|
||||
return f"TIMESTAMP_MICROS(CAST({timestamp} / 1000 AS INT64))"
|
||||
|
||||
self.unsupported(f"Unsupported scale for timestamp: {scale}.")
|
||||
return ""
|
||||
|
||||
|
||||
class BigQuery(Dialect):
|
||||
UNNEST_COLUMN_ONLY = True
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
@ -181,7 +221,7 @@ class BigQuery(Dialect):
|
|||
LOG_BASE_FIRST = False
|
||||
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
|
||||
|
||||
# bigquery udfs are case sensitive
|
||||
NORMALIZE_FUNCTIONS = False
|
||||
|
@ -220,8 +260,7 @@ class BigQuery(Dialect):
|
|||
# https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table
|
||||
PSEUDOCOLUMNS = {"_PARTITIONTIME", "_PARTITIONDATE"}
|
||||
|
||||
@classmethod
|
||||
def normalize_identifier(cls, expression: E) -> E:
|
||||
def normalize_identifier(self, expression: E) -> E:
|
||||
if isinstance(expression, exp.Identifier):
|
||||
parent = expression.parent
|
||||
while isinstance(parent, exp.Dot):
|
||||
|
@ -265,7 +304,6 @@ class BigQuery(Dialect):
|
|||
"DECLARE": TokenType.COMMAND,
|
||||
"FLOAT64": TokenType.DOUBLE,
|
||||
"FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
|
||||
"INT64": TokenType.BIGINT,
|
||||
"MODEL": TokenType.MODEL,
|
||||
"NOT DETERMINISTIC": TokenType.VOLATILE,
|
||||
"RECORD": TokenType.STRUCT,
|
||||
|
@ -316,6 +354,15 @@ class BigQuery(Dialect):
|
|||
"TIME_SUB": parse_date_delta_with_interval(exp.TimeSub),
|
||||
"TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd),
|
||||
"TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub),
|
||||
"TIMESTAMP_MICROS": lambda args: exp.UnixToTime(
|
||||
this=seq_get(args, 0), scale=exp.UnixToTime.MICROS
|
||||
),
|
||||
"TIMESTAMP_MILLIS": lambda args: exp.UnixToTime(
|
||||
this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS
|
||||
),
|
||||
"TIMESTAMP_SECONDS": lambda args: exp.UnixToTime(
|
||||
this=seq_get(args, 0), scale=exp.UnixToTime.SECONDS
|
||||
),
|
||||
"TO_JSON_STRING": exp.JSONFormat.from_arg_list,
|
||||
}
|
||||
|
||||
|
@ -358,6 +405,24 @@ class BigQuery(Dialect):
|
|||
|
||||
NULL_TOKENS = {TokenType.NULL, TokenType.UNKNOWN}
|
||||
|
||||
STATEMENT_PARSERS = {
|
||||
**parser.Parser.STATEMENT_PARSERS,
|
||||
TokenType.END: lambda self: self._parse_as_command(self._prev),
|
||||
TokenType.FOR: lambda self: self._parse_for_in(),
|
||||
}
|
||||
|
||||
BRACKET_OFFSETS = {
|
||||
"OFFSET": (0, False),
|
||||
"ORDINAL": (1, False),
|
||||
"SAFE_OFFSET": (0, True),
|
||||
"SAFE_ORDINAL": (1, True),
|
||||
}
|
||||
|
||||
def _parse_for_in(self) -> exp.ForIn:
|
||||
this = self._parse_range()
|
||||
self._match_text_seq("DO")
|
||||
return self.expression(exp.ForIn, this=this, expression=self._parse_statement())
|
||||
|
||||
def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]:
|
||||
this = super()._parse_table_part(schema=schema) or self._parse_number()
|
||||
|
||||
|
@ -419,6 +484,26 @@ class BigQuery(Dialect):
|
|||
|
||||
return json_object
|
||||
|
||||
def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
bracket = super()._parse_bracket(this)
|
||||
|
||||
if this is bracket:
|
||||
return bracket
|
||||
|
||||
if isinstance(bracket, exp.Bracket):
|
||||
for expression in bracket.expressions:
|
||||
name = expression.name.upper()
|
||||
|
||||
if name not in self.BRACKET_OFFSETS:
|
||||
break
|
||||
|
||||
offset, safe = self.BRACKET_OFFSETS[name]
|
||||
bracket.set("offset", offset)
|
||||
bracket.set("safe", safe)
|
||||
expression.replace(expression.expressions[0])
|
||||
|
||||
return bracket
|
||||
|
||||
class Generator(generator.Generator):
|
||||
EXPLICIT_UNION = True
|
||||
INTERVAL_ALLOWS_PLURAL_FORM = False
|
||||
|
@ -430,12 +515,14 @@ class BigQuery(Dialect):
|
|||
NVL2_SUPPORTED = False
|
||||
UNNEST_WITH_ORDINALITY = False
|
||||
COLLATE_IS_FUNC = True
|
||||
LIMIT_ONLY_LITERALS = True
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
|
||||
exp.ArgMax: arg_max_or_min_no_count("MAX_BY"),
|
||||
exp.ArgMin: arg_max_or_min_no_count("MIN_BY"),
|
||||
exp.ArrayContains: _array_contains_sql,
|
||||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||
exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
|
||||
exp.CollateProperty: lambda self, e: f"DEFAULT COLLATE {self.sql(e, 'this')}"
|
||||
|
@ -498,10 +585,13 @@ class BigQuery(Dialect):
|
|||
exp.TimestampAdd: date_add_interval_sql("TIMESTAMP", "ADD"),
|
||||
exp.TimestampSub: date_add_interval_sql("TIMESTAMP", "SUB"),
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeToStr: lambda self, e: f"FORMAT_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
|
||||
exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression),
|
||||
exp.TsOrDsAdd: date_add_interval_sql("DATE", "ADD"),
|
||||
exp.TsOrDsAdd: _ts_or_ds_add_sql,
|
||||
exp.TsOrDsDiff: _ts_or_ds_diff_sql,
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"),
|
||||
exp.Unhex: rename_func("FROM_HEX"),
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.Values: _derived_table_values_to_unnest,
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
}
|
||||
|
@ -671,6 +761,23 @@ class BigQuery(Dialect):
|
|||
|
||||
return inline_array_sql(self, expression)
|
||||
|
||||
def bracket_sql(self, expression: exp.Bracket) -> str:
|
||||
expressions = expression.expressions
|
||||
expressions_sql = ", ".join(self.sql(e) for e in expressions)
|
||||
offset = expression.args.get("offset")
|
||||
|
||||
if offset == 0:
|
||||
expressions_sql = f"OFFSET({expressions_sql})"
|
||||
elif offset == 1:
|
||||
expressions_sql = f"ORDINAL({expressions_sql})"
|
||||
else:
|
||||
self.unsupported(f"Unsupported array offset: {offset}")
|
||||
|
||||
if expression.args.get("safe"):
|
||||
expressions_sql = f"SAFE_{expressions_sql}"
|
||||
|
||||
return f"{self.sql(expression, 'this')}[{expressions_sql}]"
|
||||
|
||||
def transaction_sql(self, *_) -> str:
|
||||
return "BEGIN TRANSACTION"
|
||||
|
||||
|
|
|
@ -35,8 +35,8 @@ def _quantile_sql(self, e):
|
|||
class ClickHouse(Dialect):
|
||||
NORMALIZE_FUNCTIONS: bool | str = False
|
||||
NULL_ORDERING = "nulls_are_last"
|
||||
STRICT_STRING_CONCAT = True
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
SAFE_DIVISION = True
|
||||
|
||||
ESCAPE_SEQUENCES = {
|
||||
"\\0": "\0",
|
||||
|
@ -63,11 +63,7 @@ class ClickHouse(Dialect):
|
|||
"FLOAT32": TokenType.FLOAT,
|
||||
"FLOAT64": TokenType.DOUBLE,
|
||||
"GLOBAL": TokenType.GLOBAL,
|
||||
"INT16": TokenType.SMALLINT,
|
||||
"INT256": TokenType.INT256,
|
||||
"INT32": TokenType.INT,
|
||||
"INT64": TokenType.BIGINT,
|
||||
"INT8": TokenType.TINYINT,
|
||||
"LOWCARDINALITY": TokenType.LOWCARDINALITY,
|
||||
"MAP": TokenType.MAP,
|
||||
"NESTED": TokenType.NESTED,
|
||||
|
@ -112,6 +108,7 @@ class ClickHouse(Dialect):
|
|||
|
||||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
"ARRAYJOIN": lambda self: self.expression(exp.Explode, this=self._parse_expression()),
|
||||
"QUANTILE": lambda self: self._parse_quantile(),
|
||||
}
|
||||
|
||||
|
@ -223,12 +220,13 @@ class ClickHouse(Dialect):
|
|||
except ParseError:
|
||||
# WITH <expression> AS <identifier>
|
||||
self._retreat(index)
|
||||
statement = self._parse_statement()
|
||||
|
||||
if statement and isinstance(statement.this, exp.Alias):
|
||||
self.raise_error("Expected CTE to have alias")
|
||||
|
||||
return self.expression(exp.CTE, this=statement, alias=statement and statement.this)
|
||||
return self.expression(
|
||||
exp.CTE,
|
||||
this=self._parse_field(),
|
||||
alias=self._parse_table_alias(),
|
||||
scalar=True,
|
||||
)
|
||||
|
||||
def _parse_join_parts(
|
||||
self,
|
||||
|
@ -385,9 +383,11 @@ class ClickHouse(Dialect):
|
|||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
|
||||
),
|
||||
exp.Explode: rename_func("arrayJoin"),
|
||||
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
|
||||
exp.IsNan: rename_func("isNaN"),
|
||||
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||
exp.Nullif: rename_func("nullIf"),
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.Quantile: _quantile_sql,
|
||||
|
@ -459,19 +459,11 @@ class ClickHouse(Dialect):
|
|||
|
||||
return super().datatype_sql(expression)
|
||||
|
||||
def safeconcat_sql(self, expression: exp.SafeConcat) -> str:
|
||||
# Clickhouse errors out if we try to cast a NULL value to TEXT
|
||||
return self.func(
|
||||
"CONCAT",
|
||||
*[
|
||||
exp.func("if", e.is_(exp.null()), e, exp.cast(e, "text"))
|
||||
for e in t.cast(t.List[exp.Condition], expression.expressions)
|
||||
],
|
||||
)
|
||||
|
||||
def cte_sql(self, expression: exp.CTE) -> str:
|
||||
if isinstance(expression.this, exp.Alias):
|
||||
return self.sql(expression, "this")
|
||||
if expression.args.get("scalar"):
|
||||
this = self.sql(expression, "this")
|
||||
alias = self.sql(expression, "alias")
|
||||
return f"{this} AS {alias}"
|
||||
|
||||
return super().cte_sql(expression)
|
||||
|
||||
|
|
|
@ -1,13 +1,18 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, transforms
|
||||
from sqlglot.dialects.dialect import parse_date_delta, timestamptrunc_sql
|
||||
from sqlglot.dialects.dialect import (
|
||||
date_delta_sql,
|
||||
parse_date_delta,
|
||||
timestamptrunc_sql,
|
||||
)
|
||||
from sqlglot.dialects.spark import Spark
|
||||
from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
class Databricks(Spark):
|
||||
SAFE_DIVISION = False
|
||||
|
||||
class Parser(Spark.Parser):
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
STRICT_CAST = True
|
||||
|
@ -27,8 +32,8 @@ class Databricks(Spark):
|
|||
class Generator(Spark.Generator):
|
||||
TRANSFORMS = {
|
||||
**Spark.Generator.TRANSFORMS,
|
||||
exp.DateAdd: generate_date_delta_with_unit_sql,
|
||||
exp.DateDiff: generate_date_delta_with_unit_sql,
|
||||
exp.DateAdd: date_delta_sql("DATEADD"),
|
||||
exp.DateDiff: date_delta_sql("DATEDIFF"),
|
||||
exp.DatetimeAdd: lambda self, e: self.func(
|
||||
"TIMESTAMPADD", e.text("unit"), e.expression, e.this
|
||||
),
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from enum import Enum
|
||||
from enum import Enum, auto
|
||||
from functools import reduce
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.errors import ParseError
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import flatten, seq_get
|
||||
from sqlglot.helper import AutoName, flatten, seq_get
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.time import TIMEZONES, format_time
|
||||
from sqlglot.tokens import Token, Tokenizer, TokenType
|
||||
|
@ -16,6 +16,9 @@ from sqlglot.trie import new_trie
|
|||
|
||||
B = t.TypeVar("B", bound=exp.Binary)
|
||||
|
||||
DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
|
||||
DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
|
||||
|
||||
|
||||
class Dialects(str, Enum):
|
||||
DIALECT = ""
|
||||
|
@ -43,6 +46,15 @@ class Dialects(str, Enum):
|
|||
Doris = "doris"
|
||||
|
||||
|
||||
class NormalizationStrategy(str, AutoName):
|
||||
"""Specifies the strategy according to which identifiers should be normalized."""
|
||||
|
||||
LOWERCASE = auto() # Unquoted identifiers are lowercased
|
||||
UPPERCASE = auto() # Unquoted identifiers are uppercased
|
||||
CASE_SENSITIVE = auto() # Always case-sensitive, regardless of quotes
|
||||
CASE_INSENSITIVE = auto() # Always case-insensitive, regardless of quotes
|
||||
|
||||
|
||||
class _Dialect(type):
|
||||
classes: t.Dict[str, t.Type[Dialect]] = {}
|
||||
|
||||
|
@ -106,26 +118,8 @@ class _Dialect(type):
|
|||
klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
|
||||
klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
|
||||
|
||||
dialect_properties = {
|
||||
**{
|
||||
k: v
|
||||
for k, v in vars(klass).items()
|
||||
if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__")
|
||||
},
|
||||
"TOKENIZER_CLASS": klass.tokenizer_class,
|
||||
}
|
||||
|
||||
if enum not in ("", "bigquery"):
|
||||
dialect_properties["SELECT_KINDS"] = ()
|
||||
|
||||
# Pass required dialect properties to the tokenizer, parser and generator classes
|
||||
for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class):
|
||||
for name, value in dialect_properties.items():
|
||||
if hasattr(subclass, name):
|
||||
setattr(subclass, name, value)
|
||||
|
||||
if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT:
|
||||
klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
|
||||
klass.generator_class.SELECT_KINDS = ()
|
||||
|
||||
if not klass.SUPPORTS_SEMI_ANTI_JOIN:
|
||||
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
|
||||
|
@ -133,8 +127,6 @@ class _Dialect(type):
|
|||
TokenType.SEMI,
|
||||
}
|
||||
|
||||
klass.generator_class.can_identify = klass.can_identify
|
||||
|
||||
return klass
|
||||
|
||||
|
||||
|
@ -148,9 +140,8 @@ class Dialect(metaclass=_Dialect):
|
|||
# Determines whether or not the table alias comes after tablesample
|
||||
ALIAS_POST_TABLESAMPLE = False
|
||||
|
||||
# Determines whether or not unquoted identifiers are resolved as uppercase
|
||||
# When set to None, it means that the dialect treats all identifiers as case-insensitive
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
|
||||
# Specifies the strategy according to which identifiers should be normalized.
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
|
||||
|
||||
# Determines whether or not an unquoted identifier can start with a digit
|
||||
IDENTIFIERS_CAN_START_WITH_DIGIT = False
|
||||
|
@ -177,6 +168,18 @@ class Dialect(metaclass=_Dialect):
|
|||
# Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
|
||||
NULL_ORDERING = "nulls_are_small"
|
||||
|
||||
# Whether the behavior of a / b depends on the types of a and b.
|
||||
# False means a / b is always float division.
|
||||
# True means a / b is integer division if both a and b are integers.
|
||||
TYPED_DIVISION = False
|
||||
|
||||
# False means 1 / 0 throws an error.
|
||||
# True means 1 / 0 returns null.
|
||||
SAFE_DIVISION = False
|
||||
|
||||
# A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string
|
||||
CONCAT_COALESCE = False
|
||||
|
||||
DATE_FORMAT = "'%Y-%m-%d'"
|
||||
DATEINT_FORMAT = "'%Y%m%d'"
|
||||
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
|
||||
|
@ -197,7 +200,8 @@ class Dialect(metaclass=_Dialect):
|
|||
# Such columns may be excluded from SELECT * queries, for example
|
||||
PSEUDOCOLUMNS: t.Set[str] = set()
|
||||
|
||||
# Autofilled
|
||||
# --- Autofilled ---
|
||||
|
||||
tokenizer_class = Tokenizer
|
||||
parser_class = Parser
|
||||
generator_class = Generator
|
||||
|
@ -211,26 +215,61 @@ class Dialect(metaclass=_Dialect):
|
|||
|
||||
INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
|
||||
|
||||
def __eq__(self, other: t.Any) -> bool:
|
||||
return type(self) == other
|
||||
# Delimiters for quotes, identifiers and the corresponding escape characters
|
||||
QUOTE_START = "'"
|
||||
QUOTE_END = "'"
|
||||
IDENTIFIER_START = '"'
|
||||
IDENTIFIER_END = '"'
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(type(self))
|
||||
# Delimiters for bit, hex and byte literals
|
||||
BIT_START: t.Optional[str] = None
|
||||
BIT_END: t.Optional[str] = None
|
||||
HEX_START: t.Optional[str] = None
|
||||
HEX_END: t.Optional[str] = None
|
||||
BYTE_START: t.Optional[str] = None
|
||||
BYTE_END: t.Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
|
||||
def get_or_raise(cls, dialect: DialectType) -> Dialect:
|
||||
"""
|
||||
Look up a dialect in the global dialect registry and return it if it exists.
|
||||
|
||||
Args:
|
||||
dialect: The target dialect. If this is a string, it can be optionally followed by
|
||||
additional key-value pairs that are separated by commas and are used to specify
|
||||
dialect settings, such as whether the dialect's identifiers are case-sensitive.
|
||||
|
||||
Example:
|
||||
>>> dialect = dialect_class = get_or_raise("duckdb")
|
||||
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
|
||||
|
||||
Returns:
|
||||
The corresponding Dialect instance.
|
||||
"""
|
||||
|
||||
if not dialect:
|
||||
return cls
|
||||
return cls()
|
||||
if isinstance(dialect, _Dialect):
|
||||
return dialect
|
||||
return dialect()
|
||||
if isinstance(dialect, Dialect):
|
||||
return dialect.__class__
|
||||
return dialect
|
||||
if isinstance(dialect, str):
|
||||
try:
|
||||
dialect_name, *kv_pairs = dialect.split(",")
|
||||
kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid dialect format: '{dialect}'. "
|
||||
"Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
|
||||
)
|
||||
|
||||
result = cls.get(dialect)
|
||||
if not result:
|
||||
raise ValueError(f"Unknown dialect '{dialect}'")
|
||||
result = cls.get(dialect_name.strip())
|
||||
if not result:
|
||||
raise ValueError(f"Unknown dialect '{dialect_name}'.")
|
||||
|
||||
return result
|
||||
return result(**kwargs)
|
||||
|
||||
raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
|
||||
|
||||
@classmethod
|
||||
def format_time(
|
||||
|
@ -247,36 +286,71 @@ class Dialect(metaclass=_Dialect):
|
|||
|
||||
return expression
|
||||
|
||||
@classmethod
|
||||
def normalize_identifier(cls, expression: E) -> E:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
normalization_strategy = kwargs.get("normalization_strategy")
|
||||
|
||||
if normalization_strategy is None:
|
||||
self.normalization_strategy = self.NORMALIZATION_STRATEGY
|
||||
else:
|
||||
self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
|
||||
|
||||
def __eq__(self, other: t.Any) -> bool:
|
||||
# Does not currently take dialect state into account
|
||||
return type(self) == other
|
||||
|
||||
def __hash__(self) -> int:
|
||||
# Does not currently take dialect state into account
|
||||
return hash(type(self))
|
||||
|
||||
def normalize_identifier(self, expression: E) -> E:
|
||||
"""
|
||||
Normalizes an unquoted identifier to either lower or upper case, thus essentially
|
||||
making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
|
||||
they will be normalized to lowercase regardless of being quoted or not.
|
||||
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
|
||||
|
||||
For example, an identifier like FoO would be resolved as foo in Postgres, because it
|
||||
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
|
||||
it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive,
|
||||
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
|
||||
|
||||
There are also dialects like Spark, which are case-insensitive even when quotes are
|
||||
present, and dialects like MySQL, whose resolution rules match those employed by the
|
||||
underlying operating system, for example they may always be case-sensitive in Linux.
|
||||
|
||||
Finally, the normalization behavior of some engines can even be controlled through flags,
|
||||
like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
|
||||
|
||||
SQLGlot aims to understand and handle all of these different behaviors gracefully, so
|
||||
that it can analyze queries in the optimizer and successfully capture their semantics.
|
||||
"""
|
||||
if isinstance(expression, exp.Identifier) and (
|
||||
not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
|
||||
if (
|
||||
isinstance(expression, exp.Identifier)
|
||||
and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
|
||||
and (
|
||||
not expression.quoted
|
||||
or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
|
||||
)
|
||||
):
|
||||
expression.set(
|
||||
"this",
|
||||
expression.this.upper()
|
||||
if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
|
||||
if self.normalization_strategy is NormalizationStrategy.UPPERCASE
|
||||
else expression.this.lower(),
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
@classmethod
|
||||
def case_sensitive(cls, text: str) -> bool:
|
||||
def case_sensitive(self, text: str) -> bool:
|
||||
"""Checks if text contains any case sensitive characters, based on the dialect's rules."""
|
||||
if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
|
||||
if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
|
||||
return False
|
||||
|
||||
unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
|
||||
unsafe = (
|
||||
str.islower
|
||||
if self.normalization_strategy is NormalizationStrategy.UPPERCASE
|
||||
else str.isupper
|
||||
)
|
||||
return any(unsafe(char) for char in text)
|
||||
|
||||
@classmethod
|
||||
def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
|
||||
def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
|
||||
"""Checks if text can be identified given an identify option.
|
||||
|
||||
Args:
|
||||
|
@ -292,17 +366,16 @@ class Dialect(metaclass=_Dialect):
|
|||
return True
|
||||
|
||||
if identify == "safe":
|
||||
return not cls.case_sensitive(text)
|
||||
return not self.case_sensitive(text)
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def quote_identifier(cls, expression: E, identify: bool = True) -> E:
|
||||
def quote_identifier(self, expression: E, identify: bool = True) -> E:
|
||||
if isinstance(expression, exp.Identifier):
|
||||
name = expression.this
|
||||
expression.set(
|
||||
"quoted",
|
||||
identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
|
||||
identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
|
||||
)
|
||||
|
||||
return expression
|
||||
|
@ -330,14 +403,14 @@ class Dialect(metaclass=_Dialect):
|
|||
@property
|
||||
def tokenizer(self) -> Tokenizer:
|
||||
if not hasattr(self, "_tokenizer"):
|
||||
self._tokenizer = self.tokenizer_class()
|
||||
self._tokenizer = self.tokenizer_class(dialect=self)
|
||||
return self._tokenizer
|
||||
|
||||
def parser(self, **opts) -> Parser:
|
||||
return self.parser_class(**opts)
|
||||
return self.parser_class(dialect=self, **opts)
|
||||
|
||||
def generator(self, **opts) -> Generator:
|
||||
return self.generator_class(**opts)
|
||||
return self.generator_class(dialect=self, **opts)
|
||||
|
||||
|
||||
DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
|
||||
|
@ -713,7 +786,7 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
|
|||
return _ts_or_ds_to_date_sql
|
||||
|
||||
|
||||
def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
|
||||
def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
|
||||
return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
|
||||
|
||||
|
||||
|
@ -821,3 +894,28 @@ def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | ex
|
|||
return self.func(name, expression.this, expression.expression)
|
||||
|
||||
return _arg_max_or_min_sql
|
||||
|
||||
|
||||
def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
|
||||
this = expression.this.copy()
|
||||
|
||||
return_type = expression.return_type
|
||||
if return_type.is_type(exp.DataType.Type.DATE):
|
||||
# If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
|
||||
# can truncate timestamp strings, because some dialects can't cast them to DATE
|
||||
this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
|
||||
|
||||
expression.this.replace(exp.cast(this, return_type))
|
||||
return expression
|
||||
|
||||
|
||||
def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
|
||||
def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
|
||||
if cast and isinstance(expression, exp.TsOrDsAdd):
|
||||
expression = ts_or_ds_add_cast(expression)
|
||||
|
||||
return self.func(
|
||||
name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this
|
||||
)
|
||||
|
||||
return _delta_sql
|
||||
|
|
|
@ -19,6 +19,7 @@ class Doris(MySQL):
|
|||
class Parser(MySQL.Parser):
|
||||
FUNCTIONS = {
|
||||
**MySQL.Parser.FUNCTIONS,
|
||||
"COLLECT_SET": exp.ArrayUniqueAgg.from_arg_list,
|
||||
"DATE_TRUNC": parse_timestamp_trunc,
|
||||
"REGEXP": exp.RegexpLike.from_arg_list,
|
||||
}
|
||||
|
@ -47,7 +48,7 @@ class Doris(MySQL):
|
|||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.RegexpLike: rename_func("REGEXP"),
|
||||
exp.RegexpSplit: rename_func("SPLIT_BY_STRING"),
|
||||
exp.SetAgg: rename_func("COLLECT_SET"),
|
||||
exp.ArrayUniqueAgg: rename_func("COLLECT_SET"),
|
||||
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.Split: rename_func("SPLIT_BY_STRING"),
|
||||
exp.TimeStrToDate: rename_func("TO_DATE"),
|
||||
|
|
|
@ -43,6 +43,8 @@ class Drill(Dialect):
|
|||
TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'"
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
SUPPORTS_SEMI_ANTI_JOIN = False
|
||||
TYPED_DIVISION = True
|
||||
CONCAT_COALESCE = True
|
||||
|
||||
TIME_MAPPING = {
|
||||
"y": "%Y",
|
||||
|
@ -83,7 +85,6 @@ class Drill(Dialect):
|
|||
|
||||
class Parser(parser.Parser):
|
||||
STRICT_CAST = False
|
||||
CONCAT_NULL_OUTPUTS_STRING = True
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
|
|
|
@ -2,9 +2,10 @@ from __future__ import annotations
|
|||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
NormalizationStrategy,
|
||||
approx_count_distinct_sql,
|
||||
arg_max_or_min_no_count,
|
||||
arrow_json_extract_scalar_sql,
|
||||
|
@ -36,7 +37,8 @@ from sqlglot.tokens import TokenType
|
|||
def _ts_or_ds_add_sql(self: DuckDB.Generator, expression: exp.TsOrDsAdd) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = self.sql(expression, "unit").strip("'") or "DAY"
|
||||
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
|
||||
interval = self.sql(exp.Interval(this=expression.expression, unit=unit))
|
||||
return f"CAST({this} AS {self.sql(expression.return_type)}) + {interval}"
|
||||
|
||||
|
||||
def _date_delta_sql(self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
|
@ -84,7 +86,8 @@ def _parse_date_diff(args: t.List) -> exp.Expression:
|
|||
|
||||
def _struct_sql(self: DuckDB.Generator, expression: exp.Struct) -> str:
|
||||
args = [
|
||||
f"'{e.name or e.this.name}': {self.sql(e, 'expression')}" for e in expression.expressions
|
||||
f"'{e.name or e.this.name}': {self.sql(e.expressions[0]) if isinstance(e, exp.Bracket) else self.sql(e, 'expression')}"
|
||||
for e in expression.expressions
|
||||
]
|
||||
return f"{{{', '.join(args)}}}"
|
||||
|
||||
|
@ -105,17 +108,35 @@ def _json_format_sql(self: DuckDB.Generator, expression: exp.JSONFormat) -> str:
|
|||
return f"CAST({sql} AS TEXT)"
|
||||
|
||||
|
||||
def _unix_to_time_sql(self: DuckDB.Generator, expression: exp.UnixToTime) -> str:
|
||||
scale = expression.args.get("scale")
|
||||
timestamp = self.sql(expression, "this")
|
||||
if scale in (None, exp.UnixToTime.SECONDS):
|
||||
return f"TO_TIMESTAMP({timestamp})"
|
||||
if scale == exp.UnixToTime.MILLIS:
|
||||
return f"EPOCH_MS({timestamp})"
|
||||
if scale == exp.UnixToTime.MICROS:
|
||||
return f"MAKE_TIMESTAMP({timestamp})"
|
||||
if scale == exp.UnixToTime.NANOS:
|
||||
return f"TO_TIMESTAMP({timestamp} / 1000000000)"
|
||||
|
||||
self.unsupported(f"Unsupported scale for timestamp: {scale}.")
|
||||
return ""
|
||||
|
||||
|
||||
class DuckDB(Dialect):
|
||||
NULL_ORDERING = "nulls_are_last"
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
SAFE_DIVISION = True
|
||||
INDEX_OFFSET = 1
|
||||
CONCAT_COALESCE = True
|
||||
|
||||
# https://duckdb.org/docs/sql/introduction.html#creating-a-new-table
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
":=": TokenType.EQ,
|
||||
"//": TokenType.DIV,
|
||||
"ATTACH": TokenType.COMMAND,
|
||||
"BINARY": TokenType.VARBINARY,
|
||||
|
@ -124,8 +145,6 @@ class DuckDB(Dialect):
|
|||
"CHAR": TokenType.TEXT,
|
||||
"CHARACTER VARYING": TokenType.TEXT,
|
||||
"EXCLUDE": TokenType.EXCEPT,
|
||||
"HUGEINT": TokenType.INT128,
|
||||
"INT1": TokenType.TINYINT,
|
||||
"LOGICAL": TokenType.BOOLEAN,
|
||||
"PIVOT_WIDER": TokenType.PIVOT,
|
||||
"SIGNED": TokenType.INT,
|
||||
|
@ -141,8 +160,6 @@ class DuckDB(Dialect):
|
|||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
CONCAT_NULL_OUTPUTS_STRING = True
|
||||
|
||||
BITWISE = {
|
||||
**parser.Parser.BITWISE,
|
||||
TokenType.TILDA: exp.RegexpLike,
|
||||
|
@ -150,6 +167,7 @@ class DuckDB(Dialect):
|
|||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"ARRAY_HAS": exp.ArrayContains.from_arg_list,
|
||||
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
|
||||
"ARRAY_SORT": exp.SortArray.from_arg_list,
|
||||
"ARRAY_REVERSE_SORT": _sort_array_reverse,
|
||||
|
@ -157,13 +175,23 @@ class DuckDB(Dialect):
|
|||
"DATE_DIFF": _parse_date_diff,
|
||||
"DATE_TRUNC": date_trunc_to_time,
|
||||
"DATETRUNC": date_trunc_to_time,
|
||||
"DECODE": lambda args: exp.Decode(
|
||||
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
|
||||
),
|
||||
"ENCODE": lambda args: exp.Encode(
|
||||
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
|
||||
),
|
||||
"EPOCH": exp.TimeToUnix.from_arg_list,
|
||||
"EPOCH_MS": lambda args: exp.UnixToTime(
|
||||
this=exp.Div(this=seq_get(args, 0), expression=exp.Literal.number(1000))
|
||||
this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS
|
||||
),
|
||||
"LIST_HAS": exp.ArrayContains.from_arg_list,
|
||||
"LIST_REVERSE_SORT": _sort_array_reverse,
|
||||
"LIST_SORT": exp.SortArray.from_arg_list,
|
||||
"LIST_VALUE": exp.Array.from_arg_list,
|
||||
"MAKE_TIMESTAMP": lambda args: exp.UnixToTime(
|
||||
this=seq_get(args, 0), scale=exp.UnixToTime.MICROS
|
||||
),
|
||||
"MEDIAN": lambda args: exp.PercentileCont(
|
||||
this=seq_get(args, 0), expression=exp.Literal.number(0.5)
|
||||
),
|
||||
|
@ -192,15 +220,8 @@ class DuckDB(Dialect):
|
|||
"XOR": binary_from_function(exp.BitwiseXor),
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
"DECODE": lambda self: self.expression(
|
||||
exp.Decode, this=self._parse_conjunction(), charset=exp.Literal.string("utf-8")
|
||||
),
|
||||
"ENCODE": lambda self: self.expression(
|
||||
exp.Encode, this=self._parse_conjunction(), charset=exp.Literal.string("utf-8")
|
||||
),
|
||||
}
|
||||
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
|
||||
FUNCTION_PARSERS.pop("DECODE", None)
|
||||
|
||||
TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
|
||||
TokenType.SEMI,
|
||||
|
@ -277,6 +298,7 @@ class DuckDB(Dialect):
|
|||
exp.Encode: lambda self, e: encode_decode_sql(self, e, "ENCODE", replace=False),
|
||||
exp.Explode: rename_func("UNNEST"),
|
||||
exp.IntDiv: lambda self, e: self.binary(e, "//"),
|
||||
exp.IsInf: rename_func("ISINF"),
|
||||
exp.IsNan: rename_func("ISNAN"),
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||
|
@ -294,6 +316,9 @@ class DuckDB(Dialect):
|
|||
exp.ParseJSON: rename_func("JSON"),
|
||||
exp.PercentileCont: rename_func("QUANTILE_CONT"),
|
||||
exp.PercentileDisc: rename_func("QUANTILE_DISC"),
|
||||
# DuckDB doesn't allow qualified columns inside of PIVOT expressions.
|
||||
# See: https://github.com/duckdb/duckdb/blob/671faf92411182f81dce42ac43de8bfb05d9909e/src/planner/binder/tableref/bind_pivot.cpp#L61-L62
|
||||
exp.Pivot: transforms.preprocess([transforms.unqualify_columns]),
|
||||
exp.Properties: no_properties_sql,
|
||||
exp.RegexpExtract: regexp_extract_sql,
|
||||
exp.RegexpReplace: lambda self, e: self.func(
|
||||
|
@ -322,9 +347,15 @@ class DuckDB(Dialect):
|
|||
exp.TimeToUnix: rename_func("EPOCH"),
|
||||
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
|
||||
exp.TsOrDsAdd: _ts_or_ds_add_sql,
|
||||
exp.TsOrDsDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF",
|
||||
f"'{e.args.get('unit') or 'day'}'",
|
||||
exp.cast(e.expression, "TIMESTAMP"),
|
||||
exp.cast(e.this, "TIMESTAMP"),
|
||||
),
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("duckdb"),
|
||||
exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
|
||||
exp.UnixToTime: rename_func("TO_TIMESTAMP"),
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)",
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
|
||||
|
|
|
@ -4,10 +4,13 @@ import typing as t
|
|||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
DATE_ADD_OR_SUB,
|
||||
Dialect,
|
||||
NormalizationStrategy,
|
||||
approx_count_distinct_sql,
|
||||
arg_max_or_min_no_count,
|
||||
create_with_partitions_sql,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
if_sql,
|
||||
is_parse_json,
|
||||
|
@ -76,7 +79,10 @@ def _create_sql(self, expression: exp.Create) -> str:
|
|||
return create_with_partitions_sql(self, expression)
|
||||
|
||||
|
||||
def _add_date_sql(self: Hive.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
def _add_date_sql(self: Hive.Generator, expression: DATE_ADD_OR_SUB) -> str:
|
||||
if isinstance(expression, exp.TsOrDsAdd) and not expression.unit:
|
||||
return self.func("DATE_ADD", expression.this, expression.expression)
|
||||
|
||||
unit = expression.text("unit").upper()
|
||||
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
|
||||
|
||||
|
@ -95,7 +101,7 @@ def _add_date_sql(self: Hive.Generator, expression: exp.DateAdd | exp.DateSub) -
|
|||
return self.func(func, expression.this, modified_increment)
|
||||
|
||||
|
||||
def _date_diff_sql(self: Hive.Generator, expression: exp.DateDiff) -> str:
|
||||
def _date_diff_sql(self: Hive.Generator, expression: exp.DateDiff | exp.TsOrDsDiff) -> str:
|
||||
unit = expression.text("unit").upper()
|
||||
|
||||
factor = TIME_DIFF_FACTOR.get(unit)
|
||||
|
@ -111,25 +117,31 @@ def _date_diff_sql(self: Hive.Generator, expression: exp.DateDiff) -> str:
|
|||
multiplier_sql = f" / {multiplier}" if multiplier > 1 else ""
|
||||
diff_sql = f"{sql_func}({self.format_args(expression.this, expression.expression)})"
|
||||
|
||||
if months_between:
|
||||
# MONTHS_BETWEEN returns a float, so we need to truncate the fractional part
|
||||
diff_sql = f"CAST({diff_sql} AS INT)"
|
||||
if months_between or multiplier_sql:
|
||||
# MONTHS_BETWEEN returns a float, so we need to truncate the fractional part.
|
||||
# For the same reason, we want to truncate if there's a divisor present.
|
||||
diff_sql = f"CAST({diff_sql}{multiplier_sql} AS INT)"
|
||||
|
||||
return f"{diff_sql}{multiplier_sql}"
|
||||
return diff_sql
|
||||
|
||||
|
||||
def _json_format_sql(self: Hive.Generator, expression: exp.JSONFormat) -> str:
|
||||
this = expression.this
|
||||
if is_parse_json(this) and this.this.is_string:
|
||||
# Since FROM_JSON requires a nested type, we always wrap the json string with
|
||||
# an array to ensure that "naked" strings like "'a'" will be handled correctly
|
||||
wrapped_json = exp.Literal.string(f"[{this.this.name}]")
|
||||
|
||||
from_json = self.func("FROM_JSON", wrapped_json, self.func("SCHEMA_OF_JSON", wrapped_json))
|
||||
to_json = self.func("TO_JSON", from_json)
|
||||
if is_parse_json(this):
|
||||
if this.this.is_string:
|
||||
# Since FROM_JSON requires a nested type, we always wrap the json string with
|
||||
# an array to ensure that "naked" strings like "'a'" will be handled correctly
|
||||
wrapped_json = exp.Literal.string(f"[{this.this.name}]")
|
||||
|
||||
# This strips the [, ] delimiters of the dummy array printed by TO_JSON
|
||||
return self.func("REGEXP_EXTRACT", to_json, "'^.(.*).$'", "1")
|
||||
from_json = self.func(
|
||||
"FROM_JSON", wrapped_json, self.func("SCHEMA_OF_JSON", wrapped_json)
|
||||
)
|
||||
to_json = self.func("TO_JSON", from_json)
|
||||
|
||||
# This strips the [, ] delimiters of the dummy array printed by TO_JSON
|
||||
return self.func("REGEXP_EXTRACT", to_json, "'^.(.*).$'", "1")
|
||||
return self.sql(this)
|
||||
|
||||
return self.func("TO_JSON", this, expression.args.get("options"))
|
||||
|
||||
|
@ -175,6 +187,8 @@ def _to_date_sql(self: Hive.Generator, expression: exp.TsOrDsToDate) -> str:
|
|||
time_format = self.format_time(expression)
|
||||
if time_format and time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT):
|
||||
return f"TO_DATE({this}, {time_format})"
|
||||
if isinstance(expression.this, exp.TsOrDsToDate):
|
||||
return this
|
||||
return f"TO_DATE({this})"
|
||||
|
||||
|
||||
|
@ -182,9 +196,10 @@ class Hive(Dialect):
|
|||
ALIAS_POST_TABLESAMPLE = True
|
||||
IDENTIFIERS_CAN_START_WITH_DIGIT = True
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
SAFE_DIVISION = True
|
||||
|
||||
# https://spark.apache.org/docs/latest/sql-ref-identifier.html#description
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
|
||||
|
||||
TIME_MAPPING = {
|
||||
"y": "%Y",
|
||||
|
@ -241,10 +256,10 @@ class Hive(Dialect):
|
|||
"ADD JAR": TokenType.COMMAND,
|
||||
"ADD JARS": TokenType.COMMAND,
|
||||
"MSCK REPAIR": TokenType.COMMAND,
|
||||
"REFRESH": TokenType.COMMAND,
|
||||
"WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES,
|
||||
"REFRESH": TokenType.REFRESH,
|
||||
"TIMESTAMP AS OF": TokenType.TIMESTAMP_SNAPSHOT,
|
||||
"VERSION AS OF": TokenType.VERSION_SNAPSHOT,
|
||||
"WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES,
|
||||
}
|
||||
|
||||
NUMERIC_LITERALS = {
|
||||
|
@ -264,7 +279,7 @@ class Hive(Dialect):
|
|||
**parser.Parser.FUNCTIONS,
|
||||
"BASE64": exp.ToBase64.from_arg_list,
|
||||
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
|
||||
"COLLECT_SET": exp.SetAgg.from_arg_list,
|
||||
"COLLECT_SET": exp.ArrayUniqueAgg.from_arg_list,
|
||||
"DATE_ADD": lambda args: exp.TsOrDsAdd(
|
||||
this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY")
|
||||
),
|
||||
|
@ -411,7 +426,13 @@ class Hive(Dialect):
|
|||
INDEX_ON = "ON TABLE"
|
||||
EXTRACT_ALLOWS_QUOTES = False
|
||||
NVL2_SUPPORTED = False
|
||||
SUPPORTS_NESTED_CTES = False
|
||||
|
||||
EXPRESSIONS_WITHOUT_NESTED_CTES = {
|
||||
exp.Insert,
|
||||
exp.Select,
|
||||
exp.Subquery,
|
||||
exp.Union,
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
|
@ -445,7 +466,7 @@ class Hive(Dialect):
|
|||
exp.With: no_recursive_cte_sql,
|
||||
exp.DateAdd: _add_date_sql,
|
||||
exp.DateDiff: _date_diff_sql,
|
||||
exp.DateStrToDate: rename_func("TO_DATE"),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateSub: _add_date_sql,
|
||||
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.DATEINT_FORMAT}) AS INT)",
|
||||
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.DATEINT_FORMAT})",
|
||||
|
@ -477,7 +498,7 @@ class Hive(Dialect):
|
|||
exp.Right: right_to_substring_sql,
|
||||
exp.SafeDivide: no_safe_divide_sql,
|
||||
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
|
||||
exp.SetAgg: rename_func("COLLECT_SET"),
|
||||
exp.ArrayUniqueAgg: rename_func("COLLECT_SET"),
|
||||
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
|
||||
exp.StrPosition: strposition_to_locate_sql,
|
||||
exp.StrToDate: _str_to_date_sql,
|
||||
|
@ -491,7 +512,8 @@ class Hive(Dialect):
|
|||
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.ToBase64: rename_func("BASE64"),
|
||||
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)",
|
||||
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.TsOrDsAdd: _add_date_sql,
|
||||
exp.TsOrDsDiff: _date_diff_sql,
|
||||
exp.TsOrDsToDate: _to_date_sql,
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.UnixToStr: lambda self, e: self.func(
|
||||
|
@ -571,6 +593,8 @@ class Hive(Dialect):
|
|||
and not expression.expressions
|
||||
):
|
||||
expression = exp.DataType.build("text")
|
||||
elif expression.is_type(exp.DataType.Type.TEXT) and expression.expressions:
|
||||
expression.set("this", exp.DataType.Type.VARCHAR)
|
||||
elif expression.this in exp.DataType.TEMPORAL_TYPES:
|
||||
expression = exp.DataType.build(expression.this)
|
||||
elif expression.is_type("float"):
|
||||
|
|
|
@ -5,6 +5,7 @@ import typing as t
|
|||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
NormalizationStrategy,
|
||||
arrow_json_extract_scalar_sql,
|
||||
date_add_interval_sql,
|
||||
datestrtodate_sql,
|
||||
|
@ -150,10 +151,18 @@ class MySQL(Dialect):
|
|||
# https://dev.mysql.com/doc/refman/8.0/en/identifiers.html
|
||||
IDENTIFIERS_CAN_START_WITH_DIGIT = True
|
||||
|
||||
# We default to treating all identifiers as case-sensitive, since it matches MySQL's
|
||||
# behavior on Linux systems. For MacOS and Windows systems, one can override this
|
||||
# setting by specifying `dialect="mysql, normalization_strategy = lowercase"`.
|
||||
#
|
||||
# See also https://dev.mysql.com/doc/refman/8.2/en/identifier-case-sensitivity.html
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_SENSITIVE
|
||||
|
||||
TIME_FORMAT = "'%Y-%m-%d %T'"
|
||||
DPIPE_IS_STRING_CONCAT = False
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
SUPPORTS_SEMI_ANTI_JOIN = False
|
||||
SAFE_DIVISION = True
|
||||
|
||||
# https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions
|
||||
TIME_MAPPING = {
|
||||
|
@ -264,11 +273,6 @@ class MySQL(Dialect):
|
|||
TokenType.DPIPE: exp.Or,
|
||||
}
|
||||
|
||||
# MySQL uses || as a synonym to the logical OR operator
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/logical-operators.html#operator_or
|
||||
BITWISE = parser.Parser.BITWISE.copy()
|
||||
BITWISE.pop(TokenType.DPIPE)
|
||||
|
||||
TABLE_ALIAS_TOKENS = (
|
||||
parser.Parser.TABLE_ALIAS_TOKENS - parser.Parser.TABLE_INDEX_HINT_TOKENS
|
||||
)
|
||||
|
@ -451,7 +455,7 @@ class MySQL(Dialect):
|
|||
self, kind: t.Optional[str] = None
|
||||
) -> exp.IndexColumnConstraint:
|
||||
if kind:
|
||||
self._match_texts({"INDEX", "KEY"})
|
||||
self._match_texts(("INDEX", "KEY"))
|
||||
|
||||
this = self._parse_id_var(any_token=False)
|
||||
index_type = self._match(TokenType.USING) and self._advance_any() and self._prev.text
|
||||
|
@ -514,7 +518,7 @@ class MySQL(Dialect):
|
|||
|
||||
log = self._parse_string() if self._match_text_seq("IN") else None
|
||||
|
||||
if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}:
|
||||
if this in ("BINLOG EVENTS", "RELAYLOG EVENTS"):
|
||||
position = self._parse_number() if self._match_text_seq("FROM") else None
|
||||
db = None
|
||||
else:
|
||||
|
@ -671,6 +675,7 @@ class MySQL(Dialect):
|
|||
exp.Trim: _trim_sql,
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.TsOrDsAdd: _date_add_sql("ADD"),
|
||||
exp.TsOrDsDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression),
|
||||
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
|
||||
exp.Week: _remove_ts_or_ds_to_date(),
|
||||
exp.WeekOfYear: _remove_ts_or_ds_to_date(rename_func("WEEKOFYEAR")),
|
||||
|
@ -763,7 +768,7 @@ class MySQL(Dialect):
|
|||
|
||||
target = self.sql(expression, "target")
|
||||
target = f" {target}" if target else ""
|
||||
if expression.name in {"COLUMNS", "INDEX"}:
|
||||
if expression.name in ("COLUMNS", "INDEX"):
|
||||
target = f" FROM{target}"
|
||||
elif expression.name == "GRANTS":
|
||||
target = f" FOR{target}"
|
||||
|
@ -796,6 +801,14 @@ class MySQL(Dialect):
|
|||
|
||||
return f"SHOW{full}{global_}{this}{target}{types}{db}{query}{log}{position}{channel}{mutex_or_status}{like}{where}{offset}{limit}"
|
||||
|
||||
def altercolumn_sql(self, expression: exp.AlterColumn) -> str:
|
||||
dtype = self.sql(expression, "dtype")
|
||||
if not dtype:
|
||||
return super().altercolumn_sql(expression)
|
||||
|
||||
this = self.sql(expression, "this")
|
||||
return f"MODIFY COLUMN {this} {dtype}"
|
||||
|
||||
def _prefixed_sql(self, prefix: str, expression: exp.Expression, arg: str) -> str:
|
||||
sql = self.sql(expression, arg)
|
||||
return f" {prefix} {sql}" if sql else ""
|
||||
|
|
|
@ -3,7 +3,14 @@ from __future__ import annotations
|
|||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sql
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
NormalizationStrategy,
|
||||
format_time_lambda,
|
||||
no_ilike_sql,
|
||||
rename_func,
|
||||
trim_sql,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
@ -30,12 +37,25 @@ def _parse_xml_table(self: Oracle.Parser) -> exp.XMLTable:
|
|||
return self.expression(exp.XMLTable, this=this, passing=passing, columns=columns, by_ref=by_ref)
|
||||
|
||||
|
||||
def to_char(args: t.List) -> exp.TimeToStr | exp.ToChar:
|
||||
this = seq_get(args, 0)
|
||||
|
||||
if this and not this.type:
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
annotate_types(this)
|
||||
if this.is_type(*exp.DataType.TEMPORAL_TYPES):
|
||||
return format_time_lambda(exp.TimeToStr, "oracle", default=True)(args)
|
||||
|
||||
return exp.ToChar.from_arg_list(args)
|
||||
|
||||
|
||||
class Oracle(Dialect):
|
||||
ALIAS_POST_TABLESAMPLE = True
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
|
||||
# See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = True
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
|
||||
|
||||
# https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212
|
||||
# https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes
|
||||
|
@ -64,11 +84,13 @@ class Oracle(Dialect):
|
|||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
|
||||
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP}
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
|
||||
"TO_CHAR": to_char,
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
|
||||
|
@ -130,6 +152,7 @@ class Oracle(Dialect):
|
|||
TABLE_HINTS = False
|
||||
COLUMN_JOIN_MARKS_SUPPORTED = True
|
||||
DATA_TYPE_SPECIFIERS_ALLOWED = True
|
||||
ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = False
|
||||
|
||||
LIMIT_FETCH = "FETCH"
|
||||
|
||||
|
@ -192,6 +215,12 @@ class Oracle(Dialect):
|
|||
)
|
||||
return f"XMLTABLE({self.sep('')}{self.indent(this + passing + by_ref + columns)}{self.seg(')', sep='')}"
|
||||
|
||||
def add_column_sql(self, expression: exp.AlterTable) -> str:
|
||||
actions = self.expressions(expression, key="actions", flat=True)
|
||||
if len(expression.args.get("actions", [])) > 1:
|
||||
return f"ADD ({actions})"
|
||||
return f"ADD {actions}"
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
VAR_SINGLE_TOKENS = {"@", "$", "#"}
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import typing as t
|
|||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
DATE_ADD_OR_SUB,
|
||||
Dialect,
|
||||
any_value_to_max_sql,
|
||||
arrow_json_extract_scalar_sql,
|
||||
|
@ -25,6 +26,7 @@ from sqlglot.dialects.dialect import (
|
|||
timestamptrunc_sql,
|
||||
timestrtotime_sql,
|
||||
trim_sql,
|
||||
ts_or_ds_add_cast,
|
||||
ts_or_ds_to_date_sql,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
|
@ -41,8 +43,11 @@ DATE_DIFF_FACTOR = {
|
|||
}
|
||||
|
||||
|
||||
def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, exp.DateAdd | exp.DateSub], str]:
|
||||
def func(self: Postgres.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, DATE_ADD_OR_SUB], str]:
|
||||
def func(self: Postgres.Generator, expression: DATE_ADD_OR_SUB) -> str:
|
||||
if isinstance(expression, exp.TsOrDsAdd):
|
||||
expression = ts_or_ds_add_cast(expression)
|
||||
|
||||
this = self.sql(expression, "this")
|
||||
unit = expression.args.get("unit")
|
||||
|
||||
|
@ -60,8 +65,8 @@ def _date_diff_sql(self: Postgres.Generator, expression: exp.DateDiff) -> str:
|
|||
unit = expression.text("unit").upper()
|
||||
factor = DATE_DIFF_FACTOR.get(unit)
|
||||
|
||||
end = f"CAST({expression.this} AS TIMESTAMP)"
|
||||
start = f"CAST({expression.expression} AS TIMESTAMP)"
|
||||
end = f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
|
||||
start = f"CAST({self.sql(expression, 'expression')} AS TIMESTAMP)"
|
||||
|
||||
if factor is not None:
|
||||
return f"CAST(EXTRACT(epoch FROM {end} - {start}){factor} AS BIGINT)"
|
||||
|
@ -69,7 +74,7 @@ def _date_diff_sql(self: Postgres.Generator, expression: exp.DateDiff) -> str:
|
|||
age = f"AGE({end}, {start})"
|
||||
|
||||
if unit == "WEEK":
|
||||
unit = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7"
|
||||
unit = f"EXTRACT(days FROM ({end} - {start})) / 7"
|
||||
elif unit == "MONTH":
|
||||
unit = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})"
|
||||
elif unit == "QUARTER":
|
||||
|
@ -183,37 +188,43 @@ def _to_timestamp(args: t.List) -> exp.Expression:
|
|||
return format_time_lambda(exp.StrToTime, "postgres")(args)
|
||||
|
||||
|
||||
def _remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
|
||||
"""Remove table refs from columns in when statements."""
|
||||
if isinstance(expression, exp.Merge):
|
||||
alias = expression.this.args.get("alias")
|
||||
def _merge_sql(self: Postgres.Generator, expression: exp.Merge) -> str:
|
||||
def _remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
|
||||
"""Remove table refs from columns in when statements."""
|
||||
if isinstance(expression, exp.Merge):
|
||||
alias = expression.this.args.get("alias")
|
||||
|
||||
normalize = (
|
||||
lambda identifier: Postgres.normalize_identifier(identifier).name
|
||||
if identifier
|
||||
else None
|
||||
)
|
||||
|
||||
targets = {normalize(expression.this.this)}
|
||||
|
||||
if alias:
|
||||
targets.add(normalize(alias.this))
|
||||
|
||||
for when in expression.expressions:
|
||||
when.transform(
|
||||
lambda node: exp.column(node.this)
|
||||
if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
|
||||
else node,
|
||||
copy=False,
|
||||
normalize = (
|
||||
lambda identifier: self.dialect.normalize_identifier(identifier).name
|
||||
if identifier
|
||||
else None
|
||||
)
|
||||
|
||||
return expression
|
||||
targets = {normalize(expression.this.this)}
|
||||
|
||||
if alias:
|
||||
targets.add(normalize(alias.this))
|
||||
|
||||
for when in expression.expressions:
|
||||
when.transform(
|
||||
lambda node: exp.column(node.this)
|
||||
if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
|
||||
else node,
|
||||
copy=False,
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
return transforms.preprocess([_remove_target_from_merge])(self, expression)
|
||||
|
||||
|
||||
class Postgres(Dialect):
|
||||
INDEX_OFFSET = 1
|
||||
TYPED_DIVISION = True
|
||||
CONCAT_COALESCE = True
|
||||
NULL_ORDERING = "nulls_are_large"
|
||||
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
|
||||
|
||||
TIME_MAPPING = {
|
||||
"AM": "%p",
|
||||
"PM": "%p",
|
||||
|
@ -263,6 +274,7 @@ class Postgres(Dialect):
|
|||
"BEGIN TRANSACTION": TokenType.BEGIN,
|
||||
"BIGSERIAL": TokenType.BIGSERIAL,
|
||||
"CHARACTER VARYING": TokenType.VARCHAR,
|
||||
"CONSTRAINT TRIGGER": TokenType.COMMAND,
|
||||
"DECLARE": TokenType.COMMAND,
|
||||
"DO": TokenType.COMMAND,
|
||||
"HSTORE": TokenType.HSTORE,
|
||||
|
@ -277,6 +289,7 @@ class Postgres(Dialect):
|
|||
"TEMP": TokenType.TEMPORARY,
|
||||
"CSTRING": TokenType.PSEUDO_TYPE,
|
||||
"OID": TokenType.OBJECT_IDENTIFIER,
|
||||
"OPERATOR": TokenType.OPERATOR,
|
||||
"REGCLASS": TokenType.OBJECT_IDENTIFIER,
|
||||
"REGCOLLATION": TokenType.OBJECT_IDENTIFIER,
|
||||
"REGCONFIG": TokenType.OBJECT_IDENTIFIER,
|
||||
|
@ -298,8 +311,6 @@ class Postgres(Dialect):
|
|||
VAR_SINGLE_TOKENS = {"$"}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
CONCAT_NULL_OUTPUTS_STRING = True
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"DATE_TRUNC": parse_timestamp_trunc,
|
||||
|
@ -326,12 +337,13 @@ class Postgres(Dialect):
|
|||
|
||||
RANGE_PARSERS = {
|
||||
**parser.Parser.RANGE_PARSERS,
|
||||
TokenType.AT_GT: binary_range_parser(exp.ArrayContains),
|
||||
TokenType.DAMP: binary_range_parser(exp.ArrayOverlaps),
|
||||
TokenType.DAT: lambda self, this: self.expression(
|
||||
exp.MatchAgainst, this=self._parse_bitwise(), expressions=[this]
|
||||
),
|
||||
TokenType.AT_GT: binary_range_parser(exp.ArrayContains),
|
||||
TokenType.LT_AT: binary_range_parser(exp.ArrayContained),
|
||||
TokenType.OPERATOR: lambda self, this: self._parse_operator(this),
|
||||
}
|
||||
|
||||
STATEMENT_PARSERS = {
|
||||
|
@ -339,11 +351,28 @@ class Postgres(Dialect):
|
|||
TokenType.END: lambda self: self._parse_commit_or_rollback(),
|
||||
}
|
||||
|
||||
def _parse_factor(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_tokens(self._parse_exponent, self.FACTOR)
|
||||
def _parse_operator(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
while True:
|
||||
if not self._match(TokenType.L_PAREN):
|
||||
break
|
||||
|
||||
def _parse_exponent(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_tokens(self._parse_unary, self.EXPONENT)
|
||||
op = ""
|
||||
while self._curr and not self._match(TokenType.R_PAREN):
|
||||
op += self._curr.text
|
||||
self._advance()
|
||||
|
||||
this = self.expression(
|
||||
exp.Operator,
|
||||
comments=self._prev_comments,
|
||||
this=this,
|
||||
operator=op,
|
||||
expression=self._parse_bitwise(),
|
||||
)
|
||||
|
||||
if not self._match(TokenType.OPERATOR):
|
||||
break
|
||||
|
||||
return this
|
||||
|
||||
def _parse_date_part(self) -> exp.Expression:
|
||||
part = self._parse_type()
|
||||
|
@ -405,7 +434,7 @@ class Postgres(Dialect):
|
|||
exp.Max: max_or_greatest,
|
||||
exp.MapFromEntries: no_map_from_entries_sql,
|
||||
exp.Min: min_or_least,
|
||||
exp.Merge: transforms.preprocess([_remove_target_from_merge]),
|
||||
exp.Merge: _merge_sql,
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
exp.PercentileCont: transforms.preprocess(
|
||||
[transforms.add_within_group_for_percentiles]
|
||||
|
@ -434,6 +463,8 @@ class Postgres(Dialect):
|
|||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.Trim: trim_sql,
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.TsOrDsAdd: _date_add_sql("+"),
|
||||
exp.TsOrDsDiff: _date_diff_sql,
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("postgres"),
|
||||
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
|
|
|
@ -5,9 +5,11 @@ import typing as t
|
|||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
NormalizationStrategy,
|
||||
binary_from_function,
|
||||
bool_xor_sql,
|
||||
date_trunc_to_time,
|
||||
datestrtodate_sql,
|
||||
encode_decode_sql,
|
||||
format_time_lambda,
|
||||
if_sql,
|
||||
|
@ -22,6 +24,7 @@ from sqlglot.dialects.dialect import (
|
|||
struct_extract_sql,
|
||||
timestamptrunc_sql,
|
||||
timestrtotime_sql,
|
||||
ts_or_ds_add_cast,
|
||||
)
|
||||
from sqlglot.dialects.mysql import MySQL
|
||||
from sqlglot.helper import apply_index_offset, seq_get
|
||||
|
@ -95,17 +98,16 @@ def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate)
|
|||
|
||||
|
||||
def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str:
|
||||
this = expression.this
|
||||
expression = ts_or_ds_add_cast(expression)
|
||||
unit = exp.Literal.string(expression.text("unit") or "day")
|
||||
return self.func("DATE_ADD", unit, expression.expression, expression.this)
|
||||
|
||||
if not isinstance(this, exp.CurrentDate):
|
||||
this = exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE")
|
||||
|
||||
return self.func(
|
||||
"DATE_ADD",
|
||||
exp.Literal.string(expression.text("unit") or "day"),
|
||||
expression.expression,
|
||||
this,
|
||||
)
|
||||
def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> str:
|
||||
this = exp.cast(expression.this, "TIMESTAMP")
|
||||
expr = exp.cast(expression.expression, "TIMESTAMP")
|
||||
unit = exp.Literal.string(expression.text("unit") or "day")
|
||||
return self.func("DATE_DIFF", unit, expr, this)
|
||||
|
||||
|
||||
def _approx_percentile(args: t.List) -> exp.Expression:
|
||||
|
@ -136,11 +138,11 @@ def _from_unixtime(args: t.List) -> exp.Expression:
|
|||
return exp.UnixToTime.from_arg_list(args)
|
||||
|
||||
|
||||
def _parse_element_at(args: t.List) -> exp.SafeBracket:
|
||||
def _parse_element_at(args: t.List) -> exp.Bracket:
|
||||
this = seq_get(args, 0)
|
||||
index = seq_get(args, 1)
|
||||
assert isinstance(this, exp.Expression) and isinstance(index, exp.Expression)
|
||||
return exp.SafeBracket(this=this, expressions=apply_index_offset(this, [index], -1))
|
||||
return exp.Bracket(this=this, expressions=[index], offset=1, safe=True)
|
||||
|
||||
|
||||
def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
|
||||
|
@ -168,6 +170,22 @@ def _first_last_sql(self: Presto.Generator, expression: exp.First | exp.Last) ->
|
|||
return rename_func("ARBITRARY")(self, expression)
|
||||
|
||||
|
||||
def _unix_to_time_sql(self: Presto.Generator, expression: exp.UnixToTime) -> str:
|
||||
scale = expression.args.get("scale")
|
||||
timestamp = self.sql(expression, "this")
|
||||
if scale in (None, exp.UnixToTime.SECONDS):
|
||||
return rename_func("FROM_UNIXTIME")(self, expression)
|
||||
if scale == exp.UnixToTime.MILLIS:
|
||||
return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000)"
|
||||
if scale == exp.UnixToTime.MICROS:
|
||||
return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000)"
|
||||
if scale == exp.UnixToTime.NANOS:
|
||||
return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000000)"
|
||||
|
||||
self.unsupported(f"Unsupported scale for timestamp: {scale}.")
|
||||
return ""
|
||||
|
||||
|
||||
class Presto(Dialect):
|
||||
INDEX_OFFSET = 1
|
||||
NULL_ORDERING = "nulls_are_last"
|
||||
|
@ -175,11 +193,12 @@ class Presto(Dialect):
|
|||
TIME_MAPPING = MySQL.TIME_MAPPING
|
||||
STRICT_STRING_CONCAT = True
|
||||
SUPPORTS_SEMI_ANTI_JOIN = False
|
||||
TYPED_DIVISION = True
|
||||
|
||||
# https://github.com/trinodb/trino/issues/17
|
||||
# https://github.com/trinodb/trino/issues/12289
|
||||
# https://github.com/prestodb/presto/issues/2863
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
KEYWORDS = {
|
||||
|
@ -229,6 +248,7 @@ class Presto(Dialect):
|
|||
),
|
||||
"ROW": exp.Struct.from_arg_list,
|
||||
"SEQUENCE": exp.GenerateSeries.from_arg_list,
|
||||
"SET_AGG": exp.ArrayUniqueAgg.from_arg_list,
|
||||
"SPLIT_TO_MAP": exp.StrToMap.from_arg_list,
|
||||
"STRPOS": lambda args: exp.StrPosition(
|
||||
this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2)
|
||||
|
@ -253,6 +273,7 @@ class Presto(Dialect):
|
|||
NVL2_SUPPORTED = False
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
LIMIT_ONLY_LITERALS = True
|
||||
SUPPORTS_SINGLE_ARG_CONCAT = False
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION,
|
||||
|
@ -284,6 +305,7 @@ class Presto(Dialect):
|
|||
exp.ArrayConcat: rename_func("CONCAT"),
|
||||
exp.ArrayContains: rename_func("CONTAINS"),
|
||||
exp.ArraySize: rename_func("CARDINALITY"),
|
||||
exp.ArrayUniqueAgg: rename_func("SET_AGG"),
|
||||
exp.BitwiseAnd: lambda self, e: f"BITWISE_AND({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.BitwiseLeftShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_LEFT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.BitwiseNot: lambda self, e: f"BITWISE_NOT({self.sql(e, 'this')})",
|
||||
|
@ -298,7 +320,7 @@ class Presto(Dialect):
|
|||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
|
||||
),
|
||||
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.DATE_FORMAT}) AS DATE)",
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
|
||||
exp.DateSub: lambda self, e: self.func(
|
||||
"DATE_ADD",
|
||||
|
@ -330,9 +352,6 @@ class Presto(Dialect):
|
|||
exp.Quantile: _quantile_sql,
|
||||
exp.RegexpExtract: regexp_extract_sql,
|
||||
exp.Right: right_to_substring_sql,
|
||||
exp.SafeBracket: lambda self, e: self.func(
|
||||
"ELEMENT_AT", e.this, seq_get(apply_index_offset(e.this, e.expressions, 1), 0)
|
||||
),
|
||||
exp.SafeDivide: no_safe_divide_sql,
|
||||
exp.Schema: _schema_sql,
|
||||
exp.Select: transforms.preprocess(
|
||||
|
@ -361,10 +380,11 @@ class Presto(Dialect):
|
|||
exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]),
|
||||
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
|
||||
exp.TsOrDsAdd: _ts_or_ds_add_sql,
|
||||
exp.TsOrDsDiff: _ts_or_ds_diff_sql,
|
||||
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
|
||||
exp.Unhex: rename_func("FROM_HEX"),
|
||||
exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
|
||||
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
exp.With: transforms.preprocess([transforms.add_recursive_cte_column_names]),
|
||||
|
@ -374,8 +394,24 @@ class Presto(Dialect):
|
|||
exp.Xor: bool_xor_sql,
|
||||
}
|
||||
|
||||
def bracket_sql(self, expression: exp.Bracket) -> str:
|
||||
if expression.args.get("safe"):
|
||||
return self.func(
|
||||
"ELEMENT_AT",
|
||||
expression.this,
|
||||
seq_get(
|
||||
apply_index_offset(
|
||||
expression.this,
|
||||
expression.expressions,
|
||||
1 - expression.args.get("offset", 0),
|
||||
),
|
||||
0,
|
||||
),
|
||||
)
|
||||
return super().bracket_sql(expression)
|
||||
|
||||
def struct_sql(self, expression: exp.Struct) -> str:
|
||||
if any(isinstance(arg, (exp.EQ, exp.Slice)) for arg in expression.expressions):
|
||||
if any(isinstance(arg, self.KEY_VALUE_DEFINITONS) for arg in expression.expressions):
|
||||
self.unsupported("Struct with key-value definitions is unsupported.")
|
||||
return self.function_fallback_sql(expression)
|
||||
|
||||
|
|
|
@ -4,8 +4,10 @@ import typing as t
|
|||
|
||||
from sqlglot import exp, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
NormalizationStrategy,
|
||||
concat_to_dpipe_sql,
|
||||
concat_ws_to_dpipe_sql,
|
||||
date_delta_sql,
|
||||
generatedasidentitycolumnconstraint_sql,
|
||||
rename_func,
|
||||
ts_or_ds_to_date_sql,
|
||||
|
@ -14,30 +16,28 @@ from sqlglot.dialects.postgres import Postgres
|
|||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E
|
||||
|
||||
|
||||
def _json_sql(self: Redshift.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str:
|
||||
return f'{self.sql(expression, "this")}."{expression.expression.name}"'
|
||||
|
||||
|
||||
def _parse_date_add(args: t.List) -> exp.DateAdd:
|
||||
return exp.DateAdd(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 2)),
|
||||
expression=seq_get(args, 1),
|
||||
unit=seq_get(args, 0),
|
||||
)
|
||||
def _parse_date_delta(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
|
||||
def _parse_delta(args: t.List) -> E:
|
||||
expr = expr_type(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
|
||||
if expr_type is exp.TsOrDsAdd:
|
||||
expr.set("return_type", exp.DataType.build("TIMESTAMP"))
|
||||
|
||||
return expr
|
||||
|
||||
def _parse_datediff(args: t.List) -> exp.DateDiff:
|
||||
return exp.DateDiff(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 2)),
|
||||
expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
|
||||
unit=seq_get(args, 0),
|
||||
)
|
||||
return _parse_delta
|
||||
|
||||
|
||||
class Redshift(Postgres):
|
||||
# https://docs.aws.amazon.com/redshift/latest/dg/r_names.html
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
|
||||
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
INDEX_OFFSET = 0
|
||||
|
@ -52,15 +52,16 @@ class Redshift(Postgres):
|
|||
class Parser(Postgres.Parser):
|
||||
FUNCTIONS = {
|
||||
**Postgres.Parser.FUNCTIONS,
|
||||
"ADD_MONTHS": lambda args: exp.DateAdd(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
"ADD_MONTHS": lambda args: exp.TsOrDsAdd(
|
||||
this=seq_get(args, 0),
|
||||
expression=seq_get(args, 1),
|
||||
unit=exp.var("month"),
|
||||
return_type=exp.DataType.build("TIMESTAMP"),
|
||||
),
|
||||
"DATEADD": _parse_date_add,
|
||||
"DATE_ADD": _parse_date_add,
|
||||
"DATEDIFF": _parse_datediff,
|
||||
"DATE_DIFF": _parse_datediff,
|
||||
"DATEADD": _parse_date_delta(exp.TsOrDsAdd),
|
||||
"DATE_ADD": _parse_date_delta(exp.TsOrDsAdd),
|
||||
"DATEDIFF": _parse_date_delta(exp.TsOrDsDiff),
|
||||
"DATE_DIFF": _parse_date_delta(exp.TsOrDsDiff),
|
||||
"LISTAGG": exp.GroupConcat.from_arg_list,
|
||||
"STRTOL": exp.FromBase.from_arg_list,
|
||||
}
|
||||
|
@ -169,12 +170,8 @@ class Redshift(Postgres):
|
|||
exp.ConcatWs: concat_ws_to_dpipe_sql,
|
||||
exp.ApproxDistinct: lambda self, e: f"APPROXIMATE COUNT(DISTINCT {self.sql(e, 'this')})",
|
||||
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
|
||||
exp.DateAdd: lambda self, e: self.func(
|
||||
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
|
||||
),
|
||||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATEDIFF", exp.var(e.text("unit") or "day"), e.expression, e.this
|
||||
),
|
||||
exp.DateAdd: date_delta_sql("DATEADD"),
|
||||
exp.DateDiff: date_delta_sql("DATEDIFF"),
|
||||
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
|
||||
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
|
||||
exp.FromBase: rename_func("STRTOL"),
|
||||
|
@ -183,11 +180,12 @@ class Redshift(Postgres):
|
|||
exp.JSONExtractScalar: _json_sql,
|
||||
exp.GroupConcat: rename_func("LISTAGG"),
|
||||
exp.ParseJSON: rename_func("JSON_PARSE"),
|
||||
exp.SafeConcat: concat_to_dpipe_sql,
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
|
||||
),
|
||||
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
|
||||
exp.TsOrDsAdd: date_delta_sql("DATEADD"),
|
||||
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("redshift"),
|
||||
}
|
||||
|
||||
|
|
|
@ -3,9 +3,12 @@ from __future__ import annotations
|
|||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
NormalizationStrategy,
|
||||
binary_from_function,
|
||||
date_delta_sql,
|
||||
date_trunc_to_time,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
|
@ -21,7 +24,6 @@ from sqlglot.dialects.dialect import (
|
|||
)
|
||||
from sqlglot.expressions import Literal
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.parser import binary_range_parser
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
|
@ -50,7 +52,7 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime,
|
|||
elif second_arg.name == "3":
|
||||
timescale = exp.UnixToTime.MILLIS
|
||||
elif second_arg.name == "9":
|
||||
timescale = exp.UnixToTime.MICROS
|
||||
timescale = exp.UnixToTime.NANOS
|
||||
|
||||
return exp.UnixToTime(this=first_arg, scale=timescale)
|
||||
|
||||
|
@ -95,14 +97,17 @@ def _parse_datediff(args: t.List) -> exp.DateDiff:
|
|||
def _unix_to_time_sql(self: Snowflake.Generator, expression: exp.UnixToTime) -> str:
|
||||
scale = expression.args.get("scale")
|
||||
timestamp = self.sql(expression, "this")
|
||||
if scale in [None, exp.UnixToTime.SECONDS]:
|
||||
if scale in (None, exp.UnixToTime.SECONDS):
|
||||
return f"TO_TIMESTAMP({timestamp})"
|
||||
if scale == exp.UnixToTime.MILLIS:
|
||||
return f"TO_TIMESTAMP({timestamp}, 3)"
|
||||
if scale == exp.UnixToTime.MICROS:
|
||||
return f"TO_TIMESTAMP({timestamp} / 1000, 3)"
|
||||
if scale == exp.UnixToTime.NANOS:
|
||||
return f"TO_TIMESTAMP({timestamp}, 9)"
|
||||
|
||||
raise ValueError("Improper scale for timestamp")
|
||||
self.unsupported(f"Unsupported scale for timestamp: {scale}.")
|
||||
return ""
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
|
||||
|
@ -201,7 +206,7 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[Snowflake.Parser]
|
|||
|
||||
class Snowflake(Dialect):
|
||||
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = True
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
|
||||
NULL_ORDERING = "nulls_are_large"
|
||||
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
@ -236,6 +241,18 @@ class Snowflake(Dialect):
|
|||
"ff6": "%f",
|
||||
}
|
||||
|
||||
def quote_identifier(self, expression: E, identify: bool = True) -> E:
|
||||
# This disables quoting DUAL in SELECT ... FROM DUAL, because Snowflake treats an
|
||||
# unquoted DUAL keyword in a special way and does not map it to a user-defined table
|
||||
if (
|
||||
isinstance(expression, exp.Identifier)
|
||||
and isinstance(expression.parent, exp.Table)
|
||||
and expression.name.lower() == "dual"
|
||||
):
|
||||
return t.cast(E, expression)
|
||||
|
||||
return super().quote_identifier(expression, identify=identify)
|
||||
|
||||
class Parser(parser.Parser):
|
||||
IDENTIFY_PIVOT_STRINGS = True
|
||||
|
||||
|
@ -245,6 +262,9 @@ class Snowflake(Dialect):
|
|||
**parser.Parser.FUNCTIONS,
|
||||
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
|
||||
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
|
||||
"ARRAY_CONTAINS": lambda args: exp.ArrayContains(
|
||||
this=seq_get(args, 1), expression=seq_get(args, 0)
|
||||
),
|
||||
"ARRAY_GENERATE_RANGE": lambda args: exp.GenerateSeries(
|
||||
# ARRAY_GENERATE_RANGE has an exlusive end; we normalize it to be inclusive
|
||||
start=seq_get(args, 0),
|
||||
|
@ -296,8 +316,8 @@ class Snowflake(Dialect):
|
|||
|
||||
RANGE_PARSERS = {
|
||||
**parser.Parser.RANGE_PARSERS,
|
||||
TokenType.LIKE_ANY: binary_range_parser(exp.LikeAny),
|
||||
TokenType.ILIKE_ANY: binary_range_parser(exp.ILikeAny),
|
||||
TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny),
|
||||
TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny),
|
||||
}
|
||||
|
||||
ALTER_PARSERS = {
|
||||
|
@ -317,6 +337,11 @@ class Snowflake(Dialect):
|
|||
TokenType.SHOW: lambda self: self._parse_show(),
|
||||
}
|
||||
|
||||
PROPERTY_PARSERS = {
|
||||
**parser.Parser.PROPERTY_PARSERS,
|
||||
"LOCATION": lambda self: self._parse_location(),
|
||||
}
|
||||
|
||||
SHOW_PARSERS = {
|
||||
"PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
|
||||
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
|
||||
|
@ -349,7 +374,7 @@ class Snowflake(Dialect):
|
|||
table: t.Optional[exp.Expression] = None
|
||||
if self._match_text_seq("@"):
|
||||
table_name = "@"
|
||||
while True:
|
||||
while self._curr:
|
||||
self._advance()
|
||||
table_name += self._prev.text
|
||||
if not self._match_set(self.STAGED_FILE_SINGLE_TOKENS, advance=False):
|
||||
|
@ -411,6 +436,20 @@ class Snowflake(Dialect):
|
|||
self._match_text_seq("WITH")
|
||||
return self.expression(exp.SwapTable, this=self._parse_table(schema=True))
|
||||
|
||||
def _parse_location(self) -> exp.LocationProperty:
|
||||
self._match(TokenType.EQ)
|
||||
|
||||
parts = [self._parse_var(any_token=True)]
|
||||
|
||||
while self._match(TokenType.SLASH):
|
||||
if self._curr and self._prev.end + 1 == self._curr.start:
|
||||
parts.append(self._parse_var(any_token=True))
|
||||
else:
|
||||
parts.append(exp.Var(this=""))
|
||||
return self.expression(
|
||||
exp.LocationProperty, this=exp.var("/".join(str(p) for p in parts))
|
||||
)
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
STRING_ESCAPES = ["\\", "'"]
|
||||
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
|
||||
|
@ -457,6 +496,7 @@ class Snowflake(Dialect):
|
|||
AGGREGATE_FILTER_SUPPORTED = False
|
||||
SUPPORTS_TABLE_COPY = False
|
||||
COLLATE_IS_FUNC = True
|
||||
LIMIT_ONLY_LITERALS = True
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -464,15 +504,14 @@ class Snowflake(Dialect):
|
|||
exp.ArgMin: rename_func("MIN_BY"),
|
||||
exp.Array: inline_array_sql,
|
||||
exp.ArrayConcat: rename_func("ARRAY_CAT"),
|
||||
exp.ArrayContains: lambda self, e: self.func("ARRAY_CONTAINS", e.expression, e.this),
|
||||
exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
|
||||
exp.AtTimeZone: lambda self, e: self.func(
|
||||
"CONVERT_TIMEZONE", e.args.get("zone"), e.this
|
||||
),
|
||||
exp.BitwiseXor: rename_func("BITXOR"),
|
||||
exp.DateAdd: lambda self, e: self.func("DATEADD", e.text("unit"), e.expression, e.this),
|
||||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATEDIFF", e.text("unit"), e.expression, e.this
|
||||
),
|
||||
exp.DateAdd: date_delta_sql("DATEADD"),
|
||||
exp.DateDiff: date_delta_sql("DATEDIFF"),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
|
@ -501,10 +540,11 @@ class Snowflake(Dialect):
|
|||
exp.Select: transforms.preprocess(
|
||||
[
|
||||
transforms.eliminate_distinct_on,
|
||||
transforms.explode_to_unnest(0),
|
||||
transforms.explode_to_unnest(),
|
||||
transforms.eliminate_semi_and_anti_joins,
|
||||
]
|
||||
),
|
||||
exp.SHA: rename_func("SHA1"),
|
||||
exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
|
||||
exp.StartsWith: rename_func("STARTSWITH"),
|
||||
exp.StrPosition: lambda self, e: self.func(
|
||||
|
@ -524,6 +564,8 @@ class Snowflake(Dialect):
|
|||
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
|
||||
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
|
||||
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"),
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
|
@ -547,6 +589,20 @@ class Snowflake(Dialect):
|
|||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def trycast_sql(self, expression: exp.TryCast) -> str:
|
||||
value = expression.this
|
||||
|
||||
if value.type is None:
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
value = annotate_types(value)
|
||||
|
||||
if value.is_type(*exp.DataType.TEXT_TYPES, exp.DataType.Type.UNKNOWN):
|
||||
return super().trycast_sql(expression)
|
||||
|
||||
# TRY_CAST only works for string values in Snowflake
|
||||
return self.cast_sql(expression)
|
||||
|
||||
def log_sql(self, expression: exp.Log) -> str:
|
||||
if not expression.expression:
|
||||
return self.func("LN", expression.this)
|
||||
|
@ -554,24 +610,28 @@ class Snowflake(Dialect):
|
|||
return super().log_sql(expression)
|
||||
|
||||
def unnest_sql(self, expression: exp.Unnest) -> str:
|
||||
selects = ["value"]
|
||||
unnest_alias = expression.args.get("alias")
|
||||
|
||||
offset = expression.args.get("offset")
|
||||
if offset:
|
||||
if unnest_alias:
|
||||
unnest_alias.append("columns", offset.pop())
|
||||
|
||||
selects.append("index")
|
||||
columns = [
|
||||
exp.to_identifier("seq"),
|
||||
exp.to_identifier("key"),
|
||||
exp.to_identifier("path"),
|
||||
offset.pop() if isinstance(offset, exp.Expression) else exp.to_identifier("index"),
|
||||
seq_get(unnest_alias.columns if unnest_alias else [], 0)
|
||||
or exp.to_identifier("value"),
|
||||
exp.to_identifier("this"),
|
||||
]
|
||||
|
||||
subquery = exp.Subquery(
|
||||
this=exp.select(*selects).from_(
|
||||
f"TABLE(FLATTEN(INPUT => {self.sql(expression.expressions[0])}))"
|
||||
),
|
||||
)
|
||||
if unnest_alias:
|
||||
unnest_alias.set("columns", columns)
|
||||
else:
|
||||
unnest_alias = exp.TableAlias(this="_u", columns=columns)
|
||||
|
||||
explode = f"TABLE(FLATTEN(INPUT => {self.sql(expression.expressions[0])}))"
|
||||
alias = self.sql(unnest_alias)
|
||||
alias = f" AS {alias}" if alias else ""
|
||||
return f"{self.sql(subquery)}{alias}"
|
||||
return f"{explode}{alias}"
|
||||
|
||||
def show_sql(self, expression: exp.Show) -> str:
|
||||
scope = self.sql(expression, "scope")
|
||||
|
@ -632,3 +692,6 @@ class Snowflake(Dialect):
|
|||
def swaptable_sql(self, expression: exp.SwapTable) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
return f"SWAP WITH {this}"
|
||||
|
||||
def with_properties(self, properties: exp.Properties) -> str:
|
||||
return self.properties(properties, wrapped=False, prefix=self.seg(""), sep=" ")
|
||||
|
|
|
@ -56,15 +56,17 @@ class Spark(Spark2):
|
|||
|
||||
def _parse_generated_as_identity(
|
||||
self,
|
||||
) -> exp.GeneratedAsIdentityColumnConstraint | exp.ComputedColumnConstraint:
|
||||
) -> (
|
||||
exp.GeneratedAsIdentityColumnConstraint
|
||||
| exp.ComputedColumnConstraint
|
||||
| exp.GeneratedAsRowColumnConstraint
|
||||
):
|
||||
this = super()._parse_generated_as_identity()
|
||||
if this.expression:
|
||||
return self.expression(exp.ComputedColumnConstraint, this=this.expression)
|
||||
return this
|
||||
|
||||
class Generator(Spark2.Generator):
|
||||
SUPPORTS_NESTED_CTES = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**Spark2.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.MONEY: "DECIMAL(15, 4)",
|
||||
|
|
|
@ -48,8 +48,11 @@ def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str
|
|||
return f"TIMESTAMP_MILLIS({timestamp})"
|
||||
if scale == exp.UnixToTime.MICROS:
|
||||
return f"TIMESTAMP_MICROS({timestamp})"
|
||||
if scale == exp.UnixToTime.NANOS:
|
||||
return f"TIMESTAMP_SECONDS({timestamp} / 1000000000)"
|
||||
|
||||
raise ValueError("Improper scale for timestamp")
|
||||
self.unsupported(f"Unsupported scale for timestamp: {scale}.")
|
||||
return ""
|
||||
|
||||
|
||||
def _unalias_pivot(expression: exp.Expression) -> exp.Expression:
|
||||
|
@ -119,7 +122,11 @@ class Spark2(Hive):
|
|||
"DOUBLE": _parse_as_cast("double"),
|
||||
"FLOAT": _parse_as_cast("float"),
|
||||
"FROM_UTC_TIMESTAMP": lambda args: exp.AtTimeZone(
|
||||
this=exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("timestamp")),
|
||||
this=exp.cast_unless(
|
||||
seq_get(args, 0) or exp.Var(this=""),
|
||||
exp.DataType.build("timestamp"),
|
||||
exp.DataType.build("timestamp"),
|
||||
),
|
||||
zone=seq_get(args, 1),
|
||||
),
|
||||
"IIF": exp.If.from_arg_list,
|
||||
|
@ -224,6 +231,19 @@ class Spark2(Hive):
|
|||
WRAP_DERIVED_VALUES = False
|
||||
CREATE_FUNCTION_RETURN_AS = False
|
||||
|
||||
def struct_sql(self, expression: exp.Struct) -> str:
|
||||
args = []
|
||||
for arg in expression.expressions:
|
||||
if isinstance(arg, self.KEY_VALUE_DEFINITONS):
|
||||
if isinstance(arg, exp.Bracket):
|
||||
args.append(exp.alias_(arg.this, arg.expressions[0].name))
|
||||
else:
|
||||
args.append(exp.alias_(arg.expression, arg.this.name))
|
||||
else:
|
||||
args.append(arg)
|
||||
|
||||
return self.func("STRUCT", *args)
|
||||
|
||||
def temporary_storage_provider(self, expression: exp.Create) -> exp.Create:
|
||||
# spark2, spark, Databricks require a storage provider for temporary tables
|
||||
provider = exp.FileFormatProperty(this=exp.Literal.string("parquet"))
|
||||
|
|
|
@ -5,6 +5,7 @@ import typing as t
|
|||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
NormalizationStrategy,
|
||||
any_value_to_max_sql,
|
||||
arrow_json_extract_scalar_sql,
|
||||
arrow_json_extract_sql,
|
||||
|
@ -63,8 +64,10 @@ def _transform_create(expression: exp.Expression) -> exp.Expression:
|
|||
|
||||
class SQLite(Dialect):
|
||||
# https://sqlite.org/forum/forumpost/5e575586ac5c711b?raw
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
|
||||
SUPPORTS_SEMI_ANTI_JOIN = False
|
||||
TYPED_DIVISION = True
|
||||
SAFE_DIVISION = True
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
IDENTIFIERS = ['"', ("[", "]"), "`"]
|
||||
|
@ -124,7 +127,6 @@ class SQLite(Dialect):
|
|||
exp.LogicalOr: rename_func("MAX"),
|
||||
exp.LogicalAnd: rename_func("MIN"),
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.SafeConcat: concat_to_dpipe_sql,
|
||||
exp.Select: transforms.preprocess(
|
||||
[
|
||||
transforms.eliminate_distinct_on,
|
||||
|
|
|
@ -9,6 +9,7 @@ from sqlglot.tokens import TokenType
|
|||
|
||||
class Teradata(Dialect):
|
||||
SUPPORTS_SEMI_ANTI_JOIN = False
|
||||
TYPED_DIVISION = True
|
||||
|
||||
TIME_MAPPING = {
|
||||
"Y": "%Y",
|
||||
|
@ -33,8 +34,10 @@ class Teradata(Dialect):
|
|||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
# https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Comparison-Operators-and-Functions/Comparison-Operators/ANSI-Compliance
|
||||
# https://docs.teradata.com/r/SQL-Functions-Operators-Expressions-and-Predicates/June-2017/Arithmetic-Trigonometric-Hyperbolic-Operators/Functions
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"**": TokenType.DSTAR,
|
||||
"^=": TokenType.NEQ,
|
||||
"BYTEINT": TokenType.SMALLINT,
|
||||
"COLLECT": TokenType.COMMAND,
|
||||
|
@ -112,10 +115,16 @@ class Teradata(Dialect):
|
|||
|
||||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
# https://docs.teradata.com/r/SQL-Functions-Operators-Expressions-and-Predicates/June-2017/Data-Type-Conversions/TRYCAST
|
||||
"TRYCAST": parser.Parser.FUNCTION_PARSERS["TRY_CAST"],
|
||||
"RANGE_N": lambda self: self._parse_rangen(),
|
||||
"TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST),
|
||||
}
|
||||
|
||||
EXPONENT = {
|
||||
TokenType.DSTAR: exp.Pow,
|
||||
}
|
||||
|
||||
def _parse_translate(self, strict: bool) -> exp.Expression:
|
||||
this = self._parse_conjunction()
|
||||
|
||||
|
@ -177,6 +186,7 @@ class Teradata(Dialect):
|
|||
exp.ArgMin: rename_func("MIN_BY"),
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.Pow: lambda self, e: self.binary(e, "**"),
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
|
||||
),
|
||||
|
@ -192,6 +202,9 @@ class Teradata(Dialect):
|
|||
|
||||
return super().cast_sql(expression, safe_prefix=safe_prefix)
|
||||
|
||||
def trycast_sql(self, expression: exp.TryCast) -> str:
|
||||
return self.cast_sql(expression, safe_prefix="TRY")
|
||||
|
||||
def tablesample_sql(
|
||||
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
|
||||
) -> str:
|
||||
|
|
|
@ -7,7 +7,9 @@ import typing as t
|
|||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
NormalizationStrategy,
|
||||
any_value_to_max_sql,
|
||||
date_delta_sql,
|
||||
generatedasidentitycolumnconstraint_sql,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
|
@ -135,11 +137,7 @@ def _parse_hashbytes(args: t.List) -> exp.Expression:
|
|||
return exp.func("HASHBYTES", *args)
|
||||
|
||||
|
||||
def generate_date_delta_with_unit_sql(
|
||||
self: TSQL.Generator, expression: exp.DateAdd | exp.DateDiff
|
||||
) -> str:
|
||||
func = "DATEADD" if isinstance(expression, exp.DateAdd) else "DATEDIFF"
|
||||
return self.func(func, expression.text("unit"), expression.expression, expression.this)
|
||||
DATEPART_ONLY_FORMATS = {"dw", "hour", "quarter"}
|
||||
|
||||
|
||||
def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str:
|
||||
|
@ -153,6 +151,11 @@ def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToSt
|
|||
)
|
||||
)
|
||||
)
|
||||
|
||||
# There is no format for "quarter"
|
||||
if fmt.name.lower() in DATEPART_ONLY_FORMATS:
|
||||
return self.func("DATEPART", fmt.name, expression.this)
|
||||
|
||||
return self.func("FORMAT", expression.this, fmt, expression.args.get("culture"))
|
||||
|
||||
|
||||
|
@ -202,18 +205,50 @@ def _parse_date_delta(
|
|||
return inner_func
|
||||
|
||||
|
||||
def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression:
|
||||
"""Ensures all (unnamed) output columns are aliased for CTEs and Subqueries."""
|
||||
alias = expression.args.get("alias")
|
||||
|
||||
if (
|
||||
isinstance(expression, (exp.CTE, exp.Subquery))
|
||||
and isinstance(alias, exp.TableAlias)
|
||||
and not alias.columns
|
||||
):
|
||||
from sqlglot.optimizer.qualify_columns import qualify_outputs
|
||||
|
||||
# We keep track of the unaliased column projection indexes instead of the expressions
|
||||
# themselves, because the latter are going to be replaced by new nodes when the aliases
|
||||
# are added and hence we won't be able to reach these newly added Alias parents
|
||||
subqueryable = expression.this
|
||||
unaliased_column_indexes = (
|
||||
i
|
||||
for i, c in enumerate(subqueryable.selects)
|
||||
if isinstance(c, exp.Column) and not c.alias
|
||||
)
|
||||
|
||||
qualify_outputs(subqueryable)
|
||||
|
||||
# Preserve the quoting information of columns for newly added Alias nodes
|
||||
subqueryable_selects = subqueryable.selects
|
||||
for select_index in unaliased_column_indexes:
|
||||
alias = subqueryable_selects[select_index]
|
||||
column = alias.this
|
||||
if isinstance(column.this, exp.Identifier):
|
||||
alias.args["alias"].set("quoted", column.this.quoted)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
class TSQL(Dialect):
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
|
||||
NULL_ORDERING = "nulls_are_small"
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
|
||||
TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"
|
||||
SUPPORTS_SEMI_ANTI_JOIN = False
|
||||
LOG_BASE_FIRST = False
|
||||
TYPED_DIVISION = True
|
||||
CONCAT_COALESCE = True
|
||||
|
||||
TIME_MAPPING = {
|
||||
"year": "%Y",
|
||||
"qq": "%q",
|
||||
"q": "%q",
|
||||
"quarter": "%q",
|
||||
"dayofyear": "%j",
|
||||
"day": "%d",
|
||||
"dy": "%d",
|
||||
|
@ -320,6 +355,7 @@ class TSQL(Dialect):
|
|||
IDENTIFIERS = ['"', ("[", "]")]
|
||||
QUOTES = ["'", '"']
|
||||
HEX_STRINGS = [("0x", ""), ("0X", "")]
|
||||
VAR_SINGLE_TOKENS = {"@", "$", "#"}
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
|
@ -403,9 +439,7 @@ class TSQL(Dialect):
|
|||
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
CONCAT_NULL_OUTPUTS_STRING = True
|
||||
|
||||
ALTER_TABLE_ADD_COLUMN_KEYWORD = False
|
||||
ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
|
||||
|
||||
def _parse_projections(self) -> t.List[exp.Expression]:
|
||||
"""
|
||||
|
@ -433,7 +467,7 @@ class TSQL(Dialect):
|
|||
"""
|
||||
rollback = self._prev.token_type == TokenType.ROLLBACK
|
||||
|
||||
self._match_texts({"TRAN", "TRANSACTION"})
|
||||
self._match_texts(("TRAN", "TRANSACTION"))
|
||||
this = self._parse_id_var()
|
||||
|
||||
if rollback:
|
||||
|
@ -579,23 +613,35 @@ class TSQL(Dialect):
|
|||
return super()._parse_if()
|
||||
|
||||
def _parse_unique(self) -> exp.UniqueColumnConstraint:
|
||||
return self.expression(
|
||||
exp.UniqueColumnConstraint,
|
||||
this=None
|
||||
if self._curr and self._curr.text.upper() in {"CLUSTERED", "NONCLUSTERED"}
|
||||
else self._parse_schema(self._parse_id_var(any_token=False)),
|
||||
)
|
||||
if self._match_texts(("CLUSTERED", "NONCLUSTERED")):
|
||||
this = self.CONSTRAINT_PARSERS[self._prev.text.upper()](self)
|
||||
else:
|
||||
this = self._parse_schema(self._parse_id_var(any_token=False))
|
||||
|
||||
return self.expression(exp.UniqueColumnConstraint, this=this)
|
||||
|
||||
class Generator(generator.Generator):
|
||||
LIMIT_IS_TOP = True
|
||||
QUERY_HINTS = False
|
||||
RETURNING_END = False
|
||||
NVL2_SUPPORTED = False
|
||||
ALTER_TABLE_ADD_COLUMN_KEYWORD = False
|
||||
ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = False
|
||||
LIMIT_FETCH = "FETCH"
|
||||
COMPUTED_COLUMN_WITH_TYPE = False
|
||||
SUPPORTS_NESTED_CTES = False
|
||||
CTE_RECURSIVE_KEYWORD_REQUIRED = False
|
||||
ENSURE_BOOLS = True
|
||||
NULL_ORDERING_SUPPORTED = False
|
||||
SUPPORTS_SINGLE_ARG_CONCAT = False
|
||||
|
||||
EXPRESSIONS_WITHOUT_NESTED_CTES = {
|
||||
exp.Delete,
|
||||
exp.Insert,
|
||||
exp.Merge,
|
||||
exp.Select,
|
||||
exp.Subquery,
|
||||
exp.Union,
|
||||
exp.Update,
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
|
@ -614,14 +660,16 @@ class TSQL(Dialect):
|
|||
**generator.Generator.TRANSFORMS,
|
||||
exp.AnyValue: any_value_to_max_sql,
|
||||
exp.AutoIncrementColumnConstraint: lambda *_: "IDENTITY",
|
||||
exp.DateAdd: generate_date_delta_with_unit_sql,
|
||||
exp.DateDiff: generate_date_delta_with_unit_sql,
|
||||
exp.DateAdd: date_delta_sql("DATEADD"),
|
||||
exp.DateDiff: date_delta_sql("DATEDIFF"),
|
||||
exp.CTE: transforms.preprocess([qualify_derived_table_outputs]),
|
||||
exp.CurrentDate: rename_func("GETDATE"),
|
||||
exp.CurrentTimestamp: rename_func("GETDATE"),
|
||||
exp.Extract: rename_func("DATEPART"),
|
||||
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
|
||||
exp.GroupConcat: _string_agg_sql,
|
||||
exp.If: rename_func("IIF"),
|
||||
exp.Length: rename_func("LEN"),
|
||||
exp.Max: max_or_greatest,
|
||||
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
|
||||
exp.Min: min_or_least,
|
||||
|
@ -633,15 +681,16 @@ class TSQL(Dialect):
|
|||
transforms.eliminate_qualify,
|
||||
]
|
||||
),
|
||||
exp.Subquery: transforms.preprocess([qualify_derived_table_outputs]),
|
||||
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
|
||||
exp.SHA2: lambda self, e: self.func(
|
||||
"HASHBYTES",
|
||||
exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"),
|
||||
e.this,
|
||||
"HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
|
||||
),
|
||||
exp.TemporaryProperty: lambda self, e: "",
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeToStr: _format_sql,
|
||||
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
|
||||
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("tsql"),
|
||||
}
|
||||
|
||||
|
@ -690,8 +739,21 @@ class TSQL(Dialect):
|
|||
|
||||
table = expression.find(exp.Table)
|
||||
|
||||
# Convert CTAS statement to SELECT .. INTO ..
|
||||
if kind == "TABLE" and expression.expression:
|
||||
sql = f"SELECT * INTO {self.sql(table)} FROM ({self.sql(expression.expression)}) AS temp"
|
||||
ctas_with = expression.expression.args.get("with")
|
||||
if ctas_with:
|
||||
ctas_with = ctas_with.pop()
|
||||
|
||||
subquery = expression.expression
|
||||
if isinstance(subquery, exp.Subqueryable):
|
||||
subquery = subquery.subquery()
|
||||
|
||||
select_into = exp.select("*").from_(exp.alias_(subquery, "temp", table=True))
|
||||
select_into.set("into", exp.Into(this=table))
|
||||
select_into.set("with", ctas_with)
|
||||
|
||||
sql = self.sql(select_into)
|
||||
|
||||
if exists:
|
||||
identifier = self.sql(exp.Literal.string(exp.table_name(table) if table else ""))
|
||||
|
|
|
@ -139,10 +139,16 @@ def interval(this, unit):
|
|||
return datetime.timedelta(**{unit: float(this)})
|
||||
|
||||
|
||||
@null_if_any("this", "expression")
|
||||
def arrayjoin(this, expression, null=None):
|
||||
return expression.join(x for x in (x if x is not None else null for x in this) if x is not None)
|
||||
|
||||
|
||||
ENV = {
|
||||
"exp": exp,
|
||||
# aggs
|
||||
"ARRAYAGG": list,
|
||||
"ARRAYUNIQUEAGG": filter_nulls(lambda acc: list(set(acc))),
|
||||
"AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore
|
||||
"COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False),
|
||||
"MAX": filter_nulls(max),
|
||||
|
@ -152,6 +158,7 @@ ENV = {
|
|||
"ABS": null_if_any(lambda this: abs(this)),
|
||||
"ADD": null_if_any(lambda e, this: e + this),
|
||||
"ARRAYANY": null_if_any(lambda arr, func: any(func(e) for e in arr)),
|
||||
"ARRAYJOIN": arrayjoin,
|
||||
"BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high),
|
||||
"BITWISEAND": null_if_any(lambda this, e: this & e),
|
||||
"BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e),
|
||||
|
@ -203,4 +210,9 @@ ENV = {
|
|||
"CURRENTDATE": datetime.date.today,
|
||||
"STRFTIME": null_if_any(lambda fmt, arg: datetime.datetime.fromisoformat(arg).strftime(fmt)),
|
||||
"TRIM": null_if_any(lambda this, e=None: this.strip(e)),
|
||||
"STRUCT": lambda *args: {
|
||||
args[x]: args[x + 1]
|
||||
for x in range(0, len(args), 2)
|
||||
if (args[x + 1] is not None and args[x] is not None)
|
||||
},
|
||||
}
|
||||
|
|
|
@ -397,6 +397,20 @@ def _lambda_sql(self, e: exp.Lambda) -> str:
|
|||
return f"lambda {self.expressions(e, flat=True)}: {self.sql(e, 'this')}"
|
||||
|
||||
|
||||
def _div_sql(self: generator.Generator, e: exp.Div) -> str:
|
||||
denominator = self.sql(e, "expression")
|
||||
|
||||
if e.args.get("safe"):
|
||||
denominator += " or None"
|
||||
|
||||
sql = f"DIV({self.sql(e, 'this')}, {denominator})"
|
||||
|
||||
if e.args.get("typed"):
|
||||
sql = f"int({sql})"
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
class Python(Dialect):
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
STRING_ESCAPES = ["\\"]
|
||||
|
@ -413,7 +427,11 @@ class Python(Dialect):
|
|||
exp.Boolean: lambda self, e: "True" if e.this else "False",
|
||||
exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})",
|
||||
exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
|
||||
exp.Concat: lambda self, e: self.func(
|
||||
"SAFECONCAT" if e.args.get("safe") else "CONCAT", *e.expressions
|
||||
),
|
||||
exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})",
|
||||
exp.Div: _div_sql,
|
||||
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
|
||||
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {{{self.expressions(e, flat=True)}}}",
|
||||
exp.Interval: lambda self, e: f"INTERVAL({self.sql(e.this)}, '{self.sql(e.unit)}')",
|
||||
|
|
|
@ -120,20 +120,22 @@ def _ensure_tables(d: t.Optional[t.Dict], dialect: DialectType = None) -> t.Dict
|
|||
depth = dict_depth(d)
|
||||
if depth > 1:
|
||||
return {
|
||||
normalize_name(k, dialect=dialect, is_table=True): _ensure_tables(v, dialect=dialect)
|
||||
normalize_name(k, dialect=dialect, is_table=True).name: _ensure_tables(
|
||||
v, dialect=dialect
|
||||
)
|
||||
for k, v in d.items()
|
||||
}
|
||||
|
||||
result = {}
|
||||
for table_name, table in d.items():
|
||||
table_name = normalize_name(table_name, dialect=dialect)
|
||||
table_name = normalize_name(table_name, dialect=dialect).name
|
||||
|
||||
if isinstance(table, Table):
|
||||
result[table_name] = table
|
||||
else:
|
||||
table = [
|
||||
{
|
||||
normalize_name(column_name, dialect=dialect): value
|
||||
normalize_name(column_name, dialect=dialect).name: value
|
||||
for column_name, value in row.items()
|
||||
}
|
||||
for row in table
|
||||
|
|
|
@ -53,6 +53,7 @@ class _Expression(type):
|
|||
|
||||
|
||||
SQLGLOT_META = "sqlglot.meta"
|
||||
TABLE_PARTS = ("this", "db", "catalog")
|
||||
|
||||
|
||||
class Expression(metaclass=_Expression):
|
||||
|
@ -134,7 +135,7 @@ class Expression(metaclass=_Expression):
|
|||
return self.args.get("expression")
|
||||
|
||||
@property
|
||||
def expressions(self):
|
||||
def expressions(self) -> t.List[t.Any]:
|
||||
"""
|
||||
Retrieves the argument with key "expressions".
|
||||
"""
|
||||
|
@ -238,6 +239,9 @@ class Expression(metaclass=_Expression):
|
|||
dtype = DataType.build(dtype)
|
||||
self._type = dtype # type: ignore
|
||||
|
||||
def is_type(self, *dtypes) -> bool:
|
||||
return self.type is not None and self.type.is_type(*dtypes)
|
||||
|
||||
@property
|
||||
def meta(self) -> t.Dict[str, t.Any]:
|
||||
if self._meta is None:
|
||||
|
@ -481,7 +485,7 @@ class Expression(metaclass=_Expression):
|
|||
|
||||
def flatten(self, unnest=True):
|
||||
"""
|
||||
Returns a generator which yields child nodes who's parents are the same class.
|
||||
Returns a generator which yields child nodes whose parents are the same class.
|
||||
|
||||
A AND B AND C -> [A, B, C]
|
||||
"""
|
||||
|
@ -508,7 +512,7 @@ class Expression(metaclass=_Expression):
|
|||
"""
|
||||
from sqlglot.dialects import Dialect
|
||||
|
||||
return Dialect.get_or_raise(dialect)().generate(self, **opts)
|
||||
return Dialect.get_or_raise(dialect).generate(self, **opts)
|
||||
|
||||
def _to_s(self, hide_missing: bool = True, level: int = 0) -> str:
|
||||
indent = "" if not level else "\n"
|
||||
|
@ -821,6 +825,12 @@ class Expression(metaclass=_Expression):
|
|||
def rlike(self, other: ExpOrStr) -> RegexpLike:
|
||||
return self._binop(RegexpLike, other)
|
||||
|
||||
def div(self, other: ExpOrStr, typed: bool = False, safe: bool = False) -> Div:
|
||||
div = self._binop(Div, other)
|
||||
div.args["typed"] = typed
|
||||
div.args["safe"] = safe
|
||||
return div
|
||||
|
||||
def __lt__(self, other: t.Any) -> LT:
|
||||
return self._binop(LT, other)
|
||||
|
||||
|
@ -1000,7 +1010,6 @@ class UDTF(DerivedTable, Unionable):
|
|||
|
||||
class Cache(Expression):
|
||||
arg_types = {
|
||||
"with": False,
|
||||
"this": True,
|
||||
"lazy": False,
|
||||
"options": False,
|
||||
|
@ -1012,6 +1021,10 @@ class Uncache(Expression):
|
|||
arg_types = {"this": True, "exists": False}
|
||||
|
||||
|
||||
class Refresh(Expression):
|
||||
pass
|
||||
|
||||
|
||||
class DDL(Expression):
|
||||
@property
|
||||
def ctes(self):
|
||||
|
@ -1033,6 +1046,43 @@ class DDL(Expression):
|
|||
return []
|
||||
|
||||
|
||||
class DML(Expression):
|
||||
def returning(
|
||||
self,
|
||||
expression: ExpOrStr,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
) -> DML:
|
||||
"""
|
||||
Set the RETURNING expression. Not supported by all dialects.
|
||||
|
||||
Example:
|
||||
>>> delete("tbl").returning("*", dialect="postgres").sql()
|
||||
'DELETE FROM tbl RETURNING *'
|
||||
|
||||
Args:
|
||||
expression: the SQL code strings to parse.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
dialect: the dialect used to parse the input expressions.
|
||||
copy: if `False`, modify this expression instance in-place.
|
||||
opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
Delete: the modified expression.
|
||||
"""
|
||||
return _apply_builder(
|
||||
expression=expression,
|
||||
instance=self,
|
||||
arg="returning",
|
||||
prefix="RETURNING",
|
||||
dialect=dialect,
|
||||
copy=copy,
|
||||
into=Returning,
|
||||
**opts,
|
||||
)
|
||||
|
||||
|
||||
class Create(DDL):
|
||||
arg_types = {
|
||||
"with": False,
|
||||
|
@ -1133,8 +1183,10 @@ class WithinGroup(Expression):
|
|||
arg_types = {"this": True, "expression": False}
|
||||
|
||||
|
||||
# clickhouse supports scalar ctes
|
||||
# https://clickhouse.com/docs/en/sql-reference/statements/select/with
|
||||
class CTE(DerivedTable):
|
||||
arg_types = {"this": True, "alias": True}
|
||||
arg_types = {"this": True, "alias": True, "scalar": False}
|
||||
|
||||
|
||||
class TableAlias(Expression):
|
||||
|
@ -1297,6 +1349,10 @@ class AutoIncrementColumnConstraint(ColumnConstraintKind):
|
|||
pass
|
||||
|
||||
|
||||
class PeriodForSystemTimeConstraint(ColumnConstraintKind):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
class CaseSpecificColumnConstraint(ColumnConstraintKind):
|
||||
arg_types = {"not_": True}
|
||||
|
||||
|
@ -1351,6 +1407,10 @@ class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
|
|||
}
|
||||
|
||||
|
||||
class GeneratedAsRowColumnConstraint(ColumnConstraintKind):
|
||||
arg_types = {"start": True, "hidden": False}
|
||||
|
||||
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/create-table.html
|
||||
class IndexColumnConstraint(ColumnConstraintKind):
|
||||
arg_types = {
|
||||
|
@ -1383,6 +1443,11 @@ class OnUpdateColumnConstraint(ColumnConstraintKind):
|
|||
pass
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/sql/create-external-table#optional-parameters
|
||||
class TransformColumnConstraint(ColumnConstraintKind):
|
||||
pass
|
||||
|
||||
|
||||
class PrimaryKeyColumnConstraint(ColumnConstraintKind):
|
||||
arg_types = {"desc": False}
|
||||
|
||||
|
@ -1413,7 +1478,7 @@ class Constraint(Expression):
|
|||
arg_types = {"this": True, "expressions": True}
|
||||
|
||||
|
||||
class Delete(Expression):
|
||||
class Delete(DML):
|
||||
arg_types = {
|
||||
"with": False,
|
||||
"this": False,
|
||||
|
@ -1496,41 +1561,6 @@ class Delete(Expression):
|
|||
**opts,
|
||||
)
|
||||
|
||||
def returning(
|
||||
self,
|
||||
expression: ExpOrStr,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
) -> Delete:
|
||||
"""
|
||||
Set the RETURNING expression. Not supported by all dialects.
|
||||
|
||||
Example:
|
||||
>>> delete("tbl").returning("*", dialect="postgres").sql()
|
||||
'DELETE FROM tbl RETURNING *'
|
||||
|
||||
Args:
|
||||
expression: the SQL code strings to parse.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
dialect: the dialect used to parse the input expressions.
|
||||
copy: if `False`, modify this expression instance in-place.
|
||||
opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
Delete: the modified expression.
|
||||
"""
|
||||
return _apply_builder(
|
||||
expression=expression,
|
||||
instance=self,
|
||||
arg="returning",
|
||||
prefix="RETURNING",
|
||||
dialect=dialect,
|
||||
copy=copy,
|
||||
into=Returning,
|
||||
**opts,
|
||||
)
|
||||
|
||||
|
||||
class Drop(Expression):
|
||||
arg_types = {
|
||||
|
@ -1648,7 +1678,7 @@ class Index(Expression):
|
|||
}
|
||||
|
||||
|
||||
class Insert(DDL):
|
||||
class Insert(DDL, DML):
|
||||
arg_types = {
|
||||
"with": False,
|
||||
"this": True,
|
||||
|
@ -2259,6 +2289,11 @@ class WithJournalTableProperty(Property):
|
|||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class WithSystemVersioningProperty(Property):
|
||||
# this -> history table name, expression -> data consistency check
|
||||
arg_types = {"this": False, "expression": False}
|
||||
|
||||
|
||||
class Properties(Expression):
|
||||
arg_types = {"expressions": True}
|
||||
|
||||
|
@ -3663,6 +3698,7 @@ class DataType(Expression):
|
|||
Type.BIGINT,
|
||||
Type.INT128,
|
||||
Type.INT256,
|
||||
Type.BIT,
|
||||
}
|
||||
|
||||
FLOAT_TYPES = {
|
||||
|
@ -3692,7 +3728,7 @@ class DataType(Expression):
|
|||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
dtype: str | DataType | DataType.Type,
|
||||
dtype: DATA_TYPE,
|
||||
dialect: DialectType = None,
|
||||
udt: bool = False,
|
||||
**kwargs,
|
||||
|
@ -3733,7 +3769,7 @@ class DataType(Expression):
|
|||
|
||||
return DataType(**{**data_type_exp.args, **kwargs})
|
||||
|
||||
def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool:
|
||||
def is_type(self, *dtypes: DATA_TYPE) -> bool:
|
||||
"""
|
||||
Checks whether this DataType matches one of the provided data types. Nested types or precision
|
||||
will be compared using "structural equivalence" semantics, so e.g. array<int> != array<float>.
|
||||
|
@ -3761,6 +3797,9 @@ class DataType(Expression):
|
|||
return False
|
||||
|
||||
|
||||
DATA_TYPE = t.Union[str, DataType, DataType.Type]
|
||||
|
||||
|
||||
# https://www.postgresql.org/docs/15/datatype-pseudo.html
|
||||
class PseudoType(DataType):
|
||||
arg_types = {"this": True}
|
||||
|
@ -3868,7 +3907,7 @@ class BitwiseXor(Binary):
|
|||
|
||||
|
||||
class Div(Binary):
|
||||
pass
|
||||
arg_types = {"this": True, "expression": True, "typed": False, "safe": False}
|
||||
|
||||
|
||||
class Overlaps(Binary):
|
||||
|
@ -3892,13 +3931,25 @@ class Dot(Binary):
|
|||
|
||||
return t.cast(Dot, reduce(lambda x, y: Dot(this=x, expression=y), expressions))
|
||||
|
||||
@property
|
||||
def parts(self) -> t.List[Expression]:
|
||||
"""Return the parts of a table / column in order catalog, db, table."""
|
||||
this, *parts = self.flatten()
|
||||
|
||||
parts.reverse()
|
||||
|
||||
for arg in ("this", "table", "db", "catalog"):
|
||||
part = this.args.get(arg)
|
||||
|
||||
if isinstance(part, Expression):
|
||||
parts.append(part)
|
||||
|
||||
parts.reverse()
|
||||
return parts
|
||||
|
||||
|
||||
class DPipe(Binary):
|
||||
pass
|
||||
|
||||
|
||||
class SafeDPipe(DPipe):
|
||||
pass
|
||||
arg_types = {"this": True, "expression": True, "safe": False}
|
||||
|
||||
|
||||
class EQ(Binary, Predicate):
|
||||
|
@ -3913,6 +3964,11 @@ class NullSafeNEQ(Binary, Predicate):
|
|||
pass
|
||||
|
||||
|
||||
# Represents e.g. := in DuckDB which is mostly used for setting parameters
|
||||
class PropertyEQ(Binary):
|
||||
pass
|
||||
|
||||
|
||||
class Distance(Binary):
|
||||
pass
|
||||
|
||||
|
@ -3981,6 +4037,11 @@ class NEQ(Binary, Predicate):
|
|||
pass
|
||||
|
||||
|
||||
# https://www.postgresql.org/docs/current/ddl-schemas.html#DDL-SCHEMAS-PATH
|
||||
class Operator(Binary):
|
||||
arg_types = {"this": True, "operator": True, "expression": True}
|
||||
|
||||
|
||||
class SimilarTo(Binary, Predicate):
|
||||
pass
|
||||
|
||||
|
@ -4048,7 +4109,8 @@ class Between(Predicate):
|
|||
|
||||
|
||||
class Bracket(Condition):
|
||||
arg_types = {"this": True, "expressions": True}
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/operators#array_subscript_operator
|
||||
arg_types = {"this": True, "expressions": True, "offset": False, "safe": False}
|
||||
|
||||
@property
|
||||
def output_name(self) -> str:
|
||||
|
@ -4058,10 +4120,6 @@ class Bracket(Condition):
|
|||
return super().output_name
|
||||
|
||||
|
||||
class SafeBracket(Bracket):
|
||||
"""Represents array lookup where OOB index yields NULL instead of causing a failure."""
|
||||
|
||||
|
||||
class Distinct(Expression):
|
||||
arg_types = {"expressions": False, "on": False}
|
||||
|
||||
|
@ -4077,6 +4135,11 @@ class In(Predicate):
|
|||
}
|
||||
|
||||
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#for-in
|
||||
class ForIn(Expression):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
class TimeUnit(Expression):
|
||||
"""Automatically converts unit arg into a var."""
|
||||
|
||||
|
@ -4248,8 +4311,9 @@ class Array(Func):
|
|||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/to_char
|
||||
# https://docs.oracle.com/en/database/oracle/oracle-database/23/sqlrf/TO_CHAR-number.html
|
||||
class ToChar(Func):
|
||||
arg_types = {"this": True, "format": False}
|
||||
arg_types = {"this": True, "format": False, "nlsparam": False}
|
||||
|
||||
|
||||
class GenerateSeries(Func):
|
||||
|
@ -4260,6 +4324,10 @@ class ArrayAgg(AggFunc):
|
|||
pass
|
||||
|
||||
|
||||
class ArrayUniqueAgg(AggFunc):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayAll(Func):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
@ -4358,7 +4426,7 @@ class Cast(Func):
|
|||
def output_name(self) -> str:
|
||||
return self.name
|
||||
|
||||
def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool:
|
||||
def is_type(self, *dtypes: DATA_TYPE) -> bool:
|
||||
"""
|
||||
Checks whether this Cast's DataType matches one of the provided data types. Nested types
|
||||
like arrays or structs will be compared using "structural equivalence" semantics, so e.g.
|
||||
|
@ -4403,14 +4471,10 @@ class Chr(Func):
|
|||
|
||||
|
||||
class Concat(Func):
|
||||
arg_types = {"expressions": True}
|
||||
arg_types = {"expressions": True, "safe": False, "coalesce": False}
|
||||
is_var_len_args = True
|
||||
|
||||
|
||||
class SafeConcat(Concat):
|
||||
pass
|
||||
|
||||
|
||||
class ConcatWs(Concat):
|
||||
_sql_names = ["CONCAT_WS"]
|
||||
|
||||
|
@ -4643,6 +4707,10 @@ class If(Func):
|
|||
arg_types = {"this": True, "true": True, "false": False}
|
||||
|
||||
|
||||
class Nullif(Func):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
class Initcap(Func):
|
||||
arg_types = {"this": True, "expression": False}
|
||||
|
||||
|
@ -4651,6 +4719,10 @@ class IsNan(Func):
|
|||
_sql_names = ["IS_NAN", "ISNAN"]
|
||||
|
||||
|
||||
class IsInf(Func):
|
||||
_sql_names = ["IS_INF", "ISINF"]
|
||||
|
||||
|
||||
class FormatJson(Expression):
|
||||
pass
|
||||
|
||||
|
@ -4970,10 +5042,6 @@ class SafeDivide(Func):
|
|||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
class SetAgg(AggFunc):
|
||||
pass
|
||||
|
||||
|
||||
class SHA(Func):
|
||||
_sql_names = ["SHA", "SHA1"]
|
||||
|
||||
|
@ -5118,6 +5186,15 @@ class Trim(Func):
|
|||
|
||||
|
||||
class TsOrDsAdd(Func, TimeUnit):
|
||||
# return_type is used to correctly cast the arguments of this expression when transpiling it
|
||||
arg_types = {"this": True, "expression": True, "unit": False, "return_type": False}
|
||||
|
||||
@property
|
||||
def return_type(self) -> DataType:
|
||||
return DataType.build(self.args.get("return_type") or DataType.Type.DATE)
|
||||
|
||||
|
||||
class TsOrDsDiff(Func, TimeUnit):
|
||||
arg_types = {"this": True, "expression": True, "unit": False}
|
||||
|
||||
|
||||
|
@ -5149,6 +5226,7 @@ class UnixToTime(Func):
|
|||
SECONDS = Literal.string("seconds")
|
||||
MILLIS = Literal.string("millis")
|
||||
MICROS = Literal.string("micros")
|
||||
NANOS = Literal.string("nanos")
|
||||
|
||||
|
||||
class UnixToTimeStr(Func):
|
||||
|
@ -5202,6 +5280,7 @@ def _norm_arg(arg):
|
|||
|
||||
|
||||
ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
|
||||
FUNCTION_BY_NAME = {name: func for func in ALL_FUNCTIONS for name in func.sql_names()}
|
||||
|
||||
|
||||
# Helpers
|
||||
|
@ -5693,7 +5772,9 @@ def delete(
|
|||
if where:
|
||||
delete_expr = delete_expr.where(where, dialect=dialect, copy=False, **opts)
|
||||
if returning:
|
||||
delete_expr = delete_expr.returning(returning, dialect=dialect, copy=False, **opts)
|
||||
delete_expr = t.cast(
|
||||
Delete, delete_expr.returning(returning, dialect=dialect, copy=False, **opts)
|
||||
)
|
||||
return delete_expr
|
||||
|
||||
|
||||
|
@ -5702,6 +5783,7 @@ def insert(
|
|||
into: ExpOrStr,
|
||||
columns: t.Optional[t.Sequence[ExpOrStr]] = None,
|
||||
overwrite: t.Optional[bool] = None,
|
||||
returning: t.Optional[ExpOrStr] = None,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
|
@ -5718,6 +5800,7 @@ def insert(
|
|||
into: the tbl to insert data to.
|
||||
columns: optionally the table's column names.
|
||||
overwrite: whether to INSERT OVERWRITE or not.
|
||||
returning: sql conditional parsed into a RETURNING statement
|
||||
dialect: the dialect used to parse the input expressions.
|
||||
copy: whether or not to copy the expression.
|
||||
**opts: other options to use to parse the input expressions.
|
||||
|
@ -5739,7 +5822,12 @@ def insert(
|
|||
**opts,
|
||||
)
|
||||
|
||||
return Insert(this=this, expression=expr, overwrite=overwrite)
|
||||
insert = Insert(this=this, expression=expr, overwrite=overwrite)
|
||||
|
||||
if returning:
|
||||
insert = t.cast(Insert, insert.returning(returning, dialect=dialect, copy=False, **opts))
|
||||
|
||||
return insert
|
||||
|
||||
|
||||
def condition(
|
||||
|
@ -5913,7 +6001,7 @@ def to_identifier(name, quoted=None, copy=True):
|
|||
return identifier
|
||||
|
||||
|
||||
def parse_identifier(name: str, dialect: DialectType = None) -> Identifier:
|
||||
def parse_identifier(name: str | Identifier, dialect: DialectType = None) -> Identifier:
|
||||
"""
|
||||
Parses a given string into an identifier.
|
||||
|
||||
|
@ -5965,7 +6053,7 @@ def to_table(sql_path: None, **kwargs) -> None:
|
|||
|
||||
|
||||
def to_table(
|
||||
sql_path: t.Optional[str | Table], dialect: DialectType = None, **kwargs
|
||||
sql_path: t.Optional[str | Table], dialect: DialectType = None, copy: bool = True, **kwargs
|
||||
) -> t.Optional[Table]:
|
||||
"""
|
||||
Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional.
|
||||
|
@ -5974,13 +6062,14 @@ def to_table(
|
|||
Args:
|
||||
sql_path: a `[catalog].[schema].[table]` string.
|
||||
dialect: the source dialect according to which the table name will be parsed.
|
||||
copy: Whether or not to copy a table if it is passed in.
|
||||
kwargs: the kwargs to instantiate the resulting `Table` expression with.
|
||||
|
||||
Returns:
|
||||
A table expression.
|
||||
"""
|
||||
if sql_path is None or isinstance(sql_path, Table):
|
||||
return sql_path
|
||||
return maybe_copy(sql_path, copy=copy)
|
||||
if not isinstance(sql_path, str):
|
||||
raise ValueError(f"Invalid type provided for a table: {type(sql_path)}")
|
||||
|
||||
|
@ -6123,7 +6212,7 @@ def column(
|
|||
)
|
||||
|
||||
|
||||
def cast(expression: ExpOrStr, to: str | DataType | DataType.Type, **opts) -> Cast:
|
||||
def cast(expression: ExpOrStr, to: DATA_TYPE, **opts) -> Cast:
|
||||
"""Cast an expression to a data type.
|
||||
|
||||
Example:
|
||||
|
@ -6335,12 +6424,15 @@ def column_table_names(expression: Expression, exclude: str = "") -> t.Set[str]:
|
|||
}
|
||||
|
||||
|
||||
def table_name(table: Table | str, dialect: DialectType = None) -> str:
|
||||
def table_name(table: Table | str, dialect: DialectType = None, identify: bool = False) -> str:
|
||||
"""Get the full name of a table as a string.
|
||||
|
||||
Args:
|
||||
table: Table expression node or string.
|
||||
dialect: The dialect to generate the table name for.
|
||||
identify: Determines when an identifier should be quoted. Possible values are:
|
||||
False (default): Never quote, except in cases where it's mandatory by the dialect.
|
||||
True: Always quote.
|
||||
|
||||
Examples:
|
||||
>>> from sqlglot import exp, parse_one
|
||||
|
@ -6358,37 +6450,68 @@ def table_name(table: Table | str, dialect: DialectType = None) -> str:
|
|||
|
||||
return ".".join(
|
||||
part.sql(dialect=dialect, identify=True)
|
||||
if not SAFE_IDENTIFIER_RE.match(part.name)
|
||||
if identify or not SAFE_IDENTIFIER_RE.match(part.name)
|
||||
else part.name
|
||||
for part in table.parts
|
||||
)
|
||||
|
||||
|
||||
def replace_tables(expression: E, mapping: t.Dict[str, str], copy: bool = True) -> E:
|
||||
def normalize_table_name(table: str | Table, dialect: DialectType = None, copy: bool = True) -> str:
|
||||
"""Returns a case normalized table name without quotes.
|
||||
|
||||
Args:
|
||||
table: the table to normalize
|
||||
dialect: the dialect to use for normalization rules
|
||||
copy: whether or not to copy the expression.
|
||||
|
||||
Examples:
|
||||
>>> normalize_table_name("`A-B`.c", dialect="bigquery")
|
||||
'A-B.c'
|
||||
"""
|
||||
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
|
||||
|
||||
return ".".join(
|
||||
p.name
|
||||
for p in normalize_identifiers(
|
||||
to_table(table, dialect=dialect, copy=copy), dialect=dialect
|
||||
).parts
|
||||
)
|
||||
|
||||
|
||||
def replace_tables(
|
||||
expression: E, mapping: t.Dict[str, str], dialect: DialectType = None, copy: bool = True
|
||||
) -> E:
|
||||
"""Replace all tables in expression according to the mapping.
|
||||
|
||||
Args:
|
||||
expression: expression node to be transformed and replaced.
|
||||
mapping: mapping of table names.
|
||||
dialect: the dialect of the mapping table
|
||||
copy: whether or not to copy the expression.
|
||||
|
||||
Examples:
|
||||
>>> from sqlglot import exp, parse_one
|
||||
>>> replace_tables(parse_one("select * from a.b"), {"a.b": "c"}).sql()
|
||||
'SELECT * FROM c'
|
||||
'SELECT * FROM c /* a.b */'
|
||||
|
||||
Returns:
|
||||
The mapped expression.
|
||||
"""
|
||||
|
||||
mapping = {normalize_table_name(k, dialect=dialect): v for k, v in mapping.items()}
|
||||
|
||||
def _replace_tables(node: Expression) -> Expression:
|
||||
if isinstance(node, Table):
|
||||
new_name = mapping.get(table_name(node))
|
||||
original = normalize_table_name(node, dialect=dialect)
|
||||
new_name = mapping.get(original)
|
||||
|
||||
if new_name:
|
||||
return to_table(
|
||||
table = to_table(
|
||||
new_name,
|
||||
**{k: v for k, v in node.args.items() if k not in ("this", "db", "catalog")},
|
||||
**{k: v for k, v in node.args.items() if k not in TABLE_PARTS},
|
||||
)
|
||||
table.add_comments([original])
|
||||
return table
|
||||
return node
|
||||
|
||||
return expression.transform(_replace_tables, copy=copy)
|
||||
|
@ -6431,7 +6554,10 @@ def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression:
|
|||
|
||||
|
||||
def expand(
|
||||
expression: Expression, sources: t.Dict[str, Subqueryable], copy: bool = True
|
||||
expression: Expression,
|
||||
sources: t.Dict[str, Subqueryable],
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
) -> Expression:
|
||||
"""Transforms an expression by expanding all referenced sources into subqueries.
|
||||
|
||||
|
@ -6446,15 +6572,17 @@ def expand(
|
|||
Args:
|
||||
expression: The expression to expand.
|
||||
sources: A dictionary of name to Subqueryables.
|
||||
dialect: The dialect of the sources dict.
|
||||
copy: Whether or not to copy the expression during transformation. Defaults to True.
|
||||
|
||||
Returns:
|
||||
The transformed expression.
|
||||
"""
|
||||
sources = {normalize_table_name(k, dialect=dialect): v for k, v in sources.items()}
|
||||
|
||||
def _expand(node: Expression):
|
||||
if isinstance(node, Table):
|
||||
name = table_name(node)
|
||||
name = normalize_table_name(node, dialect=dialect)
|
||||
source = sources.get(name)
|
||||
if source:
|
||||
subquery = source.subquery(node.alias or name)
|
||||
|
@ -6465,7 +6593,7 @@ def expand(
|
|||
return expression.transform(_expand, copy=copy)
|
||||
|
||||
|
||||
def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
|
||||
def func(name: str, *args, copy: bool = True, dialect: DialectType = None, **kwargs) -> Func:
|
||||
"""
|
||||
Returns a Func expression.
|
||||
|
||||
|
@ -6479,6 +6607,7 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
|
|||
Args:
|
||||
name: the name of the function to build.
|
||||
args: the args used to instantiate the function of interest.
|
||||
copy: whether or not to copy the argument expressions.
|
||||
dialect: the source dialect.
|
||||
kwargs: the kwargs used to instantiate the function of interest.
|
||||
|
||||
|
@ -6494,14 +6623,29 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
|
|||
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
|
||||
converted: t.List[Expression] = [maybe_parse(arg, dialect=dialect) for arg in args]
|
||||
kwargs = {key: maybe_parse(value, dialect=dialect) for key, value in kwargs.items()}
|
||||
dialect = Dialect.get_or_raise(dialect)
|
||||
|
||||
parser = Dialect.get_or_raise(dialect)().parser()
|
||||
from_args_list = parser.FUNCTIONS.get(name.upper())
|
||||
converted: t.List[Expression] = [maybe_parse(arg, dialect=dialect, copy=copy) for arg in args]
|
||||
kwargs = {key: maybe_parse(value, dialect=dialect, copy=copy) for key, value in kwargs.items()}
|
||||
|
||||
if from_args_list:
|
||||
function = from_args_list(converted) if converted else from_args_list.__self__(**kwargs) # type: ignore
|
||||
constructor = dialect.parser_class.FUNCTIONS.get(name.upper())
|
||||
if constructor:
|
||||
if converted:
|
||||
if "dialect" in constructor.__code__.co_varnames:
|
||||
function = constructor(converted, dialect=dialect)
|
||||
else:
|
||||
function = constructor(converted)
|
||||
elif constructor.__name__ == "from_arg_list":
|
||||
function = constructor.__self__(**kwargs) # type: ignore
|
||||
else:
|
||||
constructor = FUNCTION_BY_NAME.get(name.upper())
|
||||
if constructor:
|
||||
function = constructor(**kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unable to convert '{name}' into a Func. Either manually construct "
|
||||
"the Func expression of interest or parse the function call."
|
||||
)
|
||||
else:
|
||||
kwargs = kwargs or {"expressions": converted}
|
||||
function = Anonymous(this=name, **kwargs)
|
||||
|
@ -6512,6 +6656,48 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
|
|||
return function
|
||||
|
||||
|
||||
def case(
|
||||
expression: t.Optional[ExpOrStr] = None,
|
||||
**opts,
|
||||
) -> Case:
|
||||
"""
|
||||
Initialize a CASE statement.
|
||||
|
||||
Example:
|
||||
case().when("a = 1", "foo").else_("bar")
|
||||
|
||||
Args:
|
||||
expression: Optionally, the input expression (not all dialects support this)
|
||||
**opts: Extra keyword arguments for parsing `expression`
|
||||
"""
|
||||
if expression is not None:
|
||||
this = maybe_parse(expression, **opts)
|
||||
else:
|
||||
this = None
|
||||
return Case(this=this, ifs=[])
|
||||
|
||||
|
||||
def cast_unless(
|
||||
expression: ExpOrStr,
|
||||
to: DATA_TYPE,
|
||||
*types: DATA_TYPE,
|
||||
**opts: t.Any,
|
||||
) -> Expression | Cast:
|
||||
"""
|
||||
Cast an expression to a data type unless it is a specified type.
|
||||
|
||||
Args:
|
||||
expression: The expression to cast.
|
||||
to: The data type to cast to.
|
||||
**types: The types to exclude from casting.
|
||||
**opts: Extra keyword arguments for parsing `expression`
|
||||
"""
|
||||
expr = maybe_parse(expression, **opts)
|
||||
if expr.is_type(*types):
|
||||
return expr
|
||||
return cast(expr, to, **opts)
|
||||
|
||||
|
||||
def true() -> Boolean:
|
||||
"""
|
||||
Returns a true Boolean expression.
|
||||
|
|
|
@ -9,10 +9,11 @@ from sqlglot import exp
|
|||
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
|
||||
from sqlglot.helper import apply_index_offset, csv, seq_get
|
||||
from sqlglot.time import format_time
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
|
@ -58,9 +59,6 @@ class Generator:
|
|||
exp.DateAdd: lambda self, e: self.func(
|
||||
"DATE_ADD", e.this, e.expression, exp.Literal.string(e.text("unit"))
|
||||
),
|
||||
exp.TsOrDsAdd: lambda self, e: self.func(
|
||||
"TS_OR_DS_ADD", e.this, e.expression, exp.Literal.string(e.text("unit"))
|
||||
),
|
||||
exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
|
||||
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
|
||||
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}",
|
||||
|
@ -108,9 +106,6 @@ class Generator:
|
|||
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
|
||||
}
|
||||
|
||||
# Whether the base comes first
|
||||
LOG_BASE_FIRST = True
|
||||
|
||||
# Whether or not null ordering is supported in order by
|
||||
NULL_ORDERING_SUPPORTED = True
|
||||
|
||||
|
@ -201,7 +196,7 @@ class Generator:
|
|||
VALUES_AS_TABLE = True
|
||||
|
||||
# Whether or not the word COLUMN is included when adding a column with ALTER TABLE
|
||||
ALTER_TABLE_ADD_COLUMN_KEYWORD = True
|
||||
ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = True
|
||||
|
||||
# UNNEST WITH ORDINALITY (presto) instead of UNNEST WITH OFFSET (bigquery)
|
||||
UNNEST_WITH_ORDINALITY = True
|
||||
|
@ -212,9 +207,6 @@ class Generator:
|
|||
# Whether or not JOIN sides (LEFT, RIGHT) are supported in conjunction with SEMI/ANTI join kinds
|
||||
SEMI_ANTI_JOIN_WITH_SIDE = True
|
||||
|
||||
# Whether or not session variables / parameters are supported, e.g. @x in T-SQL
|
||||
SUPPORTS_PARAMETERS = True
|
||||
|
||||
# Whether or not to include the type of a computed column in the CREATE DDL
|
||||
COMPUTED_COLUMN_WITH_TYPE = True
|
||||
|
||||
|
@ -230,12 +222,15 @@ class Generator:
|
|||
# Whether or not data types support additional specifiers like e.g. CHAR or BYTE (oracle)
|
||||
DATA_TYPE_SPECIFIERS_ALLOWED = False
|
||||
|
||||
# Whether or not nested CTEs (e.g. defined inside of subqueries) are allowed
|
||||
SUPPORTS_NESTED_CTES = True
|
||||
# Whether or not conditions require booleans WHERE x = 0 vs WHERE x
|
||||
ENSURE_BOOLS = False
|
||||
|
||||
# Whether or not the "RECURSIVE" keyword is required when defining recursive CTEs
|
||||
CTE_RECURSIVE_KEYWORD_REQUIRED = True
|
||||
|
||||
# Whether or not CONCAT requires >1 arguments
|
||||
SUPPORTS_SINGLE_ARG_CONCAT = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
exp.DataType.Type.NCHAR: "CHAR",
|
||||
exp.DataType.Type.NVARCHAR: "VARCHAR",
|
||||
|
@ -335,6 +330,7 @@ class Generator:
|
|||
exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION,
|
||||
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.WithSystemVersioningProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
}
|
||||
|
||||
# Keywords that can't be used as unquoted identifier names
|
||||
|
@ -368,37 +364,13 @@ class Generator:
|
|||
exp.Paren,
|
||||
)
|
||||
|
||||
# Expressions that need to have all CTEs under them bubbled up to them
|
||||
EXPRESSIONS_WITHOUT_NESTED_CTES: t.Set[t.Type[exp.Expression]] = set()
|
||||
|
||||
KEY_VALUE_DEFINITONS = (exp.Bracket, exp.EQ, exp.PropertyEQ, exp.Slice)
|
||||
|
||||
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
|
||||
|
||||
# Autofilled
|
||||
INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
|
||||
INVERSE_TIME_TRIE: t.Dict = {}
|
||||
INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
|
||||
INDEX_OFFSET = 0
|
||||
UNNEST_COLUMN_ONLY = False
|
||||
ALIAS_POST_TABLESAMPLE = False
|
||||
IDENTIFIERS_CAN_START_WITH_DIGIT = False
|
||||
STRICT_STRING_CONCAT = False
|
||||
NORMALIZE_FUNCTIONS: bool | str = "upper"
|
||||
NULL_ORDERING = "nulls_are_small"
|
||||
|
||||
can_identify: t.Callable[[str, str | bool], bool]
|
||||
|
||||
# Delimiters for quotes, identifiers and the corresponding escape characters
|
||||
QUOTE_START = "'"
|
||||
QUOTE_END = "'"
|
||||
IDENTIFIER_START = '"'
|
||||
IDENTIFIER_END = '"'
|
||||
TOKENIZER_CLASS = Tokenizer
|
||||
|
||||
# Delimiters for bit, hex, byte and raw literals
|
||||
BIT_START: t.Optional[str] = None
|
||||
BIT_END: t.Optional[str] = None
|
||||
HEX_START: t.Optional[str] = None
|
||||
HEX_END: t.Optional[str] = None
|
||||
BYTE_START: t.Optional[str] = None
|
||||
BYTE_END: t.Optional[str] = None
|
||||
|
||||
__slots__ = (
|
||||
"pretty",
|
||||
"identify",
|
||||
|
@ -411,6 +383,7 @@ class Generator:
|
|||
"leading_comma",
|
||||
"max_text_width",
|
||||
"comments",
|
||||
"dialect",
|
||||
"unsupported_messages",
|
||||
"_escaped_quote_end",
|
||||
"_escaped_identifier_end",
|
||||
|
@ -429,8 +402,10 @@ class Generator:
|
|||
leading_comma: bool = False,
|
||||
max_text_width: int = 80,
|
||||
comments: bool = True,
|
||||
dialect: DialectType = None,
|
||||
):
|
||||
import sqlglot
|
||||
from sqlglot.dialects import Dialect
|
||||
|
||||
self.pretty = pretty if pretty is not None else sqlglot.pretty
|
||||
self.identify = identify
|
||||
|
@ -442,16 +417,19 @@ class Generator:
|
|||
self.leading_comma = leading_comma
|
||||
self.max_text_width = max_text_width
|
||||
self.comments = comments
|
||||
self.dialect = Dialect.get_or_raise(dialect)
|
||||
|
||||
# This is both a Dialect property and a Generator argument, so we prioritize the latter
|
||||
self.normalize_functions = (
|
||||
self.NORMALIZE_FUNCTIONS if normalize_functions is None else normalize_functions
|
||||
self.dialect.NORMALIZE_FUNCTIONS if normalize_functions is None else normalize_functions
|
||||
)
|
||||
|
||||
self.unsupported_messages: t.List[str] = []
|
||||
self._escaped_quote_end: str = self.TOKENIZER_CLASS.STRING_ESCAPES[0] + self.QUOTE_END
|
||||
self._escaped_quote_end: str = (
|
||||
self.dialect.tokenizer_class.STRING_ESCAPES[0] + self.dialect.QUOTE_END
|
||||
)
|
||||
self._escaped_identifier_end: str = (
|
||||
self.TOKENIZER_CLASS.IDENTIFIER_ESCAPES[0] + self.IDENTIFIER_END
|
||||
self.dialect.tokenizer_class.IDENTIFIER_ESCAPES[0] + self.dialect.IDENTIFIER_END
|
||||
)
|
||||
|
||||
def generate(self, expression: exp.Expression, copy: bool = True) -> str:
|
||||
|
@ -469,23 +447,14 @@ class Generator:
|
|||
if copy:
|
||||
expression = expression.copy()
|
||||
|
||||
# Some dialects only support CTEs at the top level expression, so we need to bubble up nested
|
||||
# CTEs to that level in order to produce a syntactically valid expression. This transformation
|
||||
# happens here to minimize code duplication, since many expressions support CTEs.
|
||||
if (
|
||||
not self.SUPPORTS_NESTED_CTES
|
||||
and isinstance(expression, exp.Expression)
|
||||
and not expression.parent
|
||||
and "with" in expression.arg_types
|
||||
and any(node.parent is not expression for node in expression.find_all(exp.With))
|
||||
):
|
||||
from sqlglot.transforms import move_ctes_to_top_level
|
||||
|
||||
expression = move_ctes_to_top_level(expression)
|
||||
expression = self.preprocess(expression)
|
||||
|
||||
self.unsupported_messages = []
|
||||
sql = self.sql(expression).strip()
|
||||
|
||||
if self.pretty:
|
||||
sql = sql.replace(self.SENTINEL_LINE_BREAK, "\n")
|
||||
|
||||
if self.unsupported_level == ErrorLevel.IGNORE:
|
||||
return sql
|
||||
|
||||
|
@ -495,10 +464,26 @@ class Generator:
|
|||
elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
|
||||
raise UnsupportedError(concat_messages(self.unsupported_messages, self.max_unsupported))
|
||||
|
||||
if self.pretty:
|
||||
sql = sql.replace(self.SENTINEL_LINE_BREAK, "\n")
|
||||
return sql
|
||||
|
||||
def preprocess(self, expression: exp.Expression) -> exp.Expression:
|
||||
"""Apply generic preprocessing transformations to a given expression."""
|
||||
if (
|
||||
not expression.parent
|
||||
and type(expression) in self.EXPRESSIONS_WITHOUT_NESTED_CTES
|
||||
and any(node.parent is not expression for node in expression.find_all(exp.With))
|
||||
):
|
||||
from sqlglot.transforms import move_ctes_to_top_level
|
||||
|
||||
expression = move_ctes_to_top_level(expression)
|
||||
|
||||
if self.ENSURE_BOOLS:
|
||||
from sqlglot.transforms import ensure_bools
|
||||
|
||||
expression = ensure_bools(expression)
|
||||
|
||||
return expression
|
||||
|
||||
def unsupported(self, message: str) -> None:
|
||||
if self.unsupported_level == ErrorLevel.IMMEDIATE:
|
||||
raise UnsupportedError(message)
|
||||
|
@ -752,9 +737,24 @@ class Generator:
|
|||
|
||||
return f"GENERATED{this} AS {expr}{sequence_opts}"
|
||||
|
||||
def generatedasrowcolumnconstraint_sql(
|
||||
self, expression: exp.GeneratedAsRowColumnConstraint
|
||||
) -> str:
|
||||
start = "START" if expression.args["start"] else "END"
|
||||
hidden = " HIDDEN" if expression.args.get("hidden") else ""
|
||||
return f"GENERATED ALWAYS AS ROW {start}{hidden}"
|
||||
|
||||
def periodforsystemtimeconstraint_sql(
|
||||
self, expression: exp.PeriodForSystemTimeConstraint
|
||||
) -> str:
|
||||
return f"PERIOD FOR SYSTEM_TIME ({self.sql(expression, 'this')}, {self.sql(expression, 'expression')})"
|
||||
|
||||
def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str:
|
||||
return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL"
|
||||
|
||||
def transformcolumnconstraint_sql(self, expression: exp.TransformColumnConstraint) -> str:
|
||||
return f"AS {self.sql(expression, 'this')}"
|
||||
|
||||
def primarykeycolumnconstraint_sql(self, expression: exp.PrimaryKeyColumnConstraint) -> str:
|
||||
desc = expression.args.get("desc")
|
||||
if desc is not None:
|
||||
|
@ -900,32 +900,32 @@ class Generator:
|
|||
columns = self.expressions(expression, key="columns", flat=True)
|
||||
columns = f"({columns})" if columns else ""
|
||||
|
||||
if not alias and not self.UNNEST_COLUMN_ONLY:
|
||||
if not alias and not self.dialect.UNNEST_COLUMN_ONLY:
|
||||
alias = "_t"
|
||||
|
||||
return f"{alias}{columns}"
|
||||
|
||||
def bitstring_sql(self, expression: exp.BitString) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
if self.BIT_START:
|
||||
return f"{self.BIT_START}{this}{self.BIT_END}"
|
||||
if self.dialect.BIT_START:
|
||||
return f"{self.dialect.BIT_START}{this}{self.dialect.BIT_END}"
|
||||
return f"{int(this, 2)}"
|
||||
|
||||
def hexstring_sql(self, expression: exp.HexString) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
if self.HEX_START:
|
||||
return f"{self.HEX_START}{this}{self.HEX_END}"
|
||||
if self.dialect.HEX_START:
|
||||
return f"{self.dialect.HEX_START}{this}{self.dialect.HEX_END}"
|
||||
return f"{int(this, 16)}"
|
||||
|
||||
def bytestring_sql(self, expression: exp.ByteString) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
if self.BYTE_START:
|
||||
return f"{self.BYTE_START}{this}{self.BYTE_END}"
|
||||
if self.dialect.BYTE_START:
|
||||
return f"{self.dialect.BYTE_START}{this}{self.dialect.BYTE_END}"
|
||||
return this
|
||||
|
||||
def rawstring_sql(self, expression: exp.RawString) -> str:
|
||||
string = self.escape_str(expression.this.replace("\\", "\\\\"))
|
||||
return f"{self.QUOTE_START}{string}{self.QUOTE_END}"
|
||||
return f"{self.dialect.QUOTE_START}{string}{self.dialect.QUOTE_END}"
|
||||
|
||||
def datatypeparam_sql(self, expression: exp.DataTypeParam) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -1065,14 +1065,14 @@ class Generator:
|
|||
text = expression.name
|
||||
lower = text.lower()
|
||||
text = lower if self.normalize and not expression.quoted else text
|
||||
text = text.replace(self.IDENTIFIER_END, self._escaped_identifier_end)
|
||||
text = text.replace(self.dialect.IDENTIFIER_END, self._escaped_identifier_end)
|
||||
if (
|
||||
expression.quoted
|
||||
or self.can_identify(text, self.identify)
|
||||
or self.dialect.can_identify(text, self.identify)
|
||||
or lower in self.RESERVED_KEYWORDS
|
||||
or (not self.IDENTIFIERS_CAN_START_WITH_DIGIT and text[:1].isdigit())
|
||||
or (not self.dialect.IDENTIFIERS_CAN_START_WITH_DIGIT and text[:1].isdigit())
|
||||
):
|
||||
text = f"{self.IDENTIFIER_START}{text}{self.IDENTIFIER_END}"
|
||||
text = f"{self.dialect.IDENTIFIER_START}{text}{self.dialect.IDENTIFIER_END}"
|
||||
return text
|
||||
|
||||
def inputoutputformat_sql(self, expression: exp.InputOutputFormat) -> str:
|
||||
|
@ -1121,7 +1121,7 @@ class Generator:
|
|||
expressions = self.expressions(properties, sep=sep, indent=False)
|
||||
if expressions:
|
||||
expressions = self.wrap(expressions) if wrapped else expressions
|
||||
return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}"
|
||||
return f"{prefix}{' ' if prefix.strip() else ''}{expressions}{suffix}"
|
||||
return ""
|
||||
|
||||
def with_properties(self, properties: exp.Properties) -> str:
|
||||
|
@ -1286,6 +1286,21 @@ class Generator:
|
|||
statistics_sql = f" AND {'NO ' if not statistics else ''}STATISTICS"
|
||||
return f"{data_sql}{statistics_sql}"
|
||||
|
||||
def withsystemversioningproperty_sql(self, expression: exp.WithSystemVersioningProperty) -> str:
|
||||
sql = "WITH(SYSTEM_VERSIONING=ON"
|
||||
|
||||
if expression.this:
|
||||
history_table = self.sql(expression, "this")
|
||||
sql = f"{sql}(HISTORY_TABLE={history_table}"
|
||||
|
||||
if expression.expression:
|
||||
data_consistency_check = self.sql(expression, "expression")
|
||||
sql = f"{sql}, DATA_CONSISTENCY_CHECK={data_consistency_check}"
|
||||
|
||||
sql = f"{sql})"
|
||||
|
||||
return f"{sql})"
|
||||
|
||||
def insert_sql(self, expression: exp.Insert) -> str:
|
||||
overwrite = expression.args.get("overwrite")
|
||||
|
||||
|
@ -1387,13 +1402,13 @@ class Generator:
|
|||
|
||||
def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str:
|
||||
table = ".".join(
|
||||
part
|
||||
for part in [
|
||||
self.sql(expression, "catalog"),
|
||||
self.sql(expression, "db"),
|
||||
self.sql(expression, "this"),
|
||||
]
|
||||
if part
|
||||
self.sql(part)
|
||||
for part in (
|
||||
expression.args.get("catalog"),
|
||||
expression.args.get("db"),
|
||||
expression.args.get("this"),
|
||||
)
|
||||
if part is not None
|
||||
)
|
||||
|
||||
version = self.sql(expression, "version")
|
||||
|
@ -1426,7 +1441,7 @@ class Generator:
|
|||
def tablesample_sql(
|
||||
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
|
||||
) -> str:
|
||||
if self.ALIAS_POST_TABLESAMPLE and expression.this.alias:
|
||||
if self.dialect.ALIAS_POST_TABLESAMPLE and expression.this and expression.this.alias:
|
||||
table = expression.this.copy()
|
||||
table.set("alias", None)
|
||||
this = self.sql(table)
|
||||
|
@ -1676,12 +1691,16 @@ class Generator:
|
|||
|
||||
def limit_sql(self, expression: exp.Limit, top: bool = False) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
args = ", ".join(
|
||||
self.sql(self._simplify_unless_literal(e) if self.LIMIT_ONLY_LITERALS else e)
|
||||
|
||||
args = [
|
||||
self._simplify_unless_literal(e) if self.LIMIT_ONLY_LITERALS else e
|
||||
for e in (expression.args.get(k) for k in ("offset", "expression"))
|
||||
if e
|
||||
)
|
||||
return f"{this}{self.seg('TOP' if top else 'LIMIT')} {args}"
|
||||
]
|
||||
|
||||
args_sql = ", ".join(self.sql(e) for e in args)
|
||||
args_sql = f"({args_sql})" if any(top and not e.is_number for e in args) else args_sql
|
||||
return f"{this}{self.seg('TOP' if top else 'LIMIT')} {args_sql}"
|
||||
|
||||
def offset_sql(self, expression: exp.Offset) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -1732,13 +1751,13 @@ class Generator:
|
|||
def literal_sql(self, expression: exp.Literal) -> str:
|
||||
text = expression.this or ""
|
||||
if expression.is_string:
|
||||
text = f"{self.QUOTE_START}{self.escape_str(text)}{self.QUOTE_END}"
|
||||
text = f"{self.dialect.QUOTE_START}{self.escape_str(text)}{self.dialect.QUOTE_END}"
|
||||
return text
|
||||
|
||||
def escape_str(self, text: str) -> str:
|
||||
text = text.replace(self.QUOTE_END, self._escaped_quote_end)
|
||||
if self.INVERSE_ESCAPE_SEQUENCES:
|
||||
text = "".join(self.INVERSE_ESCAPE_SEQUENCES.get(ch, ch) for ch in text)
|
||||
text = text.replace(self.dialect.QUOTE_END, self._escaped_quote_end)
|
||||
if self.dialect.INVERSE_ESCAPE_SEQUENCES:
|
||||
text = "".join(self.dialect.INVERSE_ESCAPE_SEQUENCES.get(ch, ch) for ch in text)
|
||||
elif self.pretty:
|
||||
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
|
||||
return text
|
||||
|
@ -1782,9 +1801,11 @@ class Generator:
|
|||
|
||||
nulls_first = expression.args.get("nulls_first")
|
||||
nulls_last = not nulls_first
|
||||
nulls_are_large = self.NULL_ORDERING == "nulls_are_large"
|
||||
nulls_are_small = self.NULL_ORDERING == "nulls_are_small"
|
||||
nulls_are_last = self.NULL_ORDERING == "nulls_are_last"
|
||||
nulls_are_large = self.dialect.NULL_ORDERING == "nulls_are_large"
|
||||
nulls_are_small = self.dialect.NULL_ORDERING == "nulls_are_small"
|
||||
nulls_are_last = self.dialect.NULL_ORDERING == "nulls_are_last"
|
||||
|
||||
this = self.sql(expression, "this")
|
||||
|
||||
sort_order = " DESC" if desc else (" ASC" if desc is False else "")
|
||||
nulls_sort_change = ""
|
||||
|
@ -1799,13 +1820,13 @@ class Generator:
|
|||
):
|
||||
nulls_sort_change = " NULLS LAST"
|
||||
|
||||
# If the NULLS FIRST/LAST clause is unsupported, we add another sort key to simulate it
|
||||
if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED:
|
||||
self.unsupported(
|
||||
"Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect"
|
||||
)
|
||||
null_sort_order = " DESC" if nulls_sort_change == " NULLS FIRST" else ""
|
||||
this = f"CASE WHEN {this} IS NULL THEN 1 ELSE 0 END{null_sort_order}, {this}"
|
||||
nulls_sort_change = ""
|
||||
|
||||
return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}"
|
||||
return f"{this}{sort_order}{nulls_sort_change}"
|
||||
|
||||
def matchrecognize_sql(self, expression: exp.MatchRecognize) -> str:
|
||||
partition = self.partition_by_sql(expression)
|
||||
|
@ -1933,10 +1954,13 @@ class Generator:
|
|||
)
|
||||
kind = ""
|
||||
|
||||
# We use LIMIT_IS_TOP as a proxy for whether DISTINCT should go first because tsql and Teradata
|
||||
# are the only dialects that use LIMIT_IS_TOP and both place DISTINCT first.
|
||||
top_distinct = f"{distinct}{hint}{top}" if self.LIMIT_IS_TOP else f"{top}{hint}{distinct}"
|
||||
expressions = f"{self.sep()}{expressions}" if expressions else expressions
|
||||
sql = self.query_modifiers(
|
||||
expression,
|
||||
f"SELECT{top}{hint}{distinct}{kind}{expressions}",
|
||||
f"SELECT{top_distinct}{kind}{expressions}",
|
||||
self.sql(expression, "into", comment=False),
|
||||
self.sql(expression, "from", comment=False),
|
||||
)
|
||||
|
@ -1961,7 +1985,7 @@ class Generator:
|
|||
|
||||
def parameter_sql(self, expression: exp.Parameter) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
return f"{self.PARAMETER_TOKEN}{this}" if self.SUPPORTS_PARAMETERS else this
|
||||
return f"{self.PARAMETER_TOKEN}{this}"
|
||||
|
||||
def sessionparameter_sql(self, expression: exp.SessionParameter) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -2009,7 +2033,7 @@ class Generator:
|
|||
if alias and isinstance(offset, exp.Expression):
|
||||
alias.append("columns", offset)
|
||||
|
||||
if alias and self.UNNEST_COLUMN_ONLY:
|
||||
if alias and self.dialect.UNNEST_COLUMN_ONLY:
|
||||
columns = alias.columns
|
||||
alias = self.sql(columns[0]) if columns else ""
|
||||
else:
|
||||
|
@ -2080,14 +2104,14 @@ class Generator:
|
|||
return f"{this} BETWEEN {low} AND {high}"
|
||||
|
||||
def bracket_sql(self, expression: exp.Bracket) -> str:
|
||||
expressions = apply_index_offset(expression.this, expression.expressions, self.INDEX_OFFSET)
|
||||
expressions = apply_index_offset(
|
||||
expression.this,
|
||||
expression.expressions,
|
||||
self.dialect.INDEX_OFFSET - expression.args.get("offset", 0),
|
||||
)
|
||||
expressions_sql = ", ".join(self.sql(e) for e in expressions)
|
||||
|
||||
return f"{self.sql(expression, 'this')}[{expressions_sql}]"
|
||||
|
||||
def safebracket_sql(self, expression: exp.SafeBracket) -> str:
|
||||
return self.bracket_sql(expression)
|
||||
|
||||
def all_sql(self, expression: exp.All) -> str:
|
||||
return f"ALL {self.wrap(expression)}"
|
||||
|
||||
|
@ -2145,12 +2169,33 @@ class Generator:
|
|||
else:
|
||||
return self.func("TRIM", expression.this, expression.expression)
|
||||
|
||||
def safeconcat_sql(self, expression: exp.SafeConcat) -> str:
|
||||
expressions = expression.expressions
|
||||
if self.STRICT_STRING_CONCAT:
|
||||
expressions = (exp.cast(e, "text") for e in expressions)
|
||||
def convert_concat_args(self, expression: exp.Concat | exp.ConcatWs) -> t.List[exp.Expression]:
|
||||
args = expression.expressions
|
||||
if isinstance(expression, exp.ConcatWs):
|
||||
args = args[1:] # Skip the delimiter
|
||||
|
||||
if self.dialect.STRICT_STRING_CONCAT and expression.args.get("safe"):
|
||||
args = [exp.cast(e, "text") for e in args]
|
||||
|
||||
if not self.dialect.CONCAT_COALESCE and expression.args.get("coalesce"):
|
||||
args = [exp.func("coalesce", e, exp.Literal.string("")) for e in args]
|
||||
|
||||
return args
|
||||
|
||||
def concat_sql(self, expression: exp.Concat) -> str:
|
||||
expressions = self.convert_concat_args(expression)
|
||||
|
||||
# Some dialects don't allow a single-argument CONCAT call
|
||||
if not self.SUPPORTS_SINGLE_ARG_CONCAT and len(expressions) == 1:
|
||||
return self.sql(expressions[0])
|
||||
|
||||
return self.func("CONCAT", *expressions)
|
||||
|
||||
def concatws_sql(self, expression: exp.ConcatWs) -> str:
|
||||
return self.func(
|
||||
"CONCAT_WS", seq_get(expression.expressions, 0), *self.convert_concat_args(expression)
|
||||
)
|
||||
|
||||
def check_sql(self, expression: exp.Check) -> str:
|
||||
this = self.sql(expression, key="this")
|
||||
return f"CHECK ({this})"
|
||||
|
@ -2493,14 +2538,7 @@ class Generator:
|
|||
actions = expression.args["actions"]
|
||||
|
||||
if isinstance(actions[0], exp.ColumnDef):
|
||||
if self.ALTER_TABLE_ADD_COLUMN_KEYWORD:
|
||||
actions = self.expressions(
|
||||
expression,
|
||||
key="actions",
|
||||
prefix="ADD COLUMN ",
|
||||
)
|
||||
else:
|
||||
actions = f"ADD {self.expressions(expression, key='actions')}"
|
||||
actions = self.add_column_sql(expression)
|
||||
elif isinstance(actions[0], exp.Schema):
|
||||
actions = self.expressions(expression, key="actions", prefix="ADD COLUMNS ")
|
||||
elif isinstance(actions[0], exp.Delete):
|
||||
|
@ -2512,6 +2550,15 @@ class Generator:
|
|||
only = " ONLY" if expression.args.get("only") else ""
|
||||
return f"ALTER TABLE{exists}{only} {self.sql(expression, 'this')} {actions}"
|
||||
|
||||
def add_column_sql(self, expression: exp.AlterTable) -> str:
|
||||
if self.ALTER_TABLE_INCLUDE_COLUMN_KEYWORD:
|
||||
return self.expressions(
|
||||
expression,
|
||||
key="actions",
|
||||
prefix="ADD COLUMN ",
|
||||
)
|
||||
return f"ADD {self.expressions(expression, key='actions', flat=True)}"
|
||||
|
||||
def droppartition_sql(self, expression: exp.DropPartition) -> str:
|
||||
expressions = self.expressions(expression)
|
||||
exists = " IF EXISTS " if expression.args.get("exists") else " "
|
||||
|
@ -2551,14 +2598,31 @@ class Generator:
|
|||
)
|
||||
|
||||
def dpipe_sql(self, expression: exp.DPipe) -> str:
|
||||
if self.dialect.STRICT_STRING_CONCAT and expression.args.get("safe"):
|
||||
return self.func("CONCAT", *(exp.cast(e, "text") for e in expression.flatten()))
|
||||
return self.binary(expression, "||")
|
||||
|
||||
def safedpipe_sql(self, expression: exp.SafeDPipe) -> str:
|
||||
if self.STRICT_STRING_CONCAT:
|
||||
return self.func("CONCAT", *(exp.cast(e, "text") for e in expression.flatten()))
|
||||
return self.dpipe_sql(expression)
|
||||
|
||||
def div_sql(self, expression: exp.Div) -> str:
|
||||
l, r = expression.left, expression.right
|
||||
|
||||
if not self.dialect.SAFE_DIVISION and expression.args.get("safe"):
|
||||
r.replace(exp.Nullif(this=r.copy(), expression=exp.Literal.number(0)))
|
||||
|
||||
if self.dialect.TYPED_DIVISION and not expression.args.get("typed"):
|
||||
if not l.is_type(*exp.DataType.FLOAT_TYPES) and not r.is_type(
|
||||
*exp.DataType.FLOAT_TYPES
|
||||
):
|
||||
l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DOUBLE))
|
||||
|
||||
elif not self.dialect.TYPED_DIVISION and expression.args.get("typed"):
|
||||
if l.is_type(*exp.DataType.INTEGER_TYPES) and r.is_type(*exp.DataType.INTEGER_TYPES):
|
||||
return self.sql(
|
||||
exp.cast(
|
||||
l / r,
|
||||
to=exp.DataType.Type.BIGINT,
|
||||
)
|
||||
)
|
||||
|
||||
return self.binary(expression, "/")
|
||||
|
||||
def overlaps_sql(self, expression: exp.Overlaps) -> str:
|
||||
|
@ -2573,6 +2637,9 @@ class Generator:
|
|||
def eq_sql(self, expression: exp.EQ) -> str:
|
||||
return self.binary(expression, "=")
|
||||
|
||||
def propertyeq_sql(self, expression: exp.PropertyEQ) -> str:
|
||||
return self.binary(expression, ":=")
|
||||
|
||||
def escape_sql(self, expression: exp.Escape) -> str:
|
||||
return self.binary(expression, "ESCAPE")
|
||||
|
||||
|
@ -2641,10 +2708,13 @@ class Generator:
|
|||
return self.cast_sql(expression, safe_prefix="TRY_")
|
||||
|
||||
def log_sql(self, expression: exp.Log) -> str:
|
||||
args = list(expression.args.values())
|
||||
if not self.LOG_BASE_FIRST:
|
||||
args.reverse()
|
||||
return self.func("LOG", *args)
|
||||
this = expression.this
|
||||
expr = expression.expression
|
||||
|
||||
if not self.dialect.LOG_BASE_FIRST:
|
||||
this, expr = expr, this
|
||||
|
||||
return self.func("LOG", this, expr)
|
||||
|
||||
def use_sql(self, expression: exp.Use) -> str:
|
||||
kind = self.sql(expression, "kind")
|
||||
|
@ -2696,7 +2766,9 @@ class Generator:
|
|||
|
||||
def format_time(self, expression: exp.Expression) -> t.Optional[str]:
|
||||
return format_time(
|
||||
self.sql(expression, "format"), self.INVERSE_TIME_MAPPING, self.INVERSE_TIME_TRIE
|
||||
self.sql(expression, "format"),
|
||||
self.dialect.INVERSE_TIME_MAPPING,
|
||||
self.dialect.INVERSE_TIME_TRIE,
|
||||
)
|
||||
|
||||
def expressions(
|
||||
|
@ -2963,6 +3035,19 @@ class Generator:
|
|||
parameters = self.sql(expression, "params_struct")
|
||||
return self.func("PREDICT", model, table, parameters or None)
|
||||
|
||||
def forin_sql(self, expression: exp.ForIn) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
expression_sql = self.sql(expression, "expression")
|
||||
return f"FOR {this} DO {expression_sql}"
|
||||
|
||||
def refresh_sql(self, expression: exp.Refresh) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
table = "" if isinstance(expression.this, exp.Literal) else "TABLE "
|
||||
return f"REFRESH {table}{this}"
|
||||
|
||||
def operator_sql(self, expression: exp.Operator) -> str:
|
||||
return self.binary(expression, f"OPERATOR({self.sql(expression, 'operator')})")
|
||||
|
||||
def _simplify_unless_literal(self, expression: E) -> E:
|
||||
if not isinstance(expression, exp.Literal):
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
@ -2970,3 +3055,10 @@ class Generator:
|
|||
expression = simplify(expression)
|
||||
|
||||
return expression
|
||||
|
||||
def _ensure_string_if_null(self, values: t.List[exp.Expression]) -> t.List[exp.Expression]:
|
||||
return [
|
||||
exp.func("COALESCE", exp.cast(value, "text"), exp.Literal.string(""))
|
||||
for value in values
|
||||
if value
|
||||
]
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
|
@ -283,7 +284,7 @@ def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
|
|||
file = open_file(read_csv.name)
|
||||
|
||||
delimiter = ","
|
||||
args = iter(arg.name for arg in args)
|
||||
args = iter(arg.name for arg in args) # type: ignore
|
||||
for k, v in zip(args, args):
|
||||
if k == "delimiter":
|
||||
delimiter = v
|
||||
|
@ -463,3 +464,27 @@ def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
|
|||
merged.append((start, end))
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def is_iso_date(text: str) -> bool:
|
||||
try:
|
||||
datetime.date.fromisoformat(text)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def is_iso_datetime(text: str) -> bool:
|
||||
try:
|
||||
datetime.datetime.fromisoformat(text)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
# Interval units that operate on date components
|
||||
DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
|
||||
|
||||
|
||||
def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
|
||||
return expression is not None and expression.name.lower() in DATE_UNITS
|
||||
|
|
|
@ -6,7 +6,7 @@ from dataclasses import dataclass, field
|
|||
|
||||
from sqlglot import Schema, exp, maybe_parse
|
||||
from sqlglot.errors import SqlglotError
|
||||
from sqlglot.optimizer import Scope, build_scope, qualify
|
||||
from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, qualify
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
|
@ -29,8 +29,38 @@ class Node:
|
|||
else:
|
||||
yield d
|
||||
|
||||
def to_html(self, **opts) -> LineageHTML:
|
||||
return LineageHTML(self, **opts)
|
||||
def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML:
|
||||
nodes = {}
|
||||
edges = []
|
||||
|
||||
for node in self.walk():
|
||||
if isinstance(node.expression, exp.Table):
|
||||
label = f"FROM {node.expression.this}"
|
||||
title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
|
||||
group = 1
|
||||
else:
|
||||
label = node.expression.sql(pretty=True, dialect=dialect)
|
||||
source = node.source.transform(
|
||||
lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
|
||||
if n is node.expression
|
||||
else n,
|
||||
copy=False,
|
||||
).sql(pretty=True, dialect=dialect)
|
||||
title = f"<pre>{source}</pre>"
|
||||
group = 0
|
||||
|
||||
node_id = id(node)
|
||||
|
||||
nodes[node_id] = {
|
||||
"id": node_id,
|
||||
"label": label,
|
||||
"title": title,
|
||||
"group": group,
|
||||
}
|
||||
|
||||
for d in node.downstream:
|
||||
edges.append({"from": node_id, "to": id(d)})
|
||||
return GraphHTML(nodes, edges, **opts)
|
||||
|
||||
|
||||
def lineage(
|
||||
|
@ -64,6 +94,7 @@ def lineage(
|
|||
k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
|
||||
for k, v in sources.items()
|
||||
},
|
||||
dialect=dialect,
|
||||
)
|
||||
|
||||
qualified = qualify.qualify(
|
||||
|
@ -129,17 +160,6 @@ def lineage(
|
|||
|
||||
return upstream
|
||||
|
||||
subquery = select.unalias()
|
||||
|
||||
if isinstance(subquery, exp.Subquery):
|
||||
upstream = upstream or Node(name="SUBQUERY", source=scope.expression, expression=select)
|
||||
scope = t.cast(Scope, build_scope(subquery.unnest()))
|
||||
|
||||
for select in subquery.named_selects:
|
||||
to_node(select, scope=scope, upstream=upstream)
|
||||
|
||||
return upstream
|
||||
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
# For better ergonomics in our node labels, replace the full select with
|
||||
# a version that has only the column we care about.
|
||||
|
@ -156,16 +176,28 @@ def lineage(
|
|||
expression=select,
|
||||
alias=alias or "",
|
||||
)
|
||||
|
||||
if upstream:
|
||||
upstream.downstream.append(node)
|
||||
|
||||
subquery_scopes = {
|
||||
id(subquery_scope.expression): subquery_scope
|
||||
for subquery_scope in scope.subquery_scopes
|
||||
}
|
||||
|
||||
for subquery in find_all_in_scope(select, exp.Subqueryable):
|
||||
subquery_scope = subquery_scopes[id(subquery)]
|
||||
|
||||
for name in subquery.named_selects:
|
||||
to_node(name, scope=subquery_scope, upstream=node)
|
||||
|
||||
# if the select is a star add all scope sources as downstreams
|
||||
if select.is_star:
|
||||
for source in scope.sources.values():
|
||||
node.downstream.append(Node(name=select.sql(), source=source, expression=source))
|
||||
|
||||
# Find all columns that went into creating this one to list their lineage nodes.
|
||||
source_columns = set(select.find_all(exp.Column))
|
||||
source_columns = set(find_all_in_scope(select, exp.Column))
|
||||
|
||||
# If the source is a UDTF find columns used in the UTDF to generate the table
|
||||
if isinstance(source, exp.UDTF):
|
||||
|
@ -192,20 +224,15 @@ def lineage(
|
|||
return to_node(column if isinstance(column, str) else column.name, scope)
|
||||
|
||||
|
||||
class LineageHTML:
|
||||
class GraphHTML:
|
||||
"""Node to HTML generator using vis.js.
|
||||
|
||||
https://visjs.github.io/vis-network/docs/network/
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node: Node,
|
||||
dialect: DialectType = None,
|
||||
imports: bool = True,
|
||||
**opts: t.Any,
|
||||
self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None
|
||||
):
|
||||
self.node = node
|
||||
self.imports = imports
|
||||
|
||||
self.options = {
|
||||
|
@ -235,39 +262,11 @@ class LineageHTML:
|
|||
"maximum": 300,
|
||||
},
|
||||
},
|
||||
**opts,
|
||||
**(options or {}),
|
||||
}
|
||||
|
||||
self.nodes = {}
|
||||
self.edges = []
|
||||
|
||||
for node in node.walk():
|
||||
if isinstance(node.expression, exp.Table):
|
||||
label = f"FROM {node.expression.this}"
|
||||
title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
|
||||
group = 1
|
||||
else:
|
||||
label = node.expression.sql(pretty=True, dialect=dialect)
|
||||
source = node.source.transform(
|
||||
lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
|
||||
if n is node.expression
|
||||
else n,
|
||||
copy=False,
|
||||
).sql(pretty=True, dialect=dialect)
|
||||
title = f"<pre>{source}</pre>"
|
||||
group = 0
|
||||
|
||||
node_id = id(node)
|
||||
|
||||
self.nodes[node_id] = {
|
||||
"id": node_id,
|
||||
"label": label,
|
||||
"title": title,
|
||||
"group": group,
|
||||
}
|
||||
|
||||
for d in node.downstream:
|
||||
self.edges.append({"from": node_id, "to": id(d)})
|
||||
self.nodes = nodes
|
||||
self.edges = edges
|
||||
|
||||
def __str__(self):
|
||||
nodes = json.dumps(list(self.nodes.values()))
|
||||
|
|
|
@ -1,12 +1,18 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.helper import ensure_list, seq_get, subclasses
|
||||
from sqlglot.helper import (
|
||||
ensure_list,
|
||||
is_date_unit,
|
||||
is_iso_date,
|
||||
is_iso_datetime,
|
||||
seq_get,
|
||||
subclasses,
|
||||
)
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import Schema, ensure_schema
|
||||
|
||||
|
@ -20,10 +26,6 @@ if t.TYPE_CHECKING:
|
|||
]
|
||||
|
||||
|
||||
# Interval units that operate on date components
|
||||
DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
|
||||
|
||||
|
||||
def annotate_types(
|
||||
expression: E,
|
||||
schema: t.Optional[t.Dict | Schema] = None,
|
||||
|
@ -60,43 +62,22 @@ def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[Type
|
|||
return lambda self, e: self._annotate_with_type(e, data_type)
|
||||
|
||||
|
||||
def _is_iso_date(text: str) -> bool:
|
||||
try:
|
||||
datetime.date.fromisoformat(text)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _is_iso_datetime(text: str) -> bool:
|
||||
try:
|
||||
datetime.datetime.fromisoformat(text)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _coerce_literal_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
|
||||
def _coerce_date_literal(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
|
||||
date_text = l.name
|
||||
unit = r.text("unit").lower()
|
||||
is_iso_date_ = is_iso_date(date_text)
|
||||
|
||||
is_iso_date = _is_iso_date(date_text)
|
||||
|
||||
if is_iso_date and unit in DATE_UNITS:
|
||||
l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATE))
|
||||
if is_iso_date_ and is_date_unit(unit):
|
||||
return exp.DataType.Type.DATE
|
||||
|
||||
# An ISO date is also an ISO datetime, but not vice versa
|
||||
if is_iso_date or _is_iso_datetime(date_text):
|
||||
l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATETIME))
|
||||
if is_iso_date_ or is_iso_datetime(date_text):
|
||||
return exp.DataType.Type.DATETIME
|
||||
|
||||
return exp.DataType.Type.UNKNOWN
|
||||
|
||||
|
||||
def _coerce_date_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
|
||||
unit = r.text("unit").lower()
|
||||
if unit not in DATE_UNITS:
|
||||
def _coerce_date(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
|
||||
if not is_date_unit(unit):
|
||||
return exp.DataType.Type.DATETIME
|
||||
return l.type.this if l.type else exp.DataType.Type.UNKNOWN
|
||||
|
||||
|
@ -171,7 +152,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.Date,
|
||||
exp.DateFromParts,
|
||||
exp.DateStrToDate,
|
||||
exp.DateTrunc,
|
||||
exp.DiToDate,
|
||||
exp.StrToDate,
|
||||
exp.TimeStrToDate,
|
||||
|
@ -185,6 +165,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.DataType.Type.DOUBLE: {
|
||||
exp.ApproxQuantile,
|
||||
exp.Avg,
|
||||
exp.Div,
|
||||
exp.Exp,
|
||||
exp.Ln,
|
||||
exp.Log,
|
||||
|
@ -203,8 +184,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
},
|
||||
exp.DataType.Type.INT: {
|
||||
exp.Ceil,
|
||||
exp.DateDiff,
|
||||
exp.DatetimeDiff,
|
||||
exp.DateDiff,
|
||||
exp.Extract,
|
||||
exp.TimestampDiff,
|
||||
exp.TimeDiff,
|
||||
|
@ -240,8 +221,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.GroupConcat,
|
||||
exp.Initcap,
|
||||
exp.Lower,
|
||||
exp.SafeConcat,
|
||||
exp.SafeDPipe,
|
||||
exp.Substring,
|
||||
exp.TimeToStr,
|
||||
exp.TimeToTimeStr,
|
||||
|
@ -267,6 +246,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
|
||||
for expr_type in expressions
|
||||
},
|
||||
exp.Abs: lambda self, e: self._annotate_by_args(e, "this"),
|
||||
exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
|
||||
exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
|
||||
exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
|
||||
|
@ -276,9 +256,11 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
|
||||
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
|
||||
exp.DateAdd: lambda self, e: self._annotate_dateadd(e),
|
||||
exp.DateSub: lambda self, e: self._annotate_dateadd(e),
|
||||
exp.DateAdd: lambda self, e: self._annotate_timeunit(e),
|
||||
exp.DateSub: lambda self, e: self._annotate_timeunit(e),
|
||||
exp.DateTrunc: lambda self, e: self._annotate_timeunit(e),
|
||||
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
|
||||
exp.Div: lambda self, e: self._annotate_div(e),
|
||||
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
|
||||
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
|
||||
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
|
||||
|
@ -288,6 +270,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
|
||||
exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
|
||||
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
|
||||
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
|
||||
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
|
||||
|
@ -306,13 +289,27 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
BINARY_COERCIONS: BinaryCoercions = {
|
||||
**swap_all(
|
||||
{
|
||||
(t, exp.DataType.Type.INTERVAL): _coerce_literal_and_interval
|
||||
(t, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date_literal(
|
||||
l, r.args.get("unit")
|
||||
)
|
||||
for t in exp.DataType.TEXT_TYPES
|
||||
}
|
||||
),
|
||||
**swap_all(
|
||||
{
|
||||
(exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): _coerce_date_and_interval,
|
||||
# text + numeric will yield the numeric type to match most dialects' semantics
|
||||
(text, numeric): lambda l, r: t.cast(
|
||||
exp.DataType.Type, l.type if l.type in exp.DataType.NUMERIC_TYPES else r.type
|
||||
)
|
||||
for text in exp.DataType.TEXT_TYPES
|
||||
for numeric in exp.DataType.NUMERIC_TYPES
|
||||
}
|
||||
),
|
||||
**swap_all(
|
||||
{
|
||||
(exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date(
|
||||
l, r.args.get("unit")
|
||||
),
|
||||
}
|
||||
),
|
||||
}
|
||||
|
@ -511,18 +508,17 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
return expression
|
||||
|
||||
def _annotate_dateadd(self, expression: exp.IntervalOp) -> exp.IntervalOp:
|
||||
def _annotate_timeunit(
|
||||
self, expression: exp.TimeUnit | exp.DateTrunc
|
||||
) -> exp.TimeUnit | exp.DateTrunc:
|
||||
self._annotate_args(expression)
|
||||
|
||||
if expression.this.type.this in exp.DataType.TEXT_TYPES:
|
||||
datatype = _coerce_literal_and_interval(expression.this, expression.interval())
|
||||
elif (
|
||||
expression.this.type.is_type(exp.DataType.Type.DATE)
|
||||
and expression.text("unit").lower() not in DATE_UNITS
|
||||
):
|
||||
datatype = exp.DataType.Type.DATETIME
|
||||
datatype = _coerce_date_literal(expression.this, expression.unit)
|
||||
elif expression.this.type.this in exp.DataType.TEMPORAL_TYPES:
|
||||
datatype = _coerce_date(expression.this, expression.unit)
|
||||
else:
|
||||
datatype = expression.this.type
|
||||
datatype = exp.DataType.Type.UNKNOWN
|
||||
|
||||
self._set_type(expression, datatype)
|
||||
return expression
|
||||
|
@ -547,3 +543,19 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
self._set_type(expression, exp.DataType.Type.UNKNOWN)
|
||||
|
||||
return expression
|
||||
|
||||
def _annotate_div(self, expression: exp.Div) -> exp.Div:
|
||||
self._annotate_args(expression)
|
||||
|
||||
left_type, right_type = expression.left.type.this, expression.right.type.this # type: ignore
|
||||
|
||||
if (
|
||||
expression.args.get("typed")
|
||||
and left_type in exp.DataType.INTEGER_TYPES
|
||||
and right_type in exp.DataType.INTEGER_TYPES
|
||||
):
|
||||
self._set_type(expression, exp.DataType.Type.BIGINT)
|
||||
else:
|
||||
self._set_type(expression, self._maybe_coerce(left_type, right_type))
|
||||
|
||||
return expression
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime
|
||||
|
||||
|
||||
def canonicalize(expression: exp.Expression) -> exp.Expression:
|
||||
|
@ -20,7 +22,7 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
|
|||
expression = replace_date_funcs(expression)
|
||||
expression = coerce_type(expression)
|
||||
expression = remove_redundant_casts(expression)
|
||||
expression = ensure_bool_predicates(expression)
|
||||
expression = ensure_bools(expression, _replace_int_predicate)
|
||||
expression = remove_ascending_order(expression)
|
||||
|
||||
return expression
|
||||
|
@ -40,8 +42,22 @@ def replace_date_funcs(node: exp.Expression) -> exp.Expression:
|
|||
return node
|
||||
|
||||
|
||||
COERCIBLE_DATE_OPS = (
|
||||
exp.Add,
|
||||
exp.Sub,
|
||||
exp.EQ,
|
||||
exp.NEQ,
|
||||
exp.GT,
|
||||
exp.GTE,
|
||||
exp.LT,
|
||||
exp.LTE,
|
||||
exp.NullSafeEQ,
|
||||
exp.NullSafeNEQ,
|
||||
)
|
||||
|
||||
|
||||
def coerce_type(node: exp.Expression) -> exp.Expression:
|
||||
if isinstance(node, exp.Binary):
|
||||
if isinstance(node, COERCIBLE_DATE_OPS):
|
||||
_coerce_date(node.left, node.right)
|
||||
elif isinstance(node, exp.Between):
|
||||
_coerce_date(node.this, node.args["low"])
|
||||
|
@ -49,6 +65,10 @@ def coerce_type(node: exp.Expression) -> exp.Expression:
|
|||
*exp.DataType.TEMPORAL_TYPES
|
||||
):
|
||||
_replace_cast(node.expression, exp.DataType.Type.DATETIME)
|
||||
elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)):
|
||||
_coerce_timeunit_arg(node.this, node.unit)
|
||||
elif isinstance(node, exp.DateDiff):
|
||||
_coerce_datediff_args(node)
|
||||
|
||||
return node
|
||||
|
||||
|
@ -64,17 +84,21 @@ def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression:
|
||||
def ensure_bools(
|
||||
expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None]
|
||||
) -> exp.Expression:
|
||||
if isinstance(expression, exp.Connector):
|
||||
_replace_int_predicate(expression.left)
|
||||
_replace_int_predicate(expression.right)
|
||||
|
||||
elif isinstance(expression, (exp.Where, exp.Having)) or (
|
||||
replace_func(expression.left)
|
||||
replace_func(expression.right)
|
||||
elif isinstance(expression, exp.Not):
|
||||
replace_func(expression.this)
|
||||
# We can't replace num in CASE x WHEN num ..., because it's not the full predicate
|
||||
isinstance(expression, exp.If)
|
||||
and not (isinstance(expression.parent, exp.Case) and expression.parent.this)
|
||||
elif isinstance(expression, exp.If) and not (
|
||||
isinstance(expression.parent, exp.Case) and expression.parent.this
|
||||
):
|
||||
_replace_int_predicate(expression.this)
|
||||
replace_func(expression.this)
|
||||
elif isinstance(expression, (exp.Where, exp.Having)):
|
||||
replace_func(expression.this)
|
||||
|
||||
return expression
|
||||
|
||||
|
@ -89,22 +113,59 @@ def remove_ascending_order(expression: exp.Expression) -> exp.Expression:
|
|||
|
||||
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
|
||||
for a, b in itertools.permutations([a, b]):
|
||||
if isinstance(b, exp.Interval):
|
||||
a = _coerce_timeunit_arg(a, b.unit)
|
||||
if (
|
||||
a.type
|
||||
and a.type.this == exp.DataType.Type.DATE
|
||||
and b.type
|
||||
and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL)
|
||||
and b.type.this
|
||||
not in (
|
||||
exp.DataType.Type.DATE,
|
||||
exp.DataType.Type.INTERVAL,
|
||||
)
|
||||
):
|
||||
_replace_cast(b, exp.DataType.Type.DATE)
|
||||
|
||||
|
||||
def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression:
|
||||
if not arg.type:
|
||||
return arg
|
||||
|
||||
if arg.type.this in exp.DataType.TEXT_TYPES:
|
||||
date_text = arg.name
|
||||
is_iso_date_ = is_iso_date(date_text)
|
||||
|
||||
if is_iso_date_ and is_date_unit(unit):
|
||||
return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATE))
|
||||
|
||||
# An ISO date is also an ISO datetime, but not vice versa
|
||||
if is_iso_date_ or is_iso_datetime(date_text):
|
||||
return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
|
||||
|
||||
elif arg.type.this == exp.DataType.Type.DATE and not is_date_unit(unit):
|
||||
return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
|
||||
|
||||
return arg
|
||||
|
||||
|
||||
def _coerce_datediff_args(node: exp.DateDiff) -> None:
|
||||
for e in (node.this, node.expression):
|
||||
if e.type.this not in exp.DataType.TEMPORAL_TYPES:
|
||||
e.replace(exp.cast(e.copy(), to=exp.DataType.Type.DATETIME))
|
||||
|
||||
|
||||
def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None:
|
||||
node.replace(exp.cast(node.copy(), to=to))
|
||||
|
||||
|
||||
# this was originally designed for presto, there is a similar transform for tsql
|
||||
# this is different in that it only operates on int types, this is because
|
||||
# presto has a boolean type whereas tsql doesn't (people use bits)
|
||||
# with y as (select true as x) select x = 0 FROM y -- illegal presto query
|
||||
def _replace_int_predicate(expression: exp.Expression) -> None:
|
||||
if isinstance(expression, exp.Coalesce):
|
||||
for _, child in expression.iter_expressions():
|
||||
_replace_int_predicate(child)
|
||||
elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
|
||||
expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0)))
|
||||
expression.replace(expression.neq(0))
|
||||
|
|
|
@ -186,13 +186,13 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
|||
and not (
|
||||
isinstance(from_or_join, exp.Join)
|
||||
and inner_select.args.get("where")
|
||||
and from_or_join.side in {"FULL", "LEFT", "RIGHT"}
|
||||
and from_or_join.side in ("FULL", "LEFT", "RIGHT")
|
||||
)
|
||||
and not (
|
||||
isinstance(from_or_join, exp.From)
|
||||
and inner_select.args.get("where")
|
||||
and any(
|
||||
j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])
|
||||
j.side in ("FULL", "RIGHT") for j in outer_scope.expression.args.get("joins", [])
|
||||
)
|
||||
)
|
||||
and not _outer_select_joins_on_inner_select_join()
|
||||
|
|
|
@ -13,7 +13,7 @@ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
|
|||
|
||||
|
||||
@t.overload
|
||||
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Expression:
|
||||
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier:
|
||||
...
|
||||
|
||||
|
||||
|
@ -48,11 +48,11 @@ def normalize_identifiers(expression, dialect=None):
|
|||
Returns:
|
||||
The transformed expression.
|
||||
"""
|
||||
dialect = Dialect.get_or_raise(dialect)
|
||||
|
||||
if isinstance(expression, str):
|
||||
expression = exp.parse_identifier(expression, dialect=dialect)
|
||||
|
||||
dialect = Dialect.get_or_raise(dialect)
|
||||
|
||||
def _normalize(node: E) -> E:
|
||||
if not node.meta.get("case_sensitive"):
|
||||
exp.replace_children(node, _normalize)
|
||||
|
|
|
@ -42,8 +42,8 @@ RULES = (
|
|||
def optimize(
|
||||
expression: str | exp.Expression,
|
||||
schema: t.Optional[dict | Schema] = None,
|
||||
db: t.Optional[str] = None,
|
||||
catalog: t.Optional[str] = None,
|
||||
db: t.Optional[str | exp.Identifier] = None,
|
||||
catalog: t.Optional[str | exp.Identifier] = None,
|
||||
dialect: DialectType = None,
|
||||
rules: t.Sequence[t.Callable] = RULES,
|
||||
**kwargs,
|
||||
|
|
|
@ -8,7 +8,7 @@ from sqlglot._typing import E
|
|||
from sqlglot.dialects.dialect import Dialect, DialectType
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
|
||||
from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
|
||||
from sqlglot.optimizer.simplify import simplify_parens
|
||||
from sqlglot.schema import Schema, ensure_schema
|
||||
|
||||
|
@ -58,7 +58,7 @@ def qualify_columns(
|
|||
|
||||
if not isinstance(scope.expression, exp.UDTF):
|
||||
_expand_stars(scope, resolver, using_column_tables, pseudocolumns)
|
||||
_qualify_outputs(scope)
|
||||
qualify_outputs(scope)
|
||||
|
||||
_expand_group_by(scope)
|
||||
_expand_order_by(scope, resolver)
|
||||
|
@ -237,7 +237,7 @@ def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
|
|||
ordereds = order.expressions
|
||||
for ordered, new_expression in zip(
|
||||
ordereds,
|
||||
_expand_positional_references(scope, (o.this for o in ordereds)),
|
||||
_expand_positional_references(scope, (o.this for o in ordereds), alias=True),
|
||||
):
|
||||
for agg in ordered.find_all(exp.AggFunc):
|
||||
for col in agg.find_all(exp.Column):
|
||||
|
@ -259,17 +259,23 @@ def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
|
|||
)
|
||||
|
||||
|
||||
def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]:
|
||||
new_nodes = []
|
||||
def _expand_positional_references(
|
||||
scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False
|
||||
) -> t.List[exp.Expression]:
|
||||
new_nodes: t.List[exp.Expression] = []
|
||||
for node in expressions:
|
||||
if node.is_int:
|
||||
select = _select_by_pos(scope, t.cast(exp.Literal, node)).this
|
||||
select = _select_by_pos(scope, t.cast(exp.Literal, node))
|
||||
|
||||
if isinstance(select, exp.Literal):
|
||||
new_nodes.append(node)
|
||||
if alias:
|
||||
new_nodes.append(exp.column(select.args["alias"].copy()))
|
||||
else:
|
||||
new_nodes.append(select.copy())
|
||||
scope.clear_cache()
|
||||
select = select.this
|
||||
|
||||
if isinstance(select, exp.Literal):
|
||||
new_nodes.append(node)
|
||||
else:
|
||||
new_nodes.append(select.copy())
|
||||
else:
|
||||
new_nodes.append(node)
|
||||
|
||||
|
@ -307,7 +313,9 @@ def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
|
|||
if column_table:
|
||||
column.set("table", column_table)
|
||||
elif column_table not in scope.sources and (
|
||||
not scope.parent or column_table not in scope.parent.sources
|
||||
not scope.parent
|
||||
or column_table not in scope.parent.sources
|
||||
or not scope.is_correlated_subquery
|
||||
):
|
||||
# structs are used like tables (e.g. "struct"."field"), so they need to be qualified
|
||||
# separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
|
||||
|
@ -381,15 +389,18 @@ def _expand_stars(
|
|||
columns = [name for name in columns if name.upper() not in pseudocolumns]
|
||||
|
||||
if columns and "*" not in columns:
|
||||
table_id = id(table)
|
||||
columns_to_exclude = except_columns.get(table_id) or set()
|
||||
|
||||
if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
|
||||
implicit_columns = [col for col in columns if col not in pivot_columns]
|
||||
new_selections.extend(
|
||||
exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
|
||||
for name in implicit_columns + pivot_output_columns
|
||||
if name not in columns_to_exclude
|
||||
)
|
||||
continue
|
||||
|
||||
table_id = id(table)
|
||||
for name in columns:
|
||||
if name in using_column_tables and table in using_column_tables[name]:
|
||||
if name in coalesced_columns:
|
||||
|
@ -406,7 +417,7 @@ def _expand_stars(
|
|||
copy=False,
|
||||
)
|
||||
)
|
||||
elif name not in except_columns.get(table_id, set()):
|
||||
elif name not in columns_to_exclude:
|
||||
alias_ = replace_columns.get(table_id, {}).get(name, name)
|
||||
column = exp.column(name, table=table)
|
||||
new_selections.append(
|
||||
|
@ -448,10 +459,16 @@ def _add_replace_columns(
|
|||
replace_columns[id(table)] = columns
|
||||
|
||||
|
||||
def _qualify_outputs(scope: Scope) -> None:
|
||||
def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
|
||||
"""Ensure all output columns are aliased"""
|
||||
new_selections = []
|
||||
if isinstance(scope_or_expression, exp.Expression):
|
||||
scope = build_scope(scope_or_expression)
|
||||
if not isinstance(scope, Scope):
|
||||
return
|
||||
else:
|
||||
scope = scope_or_expression
|
||||
|
||||
new_selections = []
|
||||
for i, (selection, aliased_column) in enumerate(
|
||||
itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
|
||||
):
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import typing as t
|
||||
|
||||
from sqlglot import alias, exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
from sqlglot.helper import csv_reader, name_sequence
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import Schema
|
||||
|
@ -10,9 +13,10 @@ from sqlglot.schema import Schema
|
|||
|
||||
def qualify_tables(
|
||||
expression: E,
|
||||
db: t.Optional[str] = None,
|
||||
catalog: t.Optional[str] = None,
|
||||
db: t.Optional[str | exp.Identifier] = None,
|
||||
catalog: t.Optional[str | exp.Identifier] = None,
|
||||
schema: t.Optional[Schema] = None,
|
||||
dialect: DialectType = None,
|
||||
) -> E:
|
||||
"""
|
||||
Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
|
||||
|
@ -33,11 +37,14 @@ def qualify_tables(
|
|||
db: Database name
|
||||
catalog: Catalog name
|
||||
schema: A schema to populate
|
||||
dialect: The dialect to parse catalog and schema into.
|
||||
|
||||
Returns:
|
||||
The qualified expression.
|
||||
"""
|
||||
next_alias_name = name_sequence("_q_")
|
||||
db = exp.parse_identifier(db, dialect=dialect) if db else None
|
||||
catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
|
||||
|
||||
for scope in traverse_scope(expression):
|
||||
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
|
||||
|
@ -61,9 +68,9 @@ def qualify_tables(
|
|||
if isinstance(source, exp.Table):
|
||||
if isinstance(source.this, exp.Identifier):
|
||||
if not source.args.get("db"):
|
||||
source.set("db", exp.to_identifier(db))
|
||||
source.set("db", db)
|
||||
if not source.args.get("catalog") and source.args.get("db"):
|
||||
source.set("catalog", exp.to_identifier(catalog))
|
||||
source.set("catalog", catalog)
|
||||
|
||||
if not source.alias:
|
||||
# Mutates the source by attaching an alias to it
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import typing as t
|
||||
|
|
|
@ -507,6 +507,9 @@ def simplify_literals(expression, root=True):
|
|||
return exp.Literal.number(value[1:])
|
||||
return exp.Literal.number(f"-{value}")
|
||||
|
||||
if type(expression) in INVERSE_DATE_OPS:
|
||||
return _simplify_binary(expression, expression.this, expression.interval()) or expression
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -530,22 +533,24 @@ def _simplify_binary(expression, a, b):
|
|||
return exp.null()
|
||||
|
||||
if a.is_number and b.is_number:
|
||||
a = int(a.name) if a.is_int else Decimal(a.name)
|
||||
b = int(b.name) if b.is_int else Decimal(b.name)
|
||||
num_a = int(a.name) if a.is_int else Decimal(a.name)
|
||||
num_b = int(b.name) if b.is_int else Decimal(b.name)
|
||||
|
||||
if isinstance(expression, exp.Add):
|
||||
return exp.Literal.number(a + b)
|
||||
if isinstance(expression, exp.Sub):
|
||||
return exp.Literal.number(a - b)
|
||||
return exp.Literal.number(num_a + num_b)
|
||||
if isinstance(expression, exp.Mul):
|
||||
return exp.Literal.number(a * b)
|
||||
return exp.Literal.number(num_a * num_b)
|
||||
|
||||
# We only simplify Sub, Div if a and b have the same parent because they're not associative
|
||||
if isinstance(expression, exp.Sub):
|
||||
return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
|
||||
if isinstance(expression, exp.Div):
|
||||
# engines have differing int div behavior so intdiv is not safe
|
||||
if isinstance(a, int) and isinstance(b, int):
|
||||
if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
|
||||
return None
|
||||
return exp.Literal.number(a / b)
|
||||
return exp.Literal.number(num_a / num_b)
|
||||
|
||||
boolean = eval_boolean(expression, a, b)
|
||||
boolean = eval_boolean(expression, num_a, num_b)
|
||||
|
||||
if boolean:
|
||||
return boolean
|
||||
|
@ -557,15 +562,21 @@ def _simplify_binary(expression, a, b):
|
|||
elif _is_date_literal(a) and isinstance(b, exp.Interval):
|
||||
a, b = extract_date(a), extract_interval(b)
|
||||
if a and b:
|
||||
if isinstance(expression, exp.Add):
|
||||
if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
|
||||
return date_literal(a + b)
|
||||
if isinstance(expression, exp.Sub):
|
||||
if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
|
||||
return date_literal(a - b)
|
||||
elif isinstance(a, exp.Interval) and _is_date_literal(b):
|
||||
a, b = extract_interval(a), extract_date(b)
|
||||
# you cannot subtract a date from an interval
|
||||
if a and b and isinstance(expression, exp.Add):
|
||||
return date_literal(a + b)
|
||||
elif _is_date_literal(a) and _is_date_literal(b):
|
||||
if isinstance(expression, exp.Predicate):
|
||||
a, b = extract_date(a), extract_date(b)
|
||||
boolean = eval_boolean(expression, a, b)
|
||||
if boolean:
|
||||
return boolean
|
||||
|
||||
return None
|
||||
|
||||
|
@ -590,6 +601,11 @@ def simplify_parens(expression):
|
|||
return expression
|
||||
|
||||
|
||||
NONNULL_CONSTANTS = (
|
||||
exp.Literal,
|
||||
exp.Boolean,
|
||||
)
|
||||
|
||||
CONSTANTS = (
|
||||
exp.Literal,
|
||||
exp.Boolean,
|
||||
|
@ -597,11 +613,19 @@ CONSTANTS = (
|
|||
)
|
||||
|
||||
|
||||
def _is_nonnull_constant(expression: exp.Expression) -> bool:
|
||||
return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression)
|
||||
|
||||
|
||||
def _is_constant(expression: exp.Expression) -> bool:
|
||||
return isinstance(expression, CONSTANTS) or _is_date_literal(expression)
|
||||
|
||||
|
||||
def simplify_coalesce(expression):
|
||||
# COALESCE(x) -> x
|
||||
if (
|
||||
isinstance(expression, exp.Coalesce)
|
||||
and not expression.expressions
|
||||
and (not expression.expressions or _is_nonnull_constant(expression.this))
|
||||
# COALESCE is also used as a Spark partitioning hint
|
||||
and not isinstance(expression.parent, exp.Hint)
|
||||
):
|
||||
|
@ -621,12 +645,12 @@ def simplify_coalesce(expression):
|
|||
|
||||
# This transformation is valid for non-constants,
|
||||
# but it really only does anything if they are both constants.
|
||||
if not isinstance(other, CONSTANTS):
|
||||
if not _is_constant(other):
|
||||
return expression
|
||||
|
||||
# Find the first constant arg
|
||||
for arg_index, arg in enumerate(coalesce.expressions):
|
||||
if isinstance(arg, CONSTANTS):
|
||||
if _is_constant(other):
|
||||
break
|
||||
else:
|
||||
return expression
|
||||
|
@ -656,7 +680,6 @@ def simplify_coalesce(expression):
|
|||
|
||||
|
||||
CONCATS = (exp.Concat, exp.DPipe)
|
||||
SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
|
||||
|
||||
|
||||
def simplify_concat(expression):
|
||||
|
@ -672,10 +695,15 @@ def simplify_concat(expression):
|
|||
sep_expr, *expressions = expression.expressions
|
||||
sep = sep_expr.name
|
||||
concat_type = exp.ConcatWs
|
||||
args = {}
|
||||
else:
|
||||
expressions = expression.expressions
|
||||
sep = ""
|
||||
concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
|
||||
concat_type = exp.Concat
|
||||
args = {
|
||||
"safe": expression.args.get("safe"),
|
||||
"coalesce": expression.args.get("coalesce"),
|
||||
}
|
||||
|
||||
new_args = []
|
||||
for is_string_group, group in itertools.groupby(
|
||||
|
@ -692,7 +720,7 @@ def simplify_concat(expression):
|
|||
if concat_type is exp.ConcatWs:
|
||||
new_args = [sep_expr] + new_args
|
||||
|
||||
return concat_type(expressions=new_args)
|
||||
return concat_type(expressions=new_args, **args)
|
||||
|
||||
|
||||
def simplify_conditionals(expression):
|
||||
|
@ -947,7 +975,7 @@ def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.da
|
|||
def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
|
||||
if isinstance(cast, exp.Cast):
|
||||
to = cast.to
|
||||
elif isinstance(cast, exp.TsOrDsToDate):
|
||||
elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
|
||||
to = exp.DataType.build(exp.DataType.Type.DATE)
|
||||
else:
|
||||
return None
|
||||
|
@ -966,12 +994,11 @@ def _is_date_literal(expression: exp.Expression) -> bool:
|
|||
|
||||
|
||||
def extract_interval(expression):
|
||||
n = int(expression.name)
|
||||
unit = expression.text("unit").lower()
|
||||
|
||||
try:
|
||||
n = int(expression.name)
|
||||
unit = expression.text("unit").lower()
|
||||
return interval(unit, n)
|
||||
except (UnsupportedUnit, ModuleNotFoundError):
|
||||
except (UnsupportedUnit, ModuleNotFoundError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
|
@ -1099,8 +1126,6 @@ GEN_MAP = {
|
|||
exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}",
|
||||
exp.Div: lambda e: _binary(e, "/"),
|
||||
exp.Dot: lambda e: _binary(e, "."),
|
||||
exp.DPipe: lambda e: _binary(e, "||"),
|
||||
exp.SafeDPipe: lambda e: _binary(e, "||"),
|
||||
exp.EQ: lambda e: _binary(e, "="),
|
||||
exp.GT: lambda e: _binary(e, ">"),
|
||||
exp.GTE: lambda e: _binary(e, ">="),
|
||||
|
|
|
@ -13,6 +13,7 @@ from sqlglot.trie import TrieResult, in_trie, new_trie
|
|||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import Dialect, DialectType
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
|
@ -46,6 +47,19 @@ def binary_range_parser(
|
|||
)
|
||||
|
||||
|
||||
def parse_logarithm(args: t.List, dialect: Dialect) -> exp.Func:
|
||||
# Default argument order is base, expression
|
||||
this = seq_get(args, 0)
|
||||
expression = seq_get(args, 1)
|
||||
|
||||
if expression:
|
||||
if not dialect.LOG_BASE_FIRST:
|
||||
this, expression = expression, this
|
||||
return exp.Log(this=this, expression=expression)
|
||||
|
||||
return (exp.Ln if dialect.parser_class.LOG_DEFAULTS_TO_LN else exp.Log)(this=this)
|
||||
|
||||
|
||||
class _Parser(type):
|
||||
def __new__(cls, clsname, bases, attrs):
|
||||
klass = super().__new__(cls, clsname, bases, attrs)
|
||||
|
@ -72,13 +86,24 @@ class Parser(metaclass=_Parser):
|
|||
"""
|
||||
|
||||
FUNCTIONS: t.Dict[str, t.Callable] = {
|
||||
**{name: f.from_arg_list for f in exp.ALL_FUNCTIONS for name in f.sql_names()},
|
||||
**{name: func.from_arg_list for name, func in exp.FUNCTION_BY_NAME.items()},
|
||||
"CONCAT": lambda args, dialect: exp.Concat(
|
||||
expressions=args,
|
||||
safe=not dialect.STRICT_STRING_CONCAT,
|
||||
coalesce=dialect.CONCAT_COALESCE,
|
||||
),
|
||||
"CONCAT_WS": lambda args, dialect: exp.ConcatWs(
|
||||
expressions=args,
|
||||
safe=not dialect.STRICT_STRING_CONCAT,
|
||||
coalesce=dialect.CONCAT_COALESCE,
|
||||
),
|
||||
"DATE_TO_DATE_STR": lambda args: exp.Cast(
|
||||
this=seq_get(args, 0),
|
||||
to=exp.DataType(this=exp.DataType.Type.TEXT),
|
||||
),
|
||||
"GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)),
|
||||
"LIKE": parse_like,
|
||||
"LOG": parse_logarithm,
|
||||
"TIME_TO_TIME_STR": lambda args: exp.Cast(
|
||||
this=seq_get(args, 0),
|
||||
to=exp.DataType(this=exp.DataType.Type.TEXT),
|
||||
|
@ -229,7 +254,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.SOME: exp.Any,
|
||||
}
|
||||
|
||||
RESERVED_KEYWORDS = {
|
||||
RESERVED_TOKENS = {
|
||||
*Tokenizer.SINGLE_TOKENS.values(),
|
||||
TokenType.SELECT,
|
||||
}
|
||||
|
@ -245,9 +270,11 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
CREATABLES = {
|
||||
TokenType.COLUMN,
|
||||
TokenType.CONSTRAINT,
|
||||
TokenType.FUNCTION,
|
||||
TokenType.INDEX,
|
||||
TokenType.PROCEDURE,
|
||||
TokenType.FOREIGN_KEY,
|
||||
*DB_CREATABLES,
|
||||
}
|
||||
|
||||
|
@ -291,6 +318,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.NATURAL,
|
||||
TokenType.NEXT,
|
||||
TokenType.OFFSET,
|
||||
TokenType.OPERATOR,
|
||||
TokenType.ORDINALITY,
|
||||
TokenType.OVERLAPS,
|
||||
TokenType.OVERWRITE,
|
||||
|
@ -299,7 +327,10 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.PIVOT,
|
||||
TokenType.PRAGMA,
|
||||
TokenType.RANGE,
|
||||
TokenType.RECURSIVE,
|
||||
TokenType.REFERENCES,
|
||||
TokenType.REFRESH,
|
||||
TokenType.REPLACE,
|
||||
TokenType.RIGHT,
|
||||
TokenType.ROW,
|
||||
TokenType.ROWS,
|
||||
|
@ -390,6 +421,7 @@ class Parser(metaclass=_Parser):
|
|||
}
|
||||
|
||||
EQUALITY = {
|
||||
TokenType.COLON_EQ: exp.PropertyEQ,
|
||||
TokenType.EQ: exp.EQ,
|
||||
TokenType.NEQ: exp.NEQ,
|
||||
TokenType.NULLSAFE_EQ: exp.NullSafeEQ,
|
||||
|
@ -406,7 +438,6 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.AMP: exp.BitwiseAnd,
|
||||
TokenType.CARET: exp.BitwiseXor,
|
||||
TokenType.PIPE: exp.BitwiseOr,
|
||||
TokenType.DPIPE: exp.DPipe,
|
||||
}
|
||||
|
||||
TERM = {
|
||||
|
@ -423,6 +454,8 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.STAR: exp.Mul,
|
||||
}
|
||||
|
||||
EXPONENT: t.Dict[TokenType, t.Type[exp.Expression]] = {}
|
||||
|
||||
TIMES = {
|
||||
TokenType.TIME,
|
||||
TokenType.TIMETZ,
|
||||
|
@ -558,6 +591,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.MERGE: lambda self: self._parse_merge(),
|
||||
TokenType.PIVOT: lambda self: self._parse_simplified_pivot(),
|
||||
TokenType.PRAGMA: lambda self: self.expression(exp.Pragma, this=self._parse_expression()),
|
||||
TokenType.REFRESH: lambda self: self._parse_refresh(),
|
||||
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
|
||||
TokenType.SET: lambda self: self._parse_set(),
|
||||
TokenType.UNCACHE: lambda self: self._parse_uncache(),
|
||||
|
@ -697,6 +731,7 @@ class Parser(metaclass=_Parser):
|
|||
exp.StabilityProperty, this=exp.Literal.string("STABLE")
|
||||
),
|
||||
"STORED": lambda self: self._parse_stored(),
|
||||
"SYSTEM_VERSIONING": lambda self: self._parse_system_versioning_property(),
|
||||
"TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property),
|
||||
"TEMP": lambda self: self.expression(exp.TemporaryProperty),
|
||||
"TEMPORARY": lambda self: self.expression(exp.TemporaryProperty),
|
||||
|
@ -754,6 +789,7 @@ class Parser(metaclass=_Parser):
|
|||
)
|
||||
or self.expression(exp.OnProperty, this=self._parse_id_var()),
|
||||
"PATH": lambda self: self.expression(exp.PathColumnConstraint, this=self._parse_string()),
|
||||
"PERIOD": lambda self: self._parse_period_for_system_time(),
|
||||
"PRIMARY KEY": lambda self: self._parse_primary_key(),
|
||||
"REFERENCES": lambda self: self._parse_references(match=False),
|
||||
"TITLE": lambda self: self.expression(
|
||||
|
@ -775,7 +811,7 @@ class Parser(metaclass=_Parser):
|
|||
"RENAME": lambda self: self._parse_alter_table_rename(),
|
||||
}
|
||||
|
||||
SCHEMA_UNNAMED_CONSTRAINTS = {"CHECK", "FOREIGN KEY", "LIKE", "PRIMARY KEY", "UNIQUE"}
|
||||
SCHEMA_UNNAMED_CONSTRAINTS = {"CHECK", "FOREIGN KEY", "LIKE", "PRIMARY KEY", "UNIQUE", "PERIOD"}
|
||||
|
||||
NO_PAREN_FUNCTION_PARSERS = {
|
||||
"ANY": lambda self: self.expression(exp.Any, this=self._parse_bitwise()),
|
||||
|
@ -794,14 +830,11 @@ class Parser(metaclass=_Parser):
|
|||
FUNCTION_PARSERS = {
|
||||
"ANY_VALUE": lambda self: self._parse_any_value(),
|
||||
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
|
||||
"CONCAT": lambda self: self._parse_concat(),
|
||||
"CONCAT_WS": lambda self: self._parse_concat_ws(),
|
||||
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
|
||||
"DECODE": lambda self: self._parse_decode(),
|
||||
"EXTRACT": lambda self: self._parse_extract(),
|
||||
"JSON_OBJECT": lambda self: self._parse_json_object(),
|
||||
"JSON_TABLE": lambda self: self._parse_json_table(),
|
||||
"LOG": lambda self: self._parse_logarithm(),
|
||||
"MATCH": lambda self: self._parse_match_against(),
|
||||
"OPENJSON": lambda self: self._parse_open_json(),
|
||||
"POSITION": lambda self: self._parse_position(),
|
||||
|
@ -877,6 +910,7 @@ class Parser(metaclass=_Parser):
|
|||
CLONE_KINDS = {"TIMESTAMP", "OFFSET", "STATEMENT"}
|
||||
|
||||
OPCLASS_FOLLOW_KEYWORDS = {"ASC", "DESC", "NULLS"}
|
||||
OPTYPE_FOLLOW_TOKENS = {TokenType.COMMA, TokenType.R_PAREN}
|
||||
|
||||
TABLE_INDEX_HINT_TOKENS = {TokenType.FORCE, TokenType.IGNORE, TokenType.USE}
|
||||
|
||||
|
@ -896,17 +930,13 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
STRICT_CAST = True
|
||||
|
||||
# A NULL arg in CONCAT yields NULL by default
|
||||
CONCAT_NULL_OUTPUTS_STRING = False
|
||||
|
||||
PREFIXED_PIVOT_COLUMNS = False
|
||||
IDENTIFY_PIVOT_STRINGS = False
|
||||
|
||||
LOG_BASE_FIRST = True
|
||||
LOG_DEFAULTS_TO_LN = False
|
||||
|
||||
# Whether or not ADD is present for each column added by ALTER TABLE
|
||||
ALTER_TABLE_ADD_COLUMN_KEYWORD = True
|
||||
ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = True
|
||||
|
||||
# Whether or not the table sample clause expects CSV syntax
|
||||
TABLESAMPLE_CSV = False
|
||||
|
@ -921,6 +951,7 @@ class Parser(metaclass=_Parser):
|
|||
"error_level",
|
||||
"error_message_context",
|
||||
"max_errors",
|
||||
"dialect",
|
||||
"sql",
|
||||
"errors",
|
||||
"_tokens",
|
||||
|
@ -929,35 +960,25 @@ class Parser(metaclass=_Parser):
|
|||
"_next",
|
||||
"_prev",
|
||||
"_prev_comments",
|
||||
"_tokenizer",
|
||||
)
|
||||
|
||||
# Autofilled
|
||||
TOKENIZER_CLASS: t.Type[Tokenizer] = Tokenizer
|
||||
INDEX_OFFSET: int = 0
|
||||
UNNEST_COLUMN_ONLY: bool = False
|
||||
ALIAS_POST_TABLESAMPLE: bool = False
|
||||
STRICT_STRING_CONCAT = False
|
||||
SUPPORTS_USER_DEFINED_TYPES = True
|
||||
NORMALIZE_FUNCTIONS = "upper"
|
||||
NULL_ORDERING: str = "nulls_are_small"
|
||||
SHOW_TRIE: t.Dict = {}
|
||||
SET_TRIE: t.Dict = {}
|
||||
FORMAT_MAPPING: t.Dict[str, str] = {}
|
||||
FORMAT_TRIE: t.Dict = {}
|
||||
TIME_MAPPING: t.Dict[str, str] = {}
|
||||
TIME_TRIE: t.Dict = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_level: t.Optional[ErrorLevel] = None,
|
||||
error_message_context: int = 100,
|
||||
max_errors: int = 3,
|
||||
dialect: DialectType = None,
|
||||
):
|
||||
from sqlglot.dialects import Dialect
|
||||
|
||||
self.error_level = error_level or ErrorLevel.IMMEDIATE
|
||||
self.error_message_context = error_message_context
|
||||
self.max_errors = max_errors
|
||||
self._tokenizer = self.TOKENIZER_CLASS()
|
||||
self.dialect = Dialect.get_or_raise(dialect)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
|
@ -1384,7 +1405,7 @@ class Parser(metaclass=_Parser):
|
|||
if self._match_texts(self.CLONE_KEYWORDS):
|
||||
copy = self._prev.text.lower() == "copy"
|
||||
clone = self._parse_table(schema=True)
|
||||
when = self._match_texts({"AT", "BEFORE"}) and self._prev.text.upper()
|
||||
when = self._match_texts(("AT", "BEFORE")) and self._prev.text.upper()
|
||||
clone_kind = (
|
||||
self._match(TokenType.L_PAREN)
|
||||
and self._match_texts(self.CLONE_KINDS)
|
||||
|
@ -1524,6 +1545,22 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return self.expression(exp.StabilityProperty, this=exp.Literal.string("VOLATILE"))
|
||||
|
||||
def _parse_system_versioning_property(self) -> exp.WithSystemVersioningProperty:
|
||||
self._match_pair(TokenType.EQ, TokenType.ON)
|
||||
|
||||
prop = self.expression(exp.WithSystemVersioningProperty)
|
||||
if self._match(TokenType.L_PAREN):
|
||||
self._match_text_seq("HISTORY_TABLE", "=")
|
||||
prop.set("this", self._parse_table_parts())
|
||||
|
||||
if self._match(TokenType.COMMA):
|
||||
self._match_text_seq("DATA_CONSISTENCY_CHECK", "=")
|
||||
prop.set("expression", self._advance_any() and self._prev.text.upper())
|
||||
|
||||
self._match_r_paren()
|
||||
|
||||
return prop
|
||||
|
||||
def _parse_with_property(
|
||||
self,
|
||||
) -> t.Optional[exp.Expression] | t.List[exp.Expression]:
|
||||
|
@ -2140,7 +2177,11 @@ class Parser(metaclass=_Parser):
|
|||
return self._parse_expressions()
|
||||
|
||||
def _parse_select(
|
||||
self, nested: bool = False, table: bool = False, parse_subquery_alias: bool = True
|
||||
self,
|
||||
nested: bool = False,
|
||||
table: bool = False,
|
||||
parse_subquery_alias: bool = True,
|
||||
parse_set_operation: bool = True,
|
||||
) -> t.Optional[exp.Expression]:
|
||||
cte = self._parse_with()
|
||||
|
||||
|
@ -2216,7 +2257,11 @@ class Parser(metaclass=_Parser):
|
|||
t.cast(exp.From, self._parse_from(skip_from_token=True))
|
||||
)
|
||||
else:
|
||||
this = self._parse_table() if table else self._parse_select(nested=True)
|
||||
this = (
|
||||
self._parse_table()
|
||||
if table
|
||||
else self._parse_select(nested=True, parse_set_operation=False)
|
||||
)
|
||||
this = self._parse_set_operations(self._parse_query_modifiers(this))
|
||||
|
||||
self._match_r_paren()
|
||||
|
@ -2235,7 +2280,9 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
this = None
|
||||
|
||||
return self._parse_set_operations(this)
|
||||
if parse_set_operation:
|
||||
return self._parse_set_operations(this)
|
||||
return this
|
||||
|
||||
def _parse_with(self, skip_with_token: bool = False) -> t.Optional[exp.With]:
|
||||
if not skip_with_token and not self._match(TokenType.WITH):
|
||||
|
@ -2563,9 +2610,8 @@ class Parser(metaclass=_Parser):
|
|||
if self._match_texts(self.OPCLASS_FOLLOW_KEYWORDS, advance=False):
|
||||
return this
|
||||
|
||||
opclass = self._parse_var(any_token=True)
|
||||
if opclass:
|
||||
return self.expression(exp.Opclass, this=this, expression=opclass)
|
||||
if not self._match_set(self.OPTYPE_FOLLOW_TOKENS, advance=False):
|
||||
return self.expression(exp.Opclass, this=this, expression=self._parse_table_parts())
|
||||
|
||||
return this
|
||||
|
||||
|
@ -2630,7 +2676,7 @@ class Parser(metaclass=_Parser):
|
|||
while self._match_set(self.TABLE_INDEX_HINT_TOKENS):
|
||||
hint = exp.IndexTableHint(this=self._prev.text.upper())
|
||||
|
||||
self._match_texts({"INDEX", "KEY"})
|
||||
self._match_texts(("INDEX", "KEY"))
|
||||
if self._match(TokenType.FOR):
|
||||
hint.set("target", self._advance_any() and self._prev.text.upper())
|
||||
|
||||
|
@ -2650,7 +2696,7 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_table_parts(self, schema: bool = False) -> exp.Table:
|
||||
catalog = None
|
||||
db = None
|
||||
table = self._parse_table_part(schema=schema)
|
||||
table: t.Optional[exp.Expression | str] = self._parse_table_part(schema=schema)
|
||||
|
||||
while self._match(TokenType.DOT):
|
||||
if catalog:
|
||||
|
@ -2661,7 +2707,7 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
catalog = db
|
||||
db = table
|
||||
table = self._parse_table_part(schema=schema)
|
||||
table = self._parse_table_part(schema=schema) or ""
|
||||
|
||||
if not table:
|
||||
self.raise_error(f"Expected table name but got {self._curr}")
|
||||
|
@ -2709,7 +2755,7 @@ class Parser(metaclass=_Parser):
|
|||
if version:
|
||||
this.set("version", version)
|
||||
|
||||
if self.ALIAS_POST_TABLESAMPLE:
|
||||
if self.dialect.ALIAS_POST_TABLESAMPLE:
|
||||
table_sample = self._parse_table_sample()
|
||||
|
||||
alias = self._parse_table_alias(alias_tokens=alias_tokens or self.TABLE_ALIAS_TOKENS)
|
||||
|
@ -2724,7 +2770,7 @@ class Parser(metaclass=_Parser):
|
|||
if not this.args.get("pivots"):
|
||||
this.set("pivots", self._parse_pivots())
|
||||
|
||||
if not self.ALIAS_POST_TABLESAMPLE:
|
||||
if not self.dialect.ALIAS_POST_TABLESAMPLE:
|
||||
table_sample = self._parse_table_sample()
|
||||
|
||||
if table_sample:
|
||||
|
@ -2776,13 +2822,13 @@ class Parser(metaclass=_Parser):
|
|||
if not self._match(TokenType.UNNEST):
|
||||
return None
|
||||
|
||||
expressions = self._parse_wrapped_csv(self._parse_type)
|
||||
expressions = self._parse_wrapped_csv(self._parse_equality)
|
||||
offset = self._match_pair(TokenType.WITH, TokenType.ORDINALITY)
|
||||
|
||||
alias = self._parse_table_alias() if with_alias else None
|
||||
|
||||
if alias:
|
||||
if self.UNNEST_COLUMN_ONLY:
|
||||
if self.dialect.UNNEST_COLUMN_ONLY:
|
||||
if alias.args.get("columns"):
|
||||
self.raise_error("Unexpected extra column alias in unnest.")
|
||||
|
||||
|
@ -2845,7 +2891,7 @@ class Parser(metaclass=_Parser):
|
|||
num = (
|
||||
self._parse_factor()
|
||||
if self._match(TokenType.NUMBER, advance=False)
|
||||
else self._parse_primary()
|
||||
else self._parse_primary() or self._parse_placeholder()
|
||||
)
|
||||
|
||||
if self._match_text_seq("BUCKET"):
|
||||
|
@ -3108,10 +3154,10 @@ class Parser(metaclass=_Parser):
|
|||
if (
|
||||
not explicitly_null_ordered
|
||||
and (
|
||||
(not desc and self.NULL_ORDERING == "nulls_are_small")
|
||||
or (desc and self.NULL_ORDERING != "nulls_are_small")
|
||||
(not desc and self.dialect.NULL_ORDERING == "nulls_are_small")
|
||||
or (desc and self.dialect.NULL_ORDERING != "nulls_are_small")
|
||||
)
|
||||
and self.NULL_ORDERING != "nulls_are_last"
|
||||
and self.dialect.NULL_ORDERING != "nulls_are_last"
|
||||
):
|
||||
nulls_first = True
|
||||
|
||||
|
@ -3124,7 +3170,7 @@ class Parser(metaclass=_Parser):
|
|||
comments = self._prev_comments
|
||||
if top:
|
||||
limit_paren = self._match(TokenType.L_PAREN)
|
||||
expression = self._parse_number()
|
||||
expression = self._parse_term() if limit_paren else self._parse_number()
|
||||
|
||||
if limit_paren:
|
||||
self._match_r_paren()
|
||||
|
@ -3225,7 +3271,9 @@ class Parser(metaclass=_Parser):
|
|||
this=this,
|
||||
distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL),
|
||||
by_name=self._match_text_seq("BY", "NAME"),
|
||||
expression=self._parse_set_operations(self._parse_select(nested=True)),
|
||||
expression=self._parse_set_operations(
|
||||
self._parse_select(nested=True, parse_set_operation=False)
|
||||
),
|
||||
)
|
||||
|
||||
def _parse_expression(self) -> t.Optional[exp.Expression]:
|
||||
|
@ -3287,7 +3335,8 @@ class Parser(metaclass=_Parser):
|
|||
unnest = self._parse_unnest(with_alias=False)
|
||||
if unnest:
|
||||
this = self.expression(exp.In, this=this, unnest=unnest)
|
||||
elif self._match(TokenType.L_PAREN):
|
||||
elif self._match_set((TokenType.L_PAREN, TokenType.L_BRACKET)):
|
||||
matched_l_paren = self._prev.token_type == TokenType.L_PAREN
|
||||
expressions = self._parse_csv(lambda: self._parse_select_or_expression(alias=alias))
|
||||
|
||||
if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable):
|
||||
|
@ -3295,13 +3344,16 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
this = self.expression(exp.In, this=this, expressions=expressions)
|
||||
|
||||
self._match_r_paren(this)
|
||||
if matched_l_paren:
|
||||
self._match_r_paren(this)
|
||||
elif not self._match(TokenType.R_BRACKET, expression=this):
|
||||
self.raise_error("Expecting ]")
|
||||
else:
|
||||
this = self.expression(exp.In, this=this, field=self._parse_field())
|
||||
|
||||
return this
|
||||
|
||||
def _parse_between(self, this: exp.Expression) -> exp.Between:
|
||||
def _parse_between(self, this: t.Optional[exp.Expression]) -> exp.Between:
|
||||
low = self._parse_bitwise()
|
||||
self._match(TokenType.AND)
|
||||
high = self._parse_bitwise()
|
||||
|
@ -3357,6 +3409,13 @@ class Parser(metaclass=_Parser):
|
|||
this=this,
|
||||
expression=self._parse_term(),
|
||||
)
|
||||
elif self.dialect.DPIPE_IS_STRING_CONCAT and self._match(TokenType.DPIPE):
|
||||
this = self.expression(
|
||||
exp.DPipe,
|
||||
this=this,
|
||||
expression=self._parse_term(),
|
||||
safe=not self.dialect.STRICT_STRING_CONCAT,
|
||||
)
|
||||
elif self._match(TokenType.DQMARK):
|
||||
this = self.expression(exp.Coalesce, this=this, expressions=self._parse_term())
|
||||
elif self._match_pair(TokenType.LT, TokenType.LT):
|
||||
|
@ -3376,7 +3435,17 @@ class Parser(metaclass=_Parser):
|
|||
return self._parse_tokens(self._parse_factor, self.TERM)
|
||||
|
||||
def _parse_factor(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_tokens(self._parse_unary, self.FACTOR)
|
||||
if self.EXPONENT:
|
||||
factor = self._parse_tokens(self._parse_exponent, self.FACTOR)
|
||||
else:
|
||||
factor = self._parse_tokens(self._parse_unary, self.FACTOR)
|
||||
if isinstance(factor, exp.Div):
|
||||
factor.args["typed"] = self.dialect.TYPED_DIVISION
|
||||
factor.args["safe"] = self.dialect.SAFE_DIVISION
|
||||
return factor
|
||||
|
||||
def _parse_exponent(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_tokens(self._parse_unary, self.EXPONENT)
|
||||
|
||||
def _parse_unary(self) -> t.Optional[exp.Expression]:
|
||||
if self._match_set(self.UNARY_PARSERS):
|
||||
|
@ -3427,14 +3496,14 @@ class Parser(metaclass=_Parser):
|
|||
)
|
||||
|
||||
if identifier:
|
||||
tokens = self._tokenizer.tokenize(identifier.name)
|
||||
tokens = self.dialect.tokenize(identifier.name)
|
||||
|
||||
if len(tokens) != 1:
|
||||
self.raise_error("Unexpected identifier", self._prev)
|
||||
|
||||
if tokens[0].token_type in self.TYPE_TOKENS:
|
||||
self._prev = tokens[0]
|
||||
elif self.SUPPORTS_USER_DEFINED_TYPES:
|
||||
elif self.dialect.SUPPORTS_USER_DEFINED_TYPES:
|
||||
type_name = identifier.name
|
||||
|
||||
while self._match(TokenType.DOT):
|
||||
|
@ -3713,6 +3782,7 @@ class Parser(metaclass=_Parser):
|
|||
if not self._curr:
|
||||
return None
|
||||
|
||||
comments = self._curr.comments
|
||||
token_type = self._curr.token_type
|
||||
this = self._curr.text
|
||||
upper = this.upper()
|
||||
|
@ -3754,13 +3824,22 @@ class Parser(metaclass=_Parser):
|
|||
args = self._parse_csv(lambda: self._parse_lambda(alias=alias))
|
||||
|
||||
if function and not anonymous:
|
||||
func = self.validate_expression(function(args), args)
|
||||
if not self.NORMALIZE_FUNCTIONS:
|
||||
if "dialect" in function.__code__.co_varnames:
|
||||
func = function(args, dialect=self.dialect)
|
||||
else:
|
||||
func = function(args)
|
||||
|
||||
func = self.validate_expression(func, args)
|
||||
if not self.dialect.NORMALIZE_FUNCTIONS:
|
||||
func.meta["name"] = this
|
||||
|
||||
this = func
|
||||
else:
|
||||
this = self.expression(exp.Anonymous, this=this, expressions=args)
|
||||
|
||||
if isinstance(this, exp.Expression):
|
||||
this.add_comments(comments)
|
||||
|
||||
self._match_r_paren(this)
|
||||
return self._parse_window(this)
|
||||
|
||||
|
@ -3875,6 +3954,11 @@ class Parser(metaclass=_Parser):
|
|||
not_null=self._match_pair(TokenType.NOT, TokenType.NULL),
|
||||
)
|
||||
)
|
||||
elif kind and self._match_pair(TokenType.ALIAS, TokenType.L_PAREN, advance=False):
|
||||
self._match(TokenType.ALIAS)
|
||||
constraints.append(
|
||||
self.expression(exp.TransformColumnConstraint, this=self._parse_field())
|
||||
)
|
||||
|
||||
while True:
|
||||
constraint = self._parse_column_constraint()
|
||||
|
@ -3917,7 +4001,11 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
def _parse_generated_as_identity(
|
||||
self,
|
||||
) -> exp.GeneratedAsIdentityColumnConstraint | exp.ComputedColumnConstraint:
|
||||
) -> (
|
||||
exp.GeneratedAsIdentityColumnConstraint
|
||||
| exp.ComputedColumnConstraint
|
||||
| exp.GeneratedAsRowColumnConstraint
|
||||
):
|
||||
if self._match_text_seq("BY", "DEFAULT"):
|
||||
on_null = self._match_pair(TokenType.ON, TokenType.NULL)
|
||||
this = self.expression(
|
||||
|
@ -3928,6 +4016,14 @@ class Parser(metaclass=_Parser):
|
|||
this = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True)
|
||||
|
||||
self._match(TokenType.ALIAS)
|
||||
|
||||
if self._match_text_seq("ROW"):
|
||||
start = self._match_text_seq("START")
|
||||
if not start:
|
||||
self._match(TokenType.END)
|
||||
hidden = self._match_text_seq("HIDDEN")
|
||||
return self.expression(exp.GeneratedAsRowColumnConstraint, start=start, hidden=hidden)
|
||||
|
||||
identity = self._match_text_seq("IDENTITY")
|
||||
|
||||
if self._match(TokenType.L_PAREN):
|
||||
|
@ -4100,6 +4196,16 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_primary_key_part(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_field()
|
||||
|
||||
def _parse_period_for_system_time(self) -> exp.PeriodForSystemTimeConstraint:
|
||||
self._match(TokenType.TIMESTAMP_SNAPSHOT)
|
||||
|
||||
id_vars = self._parse_wrapped_id_vars()
|
||||
return self.expression(
|
||||
exp.PeriodForSystemTimeConstraint,
|
||||
this=seq_get(id_vars, 0),
|
||||
expression=seq_get(id_vars, 1),
|
||||
)
|
||||
|
||||
def _parse_primary_key(
|
||||
self, wrapped_optional: bool = False, in_props: bool = False
|
||||
) -> exp.PrimaryKeyColumnConstraint | exp.PrimaryKey:
|
||||
|
@ -4145,7 +4251,7 @@ class Parser(metaclass=_Parser):
|
|||
elif not this or this.name.upper() == "ARRAY":
|
||||
this = self.expression(exp.Array, expressions=expressions)
|
||||
else:
|
||||
expressions = apply_index_offset(this, expressions, -self.INDEX_OFFSET)
|
||||
expressions = apply_index_offset(this, expressions, -self.dialect.INDEX_OFFSET)
|
||||
this = self.expression(exp.Bracket, this=this, expressions=expressions)
|
||||
|
||||
self._add_comments(this)
|
||||
|
@ -4259,8 +4365,8 @@ class Parser(metaclass=_Parser):
|
|||
format=exp.Literal.string(
|
||||
format_time(
|
||||
fmt_string.this if fmt_string else "",
|
||||
self.FORMAT_MAPPING or self.TIME_MAPPING,
|
||||
self.FORMAT_TRIE or self.TIME_TRIE,
|
||||
self.dialect.FORMAT_MAPPING or self.dialect.TIME_MAPPING,
|
||||
self.dialect.FORMAT_TRIE or self.dialect.TIME_TRIE,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
@ -4280,30 +4386,6 @@ class Parser(metaclass=_Parser):
|
|||
exp.Cast if strict else exp.TryCast, this=this, to=to, format=fmt, safe=safe
|
||||
)
|
||||
|
||||
def _parse_concat(self) -> t.Optional[exp.Expression]:
|
||||
args = self._parse_csv(self._parse_conjunction)
|
||||
if self.CONCAT_NULL_OUTPUTS_STRING:
|
||||
args = self._ensure_string_if_null(args)
|
||||
|
||||
# Some dialects (e.g. Trino) don't allow a single-argument CONCAT call, so when
|
||||
# we find such a call we replace it with its argument.
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
|
||||
return self.expression(
|
||||
exp.Concat if self.STRICT_STRING_CONCAT else exp.SafeConcat, expressions=args
|
||||
)
|
||||
|
||||
def _parse_concat_ws(self) -> t.Optional[exp.Expression]:
|
||||
args = self._parse_csv(self._parse_conjunction)
|
||||
if len(args) < 2:
|
||||
return self.expression(exp.ConcatWs, expressions=args)
|
||||
delim, *values = args
|
||||
if self.CONCAT_NULL_OUTPUTS_STRING:
|
||||
values = self._ensure_string_if_null(values)
|
||||
|
||||
return self.expression(exp.ConcatWs, expressions=[delim] + values)
|
||||
|
||||
def _parse_string_agg(self) -> exp.Expression:
|
||||
if self._match(TokenType.DISTINCT):
|
||||
args: t.List[t.Optional[exp.Expression]] = [
|
||||
|
@ -4495,19 +4577,6 @@ class Parser(metaclass=_Parser):
|
|||
empty_handling=empty_handling,
|
||||
)
|
||||
|
||||
def _parse_logarithm(self) -> exp.Func:
|
||||
# Default argument order is base, expression
|
||||
args = self._parse_csv(self._parse_range)
|
||||
|
||||
if len(args) > 1:
|
||||
if not self.LOG_BASE_FIRST:
|
||||
args.reverse()
|
||||
return exp.Log.from_arg_list(args)
|
||||
|
||||
return self.expression(
|
||||
exp.Ln if self.LOG_DEFAULTS_TO_LN else exp.Log, this=seq_get(args, 0)
|
||||
)
|
||||
|
||||
def _parse_match_against(self) -> exp.MatchAgainst:
|
||||
expressions = self._parse_csv(self._parse_column)
|
||||
|
||||
|
@ -4755,6 +4824,7 @@ class Parser(metaclass=_Parser):
|
|||
self, this: t.Optional[exp.Expression], explicit: bool = False
|
||||
) -> t.Optional[exp.Expression]:
|
||||
any_token = self._match(TokenType.ALIAS)
|
||||
comments = self._prev_comments
|
||||
|
||||
if explicit and not any_token:
|
||||
return this
|
||||
|
@ -4762,6 +4832,7 @@ class Parser(metaclass=_Parser):
|
|||
if self._match(TokenType.L_PAREN):
|
||||
aliases = self.expression(
|
||||
exp.Aliases,
|
||||
comments=comments,
|
||||
this=this,
|
||||
expressions=self._parse_csv(lambda: self._parse_id_var(any_token)),
|
||||
)
|
||||
|
@ -4771,7 +4842,7 @@ class Parser(metaclass=_Parser):
|
|||
alias = self._parse_id_var(any_token)
|
||||
|
||||
if alias:
|
||||
return self.expression(exp.Alias, this=this, alias=alias)
|
||||
return self.expression(exp.Alias, comments=comments, this=this, alias=alias)
|
||||
|
||||
return this
|
||||
|
||||
|
@ -4792,8 +4863,8 @@ class Parser(metaclass=_Parser):
|
|||
return None
|
||||
|
||||
def _parse_string(self) -> t.Optional[exp.Expression]:
|
||||
if self._match(TokenType.STRING):
|
||||
return self.PRIMARY_PARSERS[TokenType.STRING](self, self._prev)
|
||||
if self._match_set((TokenType.STRING, TokenType.RAW_STRING)):
|
||||
return self.PRIMARY_PARSERS[self._prev.token_type](self, self._prev)
|
||||
return self._parse_placeholder()
|
||||
|
||||
def _parse_string_as_identifier(self) -> t.Optional[exp.Identifier]:
|
||||
|
@ -4821,7 +4892,7 @@ class Parser(metaclass=_Parser):
|
|||
return self._parse_placeholder()
|
||||
|
||||
def _advance_any(self) -> t.Optional[Token]:
|
||||
if self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS:
|
||||
if self._curr and self._curr.token_type not in self.RESERVED_TOKENS:
|
||||
self._advance()
|
||||
return self._prev
|
||||
return None
|
||||
|
@ -4951,7 +5022,7 @@ class Parser(metaclass=_Parser):
|
|||
if self._match_texts(self.TRANSACTION_KIND):
|
||||
this = self._prev.text
|
||||
|
||||
self._match_texts({"TRANSACTION", "WORK"})
|
||||
self._match_texts(("TRANSACTION", "WORK"))
|
||||
|
||||
modes = []
|
||||
while True:
|
||||
|
@ -4971,7 +5042,7 @@ class Parser(metaclass=_Parser):
|
|||
savepoint = None
|
||||
is_rollback = self._prev.token_type == TokenType.ROLLBACK
|
||||
|
||||
self._match_texts({"TRANSACTION", "WORK"})
|
||||
self._match_texts(("TRANSACTION", "WORK"))
|
||||
|
||||
if self._match_text_seq("TO"):
|
||||
self._match_text_seq("SAVEPOINT")
|
||||
|
@ -4986,6 +5057,10 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return self.expression(exp.Commit, chain=chain)
|
||||
|
||||
def _parse_refresh(self) -> exp.Refresh:
|
||||
self._match(TokenType.TABLE)
|
||||
return self.expression(exp.Refresh, this=self._parse_string() or self._parse_table())
|
||||
|
||||
def _parse_add_column(self) -> t.Optional[exp.Expression]:
|
||||
if not self._match_text_seq("ADD"):
|
||||
return None
|
||||
|
@ -5050,10 +5125,9 @@ class Parser(metaclass=_Parser):
|
|||
return self._parse_csv(self._parse_add_constraint)
|
||||
|
||||
self._retreat(index)
|
||||
if not self.ALTER_TABLE_ADD_COLUMN_KEYWORD and self._match_text_seq("ADD"):
|
||||
return self._parse_csv(self._parse_field_def)
|
||||
|
||||
return self._parse_csv(self._parse_add_column)
|
||||
if not self.ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN and self._match_text_seq("ADD"):
|
||||
return self._parse_wrapped_csv(self._parse_field_def, optional=True)
|
||||
return self._parse_wrapped_csv(self._parse_add_column, optional=True)
|
||||
|
||||
def _parse_alter_table_alter(self) -> exp.AlterColumn:
|
||||
self._match(TokenType.COLUMN)
|
||||
|
@ -5198,7 +5272,7 @@ class Parser(metaclass=_Parser):
|
|||
) -> t.Optional[exp.Expression]:
|
||||
index = self._index
|
||||
|
||||
if kind in {"GLOBAL", "SESSION"} and self._match_text_seq("TRANSACTION"):
|
||||
if kind in ("GLOBAL", "SESSION") and self._match_text_seq("TRANSACTION"):
|
||||
return self._parse_set_transaction(global_=kind == "GLOBAL")
|
||||
|
||||
left = self._parse_primary() or self._parse_id_var()
|
||||
|
@ -5292,7 +5366,9 @@ class Parser(metaclass=_Parser):
|
|||
self._match_r_paren()
|
||||
return self.expression(exp.DictRange, this=this, min=min, max=max)
|
||||
|
||||
def _parse_comprehension(self, this: exp.Expression) -> t.Optional[exp.Comprehension]:
|
||||
def _parse_comprehension(
|
||||
self, this: t.Optional[exp.Expression]
|
||||
) -> t.Optional[exp.Comprehension]:
|
||||
index = self._index
|
||||
expression = self._parse_column()
|
||||
if not self._match(TokenType.IN):
|
||||
|
@ -5441,10 +5517,3 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
column.replace(dot_or_id)
|
||||
return node
|
||||
|
||||
def _ensure_string_if_null(self, values: t.List[exp.Expression]) -> t.List[exp.Expression]:
|
||||
return [
|
||||
exp.func("COALESCE", exp.cast(value, "text"), exp.Literal.string(""))
|
||||
for value in values
|
||||
if value
|
||||
]
|
||||
|
|
|
@ -15,8 +15,6 @@ if t.TYPE_CHECKING:
|
|||
|
||||
ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
|
||||
|
||||
TABLE_ARGS = ("this", "db", "catalog")
|
||||
|
||||
|
||||
class Schema(abc.ABC):
|
||||
"""Abstract base class for database schemas"""
|
||||
|
@ -147,7 +145,7 @@ class AbstractMappingSchema:
|
|||
if not depth: # None
|
||||
self._supported_table_args = tuple()
|
||||
elif 1 <= depth <= 3:
|
||||
self._supported_table_args = TABLE_ARGS[:depth]
|
||||
self._supported_table_args = exp.TABLE_PARTS[:depth]
|
||||
else:
|
||||
raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
|
||||
|
||||
|
@ -156,7 +154,7 @@ class AbstractMappingSchema:
|
|||
def table_parts(self, table: exp.Table) -> t.List[str]:
|
||||
if isinstance(table.this, exp.ReadCSV):
|
||||
return [table.this.name]
|
||||
return [table.text(part) for part in TABLE_ARGS if table.text(part)]
|
||||
return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
|
||||
|
||||
def find(
|
||||
self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
|
||||
|
@ -365,13 +363,11 @@ class MappingSchema(AbstractMappingSchema, Schema):
|
|||
f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}."
|
||||
)
|
||||
|
||||
normalized_keys = [
|
||||
self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys
|
||||
]
|
||||
normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
|
||||
for column_name, column_type in columns.items():
|
||||
nested_set(
|
||||
normalized_mapping,
|
||||
normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
|
||||
normalized_keys + [self._normalize_name(column_name)],
|
||||
column_type,
|
||||
)
|
||||
|
||||
|
@ -383,21 +379,19 @@ class MappingSchema(AbstractMappingSchema, Schema):
|
|||
dialect: DialectType = None,
|
||||
normalize: t.Optional[bool] = None,
|
||||
) -> exp.Table:
|
||||
normalized_table = exp.maybe_parse(
|
||||
table, into=exp.Table, dialect=dialect or self.dialect, copy=True
|
||||
)
|
||||
dialect = dialect or self.dialect
|
||||
normalize = self.normalize if normalize is None else normalize
|
||||
|
||||
for arg in TABLE_ARGS:
|
||||
value = normalized_table.args.get(arg)
|
||||
if isinstance(value, (str, exp.Identifier)):
|
||||
normalized_table.set(
|
||||
arg,
|
||||
exp.to_identifier(
|
||||
self._normalize_name(
|
||||
value, dialect=dialect, is_table=True, normalize=normalize
|
||||
)
|
||||
),
|
||||
)
|
||||
normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
|
||||
|
||||
if normalize:
|
||||
for arg in exp.TABLE_PARTS:
|
||||
value = normalized_table.args.get(arg)
|
||||
if isinstance(value, exp.Identifier):
|
||||
normalized_table.set(
|
||||
arg,
|
||||
normalize_name(value, dialect=dialect, is_table=True, normalize=normalize),
|
||||
)
|
||||
|
||||
return normalized_table
|
||||
|
||||
|
@ -413,7 +407,7 @@ class MappingSchema(AbstractMappingSchema, Schema):
|
|||
dialect=dialect or self.dialect,
|
||||
is_table=is_table,
|
||||
normalize=self.normalize if normalize is None else normalize,
|
||||
)
|
||||
).name
|
||||
|
||||
def depth(self) -> int:
|
||||
if not self.empty and not self._depth:
|
||||
|
@ -451,16 +445,16 @@ def normalize_name(
|
|||
dialect: DialectType = None,
|
||||
is_table: bool = False,
|
||||
normalize: t.Optional[bool] = True,
|
||||
) -> str:
|
||||
) -> exp.Identifier:
|
||||
if isinstance(identifier, str):
|
||||
identifier = exp.parse_identifier(identifier, dialect=dialect)
|
||||
|
||||
if not normalize:
|
||||
return identifier.name
|
||||
return identifier
|
||||
|
||||
# This can be useful for normalize_identifier
|
||||
# this is used for normalize_identifier, bigquery has special rules pertaining tables
|
||||
identifier.meta["is_table"] = is_table
|
||||
return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
|
||||
return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
|
||||
|
||||
|
||||
def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
|
||||
|
|
|
@ -42,6 +42,10 @@ def format_time(
|
|||
end -= 1
|
||||
chars = sym
|
||||
sym = None
|
||||
else:
|
||||
chars = chars[0]
|
||||
end = start + 1
|
||||
|
||||
start += len(chars)
|
||||
chunks.append(chars)
|
||||
current = trie
|
||||
|
|
|
@ -7,6 +7,9 @@ from sqlglot.errors import TokenError
|
|||
from sqlglot.helper import AutoName
|
||||
from sqlglot.trie import TrieResult, in_trie, new_trie
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
|
||||
|
||||
class TokenType(AutoName):
|
||||
L_PAREN = auto()
|
||||
|
@ -34,6 +37,7 @@ class TokenType(AutoName):
|
|||
EQ = auto()
|
||||
NEQ = auto()
|
||||
NULLSAFE_EQ = auto()
|
||||
COLON_EQ = auto()
|
||||
AND = auto()
|
||||
OR = auto()
|
||||
AMP = auto()
|
||||
|
@ -56,6 +60,7 @@ class TokenType(AutoName):
|
|||
SESSION_PARAMETER = auto()
|
||||
DAMP = auto()
|
||||
XOR = auto()
|
||||
DSTAR = auto()
|
||||
|
||||
BLOCK_START = auto()
|
||||
BLOCK_END = auto()
|
||||
|
@ -274,6 +279,7 @@ class TokenType(AutoName):
|
|||
OBJECT_IDENTIFIER = auto()
|
||||
OFFSET = auto()
|
||||
ON = auto()
|
||||
OPERATOR = auto()
|
||||
ORDER_BY = auto()
|
||||
ORDERED = auto()
|
||||
ORDINALITY = auto()
|
||||
|
@ -295,6 +301,7 @@ class TokenType(AutoName):
|
|||
QUOTE = auto()
|
||||
RANGE = auto()
|
||||
RECURSIVE = auto()
|
||||
REFRESH = auto()
|
||||
REPLACE = auto()
|
||||
RETURNING = auto()
|
||||
REFERENCES = auto()
|
||||
|
@ -371,7 +378,7 @@ class Token:
|
|||
col: int = 1,
|
||||
start: int = 0,
|
||||
end: int = 0,
|
||||
comments: t.List[str] = [],
|
||||
comments: t.Optional[t.List[str]] = None,
|
||||
) -> None:
|
||||
"""Token initializer.
|
||||
|
||||
|
@ -390,7 +397,7 @@ class Token:
|
|||
self.col = col
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.comments = comments
|
||||
self.comments = [] if comments is None else comments
|
||||
|
||||
def __repr__(self) -> str:
|
||||
attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
|
||||
|
@ -497,11 +504,8 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
|
||||
STRING_ESCAPES = ["'"]
|
||||
VAR_SINGLE_TOKENS: t.Set[str] = set()
|
||||
ESCAPE_SEQUENCES: t.Dict[str, str] = {}
|
||||
|
||||
# Autofilled
|
||||
IDENTIFIERS_CAN_START_WITH_DIGIT: bool = False
|
||||
|
||||
_COMMENTS: t.Dict[str, str] = {}
|
||||
_FORMAT_STRINGS: t.Dict[str, t.Tuple[str, TokenType]] = {}
|
||||
_IDENTIFIERS: t.Dict[str, str] = {}
|
||||
|
@ -523,6 +527,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"<=": TokenType.LTE,
|
||||
"<>": TokenType.NEQ,
|
||||
"!=": TokenType.NEQ,
|
||||
":=": TokenType.COLON_EQ,
|
||||
"<=>": TokenType.NULLSAFE_EQ,
|
||||
"->": TokenType.ARROW,
|
||||
"->>": TokenType.DARROW,
|
||||
|
@ -689,17 +694,22 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"BOOLEAN": TokenType.BOOLEAN,
|
||||
"BYTE": TokenType.TINYINT,
|
||||
"MEDIUMINT": TokenType.MEDIUMINT,
|
||||
"INT1": TokenType.TINYINT,
|
||||
"TINYINT": TokenType.TINYINT,
|
||||
"INT16": TokenType.SMALLINT,
|
||||
"SHORT": TokenType.SMALLINT,
|
||||
"SMALLINT": TokenType.SMALLINT,
|
||||
"INT128": TokenType.INT128,
|
||||
"HUGEINT": TokenType.INT128,
|
||||
"INT2": TokenType.SMALLINT,
|
||||
"INTEGER": TokenType.INT,
|
||||
"INT": TokenType.INT,
|
||||
"INT4": TokenType.INT,
|
||||
"INT32": TokenType.INT,
|
||||
"INT64": TokenType.BIGINT,
|
||||
"LONG": TokenType.BIGINT,
|
||||
"BIGINT": TokenType.BIGINT,
|
||||
"INT8": TokenType.BIGINT,
|
||||
"INT8": TokenType.TINYINT,
|
||||
"DEC": TokenType.DECIMAL,
|
||||
"DECIMAL": TokenType.DECIMAL,
|
||||
"BIGDECIMAL": TokenType.BIGDECIMAL,
|
||||
|
@ -781,7 +791,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"\t": TokenType.SPACE,
|
||||
"\n": TokenType.BREAK,
|
||||
"\r": TokenType.BREAK,
|
||||
"\r\n": TokenType.BREAK,
|
||||
}
|
||||
|
||||
COMMANDS = {
|
||||
|
@ -803,6 +812,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"sql",
|
||||
"size",
|
||||
"tokens",
|
||||
"dialect",
|
||||
"_start",
|
||||
"_current",
|
||||
"_line",
|
||||
|
@ -814,7 +824,10 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"_prev_token_line",
|
||||
)
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, dialect: DialectType = None) -> None:
|
||||
from sqlglot.dialects import Dialect
|
||||
|
||||
self.dialect = Dialect.get_or_raise(dialect)
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
|
@ -850,13 +863,26 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
def _scan(self, until: t.Optional[t.Callable] = None) -> None:
|
||||
while self.size and not self._end:
|
||||
self._start = self._current
|
||||
self._advance()
|
||||
current = self._current
|
||||
|
||||
# skip spaces inline rather than iteratively call advance()
|
||||
# for performance reasons
|
||||
while current < self.size:
|
||||
char = self.sql[current]
|
||||
|
||||
if char.isspace() and (char == " " or char == "\t"):
|
||||
current += 1
|
||||
else:
|
||||
break
|
||||
|
||||
n = current - self._current
|
||||
self._start = current
|
||||
self._advance(n if n > 1 else 1)
|
||||
|
||||
if self._char is None:
|
||||
break
|
||||
|
||||
if self._char not in self.WHITE_SPACE:
|
||||
if not self._char.isspace():
|
||||
if self._char.isdigit():
|
||||
self._scan_number()
|
||||
elif self._char in self._IDENTIFIERS:
|
||||
|
@ -881,6 +907,10 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
def _advance(self, i: int = 1, alnum: bool = False) -> None:
|
||||
if self.WHITE_SPACE.get(self._char) is TokenType.BREAK:
|
||||
# Ensures we don't count an extra line if we get a \r\n line break sequence
|
||||
if self._char == "\r" and self._peek == "\n":
|
||||
i = 2
|
||||
|
||||
self._col = 1
|
||||
self._line += 1
|
||||
else:
|
||||
|
@ -982,7 +1012,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
if end < self.size:
|
||||
char = self.sql[end]
|
||||
single_token = single_token or char in self.SINGLE_TOKENS
|
||||
is_space = char in self.WHITE_SPACE
|
||||
is_space = char.isspace()
|
||||
|
||||
if not is_space or not prev_space:
|
||||
if is_space:
|
||||
|
@ -994,7 +1024,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
skip = True
|
||||
else:
|
||||
char = ""
|
||||
chars = " "
|
||||
break
|
||||
|
||||
if word:
|
||||
if self._scan_string(word):
|
||||
|
@ -1086,7 +1116,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._add(TokenType.NUMBER, number_text)
|
||||
self._add(TokenType.DCOLON, "::")
|
||||
return self._add(token_type, literal)
|
||||
elif self.IDENTIFIERS_CAN_START_WITH_DIGIT:
|
||||
elif self.dialect.IDENTIFIERS_CAN_START_WITH_DIGIT:
|
||||
return self._add(TokenType.VAR)
|
||||
|
||||
self._advance(-len(literal))
|
||||
|
@ -1208,8 +1238,12 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
if self._end:
|
||||
raise TokenError(f"Missing {delimiter} from {self._line}:{self._start}")
|
||||
|
||||
if self.ESCAPE_SEQUENCES and self._peek and self._char in self.STRING_ESCAPES:
|
||||
escaped_sequence = self.ESCAPE_SEQUENCES.get(self._char + self._peek)
|
||||
if (
|
||||
self.dialect.ESCAPE_SEQUENCES
|
||||
and self._peek
|
||||
and self._char in self.STRING_ESCAPES
|
||||
):
|
||||
escaped_sequence = self.dialect.ESCAPE_SEQUENCES.get(self._char + self._peek)
|
||||
if escaped_sequence:
|
||||
self._advance(2)
|
||||
text += escaped_sequence
|
||||
|
|
|
@ -141,7 +141,7 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr
|
|||
|
||||
|
||||
def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
|
||||
"""Convert cross join unnest into lateral view explode (used in presto -> hive)."""
|
||||
"""Convert cross join unnest into lateral view explode."""
|
||||
if isinstance(expression, exp.Select):
|
||||
for join in expression.args.get("joins") or []:
|
||||
unnest = join.this
|
||||
|
@ -166,7 +166,7 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
|
|||
|
||||
|
||||
def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
|
||||
"""Convert explode/posexplode into unnest (used in hive -> presto)."""
|
||||
"""Convert explode/posexplode into unnest."""
|
||||
|
||||
def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
|
||||
if isinstance(expression, exp.Select):
|
||||
|
@ -199,11 +199,11 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
|
|||
explode_alias = ""
|
||||
|
||||
if isinstance(select, exp.Alias):
|
||||
explode_alias = select.alias
|
||||
explode_alias = select.args["alias"]
|
||||
alias = select
|
||||
elif isinstance(select, exp.Aliases):
|
||||
pos_alias = select.aliases[0].name
|
||||
explode_alias = select.aliases[1].name
|
||||
pos_alias = select.aliases[0]
|
||||
explode_alias = select.aliases[1]
|
||||
alias = select.replace(exp.alias_(select.this, "", copy=False))
|
||||
else:
|
||||
alias = select.replace(exp.alias_(select, ""))
|
||||
|
@ -230,9 +230,12 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
|
|||
|
||||
alias.set("alias", exp.to_identifier(explode_alias))
|
||||
|
||||
series_table_alias = series.args["alias"].this
|
||||
column = exp.If(
|
||||
this=exp.column(series_alias).eq(exp.column(pos_alias)),
|
||||
true=exp.column(explode_alias),
|
||||
this=exp.column(series_alias, table=series_table_alias).eq(
|
||||
exp.column(pos_alias, table=unnest_source_alias)
|
||||
),
|
||||
true=exp.column(explode_alias, table=unnest_source_alias),
|
||||
)
|
||||
|
||||
explode.replace(column)
|
||||
|
@ -242,8 +245,10 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
|
|||
expressions.insert(
|
||||
expressions.index(alias) + 1,
|
||||
exp.If(
|
||||
this=exp.column(series_alias).eq(exp.column(pos_alias)),
|
||||
true=exp.column(pos_alias),
|
||||
this=exp.column(series_alias, table=series_table_alias).eq(
|
||||
exp.column(pos_alias, table=unnest_source_alias)
|
||||
),
|
||||
true=exp.column(pos_alias, table=unnest_source_alias),
|
||||
).as_(pos_alias),
|
||||
)
|
||||
expression.set("expressions", expressions)
|
||||
|
@ -276,10 +281,12 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
|
|||
size = size - 1
|
||||
|
||||
expression.where(
|
||||
exp.column(series_alias)
|
||||
.eq(exp.column(pos_alias))
|
||||
exp.column(series_alias, table=series_table_alias)
|
||||
.eq(exp.column(pos_alias, table=unnest_source_alias))
|
||||
.or_(
|
||||
(exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size))
|
||||
(exp.column(series_alias, table=series_table_alias) > size).and_(
|
||||
exp.column(pos_alias, table=unnest_source_alias).eq(size)
|
||||
)
|
||||
),
|
||||
copy=False,
|
||||
)
|
||||
|
@ -386,14 +393,16 @@ def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
|
|||
full_outer_joins = [
|
||||
(index, join)
|
||||
for index, join in enumerate(expression.args.get("joins") or [])
|
||||
if join.side == "FULL" and join.kind == "OUTER"
|
||||
if join.side == "FULL"
|
||||
]
|
||||
|
||||
if len(full_outer_joins) == 1:
|
||||
expression_copy = expression.copy()
|
||||
expression.set("limit", None)
|
||||
index, full_outer_join = full_outer_joins[0]
|
||||
full_outer_join.set("side", "left")
|
||||
expression_copy.args["joins"][index].set("side", "right")
|
||||
expression_copy.args.pop("with", None) # remove CTEs from RIGHT side
|
||||
|
||||
return exp.union(expression, expression_copy, copy=False)
|
||||
|
||||
|
@ -430,6 +439,33 @@ def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def ensure_bools(expression: exp.Expression) -> exp.Expression:
|
||||
"""Converts numeric values used in conditions into explicit boolean expressions."""
|
||||
from sqlglot.optimizer.canonicalize import ensure_bools
|
||||
|
||||
def _ensure_bool(node: exp.Expression) -> None:
|
||||
if (
|
||||
node.is_number
|
||||
or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
|
||||
or (isinstance(node, exp.Column) and not node.type)
|
||||
):
|
||||
node.replace(node.neq(0))
|
||||
|
||||
for node, *_ in expression.walk():
|
||||
ensure_bools(node, _ensure_bool)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def unqualify_columns(expression: exp.Expression) -> exp.Expression:
|
||||
for column in expression.find_all(exp.Column):
|
||||
# We only wanna pop off the table, db, catalog args
|
||||
for part in column.parts[:-1]:
|
||||
part.pop()
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def preprocess(
|
||||
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
|
||||
) -> t.Callable[[Generator, exp.Expression], str]:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue