Merging upstream version 20.3.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
2945bcc4f7
commit
4d9376ba93
132 changed files with 55125 additions and 51576 deletions
|
@ -95,9 +95,6 @@ def eliminate_subqueries(expression):
|
|||
|
||||
|
||||
def _eliminate(scope, existing_ctes, taken):
|
||||
if scope.is_union:
|
||||
return _eliminate_union(scope, existing_ctes, taken)
|
||||
|
||||
if scope.is_derived_table:
|
||||
return _eliminate_derived_table(scope, existing_ctes, taken)
|
||||
|
||||
|
@ -105,36 +102,6 @@ def _eliminate(scope, existing_ctes, taken):
|
|||
return _eliminate_cte(scope, existing_ctes, taken)
|
||||
|
||||
|
||||
def _eliminate_union(scope, existing_ctes, taken):
|
||||
duplicate_cte_alias = existing_ctes.get(scope.expression)
|
||||
|
||||
alias = duplicate_cte_alias or find_new_name(taken=taken, base="cte")
|
||||
|
||||
taken[alias] = scope
|
||||
|
||||
# Try to maintain the selections
|
||||
expressions = scope.expression.selects
|
||||
selects = [
|
||||
exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name, copy=False)
|
||||
for e in expressions
|
||||
if e.alias_or_name
|
||||
]
|
||||
# If not all selections have an alias, just select *
|
||||
if len(selects) != len(expressions):
|
||||
selects = ["*"]
|
||||
|
||||
scope.expression.replace(
|
||||
exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias, copy=False))
|
||||
)
|
||||
|
||||
if not duplicate_cte_alias:
|
||||
existing_ctes[scope.expression] = alias
|
||||
return exp.CTE(
|
||||
this=scope.expression,
|
||||
alias=exp.TableAlias(this=exp.to_identifier(alias)),
|
||||
)
|
||||
|
||||
|
||||
def _eliminate_derived_table(scope, existing_ctes, taken):
|
||||
# This makes sure that we don't:
|
||||
# - drop the "pivot" arg from a pivoted subquery
|
||||
|
|
|
@ -174,6 +174,22 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
|||
for col in inner_projections[selection].find_all(exp.Column)
|
||||
)
|
||||
|
||||
def _is_recursive():
|
||||
# Recursive CTEs look like this:
|
||||
# WITH RECURSIVE cte AS (
|
||||
# SELECT * FROM x <-- inner scope
|
||||
# UNION ALL
|
||||
# SELECT * FROM cte <-- outer scope
|
||||
# )
|
||||
cte = inner_scope.expression.parent
|
||||
node = outer_scope.expression.parent
|
||||
|
||||
while node:
|
||||
if node is cte:
|
||||
return True
|
||||
node = node.parent
|
||||
return False
|
||||
|
||||
return (
|
||||
isinstance(outer_scope.expression, exp.Select)
|
||||
and not outer_scope.expression.is_star
|
||||
|
@ -197,6 +213,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
|||
)
|
||||
and not _outer_select_joins_on_inner_select_join()
|
||||
and not _is_a_window_expression_in_unmergable_operation()
|
||||
and not _is_recursive()
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ from sqlglot.optimizer.scope import build_scope, find_in_scope
|
|||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
|
||||
def pushdown_predicates(expression):
|
||||
def pushdown_predicates(expression, dialect=None):
|
||||
"""
|
||||
Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
|
||||
|
||||
|
@ -36,7 +36,7 @@ def pushdown_predicates(expression):
|
|||
if isinstance(parent, exp.Join) and parent.side == "RIGHT":
|
||||
selected_sources = {k: (node, source)}
|
||||
break
|
||||
pushdown(where.this, selected_sources, scope_ref_count)
|
||||
pushdown(where.this, selected_sources, scope_ref_count, dialect)
|
||||
|
||||
# joins should only pushdown into itself, not to other joins
|
||||
# so we limit the selected sources to only itself
|
||||
|
@ -44,17 +44,20 @@ def pushdown_predicates(expression):
|
|||
name = join.alias_or_name
|
||||
if name in scope.selected_sources:
|
||||
pushdown(
|
||||
join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count
|
||||
join.args.get("on"),
|
||||
{name: scope.selected_sources[name]},
|
||||
scope_ref_count,
|
||||
dialect,
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def pushdown(condition, sources, scope_ref_count):
|
||||
def pushdown(condition, sources, scope_ref_count, dialect):
|
||||
if not condition:
|
||||
return
|
||||
|
||||
condition = condition.replace(simplify(condition))
|
||||
condition = condition.replace(simplify(condition, dialect=dialect))
|
||||
cnf_like = normalized(condition) or not normalized(condition, dnf=True)
|
||||
|
||||
predicates = list(
|
||||
|
|
|
@ -37,6 +37,7 @@ class Scope:
|
|||
For example:
|
||||
SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
|
||||
The LATERAL VIEW EXPLODE gets x as a source.
|
||||
cte_sources (dict[str, Scope]): Sources from CTES
|
||||
outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
|
||||
defines a column list of it's alias of this scope, this is that list of columns.
|
||||
For example:
|
||||
|
@ -61,11 +62,14 @@ class Scope:
|
|||
parent=None,
|
||||
scope_type=ScopeType.ROOT,
|
||||
lateral_sources=None,
|
||||
cte_sources=None,
|
||||
):
|
||||
self.expression = expression
|
||||
self.sources = sources or {}
|
||||
self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
|
||||
self.lateral_sources = lateral_sources or {}
|
||||
self.cte_sources = cte_sources or {}
|
||||
self.sources.update(self.lateral_sources)
|
||||
self.sources.update(self.cte_sources)
|
||||
self.outer_column_list = outer_column_list or []
|
||||
self.parent = parent
|
||||
self.scope_type = scope_type
|
||||
|
@ -92,13 +96,17 @@ class Scope:
|
|||
self._pivots = None
|
||||
self._references = None
|
||||
|
||||
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
|
||||
def branch(
|
||||
self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
|
||||
):
|
||||
"""Branch from the current scope to a new, inner scope"""
|
||||
return Scope(
|
||||
expression=expression.unnest(),
|
||||
sources={**self.cte_sources, **(chain_sources or {})},
|
||||
sources=sources.copy() if sources else None,
|
||||
parent=self,
|
||||
scope_type=scope_type,
|
||||
cte_sources={**self.cte_sources, **(cte_sources or {})},
|
||||
lateral_sources=lateral_sources.copy() if lateral_sources else None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -305,20 +313,6 @@ class Scope:
|
|||
|
||||
return self._references
|
||||
|
||||
@property
|
||||
def cte_sources(self):
|
||||
"""
|
||||
Sources that are CTEs.
|
||||
|
||||
Returns:
|
||||
dict[str, Scope]: Mapping of source alias to Scope
|
||||
"""
|
||||
return {
|
||||
alias: scope
|
||||
for alias, scope in self.sources.items()
|
||||
if isinstance(scope, Scope) and scope.is_cte
|
||||
}
|
||||
|
||||
@property
|
||||
def external_columns(self):
|
||||
"""
|
||||
|
@ -515,7 +509,10 @@ def _traverse_scope(scope):
|
|||
elif isinstance(scope.expression, exp.Union):
|
||||
yield from _traverse_union(scope)
|
||||
elif isinstance(scope.expression, exp.Subquery):
|
||||
yield from _traverse_subqueries(scope)
|
||||
if scope.is_root:
|
||||
yield from _traverse_select(scope)
|
||||
else:
|
||||
yield from _traverse_subqueries(scope)
|
||||
elif isinstance(scope.expression, exp.Table):
|
||||
yield from _traverse_tables(scope)
|
||||
elif isinstance(scope.expression, exp.UDTF):
|
||||
|
@ -572,7 +569,7 @@ def _traverse_ctes(scope):
|
|||
for child_scope in _traverse_scope(
|
||||
scope.branch(
|
||||
cte.this,
|
||||
chain_sources=sources,
|
||||
cte_sources=sources,
|
||||
outer_column_list=cte.alias_column_names,
|
||||
scope_type=ScopeType.CTE,
|
||||
)
|
||||
|
@ -584,12 +581,14 @@ def _traverse_ctes(scope):
|
|||
|
||||
if recursive_scope:
|
||||
child_scope.add_source(alias, recursive_scope)
|
||||
child_scope.cte_sources[alias] = recursive_scope
|
||||
|
||||
# append the final child_scope yielded
|
||||
if child_scope:
|
||||
scope.cte_scopes.append(child_scope)
|
||||
|
||||
scope.sources.update(sources)
|
||||
scope.cte_sources.update(sources)
|
||||
|
||||
|
||||
def _is_derived_table(expression: exp.Subquery) -> bool:
|
||||
|
@ -725,7 +724,7 @@ def _traverse_ddl(scope):
|
|||
yield from _traverse_ctes(scope)
|
||||
|
||||
query_scope = scope.branch(
|
||||
scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, chain_sources=scope.sources
|
||||
scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, sources=scope.sources
|
||||
)
|
||||
query_scope._collect()
|
||||
query_scope._ctes = scope.ctes + query_scope._ctes
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import itertools
|
||||
|
@ -6,10 +8,17 @@ from collections import deque
|
|||
from decimal import Decimal
|
||||
|
||||
import sqlglot
|
||||
from sqlglot import exp
|
||||
from sqlglot import Dialect, exp
|
||||
from sqlglot.helper import first, is_iterable, merge_ranges, while_changing
|
||||
from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
|
||||
DateTruncBinaryTransform = t.Callable[
|
||||
[exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression]
|
||||
]
|
||||
|
||||
# Final means that an expression should not be simplified
|
||||
FINAL = "final"
|
||||
|
||||
|
@ -18,7 +27,9 @@ class UnsupportedUnit(Exception):
|
|||
pass
|
||||
|
||||
|
||||
def simplify(expression, constant_propagation=False):
|
||||
def simplify(
|
||||
expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
|
||||
):
|
||||
"""
|
||||
Rewrite sqlglot AST to simplify expressions.
|
||||
|
||||
|
@ -36,15 +47,18 @@ def simplify(expression, constant_propagation=False):
|
|||
sqlglot.Expression: simplified expression
|
||||
"""
|
||||
|
||||
dialect = Dialect.get_or_raise(dialect)
|
||||
|
||||
# group by expressions cannot be simplified, for example
|
||||
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
|
||||
# the projection must exactly match the group by key
|
||||
for group in expression.find_all(exp.Group):
|
||||
select = group.parent
|
||||
assert select
|
||||
groups = set(group.expressions)
|
||||
group.meta[FINAL] = True
|
||||
|
||||
for e in select.selects:
|
||||
for e in select.expressions:
|
||||
for node, *_ in e.walk():
|
||||
if node in groups:
|
||||
e.meta[FINAL] = True
|
||||
|
@ -84,7 +98,8 @@ def simplify(expression, constant_propagation=False):
|
|||
node = simplify_literals(node, root)
|
||||
node = simplify_equality(node)
|
||||
node = simplify_parens(node)
|
||||
node = simplify_datetrunc_predicate(node)
|
||||
node = simplify_datetrunc(node, dialect)
|
||||
node = sort_comparison(node)
|
||||
|
||||
if root:
|
||||
expression.replace(node)
|
||||
|
@ -117,14 +132,30 @@ def rewrite_between(expression: exp.Expression) -> exp.Expression:
|
|||
This is done because comparison simplification is only done on lt/lte/gt/gte.
|
||||
"""
|
||||
if isinstance(expression, exp.Between):
|
||||
return exp.and_(
|
||||
negate = isinstance(expression.parent, exp.Not)
|
||||
|
||||
expression = exp.and_(
|
||||
exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
|
||||
exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
|
||||
copy=False,
|
||||
)
|
||||
|
||||
if negate:
|
||||
expression = exp.paren(expression, copy=False)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
COMPLEMENT_COMPARISONS = {
|
||||
exp.LT: exp.GTE,
|
||||
exp.GT: exp.LTE,
|
||||
exp.LTE: exp.GT,
|
||||
exp.GTE: exp.LT,
|
||||
exp.EQ: exp.NEQ,
|
||||
exp.NEQ: exp.EQ,
|
||||
}
|
||||
|
||||
|
||||
def simplify_not(expression):
|
||||
"""
|
||||
Demorgan's Law
|
||||
|
@ -132,10 +163,15 @@ def simplify_not(expression):
|
|||
NOT (x AND y) -> NOT x OR NOT y
|
||||
"""
|
||||
if isinstance(expression, exp.Not):
|
||||
if is_null(expression.this):
|
||||
this = expression.this
|
||||
if is_null(this):
|
||||
return exp.null()
|
||||
if isinstance(expression.this, exp.Paren):
|
||||
condition = expression.this.unnest()
|
||||
if this.__class__ in COMPLEMENT_COMPARISONS:
|
||||
return COMPLEMENT_COMPARISONS[this.__class__](
|
||||
this=this.this, expression=this.expression
|
||||
)
|
||||
if isinstance(this, exp.Paren):
|
||||
condition = this.unnest()
|
||||
if isinstance(condition, exp.And):
|
||||
return exp.or_(
|
||||
exp.not_(condition.left, copy=False),
|
||||
|
@ -150,14 +186,14 @@ def simplify_not(expression):
|
|||
)
|
||||
if is_null(condition):
|
||||
return exp.null()
|
||||
if always_true(expression.this):
|
||||
if always_true(this):
|
||||
return exp.false()
|
||||
if is_false(expression.this):
|
||||
if is_false(this):
|
||||
return exp.true()
|
||||
if isinstance(expression.this, exp.Not):
|
||||
if isinstance(this, exp.Not):
|
||||
# double negation
|
||||
# NOT NOT x -> x
|
||||
return expression.this.this
|
||||
return this.this
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -249,12 +285,6 @@ def _simplify_comparison(expression, left, right, or_=False):
|
|||
except StopIteration:
|
||||
return expression
|
||||
|
||||
# make sure the comparison is always of the form x > 1 instead of 1 < x
|
||||
if left.__class__ in INVERSE_COMPARISONS and l == ll:
|
||||
left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
|
||||
if right.__class__ in INVERSE_COMPARISONS and r == rl:
|
||||
right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
|
||||
|
||||
if l.is_number and r.is_number:
|
||||
l = float(l.name)
|
||||
r = float(r.name)
|
||||
|
@ -397,13 +427,7 @@ def propagate_constants(expression, root=True):
|
|||
# TODO: create a helper that can be used to detect nested literal expressions such
|
||||
# as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
|
||||
if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
|
||||
pass
|
||||
elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
|
||||
l, r = r, l
|
||||
else:
|
||||
continue
|
||||
|
||||
constant_mapping[l] = (id(l), r)
|
||||
constant_mapping[l] = (id(l), r)
|
||||
|
||||
if constant_mapping:
|
||||
for column in find_all_in_scope(expression, exp.Column):
|
||||
|
@ -458,11 +482,7 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression:
|
|||
if isinstance(expression, COMPARISONS):
|
||||
l, r = expression.left, expression.right
|
||||
|
||||
if l.__class__ in INVERSE_OPS:
|
||||
pass
|
||||
elif r.__class__ in INVERSE_OPS:
|
||||
l, r = r, l
|
||||
else:
|
||||
if not l.__class__ in INVERSE_OPS:
|
||||
return expression
|
||||
|
||||
if r.is_number:
|
||||
|
@ -650,7 +670,7 @@ def simplify_coalesce(expression):
|
|||
|
||||
# Find the first constant arg
|
||||
for arg_index, arg in enumerate(coalesce.expressions):
|
||||
if _is_constant(other):
|
||||
if _is_constant(arg):
|
||||
break
|
||||
else:
|
||||
return expression
|
||||
|
@ -752,7 +772,7 @@ def simplify_conditionals(expression):
|
|||
DateRange = t.Tuple[datetime.date, datetime.date]
|
||||
|
||||
|
||||
def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
|
||||
def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
|
||||
"""
|
||||
Get the date range for a DATE_TRUNC equality comparison:
|
||||
|
||||
|
@ -761,7 +781,7 @@ def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
|
|||
Returns:
|
||||
tuple of [min, max) or None if a value can never be equal to `date` for `unit`
|
||||
"""
|
||||
floor = date_floor(date, unit)
|
||||
floor = date_floor(date, unit, dialect)
|
||||
|
||||
if date != floor:
|
||||
# This will always be False, except for NULL values.
|
||||
|
@ -780,9 +800,9 @@ def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Exp
|
|||
|
||||
|
||||
def _datetrunc_eq(
|
||||
left: exp.Expression, date: datetime.date, unit: str
|
||||
left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
|
||||
) -> t.Optional[exp.Expression]:
|
||||
drange = _datetrunc_range(date, unit)
|
||||
drange = _datetrunc_range(date, unit, dialect)
|
||||
if not drange:
|
||||
return None
|
||||
|
||||
|
@ -790,9 +810,9 @@ def _datetrunc_eq(
|
|||
|
||||
|
||||
def _datetrunc_neq(
|
||||
left: exp.Expression, date: datetime.date, unit: str
|
||||
left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
|
||||
) -> t.Optional[exp.Expression]:
|
||||
drange = _datetrunc_range(date, unit)
|
||||
drange = _datetrunc_range(date, unit, dialect)
|
||||
if not drange:
|
||||
return None
|
||||
|
||||
|
@ -803,41 +823,39 @@ def _datetrunc_neq(
|
|||
)
|
||||
|
||||
|
||||
DateTruncBinaryTransform = t.Callable[
|
||||
[exp.Expression, datetime.date, str], t.Optional[exp.Expression]
|
||||
]
|
||||
DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
|
||||
exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)),
|
||||
exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)),
|
||||
exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)),
|
||||
exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)),
|
||||
exp.LT: lambda l, dt, u, d: l
|
||||
< date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)),
|
||||
exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)),
|
||||
exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)),
|
||||
exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)),
|
||||
exp.EQ: _datetrunc_eq,
|
||||
exp.NEQ: _datetrunc_neq,
|
||||
}
|
||||
DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
|
||||
DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
|
||||
|
||||
|
||||
def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
|
||||
return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right)
|
||||
return isinstance(left, DATETRUNCS) and _is_date_literal(right)
|
||||
|
||||
|
||||
@catch(ModuleNotFoundError, UnsupportedUnit)
|
||||
def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
|
||||
def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
|
||||
"""Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
|
||||
comparison = expression.__class__
|
||||
|
||||
if comparison not in DATETRUNC_COMPARISONS:
|
||||
if isinstance(expression, DATETRUNCS):
|
||||
date = extract_date(expression.this)
|
||||
if date and expression.unit:
|
||||
return date_literal(date_floor(date, expression.unit.name.lower(), dialect))
|
||||
elif comparison not in DATETRUNC_COMPARISONS:
|
||||
return expression
|
||||
|
||||
if isinstance(expression, exp.Binary):
|
||||
l, r = expression.left, expression.right
|
||||
|
||||
if _is_datetrunc_predicate(l, r):
|
||||
pass
|
||||
elif _is_datetrunc_predicate(r, l):
|
||||
comparison = INVERSE_COMPARISONS.get(comparison, comparison)
|
||||
l, r = r, l
|
||||
else:
|
||||
if not _is_datetrunc_predicate(l, r):
|
||||
return expression
|
||||
|
||||
l = t.cast(exp.DateTrunc, l)
|
||||
|
@ -847,7 +865,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
|
|||
if not date:
|
||||
return expression
|
||||
|
||||
return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
|
||||
return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression
|
||||
elif isinstance(expression, exp.In):
|
||||
l = expression.this
|
||||
rs = expression.expressions
|
||||
|
@ -861,7 +879,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
|
|||
date = extract_date(r)
|
||||
if not date:
|
||||
return expression
|
||||
drange = _datetrunc_range(date, unit)
|
||||
drange = _datetrunc_range(date, unit, dialect)
|
||||
if drange:
|
||||
ranges.append(drange)
|
||||
|
||||
|
@ -875,6 +893,23 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def sort_comparison(expression: exp.Expression) -> exp.Expression:
|
||||
if expression.__class__ in COMPLEMENT_COMPARISONS:
|
||||
l, r = expression.this, expression.expression
|
||||
l_column = isinstance(l, exp.Column)
|
||||
r_column = isinstance(r, exp.Column)
|
||||
l_const = _is_constant(l)
|
||||
r_const = _is_constant(r)
|
||||
|
||||
if (l_column and not r_column) or (r_const and not l_const):
|
||||
return expression
|
||||
if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
|
||||
return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
|
||||
this=r, expression=l
|
||||
)
|
||||
return expression
|
||||
|
||||
|
||||
# CROSS joins result in an empty table if the right table is empty.
|
||||
# So we can only simplify certain types of joins to CROSS.
|
||||
# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
|
||||
|
@ -1034,7 +1069,7 @@ def interval(unit: str, n: int = 1):
|
|||
raise UnsupportedUnit(f"Unsupported unit: {unit}")
|
||||
|
||||
|
||||
def date_floor(d: datetime.date, unit: str) -> datetime.date:
|
||||
def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
|
||||
if unit == "year":
|
||||
return d.replace(month=1, day=1)
|
||||
if unit == "quarter":
|
||||
|
@ -1050,15 +1085,15 @@ def date_floor(d: datetime.date, unit: str) -> datetime.date:
|
|||
return d.replace(month=d.month, day=1)
|
||||
if unit == "week":
|
||||
# Assuming week starts on Monday (0) and ends on Sunday (6)
|
||||
return d - datetime.timedelta(days=d.weekday())
|
||||
return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
|
||||
if unit == "day":
|
||||
return d
|
||||
|
||||
raise UnsupportedUnit(f"Unsupported unit: {unit}")
|
||||
|
||||
|
||||
def date_ceil(d: datetime.date, unit: str) -> datetime.date:
|
||||
floor = date_floor(d, unit)
|
||||
def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
|
||||
floor = date_floor(d, unit, dialect)
|
||||
|
||||
if floor == d:
|
||||
return d
|
||||
|
|
|
@ -65,6 +65,8 @@ def unnest(select, parent_select, next_alias_name):
|
|||
)
|
||||
):
|
||||
column = exp.Max(this=column)
|
||||
elif not isinstance(select.parent, exp.Subquery):
|
||||
return
|
||||
|
||||
_replace(select.parent, column)
|
||||
parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue