Merging upstream version 9.0.3.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
66ef36a209
commit
b1dc5c6faf
22 changed files with 742 additions and 223 deletions
|
@ -24,7 +24,7 @@ from sqlglot.parser import Parser
|
|||
from sqlglot.schema import MappingSchema
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
|
||||
__version__ = "9.0.1"
|
||||
__version__ = "9.0.3"
|
||||
|
||||
pretty = False
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import typing as t
|
|||
import sqlglot
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.dataframe.sql.types import DataType
|
||||
from sqlglot.helper import flatten
|
||||
from sqlglot.helper import flatten, is_iterable
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dataframe.sql._typing import ColumnOrLiteral
|
||||
|
@ -134,10 +134,14 @@ class Column:
|
|||
cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs
|
||||
) -> Column:
|
||||
ensured_column = None if column is None else cls.ensure_col(column)
|
||||
ensure_expression_values = {
|
||||
k: [Column.ensure_col(x).expression for x in v] if is_iterable(v) else Column.ensure_col(v).expression
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
new_expression = (
|
||||
callable_expression(**kwargs)
|
||||
callable_expression(**ensure_expression_values)
|
||||
if ensured_column is None
|
||||
else callable_expression(this=ensured_column.column_expression, **kwargs)
|
||||
else callable_expression(this=ensured_column.column_expression, **ensure_expression_values)
|
||||
)
|
||||
return Column(new_expression)
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from inspect import signature
|
||||
|
||||
from sqlglot import expressions as glotexp
|
||||
from sqlglot.dataframe.sql.column import Column
|
||||
|
@ -24,17 +23,15 @@ def lit(value: t.Optional[t.Any] = None) -> Column:
|
|||
|
||||
|
||||
def greatest(*cols: ColumnOrName) -> Column:
|
||||
columns = [Column.ensure_col(col) for col in cols]
|
||||
return Column.invoke_expression_over_column(
|
||||
columns[0], glotexp.Greatest, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None
|
||||
)
|
||||
if len(cols) > 1:
|
||||
return Column.invoke_expression_over_column(cols[0], glotexp.Greatest, expressions=cols[1:])
|
||||
return Column.invoke_expression_over_column(cols[0], glotexp.Greatest)
|
||||
|
||||
|
||||
def least(*cols: ColumnOrName) -> Column:
|
||||
columns = [Column.ensure_col(col) for col in cols]
|
||||
return Column.invoke_expression_over_column(
|
||||
columns[0], glotexp.Least, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None
|
||||
)
|
||||
if len(cols) > 1:
|
||||
return Column.invoke_expression_over_column(cols[0], glotexp.Least, expressions=cols[1:])
|
||||
return Column.invoke_expression_over_column(cols[0], glotexp.Least)
|
||||
|
||||
|
||||
def count_distinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
|
||||
|
@ -194,7 +191,7 @@ def log2(col: ColumnOrName) -> Column:
|
|||
def log(arg1: t.Union[ColumnOrName, float], arg2: t.Optional[ColumnOrName] = None) -> Column:
|
||||
if arg2 is None:
|
||||
return Column.invoke_expression_over_column(arg1, glotexp.Ln)
|
||||
return Column.invoke_expression_over_column(arg1, glotexp.Log, expression=Column.ensure_col(arg2).expression)
|
||||
return Column.invoke_expression_over_column(arg1, glotexp.Log, expression=arg2)
|
||||
|
||||
|
||||
def rint(col: ColumnOrName) -> Column:
|
||||
|
@ -310,7 +307,7 @@ def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]
|
|||
|
||||
|
||||
def pow(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
|
||||
return Column.invoke_anonymous_function(col1, "POW", col2)
|
||||
return Column.invoke_expression_over_column(col1, glotexp.Pow, power=col2)
|
||||
|
||||
|
||||
def row_number() -> Column:
|
||||
|
@ -340,14 +337,13 @@ def approxCountDistinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Col
|
|||
def approx_count_distinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column:
|
||||
if rsd is None:
|
||||
return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct)
|
||||
return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct, accuracy=Column.ensure_col(rsd).expression)
|
||||
return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct, accuracy=rsd)
|
||||
|
||||
|
||||
def coalesce(*cols: ColumnOrName) -> Column:
|
||||
columns = [Column.ensure_col(col) for col in cols]
|
||||
return Column.invoke_expression_over_column(
|
||||
columns[0], glotexp.Coalesce, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None
|
||||
)
|
||||
if len(cols) > 1:
|
||||
return Column.invoke_expression_over_column(cols[0], glotexp.Coalesce, expressions=cols[1:])
|
||||
return Column.invoke_expression_over_column(cols[0], glotexp.Coalesce)
|
||||
|
||||
|
||||
def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column:
|
||||
|
@ -405,11 +401,13 @@ def nanvl(col1: ColumnOrName, col2: ColumnOrName) -> Column:
|
|||
def percentile_approx(
|
||||
col: ColumnOrName,
|
||||
percentage: t.Union[ColumnOrLiteral, t.List[float], t.Tuple[float]],
|
||||
accuracy: t.Optional[t.Union[ColumnOrLiteral]] = None,
|
||||
accuracy: t.Optional[t.Union[ColumnOrLiteral, int]] = None,
|
||||
) -> Column:
|
||||
if accuracy:
|
||||
return Column.invoke_anonymous_function(col, "PERCENTILE_APPROX", percentage, accuracy)
|
||||
return Column.invoke_anonymous_function(col, "PERCENTILE_APPROX", percentage)
|
||||
return Column.invoke_expression_over_column(
|
||||
col, glotexp.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy
|
||||
)
|
||||
return Column.invoke_expression_over_column(col, glotexp.ApproxQuantile, quantile=lit(percentage))
|
||||
|
||||
|
||||
def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
|
||||
|
@ -422,7 +420,7 @@ def randn(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
|
|||
|
||||
def round(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
|
||||
if scale is not None:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Round, decimals=glotexp.convert(scale))
|
||||
return Column.invoke_expression_over_column(col, glotexp.Round, decimals=scale)
|
||||
return Column.invoke_expression_over_column(col, glotexp.Round)
|
||||
|
||||
|
||||
|
@ -433,9 +431,7 @@ def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
|
|||
|
||||
|
||||
def shiftleft(col: ColumnOrName, numBits: int) -> Column:
|
||||
return Column.invoke_expression_over_column(
|
||||
col, glotexp.BitwiseLeftShift, expression=Column.ensure_col(numBits).expression
|
||||
)
|
||||
return Column.invoke_expression_over_column(col, glotexp.BitwiseLeftShift, expression=numBits)
|
||||
|
||||
|
||||
def shiftLeft(col: ColumnOrName, numBits: int) -> Column:
|
||||
|
@ -443,9 +439,7 @@ def shiftLeft(col: ColumnOrName, numBits: int) -> Column:
|
|||
|
||||
|
||||
def shiftright(col: ColumnOrName, numBits: int) -> Column:
|
||||
return Column.invoke_expression_over_column(
|
||||
col, glotexp.BitwiseRightShift, expression=Column.ensure_col(numBits).expression
|
||||
)
|
||||
return Column.invoke_expression_over_column(col, glotexp.BitwiseRightShift, expression=numBits)
|
||||
|
||||
|
||||
def shiftRight(col: ColumnOrName, numBits: int) -> Column:
|
||||
|
@ -466,8 +460,7 @@ def expr(str: str) -> Column:
|
|||
|
||||
def struct(col: t.Union[ColumnOrName, t.Iterable[ColumnOrName]], *cols: ColumnOrName) -> Column:
|
||||
columns = ensure_list(col) + list(cols)
|
||||
expressions = [Column.ensure_col(column).expression for column in columns]
|
||||
return Column(glotexp.Struct(expressions=expressions))
|
||||
return Column.invoke_expression_over_column(None, glotexp.Struct, expressions=columns)
|
||||
|
||||
|
||||
def conv(col: ColumnOrName, fromBase: int, toBase: int) -> Column:
|
||||
|
@ -515,7 +508,7 @@ def current_timestamp() -> Column:
|
|||
|
||||
|
||||
def date_format(col: ColumnOrName, format: str) -> Column:
|
||||
return Column.invoke_anonymous_function(col, "DATE_FORMAT", lit(format))
|
||||
return Column.invoke_expression_over_column(col, glotexp.TimeToStr, format=lit(format))
|
||||
|
||||
|
||||
def year(col: ColumnOrName) -> Column:
|
||||
|
@ -563,15 +556,15 @@ def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Col
|
|||
|
||||
|
||||
def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.DateAdd, expression=Column.ensure_col(days).expression)
|
||||
return Column.invoke_expression_over_column(col, glotexp.DateAdd, expression=days)
|
||||
|
||||
|
||||
def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.DateSub, expression=Column.ensure_col(days).expression)
|
||||
return Column.invoke_expression_over_column(col, glotexp.DateSub, expression=days)
|
||||
|
||||
|
||||
def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(end, glotexp.DateDiff, expression=Column.ensure_col(start).expression)
|
||||
return Column.invoke_expression_over_column(end, glotexp.DateDiff, expression=start)
|
||||
|
||||
|
||||
def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Column:
|
||||
|
@ -586,8 +579,8 @@ def months_between(date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optiona
|
|||
|
||||
def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
|
||||
if format is not None:
|
||||
return Column.invoke_anonymous_function(col, "TO_DATE", lit(format))
|
||||
return Column.invoke_anonymous_function(col, "TO_DATE")
|
||||
return Column.invoke_expression_over_column(col, glotexp.TsOrDsToDate, format=lit(format))
|
||||
return Column.invoke_expression_over_column(col, glotexp.TsOrDsToDate)
|
||||
|
||||
|
||||
def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
|
||||
|
@ -597,11 +590,11 @@ def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
|
|||
|
||||
|
||||
def trunc(col: ColumnOrName, format: str) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.DateTrunc, unit=lit(format).expression)
|
||||
return Column.invoke_expression_over_column(col, glotexp.DateTrunc, unit=lit(format))
|
||||
|
||||
|
||||
def date_trunc(format: str, timestamp: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(timestamp, glotexp.TimestampTrunc, unit=lit(format).expression)
|
||||
return Column.invoke_expression_over_column(timestamp, glotexp.TimestampTrunc, unit=lit(format))
|
||||
|
||||
|
||||
def next_day(col: ColumnOrName, dayOfWeek: str) -> Column:
|
||||
|
@ -614,14 +607,14 @@ def last_day(col: ColumnOrName) -> Column:
|
|||
|
||||
def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
|
||||
if format is not None:
|
||||
return Column.invoke_anonymous_function(col, "FROM_UNIXTIME", lit(format))
|
||||
return Column.invoke_anonymous_function(col, "FROM_UNIXTIME")
|
||||
return Column.invoke_expression_over_column(col, glotexp.UnixToStr, format=lit(format))
|
||||
return Column.invoke_expression_over_column(col, glotexp.UnixToStr)
|
||||
|
||||
|
||||
def unix_timestamp(timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None) -> Column:
|
||||
if format is not None:
|
||||
return Column.invoke_anonymous_function(timestamp, "UNIX_TIMESTAMP", lit(format))
|
||||
return Column.invoke_anonymous_function(timestamp, "UNIX_TIMESTAMP")
|
||||
return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix, format=lit(format))
|
||||
return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix)
|
||||
|
||||
|
||||
def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column:
|
||||
|
@ -738,10 +731,7 @@ def trim(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
|
||||
columns = [Column(col) for col in cols]
|
||||
return Column.invoke_expression_over_column(
|
||||
None, glotexp.ConcatWs, expressions=[x.expression for x in [lit(sep)] + list(columns)]
|
||||
)
|
||||
return Column.invoke_expression_over_column(None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols))
|
||||
|
||||
|
||||
def decode(col: ColumnOrName, charset: str) -> Column:
|
||||
|
@ -798,18 +788,14 @@ def substring_index(str: ColumnOrName, delim: str, count: int) -> Column:
|
|||
|
||||
|
||||
def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(
|
||||
left, glotexp.Levenshtein, expression=Column.ensure_col(right).expression
|
||||
)
|
||||
return Column.invoke_expression_over_column(left, glotexp.Levenshtein, expression=right)
|
||||
|
||||
|
||||
def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column:
|
||||
substr_col = lit(substr)
|
||||
pos_column = lit(pos)
|
||||
str_column = Column.ensure_col(str)
|
||||
if pos is not None:
|
||||
return Column.invoke_anonymous_function(substr_col, "LOCATE", str_column, pos_column)
|
||||
return Column.invoke_anonymous_function(substr_col, "LOCATE", str_column)
|
||||
return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col, position=pos)
|
||||
return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col)
|
||||
|
||||
|
||||
def lpad(col: ColumnOrName, len: int, pad: str) -> Column:
|
||||
|
@ -821,15 +807,15 @@ def rpad(col: ColumnOrName, len: int, pad: str) -> Column:
|
|||
|
||||
|
||||
def repeat(col: ColumnOrName, n: int) -> Column:
|
||||
return Column.invoke_anonymous_function(col, "REPEAT", n)
|
||||
return Column.invoke_expression_over_column(col, glotexp.Repeat, times=lit(n))
|
||||
|
||||
|
||||
def split(str: ColumnOrName, pattern: str, limit: t.Optional[int] = None) -> Column:
|
||||
if limit is not None:
|
||||
return Column.invoke_expression_over_column(
|
||||
str, glotexp.RegexpSplit, expression=lit(pattern).expression, limit=lit(limit).expression
|
||||
str, glotexp.RegexpSplit, expression=lit(pattern).expression, limit=limit
|
||||
)
|
||||
return Column.invoke_expression_over_column(str, glotexp.RegexpSplit, expression=lit(pattern).expression)
|
||||
return Column.invoke_expression_over_column(str, glotexp.RegexpSplit, expression=lit(pattern))
|
||||
|
||||
|
||||
def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None) -> Column:
|
||||
|
@ -879,9 +865,8 @@ def translate(srcCol: ColumnOrName, matching: str, replace: str) -> Column:
|
|||
|
||||
|
||||
def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
|
||||
cols = _flatten(cols) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
|
||||
cols = [Column.ensure_col(col).expression for col in cols] # type: ignore
|
||||
return Column.invoke_expression_over_column(None, glotexp.Array, expressions=cols)
|
||||
columns = _flatten(cols) if not isinstance(cols[0], (str, Column)) else cols
|
||||
return Column.invoke_expression_over_column(None, glotexp.Array, expressions=columns)
|
||||
|
||||
|
||||
def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
|
||||
|
@ -892,7 +877,7 @@ def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column
|
|||
|
||||
|
||||
def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column:
|
||||
return Column.invoke_anonymous_function(col1, "MAP_FROM_ARRAYS", col2)
|
||||
return Column.invoke_expression_over_column(None, glotexp.Map, keys=col1, values=col2)
|
||||
|
||||
|
||||
def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
|
||||
|
@ -970,7 +955,7 @@ def posexplode_outer(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def get_json_object(col: ColumnOrName, path: str) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.JSONExtract, path=lit(path).expression)
|
||||
return Column.invoke_expression_over_column(col, glotexp.JSONExtract, path=lit(path))
|
||||
|
||||
|
||||
def json_tuple(col: ColumnOrName, *fields: str) -> Column:
|
||||
|
@ -1031,11 +1016,17 @@ def array_max(col: ColumnOrName) -> Column:
|
|||
|
||||
def sort_array(col: ColumnOrName, asc: t.Optional[bool] = None) -> Column:
|
||||
if asc is not None:
|
||||
return Column.invoke_anonymous_function(col, "SORT_ARRAY", lit(asc))
|
||||
return Column.invoke_anonymous_function(col, "SORT_ARRAY")
|
||||
return Column.invoke_expression_over_column(col, glotexp.SortArray, asc=asc)
|
||||
return Column.invoke_expression_over_column(col, glotexp.SortArray)
|
||||
|
||||
|
||||
def array_sort(col: ColumnOrName) -> Column:
|
||||
def array_sort(
|
||||
col: ColumnOrName,
|
||||
comparator: t.Optional[t.Union[t.Callable[[Column, Column], Column]]] = None,
|
||||
) -> Column:
|
||||
if comparator is not None:
|
||||
f_expression = _get_lambda_from_func(comparator)
|
||||
return Column.invoke_expression_over_column(col, glotexp.ArraySort, expression=f_expression)
|
||||
return Column.invoke_expression_over_column(col, glotexp.ArraySort)
|
||||
|
||||
|
||||
|
@ -1108,130 +1099,53 @@ def aggregate(
|
|||
initialValue: ColumnOrName,
|
||||
merge: t.Callable[[Column, Column], Column],
|
||||
finish: t.Optional[t.Callable[[Column], Column]] = None,
|
||||
accumulator_name: str = "acc",
|
||||
target_row_name: str = "x",
|
||||
) -> Column:
|
||||
merge_exp = glotexp.Lambda(
|
||||
this=merge(Column(accumulator_name), Column(target_row_name)).expression,
|
||||
expressions=[
|
||||
glotexp.to_identifier(accumulator_name, quoted=_lambda_quoted(accumulator_name)),
|
||||
glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name)),
|
||||
],
|
||||
)
|
||||
merge_exp = _get_lambda_from_func(merge)
|
||||
if finish is not None:
|
||||
finish_exp = glotexp.Lambda(
|
||||
this=finish(Column(accumulator_name)).expression,
|
||||
expressions=[glotexp.to_identifier(accumulator_name, quoted=_lambda_quoted(accumulator_name))],
|
||||
)
|
||||
finish_exp = _get_lambda_from_func(finish)
|
||||
return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp))
|
||||
return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp))
|
||||
|
||||
|
||||
def transform(
|
||||
col: ColumnOrName,
|
||||
f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
|
||||
target_row_name: str = "x",
|
||||
row_count_name: str = "i",
|
||||
col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]]
|
||||
) -> Column:
|
||||
num_arguments = len(signature(f).parameters)
|
||||
expressions = [glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))]
|
||||
columns = [Column(target_row_name)]
|
||||
if num_arguments > 1:
|
||||
columns.append(Column(row_count_name))
|
||||
expressions.append(glotexp.to_identifier(row_count_name, quoted=_lambda_quoted(row_count_name)))
|
||||
|
||||
f_expression = glotexp.Lambda(this=f(*columns).expression, expressions=expressions)
|
||||
f_expression = _get_lambda_from_func(f)
|
||||
return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression))
|
||||
|
||||
|
||||
def exists(col: ColumnOrName, f: t.Callable[[Column], Column], target_row_name: str = "x") -> Column:
|
||||
f_expression = glotexp.Lambda(
|
||||
this=f(Column(target_row_name)).expression,
|
||||
expressions=[glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))],
|
||||
)
|
||||
def exists(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column:
|
||||
f_expression = _get_lambda_from_func(f)
|
||||
return Column.invoke_anonymous_function(col, "EXISTS", Column(f_expression))
|
||||
|
||||
|
||||
def forall(col: ColumnOrName, f: t.Callable[[Column], Column], target_row_name: str = "x") -> Column:
|
||||
f_expression = glotexp.Lambda(
|
||||
this=f(Column(target_row_name)).expression,
|
||||
expressions=[glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))],
|
||||
)
|
||||
|
||||
def forall(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column:
|
||||
f_expression = _get_lambda_from_func(f)
|
||||
return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression))
|
||||
|
||||
|
||||
def filter(
|
||||
col: ColumnOrName,
|
||||
f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
|
||||
target_row_name: str = "x",
|
||||
row_count_name: str = "i",
|
||||
) -> Column:
|
||||
num_arguments = len(signature(f).parameters)
|
||||
expressions = [glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))]
|
||||
columns = [Column(target_row_name)]
|
||||
if num_arguments > 1:
|
||||
columns.append(Column(row_count_name))
|
||||
expressions.append(glotexp.to_identifier(row_count_name, quoted=_lambda_quoted(row_count_name)))
|
||||
|
||||
f_expression = glotexp.Lambda(this=f(*columns).expression, expressions=expressions)
|
||||
return Column.invoke_anonymous_function(col, "FILTER", Column(f_expression))
|
||||
def filter(col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]]) -> Column:
|
||||
f_expression = _get_lambda_from_func(f)
|
||||
return Column.invoke_expression_over_column(col, glotexp.ArrayFilter, expression=f_expression)
|
||||
|
||||
|
||||
def zip_with(
|
||||
left: ColumnOrName,
|
||||
right: ColumnOrName,
|
||||
f: t.Callable[[Column, Column], Column],
|
||||
left_name: str = "x",
|
||||
right_name: str = "y",
|
||||
) -> Column:
|
||||
f_expression = glotexp.Lambda(
|
||||
this=f(Column(left_name), Column(right_name)).expression,
|
||||
expressions=[
|
||||
glotexp.to_identifier(left_name, quoted=_lambda_quoted(left_name)),
|
||||
glotexp.to_identifier(right_name, quoted=_lambda_quoted(right_name)),
|
||||
],
|
||||
)
|
||||
|
||||
def zip_with(left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column]) -> Column:
|
||||
f_expression = _get_lambda_from_func(f)
|
||||
return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression))
|
||||
|
||||
|
||||
def transform_keys(
|
||||
col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v"
|
||||
) -> Column:
|
||||
f_expression = glotexp.Lambda(
|
||||
this=f(Column(key_name), Column(value_name)).expression,
|
||||
expressions=[
|
||||
glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)),
|
||||
glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)),
|
||||
],
|
||||
)
|
||||
def transform_keys(col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]]) -> Column:
|
||||
f_expression = _get_lambda_from_func(f)
|
||||
return Column.invoke_anonymous_function(col, "TRANSFORM_KEYS", Column(f_expression))
|
||||
|
||||
|
||||
def transform_values(
|
||||
col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v"
|
||||
) -> Column:
|
||||
f_expression = glotexp.Lambda(
|
||||
this=f(Column(key_name), Column(value_name)).expression,
|
||||
expressions=[
|
||||
glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)),
|
||||
glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)),
|
||||
],
|
||||
)
|
||||
def transform_values(col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]]) -> Column:
|
||||
f_expression = _get_lambda_from_func(f)
|
||||
return Column.invoke_anonymous_function(col, "TRANSFORM_VALUES", Column(f_expression))
|
||||
|
||||
|
||||
def map_filter(
|
||||
col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v"
|
||||
) -> Column:
|
||||
f_expression = glotexp.Lambda(
|
||||
this=f(Column(key_name), Column(value_name)).expression,
|
||||
expressions=[
|
||||
glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)),
|
||||
glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)),
|
||||
],
|
||||
)
|
||||
def map_filter(col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]]) -> Column:
|
||||
f_expression = _get_lambda_from_func(f)
|
||||
return Column.invoke_anonymous_function(col, "MAP_FILTER", Column(f_expression))
|
||||
|
||||
|
||||
|
@ -1239,20 +1153,18 @@ def map_zip_with(
|
|||
col1: ColumnOrName,
|
||||
col2: ColumnOrName,
|
||||
f: t.Union[t.Callable[[Column, Column, Column], Column]],
|
||||
key_name: str = "k",
|
||||
value1: str = "v1",
|
||||
value2: str = "v2",
|
||||
) -> Column:
|
||||
f_expression = glotexp.Lambda(
|
||||
this=f(Column(key_name), Column(value1), Column(value2)).expression,
|
||||
expressions=[
|
||||
glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)),
|
||||
glotexp.to_identifier(value1, quoted=_lambda_quoted(value1)),
|
||||
glotexp.to_identifier(value2, quoted=_lambda_quoted(value2)),
|
||||
],
|
||||
)
|
||||
f_expression = _get_lambda_from_func(f)
|
||||
return Column.invoke_anonymous_function(col1, "MAP_ZIP_WITH", col2, Column(f_expression))
|
||||
|
||||
|
||||
def _lambda_quoted(value: str) -> t.Optional[bool]:
|
||||
return False if value == "_" else None
|
||||
|
||||
|
||||
def _get_lambda_from_func(lambda_expression: t.Callable):
|
||||
variables = [glotexp.to_identifier(x, quoted=_lambda_quoted(x)) for x in lambda_expression.__code__.co_varnames]
|
||||
return glotexp.Lambda(
|
||||
this=lambda_expression(*[Column(x) for x in variables]).expression,
|
||||
expressions=variables,
|
||||
)
|
||||
|
|
|
@ -18,6 +18,36 @@ from sqlglot.helper import list_get
|
|||
from sqlglot.parser import Parser, parse_var_map
|
||||
from sqlglot.tokens import Tokenizer
|
||||
|
||||
# (FuncType, Multiplier)
|
||||
DATE_DELTA_INTERVAL = {
|
||||
"YEAR": ("ADD_MONTHS", 12),
|
||||
"MONTH": ("ADD_MONTHS", 1),
|
||||
"QUARTER": ("ADD_MONTHS", 3),
|
||||
"WEEK": ("DATE_ADD", 7),
|
||||
"DAY": ("DATE_ADD", 1),
|
||||
}
|
||||
|
||||
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
|
||||
|
||||
|
||||
def _add_date_sql(self, expression):
|
||||
unit = expression.text("unit").upper()
|
||||
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
|
||||
modified_increment = (
|
||||
int(expression.text("expression")) * multiplier if expression.expression.is_number else expression.expression
|
||||
)
|
||||
modified_increment = exp.Literal.number(modified_increment)
|
||||
return f"{func}({self.format_args(expression.this, modified_increment.this)})"
|
||||
|
||||
|
||||
def _date_diff_sql(self, expression):
|
||||
unit = expression.text("unit").upper()
|
||||
sql_func = "MONTHS_BETWEEN" if unit in DIFF_MONTH_SWITCH else "DATEDIFF"
|
||||
_, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1))
|
||||
multiplier_sql = f" / {multiplier}" if multiplier > 1 else ""
|
||||
diff_sql = f"{sql_func}({self.format_args(expression.this, expression.expression)})"
|
||||
return f"{diff_sql}{multiplier_sql}"
|
||||
|
||||
|
||||
def _array_sort(self, expression):
|
||||
if expression.expression:
|
||||
|
@ -120,10 +150,14 @@ class Hive(Dialect):
|
|||
"m": "%-M",
|
||||
"ss": "%S",
|
||||
"s": "%-S",
|
||||
"S": "%f",
|
||||
"SSSSSS": "%f",
|
||||
"a": "%p",
|
||||
"DD": "%j",
|
||||
"D": "%-j",
|
||||
"E": "%a",
|
||||
"EE": "%a",
|
||||
"EEE": "%a",
|
||||
"EEEE": "%A",
|
||||
}
|
||||
|
||||
date_format = "'yyyy-MM-dd'"
|
||||
|
@ -207,8 +241,8 @@ class Hive(Dialect):
|
|||
exp.ArraySize: rename_func("SIZE"),
|
||||
exp.ArraySort: _array_sort,
|
||||
exp.With: no_recursive_cte_sql,
|
||||
exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.DateDiff: lambda self, e: f"DATEDIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.DateAdd: _add_date_sql,
|
||||
exp.DateDiff: _date_diff_sql,
|
||||
exp.DateStrToDate: rename_func("TO_DATE"),
|
||||
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
|
||||
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
|
||||
|
|
|
@ -71,6 +71,7 @@ class Spark(Hive):
|
|||
length=list_get(args, 1),
|
||||
),
|
||||
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
||||
"IIF": exp.If.from_arg_list,
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
|
@ -111,6 +112,7 @@ class Spark(Hive):
|
|||
exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
|
||||
exp.TimestampTrunc: lambda self, e: f"DATE_TRUNC({self.sql(e, 'unit')}, {self.sql(e, 'this')})",
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
exp.DateFromParts: rename_func("MAKE_DATE"),
|
||||
}
|
||||
|
||||
WRAP_DERIVED_VALUES = False
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import rename_func
|
||||
from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func
|
||||
from sqlglot.dialects.mysql import MySQL
|
||||
|
||||
|
||||
|
@ -14,6 +14,8 @@ class StarRocks(MySQL):
|
|||
|
||||
TRANSFORMS = {
|
||||
**MySQL.Generator.TRANSFORMS,
|
||||
exp.JSONExtractScalar: arrow_json_extract_sql,
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.DateDiff: rename_func("DATEDIFF"),
|
||||
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeStrToDate: rename_func("TO_DATE"),
|
||||
|
|
|
@ -1,14 +1,149 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
from sqlglot.dialects.dialect import Dialect, rename_func
|
||||
from sqlglot.expressions import DataType
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import list_get
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.time import format_time
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
|
||||
FULL_FORMAT_TIME_MAPPING = {"weekday": "%A", "dw": "%A", "w": "%A", "month": "%B", "mm": "%B", "m": "%B"}
|
||||
DATE_DELTA_INTERVAL = {
|
||||
"year": "year",
|
||||
"yyyy": "year",
|
||||
"yy": "year",
|
||||
"quarter": "quarter",
|
||||
"qq": "quarter",
|
||||
"q": "quarter",
|
||||
"month": "month",
|
||||
"mm": "month",
|
||||
"m": "month",
|
||||
"week": "week",
|
||||
"ww": "week",
|
||||
"wk": "week",
|
||||
"day": "day",
|
||||
"dd": "day",
|
||||
"d": "day",
|
||||
}
|
||||
|
||||
|
||||
def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
|
||||
def _format_time(args):
|
||||
return exp_class(
|
||||
this=list_get(args, 1),
|
||||
format=exp.Literal.string(
|
||||
format_time(
|
||||
list_get(args, 0).name or (TSQL.time_format if default is True else default),
|
||||
{**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING} if full_format_mapping else TSQL.time_mapping,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
return _format_time
|
||||
|
||||
|
||||
def parse_date_delta(exp_class):
|
||||
def inner_func(args):
|
||||
unit = DATE_DELTA_INTERVAL.get(list_get(args, 0).name.lower(), "day")
|
||||
return exp_class(this=list_get(args, 2), expression=list_get(args, 1), unit=unit)
|
||||
|
||||
return inner_func
|
||||
|
||||
|
||||
def generate_date_delta(self, e):
|
||||
func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF"
|
||||
return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})"
|
||||
|
||||
|
||||
class TSQL(Dialect):
|
||||
null_ordering = "nulls_are_small"
|
||||
time_format = "'yyyy-mm-dd hh:mm:ss'"
|
||||
|
||||
time_mapping = {
|
||||
"yyyy": "%Y",
|
||||
"yy": "%y",
|
||||
"year": "%Y",
|
||||
"qq": "%q",
|
||||
"q": "%q",
|
||||
"quarter": "%q",
|
||||
"dayofyear": "%j",
|
||||
"day": "%d",
|
||||
"dy": "%d",
|
||||
"y": "%Y",
|
||||
"week": "%W",
|
||||
"ww": "%W",
|
||||
"wk": "%W",
|
||||
"hour": "%h",
|
||||
"hh": "%I",
|
||||
"minute": "%M",
|
||||
"mi": "%M",
|
||||
"n": "%M",
|
||||
"second": "%S",
|
||||
"ss": "%S",
|
||||
"s": "%-S",
|
||||
"millisecond": "%f",
|
||||
"ms": "%f",
|
||||
"weekday": "%W",
|
||||
"dw": "%W",
|
||||
"month": "%m",
|
||||
"mm": "%M",
|
||||
"m": "%-M",
|
||||
"Y": "%Y",
|
||||
"YYYY": "%Y",
|
||||
"YY": "%y",
|
||||
"MMMM": "%B",
|
||||
"MMM": "%b",
|
||||
"MM": "%m",
|
||||
"M": "%-m",
|
||||
"dd": "%d",
|
||||
"d": "%-d",
|
||||
"HH": "%H",
|
||||
"H": "%-H",
|
||||
"h": "%-I",
|
||||
"S": "%f",
|
||||
}
|
||||
|
||||
convert_format_mapping = {
|
||||
"0": "%b %d %Y %-I:%M%p",
|
||||
"1": "%m/%d/%y",
|
||||
"2": "%y.%m.%d",
|
||||
"3": "%d/%m/%y",
|
||||
"4": "%d.%m.%y",
|
||||
"5": "%d-%m-%y",
|
||||
"6": "%d %b %y",
|
||||
"7": "%b %d, %y",
|
||||
"8": "%H:%M:%S",
|
||||
"9": "%b %d %Y %-I:%M:%S:%f%p",
|
||||
"10": "mm-dd-yy",
|
||||
"11": "yy/mm/dd",
|
||||
"12": "yymmdd",
|
||||
"13": "%d %b %Y %H:%M:ss:%f",
|
||||
"14": "%H:%M:%S:%f",
|
||||
"20": "%Y-%m-%d %H:%M:%S",
|
||||
"21": "%Y-%m-%d %H:%M:%S.%f",
|
||||
"22": "%m/%d/%y %-I:%M:%S %p",
|
||||
"23": "%Y-%m-%d",
|
||||
"24": "%H:%M:%S",
|
||||
"25": "%Y-%m-%d %H:%M:%S.%f",
|
||||
"100": "%b %d %Y %-I:%M%p",
|
||||
"101": "%m/%d/%Y",
|
||||
"102": "%Y.%m.%d",
|
||||
"103": "%d/%m/%Y",
|
||||
"104": "%d.%m.%Y",
|
||||
"105": "%d-%m-%Y",
|
||||
"106": "%d %b %Y",
|
||||
"107": "%b %d, %Y",
|
||||
"108": "%H:%M:%S",
|
||||
"109": "%b %d %Y %-I:%M:%S:%f%p",
|
||||
"110": "%m-%d-%Y",
|
||||
"111": "%Y/%m/%d",
|
||||
"112": "%Y%m%d",
|
||||
"113": "%d %b %Y %H:%M:%S:%f",
|
||||
"114": "%H:%M:%S:%f",
|
||||
"120": "%Y-%m-%d %H:%M:%S",
|
||||
"121": "%Y-%m-%d %H:%M:%S.%f",
|
||||
}
|
||||
|
||||
class Tokenizer(Tokenizer):
|
||||
IDENTIFIERS = ['"', ("[", "]")]
|
||||
|
||||
|
@ -29,19 +164,67 @@ class TSQL(Dialect):
|
|||
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
|
||||
"XML": TokenType.XML,
|
||||
"SQL_VARIANT": TokenType.VARIANT,
|
||||
"NVARCHAR(MAX)": TokenType.TEXT,
|
||||
"VARCHAR(MAX)": TokenType.TEXT,
|
||||
}
|
||||
|
||||
class Parser(Parser):
|
||||
FUNCTIONS = {
|
||||
**Parser.FUNCTIONS,
|
||||
"CHARINDEX": exp.StrPosition.from_arg_list,
|
||||
"ISNULL": exp.Coalesce.from_arg_list,
|
||||
"DATEADD": parse_date_delta(exp.DateAdd),
|
||||
"DATEDIFF": parse_date_delta(exp.DateDiff),
|
||||
"DATENAME": tsql_format_time_lambda(exp.TimeToStr, full_format_mapping=True),
|
||||
"DATEPART": tsql_format_time_lambda(exp.TimeToStr),
|
||||
"GETDATE": exp.CurrentDate.from_arg_list,
|
||||
"IIF": exp.If.from_arg_list,
|
||||
"LEN": exp.Length.from_arg_list,
|
||||
"REPLICATE": exp.Repeat.from_arg_list,
|
||||
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
|
||||
}
|
||||
|
||||
def _parse_convert(self):
|
||||
VAR_LENGTH_DATATYPES = {
|
||||
DataType.Type.NVARCHAR,
|
||||
DataType.Type.VARCHAR,
|
||||
DataType.Type.CHAR,
|
||||
DataType.Type.NCHAR,
|
||||
}
|
||||
|
||||
def _parse_convert(self, strict):
|
||||
to = self._parse_types()
|
||||
self._match(TokenType.COMMA)
|
||||
this = self._parse_field()
|
||||
return self.expression(exp.Cast, this=this, to=to)
|
||||
|
||||
# Retrieve length of datatype and override to default if not specified
|
||||
if list_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
|
||||
to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False)
|
||||
|
||||
# Check whether a conversion with format is applicable
|
||||
if self._match(TokenType.COMMA):
|
||||
format_val = self._parse_number().name
|
||||
if format_val not in TSQL.convert_format_mapping:
|
||||
raise ValueError(f"CONVERT function at T-SQL does not support format style {format_val}")
|
||||
format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val])
|
||||
|
||||
# Check whether the convert entails a string to date format
|
||||
if to.this == DataType.Type.DATE:
|
||||
return self.expression(exp.StrToDate, this=this, format=format_norm)
|
||||
# Check whether the convert entails a string to datetime format
|
||||
elif to.this == DataType.Type.DATETIME:
|
||||
return self.expression(exp.StrToTime, this=this, format=format_norm)
|
||||
# Check whether the convert entails a date to string format
|
||||
elif to.this in self.VAR_LENGTH_DATATYPES:
|
||||
return self.expression(
|
||||
exp.Cast if strict else exp.TryCast,
|
||||
to=to,
|
||||
this=self.expression(exp.TimeToStr, this=this, format=format_norm),
|
||||
)
|
||||
elif to.this == DataType.Type.TEXT:
|
||||
return self.expression(exp.TimeToStr, this=this, format=format_norm)
|
||||
|
||||
# Entails a simple cast without any format requirement
|
||||
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
|
||||
|
||||
class Generator(Generator):
|
||||
TYPE_MAPPING = {
|
||||
|
@ -52,3 +235,11 @@ class TSQL(Dialect):
|
|||
exp.DataType.Type.DATETIME: "DATETIME2",
|
||||
exp.DataType.Type.VARIANT: "SQL_VARIANT",
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**Generator.TRANSFORMS,
|
||||
exp.DateAdd: lambda self, e: generate_date_delta(self, e),
|
||||
exp.DateDiff: lambda self, e: generate_date_delta(self, e),
|
||||
exp.CurrentDate: rename_func("GETDATE"),
|
||||
exp.If: rename_func("IIF"),
|
||||
}
|
||||
|
|
|
@ -2411,6 +2411,11 @@ class TimeTrunc(Func, TimeUnit):
|
|||
arg_types = {"this": True, "unit": True, "zone": False}
|
||||
|
||||
|
||||
class DateFromParts(Func):
|
||||
_sql_names = ["DATEFROMPARTS"]
|
||||
arg_types = {"year": True, "month": True, "day": True}
|
||||
|
||||
|
||||
class DateStrToDate(Func):
|
||||
pass
|
||||
|
||||
|
@ -2554,7 +2559,7 @@ class Quantile(AggFunc):
|
|||
|
||||
|
||||
class ApproxQuantile(Quantile):
|
||||
pass
|
||||
arg_types = {"this": True, "quantile": True, "accuracy": False}
|
||||
|
||||
|
||||
class Reduce(Func):
|
||||
|
@ -2569,6 +2574,10 @@ class RegexpSplit(Func):
|
|||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
class Repeat(Func):
|
||||
arg_types = {"this": True, "times": True}
|
||||
|
||||
|
||||
class Round(Func):
|
||||
arg_types = {"this": True, "decimals": False}
|
||||
|
||||
|
@ -2690,7 +2699,7 @@ class TsOrDiToDi(Func):
|
|||
|
||||
|
||||
class UnixToStr(Func):
|
||||
arg_types = {"this": True, "format": True}
|
||||
arg_types = {"this": True, "format": False}
|
||||
|
||||
|
||||
class UnixToTime(Func):
|
||||
|
@ -3077,6 +3086,8 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts):
|
|||
)
|
||||
if from_:
|
||||
update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts))
|
||||
if isinstance(where, Condition):
|
||||
where = Where(this=where)
|
||||
if where:
|
||||
update.set("where", maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts))
|
||||
return update
|
||||
|
@ -3518,6 +3529,41 @@ def replace_tables(expression, mapping):
|
|||
return expression.transform(_replace_tables)
|
||||
|
||||
|
||||
def replace_placeholders(expression, *args, **kwargs):
|
||||
"""Replace placeholders in an expression.
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): Expression node to be transformed and replaced
|
||||
args: Positional names that will substitute unnamed placeholders in the given order
|
||||
kwargs: Keyword arguments that will substitute named placeholders
|
||||
|
||||
Examples:
|
||||
>>> from sqlglot import exp, parse_one
|
||||
>>> replace_placeholders(
|
||||
... parse_one("select * from :tbl where ? = ?"), "a", "b", tbl="foo"
|
||||
... ).sql()
|
||||
'SELECT * FROM foo WHERE a = b'
|
||||
|
||||
Returns:
|
||||
The mapped expression
|
||||
"""
|
||||
|
||||
def _replace_placeholders(node, args, **kwargs):
|
||||
if isinstance(node, Placeholder):
|
||||
if node.name:
|
||||
new_name = kwargs.get(node.name)
|
||||
if new_name:
|
||||
return to_identifier(new_name)
|
||||
else:
|
||||
try:
|
||||
return to_identifier(next(args))
|
||||
except StopIteration:
|
||||
pass
|
||||
return node
|
||||
|
||||
return expression.transform(_replace_placeholders, iter(args), **kwargs)
|
||||
|
||||
|
||||
TRUE = Boolean(this=True)
|
||||
FALSE = Boolean(this=False)
|
||||
NULL = Null()
|
||||
|
|
|
@ -47,7 +47,8 @@ class Generator:
|
|||
The default is on the smaller end because the length only represents a segment and not the true
|
||||
line length.
|
||||
Default: 80
|
||||
annotations: Whether or not to show annotations in the SQL.
|
||||
annotations: Whether or not to show annotations in the SQL when `pretty` is True.
|
||||
Annotations can only be shown in pretty mode otherwise they may clobber resulting sql.
|
||||
Default: True
|
||||
"""
|
||||
|
||||
|
@ -280,7 +281,7 @@ class Generator:
|
|||
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
|
||||
|
||||
def annotation_sql(self, expression):
|
||||
if self._annotations:
|
||||
if self._annotations and self.pretty:
|
||||
return f"{self.sql(expression, 'expression')} # {expression.name}"
|
||||
return self.sql(expression, "expression")
|
||||
|
||||
|
|
|
@ -194,6 +194,24 @@ def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: b
|
|||
return words + [None] * (min_num_words - len(words))
|
||||
|
||||
|
||||
def is_iterable(value: t.Any) -> bool:
|
||||
"""
|
||||
Checks if the value is an iterable but does not include strings and bytes
|
||||
|
||||
Examples:
|
||||
>>> is_iterable([1,2])
|
||||
True
|
||||
>>> is_iterable("test")
|
||||
False
|
||||
|
||||
Args:
|
||||
value: The value to check if it is an interable
|
||||
|
||||
Returns: Bool indicating if it is an iterable
|
||||
"""
|
||||
return hasattr(value, "__iter__") and not isinstance(value, (str, bytes))
|
||||
|
||||
|
||||
def flatten(values: t.Iterable[t.Union[t.Iterable[t.Any], t.Any]]) -> t.Generator[t.Any, None, None]:
|
||||
"""
|
||||
Flattens a list that can contain both iterables and non-iterable elements
|
||||
|
@ -211,7 +229,7 @@ def flatten(values: t.Iterable[t.Union[t.Iterable[t.Any], t.Any]]) -> t.Generato
|
|||
Yields non-iterable elements (not including str or byte as iterable)
|
||||
"""
|
||||
for value in values:
|
||||
if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
|
||||
if is_iterable(value):
|
||||
yield from flatten(value)
|
||||
else:
|
||||
yield value
|
||||
|
|
|
@ -433,7 +433,8 @@ class Parser:
|
|||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
"CONVERT": lambda self: self._parse_convert(),
|
||||
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
|
||||
"TRY_CONVERT": lambda self: self._parse_convert(False),
|
||||
"EXTRACT": lambda self: self._parse_extract(),
|
||||
"POSITION": lambda self: self._parse_position(),
|
||||
"SUBSTRING": lambda self: self._parse_substring(),
|
||||
|
@ -1512,7 +1513,7 @@ class Parser:
|
|||
return this
|
||||
|
||||
def _parse_offset(self, this=None):
|
||||
if not self._match(TokenType.OFFSET):
|
||||
if not self._match_set((TokenType.OFFSET, TokenType.COMMA)):
|
||||
return this
|
||||
count = self._parse_number()
|
||||
self._match_set((TokenType.ROW, TokenType.ROWS))
|
||||
|
@ -2134,7 +2135,7 @@ class Parser:
|
|||
|
||||
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
|
||||
|
||||
def _parse_convert(self):
|
||||
def _parse_convert(self, strict):
|
||||
this = self._parse_field()
|
||||
if self._match(TokenType.USING):
|
||||
to = self.expression(exp.CharacterSet, this=self._parse_var())
|
||||
|
@ -2142,7 +2143,7 @@ class Parser:
|
|||
to = self._parse_types()
|
||||
else:
|
||||
to = None
|
||||
return self.expression(exp.Cast, this=this, to=to)
|
||||
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
|
||||
|
||||
def _parse_position(self):
|
||||
args = self._parse_csv(self._parse_bitwise)
|
||||
|
|
|
@ -14,6 +14,8 @@ def format_time(string, mapping, trie=None):
|
|||
mapping: Dictionary of time format to target time format
|
||||
trie: Optional trie, can be passed in for performance
|
||||
"""
|
||||
if not string:
|
||||
return None
|
||||
start = 0
|
||||
end = 1
|
||||
size = len(string)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue