1
0
Fork 0

Merging upstream version 11.7.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:52:09 +01:00
parent 0c053462ae
commit 8d96084fad
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
144 changed files with 44104 additions and 39367 deletions

View file

@ -21,10 +21,12 @@ from sqlglot.expressions import (
Expression as Expression,
alias_ as alias,
and_ as and_,
cast as cast,
column as column,
condition as condition,
except_ as except_,
from_ as from_,
func as func,
intersect as intersect,
maybe_parse as maybe_parse,
not_ as not_,
@ -33,6 +35,7 @@ from sqlglot.expressions import (
subquery as subquery,
table_ as table,
to_column as to_column,
to_identifier as to_identifier,
to_table as to_table,
union as union,
)
@ -47,7 +50,7 @@ if t.TYPE_CHECKING:
T = t.TypeVar("T", bound=Expression)
__version__ = "11.5.2"
__version__ = "11.7.1"
pretty = False
"""Whether to format generated SQL by default."""

View file

@ -176,7 +176,7 @@ class Column:
return isinstance(self.expression, exp.Column)
@property
def column_expression(self) -> exp.Column:
def column_expression(self) -> t.Union[exp.Column, exp.Literal]:
return self.expression.unalias()
@property

View file

@ -16,7 +16,7 @@ from sqlglot.dataframe.sql.readwriter import DataFrameWriter
from sqlglot.dataframe.sql.transforms import replace_id_value
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
from sqlglot.dataframe.sql.window import Window
from sqlglot.helper import ensure_list, object_to_dict
from sqlglot.helper import ensure_list, object_to_dict, seq_get
from sqlglot.optimizer import optimize as optimize_func
if t.TYPE_CHECKING:
@ -146,9 +146,9 @@ class DataFrame:
def _ensure_list_of_columns(self, cols):
return Column.ensure_cols(ensure_list(cols))
def _ensure_and_normalize_cols(self, cols):
def _ensure_and_normalize_cols(self, cols, expression: t.Optional[exp.Select] = None):
cols = self._ensure_list_of_columns(cols)
normalize(self.spark, self.expression, cols)
normalize(self.spark, expression or self.expression, cols)
return cols
def _ensure_and_normalize_col(self, col):
@ -355,12 +355,20 @@ class DataFrame:
cols = self._ensure_and_normalize_cols(cols)
kwargs["append"] = kwargs.get("append", False)
if self.expression.args.get("joins"):
ambiguous_cols = [col for col in cols if not col.column_expression.table]
ambiguous_cols = [
col
for col in cols
if isinstance(col.column_expression, exp.Column) and not col.column_expression.table
]
if ambiguous_cols:
join_table_identifiers = [
x.this for x in get_tables_from_expression_with_join(self.expression)
]
cte_names_in_join = [x.this for x in join_table_identifiers]
# If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right
# and therefore we allow multiple columns with the same name in the result. This matches the behavior
# of Spark.
resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols}
for ambiguous_col in ambiguous_cols:
ctes_with_column = [
cte
@ -368,13 +376,14 @@ class DataFrame:
if cte.alias_or_name in cte_names_in_join
and ambiguous_col.alias_or_name in cte.this.named_selects
]
# If the select column does not specify a table and there is a join
# then we assume they are referring to the left table
if len(ctes_with_column) > 1:
table_identifier = self.expression.args["from"].args["expressions"][0].this
# Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise,
# use the same CTE we used before
cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1)
if cte:
resolved_column_position[ambiguous_col] += 1
else:
table_identifier = ctes_with_column[0].args["alias"].this
ambiguous_col.expression.set("table", table_identifier)
cte = ctes_with_column[resolved_column_position[ambiguous_col]]
ambiguous_col.expression.set("table", cte.alias_or_name)
return self.copy(
expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs
)
@ -416,59 +425,87 @@ class DataFrame:
**kwargs,
) -> DataFrame:
other_df = other_df._convert_leaf_to_cte()
pre_join_self_latest_cte_name = self.latest_cte_name
columns = self._ensure_and_normalize_cols(on)
join_type = how.replace("_", " ")
if isinstance(columns[0].expression, exp.Column):
join_columns = [
Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns
join_columns = self._ensure_list_of_columns(on)
# We will determine actual "join on" expression later so we don't provide it at first
join_expression = self.expression.join(
other_df.latest_cte_name, join_type=how.replace("_", " ")
)
join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes)
self_columns = self._get_outer_select_columns(join_expression)
other_columns = self._get_outer_select_columns(other_df)
# Determines the join clause and select columns to be used passed on what type of columns were provided for
# the join. The columns returned changes based on how the on expression is provided.
if isinstance(join_columns[0].expression, exp.Column):
"""
Unique characteristics of join on column names only:
* The column names are put at the front of the select list
* The column names are deduplicated across the entire select list and only the column names (other dups are allowed)
"""
table_names = [
table.alias_or_name
for table in get_tables_from_expression_with_join(join_expression)
]
potential_ctes = [
cte
for cte in join_expression.ctes
if cte.alias_or_name in table_names
and cte.alias_or_name != other_df.latest_cte_name
]
# Determine the table to reference for the left side of the join by checking each of the left side
# tables and see if they have the column being referenced.
join_column_pairs = []
for join_column in join_columns:
num_matching_ctes = 0
for cte in potential_ctes:
if join_column.alias_or_name in cte.this.named_selects:
left_column = join_column.copy().set_table_name(cte.alias_or_name)
right_column = join_column.copy().set_table_name(other_df.latest_cte_name)
join_column_pairs.append((left_column, right_column))
num_matching_ctes += 1
if num_matching_ctes > 1:
raise ValueError(
f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name."
)
elif num_matching_ctes == 0:
raise ValueError(
f"Column {join_column.alias_or_name} does not exist in any of the tables."
)
join_clause = functools.reduce(
lambda x, y: x & y,
[
col.copy().set_table_name(pre_join_self_latest_cte_name)
== col.copy().set_table_name(other_df.latest_cte_name)
for col in columns
],
[left_column == right_column for left_column, right_column in join_column_pairs],
)
else:
if len(columns) > 1:
columns = [functools.reduce(lambda x, y: x & y, columns)]
join_clause = columns[0]
join_columns = [
Column(x).set_table_name(pre_join_self_latest_cte_name)
if i % 2 == 0
else Column(x).set_table_name(other_df.latest_cte_name)
for i, x in enumerate(join_clause.expression.find_all(exp.Column))
join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs]
# To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list
select_column_names = [
column.alias_or_name
if not isinstance(column.expression.this, exp.Star)
else column.sql()
for column in self_columns + other_columns
]
self_columns = [
column.set_table_name(pre_join_self_latest_cte_name, copy=True)
for column in self._get_outer_select_columns(self)
]
other_columns = [
column.set_table_name(other_df.latest_cte_name, copy=True)
for column in self._get_outer_select_columns(other_df)
]
column_value_mapping = {
column.alias_or_name
if not isinstance(column.expression.this, exp.Star)
else column.sql(): column
for column in other_columns + self_columns + join_columns
}
all_columns = [
column_value_mapping[name]
for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns}
]
new_df = self.copy(
expression=self.expression.join(
other_df.latest_cte_name, on=join_clause.expression, join_type=join_type
)
)
new_df.expression = new_df._add_ctes_to_expression(
new_df.expression, other_df.expression.ctes
)
select_column_names = [
column_name
for column_name in select_column_names
if column_name not in join_column_names
]
select_column_names = join_column_names + select_column_names
else:
"""
Unique characteristics of join on expressions:
* There is no deduplication of the results.
* The left join dataframe columns go first and right come after. No sort preference is given to join columns
"""
join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
if len(join_columns) > 1:
join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
join_clause = join_columns[0]
select_column_names = [column.alias_or_name for column in self_columns + other_columns]
# Update the on expression with the actual join clause to replace the dummy one from before
join_expression.args["joins"][-1].set("on", join_clause.expression)
new_df = self.copy(expression=join_expression)
new_df.pending_join_hints.extend(self.pending_join_hints)
new_df.pending_hints.extend(other_df.pending_hints)
new_df = new_df.select.__wrapped__(new_df, *all_columns)
new_df = new_df.select.__wrapped__(new_df, *select_column_names)
return new_df
@operation(Operation.ORDER_BY)

View file

@ -577,11 +577,15 @@ def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Col
def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
return Column.invoke_expression_over_column(col, expression.DateAdd, expression=days)
return Column.invoke_expression_over_column(
col, expression.DateAdd, expression=days, unit=expression.Var(this="day")
)
def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
return Column.invoke_expression_over_column(col, expression.DateSub, expression=days)
return Column.invoke_expression_over_column(
col, expression.DateSub, expression=days, unit=expression.Var(this="day")
)
def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column:
@ -695,18 +699,17 @@ def crc32(col: ColumnOrName) -> Column:
def md5(col: ColumnOrName) -> Column:
column = col if isinstance(col, Column) else lit(col)
return Column.invoke_anonymous_function(column, "MD5")
return Column.invoke_expression_over_column(column, expression.MD5)
def sha1(col: ColumnOrName) -> Column:
column = col if isinstance(col, Column) else lit(col)
return Column.invoke_anonymous_function(column, "SHA1")
return Column.invoke_expression_over_column(column, expression.SHA)
def sha2(col: ColumnOrName, numBits: int) -> Column:
column = col if isinstance(col, Column) else lit(col)
num_bits = lit(numBits)
return Column.invoke_anonymous_function(column, "SHA2", num_bits)
return Column.invoke_expression_over_column(column, expression.SHA2, length=lit(numBits))
def hash(*cols: ColumnOrName) -> Column:

View file

@ -4,7 +4,7 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
from sqlglot.helper import object_to_dict
from sqlglot.helper import object_to_dict, should_identify
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.dataframe import DataFrame
@ -19,9 +19,17 @@ class DataFrameReader:
from sqlglot.dataframe.sql.dataframe import DataFrame
sqlglot.schema.add_table(tableName)
return DataFrame(
self.spark,
exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)),
exp.Select()
.from_(tableName)
.select(
*(
column if should_identify(column, "safe") else f'"{column}"'
for column in sqlglot.schema.column_names(tableName)
)
),
)

View file

@ -13,6 +13,7 @@ from sqlglot.dialects.dialect import (
max_or_greatest,
min_or_least,
no_ilike_sql,
parse_date_delta_with_interval,
rename_func,
timestrtotime_sql,
ts_or_ds_to_date_sql,
@ -23,18 +24,6 @@ from sqlglot.tokens import TokenType
E = t.TypeVar("E", bound=exp.Expression)
def _date_add(expression_class: t.Type[E]) -> t.Callable[[t.Sequence], E]:
def func(args):
interval = seq_get(args, 1)
return expression_class(
this=seq_get(args, 0),
expression=interval.this,
unit=interval.args.get("unit"),
)
return func
def _date_add_sql(
data_type: str, kind: str
) -> t.Callable[[generator.Generator, exp.Expression], str]:
@ -142,6 +131,7 @@ class BigQuery(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ANY TYPE": TokenType.VARIANT,
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
@ -155,14 +145,19 @@ class BigQuery(Dialect):
KEYWORDS.pop("DIV")
class Parser(parser.Parser):
PREFIXED_PIVOT_COLUMNS = True
LOG_BASE_FIRST = False
LOG_DEFAULTS_TO_LN = True
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"DATE_TRUNC": lambda args: exp.DateTrunc(
unit=exp.Literal.string(seq_get(args, 1).name), # type: ignore
this=seq_get(args, 0),
),
"DATE_ADD": _date_add(exp.DateAdd),
"DATETIME_ADD": _date_add(exp.DatetimeAdd),
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
"DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd),
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
"REGEXP_CONTAINS": exp.RegexpLike.from_arg_list,
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
@ -174,12 +169,12 @@ class BigQuery(Dialect):
if re.compile(str(seq_get(args, 1))).groups == 1
else None,
),
"TIME_ADD": _date_add(exp.TimeAdd),
"TIMESTAMP_ADD": _date_add(exp.TimestampAdd),
"DATE_SUB": _date_add(exp.DateSub),
"DATETIME_SUB": _date_add(exp.DatetimeSub),
"TIME_SUB": _date_add(exp.TimeSub),
"TIMESTAMP_SUB": _date_add(exp.TimestampSub),
"TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd),
"TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd),
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
"DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub),
"TIME_SUB": parse_date_delta_with_interval(exp.TimeSub),
"TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub),
"PARSE_TIMESTAMP": lambda args: exp.StrToTime(
this=seq_get(args, 1), format=seq_get(args, 0)
),
@ -209,14 +204,17 @@ class BigQuery(Dialect):
PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS, # type: ignore
"NOT DETERMINISTIC": lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
exp.StabilityProperty, this=exp.Literal.string("VOLATILE")
),
}
LOG_BASE_FIRST = False
LOG_DEFAULTS_TO_LN = True
class Generator(generator.Generator):
EXPLICIT_UNION = True
INTERVAL_ALLOWS_PLURAL_FORM = False
JOIN_HINTS = False
TABLE_HINTS = False
LIMIT_FETCH = "LIMIT"
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore
@ -236,9 +234,7 @@ class BigQuery(Dialect):
exp.IntDiv: rename_func("DIV"),
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Select: transforms.preprocess(
[_unqualify_unnest], transforms.delegate("select_sql")
),
exp.Select: transforms.preprocess([_unqualify_unnest]),
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
exp.TimeSub: _date_add_sql("TIME", "SUB"),
@ -253,7 +249,7 @@ class BigQuery(Dialect):
exp.ReturnsProperty: _returnsproperty_sql,
exp.Create: _create_sql,
exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression),
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
exp.StabilityProperty: lambda self, e: f"DETERMINISTIC"
if e.name == "IMMUTABLE"
else "NOT DETERMINISTIC",
exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
@ -261,6 +257,7 @@ class BigQuery(Dialect):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC",
exp.DataType.Type.BIGINT: "INT64",
exp.DataType.Type.BOOLEAN: "BOOL",
exp.DataType.Type.CHAR: "STRING",
@ -272,17 +269,19 @@ class BigQuery(Dialect):
exp.DataType.Type.NVARCHAR: "STRING",
exp.DataType.Type.SMALLINT: "INT64",
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.TIMESTAMP: "DATETIME",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.VARCHAR: "STRING",
exp.DataType.Type.VARIANT: "ANY TYPE",
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
EXPLICIT_UNION = True
LIMIT_FETCH = "LIMIT"
def array_sql(self, expression: exp.Array) -> str:
first_arg = seq_get(expression.expressions, 0)
if isinstance(first_arg, exp.Subqueryable):

View file

@ -144,6 +144,13 @@ class ClickHouse(Dialect):
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
JOIN_HINTS = False
TABLE_HINTS = False
EXPLICIT_UNION = True
def _param_args_sql(

View file

@ -9,6 +9,8 @@ from sqlglot.tokens import TokenType
class Databricks(Spark):
class Parser(Spark.Parser):
LOG_DEFAULTS_TO_LN = True
FUNCTIONS = {
**Spark.Parser.FUNCTIONS,
"DATEADD": parse_date_delta(exp.DateAdd),
@ -16,13 +18,17 @@ class Databricks(Spark):
"DATEDIFF": parse_date_delta(exp.DateDiff),
}
LOG_DEFAULTS_TO_LN = True
FACTOR = {
**Spark.Parser.FACTOR,
TokenType.COLON: exp.JSONExtract,
}
class Generator(Spark.Generator):
TRANSFORMS = {
**Spark.Generator.TRANSFORMS, # type: ignore
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.JSONExtract: lambda self, e: self.binary(e, ":"),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}
TRANSFORMS.pop(exp.Select) # Remove the ELIMINATE_QUALIFY transformation

View file

@ -293,6 +293,13 @@ def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
return ""
def no_comment_column_constraint_sql(
self: Generator, expression: exp.CommentColumnConstraint
) -> str:
self.unsupported("CommentColumnConstraint unsupported")
return ""
def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
this = self.sql(expression, "this")
substr = self.sql(expression, "substr")
@ -379,15 +386,35 @@ def parse_date_delta(
) -> t.Callable[[t.Sequence], E]:
def inner_func(args: t.Sequence) -> E:
unit_based = len(args) == 3
this = seq_get(args, 2) if unit_based else seq_get(args, 0)
expression = seq_get(args, 1) if unit_based else seq_get(args, 1)
unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY")
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit # type: ignore
return exp_class(this=this, expression=expression, unit=unit)
this = args[2] if unit_based else seq_get(args, 0)
unit = args[0] if unit_based else exp.Literal.string("DAY")
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
return inner_func
def parse_date_delta_with_interval(
expression_class: t.Type[E],
) -> t.Callable[[t.Sequence], t.Optional[E]]:
def func(args: t.Sequence) -> t.Optional[E]:
if len(args) < 2:
return None
interval = args[1]
expression = interval.this
if expression and expression.is_string:
expression = exp.Literal.number(expression.this)
return expression_class(
this=args[0],
expression=expression,
unit=exp.Literal.string(interval.text("unit")),
)
return func
def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
unit = seq_get(args, 0)
this = seq_get(args, 1)

View file

@ -104,6 +104,9 @@ class Drill(Dialect):
LOG_DEFAULTS_TO_LN = True
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.INT: "INTEGER",
@ -120,6 +123,7 @@ class Drill(Dialect):
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
TRANSFORMS = {

View file

@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
arrow_json_extract_sql,
datestrtodate_sql,
format_time_lambda,
no_comment_column_constraint_sql,
no_pivot_sql,
no_properties_sql,
no_safe_divide_sql,
@ -23,7 +24,7 @@ from sqlglot.tokens import TokenType
def _ts_or_ds_add(self, expression):
this = expression.args.get("this")
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
@ -139,6 +140,8 @@ class DuckDB(Dialect):
}
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
STRUCT_DELIMITER = ("(", ")")
TRANSFORMS = {
@ -150,6 +153,7 @@ class DuckDB(Dialect):
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArraySort: _array_sort_sql,
exp.ArraySum: rename_func("LIST_SUM"),
exp.CommentColumnConstraint: no_comment_column_constraint_sql,
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
@ -213,6 +217,11 @@ class DuckDB(Dialect):
"except": "EXCLUDE",
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
LIMIT_FETCH = "LIMIT"
def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str:

View file

@ -45,16 +45,23 @@ TIME_DIFF_FACTOR = {
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
def _add_date_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
unit = expression.text("unit").upper()
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
modified_increment = (
int(expression.text("expression")) * multiplier
if expression.expression.is_number
else expression.expression
)
modified_increment = exp.Literal.number(modified_increment)
return self.func(func, expression.this, modified_increment.this)
if isinstance(expression, exp.DateSub):
multiplier *= -1
if expression.expression.is_number:
modified_increment = exp.Literal.number(int(expression.text("expression")) * multiplier)
else:
modified_increment = expression.expression
if multiplier != 1:
modified_increment = exp.Mul( # type: ignore
this=modified_increment, expression=exp.Literal.number(multiplier)
)
return self.func(func, expression.this, modified_increment)
def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
@ -127,24 +134,6 @@ def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str
return f"TO_DATE({this})"
def _unnest_to_explode_sql(self: generator.Generator, expression: exp.Join) -> str:
unnest = expression.this
if isinstance(unnest, exp.Unnest):
alias = unnest.args.get("alias")
udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
return "".join(
self.sql(
exp.Lateral(
this=udtf(this=expression),
view=True,
alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore
)
)
for expression, column in zip(unnest.expressions, alias.columns if alias else [])
)
return self.join_sql(expression)
def _index_sql(self: generator.Generator, expression: exp.Index) -> str:
this = self.sql(expression, "this")
table = self.sql(expression, "table")
@ -195,6 +184,7 @@ class Hive(Dialect):
IDENTIFIERS = ["`"]
STRING_ESCAPES = ["\\"]
ENCODE = "utf-8"
IDENTIFIER_CAN_START_WITH_DIGIT = True
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@ -217,9 +207,8 @@ class Hive(Dialect):
"BD": "DECIMAL",
}
IDENTIFIER_CAN_START_WITH_DIGIT = True
class Parser(parser.Parser):
LOG_DEFAULTS_TO_LN = True
STRICT_CAST = False
FUNCTIONS = {
@ -273,9 +262,13 @@ class Hive(Dialect):
),
}
LOG_DEFAULTS_TO_LN = True
class Generator(generator.Generator):
LIMIT_FETCH = "LIMIT"
TABLESAMPLE_WITH_METHOD = False
TABLESAMPLE_SIZE_IS_PERCENT = True
JOIN_HINTS = False
TABLE_HINTS = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TEXT: "STRING",
@ -289,6 +282,9 @@ class Hive(Dialect):
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
**transforms.ELIMINATE_QUALIFY, # type: ignore
exp.Select: transforms.preprocess(
[transforms.eliminate_qualify, transforms.unnest_to_explode]
),
exp.Property: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayConcat: rename_func("CONCAT"),
@ -298,13 +294,13 @@ class Hive(Dialect):
exp.DateAdd: _add_date_sql,
exp.DateDiff: _date_diff_sql,
exp.DateStrToDate: rename_func("TO_DATE"),
exp.DateSub: _add_date_sql,
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
exp.FileFormatProperty: lambda self, e: f"STORED AS {e.name.upper()}",
exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
exp.If: if_sql,
exp.Index: _index_sql,
exp.ILike: no_ilike_sql,
exp.Join: _unnest_to_explode_sql,
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
exp.JSONFormat: rename_func("TO_JSON"),
@ -354,10 +350,9 @@ class Hive(Dialect):
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA,
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
LIMIT_FETCH = "LIMIT"
def arrayagg_sql(self, expression: exp.ArrayAgg) -> str:
return self.func(
"COLLECT_LIST",
@ -378,4 +373,5 @@ class Hive(Dialect):
expression = exp.DataType.build("text")
elif expression.this in exp.DataType.TEMPORAL_TYPES:
expression = exp.DataType.build(expression.this)
return super().datatype_sql(expression)

View file

@ -4,6 +4,8 @@ from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
datestrtodate_sql,
format_time_lambda,
locate_to_strposition,
max_or_greatest,
min_or_least,
@ -11,6 +13,7 @@ from sqlglot.dialects.dialect import (
no_paren_current_date_sql,
no_tablesample_sql,
no_trycast_sql,
parse_date_delta_with_interval,
rename_func,
strposition_to_locate_sql,
)
@ -76,18 +79,6 @@ def _trim_sql(self, expression):
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
def _date_add(expression_class):
def func(args):
interval = seq_get(args, 1)
return expression_class(
this=seq_get(args, 0),
expression=interval.this,
unit=exp.Literal.string(interval.text("unit").lower()),
)
return func
def _date_add_sql(kind):
def func(self, expression):
this = self.sql(expression, "this")
@ -115,6 +106,7 @@ class MySQL(Dialect):
"%k": "%-H",
"%l": "%-I",
"%T": "%H:%M:%S",
"%W": "%a",
}
class Tokenizer(tokens.Tokenizer):
@ -127,12 +119,13 @@ class MySQL(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"MEDIUMTEXT": TokenType.MEDIUMTEXT,
"CHARSET": TokenType.CHARACTER_SET,
"LONGBLOB": TokenType.LONGBLOB,
"LONGTEXT": TokenType.LONGTEXT,
"MEDIUMBLOB": TokenType.MEDIUMBLOB,
"LONGBLOB": TokenType.LONGBLOB,
"START": TokenType.BEGIN,
"MEDIUMTEXT": TokenType.MEDIUMTEXT,
"SEPARATOR": TokenType.SEPARATOR,
"START": TokenType.BEGIN,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
"_BIG5": TokenType.INTRODUCER,
@ -186,14 +179,15 @@ class MySQL(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"DATE_ADD": _date_add(exp.DateAdd),
"DATE_SUB": _date_add(exp.DateSub),
"STR_TO_DATE": _str_to_date,
"LOCATE": locate_to_strposition,
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
"INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
"LEFT": lambda args: exp.Substring(
this=seq_get(args, 0), start=exp.Literal.number(1), length=seq_get(args, 1)
),
"LOCATE": locate_to_strposition,
"STR_TO_DATE": _str_to_date,
}
FUNCTION_PARSERS = {
@ -388,32 +382,36 @@ class MySQL(Dialect):
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
NULL_ORDERING_SUPPORTED = False
JOIN_HINTS = False
TABLE_HINTS = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression),
exp.DateAdd: _date_add_sql("ADD"),
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _date_add_sql("SUB"),
exp.DateTrunc: _date_trunc_sql,
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
exp.ILike: no_ilike_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.TableSample: no_tablesample_sql,
exp.TryCast: no_trycast_sql,
exp.DateAdd: _date_add_sql("ADD"),
exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})",
exp.DateSub: _date_add_sql("SUB"),
exp.DateTrunc: _date_trunc_sql,
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_date_sql,
exp.Trim: _trim_sql,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_date_sql,
exp.TableSample: no_tablesample_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)),
exp.Trim: _trim_sql,
exp.TryCast: no_trycast_sql,
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
}
TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy()
@ -425,6 +423,7 @@ class MySQL(Dialect):
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
LIMIT_FETCH = "LIMIT"

View file

@ -7,11 +7,6 @@ from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sq
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
TokenType.COLUMN,
TokenType.RETURNING,
}
def _parse_xml_table(self) -> exp.XMLTable:
this = self._parse_string()
@ -22,9 +17,7 @@ def _parse_xml_table(self) -> exp.XMLTable:
if self._match_text_seq("PASSING"):
# The BY VALUE keywords are optional and are provided for semantic clarity
self._match_text_seq("BY", "VALUE")
passing = self._parse_csv(
lambda: self._parse_table(alias_tokens=PASSING_TABLE_ALIAS_TOKENS)
)
passing = self._parse_csv(self._parse_column)
by_ref = self._match_text_seq("RETURNING", "SEQUENCE", "BY", "REF")
@ -68,6 +61,8 @@ class Oracle(Dialect):
}
class Parser(parser.Parser):
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP}
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
@ -78,6 +73,12 @@ class Oracle(Dialect):
"XMLTABLE": _parse_xml_table,
}
TYPE_LITERAL_PARSERS = {
exp.DataType.Type.DATE: lambda self, this, _: self.expression(
exp.DateStrToDate, this=this
)
}
def _parse_column(self) -> t.Optional[exp.Expression]:
column = super()._parse_column()
if column:
@ -100,6 +101,8 @@ class Oracle(Dialect):
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
JOIN_HINTS = False
TABLE_HINTS = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
@ -119,6 +122,9 @@ class Oracle(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
exp.DateStrToDate: lambda self, e: self.func(
"TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD")
),
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.ILike: no_ilike_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
@ -129,6 +135,12 @@ class Oracle(Dialect):
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: trim_sql,
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
exp.IfNull: rename_func("NVL"),
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
LIMIT_FETCH = "FETCH"
@ -142,9 +154,9 @@ class Oracle(Dialect):
def xmltable_sql(self, expression: exp.XMLTable) -> str:
this = self.sql(expression, "this")
passing = self.expressions(expression, "passing")
passing = self.expressions(expression, key="passing")
passing = f"{self.sep()}PASSING{self.seg(passing)}" if passing else ""
columns = self.expressions(expression, "columns")
columns = self.expressions(expression, key="columns")
columns = f"{self.sep()}COLUMNS{self.seg(columns)}" if columns else ""
by_ref = (
f"{self.sep()}RETURNING SEQUENCE BY REF" if expression.args.get("by_ref") else ""

View file

@ -5,6 +5,7 @@ from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
datestrtodate_sql,
format_time_lambda,
max_or_greatest,
min_or_least,
@ -19,7 +20,7 @@ from sqlglot.dialects.dialect import (
from sqlglot.helper import seq_get
from sqlglot.parser import binary_range_parser
from sqlglot.tokens import TokenType
from sqlglot.transforms import delegate, preprocess
from sqlglot.transforms import preprocess, remove_target_from_merge
DATE_DIFF_FACTOR = {
"MICROSECOND": " * 1000000",
@ -239,7 +240,6 @@ class Postgres(Dialect):
"SERIAL": TokenType.SERIAL,
"SMALLSERIAL": TokenType.SMALLSERIAL,
"TEMP": TokenType.TEMPORARY,
"UUID": TokenType.UUID,
"CSTRING": TokenType.PSEUDO_TYPE,
}
@ -248,18 +248,25 @@ class Postgres(Dialect):
"$": TokenType.PARAMETER,
}
VAR_SINGLE_TOKENS = {"$"}
class Parser(parser.Parser):
STRICT_CAST = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"NOW": exp.CurrentTimestamp.from_arg_list,
"TO_TIMESTAMP": _to_timestamp,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
"GENERATE_SERIES": _generate_series,
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1), unit=seq_get(args, 0)
),
"GENERATE_SERIES": _generate_series,
"NOW": exp.CurrentTimestamp.from_arg_list,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
"TO_TIMESTAMP": _to_timestamp,
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
"DATE_PART": lambda self: self._parse_date_part(),
}
BITWISE = {
@ -279,8 +286,21 @@ class Postgres(Dialect):
TokenType.LT_AT: binary_range_parser(exp.ArrayContained),
}
def _parse_date_part(self) -> exp.Expression:
part = self._parse_type()
self._match(TokenType.COMMA)
value = self._parse_bitwise()
if part and part.is_string:
part = exp.Var(this=part.name)
return self.expression(exp.Extract, this=part, expression=value)
class Generator(generator.Generator):
INTERVAL_ALLOWS_PLURAL_FORM = False
LOCKING_READS_SUPPORTED = True
JOIN_HINTS = False
TABLE_HINTS = False
PARAMETER_TOKEN = "$"
TYPE_MAPPING = {
@ -301,7 +321,6 @@ class Postgres(Dialect):
_auto_increment_to_serial,
_serial_to_generated,
],
delegate("columndef_sql"),
),
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
@ -312,6 +331,7 @@ class Postgres(Dialect):
exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: _date_add_sql("+"),
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _date_add_sql("-"),
exp.DateDiff: _date_diff_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
@ -321,6 +341,7 @@ class Postgres(Dialect):
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
exp.Merge: preprocess([remove_target_from_merge]),
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
exp.StrPosition: str_position_sql,
@ -344,4 +365,5 @@ class Postgres(Dialect):
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}

View file

@ -1,5 +1,7 @@
from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
@ -19,20 +21,20 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
def _approx_distinct_sql(self, expression):
def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistinct) -> str:
accuracy = expression.args.get("accuracy")
accuracy = ", " + self.sql(accuracy) if accuracy else ""
return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})"
def _datatype_sql(self, expression):
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
sql = self.datatype_sql(expression)
if expression.this == exp.DataType.Type.TIMESTAMPTZ:
sql = f"{sql} WITH TIME ZONE"
return sql
def _explode_to_unnest_sql(self, expression):
def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str:
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
return self.sql(
exp.Join(
@ -47,22 +49,22 @@ def _explode_to_unnest_sql(self, expression):
return self.lateral_sql(expression)
def _initcap_sql(self, expression):
def _initcap_sql(self: generator.Generator, expression: exp.Initcap) -> str:
regex = r"(\w)(\w*)"
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
def _decode_sql(self, expression):
_ensure_utf8(expression.args.get("charset"))
def _decode_sql(self: generator.Generator, expression: exp.Decode) -> str:
_ensure_utf8(expression.args["charset"])
return self.func("FROM_UTF8", expression.this, expression.args.get("replace"))
def _encode_sql(self, expression):
_ensure_utf8(expression.args.get("charset"))
def _encode_sql(self: generator.Generator, expression: exp.Encode) -> str:
_ensure_utf8(expression.args["charset"])
return f"TO_UTF8({self.sql(expression, 'this')})"
def _no_sort_array(self, expression):
def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str:
if expression.args.get("asc") == exp.false():
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
else:
@ -70,49 +72,62 @@ def _no_sort_array(self, expression):
return self.func("ARRAY_SORT", expression.this, comparator)
def _schema_sql(self, expression):
def _schema_sql(self: generator.Generator, expression: exp.Schema) -> str:
if isinstance(expression.parent, exp.Property):
columns = ", ".join(f"'{c.name}'" for c in expression.expressions)
return f"ARRAY[{columns}]"
for schema in expression.parent.find_all(exp.Schema):
if isinstance(schema.parent, exp.Property):
expression = expression.copy()
expression.expressions.extend(schema.expressions)
if expression.parent:
for schema in expression.parent.find_all(exp.Schema):
if isinstance(schema.parent, exp.Property):
expression = expression.copy()
expression.expressions.extend(schema.expressions)
return self.schema_sql(expression)
def _quantile_sql(self, expression):
def _quantile_sql(self: generator.Generator, expression: exp.Quantile) -> str:
self.unsupported("Presto does not support exact quantiles")
return f"APPROX_PERCENTILE({self.sql(expression, 'this')}, {self.sql(expression, 'quantile')})"
def _str_to_time_sql(self, expression):
def _str_to_time_sql(
self: generator.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate
) -> str:
return f"DATE_PARSE({self.sql(expression, 'this')}, {self.format_time(expression)})"
def _ts_or_ds_to_date_sql(self, expression):
def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
time_format = self.format_time(expression)
if time_format and time_format not in (Presto.time_format, Presto.date_format):
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)"
def _ts_or_ds_add_sql(self, expression):
def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
this = expression.this
if not isinstance(this, exp.CurrentDate):
this = self.func(
"DATE_PARSE",
self.func(
"SUBSTR",
this if this.is_string else exp.cast(this, "VARCHAR"),
exp.Literal.number(1),
exp.Literal.number(10),
),
Presto.date_format,
)
return self.func(
"DATE_ADD",
exp.Literal.string(expression.text("unit") or "day"),
expression.expression,
self.func(
"DATE_PARSE",
self.func("SUBSTR", expression.this, exp.Literal.number(1), exp.Literal.number(10)),
Presto.date_format,
),
this,
)
def _sequence_sql(self, expression):
def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str:
start = expression.args["start"]
end = expression.args["end"]
step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series
@ -135,12 +150,12 @@ def _sequence_sql(self, expression):
return self.func("SEQUENCE", start, end, step)
def _ensure_utf8(charset):
def _ensure_utf8(charset: exp.Literal) -> None:
if charset.name.lower() != "utf-8":
raise UnsupportedError(f"Unsupported charset {charset}")
def _approx_percentile(args):
def _approx_percentile(args: t.Sequence) -> exp.Expression:
if len(args) == 4:
return exp.ApproxQuantile(
this=seq_get(args, 0),
@ -157,7 +172,7 @@ def _approx_percentile(args):
return exp.ApproxQuantile.from_arg_list(args)
def _from_unixtime(args):
def _from_unixtime(args: t.Sequence) -> exp.Expression:
if len(args) == 3:
return exp.UnixToTime(
this=seq_get(args, 0),
@ -226,11 +241,15 @@ class Presto(Dialect):
FUNCTION_PARSERS.pop("TRIM")
class Generator(generator.Generator):
INTERVAL_ALLOWS_PLURAL_FORM = False
JOIN_HINTS = False
TABLE_HINTS = False
STRUCT_DELIMITER = ("(", ")")
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.LocationProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
TYPE_MAPPING = {
@ -246,7 +265,6 @@ class Presto(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
**transforms.ELIMINATE_QUALIFY, # type: ignore
exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"),
@ -284,6 +302,9 @@ class Presto(Dialect):
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.Select: transforms.preprocess(
[transforms.eliminate_qualify, transforms.explode_to_unnest]
),
exp.SortArray: _no_sort_array,
exp.StrPosition: rename_func("STRPOS"),
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
@ -308,7 +329,13 @@ class Presto(Dialect):
exp.VariancePop: rename_func("VAR_POP"),
}
def transaction_sql(self, expression):
def interval_sql(self, expression: exp.Interval) -> str:
unit = self.sql(expression, "unit")
if expression.this and unit.lower().startswith("week"):
return f"({expression.this.name} * INTERVAL '7' day)"
return super().interval_sql(expression)
def transaction_sql(self, expression: exp.Transaction) -> str:
modes = expression.args.get("modes")
modes = f" {', '.join(modes)}" if modes else ""
return f"START TRANSACTION{modes}"

View file

@ -8,6 +8,10 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
def _json_sql(self, e) -> str:
return f'{self.sql(e, "this")}."{e.expression.name}"'
class Redshift(Postgres):
time_format = "'YYYY-MM-DD HH:MI:SS'"
time_mapping = {
@ -56,6 +60,7 @@ class Redshift(Postgres):
"GEOGRAPHY": TokenType.GEOGRAPHY,
"HLLSKETCH": TokenType.HLLSKETCH,
"SUPER": TokenType.SUPER,
"SYSDATE": TokenType.CURRENT_TIMESTAMP,
"TIME": TokenType.TIMESTAMP,
"TIMETZ": TokenType.TIMESTAMPTZ,
"TOP": TokenType.TOP,
@ -63,7 +68,14 @@ class Redshift(Postgres):
"VARBYTE": TokenType.VARBINARY,
}
# Redshift allows # to appear as a table identifier prefix
SINGLE_TOKENS = Postgres.Tokenizer.SINGLE_TOKENS.copy()
SINGLE_TOKENS.pop("#")
class Generator(Postgres.Generator):
LOCKING_READS_SUPPORTED = False
SINGLE_STRING_INTERVAL = True
TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BINARY: "VARBYTE",
@ -79,6 +91,7 @@ class Redshift(Postgres):
TRANSFORMS = {
**Postgres.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
exp.DateAdd: lambda self, e: self.func(
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
),
@ -87,12 +100,16 @@ class Redshift(Postgres):
),
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
exp.JSONExtract: _json_sql,
exp.JSONExtractScalar: _json_sql,
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
}
# Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres)
TRANSFORMS.pop(exp.Pow)
RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot"}
def values_sql(self, expression: exp.Values) -> str:
"""
Converts `VALUES...` expression into a series of unions.

View file

@ -23,14 +23,14 @@ from sqlglot.parser import binary_range_parser
from sqlglot.tokens import TokenType
def _check_int(s):
def _check_int(s: str) -> bool:
if s[0] in ("-", "+"):
return s[1:].isdigit()
return s.isdigit()
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
def _snowflake_to_timestamp(args):
def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.UnixToTime]:
if len(args) == 2:
first_arg, second_arg = args
if second_arg.is_string:
@ -69,7 +69,7 @@ def _snowflake_to_timestamp(args):
return exp.UnixToTime.from_arg_list(args)
def _unix_to_time_sql(self, expression):
def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale in [None, exp.UnixToTime.SECONDS]:
@ -84,8 +84,12 @@ def _unix_to_time_sql(self, expression):
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts
def _parse_date_part(self):
def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]:
this = self._parse_var() or self._parse_type()
if not this:
return None
self._match(TokenType.COMMA)
expression = self._parse_bitwise()
@ -101,7 +105,7 @@ def _parse_date_part(self):
scale = None
ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP"))
to_unix = self.expression(exp.TimeToUnix, this=ts)
to_unix: exp.Expression = self.expression(exp.TimeToUnix, this=ts)
if scale:
to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale))
@ -112,7 +116,7 @@ def _parse_date_part(self):
# https://docs.snowflake.com/en/sql-reference/functions/div0
def _div0_to_if(args):
def _div0_to_if(args: t.Sequence) -> exp.Expression:
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
true = exp.Literal.number(0)
false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1))
@ -120,18 +124,18 @@ def _div0_to_if(args):
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
def _zeroifnull_to_if(args):
def _zeroifnull_to_if(args: t.Sequence) -> exp.Expression:
cond = exp.Is(this=seq_get(args, 0), expression=exp.Null())
return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0))
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
def _nullifzero_to_if(args):
def _nullifzero_to_if(args: t.Sequence) -> exp.Expression:
cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0))
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
def _datatype_sql(self, expression):
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
if expression.this == exp.DataType.Type.ARRAY:
return "ARRAY"
elif expression.this == exp.DataType.Type.MAP:
@ -155,9 +159,8 @@ class Snowflake(Dialect):
"MM": "%m",
"mm": "%m",
"DD": "%d",
"dd": "%d",
"d": "%-d",
"DY": "%w",
"dd": "%-d",
"DY": "%a",
"dy": "%w",
"HH24": "%H",
"hh24": "%H",
@ -174,6 +177,8 @@ class Snowflake(Dialect):
}
class Parser(parser.Parser):
QUOTED_PIVOT_COLUMNS = True
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
@ -269,9 +274,14 @@ class Snowflake(Dialect):
"$": TokenType.PARAMETER,
}
VAR_SINGLE_TOKENS = {"$"}
class Generator(generator.Generator):
PARAMETER_TOKEN = "$"
MATCHED_BY_SOURCE = False
SINGLE_STRING_INTERVAL = True
JOIN_HINTS = False
TABLE_HINTS = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
@ -287,26 +297,30 @@ class Snowflake(Dialect):
),
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.If: rename_func("IFF"),
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.LogicalOr: rename_func("BOOLOR_AGG"),
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.LogicalOr: rename_func("BOOLOR_AGG"),
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
exp.StrPosition: lambda self, e: self.func(
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
exp.TimeToStr: lambda self, e: self.func(
"TO_CHAR", exp.cast(e.this, "timestamp"), self.format_time(e)
),
exp.TimestampTrunc: timestamptrunc_sql,
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"),
exp.UnixToTime: _unix_to_time_sql,
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
}
TYPE_MAPPING = {
@ -322,14 +336,15 @@ class Snowflake(Dialect):
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.SetProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
def except_op(self, expression):
def except_op(self, expression: exp.Except) -> str:
if not expression.args.get("distinct", False):
self.unsupported("EXCEPT with All is not supported in Snowflake")
return super().except_op(expression)
def intersect_op(self, expression):
def intersect_op(self, expression: exp.Intersect) -> str:
if not expression.args.get("distinct", False):
self.unsupported("INTERSECT with All is not supported in Snowflake")
return super().intersect_op(expression)

View file

@ -1,13 +1,15 @@
from __future__ import annotations
import typing as t
from sqlglot import exp, parser
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
from sqlglot.dialects.hive import Hive
from sqlglot.helper import seq_get
def _create_sql(self, e):
kind = e.args.get("kind")
def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
kind = e.args["kind"]
properties = e.args.get("properties")
if kind.upper() == "TABLE" and any(
@ -18,13 +20,13 @@ def _create_sql(self, e):
return create_with_partitions_sql(self, e)
def _map_sql(self, expression):
def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
keys = self.sql(expression.args["keys"])
values = self.sql(expression.args["values"])
return f"MAP_FROM_ARRAYS({keys}, {values})"
def _str_to_date(self, expression):
def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Hive.date_format:
@ -32,7 +34,7 @@ def _str_to_date(self, expression):
return f"TO_DATE({this}, {time_format})"
def _unix_to_time(self, expression):
def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale is None:
@ -75,7 +77,11 @@ class Spark(Hive):
length=seq_get(args, 1),
),
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"BOOLEAN": lambda args: exp.Cast(
this=seq_get(args, 0), to=exp.DataType.build("boolean")
),
"IIF": exp.If.from_arg_list,
"INT": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("int")),
"AGGREGATE": exp.Reduce.from_arg_list,
"DAYOFWEEK": lambda args: exp.DayOfWeek(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
@ -89,11 +95,16 @@ class Spark(Hive):
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1),
unit=exp.var(seq_get(args, 0)),
),
"STRING": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("string")),
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
"TIMESTAMP": lambda args: exp.Cast(
this=seq_get(args, 0), to=exp.DataType.build("timestamp")
),
}
FUNCTION_PARSERS = {
@ -108,16 +119,43 @@ class Spark(Hive):
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
}
def _parse_add_column(self):
def _parse_add_column(self) -> t.Optional[exp.Expression]:
return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
def _parse_drop_column(self):
def _parse_drop_column(self) -> t.Optional[exp.Expression]:
return self._match_text_seq("DROP", "COLUMNS") and self.expression(
exp.Drop,
this=self._parse_schema(),
kind="COLUMNS",
)
def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
# Spark doesn't add a suffix to the pivot columns when there's a single aggregation
if len(pivot_columns) == 1:
return [""]
names = []
for agg in pivot_columns:
if isinstance(agg, exp.Alias):
names.append(agg.alias)
else:
"""
This case corresponds to aggregations without aliases being used as suffixes
(e.g. col_avg(foo)). We need to unquote identifiers because they're going to
be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
Moreover, function names are lowercased in order to mimic Spark's naming scheme.
"""
agg_all_unquoted = agg.transform(
lambda node: exp.Identifier(this=node.name, quoted=False)
if isinstance(node, exp.Identifier)
else node
)
names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
return names
class Generator(Hive.Generator):
TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING, # type: ignore
@ -145,7 +183,7 @@ class Spark(Hive):
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: _unix_to_time,
exp.UnixToTime: _unix_to_time_sql,
exp.Create: _create_sql,
exp.Map: _map_sql,
exp.Reduce: rename_func("AGGREGATE"),

View file

@ -16,7 +16,7 @@ from sqlglot.tokens import TokenType
def _date_add_sql(self, expression):
modifier = expression.expression
modifier = expression.name if modifier.is_string else self.sql(modifier)
modifier = modifier.name if modifier.is_string else self.sql(modifier)
unit = expression.args.get("unit")
modifier = f"'{modifier} {unit.name}'" if unit else f"'{modifier}'"
return self.func("DATE", expression.this, modifier)
@ -38,6 +38,9 @@ class SQLite(Dialect):
}
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BOOLEAN: "INTEGER",
@ -82,6 +85,11 @@ class SQLite(Dialect):
exp.TryCast: no_trycast_sql,
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
LIMIT_FETCH = "LIMIT"
def cast_sql(self, expression: exp.Cast) -> str:

View file

@ -1,7 +1,11 @@
from __future__ import annotations
from sqlglot import exp
from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func
from sqlglot.dialects.dialect import (
approx_count_distinct_sql,
arrow_json_extract_sql,
rename_func,
)
from sqlglot.dialects.mysql import MySQL
from sqlglot.helper import seq_get
@ -10,6 +14,7 @@ class StarRocks(MySQL):
class Parser(MySQL.Parser): # type: ignore
FUNCTIONS = {
**MySQL.Parser.FUNCTIONS,
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1), unit=seq_get(args, 0)
),
@ -25,6 +30,7 @@ class StarRocks(MySQL):
TRANSFORMS = {
**MySQL.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: approx_count_distinct_sql,
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.DateDiff: rename_func("DATEDIFF"),

View file

@ -21,6 +21,9 @@ def _count_sql(self, expression):
class Tableau(Dialect):
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
exp.If: _if_sql,
@ -28,6 +31,11 @@ class Tableau(Dialect):
exp.Count: _count_sql,
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore

View file

@ -1,7 +1,14 @@
from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least
from sqlglot.dialects.dialect import (
Dialect,
format_time_lambda,
max_or_greatest,
min_or_least,
)
from sqlglot.tokens import TokenType
@ -115,7 +122,18 @@ class Teradata(Dialect):
return self.expression(exp.RangeN, this=this, expressions=expressions, each=each)
def _parse_cast(self, strict: bool) -> exp.Expression:
cast = t.cast(exp.Cast, super()._parse_cast(strict))
if cast.to.this == exp.DataType.Type.DATE and self._match(TokenType.FORMAT):
return format_time_lambda(exp.TimeToStr, "teradata")(
[cast.this, self._parse_string()]
)
return cast
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.GEOMETRY: "ST_GEOMETRY",
@ -130,6 +148,7 @@ class Teradata(Dialect):
**generator.Generator.TRANSFORMS,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.TimeToStr: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}

View file

@ -96,6 +96,23 @@ def _parse_eomonth(args):
return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit))
def _parse_hashbytes(args):
kind, data = args
kind = kind.name.upper() if kind.is_string else ""
if kind == "MD5":
args.pop(0)
return exp.MD5(this=data)
if kind in ("SHA", "SHA1"):
args.pop(0)
return exp.SHA(this=data)
if kind == "SHA2_256":
return exp.SHA2(this=data, length=exp.Literal.number(256))
if kind == "SHA2_512":
return exp.SHA2(this=data, length=exp.Literal.number(512))
return exp.func("HASHBYTES", *args)
def generate_date_delta_with_unit_sql(self, e):
func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF"
return self.func(func, e.text("unit"), e.expression, e.this)
@ -266,6 +283,7 @@ class TSQL(Dialect):
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
"VARCHAR(MAX)": TokenType.TEXT,
"XML": TokenType.XML,
"SYSTEM_USER": TokenType.CURRENT_USER,
}
# TSQL allows @, # to appear as a variable/identifier prefix
@ -287,6 +305,7 @@ class TSQL(Dialect):
"EOMONTH": _parse_eomonth,
"FORMAT": _parse_format,
"GETDATE": exp.CurrentTimestamp.from_arg_list,
"HASHBYTES": _parse_hashbytes,
"IIF": exp.If.from_arg_list,
"ISNULL": exp.Coalesce.from_arg_list,
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
@ -296,6 +315,14 @@ class TSQL(Dialect):
"SYSDATETIME": exp.CurrentTimestamp.from_arg_list,
"SUSER_NAME": exp.CurrentUser.from_arg_list,
"SUSER_SNAME": exp.CurrentUser.from_arg_list,
"SYSTEM_USER": exp.CurrentUser.from_arg_list,
}
JOIN_HINTS = {
"LOOP",
"HASH",
"MERGE",
"REMOTE",
}
VAR_LENGTH_DATATYPES = {
@ -441,11 +468,21 @@ class TSQL(Dialect):
exp.TimeToStr: _format_sql,
exp.GroupConcat: _string_agg_sql,
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
exp.Min: min_or_least,
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
exp.SHA2: lambda self, e: self.func(
"HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
),
}
TRANSFORMS.pop(exp.ReturnsProperty)
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
LIMIT_FETCH = "FETCH"
def offset_sql(self, expression: exp.Offset) -> str:

View file

@ -701,6 +701,119 @@ class Condition(Expression):
"""
return not_(self)
def _binop(self, klass: t.Type[E], other: ExpOrStr, reverse=False) -> E:
this = self
other = convert(other)
if not isinstance(this, klass) and not isinstance(other, klass):
this = _wrap(this, Binary)
other = _wrap(other, Binary)
if reverse:
return klass(this=other, expression=this)
return klass(this=this, expression=other)
def __getitem__(self, other: ExpOrStr | slice | t.Tuple[ExpOrStr]):
if isinstance(other, slice):
return Between(
this=self,
low=convert(other.start),
high=convert(other.stop),
)
return Bracket(this=self, expressions=[convert(e) for e in ensure_list(other)])
def isin(self, *expressions: ExpOrStr, query: t.Optional[ExpOrStr] = None, **opts) -> In:
return In(
this=self,
expressions=[convert(e) for e in expressions],
query=maybe_parse(query, **opts) if query else None,
)
def like(self, other: ExpOrStr) -> Like:
return self._binop(Like, other)
def ilike(self, other: ExpOrStr) -> ILike:
return self._binop(ILike, other)
def eq(self, other: ExpOrStr) -> EQ:
return self._binop(EQ, other)
def neq(self, other: ExpOrStr) -> NEQ:
return self._binop(NEQ, other)
def rlike(self, other: ExpOrStr) -> RegexpLike:
return self._binop(RegexpLike, other)
def __lt__(self, other: ExpOrStr) -> LT:
return self._binop(LT, other)
def __le__(self, other: ExpOrStr) -> LTE:
return self._binop(LTE, other)
def __gt__(self, other: ExpOrStr) -> GT:
return self._binop(GT, other)
def __ge__(self, other: ExpOrStr) -> GTE:
return self._binop(GTE, other)
def __add__(self, other: ExpOrStr) -> Add:
return self._binop(Add, other)
def __radd__(self, other: ExpOrStr) -> Add:
return self._binop(Add, other, reverse=True)
def __sub__(self, other: ExpOrStr) -> Sub:
return self._binop(Sub, other)
def __rsub__(self, other: ExpOrStr) -> Sub:
return self._binop(Sub, other, reverse=True)
def __mul__(self, other: ExpOrStr) -> Mul:
return self._binop(Mul, other)
def __rmul__(self, other: ExpOrStr) -> Mul:
return self._binop(Mul, other, reverse=True)
def __truediv__(self, other: ExpOrStr) -> Div:
return self._binop(Div, other)
def __rtruediv__(self, other: ExpOrStr) -> Div:
return self._binop(Div, other, reverse=True)
def __floordiv__(self, other: ExpOrStr) -> IntDiv:
return self._binop(IntDiv, other)
def __rfloordiv__(self, other: ExpOrStr) -> IntDiv:
return self._binop(IntDiv, other, reverse=True)
def __mod__(self, other: ExpOrStr) -> Mod:
return self._binop(Mod, other)
def __rmod__(self, other: ExpOrStr) -> Mod:
return self._binop(Mod, other, reverse=True)
def __pow__(self, other: ExpOrStr) -> Pow:
return self._binop(Pow, other)
def __rpow__(self, other: ExpOrStr) -> Pow:
return self._binop(Pow, other, reverse=True)
def __and__(self, other: ExpOrStr) -> And:
return self._binop(And, other)
def __rand__(self, other: ExpOrStr) -> And:
return self._binop(And, other, reverse=True)
def __or__(self, other: ExpOrStr) -> Or:
return self._binop(Or, other)
def __ror__(self, other: ExpOrStr) -> Or:
return self._binop(Or, other, reverse=True)
def __neg__(self) -> Neg:
return Neg(this=_wrap(self, Binary))
def __invert__(self) -> Not:
return not_(self)
class Predicate(Condition):
"""Relationships like x = y, x > 1, x >= y."""
@ -818,7 +931,6 @@ class Create(Expression):
"properties": False,
"replace": False,
"unique": False,
"volatile": False,
"indexes": False,
"no_schema_binding": False,
"begin": False,
@ -1053,6 +1165,11 @@ class NotNullColumnConstraint(ColumnConstraintKind):
arg_types = {"allow_null": False}
# https://dev.mysql.com/doc/refman/5.7/en/timestamp-initialization.html
class OnUpdateColumnConstraint(ColumnConstraintKind):
pass
class PrimaryKeyColumnConstraint(ColumnConstraintKind):
arg_types = {"desc": False}
@ -1197,6 +1314,7 @@ class Drop(Expression):
"materialized": False,
"cascade": False,
"constraints": False,
"purge": False,
}
@ -1287,6 +1405,7 @@ class Insert(Expression):
"with": False,
"this": True,
"expression": False,
"conflict": False,
"returning": False,
"overwrite": False,
"exists": False,
@ -1295,6 +1414,16 @@ class Insert(Expression):
}
class OnConflict(Expression):
arg_types = {
"duplicate": False,
"expressions": False,
"nothing": False,
"key": False,
"constraint": False,
}
class Returning(Expression):
arg_types = {"expressions": True}
@ -1326,7 +1455,12 @@ class Partition(Expression):
class Fetch(Expression):
arg_types = {"direction": False, "count": False}
arg_types = {
"direction": False,
"count": False,
"percent": False,
"with_ties": False,
}
class Group(Expression):
@ -1374,6 +1508,7 @@ class Join(Expression):
"kind": False,
"using": False,
"natural": False,
"hint": False,
}
@property
@ -1384,6 +1519,10 @@ class Join(Expression):
def side(self):
return self.text("side").upper()
@property
def hint(self):
return self.text("hint").upper()
@property
def alias_or_name(self):
return self.this.alias_or_name
@ -1475,6 +1614,7 @@ class MatchRecognize(Expression):
"after": False,
"pattern": False,
"define": False,
"alias": False,
}
@ -1582,6 +1722,10 @@ class FreespaceProperty(Property):
arg_types = {"this": True, "percent": False}
class InputOutputFormat(Expression):
arg_types = {"input_format": False, "output_format": False}
class IsolatedLoadingProperty(Property):
arg_types = {
"no": True,
@ -1646,6 +1790,10 @@ class ReturnsProperty(Property):
arg_types = {"this": True, "is_table": False, "table": False}
class RowFormatProperty(Property):
arg_types = {"this": True}
class RowFormatDelimitedProperty(Property):
# https://cwiki.apache.org/confluence/display/hive/languagemanual+dml
arg_types = {
@ -1683,6 +1831,10 @@ class SqlSecurityProperty(Property):
arg_types = {"definer": True}
class StabilityProperty(Property):
arg_types = {"this": True}
class TableFormatProperty(Property):
arg_types = {"this": True}
@ -1695,8 +1847,8 @@ class TransientProperty(Property):
arg_types = {"this": False}
class VolatilityProperty(Property):
arg_types = {"this": True}
class VolatileProperty(Property):
arg_types = {"this": False}
class WithDataProperty(Property):
@ -1726,6 +1878,7 @@ class Properties(Expression):
"LOCATION": LocationProperty,
"PARTITIONED_BY": PartitionedByProperty,
"RETURNS": ReturnsProperty,
"ROW_FORMAT": RowFormatProperty,
"SORTKEY": SortKeyProperty,
"TABLE_FORMAT": TableFormatProperty,
}
@ -2721,6 +2874,7 @@ class Pivot(Expression):
"expressions": True,
"field": True,
"unpivot": True,
"columns": False,
}
@ -2731,6 +2885,8 @@ class Window(Expression):
"order": False,
"spec": False,
"alias": False,
"over": False,
"first": False,
}
@ -2816,6 +2972,7 @@ class DataType(Expression):
FLOAT = auto()
DOUBLE = auto()
DECIMAL = auto()
BIGDECIMAL = auto()
BIT = auto()
BOOLEAN = auto()
JSON = auto()
@ -2964,7 +3121,7 @@ class DropPartition(Expression):
# Binary expressions like (ADD a b)
class Binary(Expression):
class Binary(Condition):
arg_types = {"this": True, "expression": True}
@property
@ -2980,7 +3137,7 @@ class Add(Binary):
pass
class Connector(Binary, Condition):
class Connector(Binary):
pass
@ -3142,7 +3299,7 @@ class ArrayOverlaps(Binary):
# Unary Expressions
# (NOT a)
class Unary(Expression):
class Unary(Condition):
pass
@ -3150,11 +3307,11 @@ class BitwiseNot(Unary):
pass
class Not(Unary, Condition):
class Not(Unary):
pass
class Paren(Unary, Condition):
class Paren(Unary):
arg_types = {"this": True, "with": False}
@ -3162,7 +3319,6 @@ class Neg(Unary):
pass
# Special Functions
class Alias(Expression):
arg_types = {"this": True, "alias": False}
@ -3381,6 +3537,16 @@ class AnyValue(AggFunc):
class Case(Func):
arg_types = {"this": False, "ifs": True, "default": False}
def when(self, condition: ExpOrStr, then: ExpOrStr, copy: bool = True, **opts) -> Case:
this = self.copy() if copy else self
this.append("ifs", If(this=maybe_parse(condition, **opts), true=maybe_parse(then, **opts)))
return this
def else_(self, condition: ExpOrStr, copy: bool = True, **opts) -> Case:
this = self.copy() if copy else self
this.set("default", maybe_parse(condition, **opts))
return this
class Cast(Func):
arg_types = {"this": True, "to": True}
@ -3719,6 +3885,10 @@ class Map(Func):
arg_types = {"keys": False, "values": False}
class StarMap(Func):
pass
class VarMap(Func):
arg_types = {"keys": True, "values": True}
is_var_len_args = True
@ -3734,6 +3904,10 @@ class Max(AggFunc):
is_var_len_args = True
class MD5(Func):
_sql_names = ["MD5"]
class Min(AggFunc):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@ -3840,6 +4014,15 @@ class SetAgg(AggFunc):
pass
class SHA(Func):
_sql_names = ["SHA", "SHA1"]
class SHA2(Func):
_sql_names = ["SHA2"]
arg_types = {"this": True, "length": False}
class SortArray(Func):
arg_types = {"this": True, "asc": False}
@ -4017,6 +4200,12 @@ class When(Func):
arg_types = {"matched": True, "source": False, "condition": False, "then": True}
# https://docs.oracle.com/javadb/10.8.3.0/ref/rrefsqljnextvaluefor.html
# https://learn.microsoft.com/en-us/sql/t-sql/functions/next-value-for-transact-sql?view=sql-server-ver16
class NextValueFor(Func):
arg_types = {"this": True, "order": False}
def _norm_arg(arg):
return arg.lower() if type(arg) is str else arg
@ -4025,6 +4214,32 @@ ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
# Helpers
@t.overload
def maybe_parse(
sql_or_expression: ExpOrStr,
*,
into: t.Type[E],
dialect: DialectType = None,
prefix: t.Optional[str] = None,
copy: bool = False,
**opts,
) -> E:
...
@t.overload
def maybe_parse(
sql_or_expression: str | E,
*,
into: t.Optional[IntoType] = None,
dialect: DialectType = None,
prefix: t.Optional[str] = None,
copy: bool = False,
**opts,
) -> E:
...
def maybe_parse(
sql_or_expression: ExpOrStr,
*,
@ -4200,15 +4415,15 @@ def _combine(expressions, operator, dialect=None, **opts):
expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions]
this = expressions[0]
if expressions[1:]:
this = _wrap_operator(this)
this = _wrap(this, Connector)
for expression in expressions[1:]:
this = operator(this=this, expression=_wrap_operator(expression))
this = operator(this=this, expression=_wrap(expression, Connector))
return this
def _wrap_operator(expression):
if isinstance(expression, (And, Or, Not)):
expression = Paren(this=expression)
def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren:
if isinstance(expression, kind):
return Paren(this=expression)
return expression
@ -4506,7 +4721,7 @@ def not_(expression, dialect=None, **opts) -> Not:
dialect=dialect,
**opts,
)
return Not(this=_wrap_operator(this))
return Not(this=_wrap(this, Connector))
def paren(expression) -> Paren:
@ -4657,6 +4872,8 @@ def alias_(
if table:
table_alias = TableAlias(this=alias)
exp = exp.copy() if isinstance(expression, Expression) else exp
exp.set("alias", table_alias)
if not isinstance(table, bool):
@ -4864,16 +5081,22 @@ def convert(value) -> Expression:
"""
if isinstance(value, Expression):
return value
if value is None:
return NULL
if isinstance(value, bool):
return Boolean(this=value)
if isinstance(value, str):
return Literal.string(value)
if isinstance(value, float) and math.isnan(value):
if isinstance(value, bool):
return Boolean(this=value)
if value is None or (isinstance(value, float) and math.isnan(value)):
return NULL
if isinstance(value, numbers.Number):
return Literal.number(value)
if isinstance(value, datetime.datetime):
datetime_literal = Literal.string(
(value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat()
)
return TimeStrToTime(this=datetime_literal)
if isinstance(value, datetime.date):
date_literal = Literal.string(value.strftime("%Y-%m-%d"))
return DateStrToDate(this=date_literal)
if isinstance(value, tuple):
return Tuple(expressions=[convert(v) for v in value])
if isinstance(value, list):
@ -4883,14 +5106,6 @@ def convert(value) -> Expression:
keys=[convert(k) for k in value],
values=[convert(v) for v in value.values()],
)
if isinstance(value, datetime.datetime):
datetime_literal = Literal.string(
(value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat()
)
return TimeStrToTime(this=datetime_literal)
if isinstance(value, datetime.date):
date_literal = Literal.string(value.strftime("%Y-%m-%d"))
return DateStrToDate(this=date_literal)
raise ValueError(f"Cannot convert {value}")
@ -5030,7 +5245,9 @@ def replace_placeholders(expression, *args, **kwargs):
return expression.transform(_replace_placeholders, iter(args), **kwargs)
def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True) -> Expression:
def expand(
expression: Expression, sources: t.Dict[str, Subqueryable], copy: bool = True
) -> Expression:
"""Transforms an expression by expanding all referenced sources into subqueries.
Examples:
@ -5038,6 +5255,9 @@ def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True
>>> expand(parse_one("select * from x AS z"), {"x": parse_one("select * from y")}).sql()
'SELECT * FROM (SELECT * FROM y) AS z /* source: x */'
>>> expand(parse_one("select * from x AS z"), {"x": parse_one("select * from y"), "y": parse_one("select * from z")}).sql()
'SELECT * FROM (SELECT * FROM (SELECT * FROM z) AS y /* source: y */) AS z /* source: x */'
Args:
expression: The expression to expand.
sources: A dictionary of name to Subqueryables.
@ -5054,7 +5274,7 @@ def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True
if source:
subquery = source.subquery(node.alias or name)
subquery.comments = [f"source: {name}"]
return subquery
return subquery.transform(_expand, copy=False)
return node
return expression.transform(_expand, copy=copy)
@ -5089,8 +5309,8 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
from sqlglot.dialects.dialect import Dialect
converted = [convert(arg) for arg in args]
kwargs = {key: convert(value) for key, value in kwargs.items()}
converted: t.List[Expression] = [maybe_parse(arg, dialect=dialect) for arg in args]
kwargs = {key: maybe_parse(value, dialect=dialect) for key, value in kwargs.items()}
parser = Dialect.get_or_raise(dialect)().parser()
from_args_list = parser.FUNCTIONS.get(name.upper())

View file

@ -76,11 +76,13 @@ class Generator:
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
exp.TemporaryProperty: lambda self, e: f"{'GLOBAL ' if e.args.get('global_') else ''}TEMPORARY",
exp.TransientProperty: lambda self, e: "TRANSIENT",
exp.VolatilityProperty: lambda self, e: e.name,
exp.StabilityProperty: lambda self, e: e.name,
exp.VolatileProperty: lambda self, e: "VOLATILE",
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
exp.UppercaseColumnConstraint: lambda self, e: f"UPPERCASE",
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
@ -110,8 +112,19 @@ class Generator:
# Whether or not MERGE ... WHEN MATCHED BY SOURCE is allowed
MATCHED_BY_SOURCE = True
# Whether or not limit and fetch are supported
# "ALL", "LIMIT", "FETCH"
# Whether or not the INTERVAL expression works only with values like '1 day'
SINGLE_STRING_INTERVAL = False
# Whether or not the plural form of date parts like day (i.e. "days") is supported in INTERVALs
INTERVAL_ALLOWS_PLURAL_FORM = True
# Whether or not the TABLESAMPLE clause supports a method name, like BERNOULLI
TABLESAMPLE_WITH_METHOD = True
# Whether or not to treat the number in TABLESAMPLE (50) as a percentage
TABLESAMPLE_SIZE_IS_PERCENT = False
# Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH")
LIMIT_FETCH = "ALL"
TYPE_MAPPING = {
@ -129,6 +142,18 @@ class Generator:
"replace": "REPLACE",
}
TIME_PART_SINGULARS = {
"microseconds": "microsecond",
"seconds": "second",
"minutes": "minute",
"hours": "hour",
"days": "day",
"weeks": "week",
"months": "month",
"quarters": "quarter",
"years": "year",
}
TOKEN_MAPPING: t.Dict[TokenType, str] = {}
STRUCT_DELIMITER = ("<", ">")
@ -168,6 +193,7 @@ class Generator:
exp.PartitionedByProperty: exp.Properties.Location.POST_WITH,
exp.Property: exp.Properties.Location.POST_WITH,
exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA,
exp.RowFormatProperty: exp.Properties.Location.POST_SCHEMA,
exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA,
exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA,
exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA,
@ -175,15 +201,22 @@ class Generator:
exp.SetProperty: exp.Properties.Location.POST_CREATE,
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA,
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA,
exp.TableFormatProperty: exp.Properties.Location.POST_WITH,
exp.TemporaryProperty: exp.Properties.Location.POST_CREATE,
exp.TransientProperty: exp.Properties.Location.POST_CREATE,
exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION,
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
}
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
JOIN_HINTS = True
TABLE_HINTS = True
RESERVED_KEYWORDS: t.Set[str] = set()
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.With)
UNWRAPPED_INTERVAL_VALUES = (exp.Literal, exp.Paren, exp.Column)
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
__slots__ = (
@ -322,10 +355,15 @@ class Generator:
comment = comment + " " if comment[-1].strip() else comment
return comment
def maybe_comment(self, sql: str, expression: exp.Expression) -> str:
comments = expression.comments if self._comments else None
def maybe_comment(
self,
sql: str,
expression: t.Optional[exp.Expression] = None,
comments: t.Optional[t.List[str]] = None,
) -> str:
comments = (comments or (expression and expression.comments)) if self._comments else None # type: ignore
if not comments:
if not comments or isinstance(expression, exp.Binary):
return sql
sep = "\n" if self.pretty else " "
@ -621,7 +659,6 @@ class Generator:
replace = " OR REPLACE" if expression.args.get("replace") else ""
unique = " UNIQUE" if expression.args.get("unique") else ""
volatile = " VOLATILE" if expression.args.get("volatile") else ""
postcreate_props_sql = ""
if properties_locs.get(exp.Properties.Location.POST_CREATE):
@ -632,7 +669,7 @@ class Generator:
wrapped=False,
)
modifiers = "".join((replace, unique, volatile, postcreate_props_sql))
modifiers = "".join((replace, unique, postcreate_props_sql))
postexpression_props_sql = ""
if properties_locs.get(exp.Properties.Location.POST_EXPRESSION):
@ -684,6 +721,9 @@ class Generator:
def hexstring_sql(self, expression: exp.HexString) -> str:
return self.sql(expression, "this")
def bytestring_sql(self, expression: exp.ByteString) -> str:
return self.sql(expression, "this")
def datatype_sql(self, expression: exp.DataType) -> str:
type_value = expression.this
type_sql = self.TYPE_MAPPING.get(type_value, type_value.value)
@ -695,9 +735,7 @@ class Generator:
nested = f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}"
if expression.args.get("values") is not None:
delimiters = ("[", "]") if type_value == exp.DataType.Type.ARRAY else ("(", ")")
values = (
f"{delimiters[0]}{self.expressions(expression, 'values')}{delimiters[1]}"
)
values = f"{delimiters[0]}{self.expressions(expression, key='values')}{delimiters[1]}"
else:
nested = f"({interior})"
@ -713,7 +751,7 @@ class Generator:
this = self.sql(expression, "this")
this = f" FROM {this}" if this else ""
using_sql = (
f" USING {self.expressions(expression, 'using', sep=', USING ')}"
f" USING {self.expressions(expression, key='using', sep=', USING ')}"
if expression.args.get("using")
else ""
)
@ -730,7 +768,10 @@ class Generator:
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
cascade = " CASCADE" if expression.args.get("cascade") else ""
constraints = " CONSTRAINTS" if expression.args.get("constraints") else ""
return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}{constraints}"
purge = " PURGE" if expression.args.get("purge") else ""
return (
f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}{constraints}{purge}"
)
def except_sql(self, expression: exp.Except) -> str:
return self.prepend_ctes(
@ -746,7 +787,10 @@ class Generator:
direction = f" {direction.upper()}" if direction else ""
count = expression.args.get("count")
count = f" {count}" if count else ""
return f"{self.seg('FETCH')}{direction}{count} ROWS ONLY"
if expression.args.get("percent"):
count = f"{count} PERCENT"
with_ties_or_only = "WITH TIES" if expression.args.get("with_ties") else "ONLY"
return f"{self.seg('FETCH')}{direction}{count} ROWS {with_ties_or_only}"
def filter_sql(self, expression: exp.Filter) -> str:
this = self.sql(expression, "this")
@ -766,12 +810,24 @@ class Generator:
def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
text = text.lower() if self.normalize and not expression.quoted else text
lower = text.lower()
text = lower if self.normalize and not expression.quoted else text
text = text.replace(self.identifier_end, self._escaped_identifier_end)
if expression.quoted or should_identify(text, self.identify):
if (
expression.quoted
or should_identify(text, self.identify)
or lower in self.RESERVED_KEYWORDS
):
text = f"{self.identifier_start}{text}{self.identifier_end}"
return text
def inputoutputformat_sql(self, expression: exp.InputOutputFormat) -> str:
input_format = self.sql(expression, "input_format")
input_format = f"INPUTFORMAT {input_format}" if input_format else ""
output_format = self.sql(expression, "output_format")
output_format = f"OUTPUTFORMAT {output_format}" if output_format else ""
return self.sep().join((input_format, output_format))
def national_sql(self, expression: exp.National) -> str:
return f"N{self.sql(expression, 'this')}"
@ -984,9 +1040,10 @@ class Generator:
self.sql(expression, "partition") if expression.args.get("partition") else ""
)
expression_sql = self.sql(expression, "expression")
conflict = self.sql(expression, "conflict")
returning = self.sql(expression, "returning")
sep = self.sep() if partition_sql else ""
sql = f"INSERT{alternative}{this}{exists}{partition_sql}{sep}{expression_sql}{returning}"
sql = f"INSERT{alternative}{this}{exists}{partition_sql}{sep}{expression_sql}{conflict}{returning}"
return self.prepend_ctes(expression, sql)
def intersect_sql(self, expression: exp.Intersect) -> str:
@ -1004,6 +1061,19 @@ class Generator:
def pseudotype_sql(self, expression: exp.PseudoType) -> str:
return expression.name.upper()
def onconflict_sql(self, expression: exp.OnConflict) -> str:
conflict = "ON DUPLICATE KEY" if expression.args.get("duplicate") else "ON CONFLICT"
constraint = self.sql(expression, "constraint")
if constraint:
constraint = f"ON CONSTRAINT {constraint}"
key = self.expressions(expression, key="key", flat=True)
do = "" if expression.args.get("duplicate") else " DO "
nothing = "NOTHING" if expression.args.get("nothing") else ""
expressions = self.expressions(expression, flat=True)
if expressions:
expressions = f"UPDATE SET {expressions}"
return f"{self.seg(conflict)} {constraint}{key}{do}{nothing}{expressions}"
def returning_sql(self, expression: exp.Returning) -> str:
return f"{self.seg('RETURNING')} {self.expressions(expression, flat=True)}"
@ -1036,7 +1106,7 @@ class Generator:
alias = self.sql(expression, "alias")
alias = f"{sep}{alias}" if alias else ""
hints = self.expressions(expression, key="hints", sep=", ", flat=True)
hints = f" WITH ({hints})" if hints else ""
hints = f" WITH ({hints})" if hints and self.TABLE_HINTS else ""
laterals = self.expressions(expression, key="laterals", sep="")
joins = self.expressions(expression, key="joins", sep="")
pivots = self.expressions(expression, key="pivots", sep="")
@ -1053,7 +1123,7 @@ class Generator:
this = self.sql(expression, "this")
alias = ""
method = self.sql(expression, "method")
method = f"{method.upper()} " if method else ""
method = f"{method.upper()} " if method and self.TABLESAMPLE_WITH_METHOD else ""
numerator = self.sql(expression, "bucket_numerator")
denominator = self.sql(expression, "bucket_denominator")
field = self.sql(expression, "bucket_field")
@ -1064,6 +1134,8 @@ class Generator:
rows = self.sql(expression, "rows")
rows = f"{rows} ROWS" if rows else ""
size = self.sql(expression, "size")
if size and self.TABLESAMPLE_SIZE_IS_PERCENT:
size = f"{size} PERCENT"
seed = self.sql(expression, "seed")
seed = f" {seed_prefix} ({seed})" if seed else ""
kind = expression.args.get("kind", "TABLESAMPLE")
@ -1154,6 +1226,7 @@ class Generator:
"NATURAL" if expression.args.get("natural") else None,
expression.side,
expression.kind,
expression.hint if self.JOIN_HINTS else None,
"JOIN",
)
if op
@ -1311,16 +1384,20 @@ class Generator:
def matchrecognize_sql(self, expression: exp.MatchRecognize) -> str:
partition = self.partition_by_sql(expression)
order = self.sql(expression, "order")
measures = self.sql(expression, "measures")
measures = self.seg(f"MEASURES {measures}") if measures else ""
measures = self.expressions(expression, key="measures")
measures = self.seg(f"MEASURES{self.seg(measures)}") if measures else ""
rows = self.sql(expression, "rows")
rows = self.seg(rows) if rows else ""
after = self.sql(expression, "after")
after = self.seg(after) if after else ""
pattern = self.sql(expression, "pattern")
pattern = self.seg(f"PATTERN ({pattern})") if pattern else ""
define = self.sql(expression, "define")
define = self.seg(f"DEFINE {define}") if define else ""
definition_sqls = [
f"{self.sql(definition, 'alias')} AS {self.sql(definition, 'this')}"
for definition in expression.args.get("define", [])
]
definitions = self.expressions(sqls=definition_sqls)
define = self.seg(f"DEFINE{self.seg(definitions)}") if definitions else ""
body = "".join(
(
partition,
@ -1332,7 +1409,9 @@ class Generator:
define,
)
)
return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}"
alias = self.sql(expression, "alias")
alias = f" {alias}" if alias else ""
return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}{alias}"
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
limit = expression.args.get("limit")
@ -1353,7 +1432,7 @@ class Generator:
self.sql(expression, "group"),
self.sql(expression, "having"),
self.sql(expression, "qualify"),
self.seg("WINDOW ") + self.expressions(expression, "windows", flat=True)
self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True)
if expression.args.get("windows")
else "",
self.sql(expression, "distribute"),
@ -1471,15 +1550,21 @@ class Generator:
partition_sql = partition + " " if partition and order else partition
spec = expression.args.get("spec")
spec_sql = " " + self.window_spec_sql(spec) if spec else ""
spec_sql = " " + self.windowspec_sql(spec) if spec else ""
alias = self.sql(expression, "alias")
this = f"{this} {'AS' if expression.arg_key == 'windows' else 'OVER'}"
over = self.sql(expression, "over") or "OVER"
this = f"{this} {'AS' if expression.arg_key == 'windows' else over}"
first = expression.args.get("first")
if first is not None:
first = " FIRST " if first else " LAST "
first = first or ""
if not partition and not order and not spec and alias:
return f"{this} {alias}"
window_args = alias + partition_sql + order_sql + spec_sql
window_args = alias + first + partition_sql + order_sql + spec_sql
return f"{this} ({window_args.strip()})"
@ -1487,7 +1572,7 @@ class Generator:
partition = self.expressions(expression, key="partition_by", flat=True)
return f"PARTITION BY {partition}" if partition else ""
def window_spec_sql(self, expression: exp.WindowSpec) -> str:
def windowspec_sql(self, expression: exp.WindowSpec) -> str:
kind = self.sql(expression, "kind")
start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ")
end = (
@ -1508,7 +1593,7 @@ class Generator:
return f"{this} BETWEEN {low} AND {high}"
def bracket_sql(self, expression: exp.Bracket) -> str:
expressions = apply_index_offset(expression.expressions, self.index_offset)
expressions = apply_index_offset(expression.this, expression.expressions, self.index_offset)
expressions_sql = ", ".join(self.sql(e) for e in expressions)
return f"{self.sql(expression, 'this')}[{expressions_sql}]"
@ -1550,6 +1635,11 @@ class Generator:
expressions = self.expressions(expression, flat=True)
return f"CONSTRAINT {this} {expressions}"
def nextvaluefor_sql(self, expression: exp.NextValueFor) -> str:
order = expression.args.get("order")
order = f" OVER ({self.order_sql(order, flat=True)})" if order else ""
return f"NEXT VALUE FOR {self.sql(expression, 'this')}{order}"
def extract_sql(self, expression: exp.Extract) -> str:
this = self.sql(expression, "this")
expression_sql = self.sql(expression, "expression")
@ -1586,7 +1676,7 @@ class Generator:
def primarykey_sql(self, expression: exp.ForeignKey) -> str:
expressions = self.expressions(expression, flat=True)
options = self.expressions(expression, "options", flat=True, sep=" ")
options = self.expressions(expression, key="options", flat=True, sep=" ")
options = f" {options}" if options else ""
return f"PRIMARY KEY ({expressions}){options}"
@ -1644,17 +1734,20 @@ class Generator:
return f"(SELECT {self.sql(unnest)})"
def interval_sql(self, expression: exp.Interval) -> str:
this = expression.args.get("this")
if this:
this = (
f" {this}"
if isinstance(this, exp.Literal) or isinstance(this, exp.Paren)
else f" ({this})"
)
else:
this = ""
unit = self.sql(expression, "unit")
if not self.INTERVAL_ALLOWS_PLURAL_FORM:
unit = self.TIME_PART_SINGULARS.get(unit.lower(), unit)
unit = f" {unit}" if unit else ""
if self.SINGLE_STRING_INTERVAL:
this = expression.this.name if expression.this else ""
return f"INTERVAL '{this}{unit}'"
this = self.sql(expression, "this")
if this:
unwrapped = isinstance(expression.this, self.UNWRAPPED_INTERVAL_VALUES)
this = f" {this}" if unwrapped else f" ({this})"
return f"INTERVAL{this}{unit}"
def return_sql(self, expression: exp.Return) -> str:
@ -1664,7 +1757,7 @@ class Generator:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
expressions = f"({expressions})" if expressions else ""
options = self.expressions(expression, "options", flat=True, sep=" ")
options = self.expressions(expression, key="options", flat=True, sep=" ")
options = f" {options}" if options else ""
return f"REFERENCES {this}{expressions}{options}"
@ -1690,9 +1783,9 @@ class Generator:
return f"NOT {self.sql(expression, 'this')}"
def alias_sql(self, expression: exp.Alias) -> str:
to_sql = self.sql(expression, "alias")
to_sql = f" AS {to_sql}" if to_sql else ""
return f"{self.sql(expression, 'this')}{to_sql}"
alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else ""
return f"{self.sql(expression, 'this')}{alias}"
def aliases_sql(self, expression: exp.Aliases) -> str:
return f"{self.sql(expression, 'this')} AS ({self.expressions(expression, flat=True)})"
@ -1712,7 +1805,11 @@ class Generator:
if not self.pretty:
return self.binary(expression, op)
sqls = tuple(self.sql(e) for e in expression.flatten(unnest=False))
sqls = tuple(
self.maybe_comment(self.sql(e), e, e.parent.comments) if i != 1 else self.sql(e)
for i, e in enumerate(expression.flatten(unnest=False))
)
sep = "\n" if self.text_width(sqls) > self._max_text_width else " "
return f"{sep}{op} ".join(sqls)
@ -1797,13 +1894,13 @@ class Generator:
actions = expression.args["actions"]
if isinstance(actions[0], exp.ColumnDef):
actions = self.expressions(expression, "actions", prefix="ADD COLUMN ")
actions = self.expressions(expression, key="actions", prefix="ADD COLUMN ")
elif isinstance(actions[0], exp.Schema):
actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ")
actions = self.expressions(expression, key="actions", prefix="ADD COLUMNS ")
elif isinstance(actions[0], exp.Delete):
actions = self.expressions(expression, "actions", flat=True)
actions = self.expressions(expression, key="actions", flat=True)
else:
actions = self.expressions(expression, "actions")
actions = self.expressions(expression, key="actions")
exists = " IF EXISTS" if expression.args.get("exists") else ""
return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}"
@ -1935,6 +2032,7 @@ class Generator:
return f"USE{kind}{this}"
def binary(self, expression: exp.Binary, op: str) -> str:
op = self.maybe_comment(op, comments=expression.comments)
return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
def function_fallback_sql(self, expression: exp.Func) -> str:
@ -1965,14 +2063,15 @@ class Generator:
def expressions(
self,
expression: exp.Expression,
expression: t.Optional[exp.Expression] = None,
key: t.Optional[str] = None,
sqls: t.Optional[t.List[str]] = None,
flat: bool = False,
indent: bool = True,
sep: str = ", ",
prefix: str = "",
) -> str:
expressions = expression.args.get(key or "expressions")
expressions = expression.args.get(key or "expressions") if expression else sqls
if not expressions:
return ""

View file

@ -131,11 +131,16 @@ def subclasses(
]
def apply_index_offset(expressions: t.List[t.Optional[E]], offset: int) -> t.List[t.Optional[E]]:
def apply_index_offset(
this: exp.Expression,
expressions: t.List[t.Optional[E]],
offset: int,
) -> t.List[t.Optional[E]]:
"""
Applies an offset to a given integer literal expression.
Args:
this: the target of the index
expressions: the expression the offset will be applied to, wrapped in a list.
offset: the offset that will be applied.
@ -148,11 +153,28 @@ def apply_index_offset(expressions: t.List[t.Optional[E]], offset: int) -> t.Lis
expression = expressions[0]
if expression and expression.is_int:
expression = expression.copy()
logger.warning("Applying array index offset (%s)", offset)
expression.args["this"] = str(int(expression.this) + offset) # type: ignore
return [expression]
from sqlglot import exp
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.simplify import simplify
if not this.type:
annotate_types(this)
if t.cast(exp.DataType, this.type).this not in (
exp.DataType.Type.UNKNOWN,
exp.DataType.Type.ARRAY,
):
return expressions
if expression:
if not expression.type:
annotate_types(expression)
if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
logger.warning("Applying array index offset (%s)", offset)
expression = simplify(
exp.Add(this=expression.copy(), expression=exp.Literal.number(offset))
)
return [expression]
return expressions

View file

@ -20,6 +20,7 @@ class Node:
expression: exp.Expression
source: exp.Expression
downstream: t.List[Node] = field(default_factory=list)
alias: str = ""
def walk(self) -> t.Iterator[Node]:
yield self
@ -69,14 +70,19 @@ def lineage(
optimized = optimize(expression, schema=schema, rules=rules)
scope = build_scope(optimized)
tables: t.Dict[str, Node] = {}
def to_node(
column_name: str,
scope: Scope,
scope_name: t.Optional[str] = None,
upstream: t.Optional[Node] = None,
alias: t.Optional[str] = None,
) -> Node:
aliases = {
dt.alias: dt.comments[0].split()[1]
for dt in scope.derived_tables
if dt.comments and dt.comments[0].startswith("source: ")
}
if isinstance(scope.expression, exp.Union):
for scope in scope.union_scopes:
node = to_node(
@ -84,37 +90,58 @@ def lineage(
scope=scope,
scope_name=scope_name,
upstream=upstream,
alias=aliases.get(scope_name),
)
return node
select = next(select for select in scope.selects if select.alias_or_name == column_name)
source = optimize(scope.expression.select(select, append=False), schema=schema, rules=rules)
select = source.selects[0]
# Find the specific select clause that is the source of the column we want.
# This can either be a specific, named select or a generic `*` clause.
select = next(
(select for select in scope.selects if select.alias_or_name == column_name),
exp.Star() if scope.expression.is_star else None,
)
if not select:
raise ValueError(f"Could not find {column_name} in {scope.expression}")
if isinstance(scope.expression, exp.Select):
# For better ergonomics in our node labels, replace the full select with
# a version that has only the column we care about.
# "x", SELECT x, y FROM foo
# => "x", SELECT x FROM foo
source = optimize(
scope.expression.select(select, append=False), schema=schema, rules=rules
)
select = source.selects[0]
else:
source = scope.expression
# Create the node for this step in the lineage chain, and attach it to the previous one.
node = Node(
name=f"{scope_name}.{column_name}" if scope_name else column_name,
source=source,
expression=select,
alias=alias or "",
)
if upstream:
upstream.downstream.append(node)
# Find all columns that went into creating this one to list their lineage nodes.
for c in set(select.find_all(exp.Column)):
table = c.table
source = scope.sources[table]
source = scope.sources.get(table)
if isinstance(source, Scope):
# The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
to_node(
c.name,
scope=source,
scope_name=table,
upstream=node,
c.name, scope=source, scope_name=table, upstream=node, alias=aliases.get(table)
)
else:
if table not in tables:
tables[table] = Node(name=c.sql(), source=source, expression=source)
node.downstream.append(tables[table])
# The source is not a scope - we've reached the end of the line. At this point, if a source is not found
# it means this column's lineage is unknown. This can happen if the definition of a source used in a query
# is not passed into the `sources` map.
source = source or exp.Placeholder()
node.downstream.append(Node(name=c.sql(), source=source, expression=source))
return node

View file

@ -116,6 +116,9 @@ class TypeAnnotator:
exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.ArraySize: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.Map: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
exp.VarMap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL),
exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
@ -335,7 +338,7 @@ class TypeAnnotator:
left_type = expression.left.type.this
right_type = expression.right.type.this
if isinstance(expression, (exp.And, exp.Or)):
if isinstance(expression, exp.Connector):
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
expression.type = exp.DataType.Type.NULL
elif exp.DataType.Type.NULL in (left_type, right_type):
@ -344,7 +347,7 @@ class TypeAnnotator:
)
else:
expression.type = exp.DataType.Type.BOOLEAN
elif isinstance(expression, (exp.Condition, exp.Predicate)):
elif isinstance(expression, exp.Predicate):
expression.type = exp.DataType.Type.BOOLEAN
else:
expression.type = self._maybe_coerce(left_type, right_type)

View file

@ -46,7 +46,9 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
root = node is expression
original = node.copy()
try:
node = while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
node = node.replace(
while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
)
except OptimizeError as e:
logger.info(e)
node.replace(original)

View file

@ -93,6 +93,7 @@ def _expand_using(scope, resolver):
if column not in columns:
columns[column] = k
source_table = ordered[-1]
ordered.append(join_table)
join_columns = resolver.get_source_columns(join_table)
conditions = []
@ -102,8 +103,10 @@ def _expand_using(scope, resolver):
table = columns.get(identifier)
if not table or identifier not in join_columns:
raise OptimizeError(f"Cannot automatically join: {identifier}")
if columns and join_columns:
raise OptimizeError(f"Cannot automatically join: {identifier}")
table = table or source_table
conditions.append(
exp.condition(
exp.EQ(

View file

@ -65,5 +65,8 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
if not table_alias.name:
table_alias.set("this", next_name())
if isinstance(udtf, exp.Values) and not table_alias.columns:
for i, e in enumerate(udtf.expressions[0].expressions):
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
return expression

View file

@ -201,23 +201,24 @@ def _simplify_comparison(expression, left, right, or_=False):
return left if (av < bv if or_ else av >= bv) else right
# we can't ever shortcut to true because the column could be null
if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
if not or_ and av <= bv:
return exp.false()
elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
if not or_ and av >= bv:
return exp.false()
elif isinstance(a, exp.EQ):
if isinstance(b, exp.LT):
return exp.false() if av >= bv else a
if isinstance(b, exp.LTE):
return exp.false() if av > bv else a
if isinstance(b, exp.GT):
return exp.false() if av <= bv else a
if isinstance(b, exp.GTE):
return exp.false() if av < bv else a
if isinstance(b, exp.NEQ):
return exp.false() if av == bv else a
if not or_:
if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
if av <= bv:
return exp.false()
elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
if av >= bv:
return exp.false()
elif isinstance(a, exp.EQ):
if isinstance(b, exp.LT):
return exp.false() if av >= bv else a
if isinstance(b, exp.LTE):
return exp.false() if av > bv else a
if isinstance(b, exp.GT):
return exp.false() if av <= bv else a
if isinstance(b, exp.GTE):
return exp.false() if av < bv else a
if isinstance(b, exp.NEQ):
return exp.false() if av == bv else a
return None

View file

@ -18,8 +18,13 @@ from sqlglot.trie import in_trie, new_trie
logger = logging.getLogger("sqlglot")
E = t.TypeVar("E", bound=exp.Expression)
def parse_var_map(args: t.Sequence) -> exp.Expression:
if len(args) == 1 and args[0].is_star:
return exp.StarMap(this=args[0])
keys = []
values = []
for i in range(0, len(args), 2):
@ -108,6 +113,8 @@ class Parser(metaclass=_Parser):
TokenType.CURRENT_USER: exp.CurrentUser,
}
JOIN_HINTS: t.Set[str] = set()
NESTED_TYPE_TOKENS = {
TokenType.ARRAY,
TokenType.MAP,
@ -145,6 +152,7 @@ class Parser(metaclass=_Parser):
TokenType.DATETIME,
TokenType.DATE,
TokenType.DECIMAL,
TokenType.BIGDECIMAL,
TokenType.UUID,
TokenType.GEOGRAPHY,
TokenType.GEOMETRY,
@ -221,8 +229,10 @@ class Parser(metaclass=_Parser):
TokenType.FORMAT,
TokenType.FULL,
TokenType.IF,
TokenType.IS,
TokenType.ISNULL,
TokenType.INTERVAL,
TokenType.KEEP,
TokenType.LAZY,
TokenType.LEADING,
TokenType.LEFT,
@ -235,6 +245,7 @@ class Parser(metaclass=_Parser):
TokenType.ONLY,
TokenType.OPTIONS,
TokenType.ORDINALITY,
TokenType.OVERWRITE,
TokenType.PARTITION,
TokenType.PERCENT,
TokenType.PIVOT,
@ -266,6 +277,8 @@ class Parser(metaclass=_Parser):
*NO_PAREN_FUNCTIONS,
}
INTERVAL_VARS = ID_VAR_TOKENS - {TokenType.END}
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {
TokenType.APPLY,
TokenType.FULL,
@ -276,6 +289,8 @@ class Parser(metaclass=_Parser):
TokenType.WINDOW,
}
COMMENT_TABLE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.IS}
UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET}
TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH}
@ -400,7 +415,7 @@ class Parser(metaclass=_Parser):
COLUMN_OPERATORS = {
TokenType.DOT: None,
TokenType.DCOLON: lambda self, this, to: self.expression(
exp.Cast,
exp.Cast if self.STRICT_CAST else exp.TryCast,
this=this,
to=to,
),
@ -560,7 +575,7 @@ class Parser(metaclass=_Parser):
),
"DEFINER": lambda self: self._parse_definer(),
"DETERMINISTIC": lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
),
"DISTKEY": lambda self: self._parse_distkey(),
"DISTSTYLE": lambda self: self._parse_property_assignment(exp.DistStyleProperty),
@ -571,7 +586,7 @@ class Parser(metaclass=_Parser):
"FREESPACE": lambda self: self._parse_freespace(),
"GLOBAL": lambda self: self._parse_temporary(global_=True),
"IMMUTABLE": lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
),
"JOURNAL": lambda self: self._parse_journal(
no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL"
@ -600,20 +615,20 @@ class Parser(metaclass=_Parser):
"PARTITIONED_BY": lambda self: self._parse_partitioned_by(),
"RETURNS": lambda self: self._parse_returns(),
"ROW": lambda self: self._parse_row(),
"ROW_FORMAT": lambda self: self._parse_property_assignment(exp.RowFormatProperty),
"SET": lambda self: self.expression(exp.SetProperty, multi=False),
"SORTKEY": lambda self: self._parse_sortkey(),
"STABLE": lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("STABLE")
exp.StabilityProperty, this=exp.Literal.string("STABLE")
),
"STORED": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
"STORED": lambda self: self._parse_stored(),
"TABLE_FORMAT": lambda self: self._parse_property_assignment(exp.TableFormatProperty),
"TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property),
"TEMP": lambda self: self._parse_temporary(global_=False),
"TEMPORARY": lambda self: self._parse_temporary(global_=False),
"TRANSIENT": lambda self: self.expression(exp.TransientProperty),
"USING": lambda self: self._parse_property_assignment(exp.TableFormatProperty),
"VOLATILE": lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
),
"VOLATILE": lambda self: self._parse_volatile_property(),
"WITH": lambda self: self._parse_with_property(),
}
@ -648,8 +663,11 @@ class Parser(metaclass=_Parser):
"LIKE": lambda self: self._parse_create_like(),
"NOT": lambda self: self._parse_not_constraint(),
"NULL": lambda self: self.expression(exp.NotNullColumnConstraint, allow_null=True),
"ON": lambda self: self._match(TokenType.UPDATE)
and self.expression(exp.OnUpdateColumnConstraint, this=self._parse_function()),
"PATH": lambda self: self.expression(exp.PathColumnConstraint, this=self._parse_string()),
"PRIMARY KEY": lambda self: self._parse_primary_key(),
"REFERENCES": lambda self: self._parse_references(match=False),
"TITLE": lambda self: self.expression(
exp.TitleColumnConstraint, this=self._parse_var_or_string()
),
@ -668,9 +686,14 @@ class Parser(metaclass=_Parser):
SCHEMA_UNNAMED_CONSTRAINTS = {"CHECK", "FOREIGN KEY", "LIKE", "PRIMARY KEY", "UNIQUE"}
NO_PAREN_FUNCTION_PARSERS = {
TokenType.ANY: lambda self: self.expression(exp.Any, this=self._parse_bitwise()),
TokenType.CASE: lambda self: self._parse_case(),
TokenType.IF: lambda self: self._parse_if(),
TokenType.ANY: lambda self: self.expression(exp.Any, this=self._parse_bitwise()),
TokenType.NEXT_VALUE_FOR: lambda self: self.expression(
exp.NextValueFor,
this=self._parse_column(),
order=self._match(TokenType.OVER) and self._parse_wrapped(self._parse_order),
),
}
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
@ -715,6 +738,8 @@ class Parser(metaclass=_Parser):
SHOW_PARSERS: t.Dict[str, t.Callable] = {}
TYPE_LITERAL_PARSERS: t.Dict[exp.DataType.Type, t.Callable] = {}
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
@ -731,6 +756,7 @@ class Parser(metaclass=_Parser):
INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"}
WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS}
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER}
ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY}
@ -738,6 +764,9 @@ class Parser(metaclass=_Parser):
CONVERT_TYPE_FIRST = False
QUOTED_PIVOT_COLUMNS: t.Optional[bool] = None
PREFIXED_PIVOT_COLUMNS = False
LOG_BASE_FIRST = True
LOG_DEFAULTS_TO_LN = False
@ -895,8 +924,8 @@ class Parser(metaclass=_Parser):
error level setting.
"""
token = token or self._curr or self._prev or Token.string("")
start = self._find_token(token)
end = start + len(token.text)
start = token.start
end = token.end
start_context = self.sql[max(start - self.error_message_context, 0) : start]
highlight = self.sql[start:end]
end_context = self.sql[end : end + self.error_message_context]
@ -918,8 +947,8 @@ class Parser(metaclass=_Parser):
self.errors.append(error)
def expression(
self, exp_class: t.Type[exp.Expression], comments: t.Optional[t.List[str]] = None, **kwargs
) -> exp.Expression:
self, exp_class: t.Type[E], comments: t.Optional[t.List[str]] = None, **kwargs
) -> E:
"""
Creates a new, validated Expression.
@ -958,22 +987,7 @@ class Parser(metaclass=_Parser):
self.raise_error(error_message)
def _find_sql(self, start: Token, end: Token) -> str:
return self.sql[self._find_token(start) : self._find_token(end) + len(end.text)]
def _find_token(self, token: Token) -> int:
line = 1
col = 1
index = 0
while line < token.line or col < token.col:
if Tokenizer.WHITE_SPACE.get(self.sql[index]) == TokenType.BREAK:
line += 1
col = 1
else:
col += 1
index += 1
return index
return self.sql[start.start : end.end]
def _advance(self, times: int = 1) -> None:
self._index += times
@ -990,7 +1004,7 @@ class Parser(metaclass=_Parser):
if index != self._index:
self._advance(index - self._index)
def _parse_command(self) -> exp.Expression:
def _parse_command(self) -> exp.Command:
return self.expression(exp.Command, this=self._prev.text, expression=self._parse_string())
def _parse_comment(self, allow_exists: bool = True) -> exp.Expression:
@ -1007,7 +1021,7 @@ class Parser(metaclass=_Parser):
if kind.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function(kind=kind.token_type)
elif kind.token_type == TokenType.TABLE:
this = self._parse_table()
this = self._parse_table(alias_tokens=self.COMMENT_TABLE_ALIAS_TOKENS)
elif kind.token_type == TokenType.COLUMN:
this = self._parse_column()
else:
@ -1035,16 +1049,13 @@ class Parser(metaclass=_Parser):
self._parse_query_modifiers(expression)
return expression
def _parse_drop(self, default_kind: t.Optional[str] = None) -> t.Optional[exp.Expression]:
def _parse_drop(self) -> t.Optional[exp.Drop | exp.Command]:
start = self._prev
temporary = self._match(TokenType.TEMPORARY)
materialized = self._match(TokenType.MATERIALIZED)
kind = self._match_set(self.CREATABLES) and self._prev.text
if not kind:
if default_kind:
kind = default_kind
else:
return self._parse_as_command(start)
return self._parse_as_command(start)
return self.expression(
exp.Drop,
@ -1055,6 +1066,7 @@ class Parser(metaclass=_Parser):
materialized=materialized,
cascade=self._match(TokenType.CASCADE),
constraints=self._match_text_seq("CONSTRAINTS"),
purge=self._match_text_seq("PURGE"),
)
def _parse_exists(self, not_: bool = False) -> t.Optional[bool]:
@ -1070,7 +1082,6 @@ class Parser(metaclass=_Parser):
TokenType.OR, TokenType.REPLACE
)
unique = self._match(TokenType.UNIQUE)
volatile = self._match(TokenType.VOLATILE)
if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False):
self._match(TokenType.TABLE)
@ -1179,7 +1190,6 @@ class Parser(metaclass=_Parser):
kind=create_token.text,
replace=replace,
unique=unique,
volatile=volatile,
expression=expression,
exists=exists,
properties=properties,
@ -1225,6 +1235,21 @@ class Parser(metaclass=_Parser):
return None
def _parse_stored(self) -> exp.Expression:
self._match(TokenType.ALIAS)
input_format = self._parse_string() if self._match_text_seq("INPUTFORMAT") else None
output_format = self._parse_string() if self._match_text_seq("OUTPUTFORMAT") else None
return self.expression(
exp.FileFormatProperty,
this=self.expression(
exp.InputOutputFormat, input_format=input_format, output_format=output_format
)
if input_format or output_format
else self._parse_var_or_string() or self._parse_number() or self._parse_id_var(),
)
def _parse_property_assignment(self, exp_class: t.Type[exp.Expression]) -> exp.Expression:
self._match(TokenType.EQ)
self._match(TokenType.ALIAS)
@ -1258,6 +1283,21 @@ class Parser(metaclass=_Parser):
exp.FallbackProperty, no=no, protection=self._match_text_seq("PROTECTION")
)
def _parse_volatile_property(self) -> exp.Expression:
if self._index >= 2:
pre_volatile_token = self._tokens[self._index - 2]
else:
pre_volatile_token = None
if pre_volatile_token and pre_volatile_token.token_type in (
TokenType.CREATE,
TokenType.REPLACE,
TokenType.UNIQUE,
):
return exp.VolatileProperty()
return self.expression(exp.StabilityProperty, this=exp.Literal.string("VOLATILE"))
def _parse_with_property(
self,
) -> t.Union[t.Optional[exp.Expression], t.List[t.Optional[exp.Expression]]]:
@ -1574,11 +1614,46 @@ class Parser(metaclass=_Parser):
exists=self._parse_exists(),
partition=self._parse_partition(),
expression=self._parse_ddl_select(),
conflict=self._parse_on_conflict(),
returning=self._parse_returning(),
overwrite=overwrite,
alternative=alternative,
)
def _parse_on_conflict(self) -> t.Optional[exp.Expression]:
conflict = self._match_text_seq("ON", "CONFLICT")
duplicate = self._match_text_seq("ON", "DUPLICATE", "KEY")
if not (conflict or duplicate):
return None
nothing = None
expressions = None
key = None
constraint = None
if conflict:
if self._match_text_seq("ON", "CONSTRAINT"):
constraint = self._parse_id_var()
else:
key = self._parse_csv(self._parse_value)
self._match_text_seq("DO")
if self._match_text_seq("NOTHING"):
nothing = True
else:
self._match(TokenType.UPDATE)
expressions = self._match(TokenType.SET) and self._parse_csv(self._parse_equality)
return self.expression(
exp.OnConflict,
duplicate=duplicate,
expressions=expressions,
nothing=nothing,
key=key,
constraint=constraint,
)
def _parse_returning(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.RETURNING):
return None
@ -1639,7 +1714,7 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.Delete,
this=self._parse_table(schema=True),
this=self._parse_table(),
using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table()),
where=self._parse_where(),
returning=self._parse_returning(),
@ -1792,6 +1867,7 @@ class Parser(metaclass=_Parser):
if not skip_with_token and not self._match(TokenType.WITH):
return None
comments = self._prev_comments
recursive = self._match(TokenType.RECURSIVE)
expressions = []
@ -1803,7 +1879,9 @@ class Parser(metaclass=_Parser):
else:
self._match(TokenType.WITH)
return self.expression(exp.With, expressions=expressions, recursive=recursive)
return self.expression(
exp.With, comments=comments, expressions=expressions, recursive=recursive
)
def _parse_cte(self) -> exp.Expression:
alias = self._parse_table_alias()
@ -1856,15 +1934,20 @@ class Parser(metaclass=_Parser):
table = isinstance(this, exp.Table)
while True:
lateral = self._parse_lateral()
join = self._parse_join()
comma = None if table else self._match(TokenType.COMMA)
if lateral:
this.append("laterals", lateral)
if join:
this.append("joins", join)
lateral = None
if not join:
lateral = self._parse_lateral()
if lateral:
this.append("laterals", lateral)
comma = None if table else self._match(TokenType.COMMA)
if comma:
this.args["from"].append("expressions", self._parse_table())
if not (lateral or join or comma):
break
@ -1906,14 +1989,13 @@ class Parser(metaclass=_Parser):
def _parse_match_recognize(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.MATCH_RECOGNIZE):
return None
self._match_l_paren()
partition = self._parse_partition_by()
order = self._parse_order()
measures = (
self._parse_alias(self._parse_conjunction())
if self._match_text_seq("MEASURES")
else None
self._parse_csv(self._parse_expression) if self._match_text_seq("MEASURES") else None
)
if self._match_text_seq("ONE", "ROW", "PER", "MATCH"):
@ -1967,8 +2049,17 @@ class Parser(metaclass=_Parser):
pattern = None
define = (
self._parse_alias(self._parse_conjunction()) if self._match_text_seq("DEFINE") else None
self._parse_csv(
lambda: self.expression(
exp.Alias,
alias=self._parse_id_var(any_token=True),
this=self._match(TokenType.ALIAS) and self._parse_conjunction(),
)
)
if self._match_text_seq("DEFINE")
else None
)
self._match_r_paren()
return self.expression(
@ -1980,6 +2071,7 @@ class Parser(metaclass=_Parser):
after=after,
pattern=pattern,
define=define,
alias=self._parse_table_alias(),
)
def _parse_lateral(self) -> t.Optional[exp.Expression]:
@ -2022,9 +2114,6 @@ class Parser(metaclass=_Parser):
alias=table_alias,
)
if outer_apply or cross_apply:
return self.expression(exp.Join, this=expression, side=None if cross_apply else "LEFT")
return expression
def _parse_join_side_and_kind(
@ -2037,11 +2126,26 @@ class Parser(metaclass=_Parser):
)
def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]:
index = self._index
natural, side, kind = self._parse_join_side_and_kind()
hint = self._prev.text if self._match_texts(self.JOIN_HINTS) else None
join = self._match(TokenType.JOIN)
if not skip_join_token and not self._match(TokenType.JOIN):
if not skip_join_token and not join:
self._retreat(index)
kind = None
natural = None
side = None
outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY, False)
cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY, False)
if not skip_join_token and not join and not outer_apply and not cross_apply:
return None
if outer_apply:
side = Token(TokenType.LEFT, "LEFT")
kwargs: t.Dict[
str, t.Optional[exp.Expression] | bool | str | t.List[t.Optional[exp.Expression]]
] = {"this": self._parse_table()}
@ -2052,6 +2156,8 @@ class Parser(metaclass=_Parser):
kwargs["side"] = side.text
if kind:
kwargs["kind"] = kind.text
if hint:
kwargs["hint"] = hint
if self._match(TokenType.ON):
kwargs["on"] = self._parse_conjunction()
@ -2179,7 +2285,7 @@ class Parser(metaclass=_Parser):
return None
expressions = self._parse_wrapped_csv(self._parse_column)
ordinality = bool(self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY))
ordinality = self._match_pair(TokenType.WITH, TokenType.ORDINALITY)
alias = self._parse_table_alias()
if alias and self.unnest_column_only:
@ -2191,7 +2297,7 @@ class Parser(metaclass=_Parser):
offset = None
if self._match_pair(TokenType.WITH, TokenType.OFFSET):
self._match(TokenType.ALIAS)
offset = self._parse_conjunction()
offset = self._parse_id_var() or exp.Identifier(this="offset")
return self.expression(
exp.Unnest,
@ -2294,6 +2400,9 @@ class Parser(metaclass=_Parser):
else:
expressions = self._parse_csv(lambda: self._parse_alias(self._parse_function()))
if not expressions:
self.raise_error("Failed to parse PIVOT's aggregation list")
if not self._match(TokenType.FOR):
self.raise_error("Expecting FOR")
@ -2311,8 +2420,26 @@ class Parser(metaclass=_Parser):
if not self._match_set((TokenType.PIVOT, TokenType.UNPIVOT), advance=False):
pivot.set("alias", self._parse_table_alias())
if not unpivot:
names = self._pivot_column_names(t.cast(t.List[exp.Expression], expressions))
columns: t.List[exp.Expression] = []
for col in pivot.args["field"].expressions:
for name in names:
if self.PREFIXED_PIVOT_COLUMNS:
name = f"{name}_{col.alias_or_name}" if name else col.alias_or_name
else:
name = f"{col.alias_or_name}_{name}" if name else col.alias_or_name
columns.append(exp.to_identifier(name, quoted=self.QUOTED_PIVOT_COLUMNS))
pivot.set("columns", columns)
return pivot
def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
return [agg.alias for agg in pivot_columns]
def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Expression]:
if not skip_where_token and not self._match(TokenType.WHERE):
return None
@ -2433,10 +2560,25 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.FETCH):
direction = self._match_set((TokenType.FIRST, TokenType.NEXT))
direction = self._prev.text if direction else "FIRST"
count = self._parse_number()
percent = self._match(TokenType.PERCENT)
self._match_set((TokenType.ROW, TokenType.ROWS))
self._match(TokenType.ONLY)
return self.expression(exp.Fetch, direction=direction, count=count)
only = self._match(TokenType.ONLY)
with_ties = self._match_text_seq("WITH", "TIES")
if only and with_ties:
self.raise_error("Cannot specify both ONLY and WITH TIES in FETCH clause")
return self.expression(
exp.Fetch,
direction=direction,
count=count,
percent=percent,
with_ties=with_ties,
)
return this
@ -2493,7 +2635,11 @@ class Parser(metaclass=_Parser):
negate = self._match(TokenType.NOT)
if self._match_set(self.RANGE_PARSERS):
this = self.RANGE_PARSERS[self._prev.token_type](self, this)
expression = self.RANGE_PARSERS[self._prev.token_type](self, this)
if not expression:
return this
this = expression
elif self._match(TokenType.ISNULL):
this = self.expression(exp.Is, this=this, expression=exp.Null())
@ -2511,17 +2657,19 @@ class Parser(metaclass=_Parser):
return this
def _parse_is(self, this: t.Optional[exp.Expression]) -> exp.Expression:
def _parse_is(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
index = self._index - 1
negate = self._match(TokenType.NOT)
if self._match(TokenType.DISTINCT_FROM):
klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
return self.expression(klass, this=this, expression=self._parse_expression())
this = self.expression(
exp.Is,
this=this,
expression=self._parse_null() or self._parse_boolean(),
)
expression = self._parse_null() or self._parse_boolean()
if not expression:
self._retreat(index)
return None
this = self.expression(exp.Is, this=this, expression=expression)
return self.expression(exp.Not, this=this) if negate else this
def _parse_in(self, this: t.Optional[exp.Expression]) -> exp.Expression:
@ -2553,6 +2701,27 @@ class Parser(metaclass=_Parser):
return this
return self.expression(exp.Escape, this=this, expression=self._parse_string())
def _parse_interval(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.INTERVAL):
return None
this = self._parse_primary() or self._parse_term()
unit = self._parse_function() or self._parse_var()
# Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse
# each INTERVAL expression into this canonical form so it's easy to transpile
if this and isinstance(this, exp.Literal):
if this.is_number:
this = exp.Literal.string(this.name)
# Try to not clutter Snowflake's multi-part intervals like INTERVAL '1 day, 1 year'
parts = this.name.split()
if not unit and len(parts) <= 2:
this = exp.Literal.string(seq_get(parts, 0))
unit = self.expression(exp.Var, this=seq_get(parts, 1))
return self.expression(exp.Interval, this=this, unit=unit)
def _parse_bitwise(self) -> t.Optional[exp.Expression]:
this = self._parse_term()
@ -2588,20 +2757,24 @@ class Parser(metaclass=_Parser):
return self._parse_at_time_zone(self._parse_type())
def _parse_type(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.INTERVAL):
return self.expression(exp.Interval, this=self._parse_term(), unit=self._parse_field())
interval = self._parse_interval()
if interval:
return interval
index = self._index
type_token = self._parse_types(check_func=True)
data_type = self._parse_types(check_func=True)
this = self._parse_column()
if type_token:
if data_type:
if isinstance(this, exp.Literal):
return self.expression(exp.Cast, this=this, to=type_token)
if not type_token.args.get("expressions"):
parser = self.TYPE_LITERAL_PARSERS.get(data_type.this)
if parser:
return parser(self, this, data_type)
return self.expression(exp.Cast, this=this, to=data_type)
if not data_type.args.get("expressions"):
self._retreat(index)
return self._parse_column()
return type_token
return data_type
return this
@ -2631,11 +2804,10 @@ class Parser(metaclass=_Parser):
else:
expressions = self._parse_csv(self._parse_conjunction)
if not expressions:
if not expressions or not self._match(TokenType.R_PAREN):
self._retreat(index)
return None
self._match_r_paren()
maybe_func = True
if self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
@ -2720,15 +2892,14 @@ class Parser(metaclass=_Parser):
)
def _parse_struct_kwargs(self) -> t.Optional[exp.Expression]:
if self._curr and self._curr.token_type in self.TYPE_TOKENS:
return self._parse_types()
index = self._index
this = self._parse_id_var()
self._match(TokenType.COLON)
data_type = self._parse_types()
if not data_type:
return None
self._retreat(index)
return self._parse_types()
return self.expression(exp.StructKwarg, this=this, expression=data_type)
def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
@ -2825,6 +2996,7 @@ class Parser(metaclass=_Parser):
this = self.expression(exp.Paren, this=self._parse_set_operations(this))
self._match_r_paren()
comments.extend(self._prev_comments)
if this and comments:
this.comments = comments
@ -2833,8 +3005,16 @@ class Parser(metaclass=_Parser):
return None
def _parse_field(self, any_token: bool = False) -> t.Optional[exp.Expression]:
return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token)
def _parse_field(
self,
any_token: bool = False,
tokens: t.Optional[t.Collection[TokenType]] = None,
) -> t.Optional[exp.Expression]:
return (
self._parse_primary()
or self._parse_function()
or self._parse_id_var(any_token=any_token, tokens=tokens)
)
def _parse_function(
self, functions: t.Optional[t.Dict[str, t.Callable]] = None
@ -3079,12 +3259,10 @@ class Parser(metaclass=_Parser):
return None
def _parse_column_constraint(self) -> t.Optional[exp.Expression]:
this = self._parse_references()
if this:
return this
if self._match(TokenType.CONSTRAINT):
this = self._parse_id_var()
else:
this = None
if self._match_texts(self.CONSTRAINT_PARSERS):
return self.expression(
@ -3164,8 +3342,8 @@ class Parser(metaclass=_Parser):
return options
def _parse_references(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.REFERENCES):
def _parse_references(self, match=True) -> t.Optional[exp.Expression]:
if match and not self._match(TokenType.REFERENCES):
return None
expressions = None
@ -3234,7 +3412,7 @@ class Parser(metaclass=_Parser):
elif not this or this.name.upper() == "ARRAY":
this = self.expression(exp.Array, expressions=expressions)
else:
expressions = apply_index_offset(expressions, -self.index_offset)
expressions = apply_index_offset(this, expressions, -self.index_offset)
this = self.expression(exp.Bracket, this=this, expressions=expressions)
if not self._match(TokenType.R_BRACKET) and bracket_kind == TokenType.L_BRACKET:
@ -3279,7 +3457,13 @@ class Parser(metaclass=_Parser):
self.validate_expression(this, args)
self._match_r_paren()
else:
index = self._index - 1
condition = self._parse_conjunction()
if not condition:
self._retreat(index)
return None
self._match(TokenType.THEN)
true = self._parse_conjunction()
false = self._parse_conjunction() if self._match(TokenType.ELSE) else None
@ -3591,14 +3775,24 @@ class Parser(metaclass=_Parser):
# bigquery select from window x AS (partition by ...)
if alias:
over = None
self._match(TokenType.ALIAS)
elif not self._match(TokenType.OVER):
elif not self._match_set(self.WINDOW_BEFORE_PAREN_TOKENS):
return this
else:
over = self._prev.text.upper()
if not self._match(TokenType.L_PAREN):
return self.expression(exp.Window, this=this, alias=self._parse_id_var(False))
return self.expression(
exp.Window, this=this, alias=self._parse_id_var(False), over=over
)
window_alias = self._parse_id_var(any_token=False, tokens=self.WINDOW_ALIAS_TOKENS)
first = self._match(TokenType.FIRST)
if self._match_text_seq("LAST"):
first = False
partition = self._parse_partition_by()
order = self._parse_order()
kind = self._match_set((TokenType.ROWS, TokenType.RANGE)) and self._prev.text
@ -3629,6 +3823,8 @@ class Parser(metaclass=_Parser):
order=order,
spec=spec,
alias=window_alias,
over=over,
first=first,
)
def _parse_window_spec(self) -> t.Dict[str, t.Optional[str | exp.Expression]]:
@ -3886,7 +4082,10 @@ class Parser(metaclass=_Parser):
return expression
def _parse_drop_column(self) -> t.Optional[exp.Expression]:
return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN")
drop = self._match(TokenType.DROP) and self._parse_drop()
if drop and not isinstance(drop, exp.Command):
drop.set("kind", drop.args.get("kind", "COLUMN"))
return drop
# https://docs.aws.amazon.com/athena/latest/ug/alter-table-drop-partition.html
def _parse_drop_partition(self, exists: t.Optional[bool] = None) -> exp.Expression:
@ -4010,7 +4209,7 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.INSERT):
_this = self._parse_star()
if _this:
then = self.expression(exp.Insert, this=_this)
then: t.Optional[exp.Expression] = self.expression(exp.Insert, this=_this)
else:
then = self.expression(
exp.Insert,
@ -4239,5 +4438,8 @@ class Parser(metaclass=_Parser):
break
parent = parent.parent
else:
column.replace(dot_or_id)
if column is node:
node = dot_or_id
else:
column.replace(dot_or_id)
return node

View file

@ -5,7 +5,7 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
from sqlglot.errors import SchemaError
from sqlglot.errors import ParseError, SchemaError
from sqlglot.helper import dict_depth
from sqlglot.trie import in_trie, new_trie
@ -75,12 +75,11 @@ class AbstractMappingSchema(t.Generic[T]):
mapping: dict | None = None,
) -> None:
self.mapping = mapping or {}
self.mapping_trie = self._build_trie(self.mapping)
self.mapping_trie = new_trie(
tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
)
self._supported_table_args: t.Tuple[str, ...] = tuple()
def _build_trie(self, schema: t.Dict) -> t.Dict:
return new_trie(tuple(reversed(t)) for t in flatten_schema(schema, depth=self._depth()))
def _depth(self) -> int:
return dict_depth(self.mapping)
@ -179,6 +178,64 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
}
)
def add_table(
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
) -> None:
"""
Register or update a table. Updates are only performed if a new column mapping is provided.
Args:
table: the `Table` expression instance or string representing the table.
column_mapping: a column mapping that describes the structure of the table.
"""
normalized_table = self._normalize_table(self._ensure_table(table))
normalized_column_mapping = {
self._normalize_name(key): value
for key, value in ensure_column_mapping(column_mapping).items()
}
schema = self.find(normalized_table, raise_on_missing=False)
if schema and not normalized_column_mapping:
return
parts = self.table_parts(normalized_table)
_nested_set(
self.mapping,
tuple(reversed(parts)),
normalized_column_mapping,
)
new_trie([parts], self.mapping_trie)
def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
table_ = self._normalize_table(self._ensure_table(table))
schema = self.find(table_)
if schema is None:
return []
if not only_visible or not self.visible:
return list(schema)
visible = self._nested_get(self.table_parts(table_), self.visible)
return [col for col in schema if col in visible] # type: ignore
def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
column_name = self._normalize_name(column if isinstance(column, str) else column.this)
table_ = self._normalize_table(self._ensure_table(table))
table_schema = self.find(table_, raise_on_missing=False)
if table_schema:
column_type = table_schema.get(column_name)
if isinstance(column_type, exp.DataType):
return column_type
elif isinstance(column_type, str):
return self._to_data_type(column_type.upper())
raise SchemaError(f"Unknown column type '{column_type}'")
return exp.DataType.build("unknown")
def _normalize(self, schema: t.Dict) -> t.Dict:
"""
Converts all identifiers in the schema into lowercase, unless they're quoted.
@ -206,84 +263,37 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
return normalized_mapping
def add_table(
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
) -> None:
"""
Register or update a table. Updates are only performed if a new column mapping is provided.
def _normalize_table(self, table: exp.Table) -> exp.Table:
normalized_table = table.copy()
for arg in TABLE_ARGS:
value = normalized_table.args.get(arg)
if isinstance(value, (str, exp.Identifier)):
normalized_table.set(arg, self._normalize_name(value))
Args:
table: the `Table` expression instance or string representing the table.
column_mapping: a column mapping that describes the structure of the table.
"""
table_ = self._ensure_table(table)
column_mapping = ensure_column_mapping(column_mapping)
schema = self.find(table_, raise_on_missing=False)
return normalized_table
if schema and not column_mapping:
return
_nested_set(
self.mapping,
list(reversed(self.table_parts(table_))),
column_mapping,
)
self.mapping_trie = self._build_trie(self.mapping)
def _normalize_name(self, name: str) -> str:
def _normalize_name(self, name: str | exp.Identifier) -> str:
try:
identifier: t.Optional[exp.Expression] = sqlglot.parse_one(
name, read=self.dialect, into=exp.Identifier
)
except:
identifier = exp.to_identifier(name)
assert isinstance(identifier, exp.Identifier)
identifier = sqlglot.maybe_parse(name, dialect=self.dialect, into=exp.Identifier)
except ParseError:
return name if isinstance(name, str) else name.name
if identifier.quoted:
return identifier.name
return identifier.name.lower()
return identifier.name if identifier.quoted else identifier.name.lower()
def _depth(self) -> int:
# The columns themselves are a mapping, but we don't want to include those
return super()._depth() - 1
def _ensure_table(self, table: exp.Table | str) -> exp.Table:
table_ = exp.to_table(table)
if isinstance(table, exp.Table):
return table
table_ = sqlglot.parse_one(table, read=self.dialect, into=exp.Table)
if not table_:
raise SchemaError(f"Not a valid table '{table}'")
return table_
def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
table_ = self._ensure_table(table)
schema = self.find(table_)
if schema is None:
return []
if not only_visible or not self.visible:
return list(schema)
visible = self._nested_get(self.table_parts(table_), self.visible)
return [col for col in schema if col in visible] # type: ignore
def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
column_name = column if isinstance(column, str) else column.name
table_ = exp.to_table(table)
if table_:
table_schema = self.find(table_, raise_on_missing=False)
if table_schema:
column_type = table_schema.get(column_name)
if isinstance(column_type, exp.DataType):
return column_type
elif isinstance(column_type, str):
return self._to_data_type(column_type.upper())
raise SchemaError(f"Unknown column type '{column_type}'")
return exp.DataType(this=exp.DataType.Type.UNKNOWN)
raise SchemaError(f"Could not convert table '{table}'")
def _to_data_type(self, schema_type: str) -> exp.DataType:
"""
Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
@ -313,7 +323,7 @@ def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema:
return MappingSchema(schema, dialect=dialect)
def ensure_column_mapping(mapping: t.Optional[ColumnMapping]):
def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
if isinstance(mapping, dict):
return mapping
elif isinstance(mapping, str):
@ -371,7 +381,7 @@ def _nested_get(
return d
def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict:
def _nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
"""
In-place set a value for a nested dictionary
@ -384,11 +394,11 @@ def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict:
Args:
d: dictionary to update.
keys: the keys that makeup the path to `value`.
value: the value to set in the dictionary for the given key path.
keys: the keys that makeup the path to `value`.
value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.
Returns:
The (possibly) updated dictionary.
"""
if not keys:
return d

View file

@ -87,6 +87,7 @@ class TokenType(AutoName):
FLOAT = auto()
DOUBLE = auto()
DECIMAL = auto()
BIGDECIMAL = auto()
CHAR = auto()
NCHAR = auto()
VARCHAR = auto()
@ -214,6 +215,7 @@ class TokenType(AutoName):
ISNULL = auto()
JOIN = auto()
JOIN_MARKER = auto()
KEEP = auto()
LANGUAGE = auto()
LATERAL = auto()
LAZY = auto()
@ -231,6 +233,7 @@ class TokenType(AutoName):
MOD = auto()
NATURAL = auto()
NEXT = auto()
NEXT_VALUE_FOR = auto()
NO_ACTION = auto()
NOTNULL = auto()
NULL = auto()
@ -315,7 +318,7 @@ class TokenType(AutoName):
class Token:
__slots__ = ("token_type", "text", "line", "col", "comments")
__slots__ = ("token_type", "text", "line", "col", "end", "comments")
@classmethod
def number(cls, number: int) -> Token:
@ -343,22 +346,29 @@ class Token:
text: str,
line: int = 1,
col: int = 1,
end: int = 0,
comments: t.List[str] = [],
) -> None:
self.token_type = token_type
self.text = text
self.line = line
self.col = col - len(text)
self.col = self.col if self.col > 1 else 1
size = len(text)
self.col = col
self.end = end if end else size
self.comments = comments
@property
def start(self) -> int:
"""Returns the start of the token."""
return self.end - len(self.text)
def __repr__(self) -> str:
attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
return f"<Token {attributes}>"
class _Tokenizer(type):
def __new__(cls, clsname, bases, attrs): # type: ignore
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
klass._QUOTES = {
@ -433,25 +443,25 @@ class Tokenizer(metaclass=_Tokenizer):
"#": TokenType.HASH,
}
QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
BIT_STRINGS: t.List[str | t.Tuple[str, str]] = []
HEX_STRINGS: t.List[str | t.Tuple[str, str]] = []
BYTE_STRINGS: t.List[str | t.Tuple[str, str]] = []
HEX_STRINGS: t.List[str | t.Tuple[str, str]] = []
IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
IDENTIFIER_ESCAPES = ['"']
QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
STRING_ESCAPES = ["'"]
VAR_SINGLE_TOKENS: t.Set[str] = set()
_COMMENTS: t.Dict[str, str] = {}
_BIT_STRINGS: t.Dict[str, str] = {}
_BYTE_STRINGS: t.Dict[str, str] = {}
_HEX_STRINGS: t.Dict[str, str] = {}
_IDENTIFIERS: t.Dict[str, str] = {}
_IDENTIFIER_ESCAPES: t.Set[str] = set()
_QUOTES: t.Dict[str, str] = {}
_STRING_ESCAPES: t.Set[str] = set()
IDENTIFIER_ESCAPES = ['"']
_IDENTIFIER_ESCAPES: t.Set[str] = set()
KEYWORDS = {
KEYWORDS: t.Dict[t.Optional[str], TokenType] = {
**{f"{{%{postfix}": TokenType.BLOCK_START for postfix in ("", "+", "-")},
**{f"{prefix}%}}": TokenType.BLOCK_END for prefix in ("", "+", "-")},
"{{+": TokenType.BLOCK_START,
@ -553,6 +563,7 @@ class Tokenizer(metaclass=_Tokenizer):
"IS": TokenType.IS,
"ISNULL": TokenType.ISNULL,
"JOIN": TokenType.JOIN,
"KEEP": TokenType.KEEP,
"LATERAL": TokenType.LATERAL,
"LAZY": TokenType.LAZY,
"LEADING": TokenType.LEADING,
@ -565,6 +576,7 @@ class Tokenizer(metaclass=_Tokenizer):
"MERGE": TokenType.MERGE,
"NATURAL": TokenType.NATURAL,
"NEXT": TokenType.NEXT,
"NEXT VALUE FOR": TokenType.NEXT_VALUE_FOR,
"NO ACTION": TokenType.NO_ACTION,
"NOT": TokenType.NOT,
"NOTNULL": TokenType.NOTNULL,
@ -632,6 +644,7 @@ class Tokenizer(metaclass=_Tokenizer):
"UPDATE": TokenType.UPDATE,
"USE": TokenType.USE,
"USING": TokenType.USING,
"UUID": TokenType.UUID,
"VALUES": TokenType.VALUES,
"VIEW": TokenType.VIEW,
"VOLATILE": TokenType.VOLATILE,
@ -661,6 +674,8 @@ class Tokenizer(metaclass=_Tokenizer):
"INT8": TokenType.BIGINT,
"DEC": TokenType.DECIMAL,
"DECIMAL": TokenType.DECIMAL,
"BIGDECIMAL": TokenType.BIGDECIMAL,
"BIGNUMERIC": TokenType.BIGDECIMAL,
"MAP": TokenType.MAP,
"NULLABLE": TokenType.NULLABLE,
"NUMBER": TokenType.DECIMAL,
@ -742,7 +757,7 @@ class Tokenizer(metaclass=_Tokenizer):
ENCODE: t.Optional[str] = None
COMMENTS = ["--", ("/*", "*/"), ("{#", "#}")]
KEYWORD_TRIE = None # autofilled
KEYWORD_TRIE: t.Dict = {} # autofilled
IDENTIFIER_CAN_START_WITH_DIGIT = False
@ -776,19 +791,28 @@ class Tokenizer(metaclass=_Tokenizer):
self._col = 1
self._comments: t.List[str] = []
self._char = None
self._end = None
self._peek = None
self._char = ""
self._end = False
self._peek = ""
self._prev_token_line = -1
self._prev_token_comments: t.List[str] = []
self._prev_token_type = None
self._prev_token_type: t.Optional[TokenType] = None
def tokenize(self, sql: str) -> t.List[Token]:
"""Returns a list of tokens corresponding to the SQL string `sql`."""
self.reset()
self.sql = sql
self.size = len(sql)
self._scan()
try:
self._scan()
except Exception as e:
start = self._current - 50
end = self._current + 50
start = start if start > 0 else 0
end = end if end < self.size else self.size - 1
context = self.sql[start:end]
raise ValueError(f"Error tokenizing '{context}'") from e
return self.tokens
def _scan(self, until: t.Optional[t.Callable] = None) -> None:
@ -810,9 +834,12 @@ class Tokenizer(metaclass=_Tokenizer):
if until and until():
break
if self.tokens:
self.tokens[-1].comments.extend(self._comments)
def _chars(self, size: int) -> str:
if size == 1:
return self._char # type: ignore
return self._char
start = self._current - 1
end = start + size
if end <= self.size:
@ -821,17 +848,15 @@ class Tokenizer(metaclass=_Tokenizer):
def _advance(self, i: int = 1) -> None:
if self.WHITE_SPACE.get(self._char) is TokenType.BREAK:
self._set_new_line()
self._col = 1
self._line += 1
else:
self._col += i
self._col += i
self._current += i
self._end = self._current >= self.size # type: ignore
self._char = self.sql[self._current - 1] # type: ignore
self._peek = self.sql[self._current] if self._current < self.size else "" # type: ignore
def _set_new_line(self) -> None:
self._col = 1
self._line += 1
self._end = self._current >= self.size
self._char = self.sql[self._current - 1]
self._peek = "" if self._end else self.sql[self._current]
@property
def _text(self) -> str:
@ -840,13 +865,14 @@ class Tokenizer(metaclass=_Tokenizer):
def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
self._prev_token_line = self._line
self._prev_token_comments = self._comments
self._prev_token_type = token_type # type: ignore
self._prev_token_type = token_type
self.tokens.append(
Token(
token_type,
self._text if text is None else text,
self._line,
self._col,
self._current,
self._comments,
)
)
@ -881,7 +907,7 @@ class Tokenizer(metaclass=_Tokenizer):
if skip:
result = 1
else:
result, trie = in_trie(trie, char.upper()) # type: ignore
result, trie = in_trie(trie, char.upper())
if result == 0:
break
@ -910,7 +936,7 @@ class Tokenizer(metaclass=_Tokenizer):
if not word:
if self._char in self.SINGLE_TOKENS:
self._add(self.SINGLE_TOKENS[self._char], text=self._char) # type: ignore
self._add(self.SINGLE_TOKENS[self._char], text=self._char)
return
self._scan_var()
return
@ -927,29 +953,31 @@ class Tokenizer(metaclass=_Tokenizer):
self._add(self.KEYWORDS[word], text=word)
def _scan_comment(self, comment_start: str) -> bool:
if comment_start not in self._COMMENTS: # type: ignore
if comment_start not in self._COMMENTS:
return False
comment_start_line = self._line
comment_start_size = len(comment_start)
comment_end = self._COMMENTS[comment_start] # type: ignore
comment_end = self._COMMENTS[comment_start]
if comment_end:
comment_end_size = len(comment_end)
# Skip the comment's start delimiter
self._advance(comment_start_size)
comment_end_size = len(comment_end)
while not self._end and self._chars(comment_end_size) != comment_end:
self._advance()
self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) # type: ignore
self._comments.append(self._text[comment_start_size : -comment_end_size + 1])
self._advance(comment_end_size - 1)
else:
while not self._end and not self.WHITE_SPACE.get(self._peek) is TokenType.BREAK:
self._advance()
self._comments.append(self._text[comment_start_size:]) # type: ignore
self._comments.append(self._text[comment_start_size:])
# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding.
# Multiple consecutive comments are preserved by appending them to the current comments list.
if comment_start_line == self._prev_token_line or self._end:
if comment_start_line == self._prev_token_line:
self.tokens[-1].comments.extend(self._comments)
self._comments = []
self._prev_token_line = self._line
@ -958,7 +986,7 @@ class Tokenizer(metaclass=_Tokenizer):
def _scan_number(self) -> None:
if self._char == "0":
peek = self._peek.upper() # type: ignore
peek = self._peek.upper()
if peek == "B":
return self._scan_bits()
elif peek == "X":
@ -968,7 +996,7 @@ class Tokenizer(metaclass=_Tokenizer):
scientific = 0
while True:
if self._peek.isdigit(): # type: ignore
if self._peek.isdigit():
self._advance()
elif self._peek == "." and not decimal:
decimal = True
@ -976,24 +1004,23 @@ class Tokenizer(metaclass=_Tokenizer):
elif self._peek in ("-", "+") and scientific == 1:
scientific += 1
self._advance()
elif self._peek.upper() == "E" and not scientific: # type: ignore
elif self._peek.upper() == "E" and not scientific:
scientific += 1
self._advance()
elif self._peek.isidentifier(): # type: ignore
elif self._peek.isidentifier():
number_text = self._text
literal = []
literal = ""
while self._peek.strip() and self._peek not in self.SINGLE_TOKENS: # type: ignore
literal.append(self._peek.upper()) # type: ignore
while self._peek.strip() and self._peek not in self.SINGLE_TOKENS:
literal += self._peek.upper()
self._advance()
literal = "".join(literal) # type: ignore
token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) # type: ignore
token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal))
if token_type:
self._add(TokenType.NUMBER, number_text)
self._add(TokenType.DCOLON, "::")
return self._add(token_type, literal) # type: ignore
return self._add(token_type, literal)
elif self.IDENTIFIER_CAN_START_WITH_DIGIT:
return self._add(TokenType.VAR)
@ -1020,7 +1047,7 @@ class Tokenizer(metaclass=_Tokenizer):
def _extract_value(self) -> str:
while True:
char = self._peek.strip() # type: ignore
char = self._peek.strip()
if char and char not in self.SINGLE_TOKENS:
self._advance()
else:
@ -1029,35 +1056,35 @@ class Tokenizer(metaclass=_Tokenizer):
return self._text
def _scan_string(self, quote: str) -> bool:
quote_end = self._QUOTES.get(quote) # type: ignore
quote_end = self._QUOTES.get(quote)
if quote_end is None:
return False
self._advance(len(quote))
text = self._extract_string(quote_end)
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
self._add(TokenType.NATIONAL if quote[0].upper() == "N" else TokenType.STRING, text)
return True
# X'1234, b'0110', E'\\\\\' etc.
def _scan_formatted_string(self, string_start: str) -> bool:
if string_start in self._HEX_STRINGS: # type: ignore
delimiters = self._HEX_STRINGS # type: ignore
if string_start in self._HEX_STRINGS:
delimiters = self._HEX_STRINGS
token_type = TokenType.HEX_STRING
base = 16
elif string_start in self._BIT_STRINGS: # type: ignore
delimiters = self._BIT_STRINGS # type: ignore
elif string_start in self._BIT_STRINGS:
delimiters = self._BIT_STRINGS
token_type = TokenType.BIT_STRING
base = 2
elif string_start in self._BYTE_STRINGS: # type: ignore
delimiters = self._BYTE_STRINGS # type: ignore
elif string_start in self._BYTE_STRINGS:
delimiters = self._BYTE_STRINGS
token_type = TokenType.BYTE_STRING
base = None
else:
return False
self._advance(len(string_start))
string_end = delimiters.get(string_start)
string_end = delimiters[string_start]
text = self._extract_string(string_end)
if base is None:
@ -1083,20 +1110,20 @@ class Tokenizer(metaclass=_Tokenizer):
self._advance()
if self._char == identifier_end:
if identifier_end_is_escape and self._peek == identifier_end:
text += identifier_end # type: ignore
text += identifier_end
self._advance()
continue
break
text += self._char # type: ignore
text += self._char
self._add(TokenType.IDENTIFIER, text)
def _scan_var(self) -> None:
while True:
char = self._peek.strip() # type: ignore
if char and char not in self.SINGLE_TOKENS:
char = self._peek.strip()
if char and (char in self.VAR_SINGLE_TOKENS or char not in self.SINGLE_TOKENS):
self._advance()
else:
break
@ -1115,9 +1142,9 @@ class Tokenizer(metaclass=_Tokenizer):
self._peek == delimiter or self._peek in self._STRING_ESCAPES
):
if self._peek == delimiter:
text += self._peek # type: ignore
text += self._peek
else:
text += self._char + self._peek # type: ignore
text += self._char + self._peek
if self._current + 1 < self.size:
self._advance(2)
@ -1131,7 +1158,7 @@ class Tokenizer(metaclass=_Tokenizer):
if self._end:
raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}")
text += self._char # type: ignore
text += self._char
self._advance()
return text

View file

@ -103,7 +103,11 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
if isinstance(expr, exp.Window):
alias = find_new_name(expression.named_selects, "_w")
expression.select(exp.alias_(expr.copy(), alias), copy=False)
expr.replace(exp.column(alias))
column = exp.column(alias)
if isinstance(expr.parent, exp.Qualify):
qualify_filters = column
else:
expr.replace(column)
elif expr.name not in expression.named_selects:
expression.select(expr.copy(), copy=False)
@ -133,9 +137,111 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr
)
def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
"""Convert cross join unnest into lateral view explode (used in presto -> hive)."""
if isinstance(expression, exp.Select):
for join in expression.args.get("joins") or []:
unnest = join.this
if isinstance(unnest, exp.Unnest):
alias = unnest.args.get("alias")
udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
expression.args["joins"].remove(join)
for e, column in zip(unnest.expressions, alias.columns if alias else []):
expression.append(
"laterals",
exp.Lateral(
this=udtf(this=e),
view=True,
alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore
),
)
return expression
def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
"""Convert explode/posexplode into unnest (used in hive -> presto)."""
if isinstance(expression, exp.Select):
from sqlglot.optimizer.scope import build_scope
taken_select_names = set(expression.named_selects)
taken_source_names = set(build_scope(expression).selected_sources)
for select in expression.selects:
to_replace = select
pos_alias = ""
explode_alias = ""
if isinstance(select, exp.Alias):
explode_alias = select.alias
select = select.this
elif isinstance(select, exp.Aliases):
pos_alias = select.aliases[0].name
explode_alias = select.aliases[1].name
select = select.this
if isinstance(select, (exp.Explode, exp.Posexplode)):
is_posexplode = isinstance(select, exp.Posexplode)
explode_arg = select.this
unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode)
# This ensures that we won't use [POS]EXPLODE's argument as a new selection
if isinstance(explode_arg, exp.Column):
taken_select_names.add(explode_arg.output_name)
unnest_source_alias = find_new_name(taken_source_names, "_u")
taken_source_names.add(unnest_source_alias)
if not explode_alias:
explode_alias = find_new_name(taken_select_names, "col")
taken_select_names.add(explode_alias)
if is_posexplode:
pos_alias = find_new_name(taken_select_names, "pos")
taken_select_names.add(pos_alias)
if is_posexplode:
column_names = [explode_alias, pos_alias]
to_replace.pop()
expression.select(pos_alias, explode_alias, copy=False)
else:
column_names = [explode_alias]
to_replace.replace(exp.column(explode_alias))
unnest = exp.alias_(unnest, unnest_source_alias, table=column_names)
if not expression.args.get("from"):
expression.from_(unnest, copy=False)
else:
expression.join(unnest, join_type="CROSS", copy=False)
return expression
def remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
"""Remove table refs from columns in when statements."""
if isinstance(expression, exp.Merge):
alias = expression.this.args.get("alias")
targets = {expression.this.this}
if alias:
targets.add(alias.this)
for when in expression.expressions:
when.transform(
lambda node: exp.column(node.name)
if isinstance(node, exp.Column) and node.args.get("table") in targets
else node,
copy=False,
)
return expression
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
to_sql: t.Callable[[Generator, exp.Expression], str],
) -> t.Callable[[Generator, exp.Expression], str]:
"""
Creates a new transform by chaining a sequence of transformations and converts the resulting
@ -143,36 +249,23 @@ def preprocess(
Args:
transforms: sequence of transform functions. These will be called in order.
to_sql: final transform that converts the resulting expression to a SQL string.
Returns:
Function that can be used as a generator transform.
"""
def _to_sql(self, expression):
def _to_sql(self, expression: exp.Expression) -> str:
expression = transforms[0](expression.copy())
for t in transforms[1:]:
expression = t(expression)
return to_sql(self, expression)
return getattr(self, expression.key + "_sql")(expression)
return _to_sql
def delegate(attr: str) -> t.Callable:
"""
Create a new method that delegates to `attr`. This is useful for creating `Generator.TRANSFORMS`
functions that delegate to existing generator methods.
"""
def _transform(self, *args, **kwargs):
return getattr(self, attr)(*args, **kwargs)
return _transform
UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))}
ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))}
ELIMINATE_QUALIFY = {exp.Select: preprocess([eliminate_qualify], delegate("select_sql"))}
UNALIAS_GROUP = {exp.Group: preprocess([unalias_group])}
ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on])}
ELIMINATE_QUALIFY = {exp.Select: preprocess([eliminate_qualify])}
REMOVE_PRECISION_PARAMETERIZED_TYPES = {
exp.Cast: preprocess([remove_precision_parameterized_types], delegate("cast_sql"))
exp.Cast: preprocess([remove_precision_parameterized_types])
}

View file

@ -3,7 +3,7 @@ import typing as t
key = t.Sequence[t.Hashable]
def new_trie(keywords: t.Iterable[key]) -> t.Dict:
def new_trie(keywords: t.Iterable[key], trie: t.Optional[t.Dict] = None) -> t.Dict:
"""
Creates a new trie out of a collection of keywords.
@ -16,11 +16,12 @@ def new_trie(keywords: t.Iterable[key]) -> t.Dict:
Args:
keywords: the keywords to create the trie from.
trie: a trie to mutate instead of creating a new one
Returns:
The trie corresponding to `keywords`.
"""
trie: t.Dict = {}
trie = {} if trie is None else trie
for key in keywords:
current = trie