2025-02-13 06:15:54 +01:00
|
|
|
import datetime
|
2025-02-13 14:54:32 +01:00
|
|
|
import inspect
|
2025-02-13 06:15:54 +01:00
|
|
|
import re
|
|
|
|
import statistics
|
2025-02-13 14:54:32 +01:00
|
|
|
from functools import wraps
|
2025-02-13 06:15:54 +01:00
|
|
|
|
2025-02-13 14:54:32 +01:00
|
|
|
from sqlglot import exp
|
2025-02-13 14:53:05 +01:00
|
|
|
from sqlglot.helper import PYTHON_VERSION
|
|
|
|
|
2025-02-13 06:15:54 +01:00
|
|
|
|
|
|
|
class reverse_key:
|
|
|
|
def __init__(self, obj):
|
|
|
|
self.obj = obj
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return other.obj == self.obj
|
|
|
|
|
|
|
|
def __lt__(self, other):
|
|
|
|
return other.obj < self.obj
|
|
|
|
|
|
|
|
|
2025-02-13 14:58:37 +01:00
|
|
|
def filter_nulls(func, empty_null=True):
|
2025-02-13 14:54:32 +01:00
|
|
|
@wraps(func)
|
|
|
|
def _func(values):
|
2025-02-13 14:58:37 +01:00
|
|
|
filtered = tuple(v for v in values if v is not None)
|
|
|
|
if not filtered and empty_null:
|
|
|
|
return None
|
|
|
|
return func(filtered)
|
2025-02-13 14:54:32 +01:00
|
|
|
|
|
|
|
return _func
|
|
|
|
|
|
|
|
|
|
|
|
def null_if_any(*required):
|
|
|
|
"""
|
|
|
|
Decorator that makes a function return `None` if any of the `required` arguments are `None`.
|
|
|
|
|
|
|
|
This also supports decoration with no arguments, e.g.:
|
|
|
|
|
|
|
|
@null_if_any
|
|
|
|
def foo(a, b): ...
|
|
|
|
|
|
|
|
In which case all arguments are required.
|
|
|
|
"""
|
|
|
|
f = None
|
|
|
|
if len(required) == 1 and callable(required[0]):
|
|
|
|
f = required[0]
|
|
|
|
required = ()
|
|
|
|
|
|
|
|
def decorator(func):
|
|
|
|
if required:
|
|
|
|
required_indices = [
|
|
|
|
i for i, param in enumerate(inspect.signature(func).parameters) if param in required
|
|
|
|
]
|
|
|
|
|
|
|
|
def predicate(*args):
|
|
|
|
return any(args[i] is None for i in required_indices)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
def predicate(*args):
|
|
|
|
return any(a is None for a in args)
|
|
|
|
|
|
|
|
@wraps(func)
|
|
|
|
def _func(*args):
|
|
|
|
if predicate(*args):
|
|
|
|
return None
|
|
|
|
return func(*args)
|
|
|
|
|
|
|
|
return _func
|
|
|
|
|
|
|
|
if f:
|
|
|
|
return decorator(f)
|
|
|
|
|
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
|
|
|
@null_if_any("substr", "this")
|
|
|
|
def str_position(substr, this, position=None):
|
|
|
|
position = position - 1 if position is not None else position
|
|
|
|
return this.find(substr, position) + 1
|
|
|
|
|
|
|
|
|
|
|
|
@null_if_any("this")
|
|
|
|
def substring(this, start=None, length=None):
|
|
|
|
if start is None:
|
|
|
|
return this
|
|
|
|
elif start == 0:
|
|
|
|
return ""
|
|
|
|
elif start < 0:
|
|
|
|
start = len(this) + start
|
|
|
|
else:
|
|
|
|
start -= 1
|
|
|
|
|
|
|
|
end = None if length is None else start + length
|
|
|
|
|
|
|
|
return this[start:end]
|
|
|
|
|
|
|
|
|
|
|
|
@null_if_any
|
|
|
|
def cast(this, to):
|
|
|
|
if to == exp.DataType.Type.DATE:
|
|
|
|
return datetime.date.fromisoformat(this)
|
|
|
|
if to == exp.DataType.Type.DATETIME:
|
|
|
|
return datetime.datetime.fromisoformat(this)
|
|
|
|
if to in exp.DataType.TEXT_TYPES:
|
|
|
|
return str(this)
|
|
|
|
if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}:
|
|
|
|
return float(this)
|
|
|
|
if to in exp.DataType.NUMERIC_TYPES:
|
|
|
|
return int(this)
|
|
|
|
raise NotImplementedError(f"Casting to '{to}' not implemented.")
|
|
|
|
|
|
|
|
|
|
|
|
def ordered(this, desc, nulls_first):
|
|
|
|
if desc:
|
|
|
|
return reverse_key(this)
|
|
|
|
return this
|
|
|
|
|
|
|
|
|
|
|
|
@null_if_any
|
|
|
|
def interval(this, unit):
|
|
|
|
if unit == "DAY":
|
|
|
|
return datetime.timedelta(days=float(this))
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
2025-02-13 06:15:54 +01:00
|
|
|
ENV = {
|
2025-02-13 14:54:32 +01:00
|
|
|
"exp": exp,
|
|
|
|
# aggs
|
2025-02-13 15:01:55 +01:00
|
|
|
"ARRAYAGG": list,
|
2025-02-13 14:54:32 +01:00
|
|
|
"AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore
|
2025-02-13 14:58:37 +01:00
|
|
|
"COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False),
|
2025-02-13 14:54:32 +01:00
|
|
|
"MAX": filter_nulls(max),
|
|
|
|
"MIN": filter_nulls(min),
|
2025-02-13 15:01:55 +01:00
|
|
|
"SUM": filter_nulls(sum),
|
2025-02-13 14:54:32 +01:00
|
|
|
# scalar functions
|
|
|
|
"ABS": null_if_any(lambda this: abs(this)),
|
|
|
|
"ADD": null_if_any(lambda e, this: e + this),
|
2025-02-13 15:01:55 +01:00
|
|
|
"ARRAYANY": null_if_any(lambda arr, func: any(func(e) for e in arr)),
|
2025-02-13 14:54:32 +01:00
|
|
|
"BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high),
|
|
|
|
"BITWISEAND": null_if_any(lambda this, e: this & e),
|
|
|
|
"BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e),
|
|
|
|
"BITWISEOR": null_if_any(lambda this, e: this | e),
|
|
|
|
"BITWISERIGHTSHIFT": null_if_any(lambda this, e: this >> e),
|
|
|
|
"BITWISEXOR": null_if_any(lambda this, e: this ^ e),
|
|
|
|
"CAST": cast,
|
|
|
|
"COALESCE": lambda *args: next((a for a in args if a is not None), None),
|
|
|
|
"CONCAT": null_if_any(lambda *args: "".join(args)),
|
|
|
|
"CONCATWS": null_if_any(lambda this, *args: this.join(args)),
|
|
|
|
"DIV": null_if_any(lambda e, this: e / this),
|
|
|
|
"EQ": null_if_any(lambda this, e: this == e),
|
|
|
|
"EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
|
|
|
|
"GT": null_if_any(lambda this, e: this > e),
|
|
|
|
"GTE": null_if_any(lambda this, e: this >= e),
|
|
|
|
"IFNULL": lambda e, alt: alt if e is None else e,
|
|
|
|
"IF": lambda predicate, true, false: true if predicate else false,
|
|
|
|
"INTDIV": null_if_any(lambda e, this: e // this),
|
|
|
|
"INTERVAL": interval,
|
|
|
|
"LIKE": null_if_any(
|
|
|
|
lambda this, e: bool(re.match(e.replace("_", ".").replace("%", ".*"), this))
|
|
|
|
),
|
|
|
|
"LOWER": null_if_any(lambda arg: arg.lower()),
|
|
|
|
"LT": null_if_any(lambda this, e: this < e),
|
|
|
|
"LTE": null_if_any(lambda this, e: this <= e),
|
|
|
|
"MOD": null_if_any(lambda e, this: e % this),
|
|
|
|
"MUL": null_if_any(lambda e, this: e * this),
|
|
|
|
"NEQ": null_if_any(lambda this, e: this != e),
|
|
|
|
"ORD": null_if_any(ord),
|
|
|
|
"ORDERED": ordered,
|
2025-02-13 06:15:54 +01:00
|
|
|
"POW": pow,
|
2025-02-13 14:54:32 +01:00
|
|
|
"STRPOSITION": str_position,
|
|
|
|
"SUB": null_if_any(lambda e, this: e - this),
|
|
|
|
"SUBSTRING": substring,
|
2025-02-13 15:09:58 +01:00
|
|
|
"TIMESTRTOTIME": null_if_any(lambda arg: datetime.datetime.fromisoformat(arg)),
|
2025-02-13 14:54:32 +01:00
|
|
|
"UPPER": null_if_any(lambda arg: arg.upper()),
|
2025-02-13 15:50:57 +01:00
|
|
|
"YEAR": null_if_any(lambda arg: arg.year),
|
|
|
|
"MONTH": null_if_any(lambda arg: arg.month),
|
|
|
|
"DAY": null_if_any(lambda arg: arg.day),
|
|
|
|
"CURRENTDATETIME": datetime.datetime.now,
|
|
|
|
"CURRENTTIMESTAMP": datetime.datetime.now,
|
|
|
|
"CURRENTTIME": datetime.datetime.now,
|
|
|
|
"CURRENTDATE": datetime.date.today,
|
2025-02-13 06:15:54 +01:00
|
|
|
}
|