Merging upstream version 11.4.5.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
0a06643852
commit
88f99e1c27
131 changed files with 53004 additions and 37079 deletions
|
@ -47,7 +47,7 @@ if t.TYPE_CHECKING:
|
|||
T = t.TypeVar("T", bound=Expression)
|
||||
|
||||
|
||||
__version__ = "11.4.1"
|
||||
__version__ = "11.4.5"
|
||||
|
||||
pretty = False
|
||||
"""Whether to format generated SQL by default."""
|
||||
|
|
|
@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
|
|||
Dialect,
|
||||
datestrtodate_sql,
|
||||
inline_array_sql,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
no_ilike_sql,
|
||||
rename_func,
|
||||
|
@ -212,6 +213,9 @@ class BigQuery(Dialect):
|
|||
),
|
||||
}
|
||||
|
||||
LOG_BASE_FIRST = False
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
class Generator(generator.Generator):
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
|
@ -227,6 +231,7 @@ class BigQuery(Dialect):
|
|||
exp.GroupConcat: rename_func("STRING_AGG"),
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.IntDiv: rename_func("DIV"),
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.Select: transforms.preprocess(
|
||||
[_unqualify_unnest], transforms.delegate("select_sql")
|
||||
|
@ -253,17 +258,19 @@ class BigQuery(Dialect):
|
|||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.TINYINT: "INT64",
|
||||
exp.DataType.Type.SMALLINT: "INT64",
|
||||
exp.DataType.Type.INT: "INT64",
|
||||
exp.DataType.Type.BIGINT: "INT64",
|
||||
exp.DataType.Type.DECIMAL: "NUMERIC",
|
||||
exp.DataType.Type.FLOAT: "FLOAT64",
|
||||
exp.DataType.Type.DOUBLE: "FLOAT64",
|
||||
exp.DataType.Type.BOOLEAN: "BOOL",
|
||||
exp.DataType.Type.TEXT: "STRING",
|
||||
exp.DataType.Type.VARCHAR: "STRING",
|
||||
exp.DataType.Type.CHAR: "STRING",
|
||||
exp.DataType.Type.DECIMAL: "NUMERIC",
|
||||
exp.DataType.Type.DOUBLE: "FLOAT64",
|
||||
exp.DataType.Type.FLOAT: "FLOAT64",
|
||||
exp.DataType.Type.INT: "INT64",
|
||||
exp.DataType.Type.NCHAR: "STRING",
|
||||
exp.DataType.Type.NVARCHAR: "STRING",
|
||||
exp.DataType.Type.SMALLINT: "INT64",
|
||||
exp.DataType.Type.TEXT: "STRING",
|
||||
exp.DataType.Type.TINYINT: "INT64",
|
||||
exp.DataType.Type.VARCHAR: "STRING",
|
||||
}
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
|
@ -271,6 +278,7 @@ class BigQuery(Dialect):
|
|||
}
|
||||
|
||||
EXPLICIT_UNION = True
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
|
||||
def array_sql(self, expression: exp.Array) -> str:
|
||||
first_arg = seq_get(expression.expressions, 0)
|
||||
|
|
|
@ -68,6 +68,8 @@ class ClickHouse(Dialect):
|
|||
|
||||
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore
|
||||
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
def _parse_in(
|
||||
self, this: t.Optional[exp.Expression], is_global: bool = False
|
||||
) -> exp.Expression:
|
||||
|
|
|
@ -16,6 +16,8 @@ class Databricks(Spark):
|
|||
"DATEDIFF": parse_date_delta(exp.DateDiff),
|
||||
}
|
||||
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
class Generator(Spark.Generator):
|
||||
TRANSFORMS = {
|
||||
**Spark.Generator.TRANSFORMS, # type: ignore
|
||||
|
|
|
@ -430,6 +430,11 @@ def min_or_least(self: Generator, expression: exp.Min) -> str:
|
|||
return rename_func(name)(self, expression)
|
||||
|
||||
|
||||
def max_or_greatest(self: Generator, expression: exp.Max) -> str:
|
||||
name = "GREATEST" if expression.expressions else "MAX"
|
||||
return rename_func(name)(self, expression)
|
||||
|
||||
|
||||
def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
|
||||
cond = expression.this
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
|
@ -102,6 +101,8 @@ class Drill(Dialect):
|
|||
"TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"),
|
||||
}
|
||||
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
class Generator(generator.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
|
@ -154,4 +155,4 @@ class Drill(Dialect):
|
|||
}
|
||||
|
||||
def normalize_func(self, name: str) -> str:
|
||||
return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`"
|
||||
return name if exp.SAFE_IDENTIFIER_RE.match(name) else f"`{name}`"
|
||||
|
|
|
@ -80,6 +80,7 @@ class DuckDB(Dialect):
|
|||
class Tokenizer(tokens.Tokenizer):
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"~": TokenType.RLIKE,
|
||||
":=": TokenType.EQ,
|
||||
"ATTACH": TokenType.COMMAND,
|
||||
"BINARY": TokenType.VARBINARY,
|
||||
|
@ -212,5 +213,7 @@ class DuckDB(Dialect):
|
|||
"except": "EXCLUDE",
|
||||
}
|
||||
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
|
||||
def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str:
|
||||
return super().tablesample_sql(expression, seed_prefix="REPEATABLE")
|
||||
|
|
|
@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
|
|||
format_time_lambda,
|
||||
if_sql,
|
||||
locate_to_strposition,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
no_ilike_sql,
|
||||
no_recursive_cte_sql,
|
||||
|
@ -34,6 +35,13 @@ DATE_DELTA_INTERVAL = {
|
|||
"DAY": ("DATE_ADD", 1),
|
||||
}
|
||||
|
||||
TIME_DIFF_FACTOR = {
|
||||
"MILLISECOND": " * 1000",
|
||||
"SECOND": "",
|
||||
"MINUTE": " / 60",
|
||||
"HOUR": " / 3600",
|
||||
}
|
||||
|
||||
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
|
||||
|
||||
|
||||
|
@ -51,6 +59,14 @@ def _add_date_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
|
|||
|
||||
def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
|
||||
unit = expression.text("unit").upper()
|
||||
|
||||
factor = TIME_DIFF_FACTOR.get(unit)
|
||||
if factor is not None:
|
||||
left = self.sql(expression, "this")
|
||||
right = self.sql(expression, "expression")
|
||||
sec_diff = f"UNIX_TIMESTAMP({left}) - UNIX_TIMESTAMP({right})"
|
||||
return f"({sec_diff}){factor}" if factor else sec_diff
|
||||
|
||||
sql_func = "MONTHS_BETWEEN" if unit in DIFF_MONTH_SWITCH else "DATEDIFF"
|
||||
_, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1))
|
||||
multiplier_sql = f" / {multiplier}" if multiplier > 1 else ""
|
||||
|
@ -237,11 +253,6 @@ class Hive(Dialect):
|
|||
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
|
||||
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
|
||||
"LOCATE": locate_to_strposition,
|
||||
"LOG": (
|
||||
lambda args: exp.Log.from_arg_list(args)
|
||||
if len(args) > 1
|
||||
else exp.Ln.from_arg_list(args)
|
||||
),
|
||||
"MAP": parse_var_map,
|
||||
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
|
||||
"PERCENTILE": exp.Quantile.from_arg_list,
|
||||
|
@ -261,6 +272,8 @@ class Hive(Dialect):
|
|||
),
|
||||
}
|
||||
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
class Generator(generator.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
|
@ -293,6 +306,7 @@ class Hive(Dialect):
|
|||
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
|
||||
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
|
||||
exp.Map: var_map_sql,
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.VarMap: var_map_sql,
|
||||
exp.Create: create_with_partitions_sql,
|
||||
|
@ -338,6 +352,8 @@ class Hive(Dialect):
|
|||
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
}
|
||||
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
|
||||
def arrayagg_sql(self, expression: exp.ArrayAgg) -> str:
|
||||
return self.func(
|
||||
"COLLECT_LIST",
|
||||
|
|
|
@ -3,7 +3,9 @@ from __future__ import annotations
|
|||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
arrow_json_extract_scalar_sql,
|
||||
locate_to_strposition,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
no_ilike_sql,
|
||||
no_paren_current_date_sql,
|
||||
|
@ -288,6 +290,8 @@ class MySQL(Dialect):
|
|||
"SWAPS",
|
||||
}
|
||||
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
def _parse_show_mysql(self, this, target=False, full=None, global_=None):
|
||||
if target:
|
||||
if isinstance(target, str):
|
||||
|
@ -303,7 +307,13 @@ class MySQL(Dialect):
|
|||
db = None
|
||||
else:
|
||||
position = None
|
||||
db = self._parse_id_var() if self._match_text_seq("FROM") else None
|
||||
db = None
|
||||
|
||||
if self._match(TokenType.FROM):
|
||||
db = self._parse_id_var()
|
||||
elif self._match(TokenType.DOT):
|
||||
db = target_id
|
||||
target_id = self._parse_id_var()
|
||||
|
||||
channel = self._parse_id_var() if self._match_text_seq("FOR", "CHANNEL") else None
|
||||
|
||||
|
@ -384,6 +394,8 @@ class MySQL(Dialect):
|
|||
exp.CurrentDate: no_paren_current_date_sql,
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
exp.TryCast: no_trycast_sql,
|
||||
|
@ -415,6 +427,8 @@ class MySQL(Dialect):
|
|||
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
|
||||
def show_sql(self, expression):
|
||||
this = f" {expression.name}"
|
||||
full = " FULL" if expression.args.get("full") else ""
|
||||
|
|
|
@ -4,7 +4,7 @@ import typing as t
|
|||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sql
|
||||
from sqlglot.helper import csv, seq_get
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
|
||||
|
@ -13,10 +13,6 @@ PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
|
|||
}
|
||||
|
||||
|
||||
def _limit_sql(self, expression):
|
||||
return self.fetch_sql(exp.Fetch(direction="FIRST", count=expression.expression))
|
||||
|
||||
|
||||
def _parse_xml_table(self) -> exp.XMLTable:
|
||||
this = self._parse_string()
|
||||
|
||||
|
@ -89,6 +85,20 @@ class Oracle(Dialect):
|
|||
column.set("join_mark", self._match(TokenType.JOIN_MARKER))
|
||||
return column
|
||||
|
||||
def _parse_hint(self) -> t.Optional[exp.Expression]:
|
||||
if self._match(TokenType.HINT):
|
||||
start = self._curr
|
||||
while self._curr and not self._match_pair(TokenType.STAR, TokenType.SLASH):
|
||||
self._advance()
|
||||
|
||||
if not self._curr:
|
||||
self.raise_error("Expected */ after HINT")
|
||||
|
||||
end = self._tokens[self._index - 3]
|
||||
return exp.Hint(expressions=[self._find_sql(start, end)])
|
||||
|
||||
return None
|
||||
|
||||
class Generator(generator.Generator):
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
|
||||
|
@ -110,41 +120,20 @@ class Oracle(Dialect):
|
|||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.Limit: _limit_sql,
|
||||
exp.Trim: trim_sql,
|
||||
exp.Matches: rename_func("DECODE"),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
|
||||
exp.Substring: rename_func("SUBSTR"),
|
||||
exp.Table: lambda self, e: self.table_sql(e, sep=" "),
|
||||
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
|
||||
exp.Substring: rename_func("SUBSTR"),
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.Trim: trim_sql,
|
||||
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
|
||||
}
|
||||
|
||||
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
|
||||
return csv(
|
||||
*sqls,
|
||||
*[self.sql(sql) for sql in expression.args.get("joins") or []],
|
||||
self.sql(expression, "match"),
|
||||
*[self.sql(sql) for sql in expression.args.get("laterals") or []],
|
||||
self.sql(expression, "where"),
|
||||
self.sql(expression, "group"),
|
||||
self.sql(expression, "having"),
|
||||
self.sql(expression, "qualify"),
|
||||
self.seg("WINDOW ") + self.expressions(expression, "windows", flat=True)
|
||||
if expression.args.get("windows")
|
||||
else "",
|
||||
self.sql(expression, "distribute"),
|
||||
self.sql(expression, "sort"),
|
||||
self.sql(expression, "cluster"),
|
||||
self.sql(expression, "order"),
|
||||
self.sql(expression, "offset"), # offset before limit in oracle
|
||||
self.sql(expression, "limit"),
|
||||
self.sql(expression, "lock"),
|
||||
sep="",
|
||||
)
|
||||
LIMIT_FETCH = "FETCH"
|
||||
|
||||
def offset_sql(self, expression: exp.Offset) -> str:
|
||||
return f"{super().offset_sql(expression)} ROWS"
|
||||
|
|
|
@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import (
|
|||
arrow_json_extract_scalar_sql,
|
||||
arrow_json_extract_sql,
|
||||
format_time_lambda,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
no_paren_current_date_sql,
|
||||
no_tablesample_sql,
|
||||
|
@ -315,6 +316,7 @@ class Postgres(Dialect):
|
|||
exp.DateDiff: _date_diff_sql,
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
|
||||
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
|
||||
|
|
|
@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import (
|
|||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
inline_array_sql,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
rename_func,
|
||||
timestamptrunc_sql,
|
||||
|
@ -275,6 +276,9 @@ class Snowflake(Dialect):
|
|||
exp.ArrayConcat: rename_func("ARRAY_CAT"),
|
||||
exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
|
||||
exp.DateAdd: lambda self, e: self.func("DATEADD", e.text("unit"), e.expression, e.this),
|
||||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATEDIFF", e.text("unit"), e.expression, e.this
|
||||
),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.If: rename_func("IFF"),
|
||||
|
@ -296,6 +300,7 @@ class Snowflake(Dialect):
|
|||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"),
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
}
|
||||
|
||||
|
@ -314,12 +319,6 @@ class Snowflake(Dialect):
|
|||
exp.SetProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def ilikeany_sql(self, expression: exp.ILikeAny) -> str:
|
||||
return self.binary(expression, "ILIKE ANY")
|
||||
|
||||
def likeany_sql(self, expression: exp.LikeAny) -> str:
|
||||
return self.binary(expression, "LIKE ANY")
|
||||
|
||||
def except_op(self, expression):
|
||||
if not expression.args.get("distinct", False):
|
||||
self.unsupported("EXCEPT with All is not supported in Snowflake")
|
||||
|
|
|
@ -82,6 +82,8 @@ class SQLite(Dialect):
|
|||
exp.TryCast: no_trycast_sql,
|
||||
}
|
||||
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
|
||||
def cast_sql(self, expression: exp.Cast) -> str:
|
||||
if expression.to.this == exp.DataType.Type.DATE:
|
||||
return self.func("DATE", expression.this)
|
||||
|
@ -115,9 +117,6 @@ class SQLite(Dialect):
|
|||
|
||||
return f"CAST({sql} AS INTEGER)"
|
||||
|
||||
def fetch_sql(self, expression: exp.Fetch) -> str:
|
||||
return self.limit_sql(exp.Limit(expression=expression.args.get("count")))
|
||||
|
||||
# https://www.sqlite.org/lang_aggfunc.html#group_concat
|
||||
def groupconcat_sql(self, expression):
|
||||
this = expression.this
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import Dialect, min_or_least
|
||||
from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
|
@ -128,6 +128,7 @@ class Teradata(Dialect):
|
|||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import typing as t
|
|||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
parse_date_delta,
|
||||
rename_func,
|
||||
|
@ -269,7 +270,6 @@ class TSQL(Dialect):
|
|||
|
||||
# TSQL allows @, # to appear as a variable/identifier prefix
|
||||
SINGLE_TOKENS = tokens.Tokenizer.SINGLE_TOKENS.copy()
|
||||
SINGLE_TOKENS.pop("@")
|
||||
SINGLE_TOKENS.pop("#")
|
||||
|
||||
class Parser(parser.Parser):
|
||||
|
@ -313,6 +313,9 @@ class TSQL(Dialect):
|
|||
TokenType.END: lambda self: self._parse_command(),
|
||||
}
|
||||
|
||||
LOG_BASE_FIRST = False
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
def _parse_system_time(self) -> t.Optional[exp.Expression]:
|
||||
if not self._match_text_seq("FOR", "SYSTEM_TIME"):
|
||||
return None
|
||||
|
@ -435,11 +438,17 @@ class TSQL(Dialect):
|
|||
exp.NumberToStr: _format_sql,
|
||||
exp.TimeToStr: _format_sql,
|
||||
exp.GroupConcat: _string_agg_sql,
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
}
|
||||
|
||||
TRANSFORMS.pop(exp.ReturnsProperty)
|
||||
|
||||
LIMIT_FETCH = "FETCH"
|
||||
|
||||
def offset_sql(self, expression: exp.Offset) -> str:
|
||||
return f"{super().offset_sql(expression)} ROWS"
|
||||
|
||||
def systemtime_sql(self, expression: exp.SystemTime) -> str:
|
||||
kind = expression.args["kind"]
|
||||
if kind == "ALL":
|
||||
|
|
|
@ -12,7 +12,7 @@ from dataclasses import dataclass
|
|||
from heapq import heappop, heappush
|
||||
|
||||
from sqlglot import Dialect, expressions as exp
|
||||
from sqlglot.helper import ensure_collection
|
||||
from sqlglot.helper import ensure_list
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
@ -151,8 +151,8 @@ class ChangeDistiller:
|
|||
|
||||
self._source = source
|
||||
self._target = target
|
||||
self._source_index = {id(n[0]): n[0] for n in source.bfs()}
|
||||
self._target_index = {id(n[0]): n[0] for n in target.bfs()}
|
||||
self._source_index = {id(n): n for n, *_ in self._source.bfs()}
|
||||
self._target_index = {id(n): n for n, *_ in self._target.bfs()}
|
||||
self._unmatched_source_nodes = set(self._source_index) - set(pre_matched_nodes)
|
||||
self._unmatched_target_nodes = set(self._target_index) - set(pre_matched_nodes.values())
|
||||
self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {}
|
||||
|
@ -199,10 +199,10 @@ class ChangeDistiller:
|
|||
matching_set = leaves_matching_set.copy()
|
||||
|
||||
ordered_unmatched_source_nodes = {
|
||||
id(n[0]): None for n in self._source.bfs() if id(n[0]) in self._unmatched_source_nodes
|
||||
id(n): None for n, *_ in self._source.bfs() if id(n) in self._unmatched_source_nodes
|
||||
}
|
||||
ordered_unmatched_target_nodes = {
|
||||
id(n[0]): None for n in self._target.bfs() if id(n[0]) in self._unmatched_target_nodes
|
||||
id(n): None for n, *_ in self._target.bfs() if id(n) in self._unmatched_target_nodes
|
||||
}
|
||||
|
||||
for source_node_id in ordered_unmatched_source_nodes:
|
||||
|
@ -304,18 +304,18 @@ class ChangeDistiller:
|
|||
def _get_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]:
|
||||
has_child_exprs = False
|
||||
|
||||
for a in expression.args.values():
|
||||
for node in ensure_collection(a):
|
||||
if isinstance(node, exp.Expression):
|
||||
has_child_exprs = True
|
||||
yield from _get_leaves(node)
|
||||
for _, node in expression.iter_expressions():
|
||||
has_child_exprs = True
|
||||
yield from _get_leaves(node)
|
||||
|
||||
if not has_child_exprs:
|
||||
yield expression
|
||||
|
||||
|
||||
def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool:
|
||||
if type(source) is type(target):
|
||||
if type(source) is type(target) and (
|
||||
not isinstance(source, exp.Identifier) or type(source.parent) is type(target.parent)
|
||||
):
|
||||
if isinstance(source, exp.Join):
|
||||
return source.args.get("side") == target.args.get("side")
|
||||
|
||||
|
@ -331,7 +331,7 @@ def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]:
|
|||
args: t.List[t.Union[exp.Expression, t.List]] = []
|
||||
if expression:
|
||||
for a in expression.args.values():
|
||||
args.extend(ensure_collection(a))
|
||||
args.extend(ensure_list(a))
|
||||
return [a for a in args if isinstance(a, exp.Expression)]
|
||||
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ def execute(
|
|||
for name, table in tables_.mapping.items()
|
||||
}
|
||||
|
||||
schema = ensure_schema(schema)
|
||||
schema = ensure_schema(schema, dialect=read)
|
||||
|
||||
if tables_.supported_table_args and tables_.supported_table_args != schema.supported_table_args:
|
||||
raise ExecuteError("Tables must support the same table args as schema")
|
||||
|
|
|
@ -94,13 +94,10 @@ class PythonExecutor:
|
|||
if source and isinstance(source, exp.Expression):
|
||||
source = source.name or source.alias
|
||||
|
||||
condition = self.generate(step.condition)
|
||||
projections = self.generate_tuple(step.projections)
|
||||
|
||||
if source is None:
|
||||
context, table_iter = self.static()
|
||||
elif source in context:
|
||||
if not projections and not condition:
|
||||
if not step.projections and not step.condition:
|
||||
return self.context({step.name: context.tables[source]})
|
||||
table_iter = context.table_iter(source)
|
||||
elif isinstance(step.source, exp.Table) and isinstance(step.source.this, exp.ReadCSV):
|
||||
|
@ -109,10 +106,12 @@ class PythonExecutor:
|
|||
else:
|
||||
context, table_iter = self.scan_table(step)
|
||||
|
||||
if projections:
|
||||
sink = self.table(step.projections)
|
||||
else:
|
||||
sink = self.table(context.columns)
|
||||
return self.context({step.name: self._project_and_filter(context, step, table_iter)})
|
||||
|
||||
def _project_and_filter(self, context, step, table_iter):
|
||||
sink = self.table(step.projections if step.projections else context.columns)
|
||||
condition = self.generate(step.condition)
|
||||
projections = self.generate_tuple(step.projections)
|
||||
|
||||
for reader in table_iter:
|
||||
if len(sink) >= step.limit:
|
||||
|
@ -126,7 +125,7 @@ class PythonExecutor:
|
|||
else:
|
||||
sink.append(reader.row)
|
||||
|
||||
return self.context({step.name: sink})
|
||||
return sink
|
||||
|
||||
def static(self):
|
||||
return self.context({}), [RowReader(())]
|
||||
|
@ -185,27 +184,16 @@ class PythonExecutor:
|
|||
if condition:
|
||||
source_context.filter(condition)
|
||||
|
||||
condition = self.generate(step.condition)
|
||||
projections = self.generate_tuple(step.projections)
|
||||
|
||||
if not condition and not projections:
|
||||
if not step.condition and not step.projections:
|
||||
return source_context
|
||||
|
||||
sink = self.table(step.projections if projections else source_context.columns)
|
||||
sink = self._project_and_filter(
|
||||
source_context,
|
||||
step,
|
||||
(reader for reader, _ in iter(source_context)),
|
||||
)
|
||||
|
||||
for reader, ctx in source_context:
|
||||
if condition and not ctx.eval(condition):
|
||||
continue
|
||||
|
||||
if projections:
|
||||
sink.append(ctx.eval_tuple(projections))
|
||||
else:
|
||||
sink.append(reader.row)
|
||||
|
||||
if len(sink) >= step.limit:
|
||||
break
|
||||
|
||||
if projections:
|
||||
if step.projections:
|
||||
return self.context({step.name: sink})
|
||||
else:
|
||||
return self.context(
|
||||
|
|
|
@ -26,6 +26,7 @@ from sqlglot.helper import (
|
|||
AutoName,
|
||||
camel_to_snake_case,
|
||||
ensure_collection,
|
||||
ensure_list,
|
||||
seq_get,
|
||||
split_num_words,
|
||||
subclasses,
|
||||
|
@ -84,7 +85,7 @@ class Expression(metaclass=_Expression):
|
|||
|
||||
key = "expression"
|
||||
arg_types = {"this": True}
|
||||
__slots__ = ("args", "parent", "arg_key", "comments", "_type", "_meta")
|
||||
__slots__ = ("args", "parent", "arg_key", "comments", "_type", "_meta", "_hash")
|
||||
|
||||
def __init__(self, **args: t.Any):
|
||||
self.args: t.Dict[str, t.Any] = args
|
||||
|
@ -93,22 +94,30 @@ class Expression(metaclass=_Expression):
|
|||
self.comments: t.Optional[t.List[str]] = None
|
||||
self._type: t.Optional[DataType] = None
|
||||
self._meta: t.Optional[t.Dict[str, t.Any]] = None
|
||||
self._hash: t.Optional[int] = None
|
||||
|
||||
for arg_key, value in self.args.items():
|
||||
self._set_parent(arg_key, value)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return type(self) is type(other) and _norm_args(self) == _norm_args(other)
|
||||
return type(self) is type(other) and hash(self) == hash(other)
|
||||
|
||||
@property
|
||||
def hashable_args(self) -> t.Any:
|
||||
args = (self.args.get(k) for k in self.arg_types)
|
||||
|
||||
return tuple(
|
||||
(tuple(_norm_arg(a) for a in arg) if arg else None)
|
||||
if type(arg) is list
|
||||
else (_norm_arg(arg) if arg is not None and arg is not False else None)
|
||||
for arg in args
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(
|
||||
self.key,
|
||||
tuple(
|
||||
(k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()
|
||||
),
|
||||
)
|
||||
)
|
||||
if self._hash is not None:
|
||||
return self._hash
|
||||
|
||||
return hash((self.__class__, self.hashable_args))
|
||||
|
||||
@property
|
||||
def this(self):
|
||||
|
@ -247,9 +256,6 @@ class Expression(metaclass=_Expression):
|
|||
"""
|
||||
new = deepcopy(self)
|
||||
new.parent = self.parent
|
||||
for item, parent, _ in new.bfs():
|
||||
if isinstance(item, Expression) and parent:
|
||||
item.parent = parent
|
||||
return new
|
||||
|
||||
def append(self, arg_key, value):
|
||||
|
@ -277,12 +283,12 @@ class Expression(metaclass=_Expression):
|
|||
self._set_parent(arg_key, value)
|
||||
|
||||
def _set_parent(self, arg_key, value):
|
||||
if isinstance(value, Expression):
|
||||
if hasattr(value, "parent"):
|
||||
value.parent = self
|
||||
value.arg_key = arg_key
|
||||
elif isinstance(value, list):
|
||||
elif type(value) is list:
|
||||
for v in value:
|
||||
if isinstance(v, Expression):
|
||||
if hasattr(v, "parent"):
|
||||
v.parent = self
|
||||
v.arg_key = arg_key
|
||||
|
||||
|
@ -295,6 +301,17 @@ class Expression(metaclass=_Expression):
|
|||
return self.parent.depth + 1
|
||||
return 0
|
||||
|
||||
def iter_expressions(self) -> t.Iterator[t.Tuple[str, Expression]]:
|
||||
"""Yields the key and expression for all arguments, exploding list args."""
|
||||
for k, vs in self.args.items():
|
||||
if type(vs) is list:
|
||||
for v in vs:
|
||||
if hasattr(v, "parent"):
|
||||
yield k, v
|
||||
else:
|
||||
if hasattr(vs, "parent"):
|
||||
yield k, vs
|
||||
|
||||
def find(self, *expression_types: t.Type[E], bfs=True) -> E | None:
|
||||
"""
|
||||
Returns the first node in this tree which matches at least one of
|
||||
|
@ -319,7 +336,7 @@ class Expression(metaclass=_Expression):
|
|||
Returns:
|
||||
The generator object.
|
||||
"""
|
||||
for expression, _, _ in self.walk(bfs=bfs):
|
||||
for expression, *_ in self.walk(bfs=bfs):
|
||||
if isinstance(expression, expression_types):
|
||||
yield expression
|
||||
|
||||
|
@ -345,6 +362,11 @@ class Expression(metaclass=_Expression):
|
|||
"""
|
||||
return self.find_ancestor(Select)
|
||||
|
||||
@property
|
||||
def same_parent(self):
|
||||
"""Returns if the parent is the same class as itself."""
|
||||
return type(self.parent) is self.__class__
|
||||
|
||||
def root(self) -> Expression:
|
||||
"""
|
||||
Returns the root expression of this tree.
|
||||
|
@ -385,10 +407,8 @@ class Expression(metaclass=_Expression):
|
|||
if prune and prune(self, parent, key):
|
||||
return
|
||||
|
||||
for k, v in self.args.items():
|
||||
for node in ensure_collection(v):
|
||||
if isinstance(node, Expression):
|
||||
yield from node.dfs(self, k, prune)
|
||||
for k, v in self.iter_expressions():
|
||||
yield from v.dfs(self, k, prune)
|
||||
|
||||
def bfs(self, prune=None):
|
||||
"""
|
||||
|
@ -407,18 +427,15 @@ class Expression(metaclass=_Expression):
|
|||
if prune and prune(item, parent, key):
|
||||
continue
|
||||
|
||||
if isinstance(item, Expression):
|
||||
for k, v in item.args.items():
|
||||
for node in ensure_collection(v):
|
||||
if isinstance(node, Expression):
|
||||
queue.append((node, item, k))
|
||||
for k, v in item.iter_expressions():
|
||||
queue.append((v, item, k))
|
||||
|
||||
def unnest(self):
|
||||
"""
|
||||
Returns the first non parenthesis child or self.
|
||||
"""
|
||||
expression = self
|
||||
while isinstance(expression, Paren):
|
||||
while type(expression) is Paren:
|
||||
expression = expression.this
|
||||
return expression
|
||||
|
||||
|
@ -434,7 +451,7 @@ class Expression(metaclass=_Expression):
|
|||
"""
|
||||
Returns unnested operands as a tuple.
|
||||
"""
|
||||
return tuple(arg.unnest() for arg in self.args.values() if arg)
|
||||
return tuple(arg.unnest() for _, arg in self.iter_expressions())
|
||||
|
||||
def flatten(self, unnest=True):
|
||||
"""
|
||||
|
@ -442,8 +459,8 @@ class Expression(metaclass=_Expression):
|
|||
|
||||
A AND B AND C -> [A, B, C]
|
||||
"""
|
||||
for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not isinstance(n, self.__class__)):
|
||||
if not isinstance(node, self.__class__):
|
||||
for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not type(n) is self.__class__):
|
||||
if not type(node) is self.__class__:
|
||||
yield node.unnest() if unnest else node
|
||||
|
||||
def __str__(self):
|
||||
|
@ -477,7 +494,7 @@ class Expression(metaclass=_Expression):
|
|||
v._to_s(hide_missing=hide_missing, level=level + 1)
|
||||
if hasattr(v, "_to_s")
|
||||
else str(v)
|
||||
for v in ensure_collection(vs)
|
||||
for v in ensure_list(vs)
|
||||
if v is not None
|
||||
)
|
||||
for k, vs in self.args.items()
|
||||
|
@ -812,6 +829,10 @@ class Describe(Expression):
|
|||
arg_types = {"this": True, "kind": False}
|
||||
|
||||
|
||||
class Pragma(Expression):
|
||||
pass
|
||||
|
||||
|
||||
class Set(Expression):
|
||||
arg_types = {"expressions": False}
|
||||
|
||||
|
@ -1170,6 +1191,7 @@ class Drop(Expression):
|
|||
"temporary": False,
|
||||
"materialized": False,
|
||||
"cascade": False,
|
||||
"constraints": False,
|
||||
}
|
||||
|
||||
|
||||
|
@ -1232,11 +1254,11 @@ class Identifier(Expression):
|
|||
def quoted(self):
|
||||
return bool(self.args.get("quoted"))
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg(other.this)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.key, self.this.lower()))
|
||||
@property
|
||||
def hashable_args(self) -> t.Any:
|
||||
if self.quoted and any(char.isupper() for char in self.this):
|
||||
return (self.this, self.quoted)
|
||||
return self.this.lower()
|
||||
|
||||
@property
|
||||
def output_name(self):
|
||||
|
@ -1322,15 +1344,9 @@ class Limit(Expression):
|
|||
class Literal(Condition):
|
||||
arg_types = {"this": True, "is_string": True}
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, Literal)
|
||||
and self.this == other.this
|
||||
and self.args["is_string"] == other.args["is_string"]
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.key, self.this, self.args["is_string"]))
|
||||
@property
|
||||
def hashable_args(self) -> t.Any:
|
||||
return (self.this, self.args.get("is_string"))
|
||||
|
||||
@classmethod
|
||||
def number(cls, number) -> Literal:
|
||||
|
@ -1784,7 +1800,7 @@ class Subqueryable(Unionable):
|
|||
instance = _maybe_copy(self, copy)
|
||||
return Subquery(
|
||||
this=instance,
|
||||
alias=TableAlias(this=to_identifier(alias)),
|
||||
alias=TableAlias(this=to_identifier(alias)) if alias else None,
|
||||
)
|
||||
|
||||
def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
|
||||
|
@ -2058,6 +2074,7 @@ class Lock(Expression):
|
|||
class Select(Subqueryable):
|
||||
arg_types = {
|
||||
"with": False,
|
||||
"kind": False,
|
||||
"expressions": False,
|
||||
"hint": False,
|
||||
"distinct": False,
|
||||
|
@ -3595,6 +3612,21 @@ class Initcap(Func):
|
|||
pass
|
||||
|
||||
|
||||
class JSONKeyValue(Expression):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
class JSONObject(Func):
|
||||
arg_types = {
|
||||
"expressions": False,
|
||||
"null_handling": False,
|
||||
"unique_keys": False,
|
||||
"return_type": False,
|
||||
"format_json": False,
|
||||
"encoding": False,
|
||||
}
|
||||
|
||||
|
||||
class JSONBContains(Binary):
|
||||
_sql_names = ["JSONB_CONTAINS"]
|
||||
|
||||
|
@ -3766,8 +3798,10 @@ class RegexpILike(Func):
|
|||
arg_types = {"this": True, "expression": True, "flag": False}
|
||||
|
||||
|
||||
# https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.split.html
|
||||
# limit is the number of times a pattern is applied
|
||||
class RegexpSplit(Func):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
arg_types = {"this": True, "expression": True, "limit": False}
|
||||
|
||||
|
||||
class Repeat(Func):
|
||||
|
@ -3967,25 +4001,8 @@ class When(Func):
|
|||
arg_types = {"matched": True, "source": False, "condition": False, "then": True}
|
||||
|
||||
|
||||
def _norm_args(expression):
|
||||
args = {}
|
||||
|
||||
for k, arg in expression.args.items():
|
||||
if isinstance(arg, list):
|
||||
arg = [_norm_arg(a) for a in arg]
|
||||
if not arg:
|
||||
arg = None
|
||||
else:
|
||||
arg = _norm_arg(arg)
|
||||
|
||||
if arg is not None and arg is not False:
|
||||
args[k] = arg
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def _norm_arg(arg):
|
||||
return arg.lower() if isinstance(arg, str) else arg
|
||||
return arg.lower() if type(arg) is str else arg
|
||||
|
||||
|
||||
ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
|
||||
|
@ -4512,7 +4529,7 @@ def to_identifier(name, quoted=None):
|
|||
elif isinstance(name, str):
|
||||
identifier = Identifier(
|
||||
this=name,
|
||||
quoted=not re.match(SAFE_IDENTIFIER_RE, name) if quoted is None else quoted,
|
||||
quoted=not SAFE_IDENTIFIER_RE.match(name) if quoted is None else quoted,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Name needs to be a string or an Identifier, got: {name.__class__}")
|
||||
|
@ -4586,8 +4603,7 @@ def to_column(sql_path: str | Column, **kwargs) -> Column:
|
|||
return sql_path
|
||||
if not isinstance(sql_path, str):
|
||||
raise ValueError(f"Invalid type provided for column: {type(sql_path)}")
|
||||
table_name, column_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 2))
|
||||
return Column(this=column_name, table=table_name, **kwargs)
|
||||
return column(*reversed(sql_path.split(".")), **kwargs) # type: ignore
|
||||
|
||||
|
||||
def alias_(
|
||||
|
@ -4672,7 +4688,8 @@ def subquery(expression, alias=None, dialect=None, **opts):
|
|||
def column(
|
||||
col: str | Identifier,
|
||||
table: t.Optional[str | Identifier] = None,
|
||||
schema: t.Optional[str | Identifier] = None,
|
||||
db: t.Optional[str | Identifier] = None,
|
||||
catalog: t.Optional[str | Identifier] = None,
|
||||
quoted: t.Optional[bool] = None,
|
||||
) -> Column:
|
||||
"""
|
||||
|
@ -4681,7 +4698,8 @@ def column(
|
|||
Args:
|
||||
col: column name
|
||||
table: table name
|
||||
schema: schema name
|
||||
db: db name
|
||||
catalog: catalog name
|
||||
quoted: whether or not to force quote each part
|
||||
Returns:
|
||||
Column: column instance
|
||||
|
@ -4689,7 +4707,8 @@ def column(
|
|||
return Column(
|
||||
this=to_identifier(col, quoted=quoted),
|
||||
table=to_identifier(table, quoted=quoted),
|
||||
schema=to_identifier(schema, quoted=quoted),
|
||||
db=to_identifier(db, quoted=quoted),
|
||||
catalog=to_identifier(catalog, quoted=quoted),
|
||||
)
|
||||
|
||||
|
||||
|
@ -4864,7 +4883,7 @@ def replace_children(expression, fun, *args, **kwargs):
|
|||
Replace children of an expression with the result of a lambda fun(child) -> exp.
|
||||
"""
|
||||
for k, v in expression.args.items():
|
||||
is_list_arg = isinstance(v, list)
|
||||
is_list_arg = type(v) is list
|
||||
|
||||
child_nodes = v if is_list_arg else [v]
|
||||
new_child_nodes = []
|
||||
|
|
|
@ -110,6 +110,10 @@ class Generator:
|
|||
# Whether or not MERGE ... WHEN MATCHED BY SOURCE is allowed
|
||||
MATCHED_BY_SOURCE = True
|
||||
|
||||
# Whether or not limit and fetch are supported
|
||||
# "ALL", "LIMIT", "FETCH"
|
||||
LIMIT_FETCH = "ALL"
|
||||
|
||||
TYPE_MAPPING = {
|
||||
exp.DataType.Type.NCHAR: "CHAR",
|
||||
exp.DataType.Type.NVARCHAR: "VARCHAR",
|
||||
|
@ -209,6 +213,7 @@ class Generator:
|
|||
"_leading_comma",
|
||||
"_max_text_width",
|
||||
"_comments",
|
||||
"_cache",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
|
@ -265,19 +270,28 @@ class Generator:
|
|||
self._leading_comma = leading_comma
|
||||
self._max_text_width = max_text_width
|
||||
self._comments = comments
|
||||
self._cache = None
|
||||
|
||||
def generate(self, expression: t.Optional[exp.Expression]) -> str:
|
||||
def generate(
|
||||
self,
|
||||
expression: t.Optional[exp.Expression],
|
||||
cache: t.Optional[t.Dict[int, str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generates a SQL string by interpreting the given syntax tree.
|
||||
|
||||
Args
|
||||
expression: the syntax tree.
|
||||
cache: an optional sql string cache. this leverages the hash of an expression which is slow, so only use this if you set _hash on each node.
|
||||
|
||||
Returns
|
||||
the SQL string.
|
||||
"""
|
||||
if cache is not None:
|
||||
self._cache = cache
|
||||
self.unsupported_messages = []
|
||||
sql = self.sql(expression).strip()
|
||||
self._cache = None
|
||||
|
||||
if self.unsupported_level == ErrorLevel.IGNORE:
|
||||
return sql
|
||||
|
@ -387,6 +401,12 @@ class Generator:
|
|||
if key:
|
||||
return self.sql(expression.args.get(key))
|
||||
|
||||
if self._cache is not None:
|
||||
expression_id = hash(expression)
|
||||
|
||||
if expression_id in self._cache:
|
||||
return self._cache[expression_id]
|
||||
|
||||
transform = self.TRANSFORMS.get(expression.__class__)
|
||||
|
||||
if callable(transform):
|
||||
|
@ -407,7 +427,11 @@ class Generator:
|
|||
else:
|
||||
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
|
||||
|
||||
return self.maybe_comment(sql, expression) if self._comments and comment else sql
|
||||
sql = self.maybe_comment(sql, expression) if self._comments and comment else sql
|
||||
|
||||
if self._cache is not None:
|
||||
self._cache[expression_id] = sql
|
||||
return sql
|
||||
|
||||
def uncache_sql(self, expression: exp.Uncache) -> str:
|
||||
table = self.sql(expression, "this")
|
||||
|
@ -697,7 +721,8 @@ class Generator:
|
|||
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
|
||||
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
|
||||
cascade = " CASCADE" if expression.args.get("cascade") else ""
|
||||
return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}"
|
||||
constraints = " CONSTRAINTS" if expression.args.get("constraints") else ""
|
||||
return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}{constraints}"
|
||||
|
||||
def except_sql(self, expression: exp.Except) -> str:
|
||||
return self.prepend_ctes(
|
||||
|
@ -733,9 +758,9 @@ class Generator:
|
|||
|
||||
def identifier_sql(self, expression: exp.Identifier) -> str:
|
||||
text = expression.name
|
||||
text = text.lower() if self.normalize else text
|
||||
text = text.lower() if self.normalize and not expression.quoted else text
|
||||
text = text.replace(self.identifier_end, self._escaped_identifier_end)
|
||||
if expression.args.get("quoted") or should_identify(text, self.identify):
|
||||
if expression.quoted or should_identify(text, self.identify):
|
||||
text = f"{self.identifier_start}{text}{self.identifier_end}"
|
||||
return text
|
||||
|
||||
|
@ -1191,6 +1216,9 @@ class Generator:
|
|||
)
|
||||
return f"SET{expressions}"
|
||||
|
||||
def pragma_sql(self, expression: exp.Pragma) -> str:
|
||||
return f"PRAGMA {self.sql(expression, 'this')}"
|
||||
|
||||
def lock_sql(self, expression: exp.Lock) -> str:
|
||||
if self.LOCKING_READS_SUPPORTED:
|
||||
lock_type = "UPDATE" if expression.args["update"] else "SHARE"
|
||||
|
@ -1299,6 +1327,15 @@ class Generator:
|
|||
return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}"
|
||||
|
||||
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
|
||||
limit = expression.args.get("limit")
|
||||
|
||||
if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch):
|
||||
limit = exp.Limit(expression=limit.args.get("count"))
|
||||
elif self.LIMIT_FETCH == "FETCH" and isinstance(limit, exp.Limit):
|
||||
limit = exp.Fetch(direction="FIRST", count=limit.expression)
|
||||
|
||||
fetch = isinstance(limit, exp.Fetch)
|
||||
|
||||
return csv(
|
||||
*sqls,
|
||||
*[self.sql(sql) for sql in expression.args.get("joins") or []],
|
||||
|
@ -1315,14 +1352,16 @@ class Generator:
|
|||
self.sql(expression, "sort"),
|
||||
self.sql(expression, "cluster"),
|
||||
self.sql(expression, "order"),
|
||||
self.sql(expression, "limit"),
|
||||
self.sql(expression, "offset"),
|
||||
self.sql(expression, "offset") if fetch else self.sql(limit),
|
||||
self.sql(limit) if fetch else self.sql(expression, "offset"),
|
||||
self.sql(expression, "lock"),
|
||||
self.sql(expression, "sample"),
|
||||
sep="",
|
||||
)
|
||||
|
||||
def select_sql(self, expression: exp.Select) -> str:
|
||||
kind = expression.args.get("kind")
|
||||
kind = f" AS {kind}" if kind else ""
|
||||
hint = self.sql(expression, "hint")
|
||||
distinct = self.sql(expression, "distinct")
|
||||
distinct = f" {distinct}" if distinct else ""
|
||||
|
@ -1330,7 +1369,7 @@ class Generator:
|
|||
expressions = f"{self.sep()}{expressions}" if expressions else expressions
|
||||
sql = self.query_modifiers(
|
||||
expression,
|
||||
f"SELECT{hint}{distinct}{expressions}",
|
||||
f"SELECT{kind}{hint}{distinct}{expressions}",
|
||||
self.sql(expression, "into", comment=False),
|
||||
self.sql(expression, "from", comment=False),
|
||||
)
|
||||
|
@ -1552,6 +1591,25 @@ class Generator:
|
|||
exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))
|
||||
)
|
||||
|
||||
def jsonkeyvalue_sql(self, expression: exp.JSONKeyValue) -> str:
|
||||
return f"{self.sql(expression, 'this')}: {self.sql(expression, 'expression')}"
|
||||
|
||||
def jsonobject_sql(self, expression: exp.JSONObject) -> str:
|
||||
expressions = self.expressions(expression)
|
||||
null_handling = expression.args.get("null_handling")
|
||||
null_handling = f" {null_handling}" if null_handling else ""
|
||||
unique_keys = expression.args.get("unique_keys")
|
||||
if unique_keys is not None:
|
||||
unique_keys = f" {'WITH' if unique_keys else 'WITHOUT'} UNIQUE KEYS"
|
||||
else:
|
||||
unique_keys = ""
|
||||
return_type = self.sql(expression, "return_type")
|
||||
return_type = f" RETURNING {return_type}" if return_type else ""
|
||||
format_json = " FORMAT JSON" if expression.args.get("format_json") else ""
|
||||
encoding = self.sql(expression, "encoding")
|
||||
encoding = f" ENCODING {encoding}" if encoding else ""
|
||||
return f"JSON_OBJECT({expressions}{null_handling}{unique_keys}{return_type}{format_json}{encoding})"
|
||||
|
||||
def in_sql(self, expression: exp.In) -> str:
|
||||
query = expression.args.get("query")
|
||||
unnest = expression.args.get("unnest")
|
||||
|
@ -1808,12 +1866,18 @@ class Generator:
|
|||
def ilike_sql(self, expression: exp.ILike) -> str:
|
||||
return self.binary(expression, "ILIKE")
|
||||
|
||||
def ilikeany_sql(self, expression: exp.ILikeAny) -> str:
|
||||
return self.binary(expression, "ILIKE ANY")
|
||||
|
||||
def is_sql(self, expression: exp.Is) -> str:
|
||||
return self.binary(expression, "IS")
|
||||
|
||||
def like_sql(self, expression: exp.Like) -> str:
|
||||
return self.binary(expression, "LIKE")
|
||||
|
||||
def likeany_sql(self, expression: exp.LikeAny) -> str:
|
||||
return self.binary(expression, "LIKE ANY")
|
||||
|
||||
def similarto_sql(self, expression: exp.SimilarTo) -> str:
|
||||
return self.binary(expression, "SIMILAR TO")
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ def ensure_list(value):
|
|||
"""
|
||||
if value is None:
|
||||
return []
|
||||
elif isinstance(value, (list, tuple)):
|
||||
if isinstance(value, (list, tuple)):
|
||||
return list(value)
|
||||
|
||||
return [value]
|
||||
|
@ -162,9 +162,7 @@ def camel_to_snake_case(name: str) -> str:
|
|||
return CAMEL_CASE_PATTERN.sub("_", name).upper()
|
||||
|
||||
|
||||
def while_changing(
|
||||
expression: t.Optional[Expression], func: t.Callable[[t.Optional[Expression]], E]
|
||||
) -> E:
|
||||
def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
|
||||
"""
|
||||
Applies a transformation to a given expression until a fix point is reached.
|
||||
|
||||
|
@ -176,8 +174,13 @@ def while_changing(
|
|||
The transformed expression.
|
||||
"""
|
||||
while True:
|
||||
for n, *_ in reversed(tuple(expression.walk())):
|
||||
n._hash = hash(n)
|
||||
start = hash(expression)
|
||||
expression = func(expression)
|
||||
|
||||
for n, *_ in expression.walk():
|
||||
n._hash = None
|
||||
if start == hash(expression):
|
||||
break
|
||||
return expression
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.helper import ensure_collection, ensure_list, subclasses
|
||||
from sqlglot.helper import ensure_list, subclasses
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import ensure_schema
|
||||
|
||||
|
@ -108,6 +108,7 @@ class TypeAnnotator:
|
|||
exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
|
||||
exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
|
||||
exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
|
||||
exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.GroupConcat: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
|
@ -116,6 +117,7 @@ class TypeAnnotator:
|
|||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL),
|
||||
exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
|
||||
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
|
@ -296,9 +298,6 @@ class TypeAnnotator:
|
|||
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
|
||||
|
||||
def _maybe_annotate(self, expression):
|
||||
if not isinstance(expression, exp.Expression):
|
||||
return None
|
||||
|
||||
if expression.type:
|
||||
return expression # We've already inferred the expression's type
|
||||
|
||||
|
@ -311,9 +310,8 @@ class TypeAnnotator:
|
|||
)
|
||||
|
||||
def _annotate_args(self, expression):
|
||||
for value in expression.args.values():
|
||||
for v in ensure_collection(value):
|
||||
self._maybe_annotate(v)
|
||||
for _, value in expression.iter_expressions():
|
||||
self._maybe_annotate(value)
|
||||
|
||||
return expression
|
||||
|
||||
|
|
|
@ -75,7 +75,7 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
|
|||
a.type
|
||||
and a.type.this == exp.DataType.Type.DATE
|
||||
and b.type
|
||||
and b.type.this != exp.DataType.Type.DATE
|
||||
and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL)
|
||||
):
|
||||
_replace_cast(b, "date")
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from sqlglot import expressions as exp
|
||||
from sqlglot.optimizer.normalize import normalized
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
|
||||
def eliminate_joins(expression):
|
||||
|
@ -179,6 +178,4 @@ def join_condition(join):
|
|||
for condition in conditions:
|
||||
extract_condition(condition)
|
||||
|
||||
on = simplify(on)
|
||||
remaining_condition = None if on == exp.true() else on
|
||||
return source_key, join_key, remaining_condition
|
||||
return source_key, join_key, on
|
||||
|
|
|
@ -3,7 +3,6 @@ import itertools
|
|||
from sqlglot import expressions as exp
|
||||
from sqlglot.helper import find_new_name
|
||||
from sqlglot.optimizer.scope import build_scope
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
|
||||
def eliminate_subqueries(expression):
|
||||
|
@ -31,7 +30,6 @@ def eliminate_subqueries(expression):
|
|||
eliminate_subqueries(expression.this)
|
||||
return expression
|
||||
|
||||
expression = simplify(expression)
|
||||
root = build_scope(expression)
|
||||
|
||||
# Map of alias->Scope|Table
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.helper import ensure_collection
|
||||
|
||||
|
||||
def lower_identities(expression):
|
||||
|
@ -40,13 +39,10 @@ def lower_identities(expression):
|
|||
lower_identities(expression.right)
|
||||
traversed |= {"this", "expression"}
|
||||
|
||||
for k, v in expression.args.items():
|
||||
for k, v in expression.iter_expressions():
|
||||
if k in traversed:
|
||||
continue
|
||||
|
||||
for child in ensure_collection(v):
|
||||
if isinstance(child, exp.Expression):
|
||||
child.transform(_lower, copy=False)
|
||||
v.transform(_lower, copy=False)
|
||||
|
||||
return expression
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@ from collections import defaultdict
|
|||
from sqlglot import expressions as exp
|
||||
from sqlglot.helper import find_new_name
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
|
||||
def merge_subqueries(expression, leave_tables_isolated=False):
|
||||
|
@ -330,11 +329,11 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
|
|||
|
||||
if set(exp.column_table_names(where.this)) <= sources:
|
||||
from_or_join.on(where.this, copy=False)
|
||||
from_or_join.set("on", simplify(from_or_join.args.get("on")))
|
||||
from_or_join.set("on", from_or_join.args.get("on"))
|
||||
return
|
||||
|
||||
expression.where(where.this, copy=False)
|
||||
expression.set("where", simplify(expression.args.get("where")))
|
||||
expression.set("where", expression.args.get("where"))
|
||||
|
||||
|
||||
def _merge_order(outer_scope, inner_scope):
|
||||
|
|
|
@ -1,29 +1,63 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import while_changing
|
||||
from sqlglot.optimizer.simplify import flatten, simplify, uniq_sort
|
||||
from sqlglot.optimizer.simplify import flatten, uniq_sort
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
|
||||
def normalize(expression, dnf=False, max_distance=128):
|
||||
def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128):
|
||||
"""
|
||||
Rewrite sqlglot AST into conjunctive normal form.
|
||||
Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one("(x AND y) OR z")
|
||||
>>> normalize(expression).sql()
|
||||
>>> normalize(expression, dnf=False).sql()
|
||||
'(x OR z) AND (y OR z)'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to normalize
|
||||
dnf (bool): rewrite in disjunctive normal form instead
|
||||
max_distance (int): the maximal estimated distance from cnf to attempt conversion
|
||||
expression: expression to normalize
|
||||
dnf: rewrite in disjunctive normal form instead.
|
||||
max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
|
||||
Returns:
|
||||
sqlglot.Expression: normalized expression
|
||||
"""
|
||||
expression = simplify(expression)
|
||||
cache: t.Dict[int, str] = {}
|
||||
|
||||
expression = while_changing(expression, lambda e: distributive_law(e, dnf, max_distance))
|
||||
return simplify(expression)
|
||||
for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
|
||||
if isinstance(node, exp.Connector):
|
||||
if normalized(node, dnf=dnf):
|
||||
continue
|
||||
|
||||
distance = normalization_distance(node, dnf=dnf)
|
||||
|
||||
if distance > max_distance:
|
||||
logger.info(
|
||||
f"Skipping normalization because distance {distance} exceeds max {max_distance}"
|
||||
)
|
||||
return expression
|
||||
|
||||
root = node is expression
|
||||
original = node.copy()
|
||||
try:
|
||||
node = while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
|
||||
except OptimizeError as e:
|
||||
logger.info(e)
|
||||
node.replace(original)
|
||||
if root:
|
||||
return original
|
||||
return expression
|
||||
|
||||
if root:
|
||||
expression = node
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def normalized(expression, dnf=False):
|
||||
|
@ -51,7 +85,7 @@ def normalization_distance(expression, dnf=False):
|
|||
int: difference
|
||||
"""
|
||||
return sum(_predicate_lengths(expression, dnf)) - (
|
||||
len(list(expression.find_all(exp.Connector))) + 1
|
||||
sum(1 for _ in expression.find_all(exp.Connector)) + 1
|
||||
)
|
||||
|
||||
|
||||
|
@ -64,30 +98,33 @@ def _predicate_lengths(expression, dnf):
|
|||
expression = expression.unnest()
|
||||
|
||||
if not isinstance(expression, exp.Connector):
|
||||
return [1]
|
||||
return (1,)
|
||||
|
||||
left, right = expression.args.values()
|
||||
|
||||
if isinstance(expression, exp.And if dnf else exp.Or):
|
||||
return [
|
||||
return tuple(
|
||||
a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
|
||||
]
|
||||
)
|
||||
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
|
||||
|
||||
|
||||
def distributive_law(expression, dnf, max_distance):
|
||||
def distributive_law(expression, dnf, max_distance, cache=None):
|
||||
"""
|
||||
x OR (y AND z) -> (x OR y) AND (x OR z)
|
||||
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
|
||||
"""
|
||||
if isinstance(expression.unnest(), exp.Connector):
|
||||
if normalization_distance(expression, dnf) > max_distance:
|
||||
return expression
|
||||
if normalized(expression, dnf=dnf):
|
||||
return expression
|
||||
|
||||
distance = normalization_distance(expression, dnf=dnf)
|
||||
|
||||
if distance > max_distance:
|
||||
raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
|
||||
|
||||
exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache))
|
||||
to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
|
||||
|
||||
exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
|
||||
|
||||
if isinstance(expression, from_exp):
|
||||
a, b = expression.unnest_operands()
|
||||
|
||||
|
@ -96,32 +133,29 @@ def distributive_law(expression, dnf, max_distance):
|
|||
|
||||
if isinstance(a, to_exp) and isinstance(b, to_exp):
|
||||
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
|
||||
return _distribute(a, b, from_func, to_func)
|
||||
return _distribute(b, a, from_func, to_func)
|
||||
return _distribute(a, b, from_func, to_func, cache)
|
||||
return _distribute(b, a, from_func, to_func, cache)
|
||||
if isinstance(a, to_exp):
|
||||
return _distribute(b, a, from_func, to_func)
|
||||
return _distribute(b, a, from_func, to_func, cache)
|
||||
if isinstance(b, to_exp):
|
||||
return _distribute(a, b, from_func, to_func)
|
||||
return _distribute(a, b, from_func, to_func, cache)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def _distribute(a, b, from_func, to_func):
|
||||
def _distribute(a, b, from_func, to_func, cache):
|
||||
if isinstance(a, exp.Connector):
|
||||
exp.replace_children(
|
||||
a,
|
||||
lambda c: to_func(
|
||||
exp.paren(from_func(c, b.left)),
|
||||
exp.paren(from_func(c, b.right)),
|
||||
uniq_sort(flatten(from_func(c, b.left)), cache),
|
||||
uniq_sort(flatten(from_func(c, b.right)), cache),
|
||||
),
|
||||
)
|
||||
else:
|
||||
a = to_func(from_func(a, b.left), from_func(a, b.right))
|
||||
a = to_func(
|
||||
uniq_sort(flatten(from_func(a, b.left)), cache),
|
||||
uniq_sort(flatten(from_func(a, b.right)), cache),
|
||||
)
|
||||
|
||||
return _simplify(a)
|
||||
|
||||
|
||||
def _simplify(node):
|
||||
node = uniq_sort(flatten(node))
|
||||
exp.replace_children(node, _simplify)
|
||||
return node
|
||||
return a
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.helper import tsort
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
|
||||
def optimize_joins(expression):
|
||||
|
@ -29,7 +28,6 @@ def optimize_joins(expression):
|
|||
for name, join in cross_joins:
|
||||
for dep in references.get(name, []):
|
||||
on = dep.args["on"]
|
||||
on = on.replace(simplify(on))
|
||||
|
||||
if isinstance(on, exp.Connector):
|
||||
for predicate in on.flatten():
|
||||
|
|
|
@ -21,6 +21,7 @@ from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
|
|||
from sqlglot.optimizer.pushdown_projections import pushdown_projections
|
||||
from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns
|
||||
from sqlglot.optimizer.qualify_tables import qualify_tables
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
|
||||
from sqlglot.schema import ensure_schema
|
||||
|
||||
|
@ -43,6 +44,7 @@ RULES = (
|
|||
eliminate_ctes,
|
||||
annotate_types,
|
||||
canonicalize,
|
||||
simplify,
|
||||
)
|
||||
|
||||
|
||||
|
@ -78,7 +80,7 @@ def optimize(
|
|||
Returns:
|
||||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
schema = ensure_schema(schema or sqlglot.schema)
|
||||
schema = ensure_schema(schema or sqlglot.schema, dialect=dialect)
|
||||
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
|
||||
expression = exp.maybe_parse(expression, dialect=dialect, copy=True)
|
||||
for rule in rules:
|
||||
|
|
|
@ -30,11 +30,12 @@ def qualify_columns(expression, schema):
|
|||
resolver = Resolver(scope, schema)
|
||||
_pop_table_column_aliases(scope.ctes)
|
||||
_pop_table_column_aliases(scope.derived_tables)
|
||||
_expand_using(scope, resolver)
|
||||
using_column_tables = _expand_using(scope, resolver)
|
||||
_qualify_columns(scope, resolver)
|
||||
if not isinstance(scope.expression, exp.UDTF):
|
||||
_expand_stars(scope, resolver)
|
||||
_expand_stars(scope, resolver, using_column_tables)
|
||||
_qualify_outputs(scope)
|
||||
_expand_alias_refs(scope, resolver)
|
||||
_expand_group_by(scope, resolver)
|
||||
_expand_order_by(scope)
|
||||
|
||||
|
@ -69,11 +70,11 @@ def _pop_table_column_aliases(derived_tables):
|
|||
|
||||
|
||||
def _expand_using(scope, resolver):
|
||||
joins = list(scope.expression.find_all(exp.Join))
|
||||
joins = list(scope.find_all(exp.Join))
|
||||
names = {join.this.alias for join in joins}
|
||||
ordered = [key for key in scope.selected_sources if key not in names]
|
||||
|
||||
# Mapping of automatically joined column names to source names
|
||||
# Mapping of automatically joined column names to an ordered set of source names (dict).
|
||||
column_tables = {}
|
||||
|
||||
for join in joins:
|
||||
|
@ -112,11 +113,12 @@ def _expand_using(scope, resolver):
|
|||
)
|
||||
)
|
||||
|
||||
tables = column_tables.setdefault(identifier, [])
|
||||
# Set all values in the dict to None, because we only care about the key ordering
|
||||
tables = column_tables.setdefault(identifier, {})
|
||||
if table not in tables:
|
||||
tables.append(table)
|
||||
tables[table] = None
|
||||
if join_table not in tables:
|
||||
tables.append(join_table)
|
||||
tables[join_table] = None
|
||||
|
||||
join.args.pop("using")
|
||||
join.set("on", exp.and_(*conditions))
|
||||
|
@ -134,11 +136,11 @@ def _expand_using(scope, resolver):
|
|||
|
||||
scope.replace(column, replacement)
|
||||
|
||||
return column_tables
|
||||
|
||||
def _expand_group_by(scope, resolver):
|
||||
group = scope.expression.args.get("group")
|
||||
if not group:
|
||||
return
|
||||
|
||||
def _expand_alias_refs(scope, resolver):
|
||||
selects = {}
|
||||
|
||||
# Replace references to select aliases
|
||||
def transform(node, *_):
|
||||
|
@ -150,9 +152,11 @@ def _expand_group_by(scope, resolver):
|
|||
node.set("table", table)
|
||||
return node
|
||||
|
||||
selects = {s.alias_or_name: s for s in scope.selects}
|
||||
|
||||
if not selects:
|
||||
for s in scope.selects:
|
||||
selects[s.alias_or_name] = s
|
||||
select = selects.get(node.name)
|
||||
|
||||
if select:
|
||||
scope.clear_cache()
|
||||
if isinstance(select, exp.Alias):
|
||||
|
@ -161,7 +165,21 @@ def _expand_group_by(scope, resolver):
|
|||
|
||||
return node
|
||||
|
||||
group.transform(transform, copy=False)
|
||||
for select in scope.expression.selects:
|
||||
select.transform(transform, copy=False)
|
||||
|
||||
for modifier in ("where", "group"):
|
||||
part = scope.expression.args.get(modifier)
|
||||
|
||||
if part:
|
||||
part.transform(transform, copy=False)
|
||||
|
||||
|
||||
def _expand_group_by(scope, resolver):
|
||||
group = scope.expression.args.get("group")
|
||||
if not group:
|
||||
return
|
||||
|
||||
group.set("expressions", _expand_positional_references(scope, group.expressions))
|
||||
scope.expression.set("group", group)
|
||||
|
||||
|
@ -231,18 +249,24 @@ def _qualify_columns(scope, resolver):
|
|||
column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
|
||||
|
||||
columns_missing_from_scope = []
|
||||
|
||||
# Determine whether each reference in the order by clause is to a column or an alias.
|
||||
for ordered in scope.find_all(exp.Ordered):
|
||||
for column in ordered.find_all(exp.Column):
|
||||
if (
|
||||
not column.table
|
||||
and column.parent is not ordered
|
||||
and column.name in resolver.all_columns
|
||||
):
|
||||
columns_missing_from_scope.append(column)
|
||||
order = scope.expression.args.get("order")
|
||||
|
||||
if order:
|
||||
for ordered in order.expressions:
|
||||
for column in ordered.find_all(exp.Column):
|
||||
if (
|
||||
not column.table
|
||||
and column.parent is not ordered
|
||||
and column.name in resolver.all_columns
|
||||
):
|
||||
columns_missing_from_scope.append(column)
|
||||
|
||||
# Determine whether each reference in the having clause is to a column or an alias.
|
||||
for having in scope.find_all(exp.Having):
|
||||
having = scope.expression.args.get("having")
|
||||
|
||||
if having:
|
||||
for column in having.find_all(exp.Column):
|
||||
if (
|
||||
not column.table
|
||||
|
@ -258,12 +282,13 @@ def _qualify_columns(scope, resolver):
|
|||
column.set("table", column_table)
|
||||
|
||||
|
||||
def _expand_stars(scope, resolver):
|
||||
def _expand_stars(scope, resolver, using_column_tables):
|
||||
"""Expand stars to lists of column selections"""
|
||||
|
||||
new_selections = []
|
||||
except_columns = {}
|
||||
replace_columns = {}
|
||||
coalesced_columns = set()
|
||||
|
||||
for expression in scope.selects:
|
||||
if isinstance(expression, exp.Star):
|
||||
|
@ -286,7 +311,20 @@ def _expand_stars(scope, resolver):
|
|||
if columns and "*" not in columns:
|
||||
table_id = id(table)
|
||||
for name in columns:
|
||||
if name not in except_columns.get(table_id, set()):
|
||||
if name in using_column_tables and table in using_column_tables[name]:
|
||||
if name in coalesced_columns:
|
||||
continue
|
||||
|
||||
coalesced_columns.add(name)
|
||||
tables = using_column_tables[name]
|
||||
coalesce = [exp.column(name, table=table) for table in tables]
|
||||
|
||||
new_selections.append(
|
||||
exp.alias_(
|
||||
exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), alias=name
|
||||
)
|
||||
)
|
||||
elif name not in except_columns.get(table_id, set()):
|
||||
alias_ = replace_columns.get(table_id, {}).get(name, name)
|
||||
column = exp.column(name, table)
|
||||
new_selections.append(alias(column, alias_) if alias_ != name else column)
|
||||
|
|
|
@ -160,7 +160,7 @@ class Scope:
|
|||
Yields:
|
||||
exp.Expression: nodes
|
||||
"""
|
||||
for expression, _, _ in self.walk(bfs=bfs):
|
||||
for expression, *_ in self.walk(bfs=bfs):
|
||||
if isinstance(expression, expression_types):
|
||||
yield expression
|
||||
|
||||
|
|
|
@ -5,11 +5,10 @@ from collections import deque
|
|||
from decimal import Decimal
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.expressions import FALSE, NULL, TRUE
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import first, while_changing
|
||||
|
||||
GENERATOR = Generator(normalize=True, identify=True)
|
||||
GENERATOR = Generator(normalize=True, identify="safe")
|
||||
|
||||
|
||||
def simplify(expression):
|
||||
|
@ -28,18 +27,20 @@ def simplify(expression):
|
|||
sqlglot.Expression: simplified expression
|
||||
"""
|
||||
|
||||
cache = {}
|
||||
|
||||
def _simplify(expression, root=True):
|
||||
node = expression
|
||||
node = rewrite_between(node)
|
||||
node = uniq_sort(node)
|
||||
node = absorb_and_eliminate(node)
|
||||
node = uniq_sort(node, cache, root)
|
||||
node = absorb_and_eliminate(node, root)
|
||||
exp.replace_children(node, lambda e: _simplify(e, False))
|
||||
node = simplify_not(node)
|
||||
node = flatten(node)
|
||||
node = simplify_connectors(node)
|
||||
node = remove_compliments(node)
|
||||
node = simplify_connectors(node, root)
|
||||
node = remove_compliments(node, root)
|
||||
node.parent = expression.parent
|
||||
node = simplify_literals(node)
|
||||
node = simplify_literals(node, root)
|
||||
node = simplify_parens(node)
|
||||
if root:
|
||||
expression.replace(node)
|
||||
|
@ -70,7 +71,7 @@ def simplify_not(expression):
|
|||
NOT (x AND y) -> NOT x OR NOT y
|
||||
"""
|
||||
if isinstance(expression, exp.Not):
|
||||
if isinstance(expression.this, exp.Null):
|
||||
if is_null(expression.this):
|
||||
return exp.null()
|
||||
if isinstance(expression.this, exp.Paren):
|
||||
condition = expression.this.unnest()
|
||||
|
@ -78,11 +79,11 @@ def simplify_not(expression):
|
|||
return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
|
||||
if isinstance(condition, exp.Or):
|
||||
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
|
||||
if isinstance(condition, exp.Null):
|
||||
if is_null(condition):
|
||||
return exp.null()
|
||||
if always_true(expression.this):
|
||||
return exp.false()
|
||||
if expression.this == FALSE:
|
||||
if is_false(expression.this):
|
||||
return exp.true()
|
||||
if isinstance(expression.this, exp.Not):
|
||||
# double negation
|
||||
|
@ -104,42 +105,42 @@ def flatten(expression):
|
|||
return expression
|
||||
|
||||
|
||||
def simplify_connectors(expression):
|
||||
def simplify_connectors(expression, root=True):
|
||||
def _simplify_connectors(expression, left, right):
|
||||
if isinstance(expression, exp.Connector):
|
||||
if left == right:
|
||||
if left == right:
|
||||
return left
|
||||
if isinstance(expression, exp.And):
|
||||
if is_false(left) or is_false(right):
|
||||
return exp.false()
|
||||
if is_null(left) or is_null(right):
|
||||
return exp.null()
|
||||
if always_true(left) and always_true(right):
|
||||
return exp.true()
|
||||
if always_true(left):
|
||||
return right
|
||||
if always_true(right):
|
||||
return left
|
||||
if isinstance(expression, exp.And):
|
||||
if FALSE in (left, right):
|
||||
return exp.false()
|
||||
if NULL in (left, right):
|
||||
return exp.null()
|
||||
if always_true(left) and always_true(right):
|
||||
return exp.true()
|
||||
if always_true(left):
|
||||
return right
|
||||
if always_true(right):
|
||||
return left
|
||||
return _simplify_comparison(expression, left, right)
|
||||
elif isinstance(expression, exp.Or):
|
||||
if always_true(left) or always_true(right):
|
||||
return exp.true()
|
||||
if left == FALSE and right == FALSE:
|
||||
return exp.false()
|
||||
if (
|
||||
(left == NULL and right == NULL)
|
||||
or (left == NULL and right == FALSE)
|
||||
or (left == FALSE and right == NULL)
|
||||
):
|
||||
return exp.null()
|
||||
if left == FALSE:
|
||||
return right
|
||||
if right == FALSE:
|
||||
return left
|
||||
return _simplify_comparison(expression, left, right, or_=True)
|
||||
return None
|
||||
return _simplify_comparison(expression, left, right)
|
||||
elif isinstance(expression, exp.Or):
|
||||
if always_true(left) or always_true(right):
|
||||
return exp.true()
|
||||
if is_false(left) and is_false(right):
|
||||
return exp.false()
|
||||
if (
|
||||
(is_null(left) and is_null(right))
|
||||
or (is_null(left) and is_false(right))
|
||||
or (is_false(left) and is_null(right))
|
||||
):
|
||||
return exp.null()
|
||||
if is_false(left):
|
||||
return right
|
||||
if is_false(right):
|
||||
return left
|
||||
return _simplify_comparison(expression, left, right, or_=True)
|
||||
|
||||
return _flat_simplify(expression, _simplify_connectors)
|
||||
if isinstance(expression, exp.Connector):
|
||||
return _flat_simplify(expression, _simplify_connectors, root)
|
||||
return expression
|
||||
|
||||
|
||||
LT_LTE = (exp.LT, exp.LTE)
|
||||
|
@ -220,14 +221,14 @@ def _simplify_comparison(expression, left, right, or_=False):
|
|||
return None
|
||||
|
||||
|
||||
def remove_compliments(expression):
|
||||
def remove_compliments(expression, root=True):
|
||||
"""
|
||||
Removing compliments.
|
||||
|
||||
A AND NOT A -> FALSE
|
||||
A OR NOT A -> TRUE
|
||||
"""
|
||||
if isinstance(expression, exp.Connector):
|
||||
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
|
||||
compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
|
||||
|
||||
for a, b in itertools.permutations(expression.flatten(), 2):
|
||||
|
@ -236,23 +237,23 @@ def remove_compliments(expression):
|
|||
return expression
|
||||
|
||||
|
||||
def uniq_sort(expression):
|
||||
def uniq_sort(expression, cache=None, root=True):
|
||||
"""
|
||||
Uniq and sort a connector.
|
||||
|
||||
C AND A AND B AND B -> A AND B AND C
|
||||
"""
|
||||
if isinstance(expression, exp.Connector):
|
||||
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
|
||||
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
|
||||
flattened = tuple(expression.flatten())
|
||||
deduped = {GENERATOR.generate(e): e for e in flattened}
|
||||
deduped = {GENERATOR.generate(e, cache): e for e in flattened}
|
||||
arr = tuple(deduped.items())
|
||||
|
||||
# check if the operands are already sorted, if not sort them
|
||||
# A AND C AND B -> A AND B AND C
|
||||
for i, (sql, e) in enumerate(arr[1:]):
|
||||
if sql < arr[i][0]:
|
||||
expression = result_func(*(deduped[sql] for sql in sorted(deduped)))
|
||||
expression = result_func(*(e for _, e in sorted(arr)))
|
||||
break
|
||||
else:
|
||||
# we didn't have to sort but maybe we need to dedup
|
||||
|
@ -262,7 +263,7 @@ def uniq_sort(expression):
|
|||
return expression
|
||||
|
||||
|
||||
def absorb_and_eliminate(expression):
|
||||
def absorb_and_eliminate(expression, root=True):
|
||||
"""
|
||||
absorption:
|
||||
A AND (A OR B) -> A
|
||||
|
@ -273,7 +274,7 @@ def absorb_and_eliminate(expression):
|
|||
(A AND B) OR (A AND NOT B) -> A
|
||||
(A OR B) AND (A OR NOT B) -> A
|
||||
"""
|
||||
if isinstance(expression, exp.Connector):
|
||||
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
|
||||
kind = exp.Or if isinstance(expression, exp.And) else exp.And
|
||||
|
||||
for a, b in itertools.permutations(expression.flatten(), 2):
|
||||
|
@ -302,9 +303,9 @@ def absorb_and_eliminate(expression):
|
|||
return expression
|
||||
|
||||
|
||||
def simplify_literals(expression):
|
||||
if isinstance(expression, exp.Binary):
|
||||
return _flat_simplify(expression, _simplify_binary)
|
||||
def simplify_literals(expression, root=True):
|
||||
if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
|
||||
return _flat_simplify(expression, _simplify_binary, root)
|
||||
elif isinstance(expression, exp.Neg):
|
||||
this = expression.this
|
||||
if this.is_number:
|
||||
|
@ -325,14 +326,14 @@ def _simplify_binary(expression, a, b):
|
|||
c = b
|
||||
not_ = False
|
||||
|
||||
if c == NULL:
|
||||
if is_null(c):
|
||||
if isinstance(a, exp.Literal):
|
||||
return exp.true() if not_ else exp.false()
|
||||
if a == NULL:
|
||||
if is_null(a):
|
||||
return exp.false() if not_ else exp.true()
|
||||
elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
|
||||
return None
|
||||
elif NULL in (a, b):
|
||||
elif is_null(a) or is_null(b):
|
||||
return exp.null()
|
||||
|
||||
if a.is_number and b.is_number:
|
||||
|
@ -355,7 +356,7 @@ def _simplify_binary(expression, a, b):
|
|||
if boolean:
|
||||
return boolean
|
||||
elif a.is_string and b.is_string:
|
||||
boolean = eval_boolean(expression, a, b)
|
||||
boolean = eval_boolean(expression, a.this, b.this)
|
||||
|
||||
if boolean:
|
||||
return boolean
|
||||
|
@ -381,7 +382,7 @@ def simplify_parens(expression):
|
|||
and not isinstance(expression.this, exp.Select)
|
||||
and (
|
||||
not isinstance(expression.parent, (exp.Condition, exp.Binary))
|
||||
or isinstance(expression.this, (exp.Is, exp.Like))
|
||||
or isinstance(expression.this, exp.Predicate)
|
||||
or not isinstance(expression.this, exp.Binary)
|
||||
)
|
||||
):
|
||||
|
@ -400,13 +401,23 @@ def remove_where_true(expression):
|
|||
|
||||
|
||||
def always_true(expression):
|
||||
return expression == TRUE or isinstance(expression, exp.Literal)
|
||||
return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
|
||||
expression, exp.Literal
|
||||
)
|
||||
|
||||
|
||||
def is_complement(a, b):
|
||||
return isinstance(b, exp.Not) and b.this == a
|
||||
|
||||
|
||||
def is_false(a: exp.Expression) -> bool:
|
||||
return type(a) is exp.Boolean and not a.this
|
||||
|
||||
|
||||
def is_null(a: exp.Expression) -> bool:
|
||||
return type(a) is exp.Null
|
||||
|
||||
|
||||
def eval_boolean(expression, a, b):
|
||||
if isinstance(expression, (exp.EQ, exp.Is)):
|
||||
return boolean_literal(a == b)
|
||||
|
@ -466,24 +477,27 @@ def boolean_literal(condition):
|
|||
return exp.true() if condition else exp.false()
|
||||
|
||||
|
||||
def _flat_simplify(expression, simplifier):
|
||||
operands = []
|
||||
queue = deque(expression.flatten(unnest=False))
|
||||
size = len(queue)
|
||||
def _flat_simplify(expression, simplifier, root=True):
|
||||
if root or not expression.same_parent:
|
||||
operands = []
|
||||
queue = deque(expression.flatten(unnest=False))
|
||||
size = len(queue)
|
||||
|
||||
while queue:
|
||||
a = queue.popleft()
|
||||
while queue:
|
||||
a = queue.popleft()
|
||||
|
||||
for b in queue:
|
||||
result = simplifier(expression, a, b)
|
||||
for b in queue:
|
||||
result = simplifier(expression, a, b)
|
||||
|
||||
if result:
|
||||
queue.remove(b)
|
||||
queue.append(result)
|
||||
break
|
||||
else:
|
||||
operands.append(a)
|
||||
if result:
|
||||
queue.remove(b)
|
||||
queue.append(result)
|
||||
break
|
||||
else:
|
||||
operands.append(a)
|
||||
|
||||
if len(operands) < size:
|
||||
return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
|
||||
if len(operands) < size:
|
||||
return functools.reduce(
|
||||
lambda a, b: expression.__class__(this=a, expression=b), operands
|
||||
)
|
||||
return expression
|
||||
|
|
|
@ -19,7 +19,7 @@ from sqlglot.trie import in_trie, new_trie
|
|||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
|
||||
def parse_var_map(args):
|
||||
def parse_var_map(args: t.Sequence) -> exp.Expression:
|
||||
keys = []
|
||||
values = []
|
||||
for i in range(0, len(args), 2):
|
||||
|
@ -31,6 +31,11 @@ def parse_var_map(args):
|
|||
)
|
||||
|
||||
|
||||
def parse_like(args):
|
||||
like = exp.Like(this=seq_get(args, 1), expression=seq_get(args, 0))
|
||||
return exp.Escape(this=like, expression=seq_get(args, 2)) if len(args) > 2 else like
|
||||
|
||||
|
||||
def binary_range_parser(
|
||||
expr_type: t.Type[exp.Expression],
|
||||
) -> t.Callable[[Parser, t.Optional[exp.Expression]], t.Optional[exp.Expression]]:
|
||||
|
@ -77,6 +82,9 @@ class Parser(metaclass=_Parser):
|
|||
this=seq_get(args, 0),
|
||||
to=exp.DataType(this=exp.DataType.Type.TEXT),
|
||||
),
|
||||
"GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)),
|
||||
"IFNULL": exp.Coalesce.from_arg_list,
|
||||
"LIKE": parse_like,
|
||||
"TIME_TO_TIME_STR": lambda args: exp.Cast(
|
||||
this=seq_get(args, 0),
|
||||
to=exp.DataType(this=exp.DataType.Type.TEXT),
|
||||
|
@ -90,7 +98,6 @@ class Parser(metaclass=_Parser):
|
|||
length=exp.Literal.number(10),
|
||||
),
|
||||
"VAR_MAP": parse_var_map,
|
||||
"IFNULL": exp.Coalesce.from_arg_list,
|
||||
}
|
||||
|
||||
NO_PAREN_FUNCTIONS = {
|
||||
|
@ -211,6 +218,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.FILTER,
|
||||
TokenType.FOLLOWING,
|
||||
TokenType.FORMAT,
|
||||
TokenType.FULL,
|
||||
TokenType.IF,
|
||||
TokenType.ISNULL,
|
||||
TokenType.INTERVAL,
|
||||
|
@ -226,8 +234,10 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.ONLY,
|
||||
TokenType.OPTIONS,
|
||||
TokenType.ORDINALITY,
|
||||
TokenType.PARTITION,
|
||||
TokenType.PERCENT,
|
||||
TokenType.PIVOT,
|
||||
TokenType.PRAGMA,
|
||||
TokenType.PRECEDING,
|
||||
TokenType.RANGE,
|
||||
TokenType.REFERENCES,
|
||||
|
@ -257,6 +267,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {
|
||||
TokenType.APPLY,
|
||||
TokenType.FULL,
|
||||
TokenType.LEFT,
|
||||
TokenType.NATURAL,
|
||||
TokenType.OFFSET,
|
||||
|
@ -277,6 +288,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.FILTER,
|
||||
TokenType.FIRST,
|
||||
TokenType.FORMAT,
|
||||
TokenType.GLOB,
|
||||
TokenType.IDENTIFIER,
|
||||
TokenType.INDEX,
|
||||
TokenType.ISNULL,
|
||||
|
@ -461,6 +473,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.INSERT: lambda self: self._parse_insert(),
|
||||
TokenType.LOAD_DATA: lambda self: self._parse_load_data(),
|
||||
TokenType.MERGE: lambda self: self._parse_merge(),
|
||||
TokenType.PRAGMA: lambda self: self.expression(exp.Pragma, this=self._parse_expression()),
|
||||
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
|
||||
TokenType.SET: lambda self: self._parse_set(),
|
||||
TokenType.UNCACHE: lambda self: self._parse_uncache(),
|
||||
|
@ -662,6 +675,8 @@ class Parser(metaclass=_Parser):
|
|||
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
|
||||
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
|
||||
"EXTRACT": lambda self: self._parse_extract(),
|
||||
"JSON_OBJECT": lambda self: self._parse_json_object(),
|
||||
"LOG": lambda self: self._parse_logarithm(),
|
||||
"POSITION": lambda self: self._parse_position(),
|
||||
"STRING_AGG": lambda self: self._parse_string_agg(),
|
||||
"SUBSTRING": lambda self: self._parse_substring(),
|
||||
|
@ -719,6 +734,9 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
CONVERT_TYPE_FIRST = False
|
||||
|
||||
LOG_BASE_FIRST = True
|
||||
LOG_DEFAULTS_TO_LN = False
|
||||
|
||||
__slots__ = (
|
||||
"error_level",
|
||||
"error_message_context",
|
||||
|
@ -1032,6 +1050,7 @@ class Parser(metaclass=_Parser):
|
|||
temporary=temporary,
|
||||
materialized=materialized,
|
||||
cascade=self._match(TokenType.CASCADE),
|
||||
constraints=self._match_text_seq("CONSTRAINTS"),
|
||||
)
|
||||
|
||||
def _parse_exists(self, not_: bool = False) -> t.Optional[bool]:
|
||||
|
@ -1221,7 +1240,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
if not identified_property:
|
||||
break
|
||||
for p in ensure_collection(identified_property):
|
||||
for p in ensure_list(identified_property):
|
||||
properties.append(p)
|
||||
|
||||
if properties:
|
||||
|
@ -1704,6 +1723,11 @@ class Parser(metaclass=_Parser):
|
|||
elif self._match(TokenType.SELECT):
|
||||
comments = self._prev_comments
|
||||
|
||||
kind = (
|
||||
self._match(TokenType.ALIAS)
|
||||
and self._match_texts(("STRUCT", "VALUE"))
|
||||
and self._prev.text
|
||||
)
|
||||
hint = self._parse_hint()
|
||||
all_ = self._match(TokenType.ALL)
|
||||
distinct = self._match(TokenType.DISTINCT)
|
||||
|
@ -1722,6 +1746,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
this = self.expression(
|
||||
exp.Select,
|
||||
kind=kind,
|
||||
hint=hint,
|
||||
distinct=distinct,
|
||||
expressions=expressions,
|
||||
|
@ -2785,7 +2810,6 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
this = seq_get(expressions, 0)
|
||||
self._parse_query_modifiers(this)
|
||||
self._match_r_paren()
|
||||
|
||||
if isinstance(this, exp.Subqueryable):
|
||||
this = self._parse_set_operations(
|
||||
|
@ -2794,7 +2818,9 @@ class Parser(metaclass=_Parser):
|
|||
elif len(expressions) > 1:
|
||||
this = self.expression(exp.Tuple, expressions=expressions)
|
||||
else:
|
||||
this = self.expression(exp.Paren, this=this)
|
||||
this = self.expression(exp.Paren, this=self._parse_set_operations(this))
|
||||
|
||||
self._match_r_paren()
|
||||
|
||||
if this and comments:
|
||||
this.comments = comments
|
||||
|
@ -3318,6 +3344,60 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
|
||||
|
||||
def _parse_json_key_value(self) -> t.Optional[exp.Expression]:
|
||||
self._match_text_seq("KEY")
|
||||
key = self._parse_field()
|
||||
self._match(TokenType.COLON)
|
||||
self._match_text_seq("VALUE")
|
||||
value = self._parse_field()
|
||||
if not key and not value:
|
||||
return None
|
||||
return self.expression(exp.JSONKeyValue, this=key, expression=value)
|
||||
|
||||
def _parse_json_object(self) -> exp.Expression:
|
||||
expressions = self._parse_csv(self._parse_json_key_value)
|
||||
|
||||
null_handling = None
|
||||
if self._match_text_seq("NULL", "ON", "NULL"):
|
||||
null_handling = "NULL ON NULL"
|
||||
elif self._match_text_seq("ABSENT", "ON", "NULL"):
|
||||
null_handling = "ABSENT ON NULL"
|
||||
|
||||
unique_keys = None
|
||||
if self._match_text_seq("WITH", "UNIQUE"):
|
||||
unique_keys = True
|
||||
elif self._match_text_seq("WITHOUT", "UNIQUE"):
|
||||
unique_keys = False
|
||||
|
||||
self._match_text_seq("KEYS")
|
||||
|
||||
return_type = self._match_text_seq("RETURNING") and self._parse_type()
|
||||
format_json = self._match_text_seq("FORMAT", "JSON")
|
||||
encoding = self._match_text_seq("ENCODING") and self._parse_var()
|
||||
|
||||
return self.expression(
|
||||
exp.JSONObject,
|
||||
expressions=expressions,
|
||||
null_handling=null_handling,
|
||||
unique_keys=unique_keys,
|
||||
return_type=return_type,
|
||||
format_json=format_json,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
def _parse_logarithm(self) -> exp.Expression:
|
||||
# Default argument order is base, expression
|
||||
args = self._parse_csv(self._parse_range)
|
||||
|
||||
if len(args) > 1:
|
||||
if not self.LOG_BASE_FIRST:
|
||||
args.reverse()
|
||||
return exp.Log.from_arg_list(args)
|
||||
|
||||
return self.expression(
|
||||
exp.Ln if self.LOG_DEFAULTS_TO_LN else exp.Log, this=seq_get(args, 0)
|
||||
)
|
||||
|
||||
def _parse_position(self, haystack_first: bool = False) -> exp.Expression:
|
||||
args = self._parse_csv(self._parse_bitwise)
|
||||
|
||||
|
@ -3654,7 +3734,7 @@ class Parser(metaclass=_Parser):
|
|||
return parse_result
|
||||
|
||||
def _parse_select_or_expression(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_select() or self._parse_expression()
|
||||
return self._parse_select() or self._parse_set_operations(self._parse_expression())
|
||||
|
||||
def _parse_ddl_select(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_set_operations(
|
||||
|
@ -3741,6 +3821,8 @@ class Parser(metaclass=_Parser):
|
|||
expression = self._parse_foreign_key()
|
||||
elif kind == TokenType.PRIMARY_KEY or self._match(TokenType.PRIMARY_KEY):
|
||||
expression = self._parse_primary_key()
|
||||
else:
|
||||
expression = None
|
||||
|
||||
return self.expression(exp.AddConstraint, this=this, expression=expression)
|
||||
|
||||
|
@ -3799,12 +3881,15 @@ class Parser(metaclass=_Parser):
|
|||
parser = self.ALTER_PARSERS.get(self._prev.text.upper()) if self._prev else None
|
||||
|
||||
if parser:
|
||||
return self.expression(
|
||||
exp.AlterTable,
|
||||
this=this,
|
||||
exists=exists,
|
||||
actions=ensure_list(parser(self)),
|
||||
)
|
||||
actions = ensure_list(parser(self))
|
||||
|
||||
if not self._curr:
|
||||
return self.expression(
|
||||
exp.AlterTable,
|
||||
this=this,
|
||||
exists=exists,
|
||||
actions=actions,
|
||||
)
|
||||
return self._parse_as_command(start)
|
||||
|
||||
def _parse_merge(self) -> exp.Expression:
|
||||
|
|
|
@ -175,7 +175,7 @@ class Step:
|
|||
}
|
||||
for projection in projections:
|
||||
for i, e in aggregate.group.items():
|
||||
for child, _, _ in projection.walk():
|
||||
for child, *_ in projection.walk():
|
||||
if child == e:
|
||||
child.replace(exp.column(i, step.name))
|
||||
aggregate.add_dependency(step)
|
||||
|
|
|
@ -306,11 +306,11 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
return self._type_mapping_cache[schema_type]
|
||||
|
||||
|
||||
def ensure_schema(schema: t.Any) -> Schema:
|
||||
def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema:
|
||||
if isinstance(schema, Schema):
|
||||
return schema
|
||||
|
||||
return MappingSchema(schema)
|
||||
return MappingSchema(schema, dialect=dialect)
|
||||
|
||||
|
||||
def ensure_column_mapping(mapping: t.Optional[ColumnMapping]):
|
||||
|
|
|
@ -252,6 +252,7 @@ class TokenType(AutoName):
|
|||
PERCENT = auto()
|
||||
PIVOT = auto()
|
||||
PLACEHOLDER = auto()
|
||||
PRAGMA = auto()
|
||||
PRECEDING = auto()
|
||||
PRIMARY_KEY = auto()
|
||||
PROCEDURE = auto()
|
||||
|
@ -346,7 +347,8 @@ class Token:
|
|||
self.token_type = token_type
|
||||
self.text = text
|
||||
self.line = line
|
||||
self.col = max(col - len(text), 1)
|
||||
self.col = col - len(text)
|
||||
self.col = self.col if self.col > 1 else 1
|
||||
self.comments = comments
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
@ -586,6 +588,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"PARTITIONED_BY": TokenType.PARTITION_BY,
|
||||
"PERCENT": TokenType.PERCENT,
|
||||
"PIVOT": TokenType.PIVOT,
|
||||
"PRAGMA": TokenType.PRAGMA,
|
||||
"PRECEDING": TokenType.PRECEDING,
|
||||
"PRIMARY KEY": TokenType.PRIMARY_KEY,
|
||||
"PROCEDURE": TokenType.PROCEDURE,
|
||||
|
@ -654,6 +657,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"LONG": TokenType.BIGINT,
|
||||
"BIGINT": TokenType.BIGINT,
|
||||
"INT8": TokenType.BIGINT,
|
||||
"DEC": TokenType.DECIMAL,
|
||||
"DECIMAL": TokenType.DECIMAL,
|
||||
"MAP": TokenType.MAP,
|
||||
"NULLABLE": TokenType.NULLABLE,
|
||||
|
@ -714,7 +718,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"VACUUM": TokenType.COMMAND,
|
||||
}
|
||||
|
||||
WHITE_SPACE: t.Dict[str, TokenType] = {
|
||||
WHITE_SPACE: t.Dict[t.Optional[str], TokenType] = {
|
||||
" ": TokenType.SPACE,
|
||||
"\t": TokenType.SPACE,
|
||||
"\n": TokenType.BREAK,
|
||||
|
@ -813,11 +817,8 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
return self.sql[start:end]
|
||||
return ""
|
||||
|
||||
def _line_break(self, char: t.Optional[str]) -> bool:
|
||||
return self.WHITE_SPACE.get(char) == TokenType.BREAK # type: ignore
|
||||
|
||||
def _advance(self, i: int = 1) -> None:
|
||||
if self._line_break(self._char):
|
||||
if self.WHITE_SPACE.get(self._char) is TokenType.BREAK:
|
||||
self._set_new_line()
|
||||
|
||||
self._col += i
|
||||
|
@ -939,7 +940,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) # type: ignore
|
||||
self._advance(comment_end_size - 1)
|
||||
else:
|
||||
while not self._end and not self._line_break(self._peek):
|
||||
while not self._end and not self.WHITE_SPACE.get(self._peek) is TokenType.BREAK:
|
||||
self._advance()
|
||||
self._comments.append(self._text[comment_start_size:]) # type: ignore
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue