1
0
Fork 0

Merging upstream version 11.4.5.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:48:10 +01:00
parent 0a06643852
commit 88f99e1c27
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
131 changed files with 53004 additions and 37079 deletions

View file

@ -47,7 +47,7 @@ if t.TYPE_CHECKING:
T = t.TypeVar("T", bound=Expression)
__version__ = "11.4.1"
__version__ = "11.4.5"
pretty = False
"""Whether to format generated SQL by default."""

View file

@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
Dialect,
datestrtodate_sql,
inline_array_sql,
max_or_greatest,
min_or_least,
no_ilike_sql,
rename_func,
@ -212,6 +213,9 @@ class BigQuery(Dialect):
),
}
LOG_BASE_FIRST = False
LOG_DEFAULTS_TO_LN = True
class Generator(generator.Generator):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
@ -227,6 +231,7 @@ class BigQuery(Dialect):
exp.GroupConcat: rename_func("STRING_AGG"),
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Select: transforms.preprocess(
[_unqualify_unnest], transforms.delegate("select_sql")
@ -253,17 +258,19 @@ class BigQuery(Dialect):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.SMALLINT: "INT64",
exp.DataType.Type.INT: "INT64",
exp.DataType.Type.BIGINT: "INT64",
exp.DataType.Type.DECIMAL: "NUMERIC",
exp.DataType.Type.FLOAT: "FLOAT64",
exp.DataType.Type.DOUBLE: "FLOAT64",
exp.DataType.Type.BOOLEAN: "BOOL",
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.VARCHAR: "STRING",
exp.DataType.Type.CHAR: "STRING",
exp.DataType.Type.DECIMAL: "NUMERIC",
exp.DataType.Type.DOUBLE: "FLOAT64",
exp.DataType.Type.FLOAT: "FLOAT64",
exp.DataType.Type.INT: "INT64",
exp.DataType.Type.NCHAR: "STRING",
exp.DataType.Type.NVARCHAR: "STRING",
exp.DataType.Type.SMALLINT: "INT64",
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.VARCHAR: "STRING",
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
@ -271,6 +278,7 @@ class BigQuery(Dialect):
}
EXPLICIT_UNION = True
LIMIT_FETCH = "LIMIT"
def array_sql(self, expression: exp.Array) -> str:
first_arg = seq_get(expression.expressions, 0)

View file

@ -68,6 +68,8 @@ class ClickHouse(Dialect):
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore
LOG_DEFAULTS_TO_LN = True
def _parse_in(
self, this: t.Optional[exp.Expression], is_global: bool = False
) -> exp.Expression:

View file

@ -16,6 +16,8 @@ class Databricks(Spark):
"DATEDIFF": parse_date_delta(exp.DateDiff),
}
LOG_DEFAULTS_TO_LN = True
class Generator(Spark.Generator):
TRANSFORMS = {
**Spark.Generator.TRANSFORMS, # type: ignore

View file

@ -430,6 +430,11 @@ def min_or_least(self: Generator, expression: exp.Min) -> str:
return rename_func(name)(self, expression)
def max_or_greatest(self: Generator, expression: exp.Max) -> str:
name = "GREATEST" if expression.expressions else "MAX"
return rename_func(name)(self, expression)
def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
cond = expression.this

View file

@ -1,6 +1,5 @@
from __future__ import annotations
import re
import typing as t
from sqlglot import exp, generator, parser, tokens
@ -102,6 +101,8 @@ class Drill(Dialect):
"TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"),
}
LOG_DEFAULTS_TO_LN = True
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
@ -154,4 +155,4 @@ class Drill(Dialect):
}
def normalize_func(self, name: str) -> str:
return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`"
return name if exp.SAFE_IDENTIFIER_RE.match(name) else f"`{name}`"

View file

@ -80,6 +80,7 @@ class DuckDB(Dialect):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"~": TokenType.RLIKE,
":=": TokenType.EQ,
"ATTACH": TokenType.COMMAND,
"BINARY": TokenType.VARBINARY,
@ -212,5 +213,7 @@ class DuckDB(Dialect):
"except": "EXCLUDE",
}
LIMIT_FETCH = "LIMIT"
def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str:
return super().tablesample_sql(expression, seed_prefix="REPEATABLE")

View file

@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
format_time_lambda,
if_sql,
locate_to_strposition,
max_or_greatest,
min_or_least,
no_ilike_sql,
no_recursive_cte_sql,
@ -34,6 +35,13 @@ DATE_DELTA_INTERVAL = {
"DAY": ("DATE_ADD", 1),
}
TIME_DIFF_FACTOR = {
"MILLISECOND": " * 1000",
"SECOND": "",
"MINUTE": " / 60",
"HOUR": " / 3600",
}
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
@ -51,6 +59,14 @@ def _add_date_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
unit = expression.text("unit").upper()
factor = TIME_DIFF_FACTOR.get(unit)
if factor is not None:
left = self.sql(expression, "this")
right = self.sql(expression, "expression")
sec_diff = f"UNIX_TIMESTAMP({left}) - UNIX_TIMESTAMP({right})"
return f"({sec_diff}){factor}" if factor else sec_diff
sql_func = "MONTHS_BETWEEN" if unit in DIFF_MONTH_SWITCH else "DATEDIFF"
_, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1))
multiplier_sql = f" / {multiplier}" if multiplier > 1 else ""
@ -237,11 +253,6 @@ class Hive(Dialect):
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
"LOCATE": locate_to_strposition,
"LOG": (
lambda args: exp.Log.from_arg_list(args)
if len(args) > 1
else exp.Ln.from_arg_list(args)
),
"MAP": parse_var_map,
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
"PERCENTILE": exp.Quantile.from_arg_list,
@ -261,6 +272,8 @@ class Hive(Dialect):
),
}
LOG_DEFAULTS_TO_LN = True
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
@ -293,6 +306,7 @@ class Hive(Dialect):
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
exp.Map: var_map_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.VarMap: var_map_sql,
exp.Create: create_with_partitions_sql,
@ -338,6 +352,8 @@ class Hive(Dialect):
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA,
}
LIMIT_FETCH = "LIMIT"
def arrayagg_sql(self, expression: exp.ArrayAgg) -> str:
return self.func(
"COLLECT_LIST",

View file

@ -3,7 +3,9 @@ from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
locate_to_strposition,
max_or_greatest,
min_or_least,
no_ilike_sql,
no_paren_current_date_sql,
@ -288,6 +290,8 @@ class MySQL(Dialect):
"SWAPS",
}
LOG_DEFAULTS_TO_LN = True
def _parse_show_mysql(self, this, target=False, full=None, global_=None):
if target:
if isinstance(target, str):
@ -303,7 +307,13 @@ class MySQL(Dialect):
db = None
else:
position = None
db = self._parse_id_var() if self._match_text_seq("FROM") else None
db = None
if self._match(TokenType.FROM):
db = self._parse_id_var()
elif self._match(TokenType.DOT):
db = target_id
target_id = self._parse_id_var()
channel = self._parse_id_var() if self._match_text_seq("FOR", "CHANNEL") else None
@ -384,6 +394,8 @@ class MySQL(Dialect):
exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.ILike: no_ilike_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.TableSample: no_tablesample_sql,
exp.TryCast: no_trycast_sql,
@ -415,6 +427,8 @@ class MySQL(Dialect):
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
}
LIMIT_FETCH = "LIMIT"
def show_sql(self, expression):
this = f" {expression.name}"
full = " FULL" if expression.args.get("full") else ""

View file

@ -4,7 +4,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sql
from sqlglot.helper import csv, seq_get
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
@ -13,10 +13,6 @@ PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
}
def _limit_sql(self, expression):
return self.fetch_sql(exp.Fetch(direction="FIRST", count=expression.expression))
def _parse_xml_table(self) -> exp.XMLTable:
this = self._parse_string()
@ -89,6 +85,20 @@ class Oracle(Dialect):
column.set("join_mark", self._match(TokenType.JOIN_MARKER))
return column
def _parse_hint(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.HINT):
start = self._curr
while self._curr and not self._match_pair(TokenType.STAR, TokenType.SLASH):
self._advance()
if not self._curr:
self.raise_error("Expected */ after HINT")
end = self._tokens[self._index - 3]
return exp.Hint(expressions=[self._find_sql(start, end)])
return None
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
@ -110,41 +120,20 @@ class Oracle(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.ILike: no_ilike_sql,
exp.Limit: _limit_sql,
exp.Trim: trim_sql,
exp.Matches: rename_func("DECODE"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
exp.Substring: rename_func("SUBSTR"),
exp.Table: lambda self, e: self.table_sql(e, sep=" "),
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
exp.Substring: rename_func("SUBSTR"),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: trim_sql,
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
}
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
return csv(
*sqls,
*[self.sql(sql) for sql in expression.args.get("joins") or []],
self.sql(expression, "match"),
*[self.sql(sql) for sql in expression.args.get("laterals") or []],
self.sql(expression, "where"),
self.sql(expression, "group"),
self.sql(expression, "having"),
self.sql(expression, "qualify"),
self.seg("WINDOW ") + self.expressions(expression, "windows", flat=True)
if expression.args.get("windows")
else "",
self.sql(expression, "distribute"),
self.sql(expression, "sort"),
self.sql(expression, "cluster"),
self.sql(expression, "order"),
self.sql(expression, "offset"), # offset before limit in oracle
self.sql(expression, "limit"),
self.sql(expression, "lock"),
sep="",
)
LIMIT_FETCH = "FETCH"
def offset_sql(self, expression: exp.Offset) -> str:
return f"{super().offset_sql(expression)} ROWS"

View file

@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import (
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
format_time_lambda,
max_or_greatest,
min_or_least,
no_paren_current_date_sql,
no_tablesample_sql,
@ -315,6 +316,7 @@ class Postgres(Dialect):
exp.DateDiff: _date_diff_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),

View file

@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import (
datestrtodate_sql,
format_time_lambda,
inline_array_sql,
max_or_greatest,
min_or_least,
rename_func,
timestamptrunc_sql,
@ -275,6 +276,9 @@ class Snowflake(Dialect):
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
exp.DateAdd: lambda self, e: self.func("DATEADD", e.text("unit"), e.expression, e.this),
exp.DateDiff: lambda self, e: self.func(
"DATEDIFF", e.text("unit"), e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
exp.If: rename_func("IFF"),
@ -296,6 +300,7 @@ class Snowflake(Dialect):
exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"),
exp.UnixToTime: _unix_to_time_sql,
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.Max: max_or_greatest,
exp.Min: min_or_least,
}
@ -314,12 +319,6 @@ class Snowflake(Dialect):
exp.SetProperty: exp.Properties.Location.UNSUPPORTED,
}
def ilikeany_sql(self, expression: exp.ILikeAny) -> str:
return self.binary(expression, "ILIKE ANY")
def likeany_sql(self, expression: exp.LikeAny) -> str:
return self.binary(expression, "LIKE ANY")
def except_op(self, expression):
if not expression.args.get("distinct", False):
self.unsupported("EXCEPT with All is not supported in Snowflake")

View file

@ -82,6 +82,8 @@ class SQLite(Dialect):
exp.TryCast: no_trycast_sql,
}
LIMIT_FETCH = "LIMIT"
def cast_sql(self, expression: exp.Cast) -> str:
if expression.to.this == exp.DataType.Type.DATE:
return self.func("DATE", expression.this)
@ -115,9 +117,6 @@ class SQLite(Dialect):
return f"CAST({sql} AS INTEGER)"
def fetch_sql(self, expression: exp.Fetch) -> str:
return self.limit_sql(exp.Limit(expression=expression.args.get("count")))
# https://www.sqlite.org/lang_aggfunc.html#group_concat
def groupconcat_sql(self, expression):
this = expression.this

View file

@ -1,7 +1,7 @@
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, min_or_least
from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least
from sqlglot.tokens import TokenType
@ -128,6 +128,7 @@ class Teradata(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}

View file

@ -6,6 +6,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
max_or_greatest,
min_or_least,
parse_date_delta,
rename_func,
@ -269,7 +270,6 @@ class TSQL(Dialect):
# TSQL allows @, # to appear as a variable/identifier prefix
SINGLE_TOKENS = tokens.Tokenizer.SINGLE_TOKENS.copy()
SINGLE_TOKENS.pop("@")
SINGLE_TOKENS.pop("#")
class Parser(parser.Parser):
@ -313,6 +313,9 @@ class TSQL(Dialect):
TokenType.END: lambda self: self._parse_command(),
}
LOG_BASE_FIRST = False
LOG_DEFAULTS_TO_LN = True
def _parse_system_time(self) -> t.Optional[exp.Expression]:
if not self._match_text_seq("FOR", "SYSTEM_TIME"):
return None
@ -435,11 +438,17 @@ class TSQL(Dialect):
exp.NumberToStr: _format_sql,
exp.TimeToStr: _format_sql,
exp.GroupConcat: _string_agg_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
}
TRANSFORMS.pop(exp.ReturnsProperty)
LIMIT_FETCH = "FETCH"
def offset_sql(self, expression: exp.Offset) -> str:
return f"{super().offset_sql(expression)} ROWS"
def systemtime_sql(self, expression: exp.SystemTime) -> str:
kind = expression.args["kind"]
if kind == "ALL":

View file

@ -12,7 +12,7 @@ from dataclasses import dataclass
from heapq import heappop, heappush
from sqlglot import Dialect, expressions as exp
from sqlglot.helper import ensure_collection
from sqlglot.helper import ensure_list
@dataclass(frozen=True)
@ -151,8 +151,8 @@ class ChangeDistiller:
self._source = source
self._target = target
self._source_index = {id(n[0]): n[0] for n in source.bfs()}
self._target_index = {id(n[0]): n[0] for n in target.bfs()}
self._source_index = {id(n): n for n, *_ in self._source.bfs()}
self._target_index = {id(n): n for n, *_ in self._target.bfs()}
self._unmatched_source_nodes = set(self._source_index) - set(pre_matched_nodes)
self._unmatched_target_nodes = set(self._target_index) - set(pre_matched_nodes.values())
self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {}
@ -199,10 +199,10 @@ class ChangeDistiller:
matching_set = leaves_matching_set.copy()
ordered_unmatched_source_nodes = {
id(n[0]): None for n in self._source.bfs() if id(n[0]) in self._unmatched_source_nodes
id(n): None for n, *_ in self._source.bfs() if id(n) in self._unmatched_source_nodes
}
ordered_unmatched_target_nodes = {
id(n[0]): None for n in self._target.bfs() if id(n[0]) in self._unmatched_target_nodes
id(n): None for n, *_ in self._target.bfs() if id(n) in self._unmatched_target_nodes
}
for source_node_id in ordered_unmatched_source_nodes:
@ -304,18 +304,18 @@ class ChangeDistiller:
def _get_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]:
has_child_exprs = False
for a in expression.args.values():
for node in ensure_collection(a):
if isinstance(node, exp.Expression):
has_child_exprs = True
yield from _get_leaves(node)
for _, node in expression.iter_expressions():
has_child_exprs = True
yield from _get_leaves(node)
if not has_child_exprs:
yield expression
def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool:
if type(source) is type(target):
if type(source) is type(target) and (
not isinstance(source, exp.Identifier) or type(source.parent) is type(target.parent)
):
if isinstance(source, exp.Join):
return source.args.get("side") == target.args.get("side")
@ -331,7 +331,7 @@ def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]:
args: t.List[t.Union[exp.Expression, t.List]] = []
if expression:
for a in expression.args.values():
args.extend(ensure_collection(a))
args.extend(ensure_list(a))
return [a for a in args if isinstance(a, exp.Expression)]

View file

@ -57,7 +57,7 @@ def execute(
for name, table in tables_.mapping.items()
}
schema = ensure_schema(schema)
schema = ensure_schema(schema, dialect=read)
if tables_.supported_table_args and tables_.supported_table_args != schema.supported_table_args:
raise ExecuteError("Tables must support the same table args as schema")

View file

@ -94,13 +94,10 @@ class PythonExecutor:
if source and isinstance(source, exp.Expression):
source = source.name or source.alias
condition = self.generate(step.condition)
projections = self.generate_tuple(step.projections)
if source is None:
context, table_iter = self.static()
elif source in context:
if not projections and not condition:
if not step.projections and not step.condition:
return self.context({step.name: context.tables[source]})
table_iter = context.table_iter(source)
elif isinstance(step.source, exp.Table) and isinstance(step.source.this, exp.ReadCSV):
@ -109,10 +106,12 @@ class PythonExecutor:
else:
context, table_iter = self.scan_table(step)
if projections:
sink = self.table(step.projections)
else:
sink = self.table(context.columns)
return self.context({step.name: self._project_and_filter(context, step, table_iter)})
def _project_and_filter(self, context, step, table_iter):
sink = self.table(step.projections if step.projections else context.columns)
condition = self.generate(step.condition)
projections = self.generate_tuple(step.projections)
for reader in table_iter:
if len(sink) >= step.limit:
@ -126,7 +125,7 @@ class PythonExecutor:
else:
sink.append(reader.row)
return self.context({step.name: sink})
return sink
def static(self):
return self.context({}), [RowReader(())]
@ -185,27 +184,16 @@ class PythonExecutor:
if condition:
source_context.filter(condition)
condition = self.generate(step.condition)
projections = self.generate_tuple(step.projections)
if not condition and not projections:
if not step.condition and not step.projections:
return source_context
sink = self.table(step.projections if projections else source_context.columns)
sink = self._project_and_filter(
source_context,
step,
(reader for reader, _ in iter(source_context)),
)
for reader, ctx in source_context:
if condition and not ctx.eval(condition):
continue
if projections:
sink.append(ctx.eval_tuple(projections))
else:
sink.append(reader.row)
if len(sink) >= step.limit:
break
if projections:
if step.projections:
return self.context({step.name: sink})
else:
return self.context(

View file

@ -26,6 +26,7 @@ from sqlglot.helper import (
AutoName,
camel_to_snake_case,
ensure_collection,
ensure_list,
seq_get,
split_num_words,
subclasses,
@ -84,7 +85,7 @@ class Expression(metaclass=_Expression):
key = "expression"
arg_types = {"this": True}
__slots__ = ("args", "parent", "arg_key", "comments", "_type", "_meta")
__slots__ = ("args", "parent", "arg_key", "comments", "_type", "_meta", "_hash")
def __init__(self, **args: t.Any):
self.args: t.Dict[str, t.Any] = args
@ -93,22 +94,30 @@ class Expression(metaclass=_Expression):
self.comments: t.Optional[t.List[str]] = None
self._type: t.Optional[DataType] = None
self._meta: t.Optional[t.Dict[str, t.Any]] = None
self._hash: t.Optional[int] = None
for arg_key, value in self.args.items():
self._set_parent(arg_key, value)
def __eq__(self, other) -> bool:
return type(self) is type(other) and _norm_args(self) == _norm_args(other)
return type(self) is type(other) and hash(self) == hash(other)
@property
def hashable_args(self) -> t.Any:
args = (self.args.get(k) for k in self.arg_types)
return tuple(
(tuple(_norm_arg(a) for a in arg) if arg else None)
if type(arg) is list
else (_norm_arg(arg) if arg is not None and arg is not False else None)
for arg in args
)
def __hash__(self) -> int:
return hash(
(
self.key,
tuple(
(k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()
),
)
)
if self._hash is not None:
return self._hash
return hash((self.__class__, self.hashable_args))
@property
def this(self):
@ -247,9 +256,6 @@ class Expression(metaclass=_Expression):
"""
new = deepcopy(self)
new.parent = self.parent
for item, parent, _ in new.bfs():
if isinstance(item, Expression) and parent:
item.parent = parent
return new
def append(self, arg_key, value):
@ -277,12 +283,12 @@ class Expression(metaclass=_Expression):
self._set_parent(arg_key, value)
def _set_parent(self, arg_key, value):
if isinstance(value, Expression):
if hasattr(value, "parent"):
value.parent = self
value.arg_key = arg_key
elif isinstance(value, list):
elif type(value) is list:
for v in value:
if isinstance(v, Expression):
if hasattr(v, "parent"):
v.parent = self
v.arg_key = arg_key
@ -295,6 +301,17 @@ class Expression(metaclass=_Expression):
return self.parent.depth + 1
return 0
def iter_expressions(self) -> t.Iterator[t.Tuple[str, Expression]]:
"""Yields the key and expression for all arguments, exploding list args."""
for k, vs in self.args.items():
if type(vs) is list:
for v in vs:
if hasattr(v, "parent"):
yield k, v
else:
if hasattr(vs, "parent"):
yield k, vs
def find(self, *expression_types: t.Type[E], bfs=True) -> E | None:
"""
Returns the first node in this tree which matches at least one of
@ -319,7 +336,7 @@ class Expression(metaclass=_Expression):
Returns:
The generator object.
"""
for expression, _, _ in self.walk(bfs=bfs):
for expression, *_ in self.walk(bfs=bfs):
if isinstance(expression, expression_types):
yield expression
@ -345,6 +362,11 @@ class Expression(metaclass=_Expression):
"""
return self.find_ancestor(Select)
@property
def same_parent(self):
"""Returns if the parent is the same class as itself."""
return type(self.parent) is self.__class__
def root(self) -> Expression:
"""
Returns the root expression of this tree.
@ -385,10 +407,8 @@ class Expression(metaclass=_Expression):
if prune and prune(self, parent, key):
return
for k, v in self.args.items():
for node in ensure_collection(v):
if isinstance(node, Expression):
yield from node.dfs(self, k, prune)
for k, v in self.iter_expressions():
yield from v.dfs(self, k, prune)
def bfs(self, prune=None):
"""
@ -407,18 +427,15 @@ class Expression(metaclass=_Expression):
if prune and prune(item, parent, key):
continue
if isinstance(item, Expression):
for k, v in item.args.items():
for node in ensure_collection(v):
if isinstance(node, Expression):
queue.append((node, item, k))
for k, v in item.iter_expressions():
queue.append((v, item, k))
def unnest(self):
"""
Returns the first non parenthesis child or self.
"""
expression = self
while isinstance(expression, Paren):
while type(expression) is Paren:
expression = expression.this
return expression
@ -434,7 +451,7 @@ class Expression(metaclass=_Expression):
"""
Returns unnested operands as a tuple.
"""
return tuple(arg.unnest() for arg in self.args.values() if arg)
return tuple(arg.unnest() for _, arg in self.iter_expressions())
def flatten(self, unnest=True):
"""
@ -442,8 +459,8 @@ class Expression(metaclass=_Expression):
A AND B AND C -> [A, B, C]
"""
for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not isinstance(n, self.__class__)):
if not isinstance(node, self.__class__):
for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not type(n) is self.__class__):
if not type(node) is self.__class__:
yield node.unnest() if unnest else node
def __str__(self):
@ -477,7 +494,7 @@ class Expression(metaclass=_Expression):
v._to_s(hide_missing=hide_missing, level=level + 1)
if hasattr(v, "_to_s")
else str(v)
for v in ensure_collection(vs)
for v in ensure_list(vs)
if v is not None
)
for k, vs in self.args.items()
@ -812,6 +829,10 @@ class Describe(Expression):
arg_types = {"this": True, "kind": False}
class Pragma(Expression):
pass
class Set(Expression):
arg_types = {"expressions": False}
@ -1170,6 +1191,7 @@ class Drop(Expression):
"temporary": False,
"materialized": False,
"cascade": False,
"constraints": False,
}
@ -1232,11 +1254,11 @@ class Identifier(Expression):
def quoted(self):
return bool(self.args.get("quoted"))
def __eq__(self, other):
return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg(other.this)
def __hash__(self):
return hash((self.key, self.this.lower()))
@property
def hashable_args(self) -> t.Any:
if self.quoted and any(char.isupper() for char in self.this):
return (self.this, self.quoted)
return self.this.lower()
@property
def output_name(self):
@ -1322,15 +1344,9 @@ class Limit(Expression):
class Literal(Condition):
arg_types = {"this": True, "is_string": True}
def __eq__(self, other):
return (
isinstance(other, Literal)
and self.this == other.this
and self.args["is_string"] == other.args["is_string"]
)
def __hash__(self):
return hash((self.key, self.this, self.args["is_string"]))
@property
def hashable_args(self) -> t.Any:
return (self.this, self.args.get("is_string"))
@classmethod
def number(cls, number) -> Literal:
@ -1784,7 +1800,7 @@ class Subqueryable(Unionable):
instance = _maybe_copy(self, copy)
return Subquery(
this=instance,
alias=TableAlias(this=to_identifier(alias)),
alias=TableAlias(this=to_identifier(alias)) if alias else None,
)
def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
@ -2058,6 +2074,7 @@ class Lock(Expression):
class Select(Subqueryable):
arg_types = {
"with": False,
"kind": False,
"expressions": False,
"hint": False,
"distinct": False,
@ -3595,6 +3612,21 @@ class Initcap(Func):
pass
class JSONKeyValue(Expression):
arg_types = {"this": True, "expression": True}
class JSONObject(Func):
arg_types = {
"expressions": False,
"null_handling": False,
"unique_keys": False,
"return_type": False,
"format_json": False,
"encoding": False,
}
class JSONBContains(Binary):
_sql_names = ["JSONB_CONTAINS"]
@ -3766,8 +3798,10 @@ class RegexpILike(Func):
arg_types = {"this": True, "expression": True, "flag": False}
# https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.split.html
# limit is the number of times a pattern is applied
class RegexpSplit(Func):
arg_types = {"this": True, "expression": True}
arg_types = {"this": True, "expression": True, "limit": False}
class Repeat(Func):
@ -3967,25 +4001,8 @@ class When(Func):
arg_types = {"matched": True, "source": False, "condition": False, "then": True}
def _norm_args(expression):
args = {}
for k, arg in expression.args.items():
if isinstance(arg, list):
arg = [_norm_arg(a) for a in arg]
if not arg:
arg = None
else:
arg = _norm_arg(arg)
if arg is not None and arg is not False:
args[k] = arg
return args
def _norm_arg(arg):
return arg.lower() if isinstance(arg, str) else arg
return arg.lower() if type(arg) is str else arg
ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
@ -4512,7 +4529,7 @@ def to_identifier(name, quoted=None):
elif isinstance(name, str):
identifier = Identifier(
this=name,
quoted=not re.match(SAFE_IDENTIFIER_RE, name) if quoted is None else quoted,
quoted=not SAFE_IDENTIFIER_RE.match(name) if quoted is None else quoted,
)
else:
raise ValueError(f"Name needs to be a string or an Identifier, got: {name.__class__}")
@ -4586,8 +4603,7 @@ def to_column(sql_path: str | Column, **kwargs) -> Column:
return sql_path
if not isinstance(sql_path, str):
raise ValueError(f"Invalid type provided for column: {type(sql_path)}")
table_name, column_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 2))
return Column(this=column_name, table=table_name, **kwargs)
return column(*reversed(sql_path.split(".")), **kwargs) # type: ignore
def alias_(
@ -4672,7 +4688,8 @@ def subquery(expression, alias=None, dialect=None, **opts):
def column(
col: str | Identifier,
table: t.Optional[str | Identifier] = None,
schema: t.Optional[str | Identifier] = None,
db: t.Optional[str | Identifier] = None,
catalog: t.Optional[str | Identifier] = None,
quoted: t.Optional[bool] = None,
) -> Column:
"""
@ -4681,7 +4698,8 @@ def column(
Args:
col: column name
table: table name
schema: schema name
db: db name
catalog: catalog name
quoted: whether or not to force quote each part
Returns:
Column: column instance
@ -4689,7 +4707,8 @@ def column(
return Column(
this=to_identifier(col, quoted=quoted),
table=to_identifier(table, quoted=quoted),
schema=to_identifier(schema, quoted=quoted),
db=to_identifier(db, quoted=quoted),
catalog=to_identifier(catalog, quoted=quoted),
)
@ -4864,7 +4883,7 @@ def replace_children(expression, fun, *args, **kwargs):
Replace children of an expression with the result of a lambda fun(child) -> exp.
"""
for k, v in expression.args.items():
is_list_arg = isinstance(v, list)
is_list_arg = type(v) is list
child_nodes = v if is_list_arg else [v]
new_child_nodes = []

View file

@ -110,6 +110,10 @@ class Generator:
# Whether or not MERGE ... WHEN MATCHED BY SOURCE is allowed
MATCHED_BY_SOURCE = True
# Whether or not limit and fetch are supported
# "ALL", "LIMIT", "FETCH"
LIMIT_FETCH = "ALL"
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@ -209,6 +213,7 @@ class Generator:
"_leading_comma",
"_max_text_width",
"_comments",
"_cache",
)
def __init__(
@ -265,19 +270,28 @@ class Generator:
self._leading_comma = leading_comma
self._max_text_width = max_text_width
self._comments = comments
self._cache = None
def generate(self, expression: t.Optional[exp.Expression]) -> str:
def generate(
self,
expression: t.Optional[exp.Expression],
cache: t.Optional[t.Dict[int, str]] = None,
) -> str:
"""
Generates a SQL string by interpreting the given syntax tree.
Args
expression: the syntax tree.
cache: an optional sql string cache. this leverages the hash of an expression which is slow, so only use this if you set _hash on each node.
Returns
the SQL string.
"""
if cache is not None:
self._cache = cache
self.unsupported_messages = []
sql = self.sql(expression).strip()
self._cache = None
if self.unsupported_level == ErrorLevel.IGNORE:
return sql
@ -387,6 +401,12 @@ class Generator:
if key:
return self.sql(expression.args.get(key))
if self._cache is not None:
expression_id = hash(expression)
if expression_id in self._cache:
return self._cache[expression_id]
transform = self.TRANSFORMS.get(expression.__class__)
if callable(transform):
@ -407,7 +427,11 @@ class Generator:
else:
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
return self.maybe_comment(sql, expression) if self._comments and comment else sql
sql = self.maybe_comment(sql, expression) if self._comments and comment else sql
if self._cache is not None:
self._cache[expression_id] = sql
return sql
def uncache_sql(self, expression: exp.Uncache) -> str:
table = self.sql(expression, "this")
@ -697,7 +721,8 @@ class Generator:
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
cascade = " CASCADE" if expression.args.get("cascade") else ""
return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}"
constraints = " CONSTRAINTS" if expression.args.get("constraints") else ""
return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}{constraints}"
def except_sql(self, expression: exp.Except) -> str:
return self.prepend_ctes(
@ -733,9 +758,9 @@ class Generator:
def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
text = text.lower() if self.normalize else text
text = text.lower() if self.normalize and not expression.quoted else text
text = text.replace(self.identifier_end, self._escaped_identifier_end)
if expression.args.get("quoted") or should_identify(text, self.identify):
if expression.quoted or should_identify(text, self.identify):
text = f"{self.identifier_start}{text}{self.identifier_end}"
return text
@ -1191,6 +1216,9 @@ class Generator:
)
return f"SET{expressions}"
def pragma_sql(self, expression: exp.Pragma) -> str:
return f"PRAGMA {self.sql(expression, 'this')}"
def lock_sql(self, expression: exp.Lock) -> str:
if self.LOCKING_READS_SUPPORTED:
lock_type = "UPDATE" if expression.args["update"] else "SHARE"
@ -1299,6 +1327,15 @@ class Generator:
return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}"
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
limit = expression.args.get("limit")
if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch):
limit = exp.Limit(expression=limit.args.get("count"))
elif self.LIMIT_FETCH == "FETCH" and isinstance(limit, exp.Limit):
limit = exp.Fetch(direction="FIRST", count=limit.expression)
fetch = isinstance(limit, exp.Fetch)
return csv(
*sqls,
*[self.sql(sql) for sql in expression.args.get("joins") or []],
@ -1315,14 +1352,16 @@ class Generator:
self.sql(expression, "sort"),
self.sql(expression, "cluster"),
self.sql(expression, "order"),
self.sql(expression, "limit"),
self.sql(expression, "offset"),
self.sql(expression, "offset") if fetch else self.sql(limit),
self.sql(limit) if fetch else self.sql(expression, "offset"),
self.sql(expression, "lock"),
self.sql(expression, "sample"),
sep="",
)
def select_sql(self, expression: exp.Select) -> str:
kind = expression.args.get("kind")
kind = f" AS {kind}" if kind else ""
hint = self.sql(expression, "hint")
distinct = self.sql(expression, "distinct")
distinct = f" {distinct}" if distinct else ""
@ -1330,7 +1369,7 @@ class Generator:
expressions = f"{self.sep()}{expressions}" if expressions else expressions
sql = self.query_modifiers(
expression,
f"SELECT{hint}{distinct}{expressions}",
f"SELECT{kind}{hint}{distinct}{expressions}",
self.sql(expression, "into", comment=False),
self.sql(expression, "from", comment=False),
)
@ -1552,6 +1591,25 @@ class Generator:
exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))
)
def jsonkeyvalue_sql(self, expression: exp.JSONKeyValue) -> str:
return f"{self.sql(expression, 'this')}: {self.sql(expression, 'expression')}"
def jsonobject_sql(self, expression: exp.JSONObject) -> str:
expressions = self.expressions(expression)
null_handling = expression.args.get("null_handling")
null_handling = f" {null_handling}" if null_handling else ""
unique_keys = expression.args.get("unique_keys")
if unique_keys is not None:
unique_keys = f" {'WITH' if unique_keys else 'WITHOUT'} UNIQUE KEYS"
else:
unique_keys = ""
return_type = self.sql(expression, "return_type")
return_type = f" RETURNING {return_type}" if return_type else ""
format_json = " FORMAT JSON" if expression.args.get("format_json") else ""
encoding = self.sql(expression, "encoding")
encoding = f" ENCODING {encoding}" if encoding else ""
return f"JSON_OBJECT({expressions}{null_handling}{unique_keys}{return_type}{format_json}{encoding})"
def in_sql(self, expression: exp.In) -> str:
query = expression.args.get("query")
unnest = expression.args.get("unnest")
@ -1808,12 +1866,18 @@ class Generator:
def ilike_sql(self, expression: exp.ILike) -> str:
return self.binary(expression, "ILIKE")
def ilikeany_sql(self, expression: exp.ILikeAny) -> str:
return self.binary(expression, "ILIKE ANY")
def is_sql(self, expression: exp.Is) -> str:
return self.binary(expression, "IS")
def like_sql(self, expression: exp.Like) -> str:
return self.binary(expression, "LIKE")
def likeany_sql(self, expression: exp.LikeAny) -> str:
return self.binary(expression, "LIKE ANY")
def similarto_sql(self, expression: exp.SimilarTo) -> str:
return self.binary(expression, "SIMILAR TO")

View file

@ -59,7 +59,7 @@ def ensure_list(value):
"""
if value is None:
return []
elif isinstance(value, (list, tuple)):
if isinstance(value, (list, tuple)):
return list(value)
return [value]
@ -162,9 +162,7 @@ def camel_to_snake_case(name: str) -> str:
return CAMEL_CASE_PATTERN.sub("_", name).upper()
def while_changing(
expression: t.Optional[Expression], func: t.Callable[[t.Optional[Expression]], E]
) -> E:
def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
"""
Applies a transformation to a given expression until a fix point is reached.
@ -176,8 +174,13 @@ def while_changing(
The transformed expression.
"""
while True:
for n, *_ in reversed(tuple(expression.walk())):
n._hash = hash(n)
start = hash(expression)
expression = func(expression)
for n, *_ in expression.walk():
n._hash = None
if start == hash(expression):
break
return expression

View file

@ -1,5 +1,5 @@
from sqlglot import exp
from sqlglot.helper import ensure_collection, ensure_list, subclasses
from sqlglot.helper import ensure_list, subclasses
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
@ -108,6 +108,7 @@ class TypeAnnotator:
exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.GroupConcat: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
@ -116,6 +117,7 @@ class TypeAnnotator:
expr, exp.DataType.Type.VARCHAR
),
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL),
exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
@ -296,9 +298,6 @@ class TypeAnnotator:
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
def _maybe_annotate(self, expression):
if not isinstance(expression, exp.Expression):
return None
if expression.type:
return expression # We've already inferred the expression's type
@ -311,9 +310,8 @@ class TypeAnnotator:
)
def _annotate_args(self, expression):
for value in expression.args.values():
for v in ensure_collection(value):
self._maybe_annotate(v)
for _, value in expression.iter_expressions():
self._maybe_annotate(value)
return expression

View file

@ -75,7 +75,7 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
a.type
and a.type.this == exp.DataType.Type.DATE
and b.type
and b.type.this != exp.DataType.Type.DATE
and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL)
):
_replace_cast(b, "date")

View file

@ -1,7 +1,6 @@
from sqlglot import expressions as exp
from sqlglot.optimizer.normalize import normalized
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.optimizer.simplify import simplify
def eliminate_joins(expression):
@ -179,6 +178,4 @@ def join_condition(join):
for condition in conditions:
extract_condition(condition)
on = simplify(on)
remaining_condition = None if on == exp.true() else on
return source_key, join_key, remaining_condition
return source_key, join_key, on

View file

@ -3,7 +3,6 @@ import itertools
from sqlglot import expressions as exp
from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import build_scope
from sqlglot.optimizer.simplify import simplify
def eliminate_subqueries(expression):
@ -31,7 +30,6 @@ def eliminate_subqueries(expression):
eliminate_subqueries(expression.this)
return expression
expression = simplify(expression)
root = build_scope(expression)
# Map of alias->Scope|Table

View file

@ -1,5 +1,4 @@
from sqlglot import exp
from sqlglot.helper import ensure_collection
def lower_identities(expression):
@ -40,13 +39,10 @@ def lower_identities(expression):
lower_identities(expression.right)
traversed |= {"this", "expression"}
for k, v in expression.args.items():
for k, v in expression.iter_expressions():
if k in traversed:
continue
for child in ensure_collection(v):
if isinstance(child, exp.Expression):
child.transform(_lower, copy=False)
v.transform(_lower, copy=False)
return expression

View file

@ -3,7 +3,6 @@ from collections import defaultdict
from sqlglot import expressions as exp
from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.optimizer.simplify import simplify
def merge_subqueries(expression, leave_tables_isolated=False):
@ -330,11 +329,11 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
if set(exp.column_table_names(where.this)) <= sources:
from_or_join.on(where.this, copy=False)
from_or_join.set("on", simplify(from_or_join.args.get("on")))
from_or_join.set("on", from_or_join.args.get("on"))
return
expression.where(where.this, copy=False)
expression.set("where", simplify(expression.args.get("where")))
expression.set("where", expression.args.get("where"))
def _merge_order(outer_scope, inner_scope):

View file

@ -1,29 +1,63 @@
from __future__ import annotations
import logging
import typing as t
from sqlglot import exp
from sqlglot.errors import OptimizeError
from sqlglot.helper import while_changing
from sqlglot.optimizer.simplify import flatten, simplify, uniq_sort
from sqlglot.optimizer.simplify import flatten, uniq_sort
logger = logging.getLogger("sqlglot")
def normalize(expression, dnf=False, max_distance=128):
def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128):
"""
Rewrite sqlglot AST into conjunctive normal form.
Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(x AND y) OR z")
>>> normalize(expression).sql()
>>> normalize(expression, dnf=False).sql()
'(x OR z) AND (y OR z)'
Args:
expression (sqlglot.Expression): expression to normalize
dnf (bool): rewrite in disjunctive normal form instead
max_distance (int): the maximal estimated distance from cnf to attempt conversion
expression: expression to normalize
dnf: rewrite in disjunctive normal form instead.
max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
Returns:
sqlglot.Expression: normalized expression
"""
expression = simplify(expression)
cache: t.Dict[int, str] = {}
expression = while_changing(expression, lambda e: distributive_law(e, dnf, max_distance))
return simplify(expression)
for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
if isinstance(node, exp.Connector):
if normalized(node, dnf=dnf):
continue
distance = normalization_distance(node, dnf=dnf)
if distance > max_distance:
logger.info(
f"Skipping normalization because distance {distance} exceeds max {max_distance}"
)
return expression
root = node is expression
original = node.copy()
try:
node = while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
except OptimizeError as e:
logger.info(e)
node.replace(original)
if root:
return original
return expression
if root:
expression = node
return expression
def normalized(expression, dnf=False):
@ -51,7 +85,7 @@ def normalization_distance(expression, dnf=False):
int: difference
"""
return sum(_predicate_lengths(expression, dnf)) - (
len(list(expression.find_all(exp.Connector))) + 1
sum(1 for _ in expression.find_all(exp.Connector)) + 1
)
@ -64,30 +98,33 @@ def _predicate_lengths(expression, dnf):
expression = expression.unnest()
if not isinstance(expression, exp.Connector):
return [1]
return (1,)
left, right = expression.args.values()
if isinstance(expression, exp.And if dnf else exp.Or):
return [
return tuple(
a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
]
)
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
def distributive_law(expression, dnf, max_distance):
def distributive_law(expression, dnf, max_distance, cache=None):
"""
x OR (y AND z) -> (x OR y) AND (x OR z)
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
"""
if isinstance(expression.unnest(), exp.Connector):
if normalization_distance(expression, dnf) > max_distance:
return expression
if normalized(expression, dnf=dnf):
return expression
distance = normalization_distance(expression, dnf=dnf)
if distance > max_distance:
raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache))
to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
if isinstance(expression, from_exp):
a, b = expression.unnest_operands()
@ -96,32 +133,29 @@ def distributive_law(expression, dnf, max_distance):
if isinstance(a, to_exp) and isinstance(b, to_exp):
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
return _distribute(a, b, from_func, to_func)
return _distribute(b, a, from_func, to_func)
return _distribute(a, b, from_func, to_func, cache)
return _distribute(b, a, from_func, to_func, cache)
if isinstance(a, to_exp):
return _distribute(b, a, from_func, to_func)
return _distribute(b, a, from_func, to_func, cache)
if isinstance(b, to_exp):
return _distribute(a, b, from_func, to_func)
return _distribute(a, b, from_func, to_func, cache)
return expression
def _distribute(a, b, from_func, to_func):
def _distribute(a, b, from_func, to_func, cache):
if isinstance(a, exp.Connector):
exp.replace_children(
a,
lambda c: to_func(
exp.paren(from_func(c, b.left)),
exp.paren(from_func(c, b.right)),
uniq_sort(flatten(from_func(c, b.left)), cache),
uniq_sort(flatten(from_func(c, b.right)), cache),
),
)
else:
a = to_func(from_func(a, b.left), from_func(a, b.right))
a = to_func(
uniq_sort(flatten(from_func(a, b.left)), cache),
uniq_sort(flatten(from_func(a, b.right)), cache),
)
return _simplify(a)
def _simplify(node):
node = uniq_sort(flatten(node))
exp.replace_children(node, _simplify)
return node
return a

View file

@ -1,6 +1,5 @@
from sqlglot import exp
from sqlglot.helper import tsort
from sqlglot.optimizer.simplify import simplify
def optimize_joins(expression):
@ -29,7 +28,6 @@ def optimize_joins(expression):
for name, join in cross_joins:
for dep in references.get(name, []):
on = dep.args["on"]
on = on.replace(simplify(on))
if isinstance(on, exp.Connector):
for predicate in on.flatten():

View file

@ -21,6 +21,7 @@ from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.optimizer.simplify import simplify
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
from sqlglot.schema import ensure_schema
@ -43,6 +44,7 @@ RULES = (
eliminate_ctes,
annotate_types,
canonicalize,
simplify,
)
@ -78,7 +80,7 @@ def optimize(
Returns:
sqlglot.Expression: optimized expression
"""
schema = ensure_schema(schema or sqlglot.schema)
schema = ensure_schema(schema or sqlglot.schema, dialect=dialect)
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
expression = exp.maybe_parse(expression, dialect=dialect, copy=True)
for rule in rules:

View file

@ -30,11 +30,12 @@ def qualify_columns(expression, schema):
resolver = Resolver(scope, schema)
_pop_table_column_aliases(scope.ctes)
_pop_table_column_aliases(scope.derived_tables)
_expand_using(scope, resolver)
using_column_tables = _expand_using(scope, resolver)
_qualify_columns(scope, resolver)
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver)
_expand_stars(scope, resolver, using_column_tables)
_qualify_outputs(scope)
_expand_alias_refs(scope, resolver)
_expand_group_by(scope, resolver)
_expand_order_by(scope)
@ -69,11 +70,11 @@ def _pop_table_column_aliases(derived_tables):
def _expand_using(scope, resolver):
joins = list(scope.expression.find_all(exp.Join))
joins = list(scope.find_all(exp.Join))
names = {join.this.alias for join in joins}
ordered = [key for key in scope.selected_sources if key not in names]
# Mapping of automatically joined column names to source names
# Mapping of automatically joined column names to an ordered set of source names (dict).
column_tables = {}
for join in joins:
@ -112,11 +113,12 @@ def _expand_using(scope, resolver):
)
)
tables = column_tables.setdefault(identifier, [])
# Set all values in the dict to None, because we only care about the key ordering
tables = column_tables.setdefault(identifier, {})
if table not in tables:
tables.append(table)
tables[table] = None
if join_table not in tables:
tables.append(join_table)
tables[join_table] = None
join.args.pop("using")
join.set("on", exp.and_(*conditions))
@ -134,11 +136,11 @@ def _expand_using(scope, resolver):
scope.replace(column, replacement)
return column_tables
def _expand_group_by(scope, resolver):
group = scope.expression.args.get("group")
if not group:
return
def _expand_alias_refs(scope, resolver):
selects = {}
# Replace references to select aliases
def transform(node, *_):
@ -150,9 +152,11 @@ def _expand_group_by(scope, resolver):
node.set("table", table)
return node
selects = {s.alias_or_name: s for s in scope.selects}
if not selects:
for s in scope.selects:
selects[s.alias_or_name] = s
select = selects.get(node.name)
if select:
scope.clear_cache()
if isinstance(select, exp.Alias):
@ -161,7 +165,21 @@ def _expand_group_by(scope, resolver):
return node
group.transform(transform, copy=False)
for select in scope.expression.selects:
select.transform(transform, copy=False)
for modifier in ("where", "group"):
part = scope.expression.args.get(modifier)
if part:
part.transform(transform, copy=False)
def _expand_group_by(scope, resolver):
group = scope.expression.args.get("group")
if not group:
return
group.set("expressions", _expand_positional_references(scope, group.expressions))
scope.expression.set("group", group)
@ -231,18 +249,24 @@ def _qualify_columns(scope, resolver):
column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
columns_missing_from_scope = []
# Determine whether each reference in the order by clause is to a column or an alias.
for ordered in scope.find_all(exp.Ordered):
for column in ordered.find_all(exp.Column):
if (
not column.table
and column.parent is not ordered
and column.name in resolver.all_columns
):
columns_missing_from_scope.append(column)
order = scope.expression.args.get("order")
if order:
for ordered in order.expressions:
for column in ordered.find_all(exp.Column):
if (
not column.table
and column.parent is not ordered
and column.name in resolver.all_columns
):
columns_missing_from_scope.append(column)
# Determine whether each reference in the having clause is to a column or an alias.
for having in scope.find_all(exp.Having):
having = scope.expression.args.get("having")
if having:
for column in having.find_all(exp.Column):
if (
not column.table
@ -258,12 +282,13 @@ def _qualify_columns(scope, resolver):
column.set("table", column_table)
def _expand_stars(scope, resolver):
def _expand_stars(scope, resolver, using_column_tables):
"""Expand stars to lists of column selections"""
new_selections = []
except_columns = {}
replace_columns = {}
coalesced_columns = set()
for expression in scope.selects:
if isinstance(expression, exp.Star):
@ -286,7 +311,20 @@ def _expand_stars(scope, resolver):
if columns and "*" not in columns:
table_id = id(table)
for name in columns:
if name not in except_columns.get(table_id, set()):
if name in using_column_tables and table in using_column_tables[name]:
if name in coalesced_columns:
continue
coalesced_columns.add(name)
tables = using_column_tables[name]
coalesce = [exp.column(name, table=table) for table in tables]
new_selections.append(
exp.alias_(
exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), alias=name
)
)
elif name not in except_columns.get(table_id, set()):
alias_ = replace_columns.get(table_id, {}).get(name, name)
column = exp.column(name, table)
new_selections.append(alias(column, alias_) if alias_ != name else column)

View file

@ -160,7 +160,7 @@ class Scope:
Yields:
exp.Expression: nodes
"""
for expression, _, _ in self.walk(bfs=bfs):
for expression, *_ in self.walk(bfs=bfs):
if isinstance(expression, expression_types):
yield expression

View file

@ -5,11 +5,10 @@ from collections import deque
from decimal import Decimal
from sqlglot import exp
from sqlglot.expressions import FALSE, NULL, TRUE
from sqlglot.generator import Generator
from sqlglot.helper import first, while_changing
GENERATOR = Generator(normalize=True, identify=True)
GENERATOR = Generator(normalize=True, identify="safe")
def simplify(expression):
@ -28,18 +27,20 @@ def simplify(expression):
sqlglot.Expression: simplified expression
"""
cache = {}
def _simplify(expression, root=True):
node = expression
node = rewrite_between(node)
node = uniq_sort(node)
node = absorb_and_eliminate(node)
node = uniq_sort(node, cache, root)
node = absorb_and_eliminate(node, root)
exp.replace_children(node, lambda e: _simplify(e, False))
node = simplify_not(node)
node = flatten(node)
node = simplify_connectors(node)
node = remove_compliments(node)
node = simplify_connectors(node, root)
node = remove_compliments(node, root)
node.parent = expression.parent
node = simplify_literals(node)
node = simplify_literals(node, root)
node = simplify_parens(node)
if root:
expression.replace(node)
@ -70,7 +71,7 @@ def simplify_not(expression):
NOT (x AND y) -> NOT x OR NOT y
"""
if isinstance(expression, exp.Not):
if isinstance(expression.this, exp.Null):
if is_null(expression.this):
return exp.null()
if isinstance(expression.this, exp.Paren):
condition = expression.this.unnest()
@ -78,11 +79,11 @@ def simplify_not(expression):
return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
if isinstance(condition, exp.Or):
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
if isinstance(condition, exp.Null):
if is_null(condition):
return exp.null()
if always_true(expression.this):
return exp.false()
if expression.this == FALSE:
if is_false(expression.this):
return exp.true()
if isinstance(expression.this, exp.Not):
# double negation
@ -104,42 +105,42 @@ def flatten(expression):
return expression
def simplify_connectors(expression):
def simplify_connectors(expression, root=True):
def _simplify_connectors(expression, left, right):
if isinstance(expression, exp.Connector):
if left == right:
if left == right:
return left
if isinstance(expression, exp.And):
if is_false(left) or is_false(right):
return exp.false()
if is_null(left) or is_null(right):
return exp.null()
if always_true(left) and always_true(right):
return exp.true()
if always_true(left):
return right
if always_true(right):
return left
if isinstance(expression, exp.And):
if FALSE in (left, right):
return exp.false()
if NULL in (left, right):
return exp.null()
if always_true(left) and always_true(right):
return exp.true()
if always_true(left):
return right
if always_true(right):
return left
return _simplify_comparison(expression, left, right)
elif isinstance(expression, exp.Or):
if always_true(left) or always_true(right):
return exp.true()
if left == FALSE and right == FALSE:
return exp.false()
if (
(left == NULL and right == NULL)
or (left == NULL and right == FALSE)
or (left == FALSE and right == NULL)
):
return exp.null()
if left == FALSE:
return right
if right == FALSE:
return left
return _simplify_comparison(expression, left, right, or_=True)
return None
return _simplify_comparison(expression, left, right)
elif isinstance(expression, exp.Or):
if always_true(left) or always_true(right):
return exp.true()
if is_false(left) and is_false(right):
return exp.false()
if (
(is_null(left) and is_null(right))
or (is_null(left) and is_false(right))
or (is_false(left) and is_null(right))
):
return exp.null()
if is_false(left):
return right
if is_false(right):
return left
return _simplify_comparison(expression, left, right, or_=True)
return _flat_simplify(expression, _simplify_connectors)
if isinstance(expression, exp.Connector):
return _flat_simplify(expression, _simplify_connectors, root)
return expression
LT_LTE = (exp.LT, exp.LTE)
@ -220,14 +221,14 @@ def _simplify_comparison(expression, left, right, or_=False):
return None
def remove_compliments(expression):
def remove_compliments(expression, root=True):
"""
Removing compliments.
A AND NOT A -> FALSE
A OR NOT A -> TRUE
"""
if isinstance(expression, exp.Connector):
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
for a, b in itertools.permutations(expression.flatten(), 2):
@ -236,23 +237,23 @@ def remove_compliments(expression):
return expression
def uniq_sort(expression):
def uniq_sort(expression, cache=None, root=True):
"""
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
"""
if isinstance(expression, exp.Connector):
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
flattened = tuple(expression.flatten())
deduped = {GENERATOR.generate(e): e for e in flattened}
deduped = {GENERATOR.generate(e, cache): e for e in flattened}
arr = tuple(deduped.items())
# check if the operands are already sorted, if not sort them
# A AND C AND B -> A AND B AND C
for i, (sql, e) in enumerate(arr[1:]):
if sql < arr[i][0]:
expression = result_func(*(deduped[sql] for sql in sorted(deduped)))
expression = result_func(*(e for _, e in sorted(arr)))
break
else:
# we didn't have to sort but maybe we need to dedup
@ -262,7 +263,7 @@ def uniq_sort(expression):
return expression
def absorb_and_eliminate(expression):
def absorb_and_eliminate(expression, root=True):
"""
absorption:
A AND (A OR B) -> A
@ -273,7 +274,7 @@ def absorb_and_eliminate(expression):
(A AND B) OR (A AND NOT B) -> A
(A OR B) AND (A OR NOT B) -> A
"""
if isinstance(expression, exp.Connector):
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
kind = exp.Or if isinstance(expression, exp.And) else exp.And
for a, b in itertools.permutations(expression.flatten(), 2):
@ -302,9 +303,9 @@ def absorb_and_eliminate(expression):
return expression
def simplify_literals(expression):
if isinstance(expression, exp.Binary):
return _flat_simplify(expression, _simplify_binary)
def simplify_literals(expression, root=True):
if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
return _flat_simplify(expression, _simplify_binary, root)
elif isinstance(expression, exp.Neg):
this = expression.this
if this.is_number:
@ -325,14 +326,14 @@ def _simplify_binary(expression, a, b):
c = b
not_ = False
if c == NULL:
if is_null(c):
if isinstance(a, exp.Literal):
return exp.true() if not_ else exp.false()
if a == NULL:
if is_null(a):
return exp.false() if not_ else exp.true()
elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
return None
elif NULL in (a, b):
elif is_null(a) or is_null(b):
return exp.null()
if a.is_number and b.is_number:
@ -355,7 +356,7 @@ def _simplify_binary(expression, a, b):
if boolean:
return boolean
elif a.is_string and b.is_string:
boolean = eval_boolean(expression, a, b)
boolean = eval_boolean(expression, a.this, b.this)
if boolean:
return boolean
@ -381,7 +382,7 @@ def simplify_parens(expression):
and not isinstance(expression.this, exp.Select)
and (
not isinstance(expression.parent, (exp.Condition, exp.Binary))
or isinstance(expression.this, (exp.Is, exp.Like))
or isinstance(expression.this, exp.Predicate)
or not isinstance(expression.this, exp.Binary)
)
):
@ -400,13 +401,23 @@ def remove_where_true(expression):
def always_true(expression):
return expression == TRUE or isinstance(expression, exp.Literal)
return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
expression, exp.Literal
)
def is_complement(a, b):
return isinstance(b, exp.Not) and b.this == a
def is_false(a: exp.Expression) -> bool:
return type(a) is exp.Boolean and not a.this
def is_null(a: exp.Expression) -> bool:
return type(a) is exp.Null
def eval_boolean(expression, a, b):
if isinstance(expression, (exp.EQ, exp.Is)):
return boolean_literal(a == b)
@ -466,24 +477,27 @@ def boolean_literal(condition):
return exp.true() if condition else exp.false()
def _flat_simplify(expression, simplifier):
operands = []
queue = deque(expression.flatten(unnest=False))
size = len(queue)
def _flat_simplify(expression, simplifier, root=True):
if root or not expression.same_parent:
operands = []
queue = deque(expression.flatten(unnest=False))
size = len(queue)
while queue:
a = queue.popleft()
while queue:
a = queue.popleft()
for b in queue:
result = simplifier(expression, a, b)
for b in queue:
result = simplifier(expression, a, b)
if result:
queue.remove(b)
queue.append(result)
break
else:
operands.append(a)
if result:
queue.remove(b)
queue.append(result)
break
else:
operands.append(a)
if len(operands) < size:
return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
if len(operands) < size:
return functools.reduce(
lambda a, b: expression.__class__(this=a, expression=b), operands
)
return expression

View file

@ -19,7 +19,7 @@ from sqlglot.trie import in_trie, new_trie
logger = logging.getLogger("sqlglot")
def parse_var_map(args):
def parse_var_map(args: t.Sequence) -> exp.Expression:
keys = []
values = []
for i in range(0, len(args), 2):
@ -31,6 +31,11 @@ def parse_var_map(args):
)
def parse_like(args):
like = exp.Like(this=seq_get(args, 1), expression=seq_get(args, 0))
return exp.Escape(this=like, expression=seq_get(args, 2)) if len(args) > 2 else like
def binary_range_parser(
expr_type: t.Type[exp.Expression],
) -> t.Callable[[Parser, t.Optional[exp.Expression]], t.Optional[exp.Expression]]:
@ -77,6 +82,9 @@ class Parser(metaclass=_Parser):
this=seq_get(args, 0),
to=exp.DataType(this=exp.DataType.Type.TEXT),
),
"GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)),
"IFNULL": exp.Coalesce.from_arg_list,
"LIKE": parse_like,
"TIME_TO_TIME_STR": lambda args: exp.Cast(
this=seq_get(args, 0),
to=exp.DataType(this=exp.DataType.Type.TEXT),
@ -90,7 +98,6 @@ class Parser(metaclass=_Parser):
length=exp.Literal.number(10),
),
"VAR_MAP": parse_var_map,
"IFNULL": exp.Coalesce.from_arg_list,
}
NO_PAREN_FUNCTIONS = {
@ -211,6 +218,7 @@ class Parser(metaclass=_Parser):
TokenType.FILTER,
TokenType.FOLLOWING,
TokenType.FORMAT,
TokenType.FULL,
TokenType.IF,
TokenType.ISNULL,
TokenType.INTERVAL,
@ -226,8 +234,10 @@ class Parser(metaclass=_Parser):
TokenType.ONLY,
TokenType.OPTIONS,
TokenType.ORDINALITY,
TokenType.PARTITION,
TokenType.PERCENT,
TokenType.PIVOT,
TokenType.PRAGMA,
TokenType.PRECEDING,
TokenType.RANGE,
TokenType.REFERENCES,
@ -257,6 +267,7 @@ class Parser(metaclass=_Parser):
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {
TokenType.APPLY,
TokenType.FULL,
TokenType.LEFT,
TokenType.NATURAL,
TokenType.OFFSET,
@ -277,6 +288,7 @@ class Parser(metaclass=_Parser):
TokenType.FILTER,
TokenType.FIRST,
TokenType.FORMAT,
TokenType.GLOB,
TokenType.IDENTIFIER,
TokenType.INDEX,
TokenType.ISNULL,
@ -461,6 +473,7 @@ class Parser(metaclass=_Parser):
TokenType.INSERT: lambda self: self._parse_insert(),
TokenType.LOAD_DATA: lambda self: self._parse_load_data(),
TokenType.MERGE: lambda self: self._parse_merge(),
TokenType.PRAGMA: lambda self: self.expression(exp.Pragma, this=self._parse_expression()),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
TokenType.SET: lambda self: self._parse_set(),
TokenType.UNCACHE: lambda self: self._parse_uncache(),
@ -662,6 +675,8 @@ class Parser(metaclass=_Parser):
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
"EXTRACT": lambda self: self._parse_extract(),
"JSON_OBJECT": lambda self: self._parse_json_object(),
"LOG": lambda self: self._parse_logarithm(),
"POSITION": lambda self: self._parse_position(),
"STRING_AGG": lambda self: self._parse_string_agg(),
"SUBSTRING": lambda self: self._parse_substring(),
@ -719,6 +734,9 @@ class Parser(metaclass=_Parser):
CONVERT_TYPE_FIRST = False
LOG_BASE_FIRST = True
LOG_DEFAULTS_TO_LN = False
__slots__ = (
"error_level",
"error_message_context",
@ -1032,6 +1050,7 @@ class Parser(metaclass=_Parser):
temporary=temporary,
materialized=materialized,
cascade=self._match(TokenType.CASCADE),
constraints=self._match_text_seq("CONSTRAINTS"),
)
def _parse_exists(self, not_: bool = False) -> t.Optional[bool]:
@ -1221,7 +1240,7 @@ class Parser(metaclass=_Parser):
if not identified_property:
break
for p in ensure_collection(identified_property):
for p in ensure_list(identified_property):
properties.append(p)
if properties:
@ -1704,6 +1723,11 @@ class Parser(metaclass=_Parser):
elif self._match(TokenType.SELECT):
comments = self._prev_comments
kind = (
self._match(TokenType.ALIAS)
and self._match_texts(("STRUCT", "VALUE"))
and self._prev.text
)
hint = self._parse_hint()
all_ = self._match(TokenType.ALL)
distinct = self._match(TokenType.DISTINCT)
@ -1722,6 +1746,7 @@ class Parser(metaclass=_Parser):
this = self.expression(
exp.Select,
kind=kind,
hint=hint,
distinct=distinct,
expressions=expressions,
@ -2785,7 +2810,6 @@ class Parser(metaclass=_Parser):
this = seq_get(expressions, 0)
self._parse_query_modifiers(this)
self._match_r_paren()
if isinstance(this, exp.Subqueryable):
this = self._parse_set_operations(
@ -2794,7 +2818,9 @@ class Parser(metaclass=_Parser):
elif len(expressions) > 1:
this = self.expression(exp.Tuple, expressions=expressions)
else:
this = self.expression(exp.Paren, this=this)
this = self.expression(exp.Paren, this=self._parse_set_operations(this))
self._match_r_paren()
if this and comments:
this.comments = comments
@ -3318,6 +3344,60 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
def _parse_json_key_value(self) -> t.Optional[exp.Expression]:
self._match_text_seq("KEY")
key = self._parse_field()
self._match(TokenType.COLON)
self._match_text_seq("VALUE")
value = self._parse_field()
if not key and not value:
return None
return self.expression(exp.JSONKeyValue, this=key, expression=value)
def _parse_json_object(self) -> exp.Expression:
expressions = self._parse_csv(self._parse_json_key_value)
null_handling = None
if self._match_text_seq("NULL", "ON", "NULL"):
null_handling = "NULL ON NULL"
elif self._match_text_seq("ABSENT", "ON", "NULL"):
null_handling = "ABSENT ON NULL"
unique_keys = None
if self._match_text_seq("WITH", "UNIQUE"):
unique_keys = True
elif self._match_text_seq("WITHOUT", "UNIQUE"):
unique_keys = False
self._match_text_seq("KEYS")
return_type = self._match_text_seq("RETURNING") and self._parse_type()
format_json = self._match_text_seq("FORMAT", "JSON")
encoding = self._match_text_seq("ENCODING") and self._parse_var()
return self.expression(
exp.JSONObject,
expressions=expressions,
null_handling=null_handling,
unique_keys=unique_keys,
return_type=return_type,
format_json=format_json,
encoding=encoding,
)
def _parse_logarithm(self) -> exp.Expression:
# Default argument order is base, expression
args = self._parse_csv(self._parse_range)
if len(args) > 1:
if not self.LOG_BASE_FIRST:
args.reverse()
return exp.Log.from_arg_list(args)
return self.expression(
exp.Ln if self.LOG_DEFAULTS_TO_LN else exp.Log, this=seq_get(args, 0)
)
def _parse_position(self, haystack_first: bool = False) -> exp.Expression:
args = self._parse_csv(self._parse_bitwise)
@ -3654,7 +3734,7 @@ class Parser(metaclass=_Parser):
return parse_result
def _parse_select_or_expression(self) -> t.Optional[exp.Expression]:
return self._parse_select() or self._parse_expression()
return self._parse_select() or self._parse_set_operations(self._parse_expression())
def _parse_ddl_select(self) -> t.Optional[exp.Expression]:
return self._parse_set_operations(
@ -3741,6 +3821,8 @@ class Parser(metaclass=_Parser):
expression = self._parse_foreign_key()
elif kind == TokenType.PRIMARY_KEY or self._match(TokenType.PRIMARY_KEY):
expression = self._parse_primary_key()
else:
expression = None
return self.expression(exp.AddConstraint, this=this, expression=expression)
@ -3799,12 +3881,15 @@ class Parser(metaclass=_Parser):
parser = self.ALTER_PARSERS.get(self._prev.text.upper()) if self._prev else None
if parser:
return self.expression(
exp.AlterTable,
this=this,
exists=exists,
actions=ensure_list(parser(self)),
)
actions = ensure_list(parser(self))
if not self._curr:
return self.expression(
exp.AlterTable,
this=this,
exists=exists,
actions=actions,
)
return self._parse_as_command(start)
def _parse_merge(self) -> exp.Expression:

View file

@ -175,7 +175,7 @@ class Step:
}
for projection in projections:
for i, e in aggregate.group.items():
for child, _, _ in projection.walk():
for child, *_ in projection.walk():
if child == e:
child.replace(exp.column(i, step.name))
aggregate.add_dependency(step)

View file

@ -306,11 +306,11 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
return self._type_mapping_cache[schema_type]
def ensure_schema(schema: t.Any) -> Schema:
def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema:
if isinstance(schema, Schema):
return schema
return MappingSchema(schema)
return MappingSchema(schema, dialect=dialect)
def ensure_column_mapping(mapping: t.Optional[ColumnMapping]):

View file

@ -252,6 +252,7 @@ class TokenType(AutoName):
PERCENT = auto()
PIVOT = auto()
PLACEHOLDER = auto()
PRAGMA = auto()
PRECEDING = auto()
PRIMARY_KEY = auto()
PROCEDURE = auto()
@ -346,7 +347,8 @@ class Token:
self.token_type = token_type
self.text = text
self.line = line
self.col = max(col - len(text), 1)
self.col = col - len(text)
self.col = self.col if self.col > 1 else 1
self.comments = comments
def __repr__(self) -> str:
@ -586,6 +588,7 @@ class Tokenizer(metaclass=_Tokenizer):
"PARTITIONED_BY": TokenType.PARTITION_BY,
"PERCENT": TokenType.PERCENT,
"PIVOT": TokenType.PIVOT,
"PRAGMA": TokenType.PRAGMA,
"PRECEDING": TokenType.PRECEDING,
"PRIMARY KEY": TokenType.PRIMARY_KEY,
"PROCEDURE": TokenType.PROCEDURE,
@ -654,6 +657,7 @@ class Tokenizer(metaclass=_Tokenizer):
"LONG": TokenType.BIGINT,
"BIGINT": TokenType.BIGINT,
"INT8": TokenType.BIGINT,
"DEC": TokenType.DECIMAL,
"DECIMAL": TokenType.DECIMAL,
"MAP": TokenType.MAP,
"NULLABLE": TokenType.NULLABLE,
@ -714,7 +718,7 @@ class Tokenizer(metaclass=_Tokenizer):
"VACUUM": TokenType.COMMAND,
}
WHITE_SPACE: t.Dict[str, TokenType] = {
WHITE_SPACE: t.Dict[t.Optional[str], TokenType] = {
" ": TokenType.SPACE,
"\t": TokenType.SPACE,
"\n": TokenType.BREAK,
@ -813,11 +817,8 @@ class Tokenizer(metaclass=_Tokenizer):
return self.sql[start:end]
return ""
def _line_break(self, char: t.Optional[str]) -> bool:
return self.WHITE_SPACE.get(char) == TokenType.BREAK # type: ignore
def _advance(self, i: int = 1) -> None:
if self._line_break(self._char):
if self.WHITE_SPACE.get(self._char) is TokenType.BREAK:
self._set_new_line()
self._col += i
@ -939,7 +940,7 @@ class Tokenizer(metaclass=_Tokenizer):
self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) # type: ignore
self._advance(comment_end_size - 1)
else:
while not self._end and not self._line_break(self._peek):
while not self._end and not self.WHITE_SPACE.get(self._peek) is TokenType.BREAK:
self._advance()
self._comments.append(self._text[comment_start_size:]) # type: ignore