Merging upstream version 11.7.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
0c053462ae
commit
8d96084fad
144 changed files with 44104 additions and 39367 deletions
|
@ -21,10 +21,12 @@ from sqlglot.expressions import (
|
|||
Expression as Expression,
|
||||
alias_ as alias,
|
||||
and_ as and_,
|
||||
cast as cast,
|
||||
column as column,
|
||||
condition as condition,
|
||||
except_ as except_,
|
||||
from_ as from_,
|
||||
func as func,
|
||||
intersect as intersect,
|
||||
maybe_parse as maybe_parse,
|
||||
not_ as not_,
|
||||
|
@ -33,6 +35,7 @@ from sqlglot.expressions import (
|
|||
subquery as subquery,
|
||||
table_ as table,
|
||||
to_column as to_column,
|
||||
to_identifier as to_identifier,
|
||||
to_table as to_table,
|
||||
union as union,
|
||||
)
|
||||
|
@ -47,7 +50,7 @@ if t.TYPE_CHECKING:
|
|||
T = t.TypeVar("T", bound=Expression)
|
||||
|
||||
|
||||
__version__ = "11.5.2"
|
||||
__version__ = "11.7.1"
|
||||
|
||||
pretty = False
|
||||
"""Whether to format generated SQL by default."""
|
||||
|
|
|
@ -176,7 +176,7 @@ class Column:
|
|||
return isinstance(self.expression, exp.Column)
|
||||
|
||||
@property
|
||||
def column_expression(self) -> exp.Column:
|
||||
def column_expression(self) -> t.Union[exp.Column, exp.Literal]:
|
||||
return self.expression.unalias()
|
||||
|
||||
@property
|
||||
|
|
|
@ -16,7 +16,7 @@ from sqlglot.dataframe.sql.readwriter import DataFrameWriter
|
|||
from sqlglot.dataframe.sql.transforms import replace_id_value
|
||||
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
|
||||
from sqlglot.dataframe.sql.window import Window
|
||||
from sqlglot.helper import ensure_list, object_to_dict
|
||||
from sqlglot.helper import ensure_list, object_to_dict, seq_get
|
||||
from sqlglot.optimizer import optimize as optimize_func
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
@ -146,9 +146,9 @@ class DataFrame:
|
|||
def _ensure_list_of_columns(self, cols):
|
||||
return Column.ensure_cols(ensure_list(cols))
|
||||
|
||||
def _ensure_and_normalize_cols(self, cols):
|
||||
def _ensure_and_normalize_cols(self, cols, expression: t.Optional[exp.Select] = None):
|
||||
cols = self._ensure_list_of_columns(cols)
|
||||
normalize(self.spark, self.expression, cols)
|
||||
normalize(self.spark, expression or self.expression, cols)
|
||||
return cols
|
||||
|
||||
def _ensure_and_normalize_col(self, col):
|
||||
|
@ -355,12 +355,20 @@ class DataFrame:
|
|||
cols = self._ensure_and_normalize_cols(cols)
|
||||
kwargs["append"] = kwargs.get("append", False)
|
||||
if self.expression.args.get("joins"):
|
||||
ambiguous_cols = [col for col in cols if not col.column_expression.table]
|
||||
ambiguous_cols = [
|
||||
col
|
||||
for col in cols
|
||||
if isinstance(col.column_expression, exp.Column) and not col.column_expression.table
|
||||
]
|
||||
if ambiguous_cols:
|
||||
join_table_identifiers = [
|
||||
x.this for x in get_tables_from_expression_with_join(self.expression)
|
||||
]
|
||||
cte_names_in_join = [x.this for x in join_table_identifiers]
|
||||
# If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right
|
||||
# and therefore we allow multiple columns with the same name in the result. This matches the behavior
|
||||
# of Spark.
|
||||
resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols}
|
||||
for ambiguous_col in ambiguous_cols:
|
||||
ctes_with_column = [
|
||||
cte
|
||||
|
@ -368,13 +376,14 @@ class DataFrame:
|
|||
if cte.alias_or_name in cte_names_in_join
|
||||
and ambiguous_col.alias_or_name in cte.this.named_selects
|
||||
]
|
||||
# If the select column does not specify a table and there is a join
|
||||
# then we assume they are referring to the left table
|
||||
if len(ctes_with_column) > 1:
|
||||
table_identifier = self.expression.args["from"].args["expressions"][0].this
|
||||
# Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise,
|
||||
# use the same CTE we used before
|
||||
cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1)
|
||||
if cte:
|
||||
resolved_column_position[ambiguous_col] += 1
|
||||
else:
|
||||
table_identifier = ctes_with_column[0].args["alias"].this
|
||||
ambiguous_col.expression.set("table", table_identifier)
|
||||
cte = ctes_with_column[resolved_column_position[ambiguous_col]]
|
||||
ambiguous_col.expression.set("table", cte.alias_or_name)
|
||||
return self.copy(
|
||||
expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs
|
||||
)
|
||||
|
@ -416,59 +425,87 @@ class DataFrame:
|
|||
**kwargs,
|
||||
) -> DataFrame:
|
||||
other_df = other_df._convert_leaf_to_cte()
|
||||
pre_join_self_latest_cte_name = self.latest_cte_name
|
||||
columns = self._ensure_and_normalize_cols(on)
|
||||
join_type = how.replace("_", " ")
|
||||
if isinstance(columns[0].expression, exp.Column):
|
||||
join_columns = [
|
||||
Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns
|
||||
join_columns = self._ensure_list_of_columns(on)
|
||||
# We will determine actual "join on" expression later so we don't provide it at first
|
||||
join_expression = self.expression.join(
|
||||
other_df.latest_cte_name, join_type=how.replace("_", " ")
|
||||
)
|
||||
join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes)
|
||||
self_columns = self._get_outer_select_columns(join_expression)
|
||||
other_columns = self._get_outer_select_columns(other_df)
|
||||
# Determines the join clause and select columns to be used passed on what type of columns were provided for
|
||||
# the join. The columns returned changes based on how the on expression is provided.
|
||||
if isinstance(join_columns[0].expression, exp.Column):
|
||||
"""
|
||||
Unique characteristics of join on column names only:
|
||||
* The column names are put at the front of the select list
|
||||
* The column names are deduplicated across the entire select list and only the column names (other dups are allowed)
|
||||
"""
|
||||
table_names = [
|
||||
table.alias_or_name
|
||||
for table in get_tables_from_expression_with_join(join_expression)
|
||||
]
|
||||
potential_ctes = [
|
||||
cte
|
||||
for cte in join_expression.ctes
|
||||
if cte.alias_or_name in table_names
|
||||
and cte.alias_or_name != other_df.latest_cte_name
|
||||
]
|
||||
# Determine the table to reference for the left side of the join by checking each of the left side
|
||||
# tables and see if they have the column being referenced.
|
||||
join_column_pairs = []
|
||||
for join_column in join_columns:
|
||||
num_matching_ctes = 0
|
||||
for cte in potential_ctes:
|
||||
if join_column.alias_or_name in cte.this.named_selects:
|
||||
left_column = join_column.copy().set_table_name(cte.alias_or_name)
|
||||
right_column = join_column.copy().set_table_name(other_df.latest_cte_name)
|
||||
join_column_pairs.append((left_column, right_column))
|
||||
num_matching_ctes += 1
|
||||
if num_matching_ctes > 1:
|
||||
raise ValueError(
|
||||
f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name."
|
||||
)
|
||||
elif num_matching_ctes == 0:
|
||||
raise ValueError(
|
||||
f"Column {join_column.alias_or_name} does not exist in any of the tables."
|
||||
)
|
||||
join_clause = functools.reduce(
|
||||
lambda x, y: x & y,
|
||||
[
|
||||
col.copy().set_table_name(pre_join_self_latest_cte_name)
|
||||
== col.copy().set_table_name(other_df.latest_cte_name)
|
||||
for col in columns
|
||||
],
|
||||
[left_column == right_column for left_column, right_column in join_column_pairs],
|
||||
)
|
||||
else:
|
||||
if len(columns) > 1:
|
||||
columns = [functools.reduce(lambda x, y: x & y, columns)]
|
||||
join_clause = columns[0]
|
||||
join_columns = [
|
||||
Column(x).set_table_name(pre_join_self_latest_cte_name)
|
||||
if i % 2 == 0
|
||||
else Column(x).set_table_name(other_df.latest_cte_name)
|
||||
for i, x in enumerate(join_clause.expression.find_all(exp.Column))
|
||||
join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs]
|
||||
# To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list
|
||||
select_column_names = [
|
||||
column.alias_or_name
|
||||
if not isinstance(column.expression.this, exp.Star)
|
||||
else column.sql()
|
||||
for column in self_columns + other_columns
|
||||
]
|
||||
self_columns = [
|
||||
column.set_table_name(pre_join_self_latest_cte_name, copy=True)
|
||||
for column in self._get_outer_select_columns(self)
|
||||
]
|
||||
other_columns = [
|
||||
column.set_table_name(other_df.latest_cte_name, copy=True)
|
||||
for column in self._get_outer_select_columns(other_df)
|
||||
]
|
||||
column_value_mapping = {
|
||||
column.alias_or_name
|
||||
if not isinstance(column.expression.this, exp.Star)
|
||||
else column.sql(): column
|
||||
for column in other_columns + self_columns + join_columns
|
||||
}
|
||||
all_columns = [
|
||||
column_value_mapping[name]
|
||||
for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns}
|
||||
]
|
||||
new_df = self.copy(
|
||||
expression=self.expression.join(
|
||||
other_df.latest_cte_name, on=join_clause.expression, join_type=join_type
|
||||
)
|
||||
)
|
||||
new_df.expression = new_df._add_ctes_to_expression(
|
||||
new_df.expression, other_df.expression.ctes
|
||||
)
|
||||
select_column_names = [
|
||||
column_name
|
||||
for column_name in select_column_names
|
||||
if column_name not in join_column_names
|
||||
]
|
||||
select_column_names = join_column_names + select_column_names
|
||||
else:
|
||||
"""
|
||||
Unique characteristics of join on expressions:
|
||||
* There is no deduplication of the results.
|
||||
* The left join dataframe columns go first and right come after. No sort preference is given to join columns
|
||||
"""
|
||||
join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
|
||||
if len(join_columns) > 1:
|
||||
join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
|
||||
join_clause = join_columns[0]
|
||||
select_column_names = [column.alias_or_name for column in self_columns + other_columns]
|
||||
|
||||
# Update the on expression with the actual join clause to replace the dummy one from before
|
||||
join_expression.args["joins"][-1].set("on", join_clause.expression)
|
||||
new_df = self.copy(expression=join_expression)
|
||||
new_df.pending_join_hints.extend(self.pending_join_hints)
|
||||
new_df.pending_hints.extend(other_df.pending_hints)
|
||||
new_df = new_df.select.__wrapped__(new_df, *all_columns)
|
||||
new_df = new_df.select.__wrapped__(new_df, *select_column_names)
|
||||
return new_df
|
||||
|
||||
@operation(Operation.ORDER_BY)
|
||||
|
|
|
@ -577,11 +577,15 @@ def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Col
|
|||
|
||||
|
||||
def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
|
||||
return Column.invoke_expression_over_column(col, expression.DateAdd, expression=days)
|
||||
return Column.invoke_expression_over_column(
|
||||
col, expression.DateAdd, expression=days, unit=expression.Var(this="day")
|
||||
)
|
||||
|
||||
|
||||
def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
|
||||
return Column.invoke_expression_over_column(col, expression.DateSub, expression=days)
|
||||
return Column.invoke_expression_over_column(
|
||||
col, expression.DateSub, expression=days, unit=expression.Var(this="day")
|
||||
)
|
||||
|
||||
|
||||
def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column:
|
||||
|
@ -695,18 +699,17 @@ def crc32(col: ColumnOrName) -> Column:
|
|||
|
||||
def md5(col: ColumnOrName) -> Column:
|
||||
column = col if isinstance(col, Column) else lit(col)
|
||||
return Column.invoke_anonymous_function(column, "MD5")
|
||||
return Column.invoke_expression_over_column(column, expression.MD5)
|
||||
|
||||
|
||||
def sha1(col: ColumnOrName) -> Column:
|
||||
column = col if isinstance(col, Column) else lit(col)
|
||||
return Column.invoke_anonymous_function(column, "SHA1")
|
||||
return Column.invoke_expression_over_column(column, expression.SHA)
|
||||
|
||||
|
||||
def sha2(col: ColumnOrName, numBits: int) -> Column:
|
||||
column = col if isinstance(col, Column) else lit(col)
|
||||
num_bits = lit(numBits)
|
||||
return Column.invoke_anonymous_function(column, "SHA2", num_bits)
|
||||
return Column.invoke_expression_over_column(column, expression.SHA2, length=lit(numBits))
|
||||
|
||||
|
||||
def hash(*cols: ColumnOrName) -> Column:
|
||||
|
|
|
@ -4,7 +4,7 @@ import typing as t
|
|||
|
||||
import sqlglot
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.helper import object_to_dict
|
||||
from sqlglot.helper import object_to_dict, should_identify
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dataframe.sql.dataframe import DataFrame
|
||||
|
@ -19,9 +19,17 @@ class DataFrameReader:
|
|||
from sqlglot.dataframe.sql.dataframe import DataFrame
|
||||
|
||||
sqlglot.schema.add_table(tableName)
|
||||
|
||||
return DataFrame(
|
||||
self.spark,
|
||||
exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)),
|
||||
exp.Select()
|
||||
.from_(tableName)
|
||||
.select(
|
||||
*(
|
||||
column if should_identify(column, "safe") else f'"{column}"'
|
||||
for column in sqlglot.schema.column_names(tableName)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ from sqlglot.dialects.dialect import (
|
|||
max_or_greatest,
|
||||
min_or_least,
|
||||
no_ilike_sql,
|
||||
parse_date_delta_with_interval,
|
||||
rename_func,
|
||||
timestrtotime_sql,
|
||||
ts_or_ds_to_date_sql,
|
||||
|
@ -23,18 +24,6 @@ from sqlglot.tokens import TokenType
|
|||
E = t.TypeVar("E", bound=exp.Expression)
|
||||
|
||||
|
||||
def _date_add(expression_class: t.Type[E]) -> t.Callable[[t.Sequence], E]:
|
||||
def func(args):
|
||||
interval = seq_get(args, 1)
|
||||
return expression_class(
|
||||
this=seq_get(args, 0),
|
||||
expression=interval.this,
|
||||
unit=interval.args.get("unit"),
|
||||
)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def _date_add_sql(
|
||||
data_type: str, kind: str
|
||||
) -> t.Callable[[generator.Generator, exp.Expression], str]:
|
||||
|
@ -142,6 +131,7 @@ class BigQuery(Dialect):
|
|||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"ANY TYPE": TokenType.VARIANT,
|
||||
"BEGIN": TokenType.COMMAND,
|
||||
"BEGIN TRANSACTION": TokenType.BEGIN,
|
||||
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
|
||||
|
@ -155,14 +145,19 @@ class BigQuery(Dialect):
|
|||
KEYWORDS.pop("DIV")
|
||||
|
||||
class Parser(parser.Parser):
|
||||
PREFIXED_PIVOT_COLUMNS = True
|
||||
|
||||
LOG_BASE_FIRST = False
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"DATE_TRUNC": lambda args: exp.DateTrunc(
|
||||
unit=exp.Literal.string(seq_get(args, 1).name), # type: ignore
|
||||
this=seq_get(args, 0),
|
||||
),
|
||||
"DATE_ADD": _date_add(exp.DateAdd),
|
||||
"DATETIME_ADD": _date_add(exp.DatetimeAdd),
|
||||
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
|
||||
"DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd),
|
||||
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
|
||||
"REGEXP_CONTAINS": exp.RegexpLike.from_arg_list,
|
||||
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
|
||||
|
@ -174,12 +169,12 @@ class BigQuery(Dialect):
|
|||
if re.compile(str(seq_get(args, 1))).groups == 1
|
||||
else None,
|
||||
),
|
||||
"TIME_ADD": _date_add(exp.TimeAdd),
|
||||
"TIMESTAMP_ADD": _date_add(exp.TimestampAdd),
|
||||
"DATE_SUB": _date_add(exp.DateSub),
|
||||
"DATETIME_SUB": _date_add(exp.DatetimeSub),
|
||||
"TIME_SUB": _date_add(exp.TimeSub),
|
||||
"TIMESTAMP_SUB": _date_add(exp.TimestampSub),
|
||||
"TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd),
|
||||
"TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd),
|
||||
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
|
||||
"DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub),
|
||||
"TIME_SUB": parse_date_delta_with_interval(exp.TimeSub),
|
||||
"TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub),
|
||||
"PARSE_TIMESTAMP": lambda args: exp.StrToTime(
|
||||
this=seq_get(args, 1), format=seq_get(args, 0)
|
||||
),
|
||||
|
@ -209,14 +204,17 @@ class BigQuery(Dialect):
|
|||
PROPERTY_PARSERS = {
|
||||
**parser.Parser.PROPERTY_PARSERS, # type: ignore
|
||||
"NOT DETERMINISTIC": lambda self: self.expression(
|
||||
exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
|
||||
exp.StabilityProperty, this=exp.Literal.string("VOLATILE")
|
||||
),
|
||||
}
|
||||
|
||||
LOG_BASE_FIRST = False
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
class Generator(generator.Generator):
|
||||
EXPLICIT_UNION = True
|
||||
INTERVAL_ALLOWS_PLURAL_FORM = False
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore
|
||||
|
@ -236,9 +234,7 @@ class BigQuery(Dialect):
|
|||
exp.IntDiv: rename_func("DIV"),
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.Select: transforms.preprocess(
|
||||
[_unqualify_unnest], transforms.delegate("select_sql")
|
||||
),
|
||||
exp.Select: transforms.preprocess([_unqualify_unnest]),
|
||||
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
|
||||
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
|
||||
exp.TimeSub: _date_add_sql("TIME", "SUB"),
|
||||
|
@ -253,7 +249,7 @@ class BigQuery(Dialect):
|
|||
exp.ReturnsProperty: _returnsproperty_sql,
|
||||
exp.Create: _create_sql,
|
||||
exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression),
|
||||
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
|
||||
exp.StabilityProperty: lambda self, e: f"DETERMINISTIC"
|
||||
if e.name == "IMMUTABLE"
|
||||
else "NOT DETERMINISTIC",
|
||||
exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
|
||||
|
@ -261,6 +257,7 @@ class BigQuery(Dialect):
|
|||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC",
|
||||
exp.DataType.Type.BIGINT: "INT64",
|
||||
exp.DataType.Type.BOOLEAN: "BOOL",
|
||||
exp.DataType.Type.CHAR: "STRING",
|
||||
|
@ -272,17 +269,19 @@ class BigQuery(Dialect):
|
|||
exp.DataType.Type.NVARCHAR: "STRING",
|
||||
exp.DataType.Type.SMALLINT: "INT64",
|
||||
exp.DataType.Type.TEXT: "STRING",
|
||||
exp.DataType.Type.TIMESTAMP: "DATETIME",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
|
||||
exp.DataType.Type.TINYINT: "INT64",
|
||||
exp.DataType.Type.VARCHAR: "STRING",
|
||||
exp.DataType.Type.VARIANT: "ANY TYPE",
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
EXPLICIT_UNION = True
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
|
||||
def array_sql(self, expression: exp.Array) -> str:
|
||||
first_arg = seq_get(expression.expressions, 0)
|
||||
if isinstance(first_arg, exp.Subqueryable):
|
||||
|
|
|
@ -144,6 +144,13 @@ class ClickHouse(Dialect):
|
|||
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
EXPLICIT_UNION = True
|
||||
|
||||
def _param_args_sql(
|
||||
|
|
|
@ -9,6 +9,8 @@ from sqlglot.tokens import TokenType
|
|||
|
||||
class Databricks(Spark):
|
||||
class Parser(Spark.Parser):
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
FUNCTIONS = {
|
||||
**Spark.Parser.FUNCTIONS,
|
||||
"DATEADD": parse_date_delta(exp.DateAdd),
|
||||
|
@ -16,13 +18,17 @@ class Databricks(Spark):
|
|||
"DATEDIFF": parse_date_delta(exp.DateDiff),
|
||||
}
|
||||
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
FACTOR = {
|
||||
**Spark.Parser.FACTOR,
|
||||
TokenType.COLON: exp.JSONExtract,
|
||||
}
|
||||
|
||||
class Generator(Spark.Generator):
|
||||
TRANSFORMS = {
|
||||
**Spark.Generator.TRANSFORMS, # type: ignore
|
||||
exp.DateAdd: generate_date_delta_with_unit_sql,
|
||||
exp.DateDiff: generate_date_delta_with_unit_sql,
|
||||
exp.JSONExtract: lambda self, e: self.binary(e, ":"),
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
}
|
||||
TRANSFORMS.pop(exp.Select) # Remove the ELIMINATE_QUALIFY transformation
|
||||
|
|
|
@ -293,6 +293,13 @@ def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
|
|||
return ""
|
||||
|
||||
|
||||
def no_comment_column_constraint_sql(
|
||||
self: Generator, expression: exp.CommentColumnConstraint
|
||||
) -> str:
|
||||
self.unsupported("CommentColumnConstraint unsupported")
|
||||
return ""
|
||||
|
||||
|
||||
def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
substr = self.sql(expression, "substr")
|
||||
|
@ -379,15 +386,35 @@ def parse_date_delta(
|
|||
) -> t.Callable[[t.Sequence], E]:
|
||||
def inner_func(args: t.Sequence) -> E:
|
||||
unit_based = len(args) == 3
|
||||
this = seq_get(args, 2) if unit_based else seq_get(args, 0)
|
||||
expression = seq_get(args, 1) if unit_based else seq_get(args, 1)
|
||||
unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY")
|
||||
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit # type: ignore
|
||||
return exp_class(this=this, expression=expression, unit=unit)
|
||||
this = args[2] if unit_based else seq_get(args, 0)
|
||||
unit = args[0] if unit_based else exp.Literal.string("DAY")
|
||||
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
|
||||
return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
|
||||
|
||||
return inner_func
|
||||
|
||||
|
||||
def parse_date_delta_with_interval(
|
||||
expression_class: t.Type[E],
|
||||
) -> t.Callable[[t.Sequence], t.Optional[E]]:
|
||||
def func(args: t.Sequence) -> t.Optional[E]:
|
||||
if len(args) < 2:
|
||||
return None
|
||||
|
||||
interval = args[1]
|
||||
expression = interval.this
|
||||
if expression and expression.is_string:
|
||||
expression = exp.Literal.number(expression.this)
|
||||
|
||||
return expression_class(
|
||||
this=args[0],
|
||||
expression=expression,
|
||||
unit=exp.Literal.string(interval.text("unit")),
|
||||
)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
|
||||
unit = seq_get(args, 0)
|
||||
this = seq_get(args, 1)
|
||||
|
|
|
@ -104,6 +104,9 @@ class Drill(Dialect):
|
|||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
class Generator(generator.Generator):
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.INT: "INTEGER",
|
||||
|
@ -120,6 +123,7 @@ class Drill(Dialect):
|
|||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
|
|
|
@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
|
|||
arrow_json_extract_sql,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
no_comment_column_constraint_sql,
|
||||
no_pivot_sql,
|
||||
no_properties_sql,
|
||||
no_safe_divide_sql,
|
||||
|
@ -23,7 +24,7 @@ from sqlglot.tokens import TokenType
|
|||
|
||||
|
||||
def _ts_or_ds_add(self, expression):
|
||||
this = expression.args.get("this")
|
||||
this = self.sql(expression, "this")
|
||||
unit = self.sql(expression, "unit").strip("'") or "DAY"
|
||||
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
|
||||
|
||||
|
@ -139,6 +140,8 @@ class DuckDB(Dialect):
|
|||
}
|
||||
|
||||
class Generator(generator.Generator):
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
||||
TRANSFORMS = {
|
||||
|
@ -150,6 +153,7 @@ class DuckDB(Dialect):
|
|||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||
exp.ArraySort: _array_sort_sql,
|
||||
exp.ArraySum: rename_func("LIST_SUM"),
|
||||
exp.CommentColumnConstraint: no_comment_column_constraint_sql,
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
|
@ -213,6 +217,11 @@ class DuckDB(Dialect):
|
|||
"except": "EXCLUDE",
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
|
||||
def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str:
|
||||
|
|
|
@ -45,16 +45,23 @@ TIME_DIFF_FACTOR = {
|
|||
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
|
||||
|
||||
|
||||
def _add_date_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
|
||||
def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
unit = expression.text("unit").upper()
|
||||
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
|
||||
modified_increment = (
|
||||
int(expression.text("expression")) * multiplier
|
||||
if expression.expression.is_number
|
||||
else expression.expression
|
||||
)
|
||||
modified_increment = exp.Literal.number(modified_increment)
|
||||
return self.func(func, expression.this, modified_increment.this)
|
||||
|
||||
if isinstance(expression, exp.DateSub):
|
||||
multiplier *= -1
|
||||
|
||||
if expression.expression.is_number:
|
||||
modified_increment = exp.Literal.number(int(expression.text("expression")) * multiplier)
|
||||
else:
|
||||
modified_increment = expression.expression
|
||||
if multiplier != 1:
|
||||
modified_increment = exp.Mul( # type: ignore
|
||||
this=modified_increment, expression=exp.Literal.number(multiplier)
|
||||
)
|
||||
|
||||
return self.func(func, expression.this, modified_increment)
|
||||
|
||||
|
||||
def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
|
||||
|
@ -127,24 +134,6 @@ def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str
|
|||
return f"TO_DATE({this})"
|
||||
|
||||
|
||||
def _unnest_to_explode_sql(self: generator.Generator, expression: exp.Join) -> str:
|
||||
unnest = expression.this
|
||||
if isinstance(unnest, exp.Unnest):
|
||||
alias = unnest.args.get("alias")
|
||||
udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
|
||||
return "".join(
|
||||
self.sql(
|
||||
exp.Lateral(
|
||||
this=udtf(this=expression),
|
||||
view=True,
|
||||
alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore
|
||||
)
|
||||
)
|
||||
for expression, column in zip(unnest.expressions, alias.columns if alias else [])
|
||||
)
|
||||
return self.join_sql(expression)
|
||||
|
||||
|
||||
def _index_sql(self: generator.Generator, expression: exp.Index) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
table = self.sql(expression, "table")
|
||||
|
@ -195,6 +184,7 @@ class Hive(Dialect):
|
|||
IDENTIFIERS = ["`"]
|
||||
STRING_ESCAPES = ["\\"]
|
||||
ENCODE = "utf-8"
|
||||
IDENTIFIER_CAN_START_WITH_DIGIT = True
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
|
@ -217,9 +207,8 @@ class Hive(Dialect):
|
|||
"BD": "DECIMAL",
|
||||
}
|
||||
|
||||
IDENTIFIER_CAN_START_WITH_DIGIT = True
|
||||
|
||||
class Parser(parser.Parser):
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
STRICT_CAST = False
|
||||
|
||||
FUNCTIONS = {
|
||||
|
@ -273,9 +262,13 @@ class Hive(Dialect):
|
|||
),
|
||||
}
|
||||
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
class Generator(generator.Generator):
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
TABLESAMPLE_WITH_METHOD = False
|
||||
TABLESAMPLE_SIZE_IS_PERCENT = True
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.TEXT: "STRING",
|
||||
|
@ -289,6 +282,9 @@ class Hive(Dialect):
|
|||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
**transforms.ELIMINATE_QUALIFY, # type: ignore
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_qualify, transforms.unnest_to_explode]
|
||||
),
|
||||
exp.Property: _property_sql,
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.ArrayConcat: rename_func("CONCAT"),
|
||||
|
@ -298,13 +294,13 @@ class Hive(Dialect):
|
|||
exp.DateAdd: _add_date_sql,
|
||||
exp.DateDiff: _date_diff_sql,
|
||||
exp.DateStrToDate: rename_func("TO_DATE"),
|
||||
exp.DateSub: _add_date_sql,
|
||||
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
|
||||
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
|
||||
exp.FileFormatProperty: lambda self, e: f"STORED AS {e.name.upper()}",
|
||||
exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
|
||||
exp.If: if_sql,
|
||||
exp.Index: _index_sql,
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.Join: _unnest_to_explode_sql,
|
||||
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
|
||||
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
|
||||
exp.JSONFormat: rename_func("TO_JSON"),
|
||||
|
@ -354,10 +350,9 @@ class Hive(Dialect):
|
|||
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
|
||||
def arrayagg_sql(self, expression: exp.ArrayAgg) -> str:
|
||||
return self.func(
|
||||
"COLLECT_LIST",
|
||||
|
@ -378,4 +373,5 @@ class Hive(Dialect):
|
|||
expression = exp.DataType.build("text")
|
||||
elif expression.this in exp.DataType.TEMPORAL_TYPES:
|
||||
expression = exp.DataType.build(expression.this)
|
||||
|
||||
return super().datatype_sql(expression)
|
||||
|
|
|
@ -4,6 +4,8 @@ from sqlglot import exp, generator, parser, tokens
|
|||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
arrow_json_extract_scalar_sql,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
locate_to_strposition,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
|
@ -11,6 +13,7 @@ from sqlglot.dialects.dialect import (
|
|||
no_paren_current_date_sql,
|
||||
no_tablesample_sql,
|
||||
no_trycast_sql,
|
||||
parse_date_delta_with_interval,
|
||||
rename_func,
|
||||
strposition_to_locate_sql,
|
||||
)
|
||||
|
@ -76,18 +79,6 @@ def _trim_sql(self, expression):
|
|||
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
|
||||
|
||||
|
||||
def _date_add(expression_class):
|
||||
def func(args):
|
||||
interval = seq_get(args, 1)
|
||||
return expression_class(
|
||||
this=seq_get(args, 0),
|
||||
expression=interval.this,
|
||||
unit=exp.Literal.string(interval.text("unit").lower()),
|
||||
)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def _date_add_sql(kind):
|
||||
def func(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -115,6 +106,7 @@ class MySQL(Dialect):
|
|||
"%k": "%-H",
|
||||
"%l": "%-I",
|
||||
"%T": "%H:%M:%S",
|
||||
"%W": "%a",
|
||||
}
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
|
@ -127,12 +119,13 @@ class MySQL(Dialect):
|
|||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"MEDIUMTEXT": TokenType.MEDIUMTEXT,
|
||||
"CHARSET": TokenType.CHARACTER_SET,
|
||||
"LONGBLOB": TokenType.LONGBLOB,
|
||||
"LONGTEXT": TokenType.LONGTEXT,
|
||||
"MEDIUMBLOB": TokenType.MEDIUMBLOB,
|
||||
"LONGBLOB": TokenType.LONGBLOB,
|
||||
"START": TokenType.BEGIN,
|
||||
"MEDIUMTEXT": TokenType.MEDIUMTEXT,
|
||||
"SEPARATOR": TokenType.SEPARATOR,
|
||||
"START": TokenType.BEGIN,
|
||||
"_ARMSCII8": TokenType.INTRODUCER,
|
||||
"_ASCII": TokenType.INTRODUCER,
|
||||
"_BIG5": TokenType.INTRODUCER,
|
||||
|
@ -186,14 +179,15 @@ class MySQL(Dialect):
|
|||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"DATE_ADD": _date_add(exp.DateAdd),
|
||||
"DATE_SUB": _date_add(exp.DateSub),
|
||||
"STR_TO_DATE": _str_to_date,
|
||||
"LOCATE": locate_to_strposition,
|
||||
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
|
||||
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
|
||||
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
|
||||
"INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
|
||||
"LEFT": lambda args: exp.Substring(
|
||||
this=seq_get(args, 0), start=exp.Literal.number(1), length=seq_get(args, 1)
|
||||
),
|
||||
"LOCATE": locate_to_strposition,
|
||||
"STR_TO_DATE": _str_to_date,
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
|
@ -388,32 +382,36 @@ class MySQL(Dialect):
|
|||
class Generator(generator.Generator):
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
NULL_ORDERING_SUPPORTED = False
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
exp.CurrentDate: no_paren_current_date_sql,
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression),
|
||||
exp.DateAdd: _date_add_sql("ADD"),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateSub: _date_add_sql("SUB"),
|
||||
exp.DateTrunc: _date_trunc_sql,
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.DateAdd: _date_add_sql("ADD"),
|
||||
exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})",
|
||||
exp.DateSub: _date_add_sql("SUB"),
|
||||
exp.DateTrunc: _date_trunc_sql,
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
|
||||
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
|
||||
exp.StrToDate: _str_to_date_sql,
|
||||
exp.StrToTime: _str_to_date_sql,
|
||||
exp.Trim: _trim_sql,
|
||||
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
|
||||
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
|
||||
exp.StrPosition: strposition_to_locate_sql,
|
||||
exp.StrToDate: _str_to_date_sql,
|
||||
exp.StrToTime: _str_to_date_sql,
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)),
|
||||
exp.Trim: _trim_sql,
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
|
||||
}
|
||||
|
||||
TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy()
|
||||
|
@ -425,6 +423,7 @@ class MySQL(Dialect):
|
|||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
|
|
|
@ -7,11 +7,6 @@ from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sq
|
|||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
|
||||
TokenType.COLUMN,
|
||||
TokenType.RETURNING,
|
||||
}
|
||||
|
||||
|
||||
def _parse_xml_table(self) -> exp.XMLTable:
|
||||
this = self._parse_string()
|
||||
|
@ -22,9 +17,7 @@ def _parse_xml_table(self) -> exp.XMLTable:
|
|||
if self._match_text_seq("PASSING"):
|
||||
# The BY VALUE keywords are optional and are provided for semantic clarity
|
||||
self._match_text_seq("BY", "VALUE")
|
||||
passing = self._parse_csv(
|
||||
lambda: self._parse_table(alias_tokens=PASSING_TABLE_ALIAS_TOKENS)
|
||||
)
|
||||
passing = self._parse_csv(self._parse_column)
|
||||
|
||||
by_ref = self._match_text_seq("RETURNING", "SEQUENCE", "BY", "REF")
|
||||
|
||||
|
@ -68,6 +61,8 @@ class Oracle(Dialect):
|
|||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP}
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
|
||||
|
@ -78,6 +73,12 @@ class Oracle(Dialect):
|
|||
"XMLTABLE": _parse_xml_table,
|
||||
}
|
||||
|
||||
TYPE_LITERAL_PARSERS = {
|
||||
exp.DataType.Type.DATE: lambda self, this, _: self.expression(
|
||||
exp.DateStrToDate, this=this
|
||||
)
|
||||
}
|
||||
|
||||
def _parse_column(self) -> t.Optional[exp.Expression]:
|
||||
column = super()._parse_column()
|
||||
if column:
|
||||
|
@ -100,6 +101,8 @@ class Oracle(Dialect):
|
|||
|
||||
class Generator(generator.Generator):
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
|
@ -119,6 +122,9 @@ class Oracle(Dialect):
|
|||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
exp.DateStrToDate: lambda self, e: self.func(
|
||||
"TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD")
|
||||
),
|
||||
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
|
@ -129,6 +135,12 @@ class Oracle(Dialect):
|
|||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.Trim: trim_sql,
|
||||
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
|
||||
exp.IfNull: rename_func("NVL"),
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
LIMIT_FETCH = "FETCH"
|
||||
|
@ -142,9 +154,9 @@ class Oracle(Dialect):
|
|||
|
||||
def xmltable_sql(self, expression: exp.XMLTable) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
passing = self.expressions(expression, "passing")
|
||||
passing = self.expressions(expression, key="passing")
|
||||
passing = f"{self.sep()}PASSING{self.seg(passing)}" if passing else ""
|
||||
columns = self.expressions(expression, "columns")
|
||||
columns = self.expressions(expression, key="columns")
|
||||
columns = f"{self.sep()}COLUMNS{self.seg(columns)}" if columns else ""
|
||||
by_ref = (
|
||||
f"{self.sep()}RETURNING SEQUENCE BY REF" if expression.args.get("by_ref") else ""
|
||||
|
|
|
@ -5,6 +5,7 @@ from sqlglot.dialects.dialect import (
|
|||
Dialect,
|
||||
arrow_json_extract_scalar_sql,
|
||||
arrow_json_extract_sql,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
|
@ -19,7 +20,7 @@ from sqlglot.dialects.dialect import (
|
|||
from sqlglot.helper import seq_get
|
||||
from sqlglot.parser import binary_range_parser
|
||||
from sqlglot.tokens import TokenType
|
||||
from sqlglot.transforms import delegate, preprocess
|
||||
from sqlglot.transforms import preprocess, remove_target_from_merge
|
||||
|
||||
DATE_DIFF_FACTOR = {
|
||||
"MICROSECOND": " * 1000000",
|
||||
|
@ -239,7 +240,6 @@ class Postgres(Dialect):
|
|||
"SERIAL": TokenType.SERIAL,
|
||||
"SMALLSERIAL": TokenType.SMALLSERIAL,
|
||||
"TEMP": TokenType.TEMPORARY,
|
||||
"UUID": TokenType.UUID,
|
||||
"CSTRING": TokenType.PSEUDO_TYPE,
|
||||
}
|
||||
|
||||
|
@ -248,18 +248,25 @@ class Postgres(Dialect):
|
|||
"$": TokenType.PARAMETER,
|
||||
}
|
||||
|
||||
VAR_SINGLE_TOKENS = {"$"}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
STRICT_CAST = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"NOW": exp.CurrentTimestamp.from_arg_list,
|
||||
"TO_TIMESTAMP": _to_timestamp,
|
||||
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
|
||||
"GENERATE_SERIES": _generate_series,
|
||||
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
|
||||
this=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
),
|
||||
"GENERATE_SERIES": _generate_series,
|
||||
"NOW": exp.CurrentTimestamp.from_arg_list,
|
||||
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
|
||||
"TO_TIMESTAMP": _to_timestamp,
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
"DATE_PART": lambda self: self._parse_date_part(),
|
||||
}
|
||||
|
||||
BITWISE = {
|
||||
|
@ -279,8 +286,21 @@ class Postgres(Dialect):
|
|||
TokenType.LT_AT: binary_range_parser(exp.ArrayContained),
|
||||
}
|
||||
|
||||
def _parse_date_part(self) -> exp.Expression:
|
||||
part = self._parse_type()
|
||||
self._match(TokenType.COMMA)
|
||||
value = self._parse_bitwise()
|
||||
|
||||
if part and part.is_string:
|
||||
part = exp.Var(this=part.name)
|
||||
|
||||
return self.expression(exp.Extract, this=part, expression=value)
|
||||
|
||||
class Generator(generator.Generator):
|
||||
INTERVAL_ALLOWS_PLURAL_FORM = False
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
PARAMETER_TOKEN = "$"
|
||||
|
||||
TYPE_MAPPING = {
|
||||
|
@ -301,7 +321,6 @@ class Postgres(Dialect):
|
|||
_auto_increment_to_serial,
|
||||
_serial_to_generated,
|
||||
],
|
||||
delegate("columndef_sql"),
|
||||
),
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||
|
@ -312,6 +331,7 @@ class Postgres(Dialect):
|
|||
exp.CurrentDate: no_paren_current_date_sql,
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.DateAdd: _date_add_sql("+"),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateSub: _date_add_sql("-"),
|
||||
exp.DateDiff: _date_diff_sql,
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
|
@ -321,6 +341,7 @@ class Postgres(Dialect):
|
|||
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
|
||||
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
|
||||
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
|
||||
exp.Merge: preprocess([remove_target_from_merge]),
|
||||
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
|
||||
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
|
||||
exp.StrPosition: str_position_sql,
|
||||
|
@ -344,4 +365,5 @@ class Postgres(Dialect):
|
|||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
|
@ -19,20 +21,20 @@ from sqlglot.helper import seq_get
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _approx_distinct_sql(self, expression):
|
||||
def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistinct) -> str:
|
||||
accuracy = expression.args.get("accuracy")
|
||||
accuracy = ", " + self.sql(accuracy) if accuracy else ""
|
||||
return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})"
|
||||
|
||||
|
||||
def _datatype_sql(self, expression):
|
||||
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
|
||||
sql = self.datatype_sql(expression)
|
||||
if expression.this == exp.DataType.Type.TIMESTAMPTZ:
|
||||
sql = f"{sql} WITH TIME ZONE"
|
||||
return sql
|
||||
|
||||
|
||||
def _explode_to_unnest_sql(self, expression):
|
||||
def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str:
|
||||
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
|
||||
return self.sql(
|
||||
exp.Join(
|
||||
|
@ -47,22 +49,22 @@ def _explode_to_unnest_sql(self, expression):
|
|||
return self.lateral_sql(expression)
|
||||
|
||||
|
||||
def _initcap_sql(self, expression):
|
||||
def _initcap_sql(self: generator.Generator, expression: exp.Initcap) -> str:
|
||||
regex = r"(\w)(\w*)"
|
||||
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
|
||||
|
||||
|
||||
def _decode_sql(self, expression):
|
||||
_ensure_utf8(expression.args.get("charset"))
|
||||
def _decode_sql(self: generator.Generator, expression: exp.Decode) -> str:
|
||||
_ensure_utf8(expression.args["charset"])
|
||||
return self.func("FROM_UTF8", expression.this, expression.args.get("replace"))
|
||||
|
||||
|
||||
def _encode_sql(self, expression):
|
||||
_ensure_utf8(expression.args.get("charset"))
|
||||
def _encode_sql(self: generator.Generator, expression: exp.Encode) -> str:
|
||||
_ensure_utf8(expression.args["charset"])
|
||||
return f"TO_UTF8({self.sql(expression, 'this')})"
|
||||
|
||||
|
||||
def _no_sort_array(self, expression):
|
||||
def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str:
|
||||
if expression.args.get("asc") == exp.false():
|
||||
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
|
||||
else:
|
||||
|
@ -70,49 +72,62 @@ def _no_sort_array(self, expression):
|
|||
return self.func("ARRAY_SORT", expression.this, comparator)
|
||||
|
||||
|
||||
def _schema_sql(self, expression):
|
||||
def _schema_sql(self: generator.Generator, expression: exp.Schema) -> str:
|
||||
if isinstance(expression.parent, exp.Property):
|
||||
columns = ", ".join(f"'{c.name}'" for c in expression.expressions)
|
||||
return f"ARRAY[{columns}]"
|
||||
|
||||
for schema in expression.parent.find_all(exp.Schema):
|
||||
if isinstance(schema.parent, exp.Property):
|
||||
expression = expression.copy()
|
||||
expression.expressions.extend(schema.expressions)
|
||||
if expression.parent:
|
||||
for schema in expression.parent.find_all(exp.Schema):
|
||||
if isinstance(schema.parent, exp.Property):
|
||||
expression = expression.copy()
|
||||
expression.expressions.extend(schema.expressions)
|
||||
|
||||
return self.schema_sql(expression)
|
||||
|
||||
|
||||
def _quantile_sql(self, expression):
|
||||
def _quantile_sql(self: generator.Generator, expression: exp.Quantile) -> str:
|
||||
self.unsupported("Presto does not support exact quantiles")
|
||||
return f"APPROX_PERCENTILE({self.sql(expression, 'this')}, {self.sql(expression, 'quantile')})"
|
||||
|
||||
|
||||
def _str_to_time_sql(self, expression):
|
||||
def _str_to_time_sql(
|
||||
self: generator.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate
|
||||
) -> str:
|
||||
return f"DATE_PARSE({self.sql(expression, 'this')}, {self.format_time(expression)})"
|
||||
|
||||
|
||||
def _ts_or_ds_to_date_sql(self, expression):
|
||||
def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
|
||||
time_format = self.format_time(expression)
|
||||
if time_format and time_format not in (Presto.time_format, Presto.date_format):
|
||||
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
|
||||
return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)"
|
||||
|
||||
|
||||
def _ts_or_ds_add_sql(self, expression):
|
||||
def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
|
||||
this = expression.this
|
||||
|
||||
if not isinstance(this, exp.CurrentDate):
|
||||
this = self.func(
|
||||
"DATE_PARSE",
|
||||
self.func(
|
||||
"SUBSTR",
|
||||
this if this.is_string else exp.cast(this, "VARCHAR"),
|
||||
exp.Literal.number(1),
|
||||
exp.Literal.number(10),
|
||||
),
|
||||
Presto.date_format,
|
||||
)
|
||||
|
||||
return self.func(
|
||||
"DATE_ADD",
|
||||
exp.Literal.string(expression.text("unit") or "day"),
|
||||
expression.expression,
|
||||
self.func(
|
||||
"DATE_PARSE",
|
||||
self.func("SUBSTR", expression.this, exp.Literal.number(1), exp.Literal.number(10)),
|
||||
Presto.date_format,
|
||||
),
|
||||
this,
|
||||
)
|
||||
|
||||
|
||||
def _sequence_sql(self, expression):
|
||||
def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str:
|
||||
start = expression.args["start"]
|
||||
end = expression.args["end"]
|
||||
step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series
|
||||
|
@ -135,12 +150,12 @@ def _sequence_sql(self, expression):
|
|||
return self.func("SEQUENCE", start, end, step)
|
||||
|
||||
|
||||
def _ensure_utf8(charset):
|
||||
def _ensure_utf8(charset: exp.Literal) -> None:
|
||||
if charset.name.lower() != "utf-8":
|
||||
raise UnsupportedError(f"Unsupported charset {charset}")
|
||||
|
||||
|
||||
def _approx_percentile(args):
|
||||
def _approx_percentile(args: t.Sequence) -> exp.Expression:
|
||||
if len(args) == 4:
|
||||
return exp.ApproxQuantile(
|
||||
this=seq_get(args, 0),
|
||||
|
@ -157,7 +172,7 @@ def _approx_percentile(args):
|
|||
return exp.ApproxQuantile.from_arg_list(args)
|
||||
|
||||
|
||||
def _from_unixtime(args):
|
||||
def _from_unixtime(args: t.Sequence) -> exp.Expression:
|
||||
if len(args) == 3:
|
||||
return exp.UnixToTime(
|
||||
this=seq_get(args, 0),
|
||||
|
@ -226,11 +241,15 @@ class Presto(Dialect):
|
|||
FUNCTION_PARSERS.pop("TRIM")
|
||||
|
||||
class Generator(generator.Generator):
|
||||
INTERVAL_ALLOWS_PLURAL_FORM = False
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.LocationProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
|
@ -246,7 +265,6 @@ class Presto(Dialect):
|
|||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
**transforms.ELIMINATE_QUALIFY, # type: ignore
|
||||
exp.ApproxDistinct: _approx_distinct_sql,
|
||||
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
|
||||
exp.ArrayConcat: rename_func("CONCAT"),
|
||||
|
@ -284,6 +302,9 @@ class Presto(Dialect):
|
|||
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
|
||||
exp.SafeDivide: no_safe_divide_sql,
|
||||
exp.Schema: _schema_sql,
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_qualify, transforms.explode_to_unnest]
|
||||
),
|
||||
exp.SortArray: _no_sort_array,
|
||||
exp.StrPosition: rename_func("STRPOS"),
|
||||
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
|
||||
|
@ -308,7 +329,13 @@ class Presto(Dialect):
|
|||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
}
|
||||
|
||||
def transaction_sql(self, expression):
|
||||
def interval_sql(self, expression: exp.Interval) -> str:
|
||||
unit = self.sql(expression, "unit")
|
||||
if expression.this and unit.lower().startswith("week"):
|
||||
return f"({expression.this.name} * INTERVAL '7' day)"
|
||||
return super().interval_sql(expression)
|
||||
|
||||
def transaction_sql(self, expression: exp.Transaction) -> str:
|
||||
modes = expression.args.get("modes")
|
||||
modes = f" {', '.join(modes)}" if modes else ""
|
||||
return f"START TRANSACTION{modes}"
|
||||
|
|
|
@ -8,6 +8,10 @@ from sqlglot.helper import seq_get
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _json_sql(self, e) -> str:
|
||||
return f'{self.sql(e, "this")}."{e.expression.name}"'
|
||||
|
||||
|
||||
class Redshift(Postgres):
|
||||
time_format = "'YYYY-MM-DD HH:MI:SS'"
|
||||
time_mapping = {
|
||||
|
@ -56,6 +60,7 @@ class Redshift(Postgres):
|
|||
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
||||
"HLLSKETCH": TokenType.HLLSKETCH,
|
||||
"SUPER": TokenType.SUPER,
|
||||
"SYSDATE": TokenType.CURRENT_TIMESTAMP,
|
||||
"TIME": TokenType.TIMESTAMP,
|
||||
"TIMETZ": TokenType.TIMESTAMPTZ,
|
||||
"TOP": TokenType.TOP,
|
||||
|
@ -63,7 +68,14 @@ class Redshift(Postgres):
|
|||
"VARBYTE": TokenType.VARBINARY,
|
||||
}
|
||||
|
||||
# Redshift allows # to appear as a table identifier prefix
|
||||
SINGLE_TOKENS = Postgres.Tokenizer.SINGLE_TOKENS.copy()
|
||||
SINGLE_TOKENS.pop("#")
|
||||
|
||||
class Generator(Postgres.Generator):
|
||||
LOCKING_READS_SUPPORTED = False
|
||||
SINGLE_STRING_INTERVAL = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**Postgres.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.BINARY: "VARBYTE",
|
||||
|
@ -79,6 +91,7 @@ class Redshift(Postgres):
|
|||
TRANSFORMS = {
|
||||
**Postgres.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
|
||||
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
|
||||
exp.DateAdd: lambda self, e: self.func(
|
||||
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
|
||||
),
|
||||
|
@ -87,12 +100,16 @@ class Redshift(Postgres):
|
|||
),
|
||||
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
|
||||
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
|
||||
exp.JSONExtract: _json_sql,
|
||||
exp.JSONExtractScalar: _json_sql,
|
||||
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
|
||||
}
|
||||
|
||||
# Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres)
|
||||
TRANSFORMS.pop(exp.Pow)
|
||||
|
||||
RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot"}
|
||||
|
||||
def values_sql(self, expression: exp.Values) -> str:
|
||||
"""
|
||||
Converts `VALUES...` expression into a series of unions.
|
||||
|
|
|
@ -23,14 +23,14 @@ from sqlglot.parser import binary_range_parser
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _check_int(s):
|
||||
def _check_int(s: str) -> bool:
|
||||
if s[0] in ("-", "+"):
|
||||
return s[1:].isdigit()
|
||||
return s.isdigit()
|
||||
|
||||
|
||||
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
|
||||
def _snowflake_to_timestamp(args):
|
||||
def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.UnixToTime]:
|
||||
if len(args) == 2:
|
||||
first_arg, second_arg = args
|
||||
if second_arg.is_string:
|
||||
|
@ -69,7 +69,7 @@ def _snowflake_to_timestamp(args):
|
|||
return exp.UnixToTime.from_arg_list(args)
|
||||
|
||||
|
||||
def _unix_to_time_sql(self, expression):
|
||||
def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str:
|
||||
scale = expression.args.get("scale")
|
||||
timestamp = self.sql(expression, "this")
|
||||
if scale in [None, exp.UnixToTime.SECONDS]:
|
||||
|
@ -84,8 +84,12 @@ def _unix_to_time_sql(self, expression):
|
|||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
|
||||
# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts
|
||||
def _parse_date_part(self):
|
||||
def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_var() or self._parse_type()
|
||||
|
||||
if not this:
|
||||
return None
|
||||
|
||||
self._match(TokenType.COMMA)
|
||||
expression = self._parse_bitwise()
|
||||
|
||||
|
@ -101,7 +105,7 @@ def _parse_date_part(self):
|
|||
scale = None
|
||||
|
||||
ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP"))
|
||||
to_unix = self.expression(exp.TimeToUnix, this=ts)
|
||||
to_unix: exp.Expression = self.expression(exp.TimeToUnix, this=ts)
|
||||
|
||||
if scale:
|
||||
to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale))
|
||||
|
@ -112,7 +116,7 @@ def _parse_date_part(self):
|
|||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/div0
|
||||
def _div0_to_if(args):
|
||||
def _div0_to_if(args: t.Sequence) -> exp.Expression:
|
||||
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
|
||||
true = exp.Literal.number(0)
|
||||
false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1))
|
||||
|
@ -120,18 +124,18 @@ def _div0_to_if(args):
|
|||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
|
||||
def _zeroifnull_to_if(args):
|
||||
def _zeroifnull_to_if(args: t.Sequence) -> exp.Expression:
|
||||
cond = exp.Is(this=seq_get(args, 0), expression=exp.Null())
|
||||
return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0))
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
|
||||
def _nullifzero_to_if(args):
|
||||
def _nullifzero_to_if(args: t.Sequence) -> exp.Expression:
|
||||
cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0))
|
||||
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
|
||||
|
||||
|
||||
def _datatype_sql(self, expression):
|
||||
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
|
||||
if expression.this == exp.DataType.Type.ARRAY:
|
||||
return "ARRAY"
|
||||
elif expression.this == exp.DataType.Type.MAP:
|
||||
|
@ -155,9 +159,8 @@ class Snowflake(Dialect):
|
|||
"MM": "%m",
|
||||
"mm": "%m",
|
||||
"DD": "%d",
|
||||
"dd": "%d",
|
||||
"d": "%-d",
|
||||
"DY": "%w",
|
||||
"dd": "%-d",
|
||||
"DY": "%a",
|
||||
"dy": "%w",
|
||||
"HH24": "%H",
|
||||
"hh24": "%H",
|
||||
|
@ -174,6 +177,8 @@ class Snowflake(Dialect):
|
|||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
QUOTED_PIVOT_COLUMNS = True
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
|
||||
|
@ -269,9 +274,14 @@ class Snowflake(Dialect):
|
|||
"$": TokenType.PARAMETER,
|
||||
}
|
||||
|
||||
VAR_SINGLE_TOKENS = {"$"}
|
||||
|
||||
class Generator(generator.Generator):
|
||||
PARAMETER_TOKEN = "$"
|
||||
MATCHED_BY_SOURCE = False
|
||||
SINGLE_STRING_INTERVAL = True
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
|
@ -287,26 +297,30 @@ class Snowflake(Dialect):
|
|||
),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.If: rename_func("IFF"),
|
||||
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
exp.LogicalOr: rename_func("BOOLOR_AGG"),
|
||||
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
|
||||
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
exp.LogicalOr: rename_func("BOOLOR_AGG"),
|
||||
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
|
||||
exp.StrPosition: lambda self, e: self.func(
|
||||
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
|
||||
),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimestampTrunc: timestamptrunc_sql,
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
|
||||
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
|
||||
exp.TimeToStr: lambda self, e: self.func(
|
||||
"TO_CHAR", exp.cast(e.this, "timestamp"), self.format_time(e)
|
||||
),
|
||||
exp.TimestampTrunc: timestamptrunc_sql,
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"),
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
|
@ -322,14 +336,15 @@ class Snowflake(Dialect):
|
|||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.SetProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def except_op(self, expression):
|
||||
def except_op(self, expression: exp.Except) -> str:
|
||||
if not expression.args.get("distinct", False):
|
||||
self.unsupported("EXCEPT with All is not supported in Snowflake")
|
||||
return super().except_op(expression)
|
||||
|
||||
def intersect_op(self, expression):
|
||||
def intersect_op(self, expression: exp.Intersect) -> str:
|
||||
if not expression.args.get("distinct", False):
|
||||
self.unsupported("INTERSECT with All is not supported in Snowflake")
|
||||
return super().intersect_op(expression)
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, parser
|
||||
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
|
||||
from sqlglot.dialects.hive import Hive
|
||||
from sqlglot.helper import seq_get
|
||||
|
||||
|
||||
def _create_sql(self, e):
|
||||
kind = e.args.get("kind")
|
||||
def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
|
||||
kind = e.args["kind"]
|
||||
properties = e.args.get("properties")
|
||||
|
||||
if kind.upper() == "TABLE" and any(
|
||||
|
@ -18,13 +20,13 @@ def _create_sql(self, e):
|
|||
return create_with_partitions_sql(self, e)
|
||||
|
||||
|
||||
def _map_sql(self, expression):
|
||||
def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
|
||||
keys = self.sql(expression.args["keys"])
|
||||
values = self.sql(expression.args["values"])
|
||||
return f"MAP_FROM_ARRAYS({keys}, {values})"
|
||||
|
||||
|
||||
def _str_to_date(self, expression):
|
||||
def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format == Hive.date_format:
|
||||
|
@ -32,7 +34,7 @@ def _str_to_date(self, expression):
|
|||
return f"TO_DATE({this}, {time_format})"
|
||||
|
||||
|
||||
def _unix_to_time(self, expression):
|
||||
def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
|
||||
scale = expression.args.get("scale")
|
||||
timestamp = self.sql(expression, "this")
|
||||
if scale is None:
|
||||
|
@ -75,7 +77,11 @@ class Spark(Hive):
|
|||
length=seq_get(args, 1),
|
||||
),
|
||||
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
||||
"BOOLEAN": lambda args: exp.Cast(
|
||||
this=seq_get(args, 0), to=exp.DataType.build("boolean")
|
||||
),
|
||||
"IIF": exp.If.from_arg_list,
|
||||
"INT": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("int")),
|
||||
"AGGREGATE": exp.Reduce.from_arg_list,
|
||||
"DAYOFWEEK": lambda args: exp.DayOfWeek(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
|
@ -89,11 +95,16 @@ class Spark(Hive):
|
|||
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
|
||||
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
|
||||
this=seq_get(args, 1),
|
||||
unit=exp.var(seq_get(args, 0)),
|
||||
),
|
||||
"STRING": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("string")),
|
||||
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
|
||||
"TIMESTAMP": lambda args: exp.Cast(
|
||||
this=seq_get(args, 0), to=exp.DataType.build("timestamp")
|
||||
),
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
|
@ -108,16 +119,43 @@ class Spark(Hive):
|
|||
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
|
||||
}
|
||||
|
||||
def _parse_add_column(self):
|
||||
def _parse_add_column(self) -> t.Optional[exp.Expression]:
|
||||
return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
|
||||
|
||||
def _parse_drop_column(self):
|
||||
def _parse_drop_column(self) -> t.Optional[exp.Expression]:
|
||||
return self._match_text_seq("DROP", "COLUMNS") and self.expression(
|
||||
exp.Drop,
|
||||
this=self._parse_schema(),
|
||||
kind="COLUMNS",
|
||||
)
|
||||
|
||||
def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
|
||||
# Spark doesn't add a suffix to the pivot columns when there's a single aggregation
|
||||
if len(pivot_columns) == 1:
|
||||
return [""]
|
||||
|
||||
names = []
|
||||
for agg in pivot_columns:
|
||||
if isinstance(agg, exp.Alias):
|
||||
names.append(agg.alias)
|
||||
else:
|
||||
"""
|
||||
This case corresponds to aggregations without aliases being used as suffixes
|
||||
(e.g. col_avg(foo)). We need to unquote identifiers because they're going to
|
||||
be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
|
||||
Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
|
||||
|
||||
Moreover, function names are lowercased in order to mimic Spark's naming scheme.
|
||||
"""
|
||||
agg_all_unquoted = agg.transform(
|
||||
lambda node: exp.Identifier(this=node.name, quoted=False)
|
||||
if isinstance(node, exp.Identifier)
|
||||
else node
|
||||
)
|
||||
names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
|
||||
|
||||
return names
|
||||
|
||||
class Generator(Hive.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**Hive.Generator.TYPE_MAPPING, # type: ignore
|
||||
|
@ -145,7 +183,7 @@ class Spark(Hive):
|
|||
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
|
||||
exp.StrToDate: _str_to_date,
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.UnixToTime: _unix_to_time,
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.Create: _create_sql,
|
||||
exp.Map: _map_sql,
|
||||
exp.Reduce: rename_func("AGGREGATE"),
|
||||
|
|
|
@ -16,7 +16,7 @@ from sqlglot.tokens import TokenType
|
|||
|
||||
def _date_add_sql(self, expression):
|
||||
modifier = expression.expression
|
||||
modifier = expression.name if modifier.is_string else self.sql(modifier)
|
||||
modifier = modifier.name if modifier.is_string else self.sql(modifier)
|
||||
unit = expression.args.get("unit")
|
||||
modifier = f"'{modifier} {unit.name}'" if unit else f"'{modifier}'"
|
||||
return self.func("DATE", expression.this, modifier)
|
||||
|
@ -38,6 +38,9 @@ class SQLite(Dialect):
|
|||
}
|
||||
|
||||
class Generator(generator.Generator):
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.BOOLEAN: "INTEGER",
|
||||
|
@ -82,6 +85,11 @@ class SQLite(Dialect):
|
|||
exp.TryCast: no_trycast_sql,
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
|
||||
def cast_sql(self, expression: exp.Cast) -> str:
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func
|
||||
from sqlglot.dialects.dialect import (
|
||||
approx_count_distinct_sql,
|
||||
arrow_json_extract_sql,
|
||||
rename_func,
|
||||
)
|
||||
from sqlglot.dialects.mysql import MySQL
|
||||
from sqlglot.helper import seq_get
|
||||
|
||||
|
@ -10,6 +14,7 @@ class StarRocks(MySQL):
|
|||
class Parser(MySQL.Parser): # type: ignore
|
||||
FUNCTIONS = {
|
||||
**MySQL.Parser.FUNCTIONS,
|
||||
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
|
||||
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
|
||||
this=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
),
|
||||
|
@ -25,6 +30,7 @@ class StarRocks(MySQL):
|
|||
|
||||
TRANSFORMS = {
|
||||
**MySQL.Generator.TRANSFORMS, # type: ignore
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_sql,
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.DateDiff: rename_func("DATEDIFF"),
|
||||
|
|
|
@ -21,6 +21,9 @@ def _count_sql(self, expression):
|
|||
|
||||
class Tableau(Dialect):
|
||||
class Generator(generator.Generator):
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
exp.If: _if_sql,
|
||||
|
@ -28,6 +31,11 @@ class Tableau(Dialect):
|
|||
exp.Count: _count_sql,
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
|
|
|
@ -1,7 +1,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
format_time_lambda,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
)
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
|
@ -115,7 +122,18 @@ class Teradata(Dialect):
|
|||
|
||||
return self.expression(exp.RangeN, this=this, expressions=expressions, each=each)
|
||||
|
||||
def _parse_cast(self, strict: bool) -> exp.Expression:
|
||||
cast = t.cast(exp.Cast, super()._parse_cast(strict))
|
||||
if cast.to.this == exp.DataType.Type.DATE and self._match(TokenType.FORMAT):
|
||||
return format_time_lambda(exp.TimeToStr, "teradata")(
|
||||
[cast.this, self._parse_string()]
|
||||
)
|
||||
return cast
|
||||
|
||||
class Generator(generator.Generator):
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.GEOMETRY: "ST_GEOMETRY",
|
||||
|
@ -130,6 +148,7 @@ class Teradata(Dialect):
|
|||
**generator.Generator.TRANSFORMS,
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.TimeToStr: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
}
|
||||
|
||||
|
|
|
@ -96,6 +96,23 @@ def _parse_eomonth(args):
|
|||
return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit))
|
||||
|
||||
|
||||
def _parse_hashbytes(args):
|
||||
kind, data = args
|
||||
kind = kind.name.upper() if kind.is_string else ""
|
||||
|
||||
if kind == "MD5":
|
||||
args.pop(0)
|
||||
return exp.MD5(this=data)
|
||||
if kind in ("SHA", "SHA1"):
|
||||
args.pop(0)
|
||||
return exp.SHA(this=data)
|
||||
if kind == "SHA2_256":
|
||||
return exp.SHA2(this=data, length=exp.Literal.number(256))
|
||||
if kind == "SHA2_512":
|
||||
return exp.SHA2(this=data, length=exp.Literal.number(512))
|
||||
return exp.func("HASHBYTES", *args)
|
||||
|
||||
|
||||
def generate_date_delta_with_unit_sql(self, e):
|
||||
func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF"
|
||||
return self.func(func, e.text("unit"), e.expression, e.this)
|
||||
|
@ -266,6 +283,7 @@ class TSQL(Dialect):
|
|||
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
|
||||
"VARCHAR(MAX)": TokenType.TEXT,
|
||||
"XML": TokenType.XML,
|
||||
"SYSTEM_USER": TokenType.CURRENT_USER,
|
||||
}
|
||||
|
||||
# TSQL allows @, # to appear as a variable/identifier prefix
|
||||
|
@ -287,6 +305,7 @@ class TSQL(Dialect):
|
|||
"EOMONTH": _parse_eomonth,
|
||||
"FORMAT": _parse_format,
|
||||
"GETDATE": exp.CurrentTimestamp.from_arg_list,
|
||||
"HASHBYTES": _parse_hashbytes,
|
||||
"IIF": exp.If.from_arg_list,
|
||||
"ISNULL": exp.Coalesce.from_arg_list,
|
||||
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
|
||||
|
@ -296,6 +315,14 @@ class TSQL(Dialect):
|
|||
"SYSDATETIME": exp.CurrentTimestamp.from_arg_list,
|
||||
"SUSER_NAME": exp.CurrentUser.from_arg_list,
|
||||
"SUSER_SNAME": exp.CurrentUser.from_arg_list,
|
||||
"SYSTEM_USER": exp.CurrentUser.from_arg_list,
|
||||
}
|
||||
|
||||
JOIN_HINTS = {
|
||||
"LOOP",
|
||||
"HASH",
|
||||
"MERGE",
|
||||
"REMOTE",
|
||||
}
|
||||
|
||||
VAR_LENGTH_DATATYPES = {
|
||||
|
@ -441,11 +468,21 @@ class TSQL(Dialect):
|
|||
exp.TimeToStr: _format_sql,
|
||||
exp.GroupConcat: _string_agg_sql,
|
||||
exp.Max: max_or_greatest,
|
||||
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
|
||||
exp.Min: min_or_least,
|
||||
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
|
||||
exp.SHA2: lambda self, e: self.func(
|
||||
"HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
|
||||
),
|
||||
}
|
||||
|
||||
TRANSFORMS.pop(exp.ReturnsProperty)
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
LIMIT_FETCH = "FETCH"
|
||||
|
||||
def offset_sql(self, expression: exp.Offset) -> str:
|
||||
|
|
|
@ -701,6 +701,119 @@ class Condition(Expression):
|
|||
"""
|
||||
return not_(self)
|
||||
|
||||
def _binop(self, klass: t.Type[E], other: ExpOrStr, reverse=False) -> E:
|
||||
this = self
|
||||
other = convert(other)
|
||||
if not isinstance(this, klass) and not isinstance(other, klass):
|
||||
this = _wrap(this, Binary)
|
||||
other = _wrap(other, Binary)
|
||||
if reverse:
|
||||
return klass(this=other, expression=this)
|
||||
return klass(this=this, expression=other)
|
||||
|
||||
def __getitem__(self, other: ExpOrStr | slice | t.Tuple[ExpOrStr]):
|
||||
if isinstance(other, slice):
|
||||
return Between(
|
||||
this=self,
|
||||
low=convert(other.start),
|
||||
high=convert(other.stop),
|
||||
)
|
||||
return Bracket(this=self, expressions=[convert(e) for e in ensure_list(other)])
|
||||
|
||||
def isin(self, *expressions: ExpOrStr, query: t.Optional[ExpOrStr] = None, **opts) -> In:
|
||||
return In(
|
||||
this=self,
|
||||
expressions=[convert(e) for e in expressions],
|
||||
query=maybe_parse(query, **opts) if query else None,
|
||||
)
|
||||
|
||||
def like(self, other: ExpOrStr) -> Like:
|
||||
return self._binop(Like, other)
|
||||
|
||||
def ilike(self, other: ExpOrStr) -> ILike:
|
||||
return self._binop(ILike, other)
|
||||
|
||||
def eq(self, other: ExpOrStr) -> EQ:
|
||||
return self._binop(EQ, other)
|
||||
|
||||
def neq(self, other: ExpOrStr) -> NEQ:
|
||||
return self._binop(NEQ, other)
|
||||
|
||||
def rlike(self, other: ExpOrStr) -> RegexpLike:
|
||||
return self._binop(RegexpLike, other)
|
||||
|
||||
def __lt__(self, other: ExpOrStr) -> LT:
|
||||
return self._binop(LT, other)
|
||||
|
||||
def __le__(self, other: ExpOrStr) -> LTE:
|
||||
return self._binop(LTE, other)
|
||||
|
||||
def __gt__(self, other: ExpOrStr) -> GT:
|
||||
return self._binop(GT, other)
|
||||
|
||||
def __ge__(self, other: ExpOrStr) -> GTE:
|
||||
return self._binop(GTE, other)
|
||||
|
||||
def __add__(self, other: ExpOrStr) -> Add:
|
||||
return self._binop(Add, other)
|
||||
|
||||
def __radd__(self, other: ExpOrStr) -> Add:
|
||||
return self._binop(Add, other, reverse=True)
|
||||
|
||||
def __sub__(self, other: ExpOrStr) -> Sub:
|
||||
return self._binop(Sub, other)
|
||||
|
||||
def __rsub__(self, other: ExpOrStr) -> Sub:
|
||||
return self._binop(Sub, other, reverse=True)
|
||||
|
||||
def __mul__(self, other: ExpOrStr) -> Mul:
|
||||
return self._binop(Mul, other)
|
||||
|
||||
def __rmul__(self, other: ExpOrStr) -> Mul:
|
||||
return self._binop(Mul, other, reverse=True)
|
||||
|
||||
def __truediv__(self, other: ExpOrStr) -> Div:
|
||||
return self._binop(Div, other)
|
||||
|
||||
def __rtruediv__(self, other: ExpOrStr) -> Div:
|
||||
return self._binop(Div, other, reverse=True)
|
||||
|
||||
def __floordiv__(self, other: ExpOrStr) -> IntDiv:
|
||||
return self._binop(IntDiv, other)
|
||||
|
||||
def __rfloordiv__(self, other: ExpOrStr) -> IntDiv:
|
||||
return self._binop(IntDiv, other, reverse=True)
|
||||
|
||||
def __mod__(self, other: ExpOrStr) -> Mod:
|
||||
return self._binop(Mod, other)
|
||||
|
||||
def __rmod__(self, other: ExpOrStr) -> Mod:
|
||||
return self._binop(Mod, other, reverse=True)
|
||||
|
||||
def __pow__(self, other: ExpOrStr) -> Pow:
|
||||
return self._binop(Pow, other)
|
||||
|
||||
def __rpow__(self, other: ExpOrStr) -> Pow:
|
||||
return self._binop(Pow, other, reverse=True)
|
||||
|
||||
def __and__(self, other: ExpOrStr) -> And:
|
||||
return self._binop(And, other)
|
||||
|
||||
def __rand__(self, other: ExpOrStr) -> And:
|
||||
return self._binop(And, other, reverse=True)
|
||||
|
||||
def __or__(self, other: ExpOrStr) -> Or:
|
||||
return self._binop(Or, other)
|
||||
|
||||
def __ror__(self, other: ExpOrStr) -> Or:
|
||||
return self._binop(Or, other, reverse=True)
|
||||
|
||||
def __neg__(self) -> Neg:
|
||||
return Neg(this=_wrap(self, Binary))
|
||||
|
||||
def __invert__(self) -> Not:
|
||||
return not_(self)
|
||||
|
||||
|
||||
class Predicate(Condition):
|
||||
"""Relationships like x = y, x > 1, x >= y."""
|
||||
|
@ -818,7 +931,6 @@ class Create(Expression):
|
|||
"properties": False,
|
||||
"replace": False,
|
||||
"unique": False,
|
||||
"volatile": False,
|
||||
"indexes": False,
|
||||
"no_schema_binding": False,
|
||||
"begin": False,
|
||||
|
@ -1053,6 +1165,11 @@ class NotNullColumnConstraint(ColumnConstraintKind):
|
|||
arg_types = {"allow_null": False}
|
||||
|
||||
|
||||
# https://dev.mysql.com/doc/refman/5.7/en/timestamp-initialization.html
|
||||
class OnUpdateColumnConstraint(ColumnConstraintKind):
|
||||
pass
|
||||
|
||||
|
||||
class PrimaryKeyColumnConstraint(ColumnConstraintKind):
|
||||
arg_types = {"desc": False}
|
||||
|
||||
|
@ -1197,6 +1314,7 @@ class Drop(Expression):
|
|||
"materialized": False,
|
||||
"cascade": False,
|
||||
"constraints": False,
|
||||
"purge": False,
|
||||
}
|
||||
|
||||
|
||||
|
@ -1287,6 +1405,7 @@ class Insert(Expression):
|
|||
"with": False,
|
||||
"this": True,
|
||||
"expression": False,
|
||||
"conflict": False,
|
||||
"returning": False,
|
||||
"overwrite": False,
|
||||
"exists": False,
|
||||
|
@ -1295,6 +1414,16 @@ class Insert(Expression):
|
|||
}
|
||||
|
||||
|
||||
class OnConflict(Expression):
|
||||
arg_types = {
|
||||
"duplicate": False,
|
||||
"expressions": False,
|
||||
"nothing": False,
|
||||
"key": False,
|
||||
"constraint": False,
|
||||
}
|
||||
|
||||
|
||||
class Returning(Expression):
|
||||
arg_types = {"expressions": True}
|
||||
|
||||
|
@ -1326,7 +1455,12 @@ class Partition(Expression):
|
|||
|
||||
|
||||
class Fetch(Expression):
|
||||
arg_types = {"direction": False, "count": False}
|
||||
arg_types = {
|
||||
"direction": False,
|
||||
"count": False,
|
||||
"percent": False,
|
||||
"with_ties": False,
|
||||
}
|
||||
|
||||
|
||||
class Group(Expression):
|
||||
|
@ -1374,6 +1508,7 @@ class Join(Expression):
|
|||
"kind": False,
|
||||
"using": False,
|
||||
"natural": False,
|
||||
"hint": False,
|
||||
}
|
||||
|
||||
@property
|
||||
|
@ -1384,6 +1519,10 @@ class Join(Expression):
|
|||
def side(self):
|
||||
return self.text("side").upper()
|
||||
|
||||
@property
|
||||
def hint(self):
|
||||
return self.text("hint").upper()
|
||||
|
||||
@property
|
||||
def alias_or_name(self):
|
||||
return self.this.alias_or_name
|
||||
|
@ -1475,6 +1614,7 @@ class MatchRecognize(Expression):
|
|||
"after": False,
|
||||
"pattern": False,
|
||||
"define": False,
|
||||
"alias": False,
|
||||
}
|
||||
|
||||
|
||||
|
@ -1582,6 +1722,10 @@ class FreespaceProperty(Property):
|
|||
arg_types = {"this": True, "percent": False}
|
||||
|
||||
|
||||
class InputOutputFormat(Expression):
|
||||
arg_types = {"input_format": False, "output_format": False}
|
||||
|
||||
|
||||
class IsolatedLoadingProperty(Property):
|
||||
arg_types = {
|
||||
"no": True,
|
||||
|
@ -1646,6 +1790,10 @@ class ReturnsProperty(Property):
|
|||
arg_types = {"this": True, "is_table": False, "table": False}
|
||||
|
||||
|
||||
class RowFormatProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class RowFormatDelimitedProperty(Property):
|
||||
# https://cwiki.apache.org/confluence/display/hive/languagemanual+dml
|
||||
arg_types = {
|
||||
|
@ -1683,6 +1831,10 @@ class SqlSecurityProperty(Property):
|
|||
arg_types = {"definer": True}
|
||||
|
||||
|
||||
class StabilityProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class TableFormatProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
@ -1695,8 +1847,8 @@ class TransientProperty(Property):
|
|||
arg_types = {"this": False}
|
||||
|
||||
|
||||
class VolatilityProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
class VolatileProperty(Property):
|
||||
arg_types = {"this": False}
|
||||
|
||||
|
||||
class WithDataProperty(Property):
|
||||
|
@ -1726,6 +1878,7 @@ class Properties(Expression):
|
|||
"LOCATION": LocationProperty,
|
||||
"PARTITIONED_BY": PartitionedByProperty,
|
||||
"RETURNS": ReturnsProperty,
|
||||
"ROW_FORMAT": RowFormatProperty,
|
||||
"SORTKEY": SortKeyProperty,
|
||||
"TABLE_FORMAT": TableFormatProperty,
|
||||
}
|
||||
|
@ -2721,6 +2874,7 @@ class Pivot(Expression):
|
|||
"expressions": True,
|
||||
"field": True,
|
||||
"unpivot": True,
|
||||
"columns": False,
|
||||
}
|
||||
|
||||
|
||||
|
@ -2731,6 +2885,8 @@ class Window(Expression):
|
|||
"order": False,
|
||||
"spec": False,
|
||||
"alias": False,
|
||||
"over": False,
|
||||
"first": False,
|
||||
}
|
||||
|
||||
|
||||
|
@ -2816,6 +2972,7 @@ class DataType(Expression):
|
|||
FLOAT = auto()
|
||||
DOUBLE = auto()
|
||||
DECIMAL = auto()
|
||||
BIGDECIMAL = auto()
|
||||
BIT = auto()
|
||||
BOOLEAN = auto()
|
||||
JSON = auto()
|
||||
|
@ -2964,7 +3121,7 @@ class DropPartition(Expression):
|
|||
|
||||
|
||||
# Binary expressions like (ADD a b)
|
||||
class Binary(Expression):
|
||||
class Binary(Condition):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
@property
|
||||
|
@ -2980,7 +3137,7 @@ class Add(Binary):
|
|||
pass
|
||||
|
||||
|
||||
class Connector(Binary, Condition):
|
||||
class Connector(Binary):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -3142,7 +3299,7 @@ class ArrayOverlaps(Binary):
|
|||
|
||||
# Unary Expressions
|
||||
# (NOT a)
|
||||
class Unary(Expression):
|
||||
class Unary(Condition):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -3150,11 +3307,11 @@ class BitwiseNot(Unary):
|
|||
pass
|
||||
|
||||
|
||||
class Not(Unary, Condition):
|
||||
class Not(Unary):
|
||||
pass
|
||||
|
||||
|
||||
class Paren(Unary, Condition):
|
||||
class Paren(Unary):
|
||||
arg_types = {"this": True, "with": False}
|
||||
|
||||
|
||||
|
@ -3162,7 +3319,6 @@ class Neg(Unary):
|
|||
pass
|
||||
|
||||
|
||||
# Special Functions
|
||||
class Alias(Expression):
|
||||
arg_types = {"this": True, "alias": False}
|
||||
|
||||
|
@ -3381,6 +3537,16 @@ class AnyValue(AggFunc):
|
|||
class Case(Func):
|
||||
arg_types = {"this": False, "ifs": True, "default": False}
|
||||
|
||||
def when(self, condition: ExpOrStr, then: ExpOrStr, copy: bool = True, **opts) -> Case:
|
||||
this = self.copy() if copy else self
|
||||
this.append("ifs", If(this=maybe_parse(condition, **opts), true=maybe_parse(then, **opts)))
|
||||
return this
|
||||
|
||||
def else_(self, condition: ExpOrStr, copy: bool = True, **opts) -> Case:
|
||||
this = self.copy() if copy else self
|
||||
this.set("default", maybe_parse(condition, **opts))
|
||||
return this
|
||||
|
||||
|
||||
class Cast(Func):
|
||||
arg_types = {"this": True, "to": True}
|
||||
|
@ -3719,6 +3885,10 @@ class Map(Func):
|
|||
arg_types = {"keys": False, "values": False}
|
||||
|
||||
|
||||
class StarMap(Func):
|
||||
pass
|
||||
|
||||
|
||||
class VarMap(Func):
|
||||
arg_types = {"keys": True, "values": True}
|
||||
is_var_len_args = True
|
||||
|
@ -3734,6 +3904,10 @@ class Max(AggFunc):
|
|||
is_var_len_args = True
|
||||
|
||||
|
||||
class MD5(Func):
|
||||
_sql_names = ["MD5"]
|
||||
|
||||
|
||||
class Min(AggFunc):
|
||||
arg_types = {"this": True, "expressions": False}
|
||||
is_var_len_args = True
|
||||
|
@ -3840,6 +4014,15 @@ class SetAgg(AggFunc):
|
|||
pass
|
||||
|
||||
|
||||
class SHA(Func):
|
||||
_sql_names = ["SHA", "SHA1"]
|
||||
|
||||
|
||||
class SHA2(Func):
|
||||
_sql_names = ["SHA2"]
|
||||
arg_types = {"this": True, "length": False}
|
||||
|
||||
|
||||
class SortArray(Func):
|
||||
arg_types = {"this": True, "asc": False}
|
||||
|
||||
|
@ -4017,6 +4200,12 @@ class When(Func):
|
|||
arg_types = {"matched": True, "source": False, "condition": False, "then": True}
|
||||
|
||||
|
||||
# https://docs.oracle.com/javadb/10.8.3.0/ref/rrefsqljnextvaluefor.html
|
||||
# https://learn.microsoft.com/en-us/sql/t-sql/functions/next-value-for-transact-sql?view=sql-server-ver16
|
||||
class NextValueFor(Func):
|
||||
arg_types = {"this": True, "order": False}
|
||||
|
||||
|
||||
def _norm_arg(arg):
|
||||
return arg.lower() if type(arg) is str else arg
|
||||
|
||||
|
@ -4025,6 +4214,32 @@ ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
|
|||
|
||||
|
||||
# Helpers
|
||||
@t.overload
|
||||
def maybe_parse(
|
||||
sql_or_expression: ExpOrStr,
|
||||
*,
|
||||
into: t.Type[E],
|
||||
dialect: DialectType = None,
|
||||
prefix: t.Optional[str] = None,
|
||||
copy: bool = False,
|
||||
**opts,
|
||||
) -> E:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def maybe_parse(
|
||||
sql_or_expression: str | E,
|
||||
*,
|
||||
into: t.Optional[IntoType] = None,
|
||||
dialect: DialectType = None,
|
||||
prefix: t.Optional[str] = None,
|
||||
copy: bool = False,
|
||||
**opts,
|
||||
) -> E:
|
||||
...
|
||||
|
||||
|
||||
def maybe_parse(
|
||||
sql_or_expression: ExpOrStr,
|
||||
*,
|
||||
|
@ -4200,15 +4415,15 @@ def _combine(expressions, operator, dialect=None, **opts):
|
|||
expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions]
|
||||
this = expressions[0]
|
||||
if expressions[1:]:
|
||||
this = _wrap_operator(this)
|
||||
this = _wrap(this, Connector)
|
||||
for expression in expressions[1:]:
|
||||
this = operator(this=this, expression=_wrap_operator(expression))
|
||||
this = operator(this=this, expression=_wrap(expression, Connector))
|
||||
return this
|
||||
|
||||
|
||||
def _wrap_operator(expression):
|
||||
if isinstance(expression, (And, Or, Not)):
|
||||
expression = Paren(this=expression)
|
||||
def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren:
|
||||
if isinstance(expression, kind):
|
||||
return Paren(this=expression)
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -4506,7 +4721,7 @@ def not_(expression, dialect=None, **opts) -> Not:
|
|||
dialect=dialect,
|
||||
**opts,
|
||||
)
|
||||
return Not(this=_wrap_operator(this))
|
||||
return Not(this=_wrap(this, Connector))
|
||||
|
||||
|
||||
def paren(expression) -> Paren:
|
||||
|
@ -4657,6 +4872,8 @@ def alias_(
|
|||
|
||||
if table:
|
||||
table_alias = TableAlias(this=alias)
|
||||
|
||||
exp = exp.copy() if isinstance(expression, Expression) else exp
|
||||
exp.set("alias", table_alias)
|
||||
|
||||
if not isinstance(table, bool):
|
||||
|
@ -4864,16 +5081,22 @@ def convert(value) -> Expression:
|
|||
"""
|
||||
if isinstance(value, Expression):
|
||||
return value
|
||||
if value is None:
|
||||
return NULL
|
||||
if isinstance(value, bool):
|
||||
return Boolean(this=value)
|
||||
if isinstance(value, str):
|
||||
return Literal.string(value)
|
||||
if isinstance(value, float) and math.isnan(value):
|
||||
if isinstance(value, bool):
|
||||
return Boolean(this=value)
|
||||
if value is None or (isinstance(value, float) and math.isnan(value)):
|
||||
return NULL
|
||||
if isinstance(value, numbers.Number):
|
||||
return Literal.number(value)
|
||||
if isinstance(value, datetime.datetime):
|
||||
datetime_literal = Literal.string(
|
||||
(value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat()
|
||||
)
|
||||
return TimeStrToTime(this=datetime_literal)
|
||||
if isinstance(value, datetime.date):
|
||||
date_literal = Literal.string(value.strftime("%Y-%m-%d"))
|
||||
return DateStrToDate(this=date_literal)
|
||||
if isinstance(value, tuple):
|
||||
return Tuple(expressions=[convert(v) for v in value])
|
||||
if isinstance(value, list):
|
||||
|
@ -4883,14 +5106,6 @@ def convert(value) -> Expression:
|
|||
keys=[convert(k) for k in value],
|
||||
values=[convert(v) for v in value.values()],
|
||||
)
|
||||
if isinstance(value, datetime.datetime):
|
||||
datetime_literal = Literal.string(
|
||||
(value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat()
|
||||
)
|
||||
return TimeStrToTime(this=datetime_literal)
|
||||
if isinstance(value, datetime.date):
|
||||
date_literal = Literal.string(value.strftime("%Y-%m-%d"))
|
||||
return DateStrToDate(this=date_literal)
|
||||
raise ValueError(f"Cannot convert {value}")
|
||||
|
||||
|
||||
|
@ -5030,7 +5245,9 @@ def replace_placeholders(expression, *args, **kwargs):
|
|||
return expression.transform(_replace_placeholders, iter(args), **kwargs)
|
||||
|
||||
|
||||
def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True) -> Expression:
|
||||
def expand(
|
||||
expression: Expression, sources: t.Dict[str, Subqueryable], copy: bool = True
|
||||
) -> Expression:
|
||||
"""Transforms an expression by expanding all referenced sources into subqueries.
|
||||
|
||||
Examples:
|
||||
|
@ -5038,6 +5255,9 @@ def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True
|
|||
>>> expand(parse_one("select * from x AS z"), {"x": parse_one("select * from y")}).sql()
|
||||
'SELECT * FROM (SELECT * FROM y) AS z /* source: x */'
|
||||
|
||||
>>> expand(parse_one("select * from x AS z"), {"x": parse_one("select * from y"), "y": parse_one("select * from z")}).sql()
|
||||
'SELECT * FROM (SELECT * FROM (SELECT * FROM z) AS y /* source: y */) AS z /* source: x */'
|
||||
|
||||
Args:
|
||||
expression: The expression to expand.
|
||||
sources: A dictionary of name to Subqueryables.
|
||||
|
@ -5054,7 +5274,7 @@ def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True
|
|||
if source:
|
||||
subquery = source.subquery(node.alias or name)
|
||||
subquery.comments = [f"source: {name}"]
|
||||
return subquery
|
||||
return subquery.transform(_expand, copy=False)
|
||||
return node
|
||||
|
||||
return expression.transform(_expand, copy=copy)
|
||||
|
@ -5089,8 +5309,8 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
|
|||
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
|
||||
converted = [convert(arg) for arg in args]
|
||||
kwargs = {key: convert(value) for key, value in kwargs.items()}
|
||||
converted: t.List[Expression] = [maybe_parse(arg, dialect=dialect) for arg in args]
|
||||
kwargs = {key: maybe_parse(value, dialect=dialect) for key, value in kwargs.items()}
|
||||
|
||||
parser = Dialect.get_or_raise(dialect)().parser()
|
||||
from_args_list = parser.FUNCTIONS.get(name.upper())
|
||||
|
|
|
@ -76,11 +76,13 @@ class Generator:
|
|||
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
|
||||
exp.TemporaryProperty: lambda self, e: f"{'GLOBAL ' if e.args.get('global_') else ''}TEMPORARY",
|
||||
exp.TransientProperty: lambda self, e: "TRANSIENT",
|
||||
exp.VolatilityProperty: lambda self, e: e.name,
|
||||
exp.StabilityProperty: lambda self, e: e.name,
|
||||
exp.VolatileProperty: lambda self, e: "VOLATILE",
|
||||
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
|
||||
exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
|
||||
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
|
||||
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
|
||||
exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
|
||||
exp.UppercaseColumnConstraint: lambda self, e: f"UPPERCASE",
|
||||
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
|
||||
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
|
||||
|
@ -110,8 +112,19 @@ class Generator:
|
|||
# Whether or not MERGE ... WHEN MATCHED BY SOURCE is allowed
|
||||
MATCHED_BY_SOURCE = True
|
||||
|
||||
# Whether or not limit and fetch are supported
|
||||
# "ALL", "LIMIT", "FETCH"
|
||||
# Whether or not the INTERVAL expression works only with values like '1 day'
|
||||
SINGLE_STRING_INTERVAL = False
|
||||
|
||||
# Whether or not the plural form of date parts like day (i.e. "days") is supported in INTERVALs
|
||||
INTERVAL_ALLOWS_PLURAL_FORM = True
|
||||
|
||||
# Whether or not the TABLESAMPLE clause supports a method name, like BERNOULLI
|
||||
TABLESAMPLE_WITH_METHOD = True
|
||||
|
||||
# Whether or not to treat the number in TABLESAMPLE (50) as a percentage
|
||||
TABLESAMPLE_SIZE_IS_PERCENT = False
|
||||
|
||||
# Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH")
|
||||
LIMIT_FETCH = "ALL"
|
||||
|
||||
TYPE_MAPPING = {
|
||||
|
@ -129,6 +142,18 @@ class Generator:
|
|||
"replace": "REPLACE",
|
||||
}
|
||||
|
||||
TIME_PART_SINGULARS = {
|
||||
"microseconds": "microsecond",
|
||||
"seconds": "second",
|
||||
"minutes": "minute",
|
||||
"hours": "hour",
|
||||
"days": "day",
|
||||
"weeks": "week",
|
||||
"months": "month",
|
||||
"quarters": "quarter",
|
||||
"years": "year",
|
||||
}
|
||||
|
||||
TOKEN_MAPPING: t.Dict[TokenType, str] = {}
|
||||
|
||||
STRUCT_DELIMITER = ("<", ">")
|
||||
|
@ -168,6 +193,7 @@ class Generator:
|
|||
exp.PartitionedByProperty: exp.Properties.Location.POST_WITH,
|
||||
exp.Property: exp.Properties.Location.POST_WITH,
|
||||
exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.RowFormatProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
|
@ -175,15 +201,22 @@ class Generator:
|
|||
exp.SetProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.TableFormatProperty: exp.Properties.Location.POST_WITH,
|
||||
exp.TemporaryProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.TransientProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION,
|
||||
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
|
||||
}
|
||||
|
||||
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
|
||||
JOIN_HINTS = True
|
||||
TABLE_HINTS = True
|
||||
|
||||
RESERVED_KEYWORDS: t.Set[str] = set()
|
||||
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.With)
|
||||
UNWRAPPED_INTERVAL_VALUES = (exp.Literal, exp.Paren, exp.Column)
|
||||
|
||||
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
|
||||
|
||||
__slots__ = (
|
||||
|
@ -322,10 +355,15 @@ class Generator:
|
|||
comment = comment + " " if comment[-1].strip() else comment
|
||||
return comment
|
||||
|
||||
def maybe_comment(self, sql: str, expression: exp.Expression) -> str:
|
||||
comments = expression.comments if self._comments else None
|
||||
def maybe_comment(
|
||||
self,
|
||||
sql: str,
|
||||
expression: t.Optional[exp.Expression] = None,
|
||||
comments: t.Optional[t.List[str]] = None,
|
||||
) -> str:
|
||||
comments = (comments or (expression and expression.comments)) if self._comments else None # type: ignore
|
||||
|
||||
if not comments:
|
||||
if not comments or isinstance(expression, exp.Binary):
|
||||
return sql
|
||||
|
||||
sep = "\n" if self.pretty else " "
|
||||
|
@ -621,7 +659,6 @@ class Generator:
|
|||
|
||||
replace = " OR REPLACE" if expression.args.get("replace") else ""
|
||||
unique = " UNIQUE" if expression.args.get("unique") else ""
|
||||
volatile = " VOLATILE" if expression.args.get("volatile") else ""
|
||||
|
||||
postcreate_props_sql = ""
|
||||
if properties_locs.get(exp.Properties.Location.POST_CREATE):
|
||||
|
@ -632,7 +669,7 @@ class Generator:
|
|||
wrapped=False,
|
||||
)
|
||||
|
||||
modifiers = "".join((replace, unique, volatile, postcreate_props_sql))
|
||||
modifiers = "".join((replace, unique, postcreate_props_sql))
|
||||
|
||||
postexpression_props_sql = ""
|
||||
if properties_locs.get(exp.Properties.Location.POST_EXPRESSION):
|
||||
|
@ -684,6 +721,9 @@ class Generator:
|
|||
def hexstring_sql(self, expression: exp.HexString) -> str:
|
||||
return self.sql(expression, "this")
|
||||
|
||||
def bytestring_sql(self, expression: exp.ByteString) -> str:
|
||||
return self.sql(expression, "this")
|
||||
|
||||
def datatype_sql(self, expression: exp.DataType) -> str:
|
||||
type_value = expression.this
|
||||
type_sql = self.TYPE_MAPPING.get(type_value, type_value.value)
|
||||
|
@ -695,9 +735,7 @@ class Generator:
|
|||
nested = f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}"
|
||||
if expression.args.get("values") is not None:
|
||||
delimiters = ("[", "]") if type_value == exp.DataType.Type.ARRAY else ("(", ")")
|
||||
values = (
|
||||
f"{delimiters[0]}{self.expressions(expression, 'values')}{delimiters[1]}"
|
||||
)
|
||||
values = f"{delimiters[0]}{self.expressions(expression, key='values')}{delimiters[1]}"
|
||||
else:
|
||||
nested = f"({interior})"
|
||||
|
||||
|
@ -713,7 +751,7 @@ class Generator:
|
|||
this = self.sql(expression, "this")
|
||||
this = f" FROM {this}" if this else ""
|
||||
using_sql = (
|
||||
f" USING {self.expressions(expression, 'using', sep=', USING ')}"
|
||||
f" USING {self.expressions(expression, key='using', sep=', USING ')}"
|
||||
if expression.args.get("using")
|
||||
else ""
|
||||
)
|
||||
|
@ -730,7 +768,10 @@ class Generator:
|
|||
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
|
||||
cascade = " CASCADE" if expression.args.get("cascade") else ""
|
||||
constraints = " CONSTRAINTS" if expression.args.get("constraints") else ""
|
||||
return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}{constraints}"
|
||||
purge = " PURGE" if expression.args.get("purge") else ""
|
||||
return (
|
||||
f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}{constraints}{purge}"
|
||||
)
|
||||
|
||||
def except_sql(self, expression: exp.Except) -> str:
|
||||
return self.prepend_ctes(
|
||||
|
@ -746,7 +787,10 @@ class Generator:
|
|||
direction = f" {direction.upper()}" if direction else ""
|
||||
count = expression.args.get("count")
|
||||
count = f" {count}" if count else ""
|
||||
return f"{self.seg('FETCH')}{direction}{count} ROWS ONLY"
|
||||
if expression.args.get("percent"):
|
||||
count = f"{count} PERCENT"
|
||||
with_ties_or_only = "WITH TIES" if expression.args.get("with_ties") else "ONLY"
|
||||
return f"{self.seg('FETCH')}{direction}{count} ROWS {with_ties_or_only}"
|
||||
|
||||
def filter_sql(self, expression: exp.Filter) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -766,12 +810,24 @@ class Generator:
|
|||
|
||||
def identifier_sql(self, expression: exp.Identifier) -> str:
|
||||
text = expression.name
|
||||
text = text.lower() if self.normalize and not expression.quoted else text
|
||||
lower = text.lower()
|
||||
text = lower if self.normalize and not expression.quoted else text
|
||||
text = text.replace(self.identifier_end, self._escaped_identifier_end)
|
||||
if expression.quoted or should_identify(text, self.identify):
|
||||
if (
|
||||
expression.quoted
|
||||
or should_identify(text, self.identify)
|
||||
or lower in self.RESERVED_KEYWORDS
|
||||
):
|
||||
text = f"{self.identifier_start}{text}{self.identifier_end}"
|
||||
return text
|
||||
|
||||
def inputoutputformat_sql(self, expression: exp.InputOutputFormat) -> str:
|
||||
input_format = self.sql(expression, "input_format")
|
||||
input_format = f"INPUTFORMAT {input_format}" if input_format else ""
|
||||
output_format = self.sql(expression, "output_format")
|
||||
output_format = f"OUTPUTFORMAT {output_format}" if output_format else ""
|
||||
return self.sep().join((input_format, output_format))
|
||||
|
||||
def national_sql(self, expression: exp.National) -> str:
|
||||
return f"N{self.sql(expression, 'this')}"
|
||||
|
||||
|
@ -984,9 +1040,10 @@ class Generator:
|
|||
self.sql(expression, "partition") if expression.args.get("partition") else ""
|
||||
)
|
||||
expression_sql = self.sql(expression, "expression")
|
||||
conflict = self.sql(expression, "conflict")
|
||||
returning = self.sql(expression, "returning")
|
||||
sep = self.sep() if partition_sql else ""
|
||||
sql = f"INSERT{alternative}{this}{exists}{partition_sql}{sep}{expression_sql}{returning}"
|
||||
sql = f"INSERT{alternative}{this}{exists}{partition_sql}{sep}{expression_sql}{conflict}{returning}"
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
def intersect_sql(self, expression: exp.Intersect) -> str:
|
||||
|
@ -1004,6 +1061,19 @@ class Generator:
|
|||
def pseudotype_sql(self, expression: exp.PseudoType) -> str:
|
||||
return expression.name.upper()
|
||||
|
||||
def onconflict_sql(self, expression: exp.OnConflict) -> str:
|
||||
conflict = "ON DUPLICATE KEY" if expression.args.get("duplicate") else "ON CONFLICT"
|
||||
constraint = self.sql(expression, "constraint")
|
||||
if constraint:
|
||||
constraint = f"ON CONSTRAINT {constraint}"
|
||||
key = self.expressions(expression, key="key", flat=True)
|
||||
do = "" if expression.args.get("duplicate") else " DO "
|
||||
nothing = "NOTHING" if expression.args.get("nothing") else ""
|
||||
expressions = self.expressions(expression, flat=True)
|
||||
if expressions:
|
||||
expressions = f"UPDATE SET {expressions}"
|
||||
return f"{self.seg(conflict)} {constraint}{key}{do}{nothing}{expressions}"
|
||||
|
||||
def returning_sql(self, expression: exp.Returning) -> str:
|
||||
return f"{self.seg('RETURNING')} {self.expressions(expression, flat=True)}"
|
||||
|
||||
|
@ -1036,7 +1106,7 @@ class Generator:
|
|||
alias = self.sql(expression, "alias")
|
||||
alias = f"{sep}{alias}" if alias else ""
|
||||
hints = self.expressions(expression, key="hints", sep=", ", flat=True)
|
||||
hints = f" WITH ({hints})" if hints else ""
|
||||
hints = f" WITH ({hints})" if hints and self.TABLE_HINTS else ""
|
||||
laterals = self.expressions(expression, key="laterals", sep="")
|
||||
joins = self.expressions(expression, key="joins", sep="")
|
||||
pivots = self.expressions(expression, key="pivots", sep="")
|
||||
|
@ -1053,7 +1123,7 @@ class Generator:
|
|||
this = self.sql(expression, "this")
|
||||
alias = ""
|
||||
method = self.sql(expression, "method")
|
||||
method = f"{method.upper()} " if method else ""
|
||||
method = f"{method.upper()} " if method and self.TABLESAMPLE_WITH_METHOD else ""
|
||||
numerator = self.sql(expression, "bucket_numerator")
|
||||
denominator = self.sql(expression, "bucket_denominator")
|
||||
field = self.sql(expression, "bucket_field")
|
||||
|
@ -1064,6 +1134,8 @@ class Generator:
|
|||
rows = self.sql(expression, "rows")
|
||||
rows = f"{rows} ROWS" if rows else ""
|
||||
size = self.sql(expression, "size")
|
||||
if size and self.TABLESAMPLE_SIZE_IS_PERCENT:
|
||||
size = f"{size} PERCENT"
|
||||
seed = self.sql(expression, "seed")
|
||||
seed = f" {seed_prefix} ({seed})" if seed else ""
|
||||
kind = expression.args.get("kind", "TABLESAMPLE")
|
||||
|
@ -1154,6 +1226,7 @@ class Generator:
|
|||
"NATURAL" if expression.args.get("natural") else None,
|
||||
expression.side,
|
||||
expression.kind,
|
||||
expression.hint if self.JOIN_HINTS else None,
|
||||
"JOIN",
|
||||
)
|
||||
if op
|
||||
|
@ -1311,16 +1384,20 @@ class Generator:
|
|||
def matchrecognize_sql(self, expression: exp.MatchRecognize) -> str:
|
||||
partition = self.partition_by_sql(expression)
|
||||
order = self.sql(expression, "order")
|
||||
measures = self.sql(expression, "measures")
|
||||
measures = self.seg(f"MEASURES {measures}") if measures else ""
|
||||
measures = self.expressions(expression, key="measures")
|
||||
measures = self.seg(f"MEASURES{self.seg(measures)}") if measures else ""
|
||||
rows = self.sql(expression, "rows")
|
||||
rows = self.seg(rows) if rows else ""
|
||||
after = self.sql(expression, "after")
|
||||
after = self.seg(after) if after else ""
|
||||
pattern = self.sql(expression, "pattern")
|
||||
pattern = self.seg(f"PATTERN ({pattern})") if pattern else ""
|
||||
define = self.sql(expression, "define")
|
||||
define = self.seg(f"DEFINE {define}") if define else ""
|
||||
definition_sqls = [
|
||||
f"{self.sql(definition, 'alias')} AS {self.sql(definition, 'this')}"
|
||||
for definition in expression.args.get("define", [])
|
||||
]
|
||||
definitions = self.expressions(sqls=definition_sqls)
|
||||
define = self.seg(f"DEFINE{self.seg(definitions)}") if definitions else ""
|
||||
body = "".join(
|
||||
(
|
||||
partition,
|
||||
|
@ -1332,7 +1409,9 @@ class Generator:
|
|||
define,
|
||||
)
|
||||
)
|
||||
return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}"
|
||||
alias = self.sql(expression, "alias")
|
||||
alias = f" {alias}" if alias else ""
|
||||
return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}{alias}"
|
||||
|
||||
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
|
||||
limit = expression.args.get("limit")
|
||||
|
@ -1353,7 +1432,7 @@ class Generator:
|
|||
self.sql(expression, "group"),
|
||||
self.sql(expression, "having"),
|
||||
self.sql(expression, "qualify"),
|
||||
self.seg("WINDOW ") + self.expressions(expression, "windows", flat=True)
|
||||
self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True)
|
||||
if expression.args.get("windows")
|
||||
else "",
|
||||
self.sql(expression, "distribute"),
|
||||
|
@ -1471,15 +1550,21 @@ class Generator:
|
|||
partition_sql = partition + " " if partition and order else partition
|
||||
|
||||
spec = expression.args.get("spec")
|
||||
spec_sql = " " + self.window_spec_sql(spec) if spec else ""
|
||||
spec_sql = " " + self.windowspec_sql(spec) if spec else ""
|
||||
|
||||
alias = self.sql(expression, "alias")
|
||||
this = f"{this} {'AS' if expression.arg_key == 'windows' else 'OVER'}"
|
||||
over = self.sql(expression, "over") or "OVER"
|
||||
this = f"{this} {'AS' if expression.arg_key == 'windows' else over}"
|
||||
|
||||
first = expression.args.get("first")
|
||||
if first is not None:
|
||||
first = " FIRST " if first else " LAST "
|
||||
first = first or ""
|
||||
|
||||
if not partition and not order and not spec and alias:
|
||||
return f"{this} {alias}"
|
||||
|
||||
window_args = alias + partition_sql + order_sql + spec_sql
|
||||
window_args = alias + first + partition_sql + order_sql + spec_sql
|
||||
|
||||
return f"{this} ({window_args.strip()})"
|
||||
|
||||
|
@ -1487,7 +1572,7 @@ class Generator:
|
|||
partition = self.expressions(expression, key="partition_by", flat=True)
|
||||
return f"PARTITION BY {partition}" if partition else ""
|
||||
|
||||
def window_spec_sql(self, expression: exp.WindowSpec) -> str:
|
||||
def windowspec_sql(self, expression: exp.WindowSpec) -> str:
|
||||
kind = self.sql(expression, "kind")
|
||||
start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ")
|
||||
end = (
|
||||
|
@ -1508,7 +1593,7 @@ class Generator:
|
|||
return f"{this} BETWEEN {low} AND {high}"
|
||||
|
||||
def bracket_sql(self, expression: exp.Bracket) -> str:
|
||||
expressions = apply_index_offset(expression.expressions, self.index_offset)
|
||||
expressions = apply_index_offset(expression.this, expression.expressions, self.index_offset)
|
||||
expressions_sql = ", ".join(self.sql(e) for e in expressions)
|
||||
|
||||
return f"{self.sql(expression, 'this')}[{expressions_sql}]"
|
||||
|
@ -1550,6 +1635,11 @@ class Generator:
|
|||
expressions = self.expressions(expression, flat=True)
|
||||
return f"CONSTRAINT {this} {expressions}"
|
||||
|
||||
def nextvaluefor_sql(self, expression: exp.NextValueFor) -> str:
|
||||
order = expression.args.get("order")
|
||||
order = f" OVER ({self.order_sql(order, flat=True)})" if order else ""
|
||||
return f"NEXT VALUE FOR {self.sql(expression, 'this')}{order}"
|
||||
|
||||
def extract_sql(self, expression: exp.Extract) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
expression_sql = self.sql(expression, "expression")
|
||||
|
@ -1586,7 +1676,7 @@ class Generator:
|
|||
|
||||
def primarykey_sql(self, expression: exp.ForeignKey) -> str:
|
||||
expressions = self.expressions(expression, flat=True)
|
||||
options = self.expressions(expression, "options", flat=True, sep=" ")
|
||||
options = self.expressions(expression, key="options", flat=True, sep=" ")
|
||||
options = f" {options}" if options else ""
|
||||
return f"PRIMARY KEY ({expressions}){options}"
|
||||
|
||||
|
@ -1644,17 +1734,20 @@ class Generator:
|
|||
return f"(SELECT {self.sql(unnest)})"
|
||||
|
||||
def interval_sql(self, expression: exp.Interval) -> str:
|
||||
this = expression.args.get("this")
|
||||
if this:
|
||||
this = (
|
||||
f" {this}"
|
||||
if isinstance(this, exp.Literal) or isinstance(this, exp.Paren)
|
||||
else f" ({this})"
|
||||
)
|
||||
else:
|
||||
this = ""
|
||||
unit = self.sql(expression, "unit")
|
||||
if not self.INTERVAL_ALLOWS_PLURAL_FORM:
|
||||
unit = self.TIME_PART_SINGULARS.get(unit.lower(), unit)
|
||||
unit = f" {unit}" if unit else ""
|
||||
|
||||
if self.SINGLE_STRING_INTERVAL:
|
||||
this = expression.this.name if expression.this else ""
|
||||
return f"INTERVAL '{this}{unit}'"
|
||||
|
||||
this = self.sql(expression, "this")
|
||||
if this:
|
||||
unwrapped = isinstance(expression.this, self.UNWRAPPED_INTERVAL_VALUES)
|
||||
this = f" {this}" if unwrapped else f" ({this})"
|
||||
|
||||
return f"INTERVAL{this}{unit}"
|
||||
|
||||
def return_sql(self, expression: exp.Return) -> str:
|
||||
|
@ -1664,7 +1757,7 @@ class Generator:
|
|||
this = self.sql(expression, "this")
|
||||
expressions = self.expressions(expression, flat=True)
|
||||
expressions = f"({expressions})" if expressions else ""
|
||||
options = self.expressions(expression, "options", flat=True, sep=" ")
|
||||
options = self.expressions(expression, key="options", flat=True, sep=" ")
|
||||
options = f" {options}" if options else ""
|
||||
return f"REFERENCES {this}{expressions}{options}"
|
||||
|
||||
|
@ -1690,9 +1783,9 @@ class Generator:
|
|||
return f"NOT {self.sql(expression, 'this')}"
|
||||
|
||||
def alias_sql(self, expression: exp.Alias) -> str:
|
||||
to_sql = self.sql(expression, "alias")
|
||||
to_sql = f" AS {to_sql}" if to_sql else ""
|
||||
return f"{self.sql(expression, 'this')}{to_sql}"
|
||||
alias = self.sql(expression, "alias")
|
||||
alias = f" AS {alias}" if alias else ""
|
||||
return f"{self.sql(expression, 'this')}{alias}"
|
||||
|
||||
def aliases_sql(self, expression: exp.Aliases) -> str:
|
||||
return f"{self.sql(expression, 'this')} AS ({self.expressions(expression, flat=True)})"
|
||||
|
@ -1712,7 +1805,11 @@ class Generator:
|
|||
if not self.pretty:
|
||||
return self.binary(expression, op)
|
||||
|
||||
sqls = tuple(self.sql(e) for e in expression.flatten(unnest=False))
|
||||
sqls = tuple(
|
||||
self.maybe_comment(self.sql(e), e, e.parent.comments) if i != 1 else self.sql(e)
|
||||
for i, e in enumerate(expression.flatten(unnest=False))
|
||||
)
|
||||
|
||||
sep = "\n" if self.text_width(sqls) > self._max_text_width else " "
|
||||
return f"{sep}{op} ".join(sqls)
|
||||
|
||||
|
@ -1797,13 +1894,13 @@ class Generator:
|
|||
actions = expression.args["actions"]
|
||||
|
||||
if isinstance(actions[0], exp.ColumnDef):
|
||||
actions = self.expressions(expression, "actions", prefix="ADD COLUMN ")
|
||||
actions = self.expressions(expression, key="actions", prefix="ADD COLUMN ")
|
||||
elif isinstance(actions[0], exp.Schema):
|
||||
actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ")
|
||||
actions = self.expressions(expression, key="actions", prefix="ADD COLUMNS ")
|
||||
elif isinstance(actions[0], exp.Delete):
|
||||
actions = self.expressions(expression, "actions", flat=True)
|
||||
actions = self.expressions(expression, key="actions", flat=True)
|
||||
else:
|
||||
actions = self.expressions(expression, "actions")
|
||||
actions = self.expressions(expression, key="actions")
|
||||
|
||||
exists = " IF EXISTS" if expression.args.get("exists") else ""
|
||||
return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}"
|
||||
|
@ -1935,6 +2032,7 @@ class Generator:
|
|||
return f"USE{kind}{this}"
|
||||
|
||||
def binary(self, expression: exp.Binary, op: str) -> str:
|
||||
op = self.maybe_comment(op, comments=expression.comments)
|
||||
return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
|
||||
|
||||
def function_fallback_sql(self, expression: exp.Func) -> str:
|
||||
|
@ -1965,14 +2063,15 @@ class Generator:
|
|||
|
||||
def expressions(
|
||||
self,
|
||||
expression: exp.Expression,
|
||||
expression: t.Optional[exp.Expression] = None,
|
||||
key: t.Optional[str] = None,
|
||||
sqls: t.Optional[t.List[str]] = None,
|
||||
flat: bool = False,
|
||||
indent: bool = True,
|
||||
sep: str = ", ",
|
||||
prefix: str = "",
|
||||
) -> str:
|
||||
expressions = expression.args.get(key or "expressions")
|
||||
expressions = expression.args.get(key or "expressions") if expression else sqls
|
||||
|
||||
if not expressions:
|
||||
return ""
|
||||
|
|
|
@ -131,11 +131,16 @@ def subclasses(
|
|||
]
|
||||
|
||||
|
||||
def apply_index_offset(expressions: t.List[t.Optional[E]], offset: int) -> t.List[t.Optional[E]]:
|
||||
def apply_index_offset(
|
||||
this: exp.Expression,
|
||||
expressions: t.List[t.Optional[E]],
|
||||
offset: int,
|
||||
) -> t.List[t.Optional[E]]:
|
||||
"""
|
||||
Applies an offset to a given integer literal expression.
|
||||
|
||||
Args:
|
||||
this: the target of the index
|
||||
expressions: the expression the offset will be applied to, wrapped in a list.
|
||||
offset: the offset that will be applied.
|
||||
|
||||
|
@ -148,11 +153,28 @@ def apply_index_offset(expressions: t.List[t.Optional[E]], offset: int) -> t.Lis
|
|||
|
||||
expression = expressions[0]
|
||||
|
||||
if expression and expression.is_int:
|
||||
expression = expression.copy()
|
||||
logger.warning("Applying array index offset (%s)", offset)
|
||||
expression.args["this"] = str(int(expression.this) + offset) # type: ignore
|
||||
return [expression]
|
||||
from sqlglot import exp
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
if not this.type:
|
||||
annotate_types(this)
|
||||
|
||||
if t.cast(exp.DataType, this.type).this not in (
|
||||
exp.DataType.Type.UNKNOWN,
|
||||
exp.DataType.Type.ARRAY,
|
||||
):
|
||||
return expressions
|
||||
|
||||
if expression:
|
||||
if not expression.type:
|
||||
annotate_types(expression)
|
||||
if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
|
||||
logger.warning("Applying array index offset (%s)", offset)
|
||||
expression = simplify(
|
||||
exp.Add(this=expression.copy(), expression=exp.Literal.number(offset))
|
||||
)
|
||||
return [expression]
|
||||
|
||||
return expressions
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ class Node:
|
|||
expression: exp.Expression
|
||||
source: exp.Expression
|
||||
downstream: t.List[Node] = field(default_factory=list)
|
||||
alias: str = ""
|
||||
|
||||
def walk(self) -> t.Iterator[Node]:
|
||||
yield self
|
||||
|
@ -69,14 +70,19 @@ def lineage(
|
|||
|
||||
optimized = optimize(expression, schema=schema, rules=rules)
|
||||
scope = build_scope(optimized)
|
||||
tables: t.Dict[str, Node] = {}
|
||||
|
||||
def to_node(
|
||||
column_name: str,
|
||||
scope: Scope,
|
||||
scope_name: t.Optional[str] = None,
|
||||
upstream: t.Optional[Node] = None,
|
||||
alias: t.Optional[str] = None,
|
||||
) -> Node:
|
||||
aliases = {
|
||||
dt.alias: dt.comments[0].split()[1]
|
||||
for dt in scope.derived_tables
|
||||
if dt.comments and dt.comments[0].startswith("source: ")
|
||||
}
|
||||
if isinstance(scope.expression, exp.Union):
|
||||
for scope in scope.union_scopes:
|
||||
node = to_node(
|
||||
|
@ -84,37 +90,58 @@ def lineage(
|
|||
scope=scope,
|
||||
scope_name=scope_name,
|
||||
upstream=upstream,
|
||||
alias=aliases.get(scope_name),
|
||||
)
|
||||
return node
|
||||
|
||||
select = next(select for select in scope.selects if select.alias_or_name == column_name)
|
||||
source = optimize(scope.expression.select(select, append=False), schema=schema, rules=rules)
|
||||
select = source.selects[0]
|
||||
# Find the specific select clause that is the source of the column we want.
|
||||
# This can either be a specific, named select or a generic `*` clause.
|
||||
select = next(
|
||||
(select for select in scope.selects if select.alias_or_name == column_name),
|
||||
exp.Star() if scope.expression.is_star else None,
|
||||
)
|
||||
|
||||
if not select:
|
||||
raise ValueError(f"Could not find {column_name} in {scope.expression}")
|
||||
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
# For better ergonomics in our node labels, replace the full select with
|
||||
# a version that has only the column we care about.
|
||||
# "x", SELECT x, y FROM foo
|
||||
# => "x", SELECT x FROM foo
|
||||
source = optimize(
|
||||
scope.expression.select(select, append=False), schema=schema, rules=rules
|
||||
)
|
||||
select = source.selects[0]
|
||||
else:
|
||||
source = scope.expression
|
||||
|
||||
# Create the node for this step in the lineage chain, and attach it to the previous one.
|
||||
node = Node(
|
||||
name=f"{scope_name}.{column_name}" if scope_name else column_name,
|
||||
source=source,
|
||||
expression=select,
|
||||
alias=alias or "",
|
||||
)
|
||||
|
||||
if upstream:
|
||||
upstream.downstream.append(node)
|
||||
|
||||
# Find all columns that went into creating this one to list their lineage nodes.
|
||||
for c in set(select.find_all(exp.Column)):
|
||||
table = c.table
|
||||
source = scope.sources[table]
|
||||
source = scope.sources.get(table)
|
||||
|
||||
if isinstance(source, Scope):
|
||||
# The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
|
||||
to_node(
|
||||
c.name,
|
||||
scope=source,
|
||||
scope_name=table,
|
||||
upstream=node,
|
||||
c.name, scope=source, scope_name=table, upstream=node, alias=aliases.get(table)
|
||||
)
|
||||
else:
|
||||
if table not in tables:
|
||||
tables[table] = Node(name=c.sql(), source=source, expression=source)
|
||||
node.downstream.append(tables[table])
|
||||
# The source is not a scope - we've reached the end of the line. At this point, if a source is not found
|
||||
# it means this column's lineage is unknown. This can happen if the definition of a source used in a query
|
||||
# is not passed into the `sources` map.
|
||||
source = source or exp.Placeholder()
|
||||
node.downstream.append(Node(name=c.sql(), source=source, expression=source))
|
||||
|
||||
return node
|
||||
|
||||
|
|
|
@ -116,6 +116,9 @@ class TypeAnnotator:
|
|||
exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.ArraySize: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||
exp.Map: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
|
||||
exp.VarMap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
|
||||
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL),
|
||||
exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
|
||||
|
@ -335,7 +338,7 @@ class TypeAnnotator:
|
|||
left_type = expression.left.type.this
|
||||
right_type = expression.right.type.this
|
||||
|
||||
if isinstance(expression, (exp.And, exp.Or)):
|
||||
if isinstance(expression, exp.Connector):
|
||||
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
|
||||
expression.type = exp.DataType.Type.NULL
|
||||
elif exp.DataType.Type.NULL in (left_type, right_type):
|
||||
|
@ -344,7 +347,7 @@ class TypeAnnotator:
|
|||
)
|
||||
else:
|
||||
expression.type = exp.DataType.Type.BOOLEAN
|
||||
elif isinstance(expression, (exp.Condition, exp.Predicate)):
|
||||
elif isinstance(expression, exp.Predicate):
|
||||
expression.type = exp.DataType.Type.BOOLEAN
|
||||
else:
|
||||
expression.type = self._maybe_coerce(left_type, right_type)
|
||||
|
|
|
@ -46,7 +46,9 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
|
|||
root = node is expression
|
||||
original = node.copy()
|
||||
try:
|
||||
node = while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
|
||||
node = node.replace(
|
||||
while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
|
||||
)
|
||||
except OptimizeError as e:
|
||||
logger.info(e)
|
||||
node.replace(original)
|
||||
|
|
|
@ -93,6 +93,7 @@ def _expand_using(scope, resolver):
|
|||
if column not in columns:
|
||||
columns[column] = k
|
||||
|
||||
source_table = ordered[-1]
|
||||
ordered.append(join_table)
|
||||
join_columns = resolver.get_source_columns(join_table)
|
||||
conditions = []
|
||||
|
@ -102,8 +103,10 @@ def _expand_using(scope, resolver):
|
|||
table = columns.get(identifier)
|
||||
|
||||
if not table or identifier not in join_columns:
|
||||
raise OptimizeError(f"Cannot automatically join: {identifier}")
|
||||
if columns and join_columns:
|
||||
raise OptimizeError(f"Cannot automatically join: {identifier}")
|
||||
|
||||
table = table or source_table
|
||||
conditions.append(
|
||||
exp.condition(
|
||||
exp.EQ(
|
||||
|
|
|
@ -65,5 +65,8 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
|
|||
|
||||
if not table_alias.name:
|
||||
table_alias.set("this", next_name())
|
||||
if isinstance(udtf, exp.Values) and not table_alias.columns:
|
||||
for i, e in enumerate(udtf.expressions[0].expressions):
|
||||
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
|
||||
|
||||
return expression
|
||||
|
|
|
@ -201,23 +201,24 @@ def _simplify_comparison(expression, left, right, or_=False):
|
|||
return left if (av < bv if or_ else av >= bv) else right
|
||||
|
||||
# we can't ever shortcut to true because the column could be null
|
||||
if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
|
||||
if not or_ and av <= bv:
|
||||
return exp.false()
|
||||
elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
|
||||
if not or_ and av >= bv:
|
||||
return exp.false()
|
||||
elif isinstance(a, exp.EQ):
|
||||
if isinstance(b, exp.LT):
|
||||
return exp.false() if av >= bv else a
|
||||
if isinstance(b, exp.LTE):
|
||||
return exp.false() if av > bv else a
|
||||
if isinstance(b, exp.GT):
|
||||
return exp.false() if av <= bv else a
|
||||
if isinstance(b, exp.GTE):
|
||||
return exp.false() if av < bv else a
|
||||
if isinstance(b, exp.NEQ):
|
||||
return exp.false() if av == bv else a
|
||||
if not or_:
|
||||
if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
|
||||
if av <= bv:
|
||||
return exp.false()
|
||||
elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
|
||||
if av >= bv:
|
||||
return exp.false()
|
||||
elif isinstance(a, exp.EQ):
|
||||
if isinstance(b, exp.LT):
|
||||
return exp.false() if av >= bv else a
|
||||
if isinstance(b, exp.LTE):
|
||||
return exp.false() if av > bv else a
|
||||
if isinstance(b, exp.GT):
|
||||
return exp.false() if av <= bv else a
|
||||
if isinstance(b, exp.GTE):
|
||||
return exp.false() if av < bv else a
|
||||
if isinstance(b, exp.NEQ):
|
||||
return exp.false() if av == bv else a
|
||||
return None
|
||||
|
||||
|
||||
|
|
|
@ -18,8 +18,13 @@ from sqlglot.trie import in_trie, new_trie
|
|||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
E = t.TypeVar("E", bound=exp.Expression)
|
||||
|
||||
|
||||
def parse_var_map(args: t.Sequence) -> exp.Expression:
|
||||
if len(args) == 1 and args[0].is_star:
|
||||
return exp.StarMap(this=args[0])
|
||||
|
||||
keys = []
|
||||
values = []
|
||||
for i in range(0, len(args), 2):
|
||||
|
@ -108,6 +113,8 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.CURRENT_USER: exp.CurrentUser,
|
||||
}
|
||||
|
||||
JOIN_HINTS: t.Set[str] = set()
|
||||
|
||||
NESTED_TYPE_TOKENS = {
|
||||
TokenType.ARRAY,
|
||||
TokenType.MAP,
|
||||
|
@ -145,6 +152,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.DATETIME,
|
||||
TokenType.DATE,
|
||||
TokenType.DECIMAL,
|
||||
TokenType.BIGDECIMAL,
|
||||
TokenType.UUID,
|
||||
TokenType.GEOGRAPHY,
|
||||
TokenType.GEOMETRY,
|
||||
|
@ -221,8 +229,10 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.FORMAT,
|
||||
TokenType.FULL,
|
||||
TokenType.IF,
|
||||
TokenType.IS,
|
||||
TokenType.ISNULL,
|
||||
TokenType.INTERVAL,
|
||||
TokenType.KEEP,
|
||||
TokenType.LAZY,
|
||||
TokenType.LEADING,
|
||||
TokenType.LEFT,
|
||||
|
@ -235,6 +245,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.ONLY,
|
||||
TokenType.OPTIONS,
|
||||
TokenType.ORDINALITY,
|
||||
TokenType.OVERWRITE,
|
||||
TokenType.PARTITION,
|
||||
TokenType.PERCENT,
|
||||
TokenType.PIVOT,
|
||||
|
@ -266,6 +277,8 @@ class Parser(metaclass=_Parser):
|
|||
*NO_PAREN_FUNCTIONS,
|
||||
}
|
||||
|
||||
INTERVAL_VARS = ID_VAR_TOKENS - {TokenType.END}
|
||||
|
||||
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {
|
||||
TokenType.APPLY,
|
||||
TokenType.FULL,
|
||||
|
@ -276,6 +289,8 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.WINDOW,
|
||||
}
|
||||
|
||||
COMMENT_TABLE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.IS}
|
||||
|
||||
UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET}
|
||||
|
||||
TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH}
|
||||
|
@ -400,7 +415,7 @@ class Parser(metaclass=_Parser):
|
|||
COLUMN_OPERATORS = {
|
||||
TokenType.DOT: None,
|
||||
TokenType.DCOLON: lambda self, this, to: self.expression(
|
||||
exp.Cast,
|
||||
exp.Cast if self.STRICT_CAST else exp.TryCast,
|
||||
this=this,
|
||||
to=to,
|
||||
),
|
||||
|
@ -560,7 +575,7 @@ class Parser(metaclass=_Parser):
|
|||
),
|
||||
"DEFINER": lambda self: self._parse_definer(),
|
||||
"DETERMINISTIC": lambda self: self.expression(
|
||||
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
|
||||
exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
|
||||
),
|
||||
"DISTKEY": lambda self: self._parse_distkey(),
|
||||
"DISTSTYLE": lambda self: self._parse_property_assignment(exp.DistStyleProperty),
|
||||
|
@ -571,7 +586,7 @@ class Parser(metaclass=_Parser):
|
|||
"FREESPACE": lambda self: self._parse_freespace(),
|
||||
"GLOBAL": lambda self: self._parse_temporary(global_=True),
|
||||
"IMMUTABLE": lambda self: self.expression(
|
||||
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
|
||||
exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
|
||||
),
|
||||
"JOURNAL": lambda self: self._parse_journal(
|
||||
no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL"
|
||||
|
@ -600,20 +615,20 @@ class Parser(metaclass=_Parser):
|
|||
"PARTITIONED_BY": lambda self: self._parse_partitioned_by(),
|
||||
"RETURNS": lambda self: self._parse_returns(),
|
||||
"ROW": lambda self: self._parse_row(),
|
||||
"ROW_FORMAT": lambda self: self._parse_property_assignment(exp.RowFormatProperty),
|
||||
"SET": lambda self: self.expression(exp.SetProperty, multi=False),
|
||||
"SORTKEY": lambda self: self._parse_sortkey(),
|
||||
"STABLE": lambda self: self.expression(
|
||||
exp.VolatilityProperty, this=exp.Literal.string("STABLE")
|
||||
exp.StabilityProperty, this=exp.Literal.string("STABLE")
|
||||
),
|
||||
"STORED": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
|
||||
"STORED": lambda self: self._parse_stored(),
|
||||
"TABLE_FORMAT": lambda self: self._parse_property_assignment(exp.TableFormatProperty),
|
||||
"TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property),
|
||||
"TEMP": lambda self: self._parse_temporary(global_=False),
|
||||
"TEMPORARY": lambda self: self._parse_temporary(global_=False),
|
||||
"TRANSIENT": lambda self: self.expression(exp.TransientProperty),
|
||||
"USING": lambda self: self._parse_property_assignment(exp.TableFormatProperty),
|
||||
"VOLATILE": lambda self: self.expression(
|
||||
exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
|
||||
),
|
||||
"VOLATILE": lambda self: self._parse_volatile_property(),
|
||||
"WITH": lambda self: self._parse_with_property(),
|
||||
}
|
||||
|
||||
|
@ -648,8 +663,11 @@ class Parser(metaclass=_Parser):
|
|||
"LIKE": lambda self: self._parse_create_like(),
|
||||
"NOT": lambda self: self._parse_not_constraint(),
|
||||
"NULL": lambda self: self.expression(exp.NotNullColumnConstraint, allow_null=True),
|
||||
"ON": lambda self: self._match(TokenType.UPDATE)
|
||||
and self.expression(exp.OnUpdateColumnConstraint, this=self._parse_function()),
|
||||
"PATH": lambda self: self.expression(exp.PathColumnConstraint, this=self._parse_string()),
|
||||
"PRIMARY KEY": lambda self: self._parse_primary_key(),
|
||||
"REFERENCES": lambda self: self._parse_references(match=False),
|
||||
"TITLE": lambda self: self.expression(
|
||||
exp.TitleColumnConstraint, this=self._parse_var_or_string()
|
||||
),
|
||||
|
@ -668,9 +686,14 @@ class Parser(metaclass=_Parser):
|
|||
SCHEMA_UNNAMED_CONSTRAINTS = {"CHECK", "FOREIGN KEY", "LIKE", "PRIMARY KEY", "UNIQUE"}
|
||||
|
||||
NO_PAREN_FUNCTION_PARSERS = {
|
||||
TokenType.ANY: lambda self: self.expression(exp.Any, this=self._parse_bitwise()),
|
||||
TokenType.CASE: lambda self: self._parse_case(),
|
||||
TokenType.IF: lambda self: self._parse_if(),
|
||||
TokenType.ANY: lambda self: self.expression(exp.Any, this=self._parse_bitwise()),
|
||||
TokenType.NEXT_VALUE_FOR: lambda self: self.expression(
|
||||
exp.NextValueFor,
|
||||
this=self._parse_column(),
|
||||
order=self._match(TokenType.OVER) and self._parse_wrapped(self._parse_order),
|
||||
),
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
|
||||
|
@ -715,6 +738,8 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
SHOW_PARSERS: t.Dict[str, t.Callable] = {}
|
||||
|
||||
TYPE_LITERAL_PARSERS: t.Dict[exp.DataType.Type, t.Callable] = {}
|
||||
|
||||
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
|
||||
|
||||
TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
|
||||
|
@ -731,6 +756,7 @@ class Parser(metaclass=_Parser):
|
|||
INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"}
|
||||
|
||||
WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS}
|
||||
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER}
|
||||
|
||||
ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY}
|
||||
|
||||
|
@ -738,6 +764,9 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
CONVERT_TYPE_FIRST = False
|
||||
|
||||
QUOTED_PIVOT_COLUMNS: t.Optional[bool] = None
|
||||
PREFIXED_PIVOT_COLUMNS = False
|
||||
|
||||
LOG_BASE_FIRST = True
|
||||
LOG_DEFAULTS_TO_LN = False
|
||||
|
||||
|
@ -895,8 +924,8 @@ class Parser(metaclass=_Parser):
|
|||
error level setting.
|
||||
"""
|
||||
token = token or self._curr or self._prev or Token.string("")
|
||||
start = self._find_token(token)
|
||||
end = start + len(token.text)
|
||||
start = token.start
|
||||
end = token.end
|
||||
start_context = self.sql[max(start - self.error_message_context, 0) : start]
|
||||
highlight = self.sql[start:end]
|
||||
end_context = self.sql[end : end + self.error_message_context]
|
||||
|
@ -918,8 +947,8 @@ class Parser(metaclass=_Parser):
|
|||
self.errors.append(error)
|
||||
|
||||
def expression(
|
||||
self, exp_class: t.Type[exp.Expression], comments: t.Optional[t.List[str]] = None, **kwargs
|
||||
) -> exp.Expression:
|
||||
self, exp_class: t.Type[E], comments: t.Optional[t.List[str]] = None, **kwargs
|
||||
) -> E:
|
||||
"""
|
||||
Creates a new, validated Expression.
|
||||
|
||||
|
@ -958,22 +987,7 @@ class Parser(metaclass=_Parser):
|
|||
self.raise_error(error_message)
|
||||
|
||||
def _find_sql(self, start: Token, end: Token) -> str:
|
||||
return self.sql[self._find_token(start) : self._find_token(end) + len(end.text)]
|
||||
|
||||
def _find_token(self, token: Token) -> int:
|
||||
line = 1
|
||||
col = 1
|
||||
index = 0
|
||||
|
||||
while line < token.line or col < token.col:
|
||||
if Tokenizer.WHITE_SPACE.get(self.sql[index]) == TokenType.BREAK:
|
||||
line += 1
|
||||
col = 1
|
||||
else:
|
||||
col += 1
|
||||
index += 1
|
||||
|
||||
return index
|
||||
return self.sql[start.start : end.end]
|
||||
|
||||
def _advance(self, times: int = 1) -> None:
|
||||
self._index += times
|
||||
|
@ -990,7 +1004,7 @@ class Parser(metaclass=_Parser):
|
|||
if index != self._index:
|
||||
self._advance(index - self._index)
|
||||
|
||||
def _parse_command(self) -> exp.Expression:
|
||||
def _parse_command(self) -> exp.Command:
|
||||
return self.expression(exp.Command, this=self._prev.text, expression=self._parse_string())
|
||||
|
||||
def _parse_comment(self, allow_exists: bool = True) -> exp.Expression:
|
||||
|
@ -1007,7 +1021,7 @@ class Parser(metaclass=_Parser):
|
|||
if kind.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
|
||||
this = self._parse_user_defined_function(kind=kind.token_type)
|
||||
elif kind.token_type == TokenType.TABLE:
|
||||
this = self._parse_table()
|
||||
this = self._parse_table(alias_tokens=self.COMMENT_TABLE_ALIAS_TOKENS)
|
||||
elif kind.token_type == TokenType.COLUMN:
|
||||
this = self._parse_column()
|
||||
else:
|
||||
|
@ -1035,16 +1049,13 @@ class Parser(metaclass=_Parser):
|
|||
self._parse_query_modifiers(expression)
|
||||
return expression
|
||||
|
||||
def _parse_drop(self, default_kind: t.Optional[str] = None) -> t.Optional[exp.Expression]:
|
||||
def _parse_drop(self) -> t.Optional[exp.Drop | exp.Command]:
|
||||
start = self._prev
|
||||
temporary = self._match(TokenType.TEMPORARY)
|
||||
materialized = self._match(TokenType.MATERIALIZED)
|
||||
kind = self._match_set(self.CREATABLES) and self._prev.text
|
||||
if not kind:
|
||||
if default_kind:
|
||||
kind = default_kind
|
||||
else:
|
||||
return self._parse_as_command(start)
|
||||
return self._parse_as_command(start)
|
||||
|
||||
return self.expression(
|
||||
exp.Drop,
|
||||
|
@ -1055,6 +1066,7 @@ class Parser(metaclass=_Parser):
|
|||
materialized=materialized,
|
||||
cascade=self._match(TokenType.CASCADE),
|
||||
constraints=self._match_text_seq("CONSTRAINTS"),
|
||||
purge=self._match_text_seq("PURGE"),
|
||||
)
|
||||
|
||||
def _parse_exists(self, not_: bool = False) -> t.Optional[bool]:
|
||||
|
@ -1070,7 +1082,6 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.OR, TokenType.REPLACE
|
||||
)
|
||||
unique = self._match(TokenType.UNIQUE)
|
||||
volatile = self._match(TokenType.VOLATILE)
|
||||
|
||||
if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False):
|
||||
self._match(TokenType.TABLE)
|
||||
|
@ -1179,7 +1190,6 @@ class Parser(metaclass=_Parser):
|
|||
kind=create_token.text,
|
||||
replace=replace,
|
||||
unique=unique,
|
||||
volatile=volatile,
|
||||
expression=expression,
|
||||
exists=exists,
|
||||
properties=properties,
|
||||
|
@ -1225,6 +1235,21 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return None
|
||||
|
||||
def _parse_stored(self) -> exp.Expression:
|
||||
self._match(TokenType.ALIAS)
|
||||
|
||||
input_format = self._parse_string() if self._match_text_seq("INPUTFORMAT") else None
|
||||
output_format = self._parse_string() if self._match_text_seq("OUTPUTFORMAT") else None
|
||||
|
||||
return self.expression(
|
||||
exp.FileFormatProperty,
|
||||
this=self.expression(
|
||||
exp.InputOutputFormat, input_format=input_format, output_format=output_format
|
||||
)
|
||||
if input_format or output_format
|
||||
else self._parse_var_or_string() or self._parse_number() or self._parse_id_var(),
|
||||
)
|
||||
|
||||
def _parse_property_assignment(self, exp_class: t.Type[exp.Expression]) -> exp.Expression:
|
||||
self._match(TokenType.EQ)
|
||||
self._match(TokenType.ALIAS)
|
||||
|
@ -1258,6 +1283,21 @@ class Parser(metaclass=_Parser):
|
|||
exp.FallbackProperty, no=no, protection=self._match_text_seq("PROTECTION")
|
||||
)
|
||||
|
||||
def _parse_volatile_property(self) -> exp.Expression:
|
||||
if self._index >= 2:
|
||||
pre_volatile_token = self._tokens[self._index - 2]
|
||||
else:
|
||||
pre_volatile_token = None
|
||||
|
||||
if pre_volatile_token and pre_volatile_token.token_type in (
|
||||
TokenType.CREATE,
|
||||
TokenType.REPLACE,
|
||||
TokenType.UNIQUE,
|
||||
):
|
||||
return exp.VolatileProperty()
|
||||
|
||||
return self.expression(exp.StabilityProperty, this=exp.Literal.string("VOLATILE"))
|
||||
|
||||
def _parse_with_property(
|
||||
self,
|
||||
) -> t.Union[t.Optional[exp.Expression], t.List[t.Optional[exp.Expression]]]:
|
||||
|
@ -1574,11 +1614,46 @@ class Parser(metaclass=_Parser):
|
|||
exists=self._parse_exists(),
|
||||
partition=self._parse_partition(),
|
||||
expression=self._parse_ddl_select(),
|
||||
conflict=self._parse_on_conflict(),
|
||||
returning=self._parse_returning(),
|
||||
overwrite=overwrite,
|
||||
alternative=alternative,
|
||||
)
|
||||
|
||||
def _parse_on_conflict(self) -> t.Optional[exp.Expression]:
|
||||
conflict = self._match_text_seq("ON", "CONFLICT")
|
||||
duplicate = self._match_text_seq("ON", "DUPLICATE", "KEY")
|
||||
|
||||
if not (conflict or duplicate):
|
||||
return None
|
||||
|
||||
nothing = None
|
||||
expressions = None
|
||||
key = None
|
||||
constraint = None
|
||||
|
||||
if conflict:
|
||||
if self._match_text_seq("ON", "CONSTRAINT"):
|
||||
constraint = self._parse_id_var()
|
||||
else:
|
||||
key = self._parse_csv(self._parse_value)
|
||||
|
||||
self._match_text_seq("DO")
|
||||
if self._match_text_seq("NOTHING"):
|
||||
nothing = True
|
||||
else:
|
||||
self._match(TokenType.UPDATE)
|
||||
expressions = self._match(TokenType.SET) and self._parse_csv(self._parse_equality)
|
||||
|
||||
return self.expression(
|
||||
exp.OnConflict,
|
||||
duplicate=duplicate,
|
||||
expressions=expressions,
|
||||
nothing=nothing,
|
||||
key=key,
|
||||
constraint=constraint,
|
||||
)
|
||||
|
||||
def _parse_returning(self) -> t.Optional[exp.Expression]:
|
||||
if not self._match(TokenType.RETURNING):
|
||||
return None
|
||||
|
@ -1639,7 +1714,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return self.expression(
|
||||
exp.Delete,
|
||||
this=self._parse_table(schema=True),
|
||||
this=self._parse_table(),
|
||||
using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table()),
|
||||
where=self._parse_where(),
|
||||
returning=self._parse_returning(),
|
||||
|
@ -1792,6 +1867,7 @@ class Parser(metaclass=_Parser):
|
|||
if not skip_with_token and not self._match(TokenType.WITH):
|
||||
return None
|
||||
|
||||
comments = self._prev_comments
|
||||
recursive = self._match(TokenType.RECURSIVE)
|
||||
|
||||
expressions = []
|
||||
|
@ -1803,7 +1879,9 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
self._match(TokenType.WITH)
|
||||
|
||||
return self.expression(exp.With, expressions=expressions, recursive=recursive)
|
||||
return self.expression(
|
||||
exp.With, comments=comments, expressions=expressions, recursive=recursive
|
||||
)
|
||||
|
||||
def _parse_cte(self) -> exp.Expression:
|
||||
alias = self._parse_table_alias()
|
||||
|
@ -1856,15 +1934,20 @@ class Parser(metaclass=_Parser):
|
|||
table = isinstance(this, exp.Table)
|
||||
|
||||
while True:
|
||||
lateral = self._parse_lateral()
|
||||
join = self._parse_join()
|
||||
comma = None if table else self._match(TokenType.COMMA)
|
||||
if lateral:
|
||||
this.append("laterals", lateral)
|
||||
if join:
|
||||
this.append("joins", join)
|
||||
|
||||
lateral = None
|
||||
if not join:
|
||||
lateral = self._parse_lateral()
|
||||
if lateral:
|
||||
this.append("laterals", lateral)
|
||||
|
||||
comma = None if table else self._match(TokenType.COMMA)
|
||||
if comma:
|
||||
this.args["from"].append("expressions", self._parse_table())
|
||||
|
||||
if not (lateral or join or comma):
|
||||
break
|
||||
|
||||
|
@ -1906,14 +1989,13 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_match_recognize(self) -> t.Optional[exp.Expression]:
|
||||
if not self._match(TokenType.MATCH_RECOGNIZE):
|
||||
return None
|
||||
|
||||
self._match_l_paren()
|
||||
|
||||
partition = self._parse_partition_by()
|
||||
order = self._parse_order()
|
||||
measures = (
|
||||
self._parse_alias(self._parse_conjunction())
|
||||
if self._match_text_seq("MEASURES")
|
||||
else None
|
||||
self._parse_csv(self._parse_expression) if self._match_text_seq("MEASURES") else None
|
||||
)
|
||||
|
||||
if self._match_text_seq("ONE", "ROW", "PER", "MATCH"):
|
||||
|
@ -1967,8 +2049,17 @@ class Parser(metaclass=_Parser):
|
|||
pattern = None
|
||||
|
||||
define = (
|
||||
self._parse_alias(self._parse_conjunction()) if self._match_text_seq("DEFINE") else None
|
||||
self._parse_csv(
|
||||
lambda: self.expression(
|
||||
exp.Alias,
|
||||
alias=self._parse_id_var(any_token=True),
|
||||
this=self._match(TokenType.ALIAS) and self._parse_conjunction(),
|
||||
)
|
||||
)
|
||||
if self._match_text_seq("DEFINE")
|
||||
else None
|
||||
)
|
||||
|
||||
self._match_r_paren()
|
||||
|
||||
return self.expression(
|
||||
|
@ -1980,6 +2071,7 @@ class Parser(metaclass=_Parser):
|
|||
after=after,
|
||||
pattern=pattern,
|
||||
define=define,
|
||||
alias=self._parse_table_alias(),
|
||||
)
|
||||
|
||||
def _parse_lateral(self) -> t.Optional[exp.Expression]:
|
||||
|
@ -2022,9 +2114,6 @@ class Parser(metaclass=_Parser):
|
|||
alias=table_alias,
|
||||
)
|
||||
|
||||
if outer_apply or cross_apply:
|
||||
return self.expression(exp.Join, this=expression, side=None if cross_apply else "LEFT")
|
||||
|
||||
return expression
|
||||
|
||||
def _parse_join_side_and_kind(
|
||||
|
@ -2037,11 +2126,26 @@ class Parser(metaclass=_Parser):
|
|||
)
|
||||
|
||||
def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]:
|
||||
index = self._index
|
||||
natural, side, kind = self._parse_join_side_and_kind()
|
||||
hint = self._prev.text if self._match_texts(self.JOIN_HINTS) else None
|
||||
join = self._match(TokenType.JOIN)
|
||||
|
||||
if not skip_join_token and not self._match(TokenType.JOIN):
|
||||
if not skip_join_token and not join:
|
||||
self._retreat(index)
|
||||
kind = None
|
||||
natural = None
|
||||
side = None
|
||||
|
||||
outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY, False)
|
||||
cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY, False)
|
||||
|
||||
if not skip_join_token and not join and not outer_apply and not cross_apply:
|
||||
return None
|
||||
|
||||
if outer_apply:
|
||||
side = Token(TokenType.LEFT, "LEFT")
|
||||
|
||||
kwargs: t.Dict[
|
||||
str, t.Optional[exp.Expression] | bool | str | t.List[t.Optional[exp.Expression]]
|
||||
] = {"this": self._parse_table()}
|
||||
|
@ -2052,6 +2156,8 @@ class Parser(metaclass=_Parser):
|
|||
kwargs["side"] = side.text
|
||||
if kind:
|
||||
kwargs["kind"] = kind.text
|
||||
if hint:
|
||||
kwargs["hint"] = hint
|
||||
|
||||
if self._match(TokenType.ON):
|
||||
kwargs["on"] = self._parse_conjunction()
|
||||
|
@ -2179,7 +2285,7 @@ class Parser(metaclass=_Parser):
|
|||
return None
|
||||
|
||||
expressions = self._parse_wrapped_csv(self._parse_column)
|
||||
ordinality = bool(self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY))
|
||||
ordinality = self._match_pair(TokenType.WITH, TokenType.ORDINALITY)
|
||||
alias = self._parse_table_alias()
|
||||
|
||||
if alias and self.unnest_column_only:
|
||||
|
@ -2191,7 +2297,7 @@ class Parser(metaclass=_Parser):
|
|||
offset = None
|
||||
if self._match_pair(TokenType.WITH, TokenType.OFFSET):
|
||||
self._match(TokenType.ALIAS)
|
||||
offset = self._parse_conjunction()
|
||||
offset = self._parse_id_var() or exp.Identifier(this="offset")
|
||||
|
||||
return self.expression(
|
||||
exp.Unnest,
|
||||
|
@ -2294,6 +2400,9 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
expressions = self._parse_csv(lambda: self._parse_alias(self._parse_function()))
|
||||
|
||||
if not expressions:
|
||||
self.raise_error("Failed to parse PIVOT's aggregation list")
|
||||
|
||||
if not self._match(TokenType.FOR):
|
||||
self.raise_error("Expecting FOR")
|
||||
|
||||
|
@ -2311,8 +2420,26 @@ class Parser(metaclass=_Parser):
|
|||
if not self._match_set((TokenType.PIVOT, TokenType.UNPIVOT), advance=False):
|
||||
pivot.set("alias", self._parse_table_alias())
|
||||
|
||||
if not unpivot:
|
||||
names = self._pivot_column_names(t.cast(t.List[exp.Expression], expressions))
|
||||
|
||||
columns: t.List[exp.Expression] = []
|
||||
for col in pivot.args["field"].expressions:
|
||||
for name in names:
|
||||
if self.PREFIXED_PIVOT_COLUMNS:
|
||||
name = f"{name}_{col.alias_or_name}" if name else col.alias_or_name
|
||||
else:
|
||||
name = f"{col.alias_or_name}_{name}" if name else col.alias_or_name
|
||||
|
||||
columns.append(exp.to_identifier(name, quoted=self.QUOTED_PIVOT_COLUMNS))
|
||||
|
||||
pivot.set("columns", columns)
|
||||
|
||||
return pivot
|
||||
|
||||
def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
|
||||
return [agg.alias for agg in pivot_columns]
|
||||
|
||||
def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Expression]:
|
||||
if not skip_where_token and not self._match(TokenType.WHERE):
|
||||
return None
|
||||
|
@ -2433,10 +2560,25 @@ class Parser(metaclass=_Parser):
|
|||
if self._match(TokenType.FETCH):
|
||||
direction = self._match_set((TokenType.FIRST, TokenType.NEXT))
|
||||
direction = self._prev.text if direction else "FIRST"
|
||||
|
||||
count = self._parse_number()
|
||||
percent = self._match(TokenType.PERCENT)
|
||||
|
||||
self._match_set((TokenType.ROW, TokenType.ROWS))
|
||||
self._match(TokenType.ONLY)
|
||||
return self.expression(exp.Fetch, direction=direction, count=count)
|
||||
|
||||
only = self._match(TokenType.ONLY)
|
||||
with_ties = self._match_text_seq("WITH", "TIES")
|
||||
|
||||
if only and with_ties:
|
||||
self.raise_error("Cannot specify both ONLY and WITH TIES in FETCH clause")
|
||||
|
||||
return self.expression(
|
||||
exp.Fetch,
|
||||
direction=direction,
|
||||
count=count,
|
||||
percent=percent,
|
||||
with_ties=with_ties,
|
||||
)
|
||||
|
||||
return this
|
||||
|
||||
|
@ -2493,7 +2635,11 @@ class Parser(metaclass=_Parser):
|
|||
negate = self._match(TokenType.NOT)
|
||||
|
||||
if self._match_set(self.RANGE_PARSERS):
|
||||
this = self.RANGE_PARSERS[self._prev.token_type](self, this)
|
||||
expression = self.RANGE_PARSERS[self._prev.token_type](self, this)
|
||||
if not expression:
|
||||
return this
|
||||
|
||||
this = expression
|
||||
elif self._match(TokenType.ISNULL):
|
||||
this = self.expression(exp.Is, this=this, expression=exp.Null())
|
||||
|
||||
|
@ -2511,17 +2657,19 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return this
|
||||
|
||||
def _parse_is(self, this: t.Optional[exp.Expression]) -> exp.Expression:
|
||||
def _parse_is(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
index = self._index - 1
|
||||
negate = self._match(TokenType.NOT)
|
||||
if self._match(TokenType.DISTINCT_FROM):
|
||||
klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
|
||||
return self.expression(klass, this=this, expression=self._parse_expression())
|
||||
|
||||
this = self.expression(
|
||||
exp.Is,
|
||||
this=this,
|
||||
expression=self._parse_null() or self._parse_boolean(),
|
||||
)
|
||||
expression = self._parse_null() or self._parse_boolean()
|
||||
if not expression:
|
||||
self._retreat(index)
|
||||
return None
|
||||
|
||||
this = self.expression(exp.Is, this=this, expression=expression)
|
||||
return self.expression(exp.Not, this=this) if negate else this
|
||||
|
||||
def _parse_in(self, this: t.Optional[exp.Expression]) -> exp.Expression:
|
||||
|
@ -2553,6 +2701,27 @@ class Parser(metaclass=_Parser):
|
|||
return this
|
||||
return self.expression(exp.Escape, this=this, expression=self._parse_string())
|
||||
|
||||
def _parse_interval(self) -> t.Optional[exp.Expression]:
|
||||
if not self._match(TokenType.INTERVAL):
|
||||
return None
|
||||
|
||||
this = self._parse_primary() or self._parse_term()
|
||||
unit = self._parse_function() or self._parse_var()
|
||||
|
||||
# Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse
|
||||
# each INTERVAL expression into this canonical form so it's easy to transpile
|
||||
if this and isinstance(this, exp.Literal):
|
||||
if this.is_number:
|
||||
this = exp.Literal.string(this.name)
|
||||
|
||||
# Try to not clutter Snowflake's multi-part intervals like INTERVAL '1 day, 1 year'
|
||||
parts = this.name.split()
|
||||
if not unit and len(parts) <= 2:
|
||||
this = exp.Literal.string(seq_get(parts, 0))
|
||||
unit = self.expression(exp.Var, this=seq_get(parts, 1))
|
||||
|
||||
return self.expression(exp.Interval, this=this, unit=unit)
|
||||
|
||||
def _parse_bitwise(self) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_term()
|
||||
|
||||
|
@ -2588,20 +2757,24 @@ class Parser(metaclass=_Parser):
|
|||
return self._parse_at_time_zone(self._parse_type())
|
||||
|
||||
def _parse_type(self) -> t.Optional[exp.Expression]:
|
||||
if self._match(TokenType.INTERVAL):
|
||||
return self.expression(exp.Interval, this=self._parse_term(), unit=self._parse_field())
|
||||
interval = self._parse_interval()
|
||||
if interval:
|
||||
return interval
|
||||
|
||||
index = self._index
|
||||
type_token = self._parse_types(check_func=True)
|
||||
data_type = self._parse_types(check_func=True)
|
||||
this = self._parse_column()
|
||||
|
||||
if type_token:
|
||||
if data_type:
|
||||
if isinstance(this, exp.Literal):
|
||||
return self.expression(exp.Cast, this=this, to=type_token)
|
||||
if not type_token.args.get("expressions"):
|
||||
parser = self.TYPE_LITERAL_PARSERS.get(data_type.this)
|
||||
if parser:
|
||||
return parser(self, this, data_type)
|
||||
return self.expression(exp.Cast, this=this, to=data_type)
|
||||
if not data_type.args.get("expressions"):
|
||||
self._retreat(index)
|
||||
return self._parse_column()
|
||||
return type_token
|
||||
return data_type
|
||||
|
||||
return this
|
||||
|
||||
|
@ -2631,11 +2804,10 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
expressions = self._parse_csv(self._parse_conjunction)
|
||||
|
||||
if not expressions:
|
||||
if not expressions or not self._match(TokenType.R_PAREN):
|
||||
self._retreat(index)
|
||||
return None
|
||||
|
||||
self._match_r_paren()
|
||||
maybe_func = True
|
||||
|
||||
if self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
|
||||
|
@ -2720,15 +2892,14 @@ class Parser(metaclass=_Parser):
|
|||
)
|
||||
|
||||
def _parse_struct_kwargs(self) -> t.Optional[exp.Expression]:
|
||||
if self._curr and self._curr.token_type in self.TYPE_TOKENS:
|
||||
return self._parse_types()
|
||||
|
||||
index = self._index
|
||||
this = self._parse_id_var()
|
||||
self._match(TokenType.COLON)
|
||||
data_type = self._parse_types()
|
||||
|
||||
if not data_type:
|
||||
return None
|
||||
self._retreat(index)
|
||||
return self._parse_types()
|
||||
return self.expression(exp.StructKwarg, this=this, expression=data_type)
|
||||
|
||||
def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
|
@ -2825,6 +2996,7 @@ class Parser(metaclass=_Parser):
|
|||
this = self.expression(exp.Paren, this=self._parse_set_operations(this))
|
||||
|
||||
self._match_r_paren()
|
||||
comments.extend(self._prev_comments)
|
||||
|
||||
if this and comments:
|
||||
this.comments = comments
|
||||
|
@ -2833,8 +3005,16 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return None
|
||||
|
||||
def _parse_field(self, any_token: bool = False) -> t.Optional[exp.Expression]:
|
||||
return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token)
|
||||
def _parse_field(
|
||||
self,
|
||||
any_token: bool = False,
|
||||
tokens: t.Optional[t.Collection[TokenType]] = None,
|
||||
) -> t.Optional[exp.Expression]:
|
||||
return (
|
||||
self._parse_primary()
|
||||
or self._parse_function()
|
||||
or self._parse_id_var(any_token=any_token, tokens=tokens)
|
||||
)
|
||||
|
||||
def _parse_function(
|
||||
self, functions: t.Optional[t.Dict[str, t.Callable]] = None
|
||||
|
@ -3079,12 +3259,10 @@ class Parser(metaclass=_Parser):
|
|||
return None
|
||||
|
||||
def _parse_column_constraint(self) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_references()
|
||||
if this:
|
||||
return this
|
||||
|
||||
if self._match(TokenType.CONSTRAINT):
|
||||
this = self._parse_id_var()
|
||||
else:
|
||||
this = None
|
||||
|
||||
if self._match_texts(self.CONSTRAINT_PARSERS):
|
||||
return self.expression(
|
||||
|
@ -3164,8 +3342,8 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return options
|
||||
|
||||
def _parse_references(self) -> t.Optional[exp.Expression]:
|
||||
if not self._match(TokenType.REFERENCES):
|
||||
def _parse_references(self, match=True) -> t.Optional[exp.Expression]:
|
||||
if match and not self._match(TokenType.REFERENCES):
|
||||
return None
|
||||
|
||||
expressions = None
|
||||
|
@ -3234,7 +3412,7 @@ class Parser(metaclass=_Parser):
|
|||
elif not this or this.name.upper() == "ARRAY":
|
||||
this = self.expression(exp.Array, expressions=expressions)
|
||||
else:
|
||||
expressions = apply_index_offset(expressions, -self.index_offset)
|
||||
expressions = apply_index_offset(this, expressions, -self.index_offset)
|
||||
this = self.expression(exp.Bracket, this=this, expressions=expressions)
|
||||
|
||||
if not self._match(TokenType.R_BRACKET) and bracket_kind == TokenType.L_BRACKET:
|
||||
|
@ -3279,7 +3457,13 @@ class Parser(metaclass=_Parser):
|
|||
self.validate_expression(this, args)
|
||||
self._match_r_paren()
|
||||
else:
|
||||
index = self._index - 1
|
||||
condition = self._parse_conjunction()
|
||||
|
||||
if not condition:
|
||||
self._retreat(index)
|
||||
return None
|
||||
|
||||
self._match(TokenType.THEN)
|
||||
true = self._parse_conjunction()
|
||||
false = self._parse_conjunction() if self._match(TokenType.ELSE) else None
|
||||
|
@ -3591,14 +3775,24 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
# bigquery select from window x AS (partition by ...)
|
||||
if alias:
|
||||
over = None
|
||||
self._match(TokenType.ALIAS)
|
||||
elif not self._match(TokenType.OVER):
|
||||
elif not self._match_set(self.WINDOW_BEFORE_PAREN_TOKENS):
|
||||
return this
|
||||
else:
|
||||
over = self._prev.text.upper()
|
||||
|
||||
if not self._match(TokenType.L_PAREN):
|
||||
return self.expression(exp.Window, this=this, alias=self._parse_id_var(False))
|
||||
return self.expression(
|
||||
exp.Window, this=this, alias=self._parse_id_var(False), over=over
|
||||
)
|
||||
|
||||
window_alias = self._parse_id_var(any_token=False, tokens=self.WINDOW_ALIAS_TOKENS)
|
||||
|
||||
first = self._match(TokenType.FIRST)
|
||||
if self._match_text_seq("LAST"):
|
||||
first = False
|
||||
|
||||
partition = self._parse_partition_by()
|
||||
order = self._parse_order()
|
||||
kind = self._match_set((TokenType.ROWS, TokenType.RANGE)) and self._prev.text
|
||||
|
@ -3629,6 +3823,8 @@ class Parser(metaclass=_Parser):
|
|||
order=order,
|
||||
spec=spec,
|
||||
alias=window_alias,
|
||||
over=over,
|
||||
first=first,
|
||||
)
|
||||
|
||||
def _parse_window_spec(self) -> t.Dict[str, t.Optional[str | exp.Expression]]:
|
||||
|
@ -3886,7 +4082,10 @@ class Parser(metaclass=_Parser):
|
|||
return expression
|
||||
|
||||
def _parse_drop_column(self) -> t.Optional[exp.Expression]:
|
||||
return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN")
|
||||
drop = self._match(TokenType.DROP) and self._parse_drop()
|
||||
if drop and not isinstance(drop, exp.Command):
|
||||
drop.set("kind", drop.args.get("kind", "COLUMN"))
|
||||
return drop
|
||||
|
||||
# https://docs.aws.amazon.com/athena/latest/ug/alter-table-drop-partition.html
|
||||
def _parse_drop_partition(self, exists: t.Optional[bool] = None) -> exp.Expression:
|
||||
|
@ -4010,7 +4209,7 @@ class Parser(metaclass=_Parser):
|
|||
if self._match(TokenType.INSERT):
|
||||
_this = self._parse_star()
|
||||
if _this:
|
||||
then = self.expression(exp.Insert, this=_this)
|
||||
then: t.Optional[exp.Expression] = self.expression(exp.Insert, this=_this)
|
||||
else:
|
||||
then = self.expression(
|
||||
exp.Insert,
|
||||
|
@ -4239,5 +4438,8 @@ class Parser(metaclass=_Parser):
|
|||
break
|
||||
parent = parent.parent
|
||||
else:
|
||||
column.replace(dot_or_id)
|
||||
if column is node:
|
||||
node = dot_or_id
|
||||
else:
|
||||
column.replace(dot_or_id)
|
||||
return node
|
||||
|
|
|
@ -5,7 +5,7 @@ import typing as t
|
|||
|
||||
import sqlglot
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.errors import SchemaError
|
||||
from sqlglot.errors import ParseError, SchemaError
|
||||
from sqlglot.helper import dict_depth
|
||||
from sqlglot.trie import in_trie, new_trie
|
||||
|
||||
|
@ -75,12 +75,11 @@ class AbstractMappingSchema(t.Generic[T]):
|
|||
mapping: dict | None = None,
|
||||
) -> None:
|
||||
self.mapping = mapping or {}
|
||||
self.mapping_trie = self._build_trie(self.mapping)
|
||||
self.mapping_trie = new_trie(
|
||||
tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
|
||||
)
|
||||
self._supported_table_args: t.Tuple[str, ...] = tuple()
|
||||
|
||||
def _build_trie(self, schema: t.Dict) -> t.Dict:
|
||||
return new_trie(tuple(reversed(t)) for t in flatten_schema(schema, depth=self._depth()))
|
||||
|
||||
def _depth(self) -> int:
|
||||
return dict_depth(self.mapping)
|
||||
|
||||
|
@ -179,6 +178,64 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
}
|
||||
)
|
||||
|
||||
def add_table(
|
||||
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
|
||||
) -> None:
|
||||
"""
|
||||
Register or update a table. Updates are only performed if a new column mapping is provided.
|
||||
|
||||
Args:
|
||||
table: the `Table` expression instance or string representing the table.
|
||||
column_mapping: a column mapping that describes the structure of the table.
|
||||
"""
|
||||
normalized_table = self._normalize_table(self._ensure_table(table))
|
||||
normalized_column_mapping = {
|
||||
self._normalize_name(key): value
|
||||
for key, value in ensure_column_mapping(column_mapping).items()
|
||||
}
|
||||
|
||||
schema = self.find(normalized_table, raise_on_missing=False)
|
||||
if schema and not normalized_column_mapping:
|
||||
return
|
||||
|
||||
parts = self.table_parts(normalized_table)
|
||||
|
||||
_nested_set(
|
||||
self.mapping,
|
||||
tuple(reversed(parts)),
|
||||
normalized_column_mapping,
|
||||
)
|
||||
new_trie([parts], self.mapping_trie)
|
||||
|
||||
def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
|
||||
table_ = self._normalize_table(self._ensure_table(table))
|
||||
schema = self.find(table_)
|
||||
|
||||
if schema is None:
|
||||
return []
|
||||
|
||||
if not only_visible or not self.visible:
|
||||
return list(schema)
|
||||
|
||||
visible = self._nested_get(self.table_parts(table_), self.visible)
|
||||
return [col for col in schema if col in visible] # type: ignore
|
||||
|
||||
def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
|
||||
column_name = self._normalize_name(column if isinstance(column, str) else column.this)
|
||||
table_ = self._normalize_table(self._ensure_table(table))
|
||||
|
||||
table_schema = self.find(table_, raise_on_missing=False)
|
||||
if table_schema:
|
||||
column_type = table_schema.get(column_name)
|
||||
|
||||
if isinstance(column_type, exp.DataType):
|
||||
return column_type
|
||||
elif isinstance(column_type, str):
|
||||
return self._to_data_type(column_type.upper())
|
||||
raise SchemaError(f"Unknown column type '{column_type}'")
|
||||
|
||||
return exp.DataType.build("unknown")
|
||||
|
||||
def _normalize(self, schema: t.Dict) -> t.Dict:
|
||||
"""
|
||||
Converts all identifiers in the schema into lowercase, unless they're quoted.
|
||||
|
@ -206,84 +263,37 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
|
||||
return normalized_mapping
|
||||
|
||||
def add_table(
|
||||
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
|
||||
) -> None:
|
||||
"""
|
||||
Register or update a table. Updates are only performed if a new column mapping is provided.
|
||||
def _normalize_table(self, table: exp.Table) -> exp.Table:
|
||||
normalized_table = table.copy()
|
||||
for arg in TABLE_ARGS:
|
||||
value = normalized_table.args.get(arg)
|
||||
if isinstance(value, (str, exp.Identifier)):
|
||||
normalized_table.set(arg, self._normalize_name(value))
|
||||
|
||||
Args:
|
||||
table: the `Table` expression instance or string representing the table.
|
||||
column_mapping: a column mapping that describes the structure of the table.
|
||||
"""
|
||||
table_ = self._ensure_table(table)
|
||||
column_mapping = ensure_column_mapping(column_mapping)
|
||||
schema = self.find(table_, raise_on_missing=False)
|
||||
return normalized_table
|
||||
|
||||
if schema and not column_mapping:
|
||||
return
|
||||
|
||||
_nested_set(
|
||||
self.mapping,
|
||||
list(reversed(self.table_parts(table_))),
|
||||
column_mapping,
|
||||
)
|
||||
self.mapping_trie = self._build_trie(self.mapping)
|
||||
|
||||
def _normalize_name(self, name: str) -> str:
|
||||
def _normalize_name(self, name: str | exp.Identifier) -> str:
|
||||
try:
|
||||
identifier: t.Optional[exp.Expression] = sqlglot.parse_one(
|
||||
name, read=self.dialect, into=exp.Identifier
|
||||
)
|
||||
except:
|
||||
identifier = exp.to_identifier(name)
|
||||
assert isinstance(identifier, exp.Identifier)
|
||||
identifier = sqlglot.maybe_parse(name, dialect=self.dialect, into=exp.Identifier)
|
||||
except ParseError:
|
||||
return name if isinstance(name, str) else name.name
|
||||
|
||||
if identifier.quoted:
|
||||
return identifier.name
|
||||
return identifier.name.lower()
|
||||
return identifier.name if identifier.quoted else identifier.name.lower()
|
||||
|
||||
def _depth(self) -> int:
|
||||
# The columns themselves are a mapping, but we don't want to include those
|
||||
return super()._depth() - 1
|
||||
|
||||
def _ensure_table(self, table: exp.Table | str) -> exp.Table:
|
||||
table_ = exp.to_table(table)
|
||||
if isinstance(table, exp.Table):
|
||||
return table
|
||||
|
||||
table_ = sqlglot.parse_one(table, read=self.dialect, into=exp.Table)
|
||||
if not table_:
|
||||
raise SchemaError(f"Not a valid table '{table}'")
|
||||
|
||||
return table_
|
||||
|
||||
def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
|
||||
table_ = self._ensure_table(table)
|
||||
schema = self.find(table_)
|
||||
|
||||
if schema is None:
|
||||
return []
|
||||
|
||||
if not only_visible or not self.visible:
|
||||
return list(schema)
|
||||
|
||||
visible = self._nested_get(self.table_parts(table_), self.visible)
|
||||
return [col for col in schema if col in visible] # type: ignore
|
||||
|
||||
def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
|
||||
column_name = column if isinstance(column, str) else column.name
|
||||
table_ = exp.to_table(table)
|
||||
if table_:
|
||||
table_schema = self.find(table_, raise_on_missing=False)
|
||||
if table_schema:
|
||||
column_type = table_schema.get(column_name)
|
||||
|
||||
if isinstance(column_type, exp.DataType):
|
||||
return column_type
|
||||
elif isinstance(column_type, str):
|
||||
return self._to_data_type(column_type.upper())
|
||||
raise SchemaError(f"Unknown column type '{column_type}'")
|
||||
return exp.DataType(this=exp.DataType.Type.UNKNOWN)
|
||||
raise SchemaError(f"Could not convert table '{table}'")
|
||||
|
||||
def _to_data_type(self, schema_type: str) -> exp.DataType:
|
||||
"""
|
||||
Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
|
||||
|
@ -313,7 +323,7 @@ def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema:
|
|||
return MappingSchema(schema, dialect=dialect)
|
||||
|
||||
|
||||
def ensure_column_mapping(mapping: t.Optional[ColumnMapping]):
|
||||
def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
|
||||
if isinstance(mapping, dict):
|
||||
return mapping
|
||||
elif isinstance(mapping, str):
|
||||
|
@ -371,7 +381,7 @@ def _nested_get(
|
|||
return d
|
||||
|
||||
|
||||
def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict:
|
||||
def _nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
|
||||
"""
|
||||
In-place set a value for a nested dictionary
|
||||
|
||||
|
@ -384,11 +394,11 @@ def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict:
|
|||
|
||||
Args:
|
||||
d: dictionary to update.
|
||||
keys: the keys that makeup the path to `value`.
|
||||
value: the value to set in the dictionary for the given key path.
|
||||
keys: the keys that makeup the path to `value`.
|
||||
value: the value to set in the dictionary for the given key path.
|
||||
|
||||
Returns:
|
||||
The (possibly) updated dictionary.
|
||||
Returns:
|
||||
The (possibly) updated dictionary.
|
||||
"""
|
||||
if not keys:
|
||||
return d
|
||||
|
|
|
@ -87,6 +87,7 @@ class TokenType(AutoName):
|
|||
FLOAT = auto()
|
||||
DOUBLE = auto()
|
||||
DECIMAL = auto()
|
||||
BIGDECIMAL = auto()
|
||||
CHAR = auto()
|
||||
NCHAR = auto()
|
||||
VARCHAR = auto()
|
||||
|
@ -214,6 +215,7 @@ class TokenType(AutoName):
|
|||
ISNULL = auto()
|
||||
JOIN = auto()
|
||||
JOIN_MARKER = auto()
|
||||
KEEP = auto()
|
||||
LANGUAGE = auto()
|
||||
LATERAL = auto()
|
||||
LAZY = auto()
|
||||
|
@ -231,6 +233,7 @@ class TokenType(AutoName):
|
|||
MOD = auto()
|
||||
NATURAL = auto()
|
||||
NEXT = auto()
|
||||
NEXT_VALUE_FOR = auto()
|
||||
NO_ACTION = auto()
|
||||
NOTNULL = auto()
|
||||
NULL = auto()
|
||||
|
@ -315,7 +318,7 @@ class TokenType(AutoName):
|
|||
|
||||
|
||||
class Token:
|
||||
__slots__ = ("token_type", "text", "line", "col", "comments")
|
||||
__slots__ = ("token_type", "text", "line", "col", "end", "comments")
|
||||
|
||||
@classmethod
|
||||
def number(cls, number: int) -> Token:
|
||||
|
@ -343,22 +346,29 @@ class Token:
|
|||
text: str,
|
||||
line: int = 1,
|
||||
col: int = 1,
|
||||
end: int = 0,
|
||||
comments: t.List[str] = [],
|
||||
) -> None:
|
||||
self.token_type = token_type
|
||||
self.text = text
|
||||
self.line = line
|
||||
self.col = col - len(text)
|
||||
self.col = self.col if self.col > 1 else 1
|
||||
size = len(text)
|
||||
self.col = col
|
||||
self.end = end if end else size
|
||||
self.comments = comments
|
||||
|
||||
@property
|
||||
def start(self) -> int:
|
||||
"""Returns the start of the token."""
|
||||
return self.end - len(self.text)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
|
||||
return f"<Token {attributes}>"
|
||||
|
||||
|
||||
class _Tokenizer(type):
|
||||
def __new__(cls, clsname, bases, attrs): # type: ignore
|
||||
def __new__(cls, clsname, bases, attrs):
|
||||
klass = super().__new__(cls, clsname, bases, attrs)
|
||||
|
||||
klass._QUOTES = {
|
||||
|
@ -433,25 +443,25 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"#": TokenType.HASH,
|
||||
}
|
||||
|
||||
QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
|
||||
|
||||
BIT_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||
|
||||
HEX_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||
|
||||
BYTE_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||
|
||||
HEX_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||
IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
|
||||
|
||||
IDENTIFIER_ESCAPES = ['"']
|
||||
QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
|
||||
STRING_ESCAPES = ["'"]
|
||||
VAR_SINGLE_TOKENS: t.Set[str] = set()
|
||||
|
||||
_COMMENTS: t.Dict[str, str] = {}
|
||||
_BIT_STRINGS: t.Dict[str, str] = {}
|
||||
_BYTE_STRINGS: t.Dict[str, str] = {}
|
||||
_HEX_STRINGS: t.Dict[str, str] = {}
|
||||
_IDENTIFIERS: t.Dict[str, str] = {}
|
||||
_IDENTIFIER_ESCAPES: t.Set[str] = set()
|
||||
_QUOTES: t.Dict[str, str] = {}
|
||||
_STRING_ESCAPES: t.Set[str] = set()
|
||||
|
||||
IDENTIFIER_ESCAPES = ['"']
|
||||
|
||||
_IDENTIFIER_ESCAPES: t.Set[str] = set()
|
||||
|
||||
KEYWORDS = {
|
||||
KEYWORDS: t.Dict[t.Optional[str], TokenType] = {
|
||||
**{f"{{%{postfix}": TokenType.BLOCK_START for postfix in ("", "+", "-")},
|
||||
**{f"{prefix}%}}": TokenType.BLOCK_END for prefix in ("", "+", "-")},
|
||||
"{{+": TokenType.BLOCK_START,
|
||||
|
@ -553,6 +563,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"IS": TokenType.IS,
|
||||
"ISNULL": TokenType.ISNULL,
|
||||
"JOIN": TokenType.JOIN,
|
||||
"KEEP": TokenType.KEEP,
|
||||
"LATERAL": TokenType.LATERAL,
|
||||
"LAZY": TokenType.LAZY,
|
||||
"LEADING": TokenType.LEADING,
|
||||
|
@ -565,6 +576,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"MERGE": TokenType.MERGE,
|
||||
"NATURAL": TokenType.NATURAL,
|
||||
"NEXT": TokenType.NEXT,
|
||||
"NEXT VALUE FOR": TokenType.NEXT_VALUE_FOR,
|
||||
"NO ACTION": TokenType.NO_ACTION,
|
||||
"NOT": TokenType.NOT,
|
||||
"NOTNULL": TokenType.NOTNULL,
|
||||
|
@ -632,6 +644,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"UPDATE": TokenType.UPDATE,
|
||||
"USE": TokenType.USE,
|
||||
"USING": TokenType.USING,
|
||||
"UUID": TokenType.UUID,
|
||||
"VALUES": TokenType.VALUES,
|
||||
"VIEW": TokenType.VIEW,
|
||||
"VOLATILE": TokenType.VOLATILE,
|
||||
|
@ -661,6 +674,8 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"INT8": TokenType.BIGINT,
|
||||
"DEC": TokenType.DECIMAL,
|
||||
"DECIMAL": TokenType.DECIMAL,
|
||||
"BIGDECIMAL": TokenType.BIGDECIMAL,
|
||||
"BIGNUMERIC": TokenType.BIGDECIMAL,
|
||||
"MAP": TokenType.MAP,
|
||||
"NULLABLE": TokenType.NULLABLE,
|
||||
"NUMBER": TokenType.DECIMAL,
|
||||
|
@ -742,7 +757,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
ENCODE: t.Optional[str] = None
|
||||
|
||||
COMMENTS = ["--", ("/*", "*/"), ("{#", "#}")]
|
||||
KEYWORD_TRIE = None # autofilled
|
||||
KEYWORD_TRIE: t.Dict = {} # autofilled
|
||||
|
||||
IDENTIFIER_CAN_START_WITH_DIGIT = False
|
||||
|
||||
|
@ -776,19 +791,28 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._col = 1
|
||||
self._comments: t.List[str] = []
|
||||
|
||||
self._char = None
|
||||
self._end = None
|
||||
self._peek = None
|
||||
self._char = ""
|
||||
self._end = False
|
||||
self._peek = ""
|
||||
self._prev_token_line = -1
|
||||
self._prev_token_comments: t.List[str] = []
|
||||
self._prev_token_type = None
|
||||
self._prev_token_type: t.Optional[TokenType] = None
|
||||
|
||||
def tokenize(self, sql: str) -> t.List[Token]:
|
||||
"""Returns a list of tokens corresponding to the SQL string `sql`."""
|
||||
self.reset()
|
||||
self.sql = sql
|
||||
self.size = len(sql)
|
||||
self._scan()
|
||||
try:
|
||||
self._scan()
|
||||
except Exception as e:
|
||||
start = self._current - 50
|
||||
end = self._current + 50
|
||||
start = start if start > 0 else 0
|
||||
end = end if end < self.size else self.size - 1
|
||||
context = self.sql[start:end]
|
||||
raise ValueError(f"Error tokenizing '{context}'") from e
|
||||
|
||||
return self.tokens
|
||||
|
||||
def _scan(self, until: t.Optional[t.Callable] = None) -> None:
|
||||
|
@ -810,9 +834,12 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
if until and until():
|
||||
break
|
||||
|
||||
if self.tokens:
|
||||
self.tokens[-1].comments.extend(self._comments)
|
||||
|
||||
def _chars(self, size: int) -> str:
|
||||
if size == 1:
|
||||
return self._char # type: ignore
|
||||
return self._char
|
||||
start = self._current - 1
|
||||
end = start + size
|
||||
if end <= self.size:
|
||||
|
@ -821,17 +848,15 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
def _advance(self, i: int = 1) -> None:
|
||||
if self.WHITE_SPACE.get(self._char) is TokenType.BREAK:
|
||||
self._set_new_line()
|
||||
self._col = 1
|
||||
self._line += 1
|
||||
else:
|
||||
self._col += i
|
||||
|
||||
self._col += i
|
||||
self._current += i
|
||||
self._end = self._current >= self.size # type: ignore
|
||||
self._char = self.sql[self._current - 1] # type: ignore
|
||||
self._peek = self.sql[self._current] if self._current < self.size else "" # type: ignore
|
||||
|
||||
def _set_new_line(self) -> None:
|
||||
self._col = 1
|
||||
self._line += 1
|
||||
self._end = self._current >= self.size
|
||||
self._char = self.sql[self._current - 1]
|
||||
self._peek = "" if self._end else self.sql[self._current]
|
||||
|
||||
@property
|
||||
def _text(self) -> str:
|
||||
|
@ -840,13 +865,14 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
|
||||
self._prev_token_line = self._line
|
||||
self._prev_token_comments = self._comments
|
||||
self._prev_token_type = token_type # type: ignore
|
||||
self._prev_token_type = token_type
|
||||
self.tokens.append(
|
||||
Token(
|
||||
token_type,
|
||||
self._text if text is None else text,
|
||||
self._line,
|
||||
self._col,
|
||||
self._current,
|
||||
self._comments,
|
||||
)
|
||||
)
|
||||
|
@ -881,7 +907,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
if skip:
|
||||
result = 1
|
||||
else:
|
||||
result, trie = in_trie(trie, char.upper()) # type: ignore
|
||||
result, trie = in_trie(trie, char.upper())
|
||||
|
||||
if result == 0:
|
||||
break
|
||||
|
@ -910,7 +936,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
if not word:
|
||||
if self._char in self.SINGLE_TOKENS:
|
||||
self._add(self.SINGLE_TOKENS[self._char], text=self._char) # type: ignore
|
||||
self._add(self.SINGLE_TOKENS[self._char], text=self._char)
|
||||
return
|
||||
self._scan_var()
|
||||
return
|
||||
|
@ -927,29 +953,31 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._add(self.KEYWORDS[word], text=word)
|
||||
|
||||
def _scan_comment(self, comment_start: str) -> bool:
|
||||
if comment_start not in self._COMMENTS: # type: ignore
|
||||
if comment_start not in self._COMMENTS:
|
||||
return False
|
||||
|
||||
comment_start_line = self._line
|
||||
comment_start_size = len(comment_start)
|
||||
comment_end = self._COMMENTS[comment_start] # type: ignore
|
||||
comment_end = self._COMMENTS[comment_start]
|
||||
|
||||
if comment_end:
|
||||
comment_end_size = len(comment_end)
|
||||
# Skip the comment's start delimiter
|
||||
self._advance(comment_start_size)
|
||||
|
||||
comment_end_size = len(comment_end)
|
||||
while not self._end and self._chars(comment_end_size) != comment_end:
|
||||
self._advance()
|
||||
|
||||
self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) # type: ignore
|
||||
self._comments.append(self._text[comment_start_size : -comment_end_size + 1])
|
||||
self._advance(comment_end_size - 1)
|
||||
else:
|
||||
while not self._end and not self.WHITE_SPACE.get(self._peek) is TokenType.BREAK:
|
||||
self._advance()
|
||||
self._comments.append(self._text[comment_start_size:]) # type: ignore
|
||||
self._comments.append(self._text[comment_start_size:])
|
||||
|
||||
# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding.
|
||||
# Multiple consecutive comments are preserved by appending them to the current comments list.
|
||||
if comment_start_line == self._prev_token_line or self._end:
|
||||
if comment_start_line == self._prev_token_line:
|
||||
self.tokens[-1].comments.extend(self._comments)
|
||||
self._comments = []
|
||||
self._prev_token_line = self._line
|
||||
|
@ -958,7 +986,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
def _scan_number(self) -> None:
|
||||
if self._char == "0":
|
||||
peek = self._peek.upper() # type: ignore
|
||||
peek = self._peek.upper()
|
||||
if peek == "B":
|
||||
return self._scan_bits()
|
||||
elif peek == "X":
|
||||
|
@ -968,7 +996,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
scientific = 0
|
||||
|
||||
while True:
|
||||
if self._peek.isdigit(): # type: ignore
|
||||
if self._peek.isdigit():
|
||||
self._advance()
|
||||
elif self._peek == "." and not decimal:
|
||||
decimal = True
|
||||
|
@ -976,24 +1004,23 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
elif self._peek in ("-", "+") and scientific == 1:
|
||||
scientific += 1
|
||||
self._advance()
|
||||
elif self._peek.upper() == "E" and not scientific: # type: ignore
|
||||
elif self._peek.upper() == "E" and not scientific:
|
||||
scientific += 1
|
||||
self._advance()
|
||||
elif self._peek.isidentifier(): # type: ignore
|
||||
elif self._peek.isidentifier():
|
||||
number_text = self._text
|
||||
literal = []
|
||||
literal = ""
|
||||
|
||||
while self._peek.strip() and self._peek not in self.SINGLE_TOKENS: # type: ignore
|
||||
literal.append(self._peek.upper()) # type: ignore
|
||||
while self._peek.strip() and self._peek not in self.SINGLE_TOKENS:
|
||||
literal += self._peek.upper()
|
||||
self._advance()
|
||||
|
||||
literal = "".join(literal) # type: ignore
|
||||
token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) # type: ignore
|
||||
token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal))
|
||||
|
||||
if token_type:
|
||||
self._add(TokenType.NUMBER, number_text)
|
||||
self._add(TokenType.DCOLON, "::")
|
||||
return self._add(token_type, literal) # type: ignore
|
||||
return self._add(token_type, literal)
|
||||
elif self.IDENTIFIER_CAN_START_WITH_DIGIT:
|
||||
return self._add(TokenType.VAR)
|
||||
|
||||
|
@ -1020,7 +1047,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
def _extract_value(self) -> str:
|
||||
while True:
|
||||
char = self._peek.strip() # type: ignore
|
||||
char = self._peek.strip()
|
||||
if char and char not in self.SINGLE_TOKENS:
|
||||
self._advance()
|
||||
else:
|
||||
|
@ -1029,35 +1056,35 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
return self._text
|
||||
|
||||
def _scan_string(self, quote: str) -> bool:
|
||||
quote_end = self._QUOTES.get(quote) # type: ignore
|
||||
quote_end = self._QUOTES.get(quote)
|
||||
if quote_end is None:
|
||||
return False
|
||||
|
||||
self._advance(len(quote))
|
||||
text = self._extract_string(quote_end)
|
||||
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore
|
||||
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
|
||||
self._add(TokenType.NATIONAL if quote[0].upper() == "N" else TokenType.STRING, text)
|
||||
return True
|
||||
|
||||
# X'1234, b'0110', E'\\\\\' etc.
|
||||
def _scan_formatted_string(self, string_start: str) -> bool:
|
||||
if string_start in self._HEX_STRINGS: # type: ignore
|
||||
delimiters = self._HEX_STRINGS # type: ignore
|
||||
if string_start in self._HEX_STRINGS:
|
||||
delimiters = self._HEX_STRINGS
|
||||
token_type = TokenType.HEX_STRING
|
||||
base = 16
|
||||
elif string_start in self._BIT_STRINGS: # type: ignore
|
||||
delimiters = self._BIT_STRINGS # type: ignore
|
||||
elif string_start in self._BIT_STRINGS:
|
||||
delimiters = self._BIT_STRINGS
|
||||
token_type = TokenType.BIT_STRING
|
||||
base = 2
|
||||
elif string_start in self._BYTE_STRINGS: # type: ignore
|
||||
delimiters = self._BYTE_STRINGS # type: ignore
|
||||
elif string_start in self._BYTE_STRINGS:
|
||||
delimiters = self._BYTE_STRINGS
|
||||
token_type = TokenType.BYTE_STRING
|
||||
base = None
|
||||
else:
|
||||
return False
|
||||
|
||||
self._advance(len(string_start))
|
||||
string_end = delimiters.get(string_start)
|
||||
string_end = delimiters[string_start]
|
||||
text = self._extract_string(string_end)
|
||||
|
||||
if base is None:
|
||||
|
@ -1083,20 +1110,20 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._advance()
|
||||
if self._char == identifier_end:
|
||||
if identifier_end_is_escape and self._peek == identifier_end:
|
||||
text += identifier_end # type: ignore
|
||||
text += identifier_end
|
||||
self._advance()
|
||||
continue
|
||||
|
||||
break
|
||||
|
||||
text += self._char # type: ignore
|
||||
text += self._char
|
||||
|
||||
self._add(TokenType.IDENTIFIER, text)
|
||||
|
||||
def _scan_var(self) -> None:
|
||||
while True:
|
||||
char = self._peek.strip() # type: ignore
|
||||
if char and char not in self.SINGLE_TOKENS:
|
||||
char = self._peek.strip()
|
||||
if char and (char in self.VAR_SINGLE_TOKENS or char not in self.SINGLE_TOKENS):
|
||||
self._advance()
|
||||
else:
|
||||
break
|
||||
|
@ -1115,9 +1142,9 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._peek == delimiter or self._peek in self._STRING_ESCAPES
|
||||
):
|
||||
if self._peek == delimiter:
|
||||
text += self._peek # type: ignore
|
||||
text += self._peek
|
||||
else:
|
||||
text += self._char + self._peek # type: ignore
|
||||
text += self._char + self._peek
|
||||
|
||||
if self._current + 1 < self.size:
|
||||
self._advance(2)
|
||||
|
@ -1131,7 +1158,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
if self._end:
|
||||
raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}")
|
||||
text += self._char # type: ignore
|
||||
text += self._char
|
||||
self._advance()
|
||||
|
||||
return text
|
||||
|
|
|
@ -103,7 +103,11 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
|
|||
if isinstance(expr, exp.Window):
|
||||
alias = find_new_name(expression.named_selects, "_w")
|
||||
expression.select(exp.alias_(expr.copy(), alias), copy=False)
|
||||
expr.replace(exp.column(alias))
|
||||
column = exp.column(alias)
|
||||
if isinstance(expr.parent, exp.Qualify):
|
||||
qualify_filters = column
|
||||
else:
|
||||
expr.replace(column)
|
||||
elif expr.name not in expression.named_selects:
|
||||
expression.select(expr.copy(), copy=False)
|
||||
|
||||
|
@ -133,9 +137,111 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr
|
|||
)
|
||||
|
||||
|
||||
def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
|
||||
"""Convert cross join unnest into lateral view explode (used in presto -> hive)."""
|
||||
if isinstance(expression, exp.Select):
|
||||
for join in expression.args.get("joins") or []:
|
||||
unnest = join.this
|
||||
|
||||
if isinstance(unnest, exp.Unnest):
|
||||
alias = unnest.args.get("alias")
|
||||
udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
|
||||
|
||||
expression.args["joins"].remove(join)
|
||||
|
||||
for e, column in zip(unnest.expressions, alias.columns if alias else []):
|
||||
expression.append(
|
||||
"laterals",
|
||||
exp.Lateral(
|
||||
this=udtf(this=e),
|
||||
view=True,
|
||||
alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore
|
||||
),
|
||||
)
|
||||
return expression
|
||||
|
||||
|
||||
def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
|
||||
"""Convert explode/posexplode into unnest (used in hive -> presto)."""
|
||||
if isinstance(expression, exp.Select):
|
||||
from sqlglot.optimizer.scope import build_scope
|
||||
|
||||
taken_select_names = set(expression.named_selects)
|
||||
taken_source_names = set(build_scope(expression).selected_sources)
|
||||
|
||||
for select in expression.selects:
|
||||
to_replace = select
|
||||
|
||||
pos_alias = ""
|
||||
explode_alias = ""
|
||||
|
||||
if isinstance(select, exp.Alias):
|
||||
explode_alias = select.alias
|
||||
select = select.this
|
||||
elif isinstance(select, exp.Aliases):
|
||||
pos_alias = select.aliases[0].name
|
||||
explode_alias = select.aliases[1].name
|
||||
select = select.this
|
||||
|
||||
if isinstance(select, (exp.Explode, exp.Posexplode)):
|
||||
is_posexplode = isinstance(select, exp.Posexplode)
|
||||
|
||||
explode_arg = select.this
|
||||
unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode)
|
||||
|
||||
# This ensures that we won't use [POS]EXPLODE's argument as a new selection
|
||||
if isinstance(explode_arg, exp.Column):
|
||||
taken_select_names.add(explode_arg.output_name)
|
||||
|
||||
unnest_source_alias = find_new_name(taken_source_names, "_u")
|
||||
taken_source_names.add(unnest_source_alias)
|
||||
|
||||
if not explode_alias:
|
||||
explode_alias = find_new_name(taken_select_names, "col")
|
||||
taken_select_names.add(explode_alias)
|
||||
|
||||
if is_posexplode:
|
||||
pos_alias = find_new_name(taken_select_names, "pos")
|
||||
taken_select_names.add(pos_alias)
|
||||
|
||||
if is_posexplode:
|
||||
column_names = [explode_alias, pos_alias]
|
||||
to_replace.pop()
|
||||
expression.select(pos_alias, explode_alias, copy=False)
|
||||
else:
|
||||
column_names = [explode_alias]
|
||||
to_replace.replace(exp.column(explode_alias))
|
||||
|
||||
unnest = exp.alias_(unnest, unnest_source_alias, table=column_names)
|
||||
|
||||
if not expression.args.get("from"):
|
||||
expression.from_(unnest, copy=False)
|
||||
else:
|
||||
expression.join(unnest, join_type="CROSS", copy=False)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
|
||||
"""Remove table refs from columns in when statements."""
|
||||
if isinstance(expression, exp.Merge):
|
||||
alias = expression.this.args.get("alias")
|
||||
targets = {expression.this.this}
|
||||
if alias:
|
||||
targets.add(alias.this)
|
||||
|
||||
for when in expression.expressions:
|
||||
when.transform(
|
||||
lambda node: exp.column(node.name)
|
||||
if isinstance(node, exp.Column) and node.args.get("table") in targets
|
||||
else node,
|
||||
copy=False,
|
||||
)
|
||||
return expression
|
||||
|
||||
|
||||
def preprocess(
|
||||
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
|
||||
to_sql: t.Callable[[Generator, exp.Expression], str],
|
||||
) -> t.Callable[[Generator, exp.Expression], str]:
|
||||
"""
|
||||
Creates a new transform by chaining a sequence of transformations and converts the resulting
|
||||
|
@ -143,36 +249,23 @@ def preprocess(
|
|||
|
||||
Args:
|
||||
transforms: sequence of transform functions. These will be called in order.
|
||||
to_sql: final transform that converts the resulting expression to a SQL string.
|
||||
|
||||
Returns:
|
||||
Function that can be used as a generator transform.
|
||||
"""
|
||||
|
||||
def _to_sql(self, expression):
|
||||
def _to_sql(self, expression: exp.Expression) -> str:
|
||||
expression = transforms[0](expression.copy())
|
||||
for t in transforms[1:]:
|
||||
expression = t(expression)
|
||||
return to_sql(self, expression)
|
||||
return getattr(self, expression.key + "_sql")(expression)
|
||||
|
||||
return _to_sql
|
||||
|
||||
|
||||
def delegate(attr: str) -> t.Callable:
|
||||
"""
|
||||
Create a new method that delegates to `attr`. This is useful for creating `Generator.TRANSFORMS`
|
||||
functions that delegate to existing generator methods.
|
||||
"""
|
||||
|
||||
def _transform(self, *args, **kwargs):
|
||||
return getattr(self, attr)(*args, **kwargs)
|
||||
|
||||
return _transform
|
||||
|
||||
|
||||
UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))}
|
||||
ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))}
|
||||
ELIMINATE_QUALIFY = {exp.Select: preprocess([eliminate_qualify], delegate("select_sql"))}
|
||||
UNALIAS_GROUP = {exp.Group: preprocess([unalias_group])}
|
||||
ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on])}
|
||||
ELIMINATE_QUALIFY = {exp.Select: preprocess([eliminate_qualify])}
|
||||
REMOVE_PRECISION_PARAMETERIZED_TYPES = {
|
||||
exp.Cast: preprocess([remove_precision_parameterized_types], delegate("cast_sql"))
|
||||
exp.Cast: preprocess([remove_precision_parameterized_types])
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ import typing as t
|
|||
key = t.Sequence[t.Hashable]
|
||||
|
||||
|
||||
def new_trie(keywords: t.Iterable[key]) -> t.Dict:
|
||||
def new_trie(keywords: t.Iterable[key], trie: t.Optional[t.Dict] = None) -> t.Dict:
|
||||
"""
|
||||
Creates a new trie out of a collection of keywords.
|
||||
|
||||
|
@ -16,11 +16,12 @@ def new_trie(keywords: t.Iterable[key]) -> t.Dict:
|
|||
|
||||
Args:
|
||||
keywords: the keywords to create the trie from.
|
||||
trie: a trie to mutate instead of creating a new one
|
||||
|
||||
Returns:
|
||||
The trie corresponding to `keywords`.
|
||||
"""
|
||||
trie: t.Dict = {}
|
||||
trie = {} if trie is None else trie
|
||||
|
||||
for key in keywords:
|
||||
current = trie
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue