Merging upstream version 11.1.3.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
8c1c1864c5
commit
fb546b57e5
95 changed files with 32569 additions and 30081 deletions
|
@ -40,7 +40,7 @@ if t.TYPE_CHECKING:
|
|||
T = t.TypeVar("T", bound=Expression)
|
||||
|
||||
|
||||
__version__ = "11.0.1"
|
||||
__version__ = "11.1.3"
|
||||
|
||||
pretty = False
|
||||
"""Whether to format generated SQL by default."""
|
||||
|
|
|
@ -143,6 +143,7 @@ class Column:
|
|||
if is_iterable(v)
|
||||
else Column.ensure_col(v).expression
|
||||
for k, v in kwargs.items()
|
||||
if v is not None
|
||||
}
|
||||
new_expression = (
|
||||
callable_expression(**ensure_expression_values)
|
||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import expressions as glotexp
|
||||
from sqlglot import exp as expression
|
||||
from sqlglot.dataframe.sql.column import Column
|
||||
from sqlglot.helper import ensure_list
|
||||
from sqlglot.helper import flatten as _flatten
|
||||
|
@ -18,25 +18,29 @@ def col(column_name: t.Union[ColumnOrName, t.Any]) -> Column:
|
|||
|
||||
def lit(value: t.Optional[t.Any] = None) -> Column:
|
||||
if isinstance(value, str):
|
||||
return Column(glotexp.Literal.string(str(value)))
|
||||
return Column(expression.Literal.string(str(value)))
|
||||
return Column(value)
|
||||
|
||||
|
||||
def greatest(*cols: ColumnOrName) -> Column:
|
||||
if len(cols) > 1:
|
||||
return Column.invoke_expression_over_column(cols[0], glotexp.Greatest, expressions=cols[1:])
|
||||
return Column.invoke_expression_over_column(cols[0], glotexp.Greatest)
|
||||
return Column.invoke_expression_over_column(
|
||||
cols[0], expression.Greatest, expressions=cols[1:]
|
||||
)
|
||||
return Column.invoke_expression_over_column(cols[0], expression.Greatest)
|
||||
|
||||
|
||||
def least(*cols: ColumnOrName) -> Column:
|
||||
if len(cols) > 1:
|
||||
return Column.invoke_expression_over_column(cols[0], glotexp.Least, expressions=cols[1:])
|
||||
return Column.invoke_expression_over_column(cols[0], glotexp.Least)
|
||||
return Column.invoke_expression_over_column(cols[0], expression.Least, expressions=cols[1:])
|
||||
return Column.invoke_expression_over_column(cols[0], expression.Least)
|
||||
|
||||
|
||||
def count_distinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
|
||||
columns = [Column.ensure_col(x) for x in [col] + list(cols)]
|
||||
return Column(glotexp.Count(this=glotexp.Distinct(expressions=[x.expression for x in columns])))
|
||||
return Column(
|
||||
expression.Count(this=expression.Distinct(expressions=[x.expression for x in columns]))
|
||||
)
|
||||
|
||||
|
||||
def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
|
||||
|
@ -46,8 +50,8 @@ def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
|
|||
def when(condition: Column, value: t.Any) -> Column:
|
||||
true_value = value if isinstance(value, Column) else lit(value)
|
||||
return Column(
|
||||
glotexp.Case(
|
||||
ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)]
|
||||
expression.Case(
|
||||
ifs=[expression.If(this=condition.column_expression, true=true_value.column_expression)]
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -65,19 +69,19 @@ def broadcast(df: DataFrame) -> DataFrame:
|
|||
|
||||
|
||||
def sqrt(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Sqrt)
|
||||
return Column.invoke_expression_over_column(col, expression.Sqrt)
|
||||
|
||||
|
||||
def abs(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Abs)
|
||||
return Column.invoke_expression_over_column(col, expression.Abs)
|
||||
|
||||
|
||||
def max(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Max)
|
||||
return Column.invoke_expression_over_column(col, expression.Max)
|
||||
|
||||
|
||||
def min(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Min)
|
||||
return Column.invoke_expression_over_column(col, expression.Min)
|
||||
|
||||
|
||||
def max_by(col: ColumnOrName, ord: ColumnOrName) -> Column:
|
||||
|
@ -89,15 +93,15 @@ def min_by(col: ColumnOrName, ord: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def count(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Count)
|
||||
return Column.invoke_expression_over_column(col, expression.Count)
|
||||
|
||||
|
||||
def sum(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Sum)
|
||||
return Column.invoke_expression_over_column(col, expression.Sum)
|
||||
|
||||
|
||||
def avg(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Avg)
|
||||
return Column.invoke_expression_over_column(col, expression.Avg)
|
||||
|
||||
|
||||
def mean(col: ColumnOrName) -> Column:
|
||||
|
@ -149,7 +153,7 @@ def cbrt(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def ceil(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Ceil)
|
||||
return Column.invoke_expression_over_column(col, expression.Ceil)
|
||||
|
||||
|
||||
def cos(col: ColumnOrName) -> Column:
|
||||
|
@ -169,7 +173,7 @@ def csc(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def exp(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Exp)
|
||||
return Column.invoke_expression_over_column(col, expression.Exp)
|
||||
|
||||
|
||||
def expm1(col: ColumnOrName) -> Column:
|
||||
|
@ -177,11 +181,11 @@ def expm1(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def floor(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Floor)
|
||||
return Column.invoke_expression_over_column(col, expression.Floor)
|
||||
|
||||
|
||||
def log10(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Log10)
|
||||
return Column.invoke_expression_over_column(col, expression.Log10)
|
||||
|
||||
|
||||
def log1p(col: ColumnOrName) -> Column:
|
||||
|
@ -189,13 +193,13 @@ def log1p(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def log2(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Log2)
|
||||
return Column.invoke_expression_over_column(col, expression.Log2)
|
||||
|
||||
|
||||
def log(arg1: t.Union[ColumnOrName, float], arg2: t.Optional[ColumnOrName] = None) -> Column:
|
||||
if arg2 is None:
|
||||
return Column.invoke_expression_over_column(arg1, glotexp.Ln)
|
||||
return Column.invoke_expression_over_column(arg1, glotexp.Log, expression=arg2)
|
||||
return Column.invoke_expression_over_column(arg1, expression.Ln)
|
||||
return Column.invoke_expression_over_column(arg1, expression.Log, expression=arg2)
|
||||
|
||||
|
||||
def rint(col: ColumnOrName) -> Column:
|
||||
|
@ -247,7 +251,7 @@ def bitwiseNOT(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def bitwise_not(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.BitwiseNot)
|
||||
return Column.invoke_expression_over_column(col, expression.BitwiseNot)
|
||||
|
||||
|
||||
def asc_nulls_first(col: ColumnOrName) -> Column:
|
||||
|
@ -267,27 +271,27 @@ def desc_nulls_last(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def stddev(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Stddev)
|
||||
return Column.invoke_expression_over_column(col, expression.Stddev)
|
||||
|
||||
|
||||
def stddev_samp(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.StddevSamp)
|
||||
return Column.invoke_expression_over_column(col, expression.StddevSamp)
|
||||
|
||||
|
||||
def stddev_pop(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.StddevPop)
|
||||
return Column.invoke_expression_over_column(col, expression.StddevPop)
|
||||
|
||||
|
||||
def variance(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Variance)
|
||||
return Column.invoke_expression_over_column(col, expression.Variance)
|
||||
|
||||
|
||||
def var_samp(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Variance)
|
||||
return Column.invoke_expression_over_column(col, expression.Variance)
|
||||
|
||||
|
||||
def var_pop(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.VariancePop)
|
||||
return Column.invoke_expression_over_column(col, expression.VariancePop)
|
||||
|
||||
|
||||
def skewness(col: ColumnOrName) -> Column:
|
||||
|
@ -299,11 +303,11 @@ def kurtosis(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def collect_list(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.ArrayAgg)
|
||||
return Column.invoke_expression_over_column(col, expression.ArrayAgg)
|
||||
|
||||
|
||||
def collect_set(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.SetAgg)
|
||||
return Column.invoke_expression_over_column(col, expression.SetAgg)
|
||||
|
||||
|
||||
def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
|
||||
|
@ -311,27 +315,27 @@ def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]
|
|||
|
||||
|
||||
def pow(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
|
||||
return Column.invoke_expression_over_column(col1, glotexp.Pow, expression=col2)
|
||||
return Column.invoke_expression_over_column(col1, expression.Pow, expression=col2)
|
||||
|
||||
|
||||
def row_number() -> Column:
|
||||
return Column(glotexp.Anonymous(this="ROW_NUMBER"))
|
||||
return Column(expression.Anonymous(this="ROW_NUMBER"))
|
||||
|
||||
|
||||
def dense_rank() -> Column:
|
||||
return Column(glotexp.Anonymous(this="DENSE_RANK"))
|
||||
return Column(expression.Anonymous(this="DENSE_RANK"))
|
||||
|
||||
|
||||
def rank() -> Column:
|
||||
return Column(glotexp.Anonymous(this="RANK"))
|
||||
return Column(expression.Anonymous(this="RANK"))
|
||||
|
||||
|
||||
def cume_dist() -> Column:
|
||||
return Column(glotexp.Anonymous(this="CUME_DIST"))
|
||||
return Column(expression.Anonymous(this="CUME_DIST"))
|
||||
|
||||
|
||||
def percent_rank() -> Column:
|
||||
return Column(glotexp.Anonymous(this="PERCENT_RANK"))
|
||||
return Column(expression.Anonymous(this="PERCENT_RANK"))
|
||||
|
||||
|
||||
def approxCountDistinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column:
|
||||
|
@ -340,14 +344,16 @@ def approxCountDistinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Col
|
|||
|
||||
def approx_count_distinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column:
|
||||
if rsd is None:
|
||||
return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct)
|
||||
return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct, accuracy=rsd)
|
||||
return Column.invoke_expression_over_column(col, expression.ApproxDistinct)
|
||||
return Column.invoke_expression_over_column(col, expression.ApproxDistinct, accuracy=rsd)
|
||||
|
||||
|
||||
def coalesce(*cols: ColumnOrName) -> Column:
|
||||
if len(cols) > 1:
|
||||
return Column.invoke_expression_over_column(cols[0], glotexp.Coalesce, expressions=cols[1:])
|
||||
return Column.invoke_expression_over_column(cols[0], glotexp.Coalesce)
|
||||
return Column.invoke_expression_over_column(
|
||||
cols[0], expression.Coalesce, expressions=cols[1:]
|
||||
)
|
||||
return Column.invoke_expression_over_column(cols[0], expression.Coalesce)
|
||||
|
||||
|
||||
def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column:
|
||||
|
@ -409,10 +415,10 @@ def percentile_approx(
|
|||
) -> Column:
|
||||
if accuracy:
|
||||
return Column.invoke_expression_over_column(
|
||||
col, glotexp.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy
|
||||
col, expression.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy
|
||||
)
|
||||
return Column.invoke_expression_over_column(
|
||||
col, glotexp.ApproxQuantile, quantile=lit(percentage)
|
||||
col, expression.ApproxQuantile, quantile=lit(percentage)
|
||||
)
|
||||
|
||||
|
||||
|
@ -426,8 +432,8 @@ def randn(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
|
|||
|
||||
def round(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
|
||||
if scale is not None:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Round, decimals=scale)
|
||||
return Column.invoke_expression_over_column(col, glotexp.Round)
|
||||
return Column.invoke_expression_over_column(col, expression.Round, decimals=scale)
|
||||
return Column.invoke_expression_over_column(col, expression.Round)
|
||||
|
||||
|
||||
def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
|
||||
|
@ -437,7 +443,9 @@ def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
|
|||
|
||||
|
||||
def shiftleft(col: ColumnOrName, numBits: int) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.BitwiseLeftShift, expression=numBits)
|
||||
return Column.invoke_expression_over_column(
|
||||
col, expression.BitwiseLeftShift, expression=numBits
|
||||
)
|
||||
|
||||
|
||||
def shiftLeft(col: ColumnOrName, numBits: int) -> Column:
|
||||
|
@ -445,7 +453,9 @@ def shiftLeft(col: ColumnOrName, numBits: int) -> Column:
|
|||
|
||||
|
||||
def shiftright(col: ColumnOrName, numBits: int) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.BitwiseRightShift, expression=numBits)
|
||||
return Column.invoke_expression_over_column(
|
||||
col, expression.BitwiseRightShift, expression=numBits
|
||||
)
|
||||
|
||||
|
||||
def shiftRight(col: ColumnOrName, numBits: int) -> Column:
|
||||
|
@ -466,7 +476,7 @@ def expr(str: str) -> Column:
|
|||
|
||||
def struct(col: t.Union[ColumnOrName, t.Iterable[ColumnOrName]], *cols: ColumnOrName) -> Column:
|
||||
columns = ensure_list(col) + list(cols)
|
||||
return Column.invoke_expression_over_column(None, glotexp.Struct, expressions=columns)
|
||||
return Column.invoke_expression_over_column(None, expression.Struct, expressions=columns)
|
||||
|
||||
|
||||
def conv(col: ColumnOrName, fromBase: int, toBase: int) -> Column:
|
||||
|
@ -512,19 +522,19 @@ def ntile(n: int) -> Column:
|
|||
|
||||
|
||||
def current_date() -> Column:
|
||||
return Column.invoke_expression_over_column(None, glotexp.CurrentDate)
|
||||
return Column.invoke_expression_over_column(None, expression.CurrentDate)
|
||||
|
||||
|
||||
def current_timestamp() -> Column:
|
||||
return Column.invoke_expression_over_column(None, glotexp.CurrentTimestamp)
|
||||
return Column.invoke_expression_over_column(None, expression.CurrentTimestamp)
|
||||
|
||||
|
||||
def date_format(col: ColumnOrName, format: str) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.TimeToStr, format=lit(format))
|
||||
return Column.invoke_expression_over_column(col, expression.TimeToStr, format=lit(format))
|
||||
|
||||
|
||||
def year(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Year)
|
||||
return Column.invoke_expression_over_column(col, expression.Year)
|
||||
|
||||
|
||||
def quarter(col: ColumnOrName) -> Column:
|
||||
|
@ -532,19 +542,19 @@ def quarter(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def month(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Month)
|
||||
return Column.invoke_expression_over_column(col, expression.Month)
|
||||
|
||||
|
||||
def dayofweek(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.DayOfWeek)
|
||||
return Column.invoke_expression_over_column(col, expression.DayOfWeek)
|
||||
|
||||
|
||||
def dayofmonth(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.DayOfMonth)
|
||||
return Column.invoke_expression_over_column(col, expression.DayOfMonth)
|
||||
|
||||
|
||||
def dayofyear(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.DayOfYear)
|
||||
return Column.invoke_expression_over_column(col, expression.DayOfYear)
|
||||
|
||||
|
||||
def hour(col: ColumnOrName) -> Column:
|
||||
|
@ -560,7 +570,7 @@ def second(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def weekofyear(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.WeekOfYear)
|
||||
return Column.invoke_expression_over_column(col, expression.WeekOfYear)
|
||||
|
||||
|
||||
def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Column:
|
||||
|
@ -568,15 +578,15 @@ def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Col
|
|||
|
||||
|
||||
def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.DateAdd, expression=days)
|
||||
return Column.invoke_expression_over_column(col, expression.DateAdd, expression=days)
|
||||
|
||||
|
||||
def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.DateSub, expression=days)
|
||||
return Column.invoke_expression_over_column(col, expression.DateSub, expression=days)
|
||||
|
||||
|
||||
def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(end, glotexp.DateDiff, expression=start)
|
||||
return Column.invoke_expression_over_column(end, expression.DateDiff, expression=start)
|
||||
|
||||
|
||||
def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Column:
|
||||
|
@ -593,8 +603,10 @@ def months_between(
|
|||
|
||||
def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
|
||||
if format is not None:
|
||||
return Column.invoke_expression_over_column(col, glotexp.TsOrDsToDate, format=lit(format))
|
||||
return Column.invoke_expression_over_column(col, glotexp.TsOrDsToDate)
|
||||
return Column.invoke_expression_over_column(
|
||||
col, expression.TsOrDsToDate, format=lit(format)
|
||||
)
|
||||
return Column.invoke_expression_over_column(col, expression.TsOrDsToDate)
|
||||
|
||||
|
||||
def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
|
||||
|
@ -604,11 +616,13 @@ def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
|
|||
|
||||
|
||||
def trunc(col: ColumnOrName, format: str) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.DateTrunc, unit=lit(format))
|
||||
return Column.invoke_expression_over_column(col, expression.DateTrunc, unit=lit(format))
|
||||
|
||||
|
||||
def date_trunc(format: str, timestamp: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(timestamp, glotexp.TimestampTrunc, unit=lit(format))
|
||||
return Column.invoke_expression_over_column(
|
||||
timestamp, expression.TimestampTrunc, unit=lit(format)
|
||||
)
|
||||
|
||||
|
||||
def next_day(col: ColumnOrName, dayOfWeek: str) -> Column:
|
||||
|
@ -621,8 +635,8 @@ def last_day(col: ColumnOrName) -> Column:
|
|||
|
||||
def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
|
||||
if format is not None:
|
||||
return Column.invoke_expression_over_column(col, glotexp.UnixToStr, format=lit(format))
|
||||
return Column.invoke_expression_over_column(col, glotexp.UnixToStr)
|
||||
return Column.invoke_expression_over_column(col, expression.UnixToStr, format=lit(format))
|
||||
return Column.invoke_expression_over_column(col, expression.UnixToStr)
|
||||
|
||||
|
||||
def unix_timestamp(
|
||||
|
@ -630,9 +644,9 @@ def unix_timestamp(
|
|||
) -> Column:
|
||||
if format is not None:
|
||||
return Column.invoke_expression_over_column(
|
||||
timestamp, glotexp.StrToUnix, format=lit(format)
|
||||
timestamp, expression.StrToUnix, format=lit(format)
|
||||
)
|
||||
return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix)
|
||||
return Column.invoke_expression_over_column(timestamp, expression.StrToUnix)
|
||||
|
||||
|
||||
def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column:
|
||||
|
@ -719,11 +733,11 @@ def raise_error(errorMsg: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def upper(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Upper)
|
||||
return Column.invoke_expression_over_column(col, expression.Upper)
|
||||
|
||||
|
||||
def lower(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Lower)
|
||||
return Column.invoke_expression_over_column(col, expression.Lower)
|
||||
|
||||
|
||||
def ascii(col: ColumnOrLiteral) -> Column:
|
||||
|
@ -747,24 +761,24 @@ def rtrim(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def trim(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Trim)
|
||||
return Column.invoke_expression_over_column(col, expression.Trim)
|
||||
|
||||
|
||||
def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(
|
||||
None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols)
|
||||
None, expression.ConcatWs, expressions=[lit(sep)] + list(cols)
|
||||
)
|
||||
|
||||
|
||||
def decode(col: ColumnOrName, charset: str) -> Column:
|
||||
return Column.invoke_expression_over_column(
|
||||
col, glotexp.Decode, charset=glotexp.Literal.string(charset)
|
||||
col, expression.Decode, charset=expression.Literal.string(charset)
|
||||
)
|
||||
|
||||
|
||||
def encode(col: ColumnOrName, charset: str) -> Column:
|
||||
return Column.invoke_expression_over_column(
|
||||
col, glotexp.Encode, charset=glotexp.Literal.string(charset)
|
||||
col, expression.Encode, charset=expression.Literal.string(charset)
|
||||
)
|
||||
|
||||
|
||||
|
@ -816,16 +830,16 @@ def substring_index(str: ColumnOrName, delim: str, count: int) -> Column:
|
|||
|
||||
|
||||
def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(left, glotexp.Levenshtein, expression=right)
|
||||
return Column.invoke_expression_over_column(left, expression.Levenshtein, expression=right)
|
||||
|
||||
|
||||
def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column:
|
||||
substr_col = lit(substr)
|
||||
if pos is not None:
|
||||
return Column.invoke_expression_over_column(
|
||||
str, glotexp.StrPosition, substr=substr_col, position=pos
|
||||
str, expression.StrPosition, substr=substr_col, position=pos
|
||||
)
|
||||
return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col)
|
||||
return Column.invoke_expression_over_column(str, expression.StrPosition, substr=substr_col)
|
||||
|
||||
|
||||
def lpad(col: ColumnOrName, len: int, pad: str) -> Column:
|
||||
|
@ -837,21 +851,26 @@ def rpad(col: ColumnOrName, len: int, pad: str) -> Column:
|
|||
|
||||
|
||||
def repeat(col: ColumnOrName, n: int) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Repeat, times=lit(n))
|
||||
return Column.invoke_expression_over_column(col, expression.Repeat, times=lit(n))
|
||||
|
||||
|
||||
def split(str: ColumnOrName, pattern: str, limit: t.Optional[int] = None) -> Column:
|
||||
if limit is not None:
|
||||
return Column.invoke_expression_over_column(
|
||||
str, glotexp.RegexpSplit, expression=lit(pattern).expression, limit=limit
|
||||
str, expression.RegexpSplit, expression=lit(pattern).expression, limit=limit
|
||||
)
|
||||
return Column.invoke_expression_over_column(str, glotexp.RegexpSplit, expression=lit(pattern))
|
||||
return Column.invoke_expression_over_column(
|
||||
str, expression.RegexpSplit, expression=lit(pattern)
|
||||
)
|
||||
|
||||
|
||||
def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None) -> Column:
|
||||
if idx is not None:
|
||||
return Column.invoke_anonymous_function(str, "REGEXP_EXTRACT", lit(pattern), idx)
|
||||
return Column.invoke_anonymous_function(str, "REGEXP_EXTRACT", lit(pattern))
|
||||
return Column.invoke_expression_over_column(
|
||||
str,
|
||||
expression.RegexpExtract,
|
||||
expression=lit(pattern),
|
||||
group=idx,
|
||||
)
|
||||
|
||||
|
||||
def regexp_replace(str: ColumnOrName, pattern: str, replacement: str) -> Column:
|
||||
|
@ -859,7 +878,7 @@ def regexp_replace(str: ColumnOrName, pattern: str, replacement: str) -> Column:
|
|||
|
||||
|
||||
def initcap(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Initcap)
|
||||
return Column.invoke_expression_over_column(col, expression.Initcap)
|
||||
|
||||
|
||||
def soundex(col: ColumnOrName) -> Column:
|
||||
|
@ -871,15 +890,15 @@ def bin(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def hex(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Hex)
|
||||
return Column.invoke_expression_over_column(col, expression.Hex)
|
||||
|
||||
|
||||
def unhex(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Unhex)
|
||||
return Column.invoke_expression_over_column(col, expression.Unhex)
|
||||
|
||||
|
||||
def length(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Length)
|
||||
return Column.invoke_expression_over_column(col, expression.Length)
|
||||
|
||||
|
||||
def octet_length(col: ColumnOrName) -> Column:
|
||||
|
@ -896,27 +915,27 @@ def translate(srcCol: ColumnOrName, matching: str, replace: str) -> Column:
|
|||
|
||||
def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
|
||||
columns = _flatten(cols) if not isinstance(cols[0], (str, Column)) else cols
|
||||
return Column.invoke_expression_over_column(None, glotexp.Array, expressions=columns)
|
||||
return Column.invoke_expression_over_column(None, expression.Array, expressions=columns)
|
||||
|
||||
|
||||
def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
|
||||
cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
|
||||
return Column.invoke_expression_over_column(
|
||||
None,
|
||||
glotexp.VarMap,
|
||||
expression.VarMap,
|
||||
keys=array(*cols[::2]).expression,
|
||||
values=array(*cols[1::2]).expression,
|
||||
)
|
||||
|
||||
|
||||
def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(None, glotexp.Map, keys=col1, values=col2)
|
||||
return Column.invoke_expression_over_column(None, expression.Map, keys=col1, values=col2)
|
||||
|
||||
|
||||
def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
|
||||
value_col = value if isinstance(value, Column) else lit(value)
|
||||
return Column.invoke_expression_over_column(
|
||||
col, glotexp.ArrayContains, expression=value_col.expression
|
||||
col, expression.ArrayContains, expression=value_col.expression
|
||||
)
|
||||
|
||||
|
||||
|
@ -943,7 +962,7 @@ def array_join(
|
|||
|
||||
|
||||
def concat(*cols: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(None, glotexp.Concat, expressions=cols)
|
||||
return Column.invoke_expression_over_column(None, expression.Concat, expressions=cols)
|
||||
|
||||
|
||||
def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
|
||||
|
@ -978,11 +997,11 @@ def array_except(col1: ColumnOrName, col2: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def explode(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Explode)
|
||||
return Column.invoke_expression_over_column(col, expression.Explode)
|
||||
|
||||
|
||||
def posexplode(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.Posexplode)
|
||||
return Column.invoke_expression_over_column(col, expression.Posexplode)
|
||||
|
||||
|
||||
def explode_outer(col: ColumnOrName) -> Column:
|
||||
|
@ -994,7 +1013,7 @@ def posexplode_outer(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def get_json_object(col: ColumnOrName, path: str) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.JSONExtract, path=lit(path))
|
||||
return Column.invoke_expression_over_column(col, expression.JSONExtract, path=lit(path))
|
||||
|
||||
|
||||
def json_tuple(col: ColumnOrName, *fields: str) -> Column:
|
||||
|
@ -1042,7 +1061,7 @@ def to_csv(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> C
|
|||
|
||||
|
||||
def size(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(col, glotexp.ArraySize)
|
||||
return Column.invoke_expression_over_column(col, expression.ArraySize)
|
||||
|
||||
|
||||
def array_min(col: ColumnOrName) -> Column:
|
||||
|
@ -1055,8 +1074,8 @@ def array_max(col: ColumnOrName) -> Column:
|
|||
|
||||
def sort_array(col: ColumnOrName, asc: t.Optional[bool] = None) -> Column:
|
||||
if asc is not None:
|
||||
return Column.invoke_expression_over_column(col, glotexp.SortArray, asc=asc)
|
||||
return Column.invoke_expression_over_column(col, glotexp.SortArray)
|
||||
return Column.invoke_expression_over_column(col, expression.SortArray, asc=asc)
|
||||
return Column.invoke_expression_over_column(col, expression.SortArray)
|
||||
|
||||
|
||||
def array_sort(
|
||||
|
@ -1065,8 +1084,10 @@ def array_sort(
|
|||
) -> Column:
|
||||
if comparator is not None:
|
||||
f_expression = _get_lambda_from_func(comparator)
|
||||
return Column.invoke_expression_over_column(col, glotexp.ArraySort, expression=f_expression)
|
||||
return Column.invoke_expression_over_column(col, glotexp.ArraySort)
|
||||
return Column.invoke_expression_over_column(
|
||||
col, expression.ArraySort, expression=f_expression
|
||||
)
|
||||
return Column.invoke_expression_over_column(col, expression.ArraySort)
|
||||
|
||||
|
||||
def shuffle(col: ColumnOrName) -> Column:
|
||||
|
@ -1146,13 +1167,13 @@ def aggregate(
|
|||
finish_exp = _get_lambda_from_func(finish)
|
||||
return Column.invoke_expression_over_column(
|
||||
col,
|
||||
glotexp.Reduce,
|
||||
expression.Reduce,
|
||||
initial=initialValue,
|
||||
merge=Column(merge_exp),
|
||||
finish=Column(finish_exp),
|
||||
)
|
||||
return Column.invoke_expression_over_column(
|
||||
col, glotexp.Reduce, initial=initialValue, merge=Column(merge_exp)
|
||||
col, expression.Reduce, initial=initialValue, merge=Column(merge_exp)
|
||||
)
|
||||
|
||||
|
||||
|
@ -1179,7 +1200,9 @@ def filter(
|
|||
f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
|
||||
) -> Column:
|
||||
f_expression = _get_lambda_from_func(f)
|
||||
return Column.invoke_expression_over_column(col, glotexp.ArrayFilter, expression=f_expression)
|
||||
return Column.invoke_expression_over_column(
|
||||
col, expression.ArrayFilter, expression=f_expression
|
||||
)
|
||||
|
||||
|
||||
def zip_with(
|
||||
|
@ -1219,10 +1242,10 @@ def _lambda_quoted(value: str) -> t.Optional[bool]:
|
|||
|
||||
def _get_lambda_from_func(lambda_expression: t.Callable):
|
||||
variables = [
|
||||
glotexp.to_identifier(x, quoted=_lambda_quoted(x))
|
||||
expression.to_identifier(x, quoted=_lambda_quoted(x))
|
||||
for x in lambda_expression.__code__.co_varnames
|
||||
]
|
||||
return glotexp.Lambda(
|
||||
return expression.Lambda(
|
||||
this=lambda_expression(*[Column(x) for x in variables]).expression,
|
||||
expressions=variables,
|
||||
)
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
|
@ -31,13 +32,6 @@ def _date_add(expression_class: t.Type[E]) -> t.Callable[[t.Sequence], E]:
|
|||
return func
|
||||
|
||||
|
||||
def _date_trunc(args: t.Sequence) -> exp.Expression:
|
||||
unit = seq_get(args, 1)
|
||||
if isinstance(unit, exp.Column):
|
||||
unit = exp.Var(this=unit.name)
|
||||
return exp.DateTrunc(this=seq_get(args, 0), expression=unit)
|
||||
|
||||
|
||||
def _date_add_sql(
|
||||
data_type: str, kind: str
|
||||
) -> t.Callable[[generator.Generator, exp.Expression], str]:
|
||||
|
@ -158,11 +152,23 @@ class BigQuery(Dialect):
|
|||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"DATE_TRUNC": _date_trunc,
|
||||
"DATE_TRUNC": lambda args: exp.DateTrunc(
|
||||
unit=exp.Literal.string(seq_get(args, 1).name), # type: ignore
|
||||
this=seq_get(args, 0),
|
||||
),
|
||||
"DATE_ADD": _date_add(exp.DateAdd),
|
||||
"DATETIME_ADD": _date_add(exp.DatetimeAdd),
|
||||
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
|
||||
"REGEXP_CONTAINS": exp.RegexpLike.from_arg_list,
|
||||
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
|
||||
this=seq_get(args, 0),
|
||||
expression=seq_get(args, 1),
|
||||
position=seq_get(args, 2),
|
||||
occurrence=seq_get(args, 3),
|
||||
group=exp.Literal.number(1)
|
||||
if re.compile(str(seq_get(args, 1))).groups == 1
|
||||
else None,
|
||||
),
|
||||
"TIME_ADD": _date_add(exp.TimeAdd),
|
||||
"TIMESTAMP_ADD": _date_add(exp.TimestampAdd),
|
||||
"DATE_SUB": _date_add(exp.DateSub),
|
||||
|
@ -214,6 +220,7 @@ class BigQuery(Dialect):
|
|||
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
|
||||
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")),
|
||||
exp.GroupConcat: rename_func("STRING_AGG"),
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.IntDiv: rename_func("DIV"),
|
||||
|
@ -226,11 +233,12 @@ class BigQuery(Dialect):
|
|||
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
|
||||
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
exp.Values: _derived_table_values_to_unnest,
|
||||
exp.ReturnsProperty: _returnsproperty_sql,
|
||||
exp.Create: _create_sql,
|
||||
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
|
||||
exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression),
|
||||
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
|
||||
if e.name == "IMMUTABLE"
|
||||
else "NOT DETERMINISTIC",
|
||||
|
@ -251,6 +259,10 @@ class BigQuery(Dialect):
|
|||
exp.DataType.Type.VARCHAR: "STRING",
|
||||
exp.DataType.Type.NVARCHAR: "STRING",
|
||||
}
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
}
|
||||
|
||||
EXPLICIT_UNION = True
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ from sqlglot import exp
|
|||
from sqlglot.dialects.dialect import parse_date_delta
|
||||
from sqlglot.dialects.spark import Spark
|
||||
from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
class Databricks(Spark):
|
||||
|
@ -21,3 +22,11 @@ class Databricks(Spark):
|
|||
exp.DateAdd: generate_date_delta_with_unit_sql,
|
||||
exp.DateDiff: generate_date_delta_with_unit_sql,
|
||||
}
|
||||
|
||||
PARAMETER_TOKEN = "$"
|
||||
|
||||
class Tokenizer(Spark.Tokenizer):
|
||||
SINGLE_TOKENS = {
|
||||
**Spark.Tokenizer.SINGLE_TOKENS,
|
||||
"$": TokenType.PARAMETER,
|
||||
}
|
||||
|
|
|
@ -215,24 +215,19 @@ DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
|
|||
|
||||
|
||||
def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
|
||||
def _rename(self, expression):
|
||||
args = flatten(expression.args.values())
|
||||
return f"{self.normalize_func(name)}({self.format_args(*args)})"
|
||||
|
||||
return _rename
|
||||
return lambda self, expression: self.func(name, *flatten(expression.args.values()))
|
||||
|
||||
|
||||
def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
|
||||
if expression.args.get("accuracy"):
|
||||
self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
|
||||
return f"APPROX_COUNT_DISTINCT({self.format_args(expression.this)})"
|
||||
return self.func("APPROX_COUNT_DISTINCT", expression.this)
|
||||
|
||||
|
||||
def if_sql(self: Generator, expression: exp.If) -> str:
|
||||
expressions = self.format_args(
|
||||
expression.this, expression.args.get("true"), expression.args.get("false")
|
||||
return self.func(
|
||||
"IF", expression.this, expression.args.get("true"), expression.args.get("false")
|
||||
)
|
||||
return f"IF({expressions})"
|
||||
|
||||
|
||||
def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
|
||||
|
@ -318,13 +313,13 @@ def var_map_sql(
|
|||
|
||||
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
|
||||
self.unsupported("Cannot convert array columns into map.")
|
||||
return f"{map_func_name}({self.format_args(keys, values)})"
|
||||
return self.func(map_func_name, keys, values)
|
||||
|
||||
args = []
|
||||
for key, value in zip(keys.expressions, values.expressions):
|
||||
args.append(self.sql(key))
|
||||
args.append(self.sql(value))
|
||||
return f"{map_func_name}({self.format_args(*args)})"
|
||||
return self.func(map_func_name, *args)
|
||||
|
||||
|
||||
def format_time_lambda(
|
||||
|
@ -400,10 +395,9 @@ def locate_to_strposition(args: t.Sequence) -> exp.Expression:
|
|||
|
||||
|
||||
def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
|
||||
args = self.format_args(
|
||||
expression.args.get("substr"), expression.this, expression.args.get("position")
|
||||
return self.func(
|
||||
"LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
|
||||
)
|
||||
return f"LOCATE({args})"
|
||||
|
||||
|
||||
def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
|
||||
|
|
|
@ -39,23 +39,6 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
|
|||
return func
|
||||
|
||||
|
||||
def if_sql(self: generator.Generator, expression: exp.If) -> str:
|
||||
"""
|
||||
Drill requires backticks around certain SQL reserved words, IF being one of them, This function
|
||||
adds the backticks around the keyword IF.
|
||||
Args:
|
||||
self: The Drill dialect
|
||||
expression: The input IF expression
|
||||
|
||||
Returns: The expression with IF in backticks.
|
||||
|
||||
"""
|
||||
expressions = self.format_args(
|
||||
expression.this, expression.args.get("true"), expression.args.get("false")
|
||||
)
|
||||
return f"`IF`({expressions})"
|
||||
|
||||
|
||||
def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
|
@ -134,7 +117,7 @@ class Drill(Dialect):
|
|||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
|
@ -148,7 +131,7 @@ class Drill(Dialect):
|
|||
exp.DateSub: _date_add_sql("SUB"),
|
||||
exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)",
|
||||
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})",
|
||||
exp.If: if_sql,
|
||||
exp.If: lambda self, e: f"`IF`({self.format_args(e.this, e.args.get('true'), e.args.get('false'))})",
|
||||
exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
|
||||
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
|
|
|
@ -73,11 +73,24 @@ def _datatype_sql(self, expression):
|
|||
return self.datatype_sql(expression)
|
||||
|
||||
|
||||
def _regexp_extract_sql(self, expression):
|
||||
bad_args = list(filter(expression.args.get, ("position", "occurrence")))
|
||||
if bad_args:
|
||||
self.unsupported(f"REGEXP_EXTRACT does not support arg(s) {bad_args}")
|
||||
return self.func(
|
||||
"REGEXP_EXTRACT",
|
||||
expression.args.get("this"),
|
||||
expression.args.get("expression"),
|
||||
expression.args.get("group"),
|
||||
)
|
||||
|
||||
|
||||
class DuckDB(Dialect):
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
":=": TokenType.EQ,
|
||||
"ATTACH": TokenType.COMMAND,
|
||||
"CHARACTER VARYING": TokenType.VARCHAR,
|
||||
}
|
||||
|
||||
|
@ -117,7 +130,7 @@ class DuckDB(Dialect):
|
|||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
|
||||
exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0])
|
||||
if isinstance(seq_get(e.expressions, 0), exp.Select)
|
||||
else rename_func("LIST_VALUE")(self, e),
|
||||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||
|
@ -125,7 +138,9 @@ class DuckDB(Dialect):
|
|||
exp.ArraySum: rename_func("LIST_SUM"),
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.DateAdd: _date_add,
|
||||
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.format_args(e.args.get("unit") or "'day'", e.expression, e.this)})""",
|
||||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF", e.args.get("unit") or exp.Literal.string("day"), e.expression, e.this
|
||||
),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
|
||||
exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)",
|
||||
|
@ -137,6 +152,7 @@ class DuckDB(Dialect):
|
|||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.Properties: no_properties_sql,
|
||||
exp.RegexpExtract: _regexp_extract_sql,
|
||||
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
|
||||
exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"),
|
||||
exp.SafeDivide: no_safe_divide_sql,
|
||||
|
|
|
@ -43,7 +43,7 @@ def _add_date_sql(self, expression):
|
|||
else expression.expression
|
||||
)
|
||||
modified_increment = exp.Literal.number(modified_increment)
|
||||
return f"{func}({self.format_args(expression.this, modified_increment.this)})"
|
||||
return self.func(func, expression.this, modified_increment.this)
|
||||
|
||||
|
||||
def _date_diff_sql(self, expression):
|
||||
|
@ -66,7 +66,7 @@ def _property_sql(self, expression):
|
|||
|
||||
|
||||
def _str_to_unix(self, expression):
|
||||
return f"UNIX_TIMESTAMP({self.format_args(expression.this, _time_format(self, expression))})"
|
||||
return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression))
|
||||
|
||||
|
||||
def _str_to_date(self, expression):
|
||||
|
@ -312,7 +312,9 @@ class Hive(Dialect):
|
|||
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.TsOrDsToDate: _to_date_sql,
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.format_args(e.this, _time_format(self, e))})",
|
||||
exp.UnixToStr: lambda self, e: self.func(
|
||||
"FROM_UNIXTIME", e.this, _time_format(self, e)
|
||||
),
|
||||
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
||||
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
|
||||
|
@ -324,9 +326,9 @@ class Hive(Dialect):
|
|||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
}
|
||||
|
||||
def with_properties(self, properties):
|
||||
|
|
|
@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
|
|||
no_paren_current_date_sql,
|
||||
no_tablesample_sql,
|
||||
no_trycast_sql,
|
||||
rename_func,
|
||||
strposition_to_locate_sql,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
|
@ -22,9 +23,8 @@ def _show_parser(*args, **kwargs):
|
|||
|
||||
|
||||
def _date_trunc_sql(self, expression):
|
||||
unit = expression.name.lower()
|
||||
|
||||
expr = self.sql(expression.expression)
|
||||
expr = self.sql(expression, "this")
|
||||
unit = expression.text("unit")
|
||||
|
||||
if unit == "day":
|
||||
return f"DATE({expr})"
|
||||
|
@ -42,7 +42,7 @@ def _date_trunc_sql(self, expression):
|
|||
concat = f"CONCAT(YEAR({expr}), ' 1 1')"
|
||||
date_format = "%Y %c %e"
|
||||
else:
|
||||
self.unsupported("Unexpected interval unit: {unit}")
|
||||
self.unsupported(f"Unexpected interval unit: {unit}")
|
||||
return f"DATE({expr})"
|
||||
|
||||
return f"STR_TO_DATE({concat}, '{date_format}')"
|
||||
|
@ -443,6 +443,10 @@ class MySQL(Dialect):
|
|||
exp.DateAdd: _date_add_sql("ADD"),
|
||||
exp.DateSub: _date_add_sql("SUB"),
|
||||
exp.DateTrunc: _date_trunc_sql,
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
|
||||
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
|
||||
exp.StrToDate: _str_to_date_sql,
|
||||
exp.StrToTime: _str_to_date_sql,
|
||||
|
|
|
@ -1,15 +1,49 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sql
|
||||
from sqlglot.helper import csv
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
|
||||
TokenType.COLUMN,
|
||||
TokenType.RETURNING,
|
||||
}
|
||||
|
||||
|
||||
def _limit_sql(self, expression):
|
||||
return self.fetch_sql(exp.Fetch(direction="FIRST", count=expression.expression))
|
||||
|
||||
|
||||
def _parse_xml_table(self) -> exp.XMLTable:
|
||||
this = self._parse_string()
|
||||
|
||||
passing = None
|
||||
columns = None
|
||||
|
||||
if self._match_text_seq("PASSING"):
|
||||
# The BY VALUE keywords are optional and are provided for semantic clarity
|
||||
self._match_text_seq("BY", "VALUE")
|
||||
passing = self._parse_csv(
|
||||
lambda: self._parse_table(alias_tokens=PASSING_TABLE_ALIAS_TOKENS)
|
||||
)
|
||||
|
||||
by_ref = self._match_text_seq("RETURNING", "SEQUENCE", "BY", "REF")
|
||||
|
||||
if self._match_text_seq("COLUMNS"):
|
||||
columns = self._parse_csv(lambda: self._parse_column_def(self._parse_field(any_token=True)))
|
||||
|
||||
return self.expression(
|
||||
exp.XMLTable,
|
||||
this=this,
|
||||
passing=passing,
|
||||
columns=columns,
|
||||
by_ref=by_ref,
|
||||
)
|
||||
|
||||
|
||||
class Oracle(Dialect):
|
||||
# https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212
|
||||
# https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes
|
||||
|
@ -43,6 +77,11 @@ class Oracle(Dialect):
|
|||
"DECODE": exp.Matches.from_arg_list,
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
"XMLTABLE": _parse_xml_table,
|
||||
}
|
||||
|
||||
class Generator(generator.Generator):
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
|
||||
|
@ -74,7 +113,7 @@ class Oracle(Dialect):
|
|||
exp.Substring: rename_func("SUBSTR"),
|
||||
}
|
||||
|
||||
def query_modifiers(self, expression, *sqls):
|
||||
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
|
||||
return csv(
|
||||
*sqls,
|
||||
*[self.sql(sql) for sql in expression.args.get("joins") or []],
|
||||
|
@ -97,19 +136,32 @@ class Oracle(Dialect):
|
|||
sep="",
|
||||
)
|
||||
|
||||
def offset_sql(self, expression):
|
||||
def offset_sql(self, expression: exp.Offset) -> str:
|
||||
return f"{super().offset_sql(expression)} ROWS"
|
||||
|
||||
def table_sql(self, expression):
|
||||
return super().table_sql(expression, sep=" ")
|
||||
def table_sql(self, expression: exp.Table, sep: str = " ") -> str:
|
||||
return super().table_sql(expression, sep=sep)
|
||||
|
||||
def xmltable_sql(self, expression: exp.XMLTable) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
passing = self.expressions(expression, "passing")
|
||||
passing = f"{self.sep()}PASSING{self.seg(passing)}" if passing else ""
|
||||
columns = self.expressions(expression, "columns")
|
||||
columns = f"{self.sep()}COLUMNS{self.seg(columns)}" if columns else ""
|
||||
by_ref = (
|
||||
f"{self.sep()}RETURNING SEQUENCE BY REF" if expression.args.get("by_ref") else ""
|
||||
)
|
||||
return f"XMLTABLE({self.sep('')}{self.indent(this + passing + by_ref + columns)}{self.seg(')', sep='')}"
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"COLUMNS": TokenType.COLUMN,
|
||||
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
|
||||
"MINUS": TokenType.EXCEPT,
|
||||
"NVARCHAR2": TokenType.NVARCHAR,
|
||||
"RETURNING": TokenType.RETURNING,
|
||||
"START": TokenType.BEGIN,
|
||||
"TOP": TokenType.TOP,
|
||||
"VARCHAR2": TokenType.VARCHAR,
|
||||
"NVARCHAR2": TokenType.NVARCHAR,
|
||||
}
|
||||
|
|
|
@ -58,17 +58,17 @@ def _date_diff_sql(self, expression):
|
|||
age = f"AGE({end}, {start})"
|
||||
|
||||
if unit == "WEEK":
|
||||
extract = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7"
|
||||
unit = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7"
|
||||
elif unit == "MONTH":
|
||||
extract = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})"
|
||||
unit = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})"
|
||||
elif unit == "QUARTER":
|
||||
extract = f"EXTRACT(year FROM {age}) * 4 + EXTRACT(month FROM {age}) / 3"
|
||||
unit = f"EXTRACT(year FROM {age}) * 4 + EXTRACT(month FROM {age}) / 3"
|
||||
elif unit == "YEAR":
|
||||
extract = f"EXTRACT(year FROM {age})"
|
||||
unit = f"EXTRACT(year FROM {age})"
|
||||
else:
|
||||
self.unsupported(f"Unsupported DATEDIFF unit {unit}")
|
||||
unit = age
|
||||
|
||||
return f"CAST({extract} AS BIGINT)"
|
||||
return f"CAST({unit} AS BIGINT)"
|
||||
|
||||
|
||||
def _substring_sql(self, expression):
|
||||
|
@ -206,6 +206,8 @@ class Postgres(Dialect):
|
|||
}
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
QUOTES = ["'", "$$"]
|
||||
|
||||
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
|
||||
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
|
||||
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
|
||||
|
@ -236,7 +238,7 @@ class Postgres(Dialect):
|
|||
"UUID": TokenType.UUID,
|
||||
"CSTRING": TokenType.PSEUDO_TYPE,
|
||||
}
|
||||
QUOTES = ["'", "$$"]
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
**tokens.Tokenizer.SINGLE_TOKENS,
|
||||
"$": TokenType.PARAMETER,
|
||||
|
@ -265,6 +267,7 @@ class Postgres(Dialect):
|
|||
|
||||
class Generator(generator.Generator):
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
PARAMETER_TOKEN = "$"
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
|
|
|
@ -52,7 +52,7 @@ def _initcap_sql(self, expression):
|
|||
|
||||
def _decode_sql(self, expression):
|
||||
_ensure_utf8(expression.args.get("charset"))
|
||||
return f"FROM_UTF8({self.format_args(expression.this, expression.args.get('replace'))})"
|
||||
return self.func("FROM_UTF8", expression.this, expression.args.get("replace"))
|
||||
|
||||
|
||||
def _encode_sql(self, expression):
|
||||
|
@ -65,8 +65,7 @@ def _no_sort_array(self, expression):
|
|||
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
|
||||
else:
|
||||
comparator = None
|
||||
args = self.format_args(expression.this, comparator)
|
||||
return f"ARRAY_SORT({args})"
|
||||
return self.func("ARRAY_SORT", expression.this, comparator)
|
||||
|
||||
|
||||
def _schema_sql(self, expression):
|
||||
|
@ -125,7 +124,7 @@ def _sequence_sql(self, expression):
|
|||
else:
|
||||
start = exp.Cast(this=start, to=to)
|
||||
|
||||
return f"SEQUENCE({self.format_args(start, end, step)})"
|
||||
return self.func("SEQUENCE", start, end, step)
|
||||
|
||||
|
||||
def _ensure_utf8(charset):
|
||||
|
|
|
@ -5,6 +5,7 @@ import typing as t
|
|||
from sqlglot import exp, transforms
|
||||
from sqlglot.dialects.dialect import rename_func
|
||||
from sqlglot.dialects.postgres import Postgres
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
|
@ -19,6 +20,11 @@ class Redshift(Postgres):
|
|||
class Parser(Postgres.Parser):
|
||||
FUNCTIONS = {
|
||||
**Postgres.Parser.FUNCTIONS, # type: ignore
|
||||
"DATEDIFF": lambda args: exp.DateDiff(
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
unit=seq_get(args, 0),
|
||||
),
|
||||
"DECODE": exp.Matches.from_arg_list,
|
||||
"NVL": exp.Coalesce.from_arg_list,
|
||||
}
|
||||
|
@ -41,7 +47,6 @@ class Redshift(Postgres):
|
|||
|
||||
KEYWORDS = {
|
||||
**Postgres.Tokenizer.KEYWORDS, # type: ignore
|
||||
"ENCODE": TokenType.ENCODE,
|
||||
"GEOMETRY": TokenType.GEOMETRY,
|
||||
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
||||
"HLLSKETCH": TokenType.HLLSKETCH,
|
||||
|
@ -62,12 +67,15 @@ class Redshift(Postgres):
|
|||
|
||||
PROPERTIES_LOCATION = {
|
||||
**Postgres.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.LikeProperty: exp.Properties.Location.POST_SCHEMA_WITH,
|
||||
exp.LikeProperty: exp.Properties.Location.POST_WITH,
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**Postgres.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
|
||||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATEDIFF", e.args.get("unit") or "day", e.expression, e.this
|
||||
),
|
||||
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
|
||||
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
|
||||
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
|
||||
|
|
|
@ -178,18 +178,25 @@ class Snowflake(Dialect):
|
|||
),
|
||||
}
|
||||
|
||||
RANGE_PARSERS = {
|
||||
**parser.Parser.RANGE_PARSERS, # type: ignore
|
||||
TokenType.LIKE_ANY: lambda self, this: self._parse_escape(
|
||||
self.expression(exp.LikeAny, this=this, expression=self._parse_bitwise())
|
||||
),
|
||||
TokenType.ILIKE_ANY: lambda self, this: self._parse_escape(
|
||||
self.expression(exp.ILikeAny, this=this, expression=self._parse_bitwise())
|
||||
),
|
||||
}
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
QUOTES = ["'", "$$"]
|
||||
STRING_ESCAPES = ["\\", "'"]
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
**tokens.Tokenizer.SINGLE_TOKENS,
|
||||
"$": TokenType.PARAMETER,
|
||||
}
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"EXCLUDE": TokenType.EXCEPT,
|
||||
"ILIKE ANY": TokenType.ILIKE_ANY,
|
||||
"LIKE ANY": TokenType.LIKE_ANY,
|
||||
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
|
||||
"PUT": TokenType.COMMAND,
|
||||
"RENAME": TokenType.REPLACE,
|
||||
|
@ -201,8 +208,14 @@ class Snowflake(Dialect):
|
|||
"SAMPLE": TokenType.TABLE_SAMPLE,
|
||||
}
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
**tokens.Tokenizer.SINGLE_TOKENS,
|
||||
"$": TokenType.PARAMETER,
|
||||
}
|
||||
|
||||
class Generator(generator.Generator):
|
||||
CREATE_TRANSIENT = True
|
||||
PARAMETER_TOKEN = "$"
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
|
@ -214,14 +227,15 @@ class Snowflake(Dialect):
|
|||
exp.If: rename_func("IFF"),
|
||||
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
exp.Matches: rename_func("DECODE"),
|
||||
exp.StrPosition: lambda self, e: f"{self.normalize_func('POSITION')}({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})",
|
||||
exp.StrPosition: lambda self, e: self.func(
|
||||
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
|
||||
),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
|
||||
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
|
||||
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
}
|
||||
|
@ -236,6 +250,12 @@ class Snowflake(Dialect):
|
|||
"replace": "RENAME",
|
||||
}
|
||||
|
||||
def ilikeany_sql(self, expression: exp.ILikeAny) -> str:
|
||||
return self.binary(expression, "ILIKE ANY")
|
||||
|
||||
def likeany_sql(self, expression: exp.LikeAny) -> str:
|
||||
return self.binary(expression, "LIKE ANY")
|
||||
|
||||
def except_op(self, expression):
|
||||
if not expression.args.get("distinct", False):
|
||||
self.unsupported("EXCEPT with All is not supported in Snowflake")
|
||||
|
|
|
@ -86,6 +86,11 @@ class Spark(Hive):
|
|||
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
|
||||
this=seq_get(args, 1),
|
||||
unit=exp.var(seq_get(args, 0)),
|
||||
),
|
||||
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
|
@ -133,7 +138,7 @@ class Spark(Hive):
|
|||
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
|
||||
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
|
||||
exp.DateTrunc: rename_func("TRUNC"),
|
||||
exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
|
||||
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
|
||||
exp.StrToDate: _str_to_date,
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
|
@ -142,7 +147,9 @@ class Spark(Hive):
|
|||
exp.Map: _map_sql,
|
||||
exp.Reduce: rename_func("AGGREGATE"),
|
||||
exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
|
||||
exp.TimestampTrunc: lambda self, e: f"DATE_TRUNC({self.sql(e, 'unit')}, {self.sql(e, 'this')})",
|
||||
exp.TimestampTrunc: lambda self, e: self.func(
|
||||
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
|
||||
),
|
||||
exp.Trim: trim_sql,
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
exp.DateFromParts: rename_func("MAKE_DATE"),
|
||||
|
@ -157,16 +164,16 @@ class Spark(Hive):
|
|||
TRANSFORMS.pop(exp.ILike)
|
||||
|
||||
WRAP_DERIVED_VALUES = False
|
||||
CREATE_FUNCTION_AS = False
|
||||
CREATE_FUNCTION_RETURN_AS = False
|
||||
|
||||
def cast_sql(self, expression: exp.Cast) -> str:
|
||||
if isinstance(expression.this, exp.Cast) and expression.this.is_type(
|
||||
exp.DataType.Type.JSON
|
||||
):
|
||||
schema = f"'{self.sql(expression, 'to')}'"
|
||||
return f"FROM_JSON({self.format_args(self.sql(expression.this, 'this'), schema)})"
|
||||
return self.func("FROM_JSON", expression.this.this, schema)
|
||||
if expression.to.is_type(exp.DataType.Type.JSON):
|
||||
return f"TO_JSON({self.sql(expression, 'this')})"
|
||||
return self.func("TO_JSON", expression.this)
|
||||
|
||||
return super(Spark.Generator, self).cast_sql(expression)
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ def _date_add_sql(self, expression):
|
|||
modifier = expression.name if modifier.is_string else self.sql(modifier)
|
||||
unit = expression.args.get("unit")
|
||||
modifier = f"'{modifier} {unit.name}'" if unit else f"'{modifier}'"
|
||||
return f"{self.normalize_func('DATE')}({self.format_args(expression.this, modifier)})"
|
||||
return self.func("DATE", expression.this, modifier)
|
||||
|
||||
|
||||
class SQLite(Dialect):
|
||||
|
|
|
@ -1,11 +1,33 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, parser
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
class Teradata(Dialect):
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
# https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Comparison-Operators-and-Functions/Comparison-Operators/ANSI-Compliance
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"BYTEINT": TokenType.SMALLINT,
|
||||
"SEL": TokenType.SELECT,
|
||||
"INS": TokenType.INSERT,
|
||||
"MOD": TokenType.MOD,
|
||||
"LT": TokenType.LT,
|
||||
"LE": TokenType.LTE,
|
||||
"GT": TokenType.GT,
|
||||
"GE": TokenType.GTE,
|
||||
"^=": TokenType.NEQ,
|
||||
"NE": TokenType.NEQ,
|
||||
"NOT=": TokenType.NEQ,
|
||||
"ST_GEOMETRY": TokenType.GEOMETRY,
|
||||
}
|
||||
|
||||
# teradata does not support % for modulus
|
||||
SINGLE_TOKENS = {**tokens.Tokenizer.SINGLE_TOKENS}
|
||||
SINGLE_TOKENS.pop("%")
|
||||
|
||||
class Parser(parser.Parser):
|
||||
CHARSET_TRANSLATORS = {
|
||||
"GRAPHIC_TO_KANJISJIS",
|
||||
|
@ -42,6 +64,14 @@ class Teradata(Dialect):
|
|||
"UNICODE_TO_UNICODE_NFKD",
|
||||
}
|
||||
|
||||
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS}
|
||||
FUNC_TOKENS.remove(TokenType.REPLACE)
|
||||
|
||||
STATEMENT_PARSERS = {
|
||||
**parser.Parser.STATEMENT_PARSERS, # type: ignore
|
||||
TokenType.REPLACE: lambda self: self._parse_create(),
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS, # type: ignore
|
||||
"TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST),
|
||||
|
@ -76,6 +106,11 @@ class Teradata(Dialect):
|
|||
)
|
||||
|
||||
class Generator(generator.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.GEOMETRY: "ST_GEOMETRY",
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX,
|
||||
|
@ -93,3 +128,11 @@ class Teradata(Dialect):
|
|||
where_sql = self.sql(expression, "where")
|
||||
sql = f"UPDATE {this}{from_sql} SET {set_sql}{where_sql}"
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
def mod_sql(self, expression: exp.Mod) -> str:
|
||||
return self.binary(expression, "MOD")
|
||||
|
||||
def datatype_sql(self, expression: exp.DataType) -> str:
|
||||
type_sql = super().datatype_sql(expression)
|
||||
prefix_sql = expression.args.get("prefix")
|
||||
return f"SYSUDTLIB.{type_sql}" if prefix_sql else type_sql
|
||||
|
|
|
@ -92,7 +92,7 @@ def _parse_eomonth(args):
|
|||
|
||||
def generate_date_delta_with_unit_sql(self, e):
|
||||
func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF"
|
||||
return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})"
|
||||
return self.func(func, e.text("unit"), e.expression, e.this)
|
||||
|
||||
|
||||
def _format_sql(self, e):
|
||||
|
@ -101,7 +101,7 @@ def _format_sql(self, e):
|
|||
if isinstance(e, exp.NumberToStr)
|
||||
else exp.Literal.string(format_time(e.text("format"), TSQL.inverse_time_mapping))
|
||||
)
|
||||
return f"FORMAT({self.format_args(e.this, fmt)})"
|
||||
return self.func("FORMAT", e.this, fmt)
|
||||
|
||||
|
||||
def _string_agg_sql(self, e):
|
||||
|
@ -408,7 +408,7 @@ class TSQL(Dialect):
|
|||
):
|
||||
return this
|
||||
|
||||
expressions = self._parse_csv(self._parse_udf_kwarg)
|
||||
expressions = self._parse_csv(self._parse_function_parameter)
|
||||
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
|
||||
|
||||
class Generator(generator.Generator):
|
||||
|
|
|
@ -62,10 +62,8 @@ def execute(
|
|||
if tables_.supported_table_args and tables_.supported_table_args != schema.supported_table_args:
|
||||
raise ExecuteError("Tables must support the same table args as schema")
|
||||
|
||||
expression = maybe_parse(sql, dialect=read)
|
||||
|
||||
now = time.time()
|
||||
expression = optimize(expression, schema, leave_tables_isolated=True)
|
||||
expression = optimize(sql, schema, leave_tables_isolated=True, dialect=read)
|
||||
|
||||
logger.debug("Optimization finished: %f", time.time() - now)
|
||||
logger.debug("Optimized SQL: %s", expression.sql(pretty=True))
|
||||
|
|
|
@ -376,7 +376,7 @@ def _rename(self, e):
|
|||
this = self.sql(e, "this")
|
||||
this = f"{this}, " if this else ""
|
||||
return f"{e.key.upper()}({this}{self.expressions(e)})"
|
||||
return f"{e.key.upper()}({self.format_args(*e.args.values())})"
|
||||
return self.func(e.key, *e.args.values())
|
||||
except Exception as ex:
|
||||
raise Exception(f"Could not rename {repr(e)}") from ex
|
||||
|
||||
|
|
|
@ -128,7 +128,7 @@ class Expression(metaclass=_Expression):
|
|||
"""
|
||||
return self.args.get("expressions") or []
|
||||
|
||||
def text(self, key):
|
||||
def text(self, key) -> str:
|
||||
"""
|
||||
Returns a textual representation of the argument corresponding to "key". This can only be used
|
||||
for args that are strings or leaf Expression instances, such as identifiers and literals.
|
||||
|
@ -143,21 +143,21 @@ class Expression(metaclass=_Expression):
|
|||
return ""
|
||||
|
||||
@property
|
||||
def is_string(self):
|
||||
def is_string(self) -> bool:
|
||||
"""
|
||||
Checks whether a Literal expression is a string.
|
||||
"""
|
||||
return isinstance(self, Literal) and self.args["is_string"]
|
||||
|
||||
@property
|
||||
def is_number(self):
|
||||
def is_number(self) -> bool:
|
||||
"""
|
||||
Checks whether a Literal expression is a number.
|
||||
"""
|
||||
return isinstance(self, Literal) and not self.args["is_string"]
|
||||
|
||||
@property
|
||||
def is_int(self):
|
||||
def is_int(self) -> bool:
|
||||
"""
|
||||
Checks whether a Literal expression is an integer.
|
||||
"""
|
||||
|
@ -170,7 +170,12 @@ class Expression(metaclass=_Expression):
|
|||
return False
|
||||
|
||||
@property
|
||||
def alias(self):
|
||||
def is_star(self) -> bool:
|
||||
"""Checks whether an expression is a star."""
|
||||
return isinstance(self, Star) or (isinstance(self, Column) and isinstance(self.this, Star))
|
||||
|
||||
@property
|
||||
def alias(self) -> str:
|
||||
"""
|
||||
Returns the alias of the expression, or an empty string if it's not aliased.
|
||||
"""
|
||||
|
@ -825,10 +830,6 @@ class UserDefinedFunction(Expression):
|
|||
arg_types = {"this": True, "expressions": False, "wrapped": False}
|
||||
|
||||
|
||||
class UserDefinedFunctionKwarg(Expression):
|
||||
arg_types = {"this": True, "kind": True, "default": False}
|
||||
|
||||
|
||||
class CharacterSet(Expression):
|
||||
arg_types = {"this": True, "default": False}
|
||||
|
||||
|
@ -870,14 +871,22 @@ class ByteString(Condition):
|
|||
|
||||
|
||||
class Column(Condition):
|
||||
arg_types = {"this": True, "table": False}
|
||||
arg_types = {"this": True, "table": False, "db": False, "catalog": False}
|
||||
|
||||
@property
|
||||
def table(self):
|
||||
def table(self) -> str:
|
||||
return self.text("table")
|
||||
|
||||
@property
|
||||
def output_name(self):
|
||||
def db(self) -> str:
|
||||
return self.text("db")
|
||||
|
||||
@property
|
||||
def catalog(self) -> str:
|
||||
return self.text("catalog")
|
||||
|
||||
@property
|
||||
def output_name(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
|
@ -917,6 +926,14 @@ class AutoIncrementColumnConstraint(ColumnConstraintKind):
|
|||
pass
|
||||
|
||||
|
||||
class CaseSpecificColumnConstraint(ColumnConstraintKind):
|
||||
arg_types = {"not_": True}
|
||||
|
||||
|
||||
class CharacterSetColumnConstraint(ColumnConstraintKind):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class CheckColumnConstraint(ColumnConstraintKind):
|
||||
pass
|
||||
|
||||
|
@ -929,6 +946,10 @@ class CommentColumnConstraint(ColumnConstraintKind):
|
|||
pass
|
||||
|
||||
|
||||
class DateFormatColumnConstraint(ColumnConstraintKind):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class DefaultColumnConstraint(ColumnConstraintKind):
|
||||
pass
|
||||
|
||||
|
@ -939,7 +960,14 @@ class EncodeColumnConstraint(ColumnConstraintKind):
|
|||
|
||||
class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
|
||||
# this: True -> ALWAYS, this: False -> BY DEFAULT
|
||||
arg_types = {"this": False, "start": False, "increment": False}
|
||||
arg_types = {
|
||||
"this": False,
|
||||
"start": False,
|
||||
"increment": False,
|
||||
"minvalue": False,
|
||||
"maxvalue": False,
|
||||
"cycle": False,
|
||||
}
|
||||
|
||||
|
||||
class NotNullColumnConstraint(ColumnConstraintKind):
|
||||
|
@ -950,7 +978,19 @@ class PrimaryKeyColumnConstraint(ColumnConstraintKind):
|
|||
arg_types = {"desc": False}
|
||||
|
||||
|
||||
class TitleColumnConstraint(ColumnConstraintKind):
|
||||
pass
|
||||
|
||||
|
||||
class UniqueColumnConstraint(ColumnConstraintKind):
|
||||
arg_types: t.Dict[str, t.Any] = {}
|
||||
|
||||
|
||||
class UppercaseColumnConstraint(ColumnConstraintKind):
|
||||
arg_types: t.Dict[str, t.Any] = {}
|
||||
|
||||
|
||||
class PathColumnConstraint(ColumnConstraintKind):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -1063,6 +1103,7 @@ class Insert(Expression):
|
|||
"overwrite": False,
|
||||
"exists": False,
|
||||
"partition": False,
|
||||
"alternative": False,
|
||||
}
|
||||
|
||||
|
||||
|
@ -1438,6 +1479,16 @@ class IsolatedLoadingProperty(Property):
|
|||
}
|
||||
|
||||
|
||||
class LockingProperty(Property):
|
||||
arg_types = {
|
||||
"this": False,
|
||||
"kind": True,
|
||||
"for_or_in": True,
|
||||
"lock_type": True,
|
||||
"override": False,
|
||||
}
|
||||
|
||||
|
||||
class Properties(Expression):
|
||||
arg_types = {"expressions": True}
|
||||
|
||||
|
@ -1463,12 +1514,26 @@ class Properties(Expression):
|
|||
|
||||
PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()}
|
||||
|
||||
# CREATE property locations
|
||||
# Form: schema specified
|
||||
# create [POST_CREATE]
|
||||
# table a [POST_NAME]
|
||||
# (b int) [POST_SCHEMA]
|
||||
# with ([POST_WITH])
|
||||
# index (b) [POST_INDEX]
|
||||
#
|
||||
# Form: alias selection
|
||||
# create [POST_CREATE]
|
||||
# table a [POST_NAME]
|
||||
# as [POST_ALIAS] (select * from b)
|
||||
# index (c) [POST_INDEX]
|
||||
class Location(AutoName):
|
||||
POST_CREATE = auto()
|
||||
PRE_SCHEMA = auto()
|
||||
POST_NAME = auto()
|
||||
POST_SCHEMA = auto()
|
||||
POST_WITH = auto()
|
||||
POST_ALIAS = auto()
|
||||
POST_INDEX = auto()
|
||||
POST_SCHEMA_ROOT = auto()
|
||||
POST_SCHEMA_WITH = auto()
|
||||
UNSUPPORTED = auto()
|
||||
|
||||
@classmethod
|
||||
|
@ -1633,6 +1698,14 @@ class Table(Expression):
|
|||
"system_time": False,
|
||||
}
|
||||
|
||||
@property
|
||||
def db(self) -> str:
|
||||
return self.text("db")
|
||||
|
||||
@property
|
||||
def catalog(self) -> str:
|
||||
return self.text("catalog")
|
||||
|
||||
|
||||
# See the TSQL "Querying data in a system-versioned temporal table" page
|
||||
class SystemTime(Expression):
|
||||
|
@ -1678,6 +1751,40 @@ class Union(Subqueryable):
|
|||
.limit(expression, dialect=dialect, copy=False, **opts)
|
||||
)
|
||||
|
||||
def select(
|
||||
self,
|
||||
*expressions: str | Expression,
|
||||
append: bool = True,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
) -> Union:
|
||||
"""Append to or set the SELECT of the union recursively.
|
||||
|
||||
Example:
|
||||
>>> from sqlglot import parse_one
|
||||
>>> parse_one("select a from x union select a from y union select a from z").select("b").sql()
|
||||
'SELECT a, b FROM x UNION SELECT a, b FROM y UNION SELECT a, b FROM z'
|
||||
|
||||
Args:
|
||||
*expressions: the SQL code strings to parse.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
append: if `True`, add to any existing expressions.
|
||||
Otherwise, this resets the expressions.
|
||||
dialect: the dialect used to parse the input expressions.
|
||||
copy: if `False`, modify this expression instance in-place.
|
||||
opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
Union: the modified expression.
|
||||
"""
|
||||
this = self.copy() if copy else self
|
||||
this.this.unnest().select(*expressions, append=append, dialect=dialect, copy=False, **opts)
|
||||
this.expression.unnest().select(
|
||||
*expressions, append=append, dialect=dialect, copy=False, **opts
|
||||
)
|
||||
return this
|
||||
|
||||
@property
|
||||
def named_selects(self):
|
||||
return self.this.unnest().named_selects
|
||||
|
@ -1985,7 +2092,14 @@ class Select(Subqueryable):
|
|||
**opts,
|
||||
)
|
||||
|
||||
def select(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
|
||||
def select(
|
||||
self,
|
||||
*expressions: str | Expression,
|
||||
append: bool = True,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
) -> Select:
|
||||
"""
|
||||
Append to or set the SELECT expressions.
|
||||
|
||||
|
@ -1994,13 +2108,13 @@ class Select(Subqueryable):
|
|||
'SELECT x, y'
|
||||
|
||||
Args:
|
||||
*expressions (str | Expression): the SQL code strings to parse.
|
||||
*expressions: the SQL code strings to parse.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
append (bool): if `True`, add to any existing expressions.
|
||||
append: if `True`, add to any existing expressions.
|
||||
Otherwise, this resets the expressions.
|
||||
dialect (str): the dialect used to parse the input expressions.
|
||||
copy (bool): if `False`, modify this expression instance in-place.
|
||||
opts (kwargs): other options to use to parse the input expressions.
|
||||
dialect: the dialect used to parse the input expressions.
|
||||
copy: if `False`, modify this expression instance in-place.
|
||||
opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
Select: the modified expression.
|
||||
|
@ -2399,7 +2513,7 @@ class Star(Expression):
|
|||
|
||||
|
||||
class Parameter(Expression):
|
||||
pass
|
||||
arg_types = {"this": True, "wrapped": False}
|
||||
|
||||
|
||||
class SessionParameter(Expression):
|
||||
|
@ -2428,6 +2542,7 @@ class DataType(Expression):
|
|||
"expressions": False,
|
||||
"nested": False,
|
||||
"values": False,
|
||||
"prefix": False,
|
||||
}
|
||||
|
||||
class Type(AutoName):
|
||||
|
@ -2693,6 +2808,10 @@ class ILike(Binary, Predicate):
|
|||
pass
|
||||
|
||||
|
||||
class ILikeAny(Binary, Predicate):
|
||||
pass
|
||||
|
||||
|
||||
class IntDiv(Binary):
|
||||
pass
|
||||
|
||||
|
@ -2709,6 +2828,10 @@ class Like(Binary, Predicate):
|
|||
pass
|
||||
|
||||
|
||||
class LikeAny(Binary, Predicate):
|
||||
pass
|
||||
|
||||
|
||||
class LT(Binary, Predicate):
|
||||
pass
|
||||
|
||||
|
@ -3042,7 +3165,7 @@ class DateDiff(Func, TimeUnit):
|
|||
|
||||
|
||||
class DateTrunc(Func):
|
||||
arg_types = {"this": True, "expression": True, "zone": False}
|
||||
arg_types = {"unit": True, "this": True, "zone": False}
|
||||
|
||||
|
||||
class DatetimeAdd(Func, TimeUnit):
|
||||
|
@ -3330,6 +3453,16 @@ class Reduce(Func):
|
|||
arg_types = {"this": True, "initial": True, "merge": True, "finish": False}
|
||||
|
||||
|
||||
class RegexpExtract(Func):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"expression": True,
|
||||
"position": False,
|
||||
"occurrence": False,
|
||||
"group": False,
|
||||
}
|
||||
|
||||
|
||||
class RegexpLike(Func):
|
||||
arg_types = {"this": True, "expression": True, "flag": False}
|
||||
|
||||
|
@ -3519,6 +3652,10 @@ class Week(Func):
|
|||
arg_types = {"this": True, "mode": False}
|
||||
|
||||
|
||||
class XMLTable(Func):
|
||||
arg_types = {"this": True, "passing": False, "columns": False, "by_ref": False}
|
||||
|
||||
|
||||
class Year(Func):
|
||||
pass
|
||||
|
||||
|
@ -3566,6 +3703,7 @@ def maybe_parse(
|
|||
into: t.Optional[IntoType] = None,
|
||||
dialect: DialectType = None,
|
||||
prefix: t.Optional[str] = None,
|
||||
copy: bool = False,
|
||||
**opts,
|
||||
) -> Expression:
|
||||
"""Gracefully handle a possible string or expression.
|
||||
|
@ -3583,6 +3721,7 @@ def maybe_parse(
|
|||
input expression is a SQL string).
|
||||
prefix: a string to prefix the sql with before it gets parsed
|
||||
(automatically includes a space)
|
||||
copy: whether or not to copy the expression.
|
||||
**opts: other options to use to parse the input expressions (again, in the case
|
||||
that an input expression is a SQL string).
|
||||
|
||||
|
@ -3590,6 +3729,8 @@ def maybe_parse(
|
|||
Expression: the parsed or given expression.
|
||||
"""
|
||||
if isinstance(sql_or_expression, Expression):
|
||||
if copy:
|
||||
return sql_or_expression.copy()
|
||||
return sql_or_expression
|
||||
|
||||
import sqlglot
|
||||
|
@ -3818,7 +3959,7 @@ def except_(left, right, distinct=True, dialect=None, **opts):
|
|||
return Except(this=left, expression=right, distinct=distinct)
|
||||
|
||||
|
||||
def select(*expressions, dialect=None, **opts) -> Select:
|
||||
def select(*expressions: str | Expression, dialect: DialectType = None, **opts) -> Select:
|
||||
"""
|
||||
Initializes a syntax tree from one or multiple SELECT expressions.
|
||||
|
||||
|
@ -3827,9 +3968,9 @@ def select(*expressions, dialect=None, **opts) -> Select:
|
|||
'SELECT col1, col2 FROM tbl'
|
||||
|
||||
Args:
|
||||
*expressions (str | Expression): the SQL code string to parse as the expressions of a
|
||||
*expressions: the SQL code string to parse as the expressions of a
|
||||
SELECT statement. If an Expression instance is passed, this is used as-is.
|
||||
dialect (str): the dialect used to parse the input expressions (in the case that an
|
||||
dialect: the dialect used to parse the input expressions (in the case that an
|
||||
input expression is a SQL string).
|
||||
**opts: other options to use to parse the input expressions (again, in the case
|
||||
that an input expression is a SQL string).
|
||||
|
@ -4219,19 +4360,27 @@ def subquery(expression, alias=None, dialect=None, **opts):
|
|||
return Select().from_(expression, dialect=dialect, **opts)
|
||||
|
||||
|
||||
def column(col, table=None, quoted=None) -> Column:
|
||||
def column(
|
||||
col: str | Identifier,
|
||||
table: t.Optional[str | Identifier] = None,
|
||||
schema: t.Optional[str | Identifier] = None,
|
||||
quoted: t.Optional[bool] = None,
|
||||
) -> Column:
|
||||
"""
|
||||
Build a Column.
|
||||
|
||||
Args:
|
||||
col (str | Expression): column name
|
||||
table (str | Expression): table name
|
||||
col: column name
|
||||
table: table name
|
||||
schema: schema name
|
||||
quoted: whether or not to force quote each part
|
||||
Returns:
|
||||
Column: column instance
|
||||
"""
|
||||
return Column(
|
||||
this=to_identifier(col, quoted=quoted),
|
||||
table=to_identifier(table, quoted=quoted),
|
||||
schema=to_identifier(schema, quoted=quoted),
|
||||
)
|
||||
|
||||
|
||||
|
@ -4314,6 +4463,30 @@ def values(
|
|||
)
|
||||
|
||||
|
||||
def var(name: t.Optional[str | Expression]) -> Var:
|
||||
"""Build a SQL variable.
|
||||
|
||||
Example:
|
||||
>>> repr(var('x'))
|
||||
'(VAR this: x)'
|
||||
|
||||
>>> repr(var(column('x', table='y')))
|
||||
'(VAR this: x)'
|
||||
|
||||
Args:
|
||||
name: The name of the var or an expression who's name will become the var.
|
||||
|
||||
Returns:
|
||||
The new variable node.
|
||||
"""
|
||||
if not name:
|
||||
raise ValueError(f"Cannot convert empty name into var.")
|
||||
|
||||
if isinstance(name, Expression):
|
||||
name = name.name
|
||||
return Var(this=name)
|
||||
|
||||
|
||||
def rename_table(old_name: str | Table, new_name: str | Table) -> AlterTable:
|
||||
"""Build ALTER TABLE... RENAME... expression
|
||||
|
||||
|
|
|
@ -1,19 +1,16 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
|
||||
from sqlglot.helper import apply_index_offset, csv
|
||||
from sqlglot.helper import apply_index_offset, csv, seq_get
|
||||
from sqlglot.time import format_time
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
BACKSLASH_RE = re.compile(r"\\(?!b|f|n|r|t|0)")
|
||||
|
||||
|
||||
class Generator:
|
||||
"""
|
||||
|
@ -59,10 +56,14 @@ class Generator:
|
|||
"""
|
||||
|
||||
TRANSFORMS = {
|
||||
exp.DateAdd: lambda self, e: f"DATE_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
|
||||
exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})",
|
||||
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
|
||||
exp.VarMap: lambda self, e: f"MAP({self.format_args(e.args['keys'], e.args['values'])})",
|
||||
exp.DateAdd: lambda self, e: self.func(
|
||||
"DATE_ADD", e.this, e.expression, e.args.get("unit")
|
||||
),
|
||||
exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression),
|
||||
exp.TsOrDsAdd: lambda self, e: self.func(
|
||||
"TS_OR_DS_ADD", e.this, e.expression, e.args.get("unit")
|
||||
),
|
||||
exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]),
|
||||
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'this')}",
|
||||
exp.LanguageProperty: lambda self, e: self.naked_property(e),
|
||||
exp.LocationProperty: lambda self, e: self.naked_property(e),
|
||||
|
@ -72,6 +73,17 @@ class Generator:
|
|||
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
|
||||
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
|
||||
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
|
||||
exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
|
||||
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
|
||||
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
|
||||
exp.UppercaseColumnConstraint: lambda self, e: f"UPPERCASE",
|
||||
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
|
||||
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
|
||||
exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})",
|
||||
exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
|
||||
exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
|
||||
exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}",
|
||||
exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}",
|
||||
}
|
||||
|
||||
# Whether 'CREATE ... TRANSIENT ... TABLE' is allowed
|
||||
|
@ -89,8 +101,8 @@ class Generator:
|
|||
# Wrap derived values in parens, usually standard but spark doesn't support it
|
||||
WRAP_DERIVED_VALUES = True
|
||||
|
||||
# Whether or not create function uses an AS before the def.
|
||||
CREATE_FUNCTION_AS = True
|
||||
# Whether or not create function uses an AS before the RETURN
|
||||
CREATE_FUNCTION_RETURN_AS = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
exp.DataType.Type.NCHAR: "CHAR",
|
||||
|
@ -110,42 +122,46 @@ class Generator:
|
|||
|
||||
STRUCT_DELIMITER = ("<", ">")
|
||||
|
||||
PARAMETER_TOKEN = "@"
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
exp.AfterJournalProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||
exp.AfterJournalProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.BlockCompressionProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||
exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.ChecksumProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||
exp.CollateProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.DataBlocksizeProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||
exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.BlockCompressionProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.ChecksumProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.CollateProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.Cluster: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.DataBlocksizeProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.DefinerProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.EngineProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.ExecuteAsProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.FallbackProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA_WITH,
|
||||
exp.FreespaceProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||
exp.IsolatedLoadingProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||
exp.JournalProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||
exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.LikeProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.LocationProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.LogProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||
exp.MergeBlockRatioProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_WITH,
|
||||
exp.Property: exp.Properties.Location.POST_SCHEMA_WITH,
|
||||
exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.EngineProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.ExecuteAsProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.FallbackProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.FileFormatProperty: exp.Properties.Location.POST_WITH,
|
||||
exp.FreespaceProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.IsolatedLoadingProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.JournalProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.LikeProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.LocationProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.LockingProperty: exp.Properties.Location.POST_ALIAS,
|
||||
exp.LogProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.MergeBlockRatioProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_WITH,
|
||||
exp.Property: exp.Properties.Location.POST_WITH,
|
||||
exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA_WITH,
|
||||
exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.WithJournalTableProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||
exp.TableFormatProperty: exp.Properties.Location.POST_WITH,
|
||||
exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
|
||||
}
|
||||
|
||||
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
|
||||
|
@ -173,7 +189,6 @@ class Generator:
|
|||
"null_ordering",
|
||||
"max_unsupported",
|
||||
"_indent",
|
||||
"_replace_backslash",
|
||||
"_escaped_quote_end",
|
||||
"_escaped_identifier_end",
|
||||
"_leading_comma",
|
||||
|
@ -230,7 +245,6 @@ class Generator:
|
|||
self.max_unsupported = max_unsupported
|
||||
self.null_ordering = null_ordering
|
||||
self._indent = indent
|
||||
self._replace_backslash = self.string_escape == "\\"
|
||||
self._escaped_quote_end = self.string_escape + self.quote_end
|
||||
self._escaped_identifier_end = self.identifier_escape + self.identifier_end
|
||||
self._leading_comma = leading_comma
|
||||
|
@ -403,12 +417,13 @@ class Generator:
|
|||
|
||||
def column_sql(self, expression: exp.Column) -> str:
|
||||
return ".".join(
|
||||
part
|
||||
for part in [
|
||||
self.sql(expression, "db"),
|
||||
self.sql(expression, "table"),
|
||||
self.sql(expression, "this"),
|
||||
]
|
||||
self.sql(part)
|
||||
for part in (
|
||||
expression.args.get("catalog"),
|
||||
expression.args.get("db"),
|
||||
expression.args.get("table"),
|
||||
expression.args.get("this"),
|
||||
)
|
||||
if part
|
||||
)
|
||||
|
||||
|
@ -430,26 +445,6 @@ class Generator:
|
|||
def autoincrementcolumnconstraint_sql(self, _) -> str:
|
||||
return self.token_sql(TokenType.AUTO_INCREMENT)
|
||||
|
||||
def checkcolumnconstraint_sql(self, expression: exp.CheckColumnConstraint) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
return f"CHECK ({this})"
|
||||
|
||||
def commentcolumnconstraint_sql(self, expression: exp.CommentColumnConstraint) -> str:
|
||||
comment = self.sql(expression, "this")
|
||||
return f"COMMENT {comment}"
|
||||
|
||||
def collatecolumnconstraint_sql(self, expression: exp.CollateColumnConstraint) -> str:
|
||||
collate = self.sql(expression, "this")
|
||||
return f"COLLATE {collate}"
|
||||
|
||||
def encodecolumnconstraint_sql(self, expression: exp.EncodeColumnConstraint) -> str:
|
||||
encode = self.sql(expression, "this")
|
||||
return f"ENCODE {encode}"
|
||||
|
||||
def defaultcolumnconstraint_sql(self, expression: exp.DefaultColumnConstraint) -> str:
|
||||
default = self.sql(expression, "this")
|
||||
return f"DEFAULT {default}"
|
||||
|
||||
def generatedasidentitycolumnconstraint_sql(
|
||||
self, expression: exp.GeneratedAsIdentityColumnConstraint
|
||||
) -> str:
|
||||
|
@ -459,10 +454,19 @@ class Generator:
|
|||
start = expression.args.get("start")
|
||||
start = f"START WITH {start}" if start else ""
|
||||
increment = expression.args.get("increment")
|
||||
increment = f"INCREMENT BY {increment}" if increment else ""
|
||||
increment = f" INCREMENT BY {increment}" if increment else ""
|
||||
minvalue = expression.args.get("minvalue")
|
||||
minvalue = f" MINVALUE {minvalue}" if minvalue else ""
|
||||
maxvalue = expression.args.get("maxvalue")
|
||||
maxvalue = f" MAXVALUE {maxvalue}" if maxvalue else ""
|
||||
cycle = expression.args.get("cycle")
|
||||
cycle_sql = ""
|
||||
if cycle is not None:
|
||||
cycle_sql = f"{' NO' if not cycle else ''} CYCLE"
|
||||
cycle_sql = cycle_sql.strip() if not start and not increment else cycle_sql
|
||||
sequence_opts = ""
|
||||
if start or increment:
|
||||
sequence_opts = f"{start} {increment}"
|
||||
if start or increment or cycle_sql:
|
||||
sequence_opts = f"{start}{increment}{minvalue}{maxvalue}{cycle_sql}"
|
||||
sequence_opts = f" ({sequence_opts.strip()})"
|
||||
return f"GENERATED{this}AS IDENTITY{sequence_opts}"
|
||||
|
||||
|
@ -483,22 +487,22 @@ class Generator:
|
|||
properties = expression.args.get("properties")
|
||||
properties_exp = expression.copy()
|
||||
properties_locs = self.locate_properties(properties) if properties else {}
|
||||
if properties_locs.get(exp.Properties.Location.POST_SCHEMA_ROOT) or properties_locs.get(
|
||||
exp.Properties.Location.POST_SCHEMA_WITH
|
||||
if properties_locs.get(exp.Properties.Location.POST_SCHEMA) or properties_locs.get(
|
||||
exp.Properties.Location.POST_WITH
|
||||
):
|
||||
properties_exp.set(
|
||||
"properties",
|
||||
exp.Properties(
|
||||
expressions=[
|
||||
*properties_locs[exp.Properties.Location.POST_SCHEMA_ROOT],
|
||||
*properties_locs[exp.Properties.Location.POST_SCHEMA_WITH],
|
||||
*properties_locs[exp.Properties.Location.POST_SCHEMA],
|
||||
*properties_locs[exp.Properties.Location.POST_WITH],
|
||||
]
|
||||
),
|
||||
)
|
||||
if kind == "TABLE" and properties_locs.get(exp.Properties.Location.PRE_SCHEMA):
|
||||
if kind == "TABLE" and properties_locs.get(exp.Properties.Location.POST_NAME):
|
||||
this_name = self.sql(expression.this, "this")
|
||||
this_properties = self.properties(
|
||||
exp.Properties(expressions=properties_locs[exp.Properties.Location.PRE_SCHEMA]),
|
||||
exp.Properties(expressions=properties_locs[exp.Properties.Location.POST_NAME]),
|
||||
wrapped=False,
|
||||
)
|
||||
this_schema = f"({self.expressions(expression.this)})"
|
||||
|
@ -512,8 +516,17 @@ class Generator:
|
|||
if expression_sql:
|
||||
expression_sql = f"{begin}{self.sep()}{expression_sql}"
|
||||
|
||||
if self.CREATE_FUNCTION_AS or kind != "FUNCTION":
|
||||
expression_sql = f" AS{expression_sql}"
|
||||
if self.CREATE_FUNCTION_RETURN_AS or not isinstance(expression.expression, exp.Return):
|
||||
if properties_locs.get(exp.Properties.Location.POST_ALIAS):
|
||||
postalias_props_sql = self.properties(
|
||||
exp.Properties(
|
||||
expressions=properties_locs[exp.Properties.Location.POST_ALIAS]
|
||||
),
|
||||
wrapped=False,
|
||||
)
|
||||
expression_sql = f" AS {postalias_props_sql}{expression_sql}"
|
||||
else:
|
||||
expression_sql = f" AS{expression_sql}"
|
||||
|
||||
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
|
||||
transient = (
|
||||
|
@ -736,9 +749,9 @@ class Generator:
|
|||
|
||||
for p in expression.expressions:
|
||||
p_loc = self.PROPERTIES_LOCATION[p.__class__]
|
||||
if p_loc == exp.Properties.Location.POST_SCHEMA_WITH:
|
||||
if p_loc == exp.Properties.Location.POST_WITH:
|
||||
with_properties.append(p)
|
||||
elif p_loc == exp.Properties.Location.POST_SCHEMA_ROOT:
|
||||
elif p_loc == exp.Properties.Location.POST_SCHEMA:
|
||||
root_properties.append(p)
|
||||
|
||||
return self.root_properties(
|
||||
|
@ -776,16 +789,18 @@ class Generator:
|
|||
|
||||
for p in properties.expressions:
|
||||
p_loc = self.PROPERTIES_LOCATION[p.__class__]
|
||||
if p_loc == exp.Properties.Location.PRE_SCHEMA:
|
||||
properties_locs[exp.Properties.Location.PRE_SCHEMA].append(p)
|
||||
if p_loc == exp.Properties.Location.POST_NAME:
|
||||
properties_locs[exp.Properties.Location.POST_NAME].append(p)
|
||||
elif p_loc == exp.Properties.Location.POST_INDEX:
|
||||
properties_locs[exp.Properties.Location.POST_INDEX].append(p)
|
||||
elif p_loc == exp.Properties.Location.POST_SCHEMA_ROOT:
|
||||
properties_locs[exp.Properties.Location.POST_SCHEMA_ROOT].append(p)
|
||||
elif p_loc == exp.Properties.Location.POST_SCHEMA_WITH:
|
||||
properties_locs[exp.Properties.Location.POST_SCHEMA_WITH].append(p)
|
||||
elif p_loc == exp.Properties.Location.POST_SCHEMA:
|
||||
properties_locs[exp.Properties.Location.POST_SCHEMA].append(p)
|
||||
elif p_loc == exp.Properties.Location.POST_WITH:
|
||||
properties_locs[exp.Properties.Location.POST_WITH].append(p)
|
||||
elif p_loc == exp.Properties.Location.POST_CREATE:
|
||||
properties_locs[exp.Properties.Location.POST_CREATE].append(p)
|
||||
elif p_loc == exp.Properties.Location.POST_ALIAS:
|
||||
properties_locs[exp.Properties.Location.POST_ALIAS].append(p)
|
||||
elif p_loc == exp.Properties.Location.UNSUPPORTED:
|
||||
self.unsupported(f"Unsupported property {p.key}")
|
||||
|
||||
|
@ -899,6 +914,14 @@ class Generator:
|
|||
for_ = " FOR NONE"
|
||||
return f"WITH{no}{concurrent} ISOLATED LOADING{for_}"
|
||||
|
||||
def lockingproperty_sql(self, expression: exp.LockingProperty) -> str:
|
||||
kind = expression.args.get("kind")
|
||||
this: str = f" {this}" if expression.this else ""
|
||||
for_or_in = expression.args.get("for_or_in")
|
||||
lock_type = expression.args.get("lock_type")
|
||||
override = " OVERRIDE" if expression.args.get("override") else ""
|
||||
return f"LOCKING {kind}{this} {for_or_in} {lock_type}{override}"
|
||||
|
||||
def insert_sql(self, expression: exp.Insert) -> str:
|
||||
overwrite = expression.args.get("overwrite")
|
||||
|
||||
|
@ -907,14 +930,17 @@ class Generator:
|
|||
else:
|
||||
this = "OVERWRITE TABLE " if overwrite else "INTO "
|
||||
|
||||
alternative = expression.args.get("alternative")
|
||||
alternative = f" OR {alternative} " if alternative else " "
|
||||
this = f"{this}{self.sql(expression, 'this')}"
|
||||
|
||||
exists = " IF EXISTS " if expression.args.get("exists") else " "
|
||||
partition_sql = (
|
||||
self.sql(expression, "partition") if expression.args.get("partition") else ""
|
||||
)
|
||||
expression_sql = self.sql(expression, "expression")
|
||||
sep = self.sep() if partition_sql else ""
|
||||
sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}"
|
||||
sql = f"INSERT{alternative}{this}{exists}{partition_sql}{sep}{expression_sql}"
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
def intersect_sql(self, expression: exp.Intersect) -> str:
|
||||
|
@ -1046,21 +1072,26 @@ class Generator:
|
|||
f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
|
||||
)
|
||||
|
||||
cube = expression.args.get("cube")
|
||||
if cube is True:
|
||||
cube = self.seg("WITH CUBE")
|
||||
cube = expression.args.get("cube", [])
|
||||
if seq_get(cube, 0) is True:
|
||||
return f"{group_by}{self.seg('WITH CUBE')}"
|
||||
else:
|
||||
cube = self.expressions(expression, key="cube", indent=False)
|
||||
cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else ""
|
||||
cube_sql = self.expressions(expression, key="cube", indent=False)
|
||||
cube_sql = f"{self.seg('CUBE')} {self.wrap(cube_sql)}" if cube_sql else ""
|
||||
|
||||
rollup = expression.args.get("rollup")
|
||||
if rollup is True:
|
||||
rollup = self.seg("WITH ROLLUP")
|
||||
rollup = expression.args.get("rollup", [])
|
||||
if seq_get(rollup, 0) is True:
|
||||
return f"{group_by}{self.seg('WITH ROLLUP')}"
|
||||
else:
|
||||
rollup = self.expressions(expression, key="rollup", indent=False)
|
||||
rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else ""
|
||||
rollup_sql = self.expressions(expression, key="rollup", indent=False)
|
||||
rollup_sql = f"{self.seg('ROLLUP')} {self.wrap(rollup_sql)}" if rollup_sql else ""
|
||||
|
||||
return f"{group_by}{csv(grouping_sets, cube, rollup, sep=',')}"
|
||||
groupings = csv(grouping_sets, cube_sql, rollup_sql, sep=",")
|
||||
|
||||
if expression.args.get("expressions") and groupings:
|
||||
group_by = f"{group_by},"
|
||||
|
||||
return f"{group_by}{groupings}"
|
||||
|
||||
def having_sql(self, expression: exp.Having) -> str:
|
||||
this = self.indent(self.sql(expression, "this"))
|
||||
|
@ -1139,8 +1170,6 @@ class Generator:
|
|||
def literal_sql(self, expression: exp.Literal) -> str:
|
||||
text = expression.this or ""
|
||||
if expression.is_string:
|
||||
if self._replace_backslash:
|
||||
text = BACKSLASH_RE.sub(r"\\\\", text)
|
||||
text = text.replace(self.quote_end, self._escaped_quote_end)
|
||||
if self.pretty:
|
||||
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
|
||||
|
@ -1291,7 +1320,9 @@ class Generator:
|
|||
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
|
||||
|
||||
def parameter_sql(self, expression: exp.Parameter) -> str:
|
||||
return f"@{self.sql(expression, 'this')}"
|
||||
this = self.sql(expression, "this")
|
||||
this = f"{{{this}}}" if expression.args.get("wrapped") else f"{this}"
|
||||
return f"{self.PARAMETER_TOKEN}{this}"
|
||||
|
||||
def sessionparameter_sql(self, expression: exp.SessionParameter) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -1405,7 +1436,10 @@ class Generator:
|
|||
return f"ALL {self.wrap(expression)}"
|
||||
|
||||
def any_sql(self, expression: exp.Any) -> str:
|
||||
return f"ANY {self.wrap(expression)}"
|
||||
this = self.sql(expression, "this")
|
||||
if isinstance(expression.this, exp.Subqueryable):
|
||||
this = self.wrap(this)
|
||||
return f"ANY {this}"
|
||||
|
||||
def exists_sql(self, expression: exp.Exists) -> str:
|
||||
return f"EXISTS{self.wrap(expression)}"
|
||||
|
@ -1444,11 +1478,11 @@ class Generator:
|
|||
trim_type = self.sql(expression, "position")
|
||||
|
||||
if trim_type == "LEADING":
|
||||
return f"{self.normalize_func('LTRIM')}({self.format_args(expression.this)})"
|
||||
return self.func("LTRIM", expression.this)
|
||||
elif trim_type == "TRAILING":
|
||||
return f"{self.normalize_func('RTRIM')}({self.format_args(expression.this)})"
|
||||
return self.func("RTRIM", expression.this)
|
||||
else:
|
||||
return f"{self.normalize_func('TRIM')}({self.format_args(expression.this, expression.expression)})"
|
||||
return self.func("TRIM", expression.this, expression.expression)
|
||||
|
||||
def concat_sql(self, expression: exp.Concat) -> str:
|
||||
if len(expression.expressions) == 1:
|
||||
|
@ -1530,8 +1564,7 @@ class Generator:
|
|||
return f"REFERENCES {this}{expressions}{options}"
|
||||
|
||||
def anonymous_sql(self, expression: exp.Anonymous) -> str:
|
||||
args = self.format_args(*expression.expressions)
|
||||
return f"{self.normalize_func(self.sql(expression, 'this'))}({args})"
|
||||
return self.func(expression.name, *expression.expressions)
|
||||
|
||||
def paren_sql(self, expression: exp.Paren) -> str:
|
||||
if isinstance(expression.unnest(), exp.Select):
|
||||
|
@ -1792,7 +1825,10 @@ class Generator:
|
|||
else:
|
||||
args.append(arg_value)
|
||||
|
||||
return f"{self.normalize_func(expression.sql_name())}({self.format_args(*args)})"
|
||||
return self.func(expression.sql_name(), *args)
|
||||
|
||||
def func(self, name: str, *args: t.Optional[exp.Expression | str]) -> str:
|
||||
return f"{self.normalize_func(name)}({self.format_args(*args)})"
|
||||
|
||||
def format_args(self, *args: t.Optional[str | exp.Expression]) -> str:
|
||||
arg_sqls = tuple(self.sql(arg) for arg in args if arg is not None)
|
||||
|
@ -1848,6 +1884,7 @@ class Generator:
|
|||
return self.indent(result_sql, skip_first=False) if indent else result_sql
|
||||
|
||||
def op_expressions(self, op: str, expression: exp.Expression, flat: bool = False) -> str:
|
||||
flat = flat or isinstance(expression.parent, exp.Properties)
|
||||
expressions_sql = self.expressions(expression, flat=flat)
|
||||
if flat:
|
||||
return f"{op} {expressions_sql}"
|
||||
|
@ -1880,11 +1917,6 @@ class Generator:
|
|||
)
|
||||
return f"{this}{expressions}"
|
||||
|
||||
def userdefinedfunctionkwarg_sql(self, expression: exp.UserDefinedFunctionKwarg) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
kind = self.sql(expression, "kind")
|
||||
return f"{this} {kind}"
|
||||
|
||||
def joinhint_sql(self, expression: exp.JoinHint) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
expressions = self.expressions(expression, flat=True)
|
||||
|
|
|
@ -280,6 +280,9 @@ class TypeAnnotator:
|
|||
}
|
||||
# First annotate the current scope's column references
|
||||
for col in scope.columns:
|
||||
if not col.table:
|
||||
continue
|
||||
|
||||
source = scope.sources.get(col.table)
|
||||
if isinstance(source, exp.Table):
|
||||
col.type = self.schema.get_column_type(source, col)
|
||||
|
|
|
@ -81,9 +81,7 @@ def eliminate_subqueries(expression):
|
|||
new_ctes.append(cte_scope.expression.parent)
|
||||
|
||||
# Now append the rest
|
||||
for scope in itertools.chain(
|
||||
root.union_scopes, root.subquery_scopes, root.derived_table_scopes
|
||||
):
|
||||
for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes):
|
||||
for child_scope in scope.traverse():
|
||||
new_cte = _eliminate(child_scope, existing_ctes, taken)
|
||||
if new_cte:
|
||||
|
@ -99,7 +97,7 @@ def _eliminate(scope, existing_ctes, taken):
|
|||
if scope.is_union:
|
||||
return _eliminate_union(scope, existing_ctes, taken)
|
||||
|
||||
if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF):
|
||||
if scope.is_derived_table:
|
||||
return _eliminate_derived_table(scope, existing_ctes, taken)
|
||||
|
||||
if scope.is_cte:
|
||||
|
|
|
@ -1,4 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
import sqlglot
|
||||
from sqlglot import Schema, exp
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
from sqlglot.optimizer.canonicalize import canonicalize
|
||||
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
|
||||
|
@ -24,8 +30,8 @@ RULES = (
|
|||
isolate_table_selects,
|
||||
qualify_columns,
|
||||
expand_laterals,
|
||||
validate_qualify_columns,
|
||||
pushdown_projections,
|
||||
validate_qualify_columns,
|
||||
normalize,
|
||||
unnest_subqueries,
|
||||
expand_multi_table_selects,
|
||||
|
@ -40,22 +46,31 @@ RULES = (
|
|||
)
|
||||
|
||||
|
||||
def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwargs):
|
||||
def optimize(
|
||||
expression: str | exp.Expression,
|
||||
schema: t.Optional[dict | Schema] = None,
|
||||
db: t.Optional[str] = None,
|
||||
catalog: t.Optional[str] = None,
|
||||
dialect: DialectType = None,
|
||||
rules: t.Sequence[t.Callable] = RULES,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Rewrite a sqlglot AST into an optimized form.
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to optimize
|
||||
schema (dict|sqlglot.optimizer.Schema): database schema.
|
||||
expression: expression to optimize
|
||||
schema: database schema.
|
||||
This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of
|
||||
the following forms:
|
||||
1. {table: {col: type}}
|
||||
2. {db: {table: {col: type}}}
|
||||
3. {catalog: {db: {table: {col: type}}}}
|
||||
If no schema is provided then the default schema defined at `sqlgot.schema` will be used
|
||||
db (str): specify the default database, as might be set by a `USE DATABASE db` statement
|
||||
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
|
||||
rules (sequence): sequence of optimizer rules to use.
|
||||
db: specify the default database, as might be set by a `USE DATABASE db` statement
|
||||
catalog: specify the default catalog, as might be set by a `USE CATALOG c` statement
|
||||
dialect: The dialect to parse the sql string.
|
||||
rules: sequence of optimizer rules to use.
|
||||
Many of the rules require tables and columns to be qualified.
|
||||
Do not remove qualify_tables or qualify_columns from the sequence of rules unless you know
|
||||
what you're doing!
|
||||
|
@ -65,7 +80,7 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
|
|||
"""
|
||||
schema = ensure_schema(schema or sqlglot.schema)
|
||||
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
|
||||
expression = expression.copy()
|
||||
expression = exp.maybe_parse(expression, dialect=dialect, copy=True)
|
||||
for rule in rules:
|
||||
# Find any additional rule parameters, beyond `expression`
|
||||
rule_params = rule.__code__.co_varnames
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
from collections import defaultdict
|
||||
|
||||
from sqlglot import alias, exp
|
||||
from sqlglot.helper import flatten
|
||||
from sqlglot.optimizer.qualify_columns import Resolver
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import ensure_schema
|
||||
|
||||
# Sentinel value that means an outer query selecting ALL columns
|
||||
SELECT_ALL = object()
|
||||
|
@ -10,7 +13,7 @@ SELECT_ALL = object()
|
|||
DEFAULT_SELECTION = lambda: alias("1", "_")
|
||||
|
||||
|
||||
def pushdown_projections(expression):
|
||||
def pushdown_projections(expression, schema=None):
|
||||
"""
|
||||
Rewrite sqlglot AST to remove unused columns projections.
|
||||
|
||||
|
@ -27,9 +30,9 @@ def pushdown_projections(expression):
|
|||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
# Map of Scope to all columns being selected by outer queries.
|
||||
schema = ensure_schema(schema)
|
||||
referenced_columns = defaultdict(set)
|
||||
left_union = None
|
||||
right_union = None
|
||||
|
||||
# We build the scope tree (which is traversed in DFS postorder), then iterate
|
||||
# over the result in reverse order. This should ensure that the set of selected
|
||||
# columns for a particular scope are completely build by the time we get to it.
|
||||
|
@ -41,16 +44,20 @@ def pushdown_projections(expression):
|
|||
parent_selections = {SELECT_ALL}
|
||||
|
||||
if isinstance(scope.expression, exp.Union):
|
||||
left_union, right_union = scope.union_scopes
|
||||
referenced_columns[left_union] = parent_selections
|
||||
referenced_columns[right_union] = parent_selections
|
||||
left, right = scope.union_scopes
|
||||
referenced_columns[left] = parent_selections
|
||||
|
||||
if isinstance(scope.expression, exp.Select) and scope != right_union:
|
||||
removed_indexes = _remove_unused_selections(scope, parent_selections)
|
||||
# The left union is used for column names to select and if we remove columns from the left
|
||||
# we need to also remove those same columns in the right that were at the same position
|
||||
if scope is left_union:
|
||||
_remove_indexed_selections(right_union, removed_indexes)
|
||||
if any(select.is_star for select in right.selects):
|
||||
referenced_columns[right] = parent_selections
|
||||
elif not any(select.is_star for select in left.selects):
|
||||
referenced_columns[right] = [
|
||||
right.selects[i].alias_or_name
|
||||
for i, select in enumerate(left.selects)
|
||||
if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections
|
||||
]
|
||||
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
_remove_unused_selections(scope, parent_selections, schema)
|
||||
|
||||
# Group columns by source name
|
||||
selects = defaultdict(set)
|
||||
|
@ -68,8 +75,7 @@ def pushdown_projections(expression):
|
|||
return expression
|
||||
|
||||
|
||||
def _remove_unused_selections(scope, parent_selections):
|
||||
removed_indexes = []
|
||||
def _remove_unused_selections(scope, parent_selections, schema):
|
||||
order = scope.expression.args.get("order")
|
||||
|
||||
if order:
|
||||
|
@ -78,33 +84,33 @@ def _remove_unused_selections(scope, parent_selections):
|
|||
else:
|
||||
order_refs = set()
|
||||
|
||||
new_selections = []
|
||||
new_selections = defaultdict(list)
|
||||
removed = False
|
||||
for i, selection in enumerate(scope.selects):
|
||||
if (
|
||||
SELECT_ALL in parent_selections
|
||||
or selection.alias_or_name in parent_selections
|
||||
or selection.alias_or_name in order_refs
|
||||
):
|
||||
new_selections.append(selection)
|
||||
star = False
|
||||
for selection in scope.selects:
|
||||
name = selection.alias_or_name
|
||||
|
||||
if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs:
|
||||
new_selections[name].append(selection)
|
||||
else:
|
||||
removed_indexes.append(i)
|
||||
if selection.is_star:
|
||||
star = True
|
||||
removed = True
|
||||
|
||||
if star:
|
||||
resolver = Resolver(scope, schema)
|
||||
|
||||
for name in sorted(parent_selections):
|
||||
if name not in new_selections:
|
||||
new_selections[name].append(
|
||||
alias(exp.column(name, table=resolver.get_table(name)), name)
|
||||
)
|
||||
|
||||
# If there are no remaining selections, just select a single constant
|
||||
if not new_selections:
|
||||
new_selections.append(DEFAULT_SELECTION())
|
||||
new_selections[""].append(DEFAULT_SELECTION())
|
||||
|
||||
scope.expression.select(*flatten(new_selections.values()), append=False, copy=False)
|
||||
|
||||
scope.expression.set("expressions", new_selections)
|
||||
if removed:
|
||||
scope.clear_cache()
|
||||
return removed_indexes
|
||||
|
||||
|
||||
def _remove_indexed_selections(scope, indexes_to_remove):
|
||||
new_selections = [
|
||||
selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
|
||||
]
|
||||
if not new_selections:
|
||||
new_selections.append(DEFAULT_SELECTION())
|
||||
scope.expression.set("expressions", new_selections)
|
||||
|
|
|
@ -27,17 +27,16 @@ def qualify_columns(expression, schema):
|
|||
schema = ensure_schema(schema)
|
||||
|
||||
for scope in traverse_scope(expression):
|
||||
resolver = _Resolver(scope, schema)
|
||||
resolver = Resolver(scope, schema)
|
||||
_pop_table_column_aliases(scope.ctes)
|
||||
_pop_table_column_aliases(scope.derived_tables)
|
||||
_expand_using(scope, resolver)
|
||||
_expand_group_by(scope, resolver)
|
||||
_qualify_columns(scope, resolver)
|
||||
_expand_order_by(scope)
|
||||
if not isinstance(scope.expression, exp.UDTF):
|
||||
_expand_stars(scope, resolver)
|
||||
_qualify_outputs(scope)
|
||||
|
||||
_expand_group_by(scope, resolver)
|
||||
_expand_order_by(scope)
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -48,7 +47,8 @@ def validate_qualify_columns(expression):
|
|||
if isinstance(scope.expression, exp.Select):
|
||||
unqualified_columns.extend(scope.unqualified_columns)
|
||||
if scope.external_columns and not scope.is_correlated_subquery:
|
||||
raise OptimizeError(f"Unknown table: {scope.external_columns[0].table}")
|
||||
column = scope.external_columns[0]
|
||||
raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'")
|
||||
|
||||
if unqualified_columns:
|
||||
raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
|
||||
|
@ -62,8 +62,6 @@ def _pop_table_column_aliases(derived_tables):
|
|||
(e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
|
||||
"""
|
||||
for derived_table in derived_tables:
|
||||
if isinstance(derived_table.unnest(), exp.UDTF):
|
||||
continue
|
||||
table_alias = derived_table.args.get("alias")
|
||||
if table_alias:
|
||||
table_alias.args.pop("columns", None)
|
||||
|
@ -206,7 +204,7 @@ def _qualify_columns(scope, resolver):
|
|||
|
||||
if column_table and column_table in scope.sources:
|
||||
source_columns = resolver.get_source_columns(column_table)
|
||||
if source_columns and column_name not in source_columns:
|
||||
if source_columns and column_name not in source_columns and "*" not in source_columns:
|
||||
raise OptimizeError(f"Unknown column: {column_name}")
|
||||
|
||||
if not column_table:
|
||||
|
@ -256,7 +254,7 @@ def _expand_stars(scope, resolver):
|
|||
tables = list(scope.selected_sources)
|
||||
_add_except_columns(expression, tables, except_columns)
|
||||
_add_replace_columns(expression, tables, replace_columns)
|
||||
elif isinstance(expression, exp.Column) and isinstance(expression.this, exp.Star):
|
||||
elif expression.is_star:
|
||||
tables = [expression.table]
|
||||
_add_except_columns(expression.this, tables, except_columns)
|
||||
_add_replace_columns(expression.this, tables, replace_columns)
|
||||
|
@ -268,17 +266,16 @@ def _expand_stars(scope, resolver):
|
|||
if table not in scope.sources:
|
||||
raise OptimizeError(f"Unknown table: {table}")
|
||||
columns = resolver.get_source_columns(table, only_visible=True)
|
||||
if not columns:
|
||||
raise OptimizeError(
|
||||
f"Table has no schema/columns. Cannot expand star for table: {table}."
|
||||
)
|
||||
table_id = id(table)
|
||||
for name in columns:
|
||||
if name not in except_columns.get(table_id, set()):
|
||||
alias_ = replace_columns.get(table_id, {}).get(name, name)
|
||||
column = exp.column(name, table)
|
||||
new_selections.append(alias(column, alias_) if alias_ != name else column)
|
||||
|
||||
if columns and "*" not in columns:
|
||||
table_id = id(table)
|
||||
for name in columns:
|
||||
if name not in except_columns.get(table_id, set()):
|
||||
alias_ = replace_columns.get(table_id, {}).get(name, name)
|
||||
column = exp.column(name, table)
|
||||
new_selections.append(alias(column, alias_) if alias_ != name else column)
|
||||
else:
|
||||
return
|
||||
scope.expression.set("expressions", new_selections)
|
||||
|
||||
|
||||
|
@ -316,7 +313,7 @@ def _qualify_outputs(scope):
|
|||
if isinstance(selection, exp.Subquery):
|
||||
if not selection.output_name:
|
||||
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
|
||||
elif not isinstance(selection, exp.Alias):
|
||||
elif not isinstance(selection, exp.Alias) and not selection.is_star:
|
||||
alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}")
|
||||
alias_.set("this", selection)
|
||||
selection = alias_
|
||||
|
@ -329,7 +326,7 @@ def _qualify_outputs(scope):
|
|||
scope.expression.set("expressions", new_selections)
|
||||
|
||||
|
||||
class _Resolver:
|
||||
class Resolver:
|
||||
"""
|
||||
Helper for resolving columns.
|
||||
|
||||
|
@ -361,7 +358,9 @@ class _Resolver:
|
|||
|
||||
if not table:
|
||||
sources_without_schema = tuple(
|
||||
source for source, columns in self._get_all_source_columns().items() if not columns
|
||||
source
|
||||
for source, columns in self._get_all_source_columns().items()
|
||||
if not columns or "*" in columns
|
||||
)
|
||||
if len(sources_without_schema) == 1:
|
||||
return sources_without_schema[0]
|
||||
|
@ -397,7 +396,8 @@ class _Resolver:
|
|||
def _get_all_source_columns(self):
|
||||
if self._source_columns is None:
|
||||
self._source_columns = {
|
||||
k: self.get_source_columns(k) for k in self.scope.selected_sources
|
||||
k: self.get_source_columns(k)
|
||||
for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
|
||||
}
|
||||
return self._source_columns
|
||||
|
||||
|
@ -436,7 +436,7 @@ class _Resolver:
|
|||
Find the unique columns in a list of columns.
|
||||
|
||||
Example:
|
||||
>>> sorted(_Resolver._find_unique_columns(["a", "b", "b", "c"]))
|
||||
>>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
|
||||
['a', 'c']
|
||||
|
||||
This is necessary because duplicate column names are ambiguous.
|
||||
|
|
|
@ -28,7 +28,7 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
|
|||
next_name = lambda: f"_q_{next(sequence)}"
|
||||
|
||||
for scope in traverse_scope(expression):
|
||||
for derived_table in scope.ctes + scope.derived_tables:
|
||||
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
|
||||
if not derived_table.args.get("alias"):
|
||||
alias_ = f"_q_{next(sequence)}"
|
||||
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
|
||||
|
|
|
@ -26,6 +26,10 @@ class Scope:
|
|||
SELECT * FROM x {"x": Table(this="x")}
|
||||
SELECT * FROM x AS y {"y": Table(this="x")}
|
||||
SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
|
||||
lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
|
||||
For example:
|
||||
SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
|
||||
The LATERAL VIEW EXPLODE gets x as a source.
|
||||
outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
|
||||
defines a column list of it's alias of this scope, this is that list of columns.
|
||||
For example:
|
||||
|
@ -34,8 +38,10 @@ class Scope:
|
|||
parent (Scope): Parent scope
|
||||
scope_type (ScopeType): Type of this scope, relative to it's parent
|
||||
subquery_scopes (list[Scope]): List of all child scopes for subqueries
|
||||
cte_scopes = (list[Scope]) List of all child scopes for CTEs
|
||||
derived_table_scopes = (list[Scope]) List of all child scopes for derived_tables
|
||||
cte_scopes (list[Scope]): List of all child scopes for CTEs
|
||||
derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
|
||||
udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
|
||||
table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
|
||||
union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
|
||||
a list of the left and right child scopes.
|
||||
"""
|
||||
|
@ -47,22 +53,28 @@ class Scope:
|
|||
outer_column_list=None,
|
||||
parent=None,
|
||||
scope_type=ScopeType.ROOT,
|
||||
lateral_sources=None,
|
||||
):
|
||||
self.expression = expression
|
||||
self.sources = sources or {}
|
||||
self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
|
||||
self.sources.update(self.lateral_sources)
|
||||
self.outer_column_list = outer_column_list or []
|
||||
self.parent = parent
|
||||
self.scope_type = scope_type
|
||||
self.subquery_scopes = []
|
||||
self.derived_table_scopes = []
|
||||
self.table_scopes = []
|
||||
self.cte_scopes = []
|
||||
self.union_scopes = []
|
||||
self.udtf_scopes = []
|
||||
self.clear_cache()
|
||||
|
||||
def clear_cache(self):
|
||||
self._collected = False
|
||||
self._raw_columns = None
|
||||
self._derived_tables = None
|
||||
self._udtfs = None
|
||||
self._tables = None
|
||||
self._ctes = None
|
||||
self._subqueries = None
|
||||
|
@ -86,6 +98,7 @@ class Scope:
|
|||
self._ctes = []
|
||||
self._subqueries = []
|
||||
self._derived_tables = []
|
||||
self._udtfs = []
|
||||
self._raw_columns = []
|
||||
self._join_hints = []
|
||||
|
||||
|
@ -99,7 +112,7 @@ class Scope:
|
|||
elif isinstance(node, exp.JoinHint):
|
||||
self._join_hints.append(node)
|
||||
elif isinstance(node, exp.UDTF):
|
||||
self._derived_tables.append(node)
|
||||
self._udtfs.append(node)
|
||||
elif isinstance(node, exp.CTE):
|
||||
self._ctes.append(node)
|
||||
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
|
||||
|
@ -199,6 +212,17 @@ class Scope:
|
|||
self._ensure_collected()
|
||||
return self._derived_tables
|
||||
|
||||
@property
|
||||
def udtfs(self):
|
||||
"""
|
||||
List of "User Defined Tabular Functions" in this scope.
|
||||
|
||||
Returns:
|
||||
list[exp.UDTF]: UDTFs
|
||||
"""
|
||||
self._ensure_collected()
|
||||
return self._udtfs
|
||||
|
||||
@property
|
||||
def subqueries(self):
|
||||
"""
|
||||
|
@ -227,7 +251,9 @@ class Scope:
|
|||
columns = self._raw_columns
|
||||
|
||||
external_columns = [
|
||||
column for scope in self.subquery_scopes for column in scope.external_columns
|
||||
column
|
||||
for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes)
|
||||
for column in scope.external_columns
|
||||
]
|
||||
|
||||
named_selects = set(self.expression.named_selects)
|
||||
|
@ -262,9 +288,8 @@ class Scope:
|
|||
|
||||
for table in self.tables:
|
||||
referenced_names.append((table.alias_or_name, table))
|
||||
for derived_table in self.derived_tables:
|
||||
referenced_names.append((derived_table.alias, derived_table.unnest()))
|
||||
|
||||
for expression in itertools.chain(self.derived_tables, self.udtfs):
|
||||
referenced_names.append((expression.alias, expression.unnest()))
|
||||
result = {}
|
||||
|
||||
for name, node in referenced_names:
|
||||
|
@ -414,7 +439,7 @@ class Scope:
|
|||
Scope: scope instances in depth-first-search post-order
|
||||
"""
|
||||
for child_scope in itertools.chain(
|
||||
self.cte_scopes, self.union_scopes, self.derived_table_scopes, self.subquery_scopes
|
||||
self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
|
||||
):
|
||||
yield from child_scope.traverse()
|
||||
yield self
|
||||
|
@ -480,24 +505,23 @@ def _traverse_scope(scope):
|
|||
yield from _traverse_select(scope)
|
||||
elif isinstance(scope.expression, exp.Union):
|
||||
yield from _traverse_union(scope)
|
||||
elif isinstance(scope.expression, exp.UDTF):
|
||||
_set_udtf_scope(scope)
|
||||
elif isinstance(scope.expression, exp.Subquery):
|
||||
yield from _traverse_subqueries(scope)
|
||||
elif isinstance(scope.expression, exp.UDTF):
|
||||
pass
|
||||
else:
|
||||
raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}")
|
||||
yield scope
|
||||
|
||||
|
||||
def _traverse_select(scope):
|
||||
yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
|
||||
yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE)
|
||||
yield from _traverse_ctes(scope)
|
||||
yield from _traverse_tables(scope)
|
||||
yield from _traverse_subqueries(scope)
|
||||
_add_table_sources(scope)
|
||||
|
||||
|
||||
def _traverse_union(scope):
|
||||
yield from _traverse_derived_tables(scope.ctes, scope, scope_type=ScopeType.CTE)
|
||||
yield from _traverse_ctes(scope)
|
||||
|
||||
# The last scope to be yield should be the top most scope
|
||||
left = None
|
||||
|
@ -511,44 +535,84 @@ def _traverse_union(scope):
|
|||
scope.union_scopes = [left, right]
|
||||
|
||||
|
||||
def _set_udtf_scope(scope):
|
||||
parent = scope.expression.parent
|
||||
from_ = parent.args.get("from")
|
||||
|
||||
if not from_:
|
||||
return
|
||||
|
||||
for table in from_.expressions:
|
||||
if isinstance(table, exp.Table):
|
||||
scope.tables.append(table)
|
||||
elif isinstance(table, exp.Subquery):
|
||||
scope.subqueries.append(table)
|
||||
_add_table_sources(scope)
|
||||
_traverse_subqueries(scope)
|
||||
|
||||
|
||||
def _traverse_derived_tables(derived_tables, scope, scope_type):
|
||||
def _traverse_ctes(scope):
|
||||
sources = {}
|
||||
is_cte = scope_type == ScopeType.CTE
|
||||
|
||||
for derived_table in derived_tables:
|
||||
for cte in scope.ctes:
|
||||
recursive_scope = None
|
||||
|
||||
# if the scope is a recursive cte, it must be in the form of
|
||||
# base_case UNION recursive. thus the recursive scope is the first
|
||||
# section of the union.
|
||||
if is_cte and scope.expression.args["with"].recursive:
|
||||
union = derived_table.this
|
||||
if scope.expression.args["with"].recursive:
|
||||
union = cte.this
|
||||
|
||||
if isinstance(union, exp.Union):
|
||||
recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE)
|
||||
|
||||
for child_scope in _traverse_scope(
|
||||
scope.branch(
|
||||
derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this,
|
||||
chain_sources=sources if scope_type == ScopeType.CTE else None,
|
||||
outer_column_list=derived_table.alias_column_names,
|
||||
scope_type=ScopeType.UDTF if isinstance(derived_table, exp.UDTF) else scope_type,
|
||||
cte.this,
|
||||
chain_sources=sources,
|
||||
outer_column_list=cte.alias_column_names,
|
||||
scope_type=ScopeType.CTE,
|
||||
)
|
||||
):
|
||||
yield child_scope
|
||||
|
||||
alias = cte.alias
|
||||
sources[alias] = child_scope
|
||||
|
||||
if recursive_scope:
|
||||
child_scope.add_source(alias, recursive_scope)
|
||||
|
||||
# append the final child_scope yielded
|
||||
scope.cte_scopes.append(child_scope)
|
||||
|
||||
scope.sources.update(sources)
|
||||
|
||||
|
||||
def _traverse_tables(scope):
|
||||
sources = {}
|
||||
|
||||
# Traverse FROMs, JOINs, and LATERALs in the order they are defined
|
||||
expressions = []
|
||||
from_ = scope.expression.args.get("from")
|
||||
if from_:
|
||||
expressions.extend(from_.expressions)
|
||||
|
||||
for join in scope.expression.args.get("joins") or []:
|
||||
expressions.append(join.this)
|
||||
|
||||
expressions.extend(scope.expression.args.get("laterals") or [])
|
||||
|
||||
for expression in expressions:
|
||||
if isinstance(expression, exp.Table):
|
||||
table_name = expression.name
|
||||
source_name = expression.alias_or_name
|
||||
|
||||
if table_name in scope.sources:
|
||||
# This is a reference to a parent source (e.g. a CTE), not an actual table.
|
||||
sources[source_name] = scope.sources[table_name]
|
||||
else:
|
||||
sources[source_name] = expression
|
||||
continue
|
||||
|
||||
if isinstance(expression, exp.UDTF):
|
||||
lateral_sources = sources
|
||||
scope_type = ScopeType.UDTF
|
||||
scopes = scope.udtf_scopes
|
||||
else:
|
||||
lateral_sources = None
|
||||
scope_type = ScopeType.DERIVED_TABLE
|
||||
scopes = scope.derived_table_scopes
|
||||
|
||||
for child_scope in _traverse_scope(
|
||||
scope.branch(
|
||||
expression,
|
||||
lateral_sources=lateral_sources,
|
||||
outer_column_list=expression.alias_column_names,
|
||||
scope_type=scope_type,
|
||||
)
|
||||
):
|
||||
yield child_scope
|
||||
|
@ -557,36 +621,12 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
|
|||
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
|
||||
# Until then, this means that only a single, unaliased derived table is allowed (rather,
|
||||
# the latest one wins.
|
||||
alias = derived_table.alias
|
||||
alias = expression.alias
|
||||
sources[alias] = child_scope
|
||||
|
||||
if recursive_scope:
|
||||
child_scope.add_source(alias, recursive_scope)
|
||||
|
||||
# append the final child_scope yielded
|
||||
if is_cte:
|
||||
scope.cte_scopes.append(child_scope)
|
||||
else:
|
||||
scope.derived_table_scopes.append(child_scope)
|
||||
|
||||
scope.sources.update(sources)
|
||||
|
||||
|
||||
def _add_table_sources(scope):
|
||||
sources = {}
|
||||
for table in scope.tables:
|
||||
table_name = table.name
|
||||
|
||||
if table.alias:
|
||||
source_name = table.alias
|
||||
else:
|
||||
source_name = table_name
|
||||
|
||||
if table_name in scope.sources:
|
||||
# This is a reference to a parent source (e.g. a CTE), not an actual table.
|
||||
scope.sources[source_name] = scope.sources[table_name]
|
||||
else:
|
||||
sources[source_name] = table
|
||||
scopes.append(child_scope)
|
||||
scope.table_scopes.append(child_scope)
|
||||
|
||||
scope.sources.update(sources)
|
||||
|
||||
|
@ -624,9 +664,10 @@ def walk_in_scope(expression, bfs=True):
|
|||
|
||||
if node is expression:
|
||||
continue
|
||||
elif isinstance(node, exp.CTE):
|
||||
prune = True
|
||||
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
|
||||
prune = True
|
||||
elif isinstance(node, exp.Subqueryable):
|
||||
if (
|
||||
isinstance(node, exp.CTE)
|
||||
or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)))
|
||||
or isinstance(node, exp.UDTF)
|
||||
or isinstance(node, exp.Subqueryable)
|
||||
):
|
||||
prune = True
|
||||
|
|
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
import typing as t
|
||||
from collections import defaultdict
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors
|
||||
|
@ -157,7 +158,6 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
ID_VAR_TOKENS = {
|
||||
TokenType.VAR,
|
||||
TokenType.ALWAYS,
|
||||
TokenType.ANTI,
|
||||
TokenType.APPLY,
|
||||
TokenType.AUTO_INCREMENT,
|
||||
|
@ -186,8 +186,6 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.FOLLOWING,
|
||||
TokenType.FORMAT,
|
||||
TokenType.FUNCTION,
|
||||
TokenType.GENERATED,
|
||||
TokenType.IDENTITY,
|
||||
TokenType.IF,
|
||||
TokenType.INDEX,
|
||||
TokenType.ISNULL,
|
||||
|
@ -213,7 +211,6 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.ROW,
|
||||
TokenType.ROWS,
|
||||
TokenType.SCHEMA,
|
||||
TokenType.SCHEMA_COMMENT,
|
||||
TokenType.SEED,
|
||||
TokenType.SEMI,
|
||||
TokenType.SET,
|
||||
|
@ -481,9 +478,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
PLACEHOLDER_PARSERS = {
|
||||
TokenType.PLACEHOLDER: lambda self: self.expression(exp.Placeholder),
|
||||
TokenType.PARAMETER: lambda self: self.expression(
|
||||
exp.Parameter, this=self._parse_var() or self._parse_primary()
|
||||
),
|
||||
TokenType.PARAMETER: lambda self: self._parse_parameter(),
|
||||
TokenType.COLON: lambda self: self.expression(exp.Placeholder, this=self._prev.text)
|
||||
if self._match_set((TokenType.NUMBER, TokenType.VAR))
|
||||
else None,
|
||||
|
@ -516,6 +511,9 @@ class Parser(metaclass=_Parser):
|
|||
PROPERTY_PARSERS = {
|
||||
"AUTO_INCREMENT": lambda self: self._parse_property_assignment(exp.AutoIncrementProperty),
|
||||
"CHARACTER SET": lambda self: self._parse_character_set(),
|
||||
"CLUSTER BY": lambda self: self.expression(
|
||||
exp.Cluster, expressions=self._parse_csv(self._parse_ordered)
|
||||
),
|
||||
"LOCATION": lambda self: self._parse_property_assignment(exp.LocationProperty),
|
||||
"PARTITION BY": lambda self: self._parse_partitioned_by(),
|
||||
"PARTITIONED BY": lambda self: self._parse_partitioned_by(),
|
||||
|
@ -576,20 +574,54 @@ class Parser(metaclass=_Parser):
|
|||
"BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(),
|
||||
"ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty),
|
||||
"DEFINER": lambda self: self._parse_definer(),
|
||||
"LOCK": lambda self: self._parse_locking(),
|
||||
"LOCKING": lambda self: self._parse_locking(),
|
||||
}
|
||||
|
||||
CONSTRAINT_PARSERS = {
|
||||
TokenType.CHECK: lambda self: self.expression(
|
||||
exp.Check, this=self._parse_wrapped(self._parse_conjunction)
|
||||
"AUTOINCREMENT": lambda self: self._parse_auto_increment(),
|
||||
"AUTO_INCREMENT": lambda self: self._parse_auto_increment(),
|
||||
"CASESPECIFIC": lambda self: self.expression(exp.CaseSpecificColumnConstraint, not_=False),
|
||||
"CHARACTER SET": lambda self: self.expression(
|
||||
exp.CharacterSetColumnConstraint, this=self._parse_var_or_string()
|
||||
),
|
||||
TokenType.FOREIGN_KEY: lambda self: self._parse_foreign_key(),
|
||||
TokenType.UNIQUE: lambda self: self._parse_unique(),
|
||||
TokenType.LIKE: lambda self: self._parse_create_like(),
|
||||
"CHECK": lambda self: self.expression(
|
||||
exp.CheckColumnConstraint, this=self._parse_wrapped(self._parse_conjunction)
|
||||
),
|
||||
"COLLATE": lambda self: self.expression(
|
||||
exp.CollateColumnConstraint, this=self._parse_var()
|
||||
),
|
||||
"COMMENT": lambda self: self.expression(
|
||||
exp.CommentColumnConstraint, this=self._parse_string()
|
||||
),
|
||||
"DEFAULT": lambda self: self.expression(
|
||||
exp.DefaultColumnConstraint, this=self._parse_bitwise()
|
||||
),
|
||||
"ENCODE": lambda self: self.expression(exp.EncodeColumnConstraint, this=self._parse_var()),
|
||||
"FOREIGN KEY": lambda self: self._parse_foreign_key(),
|
||||
"FORMAT": lambda self: self.expression(
|
||||
exp.DateFormatColumnConstraint, this=self._parse_var_or_string()
|
||||
),
|
||||
"GENERATED": lambda self: self._parse_generated_as_identity(),
|
||||
"IDENTITY": lambda self: self._parse_auto_increment(),
|
||||
"LIKE": lambda self: self._parse_create_like(),
|
||||
"NOT": lambda self: self._parse_not_constraint(),
|
||||
"NULL": lambda self: self.expression(exp.NotNullColumnConstraint, allow_null=True),
|
||||
"PATH": lambda self: self.expression(exp.PathColumnConstraint, this=self._parse_string()),
|
||||
"PRIMARY KEY": lambda self: self._parse_primary_key(),
|
||||
"TITLE": lambda self: self.expression(
|
||||
exp.TitleColumnConstraint, this=self._parse_var_or_string()
|
||||
),
|
||||
"UNIQUE": lambda self: self._parse_unique(),
|
||||
"UPPERCASE": lambda self: self.expression(exp.UppercaseColumnConstraint),
|
||||
}
|
||||
|
||||
SCHEMA_UNNAMED_CONSTRAINTS = {"CHECK", "FOREIGN KEY", "LIKE", "PRIMARY KEY", "UNIQUE"}
|
||||
|
||||
NO_PAREN_FUNCTION_PARSERS = {
|
||||
TokenType.CASE: lambda self: self._parse_case(),
|
||||
TokenType.IF: lambda self: self._parse_if(),
|
||||
TokenType.ANY: lambda self: self.expression(exp.Any, this=self._parse_bitwise()),
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
|
||||
|
@ -637,6 +669,8 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
|
||||
|
||||
INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"}
|
||||
|
||||
WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS}
|
||||
|
||||
ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY}
|
||||
|
@ -940,7 +974,9 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
def _parse_create(self) -> t.Optional[exp.Expression]:
|
||||
start = self._prev
|
||||
replace = self._match_pair(TokenType.OR, TokenType.REPLACE)
|
||||
replace = self._prev.text.upper() == "REPLACE" or self._match_pair(
|
||||
TokenType.OR, TokenType.REPLACE
|
||||
)
|
||||
set_ = self._match(TokenType.SET) # Teradata
|
||||
multiset = self._match_text_seq("MULTISET") # Teradata
|
||||
global_temporary = self._match_text_seq("GLOBAL", "TEMPORARY") # Teradata
|
||||
|
@ -958,7 +994,7 @@ class Parser(metaclass=_Parser):
|
|||
create_token = self._match_set(self.CREATABLES) and self._prev
|
||||
|
||||
if not create_token:
|
||||
properties = self._parse_properties()
|
||||
properties = self._parse_properties() # exp.Properties.Location.POST_CREATE
|
||||
create_token = self._match_set(self.CREATABLES) and self._prev
|
||||
|
||||
if not properties or not create_token:
|
||||
|
@ -994,15 +1030,37 @@ class Parser(metaclass=_Parser):
|
|||
):
|
||||
table_parts = self._parse_table_parts(schema=True)
|
||||
|
||||
if self._match(TokenType.COMMA): # comma-separated properties before schema definition
|
||||
properties = self._parse_properties(before=True)
|
||||
# exp.Properties.Location.POST_NAME
|
||||
if self._match(TokenType.COMMA):
|
||||
temp_properties = self._parse_properties(before=True)
|
||||
if properties and temp_properties:
|
||||
properties.expressions.append(temp_properties.expressions)
|
||||
elif temp_properties:
|
||||
properties = temp_properties
|
||||
|
||||
this = self._parse_schema(this=table_parts)
|
||||
|
||||
if not properties: # properties after schema definition
|
||||
properties = self._parse_properties()
|
||||
# exp.Properties.Location.POST_SCHEMA and POST_WITH
|
||||
temp_properties = self._parse_properties()
|
||||
if properties and temp_properties:
|
||||
properties.expressions.append(temp_properties.expressions)
|
||||
elif temp_properties:
|
||||
properties = temp_properties
|
||||
|
||||
self._match(TokenType.ALIAS)
|
||||
|
||||
# exp.Properties.Location.POST_ALIAS
|
||||
if not (
|
||||
self._match(TokenType.SELECT, advance=False)
|
||||
or self._match(TokenType.WITH, advance=False)
|
||||
or self._match(TokenType.L_PAREN, advance=False)
|
||||
):
|
||||
temp_properties = self._parse_properties()
|
||||
if properties and temp_properties:
|
||||
properties.expressions.append(temp_properties.expressions)
|
||||
elif temp_properties:
|
||||
properties = temp_properties
|
||||
|
||||
expression = self._parse_ddl_select()
|
||||
|
||||
if create_token.token_type == TokenType.TABLE:
|
||||
|
@ -1022,12 +1080,13 @@ class Parser(metaclass=_Parser):
|
|||
while True:
|
||||
index = self._parse_create_table_index()
|
||||
|
||||
# post index PARTITION BY property
|
||||
# exp.Properties.Location.POST_INDEX
|
||||
if self._match(TokenType.PARTITION_BY, advance=False):
|
||||
if properties:
|
||||
properties.expressions.append(self._parse_property())
|
||||
else:
|
||||
properties = self._parse_properties()
|
||||
temp_properties = self._parse_properties()
|
||||
if properties and temp_properties:
|
||||
properties.expressions.append(temp_properties.expressions)
|
||||
elif temp_properties:
|
||||
properties = temp_properties
|
||||
|
||||
if not index:
|
||||
break
|
||||
|
@ -1080,7 +1139,7 @@ class Parser(metaclass=_Parser):
|
|||
return self.PROPERTY_PARSERS[self._prev.text.upper()](self)
|
||||
|
||||
if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET):
|
||||
return self._parse_character_set(True)
|
||||
return self._parse_character_set(default=True)
|
||||
|
||||
if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY):
|
||||
return self._parse_sortkey(compound=True)
|
||||
|
@ -1240,7 +1299,7 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_blockcompression(self) -> exp.Expression:
|
||||
self._match_text_seq("BLOCKCOMPRESSION")
|
||||
self._match(TokenType.EQ)
|
||||
always = self._match(TokenType.ALWAYS)
|
||||
always = self._match_text_seq("ALWAYS")
|
||||
manual = self._match_text_seq("MANUAL")
|
||||
never = self._match_text_seq("NEVER")
|
||||
default = self._match_text_seq("DEFAULT")
|
||||
|
@ -1274,6 +1333,56 @@ class Parser(metaclass=_Parser):
|
|||
for_none=for_none,
|
||||
)
|
||||
|
||||
def _parse_locking(self) -> exp.Expression:
|
||||
if self._match(TokenType.TABLE):
|
||||
kind = "TABLE"
|
||||
elif self._match(TokenType.VIEW):
|
||||
kind = "VIEW"
|
||||
elif self._match(TokenType.ROW):
|
||||
kind = "ROW"
|
||||
elif self._match_text_seq("DATABASE"):
|
||||
kind = "DATABASE"
|
||||
else:
|
||||
kind = None
|
||||
|
||||
if kind in ("DATABASE", "TABLE", "VIEW"):
|
||||
this = self._parse_table_parts()
|
||||
else:
|
||||
this = None
|
||||
|
||||
if self._match(TokenType.FOR):
|
||||
for_or_in = "FOR"
|
||||
elif self._match(TokenType.IN):
|
||||
for_or_in = "IN"
|
||||
else:
|
||||
for_or_in = None
|
||||
|
||||
if self._match_text_seq("ACCESS"):
|
||||
lock_type = "ACCESS"
|
||||
elif self._match_texts(("EXCL", "EXCLUSIVE")):
|
||||
lock_type = "EXCLUSIVE"
|
||||
elif self._match_text_seq("SHARE"):
|
||||
lock_type = "SHARE"
|
||||
elif self._match_text_seq("READ"):
|
||||
lock_type = "READ"
|
||||
elif self._match_text_seq("WRITE"):
|
||||
lock_type = "WRITE"
|
||||
elif self._match_text_seq("CHECKSUM"):
|
||||
lock_type = "CHECKSUM"
|
||||
else:
|
||||
lock_type = None
|
||||
|
||||
override = self._match_text_seq("OVERRIDE")
|
||||
|
||||
return self.expression(
|
||||
exp.LockingProperty,
|
||||
this=this,
|
||||
kind=kind,
|
||||
for_or_in=for_or_in,
|
||||
lock_type=lock_type,
|
||||
override=override,
|
||||
)
|
||||
|
||||
def _parse_partition_by(self) -> t.List[t.Optional[exp.Expression]]:
|
||||
if self._match(TokenType.PARTITION_BY):
|
||||
return self._parse_csv(self._parse_conjunction)
|
||||
|
@ -1351,6 +1460,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
this: t.Optional[exp.Expression]
|
||||
|
||||
alternative = None
|
||||
if self._match_text_seq("DIRECTORY"):
|
||||
this = self.expression(
|
||||
exp.Directory,
|
||||
|
@ -1359,6 +1469,9 @@ class Parser(metaclass=_Parser):
|
|||
row_format=self._parse_row_format(match_row=True),
|
||||
)
|
||||
else:
|
||||
if self._match(TokenType.OR):
|
||||
alternative = self._match_texts(self.INSERT_ALTERNATIVES) and self._prev.text
|
||||
|
||||
self._match(TokenType.INTO)
|
||||
self._match(TokenType.TABLE)
|
||||
this = self._parse_table(schema=True)
|
||||
|
@ -1370,6 +1483,7 @@ class Parser(metaclass=_Parser):
|
|||
partition=self._parse_partition(),
|
||||
expression=self._parse_ddl_select(),
|
||||
overwrite=overwrite,
|
||||
alternative=alternative,
|
||||
)
|
||||
|
||||
def _parse_row(self) -> t.Optional[exp.Expression]:
|
||||
|
@ -1607,7 +1721,7 @@ class Parser(metaclass=_Parser):
|
|||
index = self._index
|
||||
|
||||
if self._match(TokenType.L_PAREN):
|
||||
columns = self._parse_csv(lambda: self._parse_column_def(self._parse_id_var()))
|
||||
columns = self._parse_csv(self._parse_function_parameter)
|
||||
self._match_r_paren() if columns else self._retreat(index)
|
||||
else:
|
||||
columns = None
|
||||
|
@ -2080,27 +2194,33 @@ class Parser(metaclass=_Parser):
|
|||
if not skip_group_by_token and not self._match(TokenType.GROUP_BY):
|
||||
return None
|
||||
|
||||
expressions = self._parse_csv(self._parse_conjunction)
|
||||
grouping_sets = self._parse_grouping_sets()
|
||||
elements = defaultdict(list)
|
||||
|
||||
self._match(TokenType.COMMA)
|
||||
with_ = self._match(TokenType.WITH)
|
||||
cube = self._match(TokenType.CUBE) and (
|
||||
with_ or self._parse_wrapped_csv(self._parse_column)
|
||||
)
|
||||
while True:
|
||||
expressions = self._parse_csv(self._parse_conjunction)
|
||||
if expressions:
|
||||
elements["expressions"].extend(expressions)
|
||||
|
||||
self._match(TokenType.COMMA)
|
||||
rollup = self._match(TokenType.ROLLUP) and (
|
||||
with_ or self._parse_wrapped_csv(self._parse_column)
|
||||
)
|
||||
grouping_sets = self._parse_grouping_sets()
|
||||
if grouping_sets:
|
||||
elements["grouping_sets"].extend(grouping_sets)
|
||||
|
||||
return self.expression(
|
||||
exp.Group,
|
||||
expressions=expressions,
|
||||
grouping_sets=grouping_sets,
|
||||
cube=cube,
|
||||
rollup=rollup,
|
||||
)
|
||||
rollup = None
|
||||
cube = None
|
||||
|
||||
with_ = self._match(TokenType.WITH)
|
||||
if self._match(TokenType.ROLLUP):
|
||||
rollup = with_ or self._parse_wrapped_csv(self._parse_column)
|
||||
elements["rollup"].extend(ensure_list(rollup))
|
||||
|
||||
if self._match(TokenType.CUBE):
|
||||
cube = with_ or self._parse_wrapped_csv(self._parse_column)
|
||||
elements["cube"].extend(ensure_list(cube))
|
||||
|
||||
if not (expressions or grouping_sets or rollup or cube):
|
||||
break
|
||||
|
||||
return self.expression(exp.Group, **elements) # type: ignore
|
||||
|
||||
def _parse_grouping_sets(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
|
||||
if not self._match(TokenType.GROUPING_SETS):
|
||||
|
@ -2357,6 +2477,8 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]:
|
||||
index = self._index
|
||||
|
||||
prefix = self._match_text_seq("SYSUDTLIB", ".")
|
||||
|
||||
if not self._match_set(self.TYPE_TOKENS):
|
||||
return None
|
||||
|
||||
|
@ -2458,6 +2580,7 @@ class Parser(metaclass=_Parser):
|
|||
expressions=expressions,
|
||||
nested=nested,
|
||||
values=values,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def _parse_struct_kwargs(self) -> t.Optional[exp.Expression]:
|
||||
|
@ -2512,8 +2635,14 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
if op:
|
||||
this = op(self, this, field)
|
||||
elif isinstance(this, exp.Column) and not this.table:
|
||||
this = self.expression(exp.Column, this=field, table=this.this)
|
||||
elif isinstance(this, exp.Column) and not this.args.get("catalog"):
|
||||
this = self.expression(
|
||||
exp.Column,
|
||||
this=field,
|
||||
table=this.this,
|
||||
db=this.args.get("table"),
|
||||
catalog=this.args.get("db"),
|
||||
)
|
||||
else:
|
||||
this = self.expression(exp.Dot, this=this, expression=field)
|
||||
this = self._parse_bracket(this)
|
||||
|
@ -2632,6 +2761,9 @@ class Parser(metaclass=_Parser):
|
|||
self._match_r_paren(this)
|
||||
return self._parse_window(this)
|
||||
|
||||
def _parse_function_parameter(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_column_def(self._parse_id_var())
|
||||
|
||||
def _parse_user_defined_function(
|
||||
self, kind: t.Optional[TokenType] = None
|
||||
) -> t.Optional[exp.Expression]:
|
||||
|
@ -2643,7 +2775,7 @@ class Parser(metaclass=_Parser):
|
|||
if not self._match(TokenType.L_PAREN):
|
||||
return this
|
||||
|
||||
expressions = self._parse_csv(self._parse_udf_kwarg)
|
||||
expressions = self._parse_csv(self._parse_function_parameter)
|
||||
self._match_r_paren()
|
||||
return self.expression(
|
||||
exp.UserDefinedFunction, this=this, expressions=expressions, wrapped=True
|
||||
|
@ -2669,15 +2801,6 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return self.expression(exp.SessionParameter, this=this, kind=kind)
|
||||
|
||||
def _parse_udf_kwarg(self) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_id_var()
|
||||
kind = self._parse_types()
|
||||
|
||||
if not kind:
|
||||
return this
|
||||
|
||||
return self.expression(exp.UserDefinedFunctionKwarg, this=this, kind=kind)
|
||||
|
||||
def _parse_lambda(self) -> t.Optional[exp.Expression]:
|
||||
index = self._index
|
||||
|
||||
|
@ -2726,6 +2849,9 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_column_def(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
kind = self._parse_types()
|
||||
|
||||
if self._match_text_seq("FOR", "ORDINALITY"):
|
||||
return self.expression(exp.ColumnDef, this=this, ordinality=True)
|
||||
|
||||
constraints = []
|
||||
while True:
|
||||
constraint = self._parse_column_constraint()
|
||||
|
@ -2738,79 +2864,78 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints)
|
||||
|
||||
def _parse_auto_increment(self) -> exp.Expression:
|
||||
start = None
|
||||
increment = None
|
||||
|
||||
if self._match(TokenType.L_PAREN, advance=False):
|
||||
args = self._parse_wrapped_csv(self._parse_bitwise)
|
||||
start = seq_get(args, 0)
|
||||
increment = seq_get(args, 1)
|
||||
elif self._match_text_seq("START"):
|
||||
start = self._parse_bitwise()
|
||||
self._match_text_seq("INCREMENT")
|
||||
increment = self._parse_bitwise()
|
||||
|
||||
if start and increment:
|
||||
return exp.GeneratedAsIdentityColumnConstraint(start=start, increment=increment)
|
||||
|
||||
return exp.AutoIncrementColumnConstraint()
|
||||
|
||||
def _parse_generated_as_identity(self) -> exp.Expression:
|
||||
if self._match(TokenType.BY_DEFAULT):
|
||||
this = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=False)
|
||||
else:
|
||||
self._match_text_seq("ALWAYS")
|
||||
this = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True)
|
||||
|
||||
self._match_text_seq("AS", "IDENTITY")
|
||||
if self._match(TokenType.L_PAREN):
|
||||
if self._match_text_seq("START", "WITH"):
|
||||
this.set("start", self._parse_bitwise())
|
||||
if self._match_text_seq("INCREMENT", "BY"):
|
||||
this.set("increment", self._parse_bitwise())
|
||||
if self._match_text_seq("MINVALUE"):
|
||||
this.set("minvalue", self._parse_bitwise())
|
||||
if self._match_text_seq("MAXVALUE"):
|
||||
this.set("maxvalue", self._parse_bitwise())
|
||||
|
||||
if self._match_text_seq("CYCLE"):
|
||||
this.set("cycle", True)
|
||||
elif self._match_text_seq("NO", "CYCLE"):
|
||||
this.set("cycle", False)
|
||||
|
||||
self._match_r_paren()
|
||||
|
||||
return this
|
||||
|
||||
def _parse_not_constraint(self) -> t.Optional[exp.Expression]:
|
||||
if self._match_text_seq("NULL"):
|
||||
return self.expression(exp.NotNullColumnConstraint)
|
||||
if self._match_text_seq("CASESPECIFIC"):
|
||||
return self.expression(exp.CaseSpecificColumnConstraint, not_=True)
|
||||
return None
|
||||
|
||||
def _parse_column_constraint(self) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_references()
|
||||
|
||||
if this:
|
||||
return this
|
||||
|
||||
if self._match(TokenType.CONSTRAINT):
|
||||
this = self._parse_id_var()
|
||||
|
||||
kind: exp.Expression
|
||||
if self._match_texts(self.CONSTRAINT_PARSERS):
|
||||
return self.expression(
|
||||
exp.ColumnConstraint,
|
||||
this=this,
|
||||
kind=self.CONSTRAINT_PARSERS[self._prev.text.upper()](self),
|
||||
)
|
||||
|
||||
if self._match_set((TokenType.AUTO_INCREMENT, TokenType.IDENTITY)):
|
||||
start = None
|
||||
increment = None
|
||||
|
||||
if self._match(TokenType.L_PAREN, advance=False):
|
||||
args = self._parse_wrapped_csv(self._parse_bitwise)
|
||||
start = seq_get(args, 0)
|
||||
increment = seq_get(args, 1)
|
||||
elif self._match_text_seq("START"):
|
||||
start = self._parse_bitwise()
|
||||
self._match_text_seq("INCREMENT")
|
||||
increment = self._parse_bitwise()
|
||||
|
||||
if start and increment:
|
||||
kind = exp.GeneratedAsIdentityColumnConstraint(start=start, increment=increment)
|
||||
else:
|
||||
kind = exp.AutoIncrementColumnConstraint()
|
||||
elif self._match(TokenType.CHECK):
|
||||
constraint = self._parse_wrapped(self._parse_conjunction)
|
||||
kind = self.expression(exp.CheckColumnConstraint, this=constraint)
|
||||
elif self._match(TokenType.COLLATE):
|
||||
kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var())
|
||||
elif self._match(TokenType.ENCODE):
|
||||
kind = self.expression(exp.EncodeColumnConstraint, this=self._parse_var())
|
||||
elif self._match(TokenType.DEFAULT):
|
||||
kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_bitwise())
|
||||
elif self._match_pair(TokenType.NOT, TokenType.NULL):
|
||||
kind = exp.NotNullColumnConstraint()
|
||||
elif self._match(TokenType.NULL):
|
||||
kind = exp.NotNullColumnConstraint(allow_null=True)
|
||||
elif self._match(TokenType.SCHEMA_COMMENT):
|
||||
kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string())
|
||||
elif self._match(TokenType.PRIMARY_KEY):
|
||||
desc = None
|
||||
if self._match(TokenType.ASC) or self._match(TokenType.DESC):
|
||||
desc = self._prev.token_type == TokenType.DESC
|
||||
kind = exp.PrimaryKeyColumnConstraint(desc=desc)
|
||||
elif self._match(TokenType.UNIQUE):
|
||||
kind = exp.UniqueColumnConstraint()
|
||||
elif self._match(TokenType.GENERATED):
|
||||
if self._match(TokenType.BY_DEFAULT):
|
||||
kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=False)
|
||||
else:
|
||||
self._match(TokenType.ALWAYS)
|
||||
kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True)
|
||||
self._match_pair(TokenType.ALIAS, TokenType.IDENTITY)
|
||||
|
||||
if self._match(TokenType.L_PAREN):
|
||||
if self._match_text_seq("START", "WITH"):
|
||||
kind.set("start", self._parse_bitwise())
|
||||
if self._match_text_seq("INCREMENT", "BY"):
|
||||
kind.set("increment", self._parse_bitwise())
|
||||
|
||||
self._match_r_paren()
|
||||
else:
|
||||
return this
|
||||
|
||||
return self.expression(exp.ColumnConstraint, this=this, kind=kind)
|
||||
return this
|
||||
|
||||
def _parse_constraint(self) -> t.Optional[exp.Expression]:
|
||||
if not self._match(TokenType.CONSTRAINT):
|
||||
return self._parse_unnamed_constraint()
|
||||
return self._parse_unnamed_constraint(constraints=self.SCHEMA_UNNAMED_CONSTRAINTS)
|
||||
|
||||
this = self._parse_id_var()
|
||||
expressions = []
|
||||
|
@ -2823,12 +2948,21 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return self.expression(exp.Constraint, this=this, expressions=expressions)
|
||||
|
||||
def _parse_unnamed_constraint(self) -> t.Optional[exp.Expression]:
|
||||
if not self._match_set(self.CONSTRAINT_PARSERS):
|
||||
def _parse_unnamed_constraint(
|
||||
self, constraints: t.Optional[t.Collection[str]] = None
|
||||
) -> t.Optional[exp.Expression]:
|
||||
if not self._match_texts(constraints or self.CONSTRAINT_PARSERS):
|
||||
return None
|
||||
return self.CONSTRAINT_PARSERS[self._prev.token_type](self)
|
||||
|
||||
constraint = self._prev.text.upper()
|
||||
if constraint not in self.CONSTRAINT_PARSERS:
|
||||
self.raise_error(f"No parser found for schema constraint {constraint}.")
|
||||
|
||||
return self.CONSTRAINT_PARSERS[constraint](self)
|
||||
|
||||
def _parse_unique(self) -> exp.Expression:
|
||||
if not self._match(TokenType.L_PAREN, advance=False):
|
||||
return self.expression(exp.UniqueColumnConstraint)
|
||||
return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars())
|
||||
|
||||
def _parse_key_constraint_options(self) -> t.List[str]:
|
||||
|
@ -2908,6 +3042,14 @@ class Parser(metaclass=_Parser):
|
|||
)
|
||||
|
||||
def _parse_primary_key(self) -> exp.Expression:
|
||||
desc = (
|
||||
self._match_set((TokenType.ASC, TokenType.DESC))
|
||||
and self._prev.token_type == TokenType.DESC
|
||||
)
|
||||
|
||||
if not self._match(TokenType.L_PAREN, advance=False):
|
||||
return self.expression(exp.PrimaryKeyColumnConstraint, desc=desc)
|
||||
|
||||
expressions = self._parse_wrapped_id_vars()
|
||||
options = self._parse_key_constraint_options()
|
||||
return self.expression(exp.PrimaryKey, expressions=expressions, options=options)
|
||||
|
@ -3306,6 +3448,12 @@ class Parser(metaclass=_Parser):
|
|||
return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev)
|
||||
return None
|
||||
|
||||
def _parse_parameter(self) -> exp.Expression:
|
||||
wrapped = self._match(TokenType.L_BRACE)
|
||||
this = self._parse_var() or self._parse_primary()
|
||||
self._match(TokenType.R_BRACE)
|
||||
return self.expression(exp.Parameter, this=this, wrapped=wrapped)
|
||||
|
||||
def _parse_placeholder(self) -> t.Optional[exp.Expression]:
|
||||
if self._match_set(self.PLACEHOLDER_PARSERS):
|
||||
placeholder = self.PLACEHOLDER_PARSERS[self._prev.token_type](self)
|
||||
|
@ -3449,7 +3597,7 @@ class Parser(metaclass=_Parser):
|
|||
if kind == TokenType.CONSTRAINT:
|
||||
this = self._parse_id_var()
|
||||
|
||||
if self._match(TokenType.CHECK):
|
||||
if self._match_text_seq("CHECK"):
|
||||
expression = self._parse_wrapped(self._parse_conjunction)
|
||||
enforced = self._match_text_seq("ENFORCED")
|
||||
|
||||
|
|
|
@ -138,7 +138,6 @@ class TokenType(AutoName):
|
|||
CASCADE = auto()
|
||||
CASE = auto()
|
||||
CHARACTER_SET = auto()
|
||||
CHECK = auto()
|
||||
CLUSTER_BY = auto()
|
||||
COLLATE = auto()
|
||||
COMMAND = auto()
|
||||
|
@ -164,7 +163,6 @@ class TokenType(AutoName):
|
|||
DIV = auto()
|
||||
DROP = auto()
|
||||
ELSE = auto()
|
||||
ENCODE = auto()
|
||||
END = auto()
|
||||
ESCAPE = auto()
|
||||
EXCEPT = auto()
|
||||
|
@ -182,17 +180,16 @@ class TokenType(AutoName):
|
|||
FROM = auto()
|
||||
FULL = auto()
|
||||
FUNCTION = auto()
|
||||
GENERATED = auto()
|
||||
GLOB = auto()
|
||||
GLOBAL = auto()
|
||||
GROUP_BY = auto()
|
||||
GROUPING_SETS = auto()
|
||||
HAVING = auto()
|
||||
HINT = auto()
|
||||
IDENTITY = auto()
|
||||
IF = auto()
|
||||
IGNORE_NULLS = auto()
|
||||
ILIKE = auto()
|
||||
ILIKE_ANY = auto()
|
||||
IN = auto()
|
||||
INDEX = auto()
|
||||
INNER = auto()
|
||||
|
@ -211,6 +208,7 @@ class TokenType(AutoName):
|
|||
LEADING = auto()
|
||||
LEFT = auto()
|
||||
LIKE = auto()
|
||||
LIKE_ANY = auto()
|
||||
LIMIT = auto()
|
||||
LOAD_DATA = auto()
|
||||
LOCAL = auto()
|
||||
|
@ -253,6 +251,7 @@ class TokenType(AutoName):
|
|||
RECURSIVE = auto()
|
||||
REPLACE = auto()
|
||||
RESPECT_NULLS = auto()
|
||||
RETURNING = auto()
|
||||
REFERENCES = auto()
|
||||
RIGHT = auto()
|
||||
RLIKE = auto()
|
||||
|
@ -260,7 +259,6 @@ class TokenType(AutoName):
|
|||
ROLLUP = auto()
|
||||
ROW = auto()
|
||||
ROWS = auto()
|
||||
SCHEMA_COMMENT = auto()
|
||||
SEED = auto()
|
||||
SELECT = auto()
|
||||
SEMI = auto()
|
||||
|
@ -441,7 +439,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
KEYWORDS = {
|
||||
**{
|
||||
f"{key}{postfix}": TokenType.BLOCK_START
|
||||
for key in ("{{", "{%", "{#")
|
||||
for key in ("{%", "{#")
|
||||
for postfix in ("", "+", "-")
|
||||
},
|
||||
**{
|
||||
|
@ -449,6 +447,8 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
for key in ("%}", "#}")
|
||||
for prefix in ("", "+", "-")
|
||||
},
|
||||
"{{+": TokenType.BLOCK_START,
|
||||
"{{-": TokenType.BLOCK_START,
|
||||
"+}}": TokenType.BLOCK_END,
|
||||
"-}}": TokenType.BLOCK_END,
|
||||
"/*+": TokenType.HINT,
|
||||
|
@ -486,11 +486,9 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"CASE": TokenType.CASE,
|
||||
"CASCADE": TokenType.CASCADE,
|
||||
"CHARACTER SET": TokenType.CHARACTER_SET,
|
||||
"CHECK": TokenType.CHECK,
|
||||
"CLUSTER BY": TokenType.CLUSTER_BY,
|
||||
"COLLATE": TokenType.COLLATE,
|
||||
"COLUMN": TokenType.COLUMN,
|
||||
"COMMENT": TokenType.SCHEMA_COMMENT,
|
||||
"COMMIT": TokenType.COMMIT,
|
||||
"COMPOUND": TokenType.COMPOUND,
|
||||
"CONSTRAINT": TokenType.CONSTRAINT,
|
||||
|
@ -526,12 +524,10 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"FOREIGN KEY": TokenType.FOREIGN_KEY,
|
||||
"FORMAT": TokenType.FORMAT,
|
||||
"FROM": TokenType.FROM,
|
||||
"GENERATED": TokenType.GENERATED,
|
||||
"GLOB": TokenType.GLOB,
|
||||
"GROUP BY": TokenType.GROUP_BY,
|
||||
"GROUPING SETS": TokenType.GROUPING_SETS,
|
||||
"HAVING": TokenType.HAVING,
|
||||
"IDENTITY": TokenType.IDENTITY,
|
||||
"IF": TokenType.IF,
|
||||
"ILIKE": TokenType.ILIKE,
|
||||
"IGNORE NULLS": TokenType.IGNORE_NULLS,
|
||||
|
@ -747,11 +743,9 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"_prev_token_line",
|
||||
"_prev_token_comments",
|
||||
"_prev_token_type",
|
||||
"_replace_backslash",
|
||||
)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._replace_backslash = "\\" in self._STRING_ESCAPES
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
|
@ -855,7 +849,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
def _scan_keywords(self) -> None:
|
||||
size = 0
|
||||
word = None
|
||||
chars = self._text
|
||||
chars: t.Optional[str] = self._text
|
||||
char = chars
|
||||
prev_space = False
|
||||
skip = False
|
||||
|
@ -887,7 +881,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
else:
|
||||
skip = True
|
||||
else:
|
||||
chars = None # type: ignore
|
||||
chars = None
|
||||
|
||||
if not word:
|
||||
if self._char in self.SINGLE_TOKENS:
|
||||
|
@ -1015,7 +1009,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._advance(len(quote))
|
||||
text = self._extract_string(quote_end)
|
||||
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore
|
||||
text = text.replace("\\\\", "\\") if self._replace_backslash else text
|
||||
self._add(TokenType.NATIONAL if quote[0].upper() == "N" else TokenType.STRING, text)
|
||||
return True
|
||||
|
||||
|
@ -1091,13 +1084,18 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
delim_size = len(delimiter)
|
||||
|
||||
while True:
|
||||
if (
|
||||
self._char in self._STRING_ESCAPES
|
||||
and self._peek
|
||||
and (self._peek == delimiter or self._peek in self._STRING_ESCAPES)
|
||||
if self._char in self._STRING_ESCAPES and (
|
||||
self._peek == delimiter or self._peek in self._STRING_ESCAPES
|
||||
):
|
||||
text += self._peek
|
||||
self._advance(2)
|
||||
if self._peek == delimiter:
|
||||
text += self._peek # type: ignore
|
||||
else:
|
||||
text += self._char + self._peek # type: ignore
|
||||
|
||||
if self._current + 1 < self.size:
|
||||
self._advance(2)
|
||||
else:
|
||||
raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._current}")
|
||||
else:
|
||||
if self._chars(delim_size) == delimiter:
|
||||
if delim_size > 1:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue