1
0
Fork 0

Merging upstream version 9.0.3.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:50:31 +01:00
parent 66ef36a209
commit b1dc5c6faf
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
22 changed files with 742 additions and 223 deletions

View file

@ -316,7 +316,7 @@ Dialect["custom"]
## Run Tests and Lint ## Run Tests and Lint
``` ```
pip install -r requirements.txt pip install -r dev-requirements.txt
# set `SKIP_INTEGRATION=1` to skip integration tests # set `SKIP_INTEGRATION=1` to skip integration tests
./run_checks.sh ./run_checks.sh
``` ```

View file

@ -24,7 +24,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType from sqlglot.tokens import Tokenizer, TokenType
__version__ = "9.0.1" __version__ = "9.0.3"
pretty = False pretty = False

View file

@ -5,7 +5,7 @@ import typing as t
import sqlglot import sqlglot
from sqlglot import expressions as exp from sqlglot import expressions as exp
from sqlglot.dataframe.sql.types import DataType from sqlglot.dataframe.sql.types import DataType
from sqlglot.helper import flatten from sqlglot.helper import flatten, is_iterable
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import ColumnOrLiteral from sqlglot.dataframe.sql._typing import ColumnOrLiteral
@ -134,10 +134,14 @@ class Column:
cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs
) -> Column: ) -> Column:
ensured_column = None if column is None else cls.ensure_col(column) ensured_column = None if column is None else cls.ensure_col(column)
ensure_expression_values = {
k: [Column.ensure_col(x).expression for x in v] if is_iterable(v) else Column.ensure_col(v).expression
for k, v in kwargs.items()
}
new_expression = ( new_expression = (
callable_expression(**kwargs) callable_expression(**ensure_expression_values)
if ensured_column is None if ensured_column is None
else callable_expression(this=ensured_column.column_expression, **kwargs) else callable_expression(this=ensured_column.column_expression, **ensure_expression_values)
) )
return Column(new_expression) return Column(new_expression)

View file

@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
import typing as t import typing as t
from inspect import signature
from sqlglot import expressions as glotexp from sqlglot import expressions as glotexp
from sqlglot.dataframe.sql.column import Column from sqlglot.dataframe.sql.column import Column
@ -24,17 +23,15 @@ def lit(value: t.Optional[t.Any] = None) -> Column:
def greatest(*cols: ColumnOrName) -> Column: def greatest(*cols: ColumnOrName) -> Column:
columns = [Column.ensure_col(col) for col in cols] if len(cols) > 1:
return Column.invoke_expression_over_column( return Column.invoke_expression_over_column(cols[0], glotexp.Greatest, expressions=cols[1:])
columns[0], glotexp.Greatest, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None return Column.invoke_expression_over_column(cols[0], glotexp.Greatest)
)
def least(*cols: ColumnOrName) -> Column: def least(*cols: ColumnOrName) -> Column:
columns = [Column.ensure_col(col) for col in cols] if len(cols) > 1:
return Column.invoke_expression_over_column( return Column.invoke_expression_over_column(cols[0], glotexp.Least, expressions=cols[1:])
columns[0], glotexp.Least, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None return Column.invoke_expression_over_column(cols[0], glotexp.Least)
)
def count_distinct(col: ColumnOrName, *cols: ColumnOrName) -> Column: def count_distinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
@ -194,7 +191,7 @@ def log2(col: ColumnOrName) -> Column:
def log(arg1: t.Union[ColumnOrName, float], arg2: t.Optional[ColumnOrName] = None) -> Column: def log(arg1: t.Union[ColumnOrName, float], arg2: t.Optional[ColumnOrName] = None) -> Column:
if arg2 is None: if arg2 is None:
return Column.invoke_expression_over_column(arg1, glotexp.Ln) return Column.invoke_expression_over_column(arg1, glotexp.Ln)
return Column.invoke_expression_over_column(arg1, glotexp.Log, expression=Column.ensure_col(arg2).expression) return Column.invoke_expression_over_column(arg1, glotexp.Log, expression=arg2)
def rint(col: ColumnOrName) -> Column: def rint(col: ColumnOrName) -> Column:
@ -310,7 +307,7 @@ def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]
def pow(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column: def pow(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
return Column.invoke_anonymous_function(col1, "POW", col2) return Column.invoke_expression_over_column(col1, glotexp.Pow, power=col2)
def row_number() -> Column: def row_number() -> Column:
@ -340,14 +337,13 @@ def approxCountDistinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Col
def approx_count_distinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column: def approx_count_distinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column:
if rsd is None: if rsd is None:
return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct) return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct)
return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct, accuracy=Column.ensure_col(rsd).expression) return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct, accuracy=rsd)
def coalesce(*cols: ColumnOrName) -> Column: def coalesce(*cols: ColumnOrName) -> Column:
columns = [Column.ensure_col(col) for col in cols] if len(cols) > 1:
return Column.invoke_expression_over_column( return Column.invoke_expression_over_column(cols[0], glotexp.Coalesce, expressions=cols[1:])
columns[0], glotexp.Coalesce, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None return Column.invoke_expression_over_column(cols[0], glotexp.Coalesce)
)
def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column: def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column:
@ -405,11 +401,13 @@ def nanvl(col1: ColumnOrName, col2: ColumnOrName) -> Column:
def percentile_approx( def percentile_approx(
col: ColumnOrName, col: ColumnOrName,
percentage: t.Union[ColumnOrLiteral, t.List[float], t.Tuple[float]], percentage: t.Union[ColumnOrLiteral, t.List[float], t.Tuple[float]],
accuracy: t.Optional[t.Union[ColumnOrLiteral]] = None, accuracy: t.Optional[t.Union[ColumnOrLiteral, int]] = None,
) -> Column: ) -> Column:
if accuracy: if accuracy:
return Column.invoke_anonymous_function(col, "PERCENTILE_APPROX", percentage, accuracy) return Column.invoke_expression_over_column(
return Column.invoke_anonymous_function(col, "PERCENTILE_APPROX", percentage) col, glotexp.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy
)
return Column.invoke_expression_over_column(col, glotexp.ApproxQuantile, quantile=lit(percentage))
def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column: def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
@ -422,7 +420,7 @@ def randn(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
def round(col: ColumnOrName, scale: t.Optional[int] = None) -> Column: def round(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
if scale is not None: if scale is not None:
return Column.invoke_expression_over_column(col, glotexp.Round, decimals=glotexp.convert(scale)) return Column.invoke_expression_over_column(col, glotexp.Round, decimals=scale)
return Column.invoke_expression_over_column(col, glotexp.Round) return Column.invoke_expression_over_column(col, glotexp.Round)
@ -433,9 +431,7 @@ def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
def shiftleft(col: ColumnOrName, numBits: int) -> Column: def shiftleft(col: ColumnOrName, numBits: int) -> Column:
return Column.invoke_expression_over_column( return Column.invoke_expression_over_column(col, glotexp.BitwiseLeftShift, expression=numBits)
col, glotexp.BitwiseLeftShift, expression=Column.ensure_col(numBits).expression
)
def shiftLeft(col: ColumnOrName, numBits: int) -> Column: def shiftLeft(col: ColumnOrName, numBits: int) -> Column:
@ -443,9 +439,7 @@ def shiftLeft(col: ColumnOrName, numBits: int) -> Column:
def shiftright(col: ColumnOrName, numBits: int) -> Column: def shiftright(col: ColumnOrName, numBits: int) -> Column:
return Column.invoke_expression_over_column( return Column.invoke_expression_over_column(col, glotexp.BitwiseRightShift, expression=numBits)
col, glotexp.BitwiseRightShift, expression=Column.ensure_col(numBits).expression
)
def shiftRight(col: ColumnOrName, numBits: int) -> Column: def shiftRight(col: ColumnOrName, numBits: int) -> Column:
@ -466,8 +460,7 @@ def expr(str: str) -> Column:
def struct(col: t.Union[ColumnOrName, t.Iterable[ColumnOrName]], *cols: ColumnOrName) -> Column: def struct(col: t.Union[ColumnOrName, t.Iterable[ColumnOrName]], *cols: ColumnOrName) -> Column:
columns = ensure_list(col) + list(cols) columns = ensure_list(col) + list(cols)
expressions = [Column.ensure_col(column).expression for column in columns] return Column.invoke_expression_over_column(None, glotexp.Struct, expressions=columns)
return Column(glotexp.Struct(expressions=expressions))
def conv(col: ColumnOrName, fromBase: int, toBase: int) -> Column: def conv(col: ColumnOrName, fromBase: int, toBase: int) -> Column:
@ -515,7 +508,7 @@ def current_timestamp() -> Column:
def date_format(col: ColumnOrName, format: str) -> Column: def date_format(col: ColumnOrName, format: str) -> Column:
return Column.invoke_anonymous_function(col, "DATE_FORMAT", lit(format)) return Column.invoke_expression_over_column(col, glotexp.TimeToStr, format=lit(format))
def year(col: ColumnOrName) -> Column: def year(col: ColumnOrName) -> Column:
@ -563,15 +556,15 @@ def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Col
def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column: def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
return Column.invoke_expression_over_column(col, glotexp.DateAdd, expression=Column.ensure_col(days).expression) return Column.invoke_expression_over_column(col, glotexp.DateAdd, expression=days)
def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column: def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
return Column.invoke_expression_over_column(col, glotexp.DateSub, expression=Column.ensure_col(days).expression) return Column.invoke_expression_over_column(col, glotexp.DateSub, expression=days)
def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column: def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column:
return Column.invoke_expression_over_column(end, glotexp.DateDiff, expression=Column.ensure_col(start).expression) return Column.invoke_expression_over_column(end, glotexp.DateDiff, expression=start)
def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Column: def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Column:
@ -586,8 +579,8 @@ def months_between(date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optiona
def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column: def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
if format is not None: if format is not None:
return Column.invoke_anonymous_function(col, "TO_DATE", lit(format)) return Column.invoke_expression_over_column(col, glotexp.TsOrDsToDate, format=lit(format))
return Column.invoke_anonymous_function(col, "TO_DATE") return Column.invoke_expression_over_column(col, glotexp.TsOrDsToDate)
def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column: def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
@ -597,11 +590,11 @@ def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
def trunc(col: ColumnOrName, format: str) -> Column: def trunc(col: ColumnOrName, format: str) -> Column:
return Column.invoke_expression_over_column(col, glotexp.DateTrunc, unit=lit(format).expression) return Column.invoke_expression_over_column(col, glotexp.DateTrunc, unit=lit(format))
def date_trunc(format: str, timestamp: ColumnOrName) -> Column: def date_trunc(format: str, timestamp: ColumnOrName) -> Column:
return Column.invoke_expression_over_column(timestamp, glotexp.TimestampTrunc, unit=lit(format).expression) return Column.invoke_expression_over_column(timestamp, glotexp.TimestampTrunc, unit=lit(format))
def next_day(col: ColumnOrName, dayOfWeek: str) -> Column: def next_day(col: ColumnOrName, dayOfWeek: str) -> Column:
@ -614,14 +607,14 @@ def last_day(col: ColumnOrName) -> Column:
def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column: def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
if format is not None: if format is not None:
return Column.invoke_anonymous_function(col, "FROM_UNIXTIME", lit(format)) return Column.invoke_expression_over_column(col, glotexp.UnixToStr, format=lit(format))
return Column.invoke_anonymous_function(col, "FROM_UNIXTIME") return Column.invoke_expression_over_column(col, glotexp.UnixToStr)
def unix_timestamp(timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None) -> Column: def unix_timestamp(timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None) -> Column:
if format is not None: if format is not None:
return Column.invoke_anonymous_function(timestamp, "UNIX_TIMESTAMP", lit(format)) return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix, format=lit(format))
return Column.invoke_anonymous_function(timestamp, "UNIX_TIMESTAMP") return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix)
def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column: def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column:
@ -738,10 +731,7 @@ def trim(col: ColumnOrName) -> Column:
def concat_ws(sep: str, *cols: ColumnOrName) -> Column: def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
columns = [Column(col) for col in cols] return Column.invoke_expression_over_column(None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols))
return Column.invoke_expression_over_column(
None, glotexp.ConcatWs, expressions=[x.expression for x in [lit(sep)] + list(columns)]
)
def decode(col: ColumnOrName, charset: str) -> Column: def decode(col: ColumnOrName, charset: str) -> Column:
@ -798,18 +788,14 @@ def substring_index(str: ColumnOrName, delim: str, count: int) -> Column:
def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column: def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column:
return Column.invoke_expression_over_column( return Column.invoke_expression_over_column(left, glotexp.Levenshtein, expression=right)
left, glotexp.Levenshtein, expression=Column.ensure_col(right).expression
)
def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column: def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column:
substr_col = lit(substr) substr_col = lit(substr)
pos_column = lit(pos)
str_column = Column.ensure_col(str)
if pos is not None: if pos is not None:
return Column.invoke_anonymous_function(substr_col, "LOCATE", str_column, pos_column) return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col, position=pos)
return Column.invoke_anonymous_function(substr_col, "LOCATE", str_column) return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col)
def lpad(col: ColumnOrName, len: int, pad: str) -> Column: def lpad(col: ColumnOrName, len: int, pad: str) -> Column:
@ -821,15 +807,15 @@ def rpad(col: ColumnOrName, len: int, pad: str) -> Column:
def repeat(col: ColumnOrName, n: int) -> Column: def repeat(col: ColumnOrName, n: int) -> Column:
return Column.invoke_anonymous_function(col, "REPEAT", n) return Column.invoke_expression_over_column(col, glotexp.Repeat, times=lit(n))
def split(str: ColumnOrName, pattern: str, limit: t.Optional[int] = None) -> Column: def split(str: ColumnOrName, pattern: str, limit: t.Optional[int] = None) -> Column:
if limit is not None: if limit is not None:
return Column.invoke_expression_over_column( return Column.invoke_expression_over_column(
str, glotexp.RegexpSplit, expression=lit(pattern).expression, limit=lit(limit).expression str, glotexp.RegexpSplit, expression=lit(pattern).expression, limit=limit
) )
return Column.invoke_expression_over_column(str, glotexp.RegexpSplit, expression=lit(pattern).expression) return Column.invoke_expression_over_column(str, glotexp.RegexpSplit, expression=lit(pattern))
def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None) -> Column: def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None) -> Column:
@ -879,9 +865,8 @@ def translate(srcCol: ColumnOrName, matching: str, replace: str) -> Column:
def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
cols = _flatten(cols) if not isinstance(cols[0], (str, Column)) else cols # type: ignore columns = _flatten(cols) if not isinstance(cols[0], (str, Column)) else cols
cols = [Column.ensure_col(col).expression for col in cols] # type: ignore return Column.invoke_expression_over_column(None, glotexp.Array, expressions=columns)
return Column.invoke_expression_over_column(None, glotexp.Array, expressions=cols)
def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
@ -892,7 +877,7 @@ def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column
def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column: def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col1, "MAP_FROM_ARRAYS", col2) return Column.invoke_expression_over_column(None, glotexp.Map, keys=col1, values=col2)
def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column: def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
@ -970,7 +955,7 @@ def posexplode_outer(col: ColumnOrName) -> Column:
def get_json_object(col: ColumnOrName, path: str) -> Column: def get_json_object(col: ColumnOrName, path: str) -> Column:
return Column.invoke_expression_over_column(col, glotexp.JSONExtract, path=lit(path).expression) return Column.invoke_expression_over_column(col, glotexp.JSONExtract, path=lit(path))
def json_tuple(col: ColumnOrName, *fields: str) -> Column: def json_tuple(col: ColumnOrName, *fields: str) -> Column:
@ -1031,11 +1016,17 @@ def array_max(col: ColumnOrName) -> Column:
def sort_array(col: ColumnOrName, asc: t.Optional[bool] = None) -> Column: def sort_array(col: ColumnOrName, asc: t.Optional[bool] = None) -> Column:
if asc is not None: if asc is not None:
return Column.invoke_anonymous_function(col, "SORT_ARRAY", lit(asc)) return Column.invoke_expression_over_column(col, glotexp.SortArray, asc=asc)
return Column.invoke_anonymous_function(col, "SORT_ARRAY") return Column.invoke_expression_over_column(col, glotexp.SortArray)
def array_sort(col: ColumnOrName) -> Column: def array_sort(
col: ColumnOrName,
comparator: t.Optional[t.Union[t.Callable[[Column, Column], Column]]] = None,
) -> Column:
if comparator is not None:
f_expression = _get_lambda_from_func(comparator)
return Column.invoke_expression_over_column(col, glotexp.ArraySort, expression=f_expression)
return Column.invoke_expression_over_column(col, glotexp.ArraySort) return Column.invoke_expression_over_column(col, glotexp.ArraySort)
@ -1108,130 +1099,53 @@ def aggregate(
initialValue: ColumnOrName, initialValue: ColumnOrName,
merge: t.Callable[[Column, Column], Column], merge: t.Callable[[Column, Column], Column],
finish: t.Optional[t.Callable[[Column], Column]] = None, finish: t.Optional[t.Callable[[Column], Column]] = None,
accumulator_name: str = "acc",
target_row_name: str = "x",
) -> Column: ) -> Column:
merge_exp = glotexp.Lambda( merge_exp = _get_lambda_from_func(merge)
this=merge(Column(accumulator_name), Column(target_row_name)).expression,
expressions=[
glotexp.to_identifier(accumulator_name, quoted=_lambda_quoted(accumulator_name)),
glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name)),
],
)
if finish is not None: if finish is not None:
finish_exp = glotexp.Lambda( finish_exp = _get_lambda_from_func(finish)
this=finish(Column(accumulator_name)).expression,
expressions=[glotexp.to_identifier(accumulator_name, quoted=_lambda_quoted(accumulator_name))],
)
return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp)) return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp))
return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp)) return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp))
def transform( def transform(
col: ColumnOrName, col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]]
f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
target_row_name: str = "x",
row_count_name: str = "i",
) -> Column: ) -> Column:
num_arguments = len(signature(f).parameters) f_expression = _get_lambda_from_func(f)
expressions = [glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))]
columns = [Column(target_row_name)]
if num_arguments > 1:
columns.append(Column(row_count_name))
expressions.append(glotexp.to_identifier(row_count_name, quoted=_lambda_quoted(row_count_name)))
f_expression = glotexp.Lambda(this=f(*columns).expression, expressions=expressions)
return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression)) return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression))
def exists(col: ColumnOrName, f: t.Callable[[Column], Column], target_row_name: str = "x") -> Column: def exists(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column:
f_expression = glotexp.Lambda( f_expression = _get_lambda_from_func(f)
this=f(Column(target_row_name)).expression,
expressions=[glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))],
)
return Column.invoke_anonymous_function(col, "EXISTS", Column(f_expression)) return Column.invoke_anonymous_function(col, "EXISTS", Column(f_expression))
def forall(col: ColumnOrName, f: t.Callable[[Column], Column], target_row_name: str = "x") -> Column: def forall(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column:
f_expression = glotexp.Lambda( f_expression = _get_lambda_from_func(f)
this=f(Column(target_row_name)).expression,
expressions=[glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))],
)
return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression)) return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression))
def filter( def filter(col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]]) -> Column:
col: ColumnOrName, f_expression = _get_lambda_from_func(f)
f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]], return Column.invoke_expression_over_column(col, glotexp.ArrayFilter, expression=f_expression)
target_row_name: str = "x",
row_count_name: str = "i",
) -> Column:
num_arguments = len(signature(f).parameters)
expressions = [glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))]
columns = [Column(target_row_name)]
if num_arguments > 1:
columns.append(Column(row_count_name))
expressions.append(glotexp.to_identifier(row_count_name, quoted=_lambda_quoted(row_count_name)))
f_expression = glotexp.Lambda(this=f(*columns).expression, expressions=expressions)
return Column.invoke_anonymous_function(col, "FILTER", Column(f_expression))
def zip_with( def zip_with(left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column]) -> Column:
left: ColumnOrName, f_expression = _get_lambda_from_func(f)
right: ColumnOrName,
f: t.Callable[[Column, Column], Column],
left_name: str = "x",
right_name: str = "y",
) -> Column:
f_expression = glotexp.Lambda(
this=f(Column(left_name), Column(right_name)).expression,
expressions=[
glotexp.to_identifier(left_name, quoted=_lambda_quoted(left_name)),
glotexp.to_identifier(right_name, quoted=_lambda_quoted(right_name)),
],
)
return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression)) return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression))
def transform_keys( def transform_keys(col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]]) -> Column:
col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v" f_expression = _get_lambda_from_func(f)
) -> Column:
f_expression = glotexp.Lambda(
this=f(Column(key_name), Column(value_name)).expression,
expressions=[
glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)),
glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)),
],
)
return Column.invoke_anonymous_function(col, "TRANSFORM_KEYS", Column(f_expression)) return Column.invoke_anonymous_function(col, "TRANSFORM_KEYS", Column(f_expression))
def transform_values( def transform_values(col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]]) -> Column:
col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v" f_expression = _get_lambda_from_func(f)
) -> Column:
f_expression = glotexp.Lambda(
this=f(Column(key_name), Column(value_name)).expression,
expressions=[
glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)),
glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)),
],
)
return Column.invoke_anonymous_function(col, "TRANSFORM_VALUES", Column(f_expression)) return Column.invoke_anonymous_function(col, "TRANSFORM_VALUES", Column(f_expression))
def map_filter( def map_filter(col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]]) -> Column:
col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v" f_expression = _get_lambda_from_func(f)
) -> Column:
f_expression = glotexp.Lambda(
this=f(Column(key_name), Column(value_name)).expression,
expressions=[
glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)),
glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)),
],
)
return Column.invoke_anonymous_function(col, "MAP_FILTER", Column(f_expression)) return Column.invoke_anonymous_function(col, "MAP_FILTER", Column(f_expression))
@ -1239,20 +1153,18 @@ def map_zip_with(
col1: ColumnOrName, col1: ColumnOrName,
col2: ColumnOrName, col2: ColumnOrName,
f: t.Union[t.Callable[[Column, Column, Column], Column]], f: t.Union[t.Callable[[Column, Column, Column], Column]],
key_name: str = "k",
value1: str = "v1",
value2: str = "v2",
) -> Column: ) -> Column:
f_expression = glotexp.Lambda( f_expression = _get_lambda_from_func(f)
this=f(Column(key_name), Column(value1), Column(value2)).expression,
expressions=[
glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)),
glotexp.to_identifier(value1, quoted=_lambda_quoted(value1)),
glotexp.to_identifier(value2, quoted=_lambda_quoted(value2)),
],
)
return Column.invoke_anonymous_function(col1, "MAP_ZIP_WITH", col2, Column(f_expression)) return Column.invoke_anonymous_function(col1, "MAP_ZIP_WITH", col2, Column(f_expression))
def _lambda_quoted(value: str) -> t.Optional[bool]: def _lambda_quoted(value: str) -> t.Optional[bool]:
return False if value == "_" else None return False if value == "_" else None
def _get_lambda_from_func(lambda_expression: t.Callable):
variables = [glotexp.to_identifier(x, quoted=_lambda_quoted(x)) for x in lambda_expression.__code__.co_varnames]
return glotexp.Lambda(
this=lambda_expression(*[Column(x) for x in variables]).expression,
expressions=variables,
)

View file

@ -18,6 +18,36 @@ from sqlglot.helper import list_get
from sqlglot.parser import Parser, parse_var_map from sqlglot.parser import Parser, parse_var_map
from sqlglot.tokens import Tokenizer from sqlglot.tokens import Tokenizer
# (FuncType, Multiplier)
DATE_DELTA_INTERVAL = {
"YEAR": ("ADD_MONTHS", 12),
"MONTH": ("ADD_MONTHS", 1),
"QUARTER": ("ADD_MONTHS", 3),
"WEEK": ("DATE_ADD", 7),
"DAY": ("DATE_ADD", 1),
}
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
def _add_date_sql(self, expression):
unit = expression.text("unit").upper()
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
modified_increment = (
int(expression.text("expression")) * multiplier if expression.expression.is_number else expression.expression
)
modified_increment = exp.Literal.number(modified_increment)
return f"{func}({self.format_args(expression.this, modified_increment.this)})"
def _date_diff_sql(self, expression):
unit = expression.text("unit").upper()
sql_func = "MONTHS_BETWEEN" if unit in DIFF_MONTH_SWITCH else "DATEDIFF"
_, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1))
multiplier_sql = f" / {multiplier}" if multiplier > 1 else ""
diff_sql = f"{sql_func}({self.format_args(expression.this, expression.expression)})"
return f"{diff_sql}{multiplier_sql}"
def _array_sort(self, expression): def _array_sort(self, expression):
if expression.expression: if expression.expression:
@ -120,10 +150,14 @@ class Hive(Dialect):
"m": "%-M", "m": "%-M",
"ss": "%S", "ss": "%S",
"s": "%-S", "s": "%-S",
"S": "%f", "SSSSSS": "%f",
"a": "%p", "a": "%p",
"DD": "%j", "DD": "%j",
"D": "%-j", "D": "%-j",
"E": "%a",
"EE": "%a",
"EEE": "%a",
"EEEE": "%A",
} }
date_format = "'yyyy-MM-dd'" date_format = "'yyyy-MM-dd'"
@ -207,8 +241,8 @@ class Hive(Dialect):
exp.ArraySize: rename_func("SIZE"), exp.ArraySize: rename_func("SIZE"),
exp.ArraySort: _array_sort, exp.ArraySort: _array_sort,
exp.With: no_recursive_cte_sql, exp.With: no_recursive_cte_sql,
exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.DateAdd: _add_date_sql,
exp.DateDiff: lambda self, e: f"DATEDIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.DateDiff: _date_diff_sql,
exp.DateStrToDate: rename_func("TO_DATE"), exp.DateStrToDate: rename_func("TO_DATE"),
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)", exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})", exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",

View file

@ -71,6 +71,7 @@ class Spark(Hive):
length=list_get(args, 1), length=list_get(args, 1),
), ),
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"IIF": exp.If.from_arg_list,
} }
FUNCTION_PARSERS = { FUNCTION_PARSERS = {
@ -111,6 +112,7 @@ class Spark(Hive):
exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}", exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
exp.TimestampTrunc: lambda self, e: f"DATE_TRUNC({self.sql(e, 'unit')}, {self.sql(e, 'this')})", exp.TimestampTrunc: lambda self, e: f"DATE_TRUNC({self.sql(e, 'unit')}, {self.sql(e, 'this')})",
exp.VariancePop: rename_func("VAR_POP"), exp.VariancePop: rename_func("VAR_POP"),
exp.DateFromParts: rename_func("MAKE_DATE"),
} }
WRAP_DERIVED_VALUES = False WRAP_DERIVED_VALUES = False

View file

@ -1,5 +1,5 @@
from sqlglot import exp from sqlglot import exp
from sqlglot.dialects.dialect import rename_func from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func
from sqlglot.dialects.mysql import MySQL from sqlglot.dialects.mysql import MySQL
@ -14,6 +14,8 @@ class StarRocks(MySQL):
TRANSFORMS = { TRANSFORMS = {
**MySQL.Generator.TRANSFORMS, **MySQL.Generator.TRANSFORMS,
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.DateDiff: rename_func("DATEDIFF"), exp.DateDiff: rename_func("DATEDIFF"),
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToDate: rename_func("TO_DATE"), exp.TimeStrToDate: rename_func("TO_DATE"),

View file

@ -1,14 +1,149 @@
from sqlglot import exp from sqlglot import exp
from sqlglot.dialects.dialect import Dialect from sqlglot.dialects.dialect import Dialect, rename_func
from sqlglot.expressions import DataType
from sqlglot.generator import Generator from sqlglot.generator import Generator
from sqlglot.helper import list_get
from sqlglot.parser import Parser from sqlglot.parser import Parser
from sqlglot.time import format_time
from sqlglot.tokens import Tokenizer, TokenType from sqlglot.tokens import Tokenizer, TokenType
FULL_FORMAT_TIME_MAPPING = {"weekday": "%A", "dw": "%A", "w": "%A", "month": "%B", "mm": "%B", "m": "%B"}
DATE_DELTA_INTERVAL = {
"year": "year",
"yyyy": "year",
"yy": "year",
"quarter": "quarter",
"qq": "quarter",
"q": "quarter",
"month": "month",
"mm": "month",
"m": "month",
"week": "week",
"ww": "week",
"wk": "week",
"day": "day",
"dd": "day",
"d": "day",
}
def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
def _format_time(args):
return exp_class(
this=list_get(args, 1),
format=exp.Literal.string(
format_time(
list_get(args, 0).name or (TSQL.time_format if default is True else default),
{**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING} if full_format_mapping else TSQL.time_mapping,
)
),
)
return _format_time
def parse_date_delta(exp_class):
def inner_func(args):
unit = DATE_DELTA_INTERVAL.get(list_get(args, 0).name.lower(), "day")
return exp_class(this=list_get(args, 2), expression=list_get(args, 1), unit=unit)
return inner_func
def generate_date_delta(self, e):
func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF"
return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})"
class TSQL(Dialect): class TSQL(Dialect):
null_ordering = "nulls_are_small" null_ordering = "nulls_are_small"
time_format = "'yyyy-mm-dd hh:mm:ss'" time_format = "'yyyy-mm-dd hh:mm:ss'"
time_mapping = {
"yyyy": "%Y",
"yy": "%y",
"year": "%Y",
"qq": "%q",
"q": "%q",
"quarter": "%q",
"dayofyear": "%j",
"day": "%d",
"dy": "%d",
"y": "%Y",
"week": "%W",
"ww": "%W",
"wk": "%W",
"hour": "%h",
"hh": "%I",
"minute": "%M",
"mi": "%M",
"n": "%M",
"second": "%S",
"ss": "%S",
"s": "%-S",
"millisecond": "%f",
"ms": "%f",
"weekday": "%W",
"dw": "%W",
"month": "%m",
"mm": "%M",
"m": "%-M",
"Y": "%Y",
"YYYY": "%Y",
"YY": "%y",
"MMMM": "%B",
"MMM": "%b",
"MM": "%m",
"M": "%-m",
"dd": "%d",
"d": "%-d",
"HH": "%H",
"H": "%-H",
"h": "%-I",
"S": "%f",
}
convert_format_mapping = {
"0": "%b %d %Y %-I:%M%p",
"1": "%m/%d/%y",
"2": "%y.%m.%d",
"3": "%d/%m/%y",
"4": "%d.%m.%y",
"5": "%d-%m-%y",
"6": "%d %b %y",
"7": "%b %d, %y",
"8": "%H:%M:%S",
"9": "%b %d %Y %-I:%M:%S:%f%p",
"10": "mm-dd-yy",
"11": "yy/mm/dd",
"12": "yymmdd",
"13": "%d %b %Y %H:%M:ss:%f",
"14": "%H:%M:%S:%f",
"20": "%Y-%m-%d %H:%M:%S",
"21": "%Y-%m-%d %H:%M:%S.%f",
"22": "%m/%d/%y %-I:%M:%S %p",
"23": "%Y-%m-%d",
"24": "%H:%M:%S",
"25": "%Y-%m-%d %H:%M:%S.%f",
"100": "%b %d %Y %-I:%M%p",
"101": "%m/%d/%Y",
"102": "%Y.%m.%d",
"103": "%d/%m/%Y",
"104": "%d.%m.%Y",
"105": "%d-%m-%Y",
"106": "%d %b %Y",
"107": "%b %d, %Y",
"108": "%H:%M:%S",
"109": "%b %d %Y %-I:%M:%S:%f%p",
"110": "%m-%d-%Y",
"111": "%Y/%m/%d",
"112": "%Y%m%d",
"113": "%d %b %Y %H:%M:%S:%f",
"114": "%H:%M:%S:%f",
"120": "%Y-%m-%d %H:%M:%S",
"121": "%Y-%m-%d %H:%M:%S.%f",
}
class Tokenizer(Tokenizer): class Tokenizer(Tokenizer):
IDENTIFIERS = ['"', ("[", "]")] IDENTIFIERS = ['"', ("[", "]")]
@ -29,19 +164,67 @@ class TSQL(Dialect):
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER, "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
"XML": TokenType.XML, "XML": TokenType.XML,
"SQL_VARIANT": TokenType.VARIANT, "SQL_VARIANT": TokenType.VARIANT,
"NVARCHAR(MAX)": TokenType.TEXT,
"VARCHAR(MAX)": TokenType.TEXT,
} }
class Parser(Parser): class Parser(Parser):
FUNCTIONS = { FUNCTIONS = {
**Parser.FUNCTIONS, **Parser.FUNCTIONS,
"CHARINDEX": exp.StrPosition.from_arg_list, "CHARINDEX": exp.StrPosition.from_arg_list,
"ISNULL": exp.Coalesce.from_arg_list,
"DATEADD": parse_date_delta(exp.DateAdd),
"DATEDIFF": parse_date_delta(exp.DateDiff),
"DATENAME": tsql_format_time_lambda(exp.TimeToStr, full_format_mapping=True),
"DATEPART": tsql_format_time_lambda(exp.TimeToStr),
"GETDATE": exp.CurrentDate.from_arg_list,
"IIF": exp.If.from_arg_list,
"LEN": exp.Length.from_arg_list,
"REPLICATE": exp.Repeat.from_arg_list,
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
} }
def _parse_convert(self): VAR_LENGTH_DATATYPES = {
DataType.Type.NVARCHAR,
DataType.Type.VARCHAR,
DataType.Type.CHAR,
DataType.Type.NCHAR,
}
def _parse_convert(self, strict):
to = self._parse_types() to = self._parse_types()
self._match(TokenType.COMMA) self._match(TokenType.COMMA)
this = self._parse_field() this = self._parse_field()
return self.expression(exp.Cast, this=this, to=to)
# Retrieve length of datatype and override to default if not specified
if list_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False)
# Check whether a conversion with format is applicable
if self._match(TokenType.COMMA):
format_val = self._parse_number().name
if format_val not in TSQL.convert_format_mapping:
raise ValueError(f"CONVERT function at T-SQL does not support format style {format_val}")
format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val])
# Check whether the convert entails a string to date format
if to.this == DataType.Type.DATE:
return self.expression(exp.StrToDate, this=this, format=format_norm)
# Check whether the convert entails a string to datetime format
elif to.this == DataType.Type.DATETIME:
return self.expression(exp.StrToTime, this=this, format=format_norm)
# Check whether the convert entails a date to string format
elif to.this in self.VAR_LENGTH_DATATYPES:
return self.expression(
exp.Cast if strict else exp.TryCast,
to=to,
this=self.expression(exp.TimeToStr, this=this, format=format_norm),
)
elif to.this == DataType.Type.TEXT:
return self.expression(exp.TimeToStr, this=this, format=format_norm)
# Entails a simple cast without any format requirement
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
class Generator(Generator): class Generator(Generator):
TYPE_MAPPING = { TYPE_MAPPING = {
@ -52,3 +235,11 @@ class TSQL(Dialect):
exp.DataType.Type.DATETIME: "DATETIME2", exp.DataType.Type.DATETIME: "DATETIME2",
exp.DataType.Type.VARIANT: "SQL_VARIANT", exp.DataType.Type.VARIANT: "SQL_VARIANT",
} }
TRANSFORMS = {
**Generator.TRANSFORMS,
exp.DateAdd: lambda self, e: generate_date_delta(self, e),
exp.DateDiff: lambda self, e: generate_date_delta(self, e),
exp.CurrentDate: rename_func("GETDATE"),
exp.If: rename_func("IIF"),
}

View file

@ -2411,6 +2411,11 @@ class TimeTrunc(Func, TimeUnit):
arg_types = {"this": True, "unit": True, "zone": False} arg_types = {"this": True, "unit": True, "zone": False}
class DateFromParts(Func):
_sql_names = ["DATEFROMPARTS"]
arg_types = {"year": True, "month": True, "day": True}
class DateStrToDate(Func): class DateStrToDate(Func):
pass pass
@ -2554,7 +2559,7 @@ class Quantile(AggFunc):
class ApproxQuantile(Quantile): class ApproxQuantile(Quantile):
pass arg_types = {"this": True, "quantile": True, "accuracy": False}
class Reduce(Func): class Reduce(Func):
@ -2569,6 +2574,10 @@ class RegexpSplit(Func):
arg_types = {"this": True, "expression": True} arg_types = {"this": True, "expression": True}
class Repeat(Func):
arg_types = {"this": True, "times": True}
class Round(Func): class Round(Func):
arg_types = {"this": True, "decimals": False} arg_types = {"this": True, "decimals": False}
@ -2690,7 +2699,7 @@ class TsOrDiToDi(Func):
class UnixToStr(Func): class UnixToStr(Func):
arg_types = {"this": True, "format": True} arg_types = {"this": True, "format": False}
class UnixToTime(Func): class UnixToTime(Func):
@ -3077,6 +3086,8 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts):
) )
if from_: if from_:
update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts)) update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts))
if isinstance(where, Condition):
where = Where(this=where)
if where: if where:
update.set("where", maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts)) update.set("where", maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts))
return update return update
@ -3518,6 +3529,41 @@ def replace_tables(expression, mapping):
return expression.transform(_replace_tables) return expression.transform(_replace_tables)
def replace_placeholders(expression, *args, **kwargs):
"""Replace placeholders in an expression.
Args:
expression (sqlglot.Expression): Expression node to be transformed and replaced
args: Positional names that will substitute unnamed placeholders in the given order
kwargs: Keyword arguments that will substitute named placeholders
Examples:
>>> from sqlglot import exp, parse_one
>>> replace_placeholders(
... parse_one("select * from :tbl where ? = ?"), "a", "b", tbl="foo"
... ).sql()
'SELECT * FROM foo WHERE a = b'
Returns:
The mapped expression
"""
def _replace_placeholders(node, args, **kwargs):
if isinstance(node, Placeholder):
if node.name:
new_name = kwargs.get(node.name)
if new_name:
return to_identifier(new_name)
else:
try:
return to_identifier(next(args))
except StopIteration:
pass
return node
return expression.transform(_replace_placeholders, iter(args), **kwargs)
TRUE = Boolean(this=True) TRUE = Boolean(this=True)
FALSE = Boolean(this=False) FALSE = Boolean(this=False)
NULL = Null() NULL = Null()

View file

@ -47,7 +47,8 @@ class Generator:
The default is on the smaller end because the length only represents a segment and not the true The default is on the smaller end because the length only represents a segment and not the true
line length. line length.
Default: 80 Default: 80
annotations: Whether or not to show annotations in the SQL. annotations: Whether or not to show annotations in the SQL when `pretty` is True.
Annotations can only be shown in pretty mode otherwise they may clobber resulting sql.
Default: True Default: True
""" """
@ -280,7 +281,7 @@ class Generator:
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}") raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
def annotation_sql(self, expression): def annotation_sql(self, expression):
if self._annotations: if self._annotations and self.pretty:
return f"{self.sql(expression, 'expression')} # {expression.name}" return f"{self.sql(expression, 'expression')} # {expression.name}"
return self.sql(expression, "expression") return self.sql(expression, "expression")

View file

@ -194,6 +194,24 @@ def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: b
return words + [None] * (min_num_words - len(words)) return words + [None] * (min_num_words - len(words))
def is_iterable(value: t.Any) -> bool:
"""
Checks if the value is an iterable but does not include strings and bytes
Examples:
>>> is_iterable([1,2])
True
>>> is_iterable("test")
False
Args:
value: The value to check if it is an interable
Returns: Bool indicating if it is an iterable
"""
return hasattr(value, "__iter__") and not isinstance(value, (str, bytes))
def flatten(values: t.Iterable[t.Union[t.Iterable[t.Any], t.Any]]) -> t.Generator[t.Any, None, None]: def flatten(values: t.Iterable[t.Union[t.Iterable[t.Any], t.Any]]) -> t.Generator[t.Any, None, None]:
""" """
Flattens a list that can contain both iterables and non-iterable elements Flattens a list that can contain both iterables and non-iterable elements
@ -211,7 +229,7 @@ def flatten(values: t.Iterable[t.Union[t.Iterable[t.Any], t.Any]]) -> t.Generato
Yields non-iterable elements (not including str or byte as iterable) Yields non-iterable elements (not including str or byte as iterable)
""" """
for value in values: for value in values:
if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)): if is_iterable(value):
yield from flatten(value) yield from flatten(value)
else: else:
yield value yield value

View file

@ -433,7 +433,8 @@ class Parser:
} }
FUNCTION_PARSERS = { FUNCTION_PARSERS = {
"CONVERT": lambda self: self._parse_convert(), "CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
"TRY_CONVERT": lambda self: self._parse_convert(False),
"EXTRACT": lambda self: self._parse_extract(), "EXTRACT": lambda self: self._parse_extract(),
"POSITION": lambda self: self._parse_position(), "POSITION": lambda self: self._parse_position(),
"SUBSTRING": lambda self: self._parse_substring(), "SUBSTRING": lambda self: self._parse_substring(),
@ -1512,7 +1513,7 @@ class Parser:
return this return this
def _parse_offset(self, this=None): def _parse_offset(self, this=None):
if not self._match(TokenType.OFFSET): if not self._match_set((TokenType.OFFSET, TokenType.COMMA)):
return this return this
count = self._parse_number() count = self._parse_number()
self._match_set((TokenType.ROW, TokenType.ROWS)) self._match_set((TokenType.ROW, TokenType.ROWS))
@ -2134,7 +2135,7 @@ class Parser:
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
def _parse_convert(self): def _parse_convert(self, strict):
this = self._parse_field() this = self._parse_field()
if self._match(TokenType.USING): if self._match(TokenType.USING):
to = self.expression(exp.CharacterSet, this=self._parse_var()) to = self.expression(exp.CharacterSet, this=self._parse_var())
@ -2142,7 +2143,7 @@ class Parser:
to = self._parse_types() to = self._parse_types()
else: else:
to = None to = None
return self.expression(exp.Cast, this=this, to=to) return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
def _parse_position(self): def _parse_position(self):
args = self._parse_csv(self._parse_bitwise) args = self._parse_csv(self._parse_bitwise)

View file

@ -14,6 +14,8 @@ def format_time(string, mapping, trie=None):
mapping: Dictionary of time format to target time format mapping: Dictionary of time format to target time format
trie: Optional trie, can be passed in for performance trie: Optional trie, can be passed in for performance
""" """
if not string:
return None
start = 0 start = 0
end = 1 end = 1
size = len(string) size = len(string)

View file

@ -9,7 +9,6 @@ from sqlglot.errors import ErrorLevel
class TestFunctions(unittest.TestCase): class TestFunctions(unittest.TestCase):
@unittest.skip("not yet fixed.")
def test_invoke_anonymous(self): def test_invoke_anonymous(self):
for name, func in inspect.getmembers(SF, inspect.isfunction): for name, func in inspect.getmembers(SF, inspect.isfunction):
with self.subTest(f"{name} should not invoke anonymous_function"): with self.subTest(f"{name} should not invoke anonymous_function"):
@ -438,13 +437,13 @@ class TestFunctions(unittest.TestCase):
def test_pow(self): def test_pow(self):
col_str = SF.pow("cola", "colb") col_str = SF.pow("cola", "colb")
self.assertEqual("POW(cola, colb)", col_str.sql()) self.assertEqual("POWER(cola, colb)", col_str.sql())
col = SF.pow(SF.col("cola"), SF.col("colb")) col = SF.pow(SF.col("cola"), SF.col("colb"))
self.assertEqual("POW(cola, colb)", col.sql()) self.assertEqual("POWER(cola, colb)", col.sql())
col_float = SF.pow(10.10, "colb") col_float = SF.pow(10.10, "colb")
self.assertEqual("POW(10.1, colb)", col_float.sql()) self.assertEqual("POWER(10.1, colb)", col_float.sql())
col_float2 = SF.pow("cola", 10.10) col_float2 = SF.pow("cola", 10.10)
self.assertEqual("POW(cola, 10.1)", col_float2.sql()) self.assertEqual("POWER(cola, 10.1)", col_float2.sql())
def test_row_number(self): def test_row_number(self):
col_str = SF.row_number() col_str = SF.row_number()
@ -493,6 +492,8 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("COALESCE(cola, colb, colc)", col_str.sql()) self.assertEqual("COALESCE(cola, colb, colc)", col_str.sql())
col = SF.coalesce(SF.col("cola"), "colb", SF.col("colc")) col = SF.coalesce(SF.col("cola"), "colb", SF.col("colc"))
self.assertEqual("COALESCE(cola, colb, colc)", col.sql()) self.assertEqual("COALESCE(cola, colb, colc)", col.sql())
col_single = SF.coalesce("cola")
self.assertEqual("COALESCE(cola)", col_single.sql())
def test_corr(self): def test_corr(self):
col_str = SF.corr("cola", "colb") col_str = SF.corr("cola", "colb")
@ -843,8 +844,8 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("TO_DATE(cola)", col_str.sql()) self.assertEqual("TO_DATE(cola)", col_str.sql())
col = SF.to_date(SF.col("cola")) col = SF.to_date(SF.col("cola"))
self.assertEqual("TO_DATE(cola)", col.sql()) self.assertEqual("TO_DATE(cola)", col.sql())
col_with_format = SF.to_date("cola", "yyyy-MM-dd") col_with_format = SF.to_date("cola", "yy-MM-dd")
self.assertEqual("TO_DATE(cola, 'yyyy-MM-dd')", col_with_format.sql()) self.assertEqual("TO_DATE(cola, 'yy-MM-dd')", col_with_format.sql())
def test_to_timestamp(self): def test_to_timestamp(self):
col_str = SF.to_timestamp("cola") col_str = SF.to_timestamp("cola")
@ -883,16 +884,16 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("FROM_UNIXTIME(cola)", col_str.sql()) self.assertEqual("FROM_UNIXTIME(cola)", col_str.sql())
col = SF.from_unixtime(SF.col("cola")) col = SF.from_unixtime(SF.col("cola"))
self.assertEqual("FROM_UNIXTIME(cola)", col.sql()) self.assertEqual("FROM_UNIXTIME(cola)", col.sql())
col_format = SF.from_unixtime("cola", "yyyy-MM-dd HH:mm:ss") col_format = SF.from_unixtime("cola", "yyyy-MM-dd HH:mm")
self.assertEqual("FROM_UNIXTIME(cola, 'yyyy-MM-dd HH:mm:ss')", col_format.sql()) self.assertEqual("FROM_UNIXTIME(cola, 'yyyy-MM-dd HH:mm')", col_format.sql())
def test_unix_timestamp(self): def test_unix_timestamp(self):
col_str = SF.unix_timestamp("cola") col_str = SF.unix_timestamp("cola")
self.assertEqual("UNIX_TIMESTAMP(cola)", col_str.sql()) self.assertEqual("UNIX_TIMESTAMP(cola)", col_str.sql())
col = SF.unix_timestamp(SF.col("cola")) col = SF.unix_timestamp(SF.col("cola"))
self.assertEqual("UNIX_TIMESTAMP(cola)", col.sql()) self.assertEqual("UNIX_TIMESTAMP(cola)", col.sql())
col_format = SF.unix_timestamp("cola", "yyyy-MM-dd HH:mm:ss") col_format = SF.unix_timestamp("cola", "yyyy-MM-dd HH:mm")
self.assertEqual("UNIX_TIMESTAMP(cola, 'yyyy-MM-dd HH:mm:ss')", col_format.sql()) self.assertEqual("UNIX_TIMESTAMP(cola, 'yyyy-MM-dd HH:mm')", col_format.sql())
col_current = SF.unix_timestamp() col_current = SF.unix_timestamp()
self.assertEqual("UNIX_TIMESTAMP()", col_current.sql()) self.assertEqual("UNIX_TIMESTAMP()", col_current.sql())
@ -1427,6 +1428,13 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("ARRAY_SORT(cola)", col_str.sql()) self.assertEqual("ARRAY_SORT(cola)", col_str.sql())
col = SF.array_sort(SF.col("cola")) col = SF.array_sort(SF.col("cola"))
self.assertEqual("ARRAY_SORT(cola)", col.sql()) self.assertEqual("ARRAY_SORT(cola)", col.sql())
col_comparator = SF.array_sort(
"cola", lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(SF.length(y) - SF.length(x))
)
self.assertEqual(
"ARRAY_SORT(cola, (x, y) -> CASE WHEN x IS NULL OR y IS NULL THEN 0 ELSE LENGTH(y) - LENGTH(x) END)",
col_comparator.sql(),
)
def test_reverse(self): def test_reverse(self):
col_str = SF.reverse("cola") col_str = SF.reverse("cola")
@ -1514,8 +1522,6 @@ class TestFunctions(unittest.TestCase):
SF.lit(0), SF.lit(0),
lambda accumulator, target: accumulator + target, lambda accumulator, target: accumulator + target,
lambda accumulator: accumulator * 2, lambda accumulator: accumulator * 2,
"accumulator",
"target",
) )
self.assertEqual( self.assertEqual(
"AGGREGATE(cola, 0, (accumulator, target) -> accumulator + target, accumulator -> accumulator * 2)", "AGGREGATE(cola, 0, (accumulator, target) -> accumulator + target, accumulator -> accumulator * 2)",
@ -1527,7 +1533,7 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("TRANSFORM(cola, x -> x * 2)", col_str.sql()) self.assertEqual("TRANSFORM(cola, x -> x * 2)", col_str.sql())
col = SF.transform(SF.col("cola"), lambda x, i: x * i) col = SF.transform(SF.col("cola"), lambda x, i: x * i)
self.assertEqual("TRANSFORM(cola, (x, i) -> x * i)", col.sql()) self.assertEqual("TRANSFORM(cola, (x, i) -> x * i)", col.sql())
col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count, "target", "row_count") col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count)
self.assertEqual("TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql()) self.assertEqual("TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql())
@ -1536,7 +1542,7 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("EXISTS(cola, x -> x % 2 = 0)", col_str.sql()) self.assertEqual("EXISTS(cola, x -> x % 2 = 0)", col_str.sql())
col = SF.exists(SF.col("cola"), lambda x: x % 2 == 0) col = SF.exists(SF.col("cola"), lambda x: x % 2 == 0)
self.assertEqual("EXISTS(cola, x -> x % 2 = 0)", col.sql()) self.assertEqual("EXISTS(cola, x -> x % 2 = 0)", col.sql())
col_custom_name = SF.exists("cola", lambda target: target > 0, "target") col_custom_name = SF.exists("cola", lambda target: target > 0)
self.assertEqual("EXISTS(cola, target -> target > 0)", col_custom_name.sql()) self.assertEqual("EXISTS(cola, target -> target > 0)", col_custom_name.sql())
def test_forall(self): def test_forall(self):
@ -1544,7 +1550,7 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("FORALL(cola, x -> x RLIKE 'foo')", col_str.sql()) self.assertEqual("FORALL(cola, x -> x RLIKE 'foo')", col_str.sql())
col = SF.forall(SF.col("cola"), lambda x: x.rlike("foo")) col = SF.forall(SF.col("cola"), lambda x: x.rlike("foo"))
self.assertEqual("FORALL(cola, x -> x RLIKE 'foo')", col.sql()) self.assertEqual("FORALL(cola, x -> x RLIKE 'foo')", col.sql())
col_custom_name = SF.forall("cola", lambda target: target.rlike("foo"), "target") col_custom_name = SF.forall("cola", lambda target: target.rlike("foo"))
self.assertEqual("FORALL(cola, target -> target RLIKE 'foo')", col_custom_name.sql()) self.assertEqual("FORALL(cola, target -> target RLIKE 'foo')", col_custom_name.sql())
def test_filter(self): def test_filter(self):
@ -1552,9 +1558,7 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("FILTER(cola, x -> MONTH(TO_DATE(x)) > 6)", col_str.sql()) self.assertEqual("FILTER(cola, x -> MONTH(TO_DATE(x)) > 6)", col_str.sql())
col = SF.filter(SF.col("cola"), lambda x, i: SF.month(SF.to_date(x)) > SF.lit(i)) col = SF.filter(SF.col("cola"), lambda x, i: SF.month(SF.to_date(x)) > SF.lit(i))
self.assertEqual("FILTER(cola, (x, i) -> MONTH(TO_DATE(x)) > i)", col.sql()) self.assertEqual("FILTER(cola, (x, i) -> MONTH(TO_DATE(x)) > i)", col.sql())
col_custom_names = SF.filter( col_custom_names = SF.filter("cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count))
"cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count), "target", "row_count"
)
self.assertEqual( self.assertEqual(
"FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)", col_custom_names.sql() "FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)", col_custom_names.sql()
@ -1565,7 +1569,7 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col_str.sql()) self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col_str.sql())
col = SF.zip_with(SF.col("cola"), SF.col("colb"), lambda x, y: SF.concat_ws("_", x, y)) col = SF.zip_with(SF.col("cola"), SF.col("colb"), lambda x, y: SF.concat_ws("_", x, y))
self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col.sql()) self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col.sql())
col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r), "l", "r") col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r))
self.assertEqual("ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql()) self.assertEqual("ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql())
def test_transform_keys(self): def test_transform_keys(self):
@ -1573,7 +1577,7 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("TRANSFORM_KEYS(cola, (k, v) -> UPPER(k))", col_str.sql()) self.assertEqual("TRANSFORM_KEYS(cola, (k, v) -> UPPER(k))", col_str.sql())
col = SF.transform_keys(SF.col("cola"), lambda k, v: SF.upper(k)) col = SF.transform_keys(SF.col("cola"), lambda k, v: SF.upper(k))
self.assertEqual("TRANSFORM_KEYS(cola, (k, v) -> UPPER(k))", col.sql()) self.assertEqual("TRANSFORM_KEYS(cola, (k, v) -> UPPER(k))", col.sql())
col_custom_names = SF.transform_keys("cola", lambda key, _: SF.upper(key), "key", "_") col_custom_names = SF.transform_keys("cola", lambda key, _: SF.upper(key))
self.assertEqual("TRANSFORM_KEYS(cola, (key, _) -> UPPER(key))", col_custom_names.sql()) self.assertEqual("TRANSFORM_KEYS(cola, (key, _) -> UPPER(key))", col_custom_names.sql())
def test_transform_values(self): def test_transform_values(self):
@ -1581,7 +1585,7 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col_str.sql()) self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col_str.sql())
col = SF.transform_values(SF.col("cola"), lambda k, v: SF.upper(v)) col = SF.transform_values(SF.col("cola"), lambda k, v: SF.upper(v))
self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col.sql()) self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col.sql())
col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value), "_", "value") col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value))
self.assertEqual("TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql()) self.assertEqual("TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql())
def test_map_filter(self): def test_map_filter(self):
@ -1589,5 +1593,9 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("MAP_FILTER(cola, (k, v) -> k > v)", col_str.sql()) self.assertEqual("MAP_FILTER(cola, (k, v) -> k > v)", col_str.sql())
col = SF.map_filter(SF.col("cola"), lambda k, v: k > v) col = SF.map_filter(SF.col("cola"), lambda k, v: k > v)
self.assertEqual("MAP_FILTER(cola, (k, v) -> k > v)", col.sql()) self.assertEqual("MAP_FILTER(cola, (k, v) -> k > v)", col.sql())
col_custom_names = SF.map_filter("cola", lambda key, value: key > value, "key", "value") col_custom_names = SF.map_filter("cola", lambda key, value: key > value)
self.assertEqual("MAP_FILTER(cola, (key, value) -> key > value)", col_custom_names.sql()) self.assertEqual("MAP_FILTER(cola, (key, value) -> key > value)", col_custom_names.sql())
def test_map_zip_with(self):
col = SF.map_zip_with("base", "ratio", lambda k, v1, v2: SF.round(v1 * v2, 2))
self.assertEqual("MAP_ZIP_WITH(base, ratio, (k, v1, v2) -> ROUND(v1 * v2, 2))", col.sql())

View file

@ -105,6 +105,15 @@ class TestBigQuery(Validator):
"spark": "x IS NULL", "spark": "x IS NULL",
}, },
) )
self.validate_all(
"CURRENT_DATE",
read={
"tsql": "GETDATE()",
},
write={
"tsql": "GETDATE()",
},
)
self.validate_all( self.validate_all(
"current_datetime", "current_datetime",
write={ write={

View file

@ -434,12 +434,7 @@ class TestDialect(Validator):
"presto": "DATE_ADD('day', 1, x)", "presto": "DATE_ADD('day', 1, x)",
"spark": "DATE_ADD(x, 1)", "spark": "DATE_ADD(x, 1)",
"starrocks": "DATE_ADD(x, INTERVAL 1 DAY)", "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)",
}, "tsql": "DATEADD(day, 1, x)",
)
self.validate_all(
"DATE_ADD(x, y, 'day')",
write={
"postgres": UnsupportedError,
}, },
) )
self.validate_all( self.validate_all(
@ -634,11 +629,13 @@ class TestDialect(Validator):
read={ read={
"postgres": "x->'y'", "postgres": "x->'y'",
"presto": "JSON_EXTRACT(x, 'y')", "presto": "JSON_EXTRACT(x, 'y')",
"starrocks": "x->'y'",
}, },
write={ write={
"oracle": "JSON_EXTRACT(x, 'y')", "oracle": "JSON_EXTRACT(x, 'y')",
"postgres": "x->'y'", "postgres": "x->'y'",
"presto": "JSON_EXTRACT(x, 'y')", "presto": "JSON_EXTRACT(x, 'y')",
"starrocks": "x->'y'",
}, },
) )
self.validate_all( self.validate_all(
@ -983,6 +980,7 @@ class TestDialect(Validator):
) )
def test_limit(self): def test_limit(self):
self.validate_all("SELECT * FROM data LIMIT 10, 20", write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"})
self.validate_all( self.validate_all(
"SELECT x FROM y LIMIT 10", "SELECT x FROM y LIMIT 10",
write={ write={

View file

@ -282,3 +282,6 @@ TBLPROPERTIES (
"spark": "SELECT ARRAY_SORT(x)", "spark": "SELECT ARRAY_SORT(x)",
}, },
) )
def test_iif(self):
self.validate_all("SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"})

View file

@ -71,3 +71,226 @@ class TestTSQL(Validator):
"spark": "LOCATE('sub', 'testsubstring')", "spark": "LOCATE('sub', 'testsubstring')",
}, },
) )
def test_len(self):
self.validate_all("LEN(x)", write={"spark": "LENGTH(x)"})
def test_replicate(self):
self.validate_all("REPLICATE('x', 2)", write={"spark": "REPEAT('x', 2)"})
def test_isnull(self):
self.validate_all("ISNULL(x, y)", write={"spark": "COALESCE(x, y)"})
def test_jsonvalue(self):
self.validate_all(
"JSON_VALUE(r.JSON, '$.Attr_INT')",
write={"spark": "GET_JSON_OBJECT(r.JSON, '$.Attr_INT')"},
)
def test_datefromparts(self):
self.validate_all(
"SELECT DATEFROMPARTS('2020', 10, 01)",
write={"spark": "SELECT MAKE_DATE('2020', 10, 01)"},
)
def test_datename(self):
self.validate_all(
"SELECT DATENAME(mm,'01-01-1970')",
write={"spark": "SELECT DATE_FORMAT('01-01-1970', 'MMMM')"},
)
self.validate_all(
"SELECT DATENAME(dw,'01-01-1970')",
write={"spark": "SELECT DATE_FORMAT('01-01-1970', 'EEEE')"},
)
def test_datepart(self):
self.validate_all(
"SELECT DATEPART(month,'01-01-1970')",
write={"spark": "SELECT DATE_FORMAT('01-01-1970', 'MM')"},
)
def test_convert_date_format(self):
self.validate_all(
"CONVERT(NVARCHAR(200), x)",
write={
"spark": "CAST(x AS VARCHAR(200))",
},
)
self.validate_all(
"CONVERT(NVARCHAR, x)",
write={
"spark": "CAST(x AS VARCHAR(30))",
},
)
self.validate_all(
"CONVERT(NVARCHAR(MAX), x)",
write={
"spark": "CAST(x AS STRING)",
},
)
self.validate_all(
"CONVERT(VARCHAR(200), x)",
write={
"spark": "CAST(x AS VARCHAR(200))",
},
)
self.validate_all(
"CONVERT(VARCHAR, x)",
write={
"spark": "CAST(x AS VARCHAR(30))",
},
)
self.validate_all(
"CONVERT(VARCHAR(MAX), x)",
write={
"spark": "CAST(x AS STRING)",
},
)
self.validate_all(
"CONVERT(CHAR(40), x)",
write={
"spark": "CAST(x AS CHAR(40))",
},
)
self.validate_all(
"CONVERT(CHAR, x)",
write={
"spark": "CAST(x AS CHAR(30))",
},
)
self.validate_all(
"CONVERT(NCHAR(40), x)",
write={
"spark": "CAST(x AS CHAR(40))",
},
)
self.validate_all(
"CONVERT(NCHAR, x)",
write={
"spark": "CAST(x AS CHAR(30))",
},
)
self.validate_all(
"CONVERT(VARCHAR, x, 121)",
write={
"spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(30))",
},
)
self.validate_all(
"CONVERT(VARCHAR(40), x, 121)",
write={
"spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(40))",
},
)
self.validate_all(
"CONVERT(VARCHAR(MAX), x, 121)",
write={
"spark": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')",
},
)
self.validate_all(
"CONVERT(NVARCHAR, x, 121)",
write={
"spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(30))",
},
)
self.validate_all(
"CONVERT(NVARCHAR(40), x, 121)",
write={
"spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(40))",
},
)
self.validate_all(
"CONVERT(NVARCHAR(MAX), x, 121)",
write={
"spark": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')",
},
)
self.validate_all(
"CONVERT(DATE, x, 121)",
write={
"spark": "TO_DATE(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')",
},
)
self.validate_all(
"CONVERT(DATETIME, x, 121)",
write={
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')",
},
)
self.validate_all(
"CONVERT(DATETIME2, x, 121)",
write={
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')",
},
)
self.validate_all(
"CONVERT(INT, x)",
write={
"spark": "CAST(x AS INT)",
},
)
self.validate_all(
"CONVERT(INT, x, 121)",
write={
"spark": "CAST(x AS INT)",
},
)
self.validate_all(
"TRY_CONVERT(NVARCHAR, x, 121)",
write={
"spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(30))",
},
)
self.validate_all(
"TRY_CONVERT(INT, x)",
write={
"spark": "CAST(x AS INT)",
},
)
self.validate_all(
"TRY_CAST(x AS INT)",
write={
"spark": "CAST(x AS INT)",
},
)
self.validate_all(
"CAST(x AS INT)",
write={
"spark": "CAST(x AS INT)",
},
)
def test_add_date(self):
self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')")
self.validate_all(
"SELECT DATEADD(year, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"}
)
self.validate_all("SELECT DATEADD(qq, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"})
self.validate_all("SELECT DATEADD(wk, 1, '2017/08/25')", write={"spark": "SELECT DATE_ADD('2017/08/25', 7)"})
def test_date_diff(self):
self.validate_identity("SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')")
self.validate_all(
"SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')",
write={
"tsql": "SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')",
"spark": "SELECT MONTHS_BETWEEN('2021/01/01', '2020/01/01') / 12",
},
)
self.validate_all(
"SELECT DATEDIFF(month, 'start','end')",
write={"spark": "SELECT MONTHS_BETWEEN('end', 'start')", "tsql": "SELECT DATEDIFF(month, 'start', 'end')"},
)
self.validate_all(
"SELECT DATEDIFF(quarter, 'start', 'end')", write={"spark": "SELECT MONTHS_BETWEEN('end', 'start') / 3"}
)
def test_iif(self):
self.validate_identity("SELECT IIF(cond, 'True', 'False')")
self.validate_all(
"SELECT IIF(cond, 'True', 'False');",
write={
"spark": "SELECT IF(cond, 'True', 'False')",
},
)

View file

@ -149,7 +149,6 @@ SELECT 1 AS count FROM test
SELECT 1 AS comment FROM test SELECT 1 AS comment FROM test
SELECT 1 AS numeric FROM test SELECT 1 AS numeric FROM test
SELECT 1 AS number FROM test SELECT 1 AS number FROM test
SELECT 1 AS number # annotation
SELECT t.count SELECT t.count
SELECT DISTINCT x FROM test SELECT DISTINCT x FROM test
SELECT DISTINCT x, y FROM test SELECT DISTINCT x, y FROM test

View file

@ -329,6 +329,10 @@ class TestBuild(unittest.TestCase):
lambda: exp.update("tbl", {"x": 1}, where="y > 0"), lambda: exp.update("tbl", {"x": 1}, where="y > 0"),
"UPDATE tbl SET x = 1 WHERE y > 0", "UPDATE tbl SET x = 1 WHERE y > 0",
), ),
(
lambda: exp.update("tbl", {"x": 1}, where=exp.condition("y > 0")),
"UPDATE tbl SET x = 1 WHERE y > 0",
),
( (
lambda: exp.update("tbl", {"x": 1}, from_="tbl2"), lambda: exp.update("tbl", {"x": 1}, from_="tbl2"),
"UPDATE tbl SET x = 1 FROM tbl2", "UPDATE tbl SET x = 1 FROM tbl2",

View file

@ -135,6 +135,53 @@ class TestExpressions(unittest.TestCase):
"SELECT * FROM a1 AS a JOIN b.a JOIN c.a2 JOIN d2 JOIN e.a", "SELECT * FROM a1 AS a JOIN b.a JOIN c.a2 JOIN d2 JOIN e.a",
) )
def test_replace_placeholders(self):
self.assertEqual(
exp.replace_placeholders(
parse_one("select * from :tbl1 JOIN :tbl2 ON :col1 = :col2 WHERE :col3 > 100"),
tbl1="foo",
tbl2="bar",
col1="a",
col2="b",
col3="c",
).sql(),
"SELECT * FROM foo JOIN bar ON a = b WHERE c > 100",
)
self.assertEqual(
exp.replace_placeholders(
parse_one("select * from ? JOIN ? ON ? = ? WHERE ? > 100"),
"foo",
"bar",
"a",
"b",
"c",
).sql(),
"SELECT * FROM foo JOIN bar ON a = b WHERE c > 100",
)
self.assertEqual(
exp.replace_placeholders(
parse_one("select * from ? WHERE ? > 100"),
"foo",
).sql(),
"SELECT * FROM foo WHERE ? > 100",
)
self.assertEqual(
exp.replace_placeholders(parse_one("select * from :name WHERE ? > 100"), another_name="bla").sql(),
"SELECT * FROM :name WHERE ? > 100",
)
self.assertEqual(
exp.replace_placeholders(
parse_one("select * from (SELECT :col1 FROM ?) WHERE :col2 > 100"),
"tbl1",
"tbl2",
"tbl3",
col1="a",
col2="b",
col3="c",
).sql(),
"SELECT * FROM (SELECT a FROM tbl1) WHERE b > 100",
)
def test_named_selects(self): def test_named_selects(self):
expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz") expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz")
self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"]) self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"])
@ -504,9 +551,24 @@ class TestExpressions(unittest.TestCase):
[e.alias_or_name for e in expression.expressions], [e.alias_or_name for e in expression.expressions],
["a", "B", "c", "D"], ["a", "B", "c", "D"],
) )
self.assertEqual(expression.sql(), sql) self.assertEqual(expression.sql(), "SELECT a, b AS B, c, d AS D")
self.assertEqual(expression.expressions[2].name, "comment") self.assertEqual(expression.expressions[2].name, "comment")
self.assertEqual(expression.sql(annotations=False), "SELECT a, b AS B, c, d AS D") self.assertEqual(
expression.sql(pretty=True, annotations=False),
"""SELECT
a,
b AS B,
c,
d AS D""",
)
self.assertEqual(
expression.sql(pretty=True),
"""SELECT
a,
b AS B,
c # comment,
d AS D # another_comment FROM foo""",
)
def test_to_table(self): def test_to_table(self):
table_only = exp.to_table("table_name") table_only = exp.to_table("table_name")

View file

@ -5,7 +5,7 @@ from sqlglot.time import format_time
class TestTime(unittest.TestCase): class TestTime(unittest.TestCase):
def test_format_time(self): def test_format_time(self):
self.assertEqual(format_time("", {}), "") self.assertEqual(format_time("", {}), None)
self.assertEqual(format_time(" ", {}), " ") self.assertEqual(format_time(" ", {}), " ")
mapping = {"a": "b", "aa": "c"} mapping = {"a": "b", "aa": "c"}
self.assertEqual(format_time("a", mapping), "b") self.assertEqual(format_time("a", mapping), "b")