Merging upstream version 10.6.3.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
d03a55eda6
commit
ece6881255
48 changed files with 906 additions and 266 deletions
|
@ -33,7 +33,13 @@ from sqlglot.parser import Parser
|
|||
from sqlglot.schema import MappingSchema, Schema
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
|
||||
__version__ = "10.6.0"
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
|
||||
T = t.TypeVar("T", bound=Expression)
|
||||
|
||||
|
||||
__version__ = "10.6.3"
|
||||
|
||||
pretty = False
|
||||
"""Whether to format generated SQL by default."""
|
||||
|
@ -42,9 +48,7 @@ schema = MappingSchema()
|
|||
"""The default schema used by SQLGlot (e.g. in the optimizer)."""
|
||||
|
||||
|
||||
def parse(
|
||||
sql: str, read: t.Optional[str | Dialect] = None, **opts
|
||||
) -> t.List[t.Optional[Expression]]:
|
||||
def parse(sql: str, read: DialectType = None, **opts) -> t.List[t.Optional[Expression]]:
|
||||
"""
|
||||
Parses the given SQL string into a collection of syntax trees, one per parsed SQL statement.
|
||||
|
||||
|
@ -60,9 +64,57 @@ def parse(
|
|||
return dialect.parse(sql, **opts)
|
||||
|
||||
|
||||
@t.overload
|
||||
def parse_one(
|
||||
sql: str,
|
||||
read: t.Optional[str | Dialect] = None,
|
||||
read: None = None,
|
||||
into: t.Type[T] = ...,
|
||||
**opts,
|
||||
) -> T:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def parse_one(
|
||||
sql: str,
|
||||
read: DialectType,
|
||||
into: t.Type[T],
|
||||
**opts,
|
||||
) -> T:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def parse_one(
|
||||
sql: str,
|
||||
read: None = None,
|
||||
into: t.Union[str, t.Collection[t.Union[str, t.Type[Expression]]]] = ...,
|
||||
**opts,
|
||||
) -> Expression:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def parse_one(
|
||||
sql: str,
|
||||
read: DialectType,
|
||||
into: t.Union[str, t.Collection[t.Union[str, t.Type[Expression]]]],
|
||||
**opts,
|
||||
) -> Expression:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def parse_one(
|
||||
sql: str,
|
||||
**opts,
|
||||
) -> Expression:
|
||||
...
|
||||
|
||||
|
||||
def parse_one(
|
||||
sql: str,
|
||||
read: DialectType = None,
|
||||
into: t.Optional[exp.IntoType] = None,
|
||||
**opts,
|
||||
) -> Expression:
|
||||
|
@ -96,8 +148,8 @@ def parse_one(
|
|||
|
||||
def transpile(
|
||||
sql: str,
|
||||
read: t.Optional[str | Dialect] = None,
|
||||
write: t.Optional[str | Dialect] = None,
|
||||
read: DialectType = None,
|
||||
write: DialectType = None,
|
||||
identity: bool = True,
|
||||
error_level: t.Optional[ErrorLevel] = None,
|
||||
**opts,
|
||||
|
|
|
@ -260,11 +260,7 @@ class Column:
|
|||
"""
|
||||
if isinstance(dataType, DataType):
|
||||
dataType = dataType.simpleString()
|
||||
new_expression = exp.Cast(
|
||||
this=self.column_expression,
|
||||
to=sqlglot.parse_one(dataType, into=exp.DataType, read="spark"), # type: ignore
|
||||
)
|
||||
return Column(new_expression)
|
||||
return Column(exp.cast(self.column_expression, dataType, dialect="spark"))
|
||||
|
||||
def startswith(self, value: t.Union[str, Column]) -> Column:
|
||||
value = self._lit(value) if not isinstance(value, Column) else value
|
||||
|
|
|
@ -536,15 +536,15 @@ def month(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def dayofweek(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_anonymous_function(col, "DAYOFWEEK")
|
||||
return Column.invoke_expression_over_column(col, glotexp.DayOfWeek)
|
||||
|
||||
|
||||
def dayofmonth(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_anonymous_function(col, "DAYOFMONTH")
|
||||
return Column.invoke_expression_over_column(col, glotexp.DayOfMonth)
|
||||
|
||||
|
||||
def dayofyear(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_anonymous_function(col, "DAYOFYEAR")
|
||||
return Column.invoke_expression_over_column(col, glotexp.DayOfYear)
|
||||
|
||||
|
||||
def hour(col: ColumnOrName) -> Column:
|
||||
|
@ -560,7 +560,7 @@ def second(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def weekofyear(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_anonymous_function(col, "WEEKOFYEAR")
|
||||
return Column.invoke_expression_over_column(col, glotexp.WeekOfYear)
|
||||
|
||||
|
||||
def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Column:
|
||||
|
@ -1144,10 +1144,16 @@ def aggregate(
|
|||
merge_exp = _get_lambda_from_func(merge)
|
||||
if finish is not None:
|
||||
finish_exp = _get_lambda_from_func(finish)
|
||||
return Column.invoke_anonymous_function(
|
||||
col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp)
|
||||
return Column.invoke_expression_over_column(
|
||||
col,
|
||||
glotexp.Reduce,
|
||||
initial=initialValue,
|
||||
merge=Column(merge_exp),
|
||||
finish=Column(finish_exp),
|
||||
)
|
||||
return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp))
|
||||
return Column.invoke_expression_over_column(
|
||||
col, glotexp.Reduce, initial=initialValue, merge=Column(merge_exp)
|
||||
)
|
||||
|
||||
|
||||
def transform(
|
||||
|
|
|
@ -222,14 +222,6 @@ class BigQuery(Dialect):
|
|||
exp.DataType.Type.NVARCHAR: "STRING",
|
||||
}
|
||||
|
||||
ROOT_PROPERTIES = {
|
||||
exp.LanguageProperty,
|
||||
exp.ReturnsProperty,
|
||||
exp.VolatilityProperty,
|
||||
}
|
||||
|
||||
WITH_PROPERTIES = {exp.Property}
|
||||
|
||||
EXPLICIT_UNION = True
|
||||
|
||||
def array_sql(self, expression: exp.Array) -> str:
|
||||
|
|
|
@ -122,9 +122,15 @@ class Dialect(metaclass=_Dialect):
|
|||
def get_or_raise(cls, dialect):
|
||||
if not dialect:
|
||||
return cls
|
||||
if isinstance(dialect, _Dialect):
|
||||
return dialect
|
||||
if isinstance(dialect, Dialect):
|
||||
return dialect.__class__
|
||||
|
||||
result = cls.get(dialect)
|
||||
if not result:
|
||||
raise ValueError(f"Unknown dialect '{dialect}'")
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
|
@ -196,6 +202,10 @@ class Dialect(metaclass=_Dialect):
|
|||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
|
||||
|
||||
|
||||
def rename_func(name):
|
||||
def _rename(self, expression):
|
||||
args = flatten(expression.args.values())
|
||||
|
|
|
@ -137,7 +137,10 @@ class Drill(Dialect):
|
|||
exp.DataType.Type.DATETIME: "TIMESTAMP",
|
||||
}
|
||||
|
||||
ROOT_PROPERTIES = {exp.PartitionedByProperty}
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
|
|
|
@ -20,10 +20,6 @@ from sqlglot.helper import seq_get
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _unix_to_time(self, expression):
|
||||
return f"TO_TIMESTAMP(CAST({self.sql(expression, 'this')} AS BIGINT))"
|
||||
|
||||
|
||||
def _str_to_time_sql(self, expression):
|
||||
return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})"
|
||||
|
||||
|
@ -113,7 +109,7 @@ class DuckDB(Dialect):
|
|||
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
|
||||
"STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
|
||||
"STRUCT_PACK": exp.Struct.from_arg_list,
|
||||
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
|
||||
"TO_TIMESTAMP": exp.UnixToTime.from_arg_list,
|
||||
"UNNEST": exp.Explode.from_arg_list,
|
||||
}
|
||||
|
||||
|
@ -162,9 +158,9 @@ class DuckDB(Dialect):
|
|||
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
|
||||
exp.TsOrDsAdd: _ts_or_ds_add,
|
||||
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
|
||||
exp.UnixToStr: lambda self, e: f"STRFTIME({_unix_to_time(self, e)}, {self.format_time(e)})",
|
||||
exp.UnixToTime: _unix_to_time,
|
||||
exp.UnixToTimeStr: lambda self, e: f"CAST({_unix_to_time(self, e)} AS TEXT)",
|
||||
exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
|
||||
exp.UnixToTime: rename_func("TO_TIMESTAMP"),
|
||||
exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)",
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
|
|
|
@ -322,17 +322,11 @@ class Hive(Dialect):
|
|||
exp.LastDateOfMonth: rename_func("LAST_DAY"),
|
||||
}
|
||||
|
||||
WITH_PROPERTIES = {exp.Property}
|
||||
|
||||
ROOT_PROPERTIES = {
|
||||
exp.PartitionedByProperty,
|
||||
exp.FileFormatProperty,
|
||||
exp.SchemaCommentProperty,
|
||||
exp.LocationProperty,
|
||||
exp.TableFormatProperty,
|
||||
exp.RowFormatDelimitedProperty,
|
||||
exp.RowFormatSerdeProperty,
|
||||
exp.SerdeProperties,
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
}
|
||||
|
||||
def with_properties(self, properties):
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
|
@ -98,6 +96,8 @@ def _date_add_sql(kind):
|
|||
|
||||
|
||||
class MySQL(Dialect):
|
||||
time_format = "'%Y-%m-%d %T'"
|
||||
|
||||
# https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions
|
||||
time_mapping = {
|
||||
"%M": "%B",
|
||||
|
@ -110,6 +110,7 @@ class MySQL(Dialect):
|
|||
"%u": "%W",
|
||||
"%k": "%-H",
|
||||
"%l": "%-I",
|
||||
"%T": "%H:%M:%S",
|
||||
}
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
|
@ -428,6 +429,7 @@ class MySQL(Dialect):
|
|||
)
|
||||
|
||||
class Generator(generator.Generator):
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
NULL_ORDERING_SUPPORTED = False
|
||||
|
||||
TRANSFORMS = {
|
||||
|
@ -449,23 +451,12 @@ class MySQL(Dialect):
|
|||
exp.StrPosition: strposition_to_locate_sql,
|
||||
}
|
||||
|
||||
ROOT_PROPERTIES = {
|
||||
exp.EngineProperty,
|
||||
exp.AutoIncrementProperty,
|
||||
exp.CharacterSetProperty,
|
||||
exp.CollateProperty,
|
||||
exp.SchemaCommentProperty,
|
||||
exp.LikeProperty,
|
||||
}
|
||||
|
||||
TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy()
|
||||
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT)
|
||||
TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT)
|
||||
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMBLOB)
|
||||
TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB)
|
||||
|
||||
WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()
|
||||
|
||||
def show_sql(self, expression):
|
||||
this = f" {expression.name}"
|
||||
full = " FULL" if expression.args.get("full") else ""
|
||||
|
|
|
@ -44,6 +44,8 @@ class Oracle(Dialect):
|
|||
}
|
||||
|
||||
class Generator(generator.Generator):
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.TINYINT: "NUMBER",
|
||||
|
@ -69,6 +71,7 @@ class Oracle(Dialect):
|
|||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
|
||||
exp.Substring: rename_func("SUBSTR"),
|
||||
}
|
||||
|
||||
def query_modifiers(self, expression, *sqls):
|
||||
|
@ -90,6 +93,7 @@ class Oracle(Dialect):
|
|||
self.sql(expression, "order"),
|
||||
self.sql(expression, "offset"), # offset before limit in oracle
|
||||
self.sql(expression, "limit"),
|
||||
self.sql(expression, "lock"),
|
||||
sep="",
|
||||
)
|
||||
|
||||
|
|
|
@ -148,6 +148,22 @@ def _serial_to_generated(expression):
|
|||
return expression
|
||||
|
||||
|
||||
def _generate_series(args):
|
||||
# The goal is to convert step values like '1 day' or INTERVAL '1 day' into INTERVAL '1' day
|
||||
step = seq_get(args, 2)
|
||||
|
||||
if step is None:
|
||||
# Postgres allows calls with just two arguments -- the "step" argument defaults to 1
|
||||
return exp.GenerateSeries.from_arg_list(args)
|
||||
|
||||
if step.is_string:
|
||||
args[2] = exp.to_interval(step.this)
|
||||
elif isinstance(step, exp.Interval) and not step.args.get("unit"):
|
||||
args[2] = exp.to_interval(step.this.this)
|
||||
|
||||
return exp.GenerateSeries.from_arg_list(args)
|
||||
|
||||
|
||||
def _to_timestamp(args):
|
||||
# TO_TIMESTAMP accepts either a single double argument or (text, text)
|
||||
if len(args) == 1:
|
||||
|
@ -195,29 +211,6 @@ class Postgres(Dialect):
|
|||
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
|
||||
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
|
||||
|
||||
CREATABLES = (
|
||||
"AGGREGATE",
|
||||
"CAST",
|
||||
"CONVERSION",
|
||||
"COLLATION",
|
||||
"DEFAULT CONVERSION",
|
||||
"CONSTRAINT",
|
||||
"DOMAIN",
|
||||
"EXTENSION",
|
||||
"FOREIGN",
|
||||
"FUNCTION",
|
||||
"OPERATOR",
|
||||
"POLICY",
|
||||
"ROLE",
|
||||
"RULE",
|
||||
"SEQUENCE",
|
||||
"TEXT",
|
||||
"TRIGGER",
|
||||
"TYPE",
|
||||
"UNLOGGED",
|
||||
"USER",
|
||||
)
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"~~": TokenType.LIKE,
|
||||
|
@ -243,8 +236,6 @@ class Postgres(Dialect):
|
|||
"TEMP": TokenType.TEMPORARY,
|
||||
"UUID": TokenType.UUID,
|
||||
"CSTRING": TokenType.PSEUDO_TYPE,
|
||||
**{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
|
||||
**{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
|
||||
}
|
||||
QUOTES = ["'", "$$"]
|
||||
SINGLE_TOKENS = {
|
||||
|
@ -257,8 +248,10 @@ class Postgres(Dialect):
|
|||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"NOW": exp.CurrentTimestamp.from_arg_list,
|
||||
"TO_TIMESTAMP": _to_timestamp,
|
||||
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
|
||||
"GENERATE_SERIES": _generate_series,
|
||||
}
|
||||
|
||||
BITWISE = {
|
||||
|
@ -272,6 +265,8 @@ class Postgres(Dialect):
|
|||
}
|
||||
|
||||
class Generator(generator.Generator):
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.TINYINT: "SMALLINT",
|
||||
|
|
|
@ -105,6 +105,29 @@ def _ts_or_ds_add_sql(self, expression):
|
|||
return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))"
|
||||
|
||||
|
||||
def _sequence_sql(self, expression):
|
||||
start = expression.args["start"]
|
||||
end = expression.args["end"]
|
||||
step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series
|
||||
|
||||
target_type = None
|
||||
|
||||
if isinstance(start, exp.Cast):
|
||||
target_type = start.to
|
||||
elif isinstance(end, exp.Cast):
|
||||
target_type = end.to
|
||||
|
||||
if target_type and target_type.this == exp.DataType.Type.TIMESTAMP:
|
||||
to = target_type.copy()
|
||||
|
||||
if target_type is start.to:
|
||||
end = exp.Cast(this=end, to=to)
|
||||
else:
|
||||
start = exp.Cast(this=start, to=to)
|
||||
|
||||
return f"SEQUENCE({self.format_args(start, end, step)})"
|
||||
|
||||
|
||||
def _ensure_utf8(charset):
|
||||
if charset.name.lower() != "utf-8":
|
||||
raise UnsupportedError(f"Unsupported charset {charset}")
|
||||
|
@ -145,7 +168,7 @@ def _from_unixtime(args):
|
|||
class Presto(Dialect):
|
||||
index_offset = 1
|
||||
null_ordering = "nulls_are_last"
|
||||
time_format = "'%Y-%m-%d %H:%i:%S'"
|
||||
time_format = MySQL.time_format # type: ignore
|
||||
time_mapping = MySQL.time_mapping # type: ignore
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
|
@ -197,7 +220,10 @@ class Presto(Dialect):
|
|||
class Generator(generator.Generator):
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
||||
ROOT_PROPERTIES = {exp.SchemaCommentProperty}
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.LocationProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
|
@ -223,6 +249,7 @@ class Presto(Dialect):
|
|||
exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
|
||||
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
|
||||
|
@ -231,6 +258,7 @@ class Presto(Dialect):
|
|||
exp.Decode: _decode_sql,
|
||||
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
|
||||
exp.Encode: _encode_sql,
|
||||
exp.GenerateSeries: _sequence_sql,
|
||||
exp.Hex: rename_func("TO_HEX"),
|
||||
exp.If: if_sql,
|
||||
exp.ILike: no_ilike_sql,
|
||||
|
|
|
@ -61,14 +61,9 @@ class Redshift(Postgres):
|
|||
exp.DataType.Type.INT: "INTEGER",
|
||||
}
|
||||
|
||||
ROOT_PROPERTIES = {
|
||||
exp.DistKeyProperty,
|
||||
exp.SortKeyProperty,
|
||||
exp.DistStyleProperty,
|
||||
}
|
||||
|
||||
WITH_PROPERTIES = {
|
||||
exp.LikeProperty,
|
||||
PROPERTIES_LOCATION = {
|
||||
**Postgres.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.LikeProperty: exp.Properties.Location.POST_SCHEMA_WITH,
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
|
|
|
@ -234,15 +234,6 @@ class Snowflake(Dialect):
|
|||
"replace": "RENAME",
|
||||
}
|
||||
|
||||
ROOT_PROPERTIES = {
|
||||
exp.PartitionedByProperty,
|
||||
exp.ReturnsProperty,
|
||||
exp.LanguageProperty,
|
||||
exp.SchemaCommentProperty,
|
||||
exp.ExecuteAsProperty,
|
||||
exp.VolatilityProperty,
|
||||
}
|
||||
|
||||
def except_op(self, expression):
|
||||
if not expression.args.get("distinct", False):
|
||||
self.unsupported("EXCEPT with All is not supported in Snowflake")
|
||||
|
|
|
@ -73,6 +73,19 @@ class Spark(Hive):
|
|||
),
|
||||
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
||||
"IIF": exp.If.from_arg_list,
|
||||
"AGGREGATE": exp.Reduce.from_arg_list,
|
||||
"DAYOFWEEK": lambda args: exp.DayOfWeek(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"DAYOFMONTH": lambda args: exp.DayOfMonth(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"DAYOFYEAR": lambda args: exp.DayOfYear(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
|
@ -105,6 +118,14 @@ class Spark(Hive):
|
|||
exp.DataType.Type.BIGINT: "LONG",
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**Hive.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.CollateProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**Hive.Generator.TRANSFORMS, # type: ignore
|
||||
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
|
||||
|
@ -126,11 +147,27 @@ class Spark(Hive):
|
|||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
exp.DateFromParts: rename_func("MAKE_DATE"),
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
|
||||
exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
|
||||
}
|
||||
TRANSFORMS.pop(exp.ArraySort)
|
||||
TRANSFORMS.pop(exp.ILike)
|
||||
|
||||
WRAP_DERIVED_VALUES = False
|
||||
|
||||
def cast_sql(self, expression: exp.Cast) -> str:
|
||||
if isinstance(expression.this, exp.Cast) and expression.this.is_type(
|
||||
exp.DataType.Type.JSON
|
||||
):
|
||||
schema = f"'{self.sql(expression, 'to')}'"
|
||||
return f"FROM_JSON({self.format_args(self.sql(expression.this, 'this'), schema)})"
|
||||
if expression.to.is_type(exp.DataType.Type.JSON):
|
||||
return f"TO_JSON({self.sql(expression, 'this')})"
|
||||
|
||||
return super(Spark.Generator, self).cast_sql(expression)
|
||||
|
||||
class Tokenizer(Hive.Tokenizer):
|
||||
HEX_STRINGS = [("X'", "'")]
|
||||
|
|
|
@ -31,6 +31,5 @@ class Tableau(Dialect):
|
|||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"IFNULL": exp.Coalesce.from_arg_list,
|
||||
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
|
||||
}
|
||||
|
|
|
@ -76,6 +76,14 @@ class Teradata(Dialect):
|
|||
)
|
||||
|
||||
class Generator(generator.Generator):
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX,
|
||||
}
|
||||
|
||||
def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str:
|
||||
return f"PARTITION BY {self.sql(expression, 'this')}"
|
||||
|
||||
# FROM before SET in Teradata UPDATE syntax
|
||||
# https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause
|
||||
def update_sql(self, expression: exp.Update) -> str:
|
||||
|
|
|
@ -412,6 +412,8 @@ class TSQL(Dialect):
|
|||
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
|
||||
|
||||
class Generator(generator.Generator):
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.BOOLEAN: "BIT",
|
||||
|
|
|
@ -14,10 +14,6 @@ from sqlglot import Dialect
|
|||
from sqlglot import expressions as exp
|
||||
from sqlglot.helper import ensure_collection
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
T = t.TypeVar("T")
|
||||
Edit = t.Union[Insert, Remove, Move, Update, Keep]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Insert:
|
||||
|
@ -56,6 +52,11 @@ class Keep:
|
|||
target: exp.Expression
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
T = t.TypeVar("T")
|
||||
Edit = t.Union[Insert, Remove, Move, Update, Keep]
|
||||
|
||||
|
||||
def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
|
||||
"""
|
||||
Returns the list of changes between the source and the target expressions.
|
||||
|
|
|
@ -1,5 +1,13 @@
|
|||
"""
|
||||
.. include:: ../../posts/python_sql_engine.md
|
||||
----
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
import typing as t
|
||||
|
||||
from sqlglot import maybe_parse
|
||||
from sqlglot.errors import ExecuteError
|
||||
|
@ -11,42 +19,63 @@ from sqlglot.schema import ensure_schema
|
|||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
from sqlglot.executor.table import Tables
|
||||
from sqlglot.expressions import Expression
|
||||
from sqlglot.schema import Schema
|
||||
|
||||
def execute(sql, schema=None, read=None, tables=None):
|
||||
|
||||
def execute(
|
||||
sql: str | Expression,
|
||||
schema: t.Optional[t.Dict | Schema] = None,
|
||||
read: DialectType = None,
|
||||
tables: t.Optional[t.Dict] = None,
|
||||
) -> Table:
|
||||
"""
|
||||
Run a sql query against data.
|
||||
|
||||
Args:
|
||||
sql (str|sqlglot.Expression): a sql statement
|
||||
schema (dict|sqlglot.optimizer.Schema): database schema.
|
||||
This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of
|
||||
the following forms:
|
||||
1. {table: {col: type}}
|
||||
2. {db: {table: {col: type}}}
|
||||
3. {catalog: {db: {table: {col: type}}}}
|
||||
read (str): the SQL dialect to apply during parsing
|
||||
(eg. "spark", "hive", "presto", "mysql").
|
||||
tables (dict): additional tables to register.
|
||||
sql: a sql statement.
|
||||
schema: database schema.
|
||||
This can either be an instance of `Schema` or a mapping in one of the following forms:
|
||||
1. {table: {col: type}}
|
||||
2. {db: {table: {col: type}}}
|
||||
3. {catalog: {db: {table: {col: type}}}}
|
||||
read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
|
||||
tables: additional tables to register.
|
||||
|
||||
Returns:
|
||||
sqlglot.executor.Table: Simple columnar data structure.
|
||||
Simple columnar data structure.
|
||||
"""
|
||||
tables = ensure_tables(tables)
|
||||
tables_ = ensure_tables(tables)
|
||||
|
||||
if not schema:
|
||||
schema = {
|
||||
name: {column: type(table[0][column]).__name__ for column in table.columns}
|
||||
for name, table in tables.mapping.items()
|
||||
for name, table in tables_.mapping.items()
|
||||
}
|
||||
|
||||
schema = ensure_schema(schema)
|
||||
if tables.supported_table_args and tables.supported_table_args != schema.supported_table_args:
|
||||
|
||||
if tables_.supported_table_args and tables_.supported_table_args != schema.supported_table_args:
|
||||
raise ExecuteError("Tables must support the same table args as schema")
|
||||
|
||||
expression = maybe_parse(sql, dialect=read)
|
||||
|
||||
now = time.time()
|
||||
expression = optimize(expression, schema, leave_tables_isolated=True)
|
||||
|
||||
logger.debug("Optimization finished: %f", time.time() - now)
|
||||
logger.debug("Optimized SQL: %s", expression.sql(pretty=True))
|
||||
|
||||
plan = Plan(expression)
|
||||
|
||||
logger.debug("Logical Plan: %s", plan)
|
||||
|
||||
now = time.time()
|
||||
result = PythonExecutor(tables=tables).execute(plan)
|
||||
result = PythonExecutor(tables=tables_).execute(plan)
|
||||
|
||||
logger.debug("Query finished: %f", time.time() - now)
|
||||
|
||||
return result
|
||||
|
|
|
@ -171,5 +171,6 @@ ENV = {
|
|||
"STRPOSITION": str_position,
|
||||
"SUB": null_if_any(lambda e, this: e - this),
|
||||
"SUBSTRING": substring,
|
||||
"TIMESTRTOTIME": null_if_any(lambda arg: datetime.datetime.fromisoformat(arg)),
|
||||
"UPPER": null_if_any(lambda arg: arg.upper()),
|
||||
}
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot.helper import dict_depth
|
||||
from sqlglot.schema import AbstractMappingSchema
|
||||
|
||||
|
@ -106,11 +108,11 @@ class Tables(AbstractMappingSchema[Table]):
|
|||
pass
|
||||
|
||||
|
||||
def ensure_tables(d: dict | None) -> Tables:
|
||||
def ensure_tables(d: t.Optional[t.Dict]) -> Tables:
|
||||
return Tables(_ensure_tables(d))
|
||||
|
||||
|
||||
def _ensure_tables(d: dict | None) -> dict:
|
||||
def _ensure_tables(d: t.Optional[t.Dict]) -> t.Dict:
|
||||
if not d:
|
||||
return {}
|
||||
|
||||
|
@ -127,4 +129,5 @@ def _ensure_tables(d: dict | None) -> dict:
|
|||
columns = tuple(table[0]) if table else ()
|
||||
rows = [tuple(row[c] for c in columns) for row in table]
|
||||
result[name] = Table(columns=columns, rows=rows)
|
||||
|
||||
return result
|
||||
|
|
|
@ -32,13 +32,7 @@ from sqlglot.helper import (
|
|||
from sqlglot.tokens import Token
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
|
||||
IntoType = t.Union[
|
||||
str,
|
||||
t.Type[Expression],
|
||||
t.Collection[t.Union[str, t.Type[Expression]]],
|
||||
]
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
|
||||
|
||||
class _Expression(type):
|
||||
|
@ -427,7 +421,7 @@ class Expression(metaclass=_Expression):
|
|||
def __repr__(self):
|
||||
return self._to_s()
|
||||
|
||||
def sql(self, dialect: Dialect | str | None = None, **opts) -> str:
|
||||
def sql(self, dialect: DialectType = None, **opts) -> str:
|
||||
"""
|
||||
Returns SQL string representation of this tree.
|
||||
|
||||
|
@ -595,6 +589,14 @@ class Expression(metaclass=_Expression):
|
|||
return load(obj)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
IntoType = t.Union[
|
||||
str,
|
||||
t.Type[Expression],
|
||||
t.Collection[t.Union[str, t.Type[Expression]]],
|
||||
]
|
||||
|
||||
|
||||
class Condition(Expression):
|
||||
def and_(self, *expressions, dialect=None, **opts):
|
||||
"""
|
||||
|
@ -1285,6 +1287,18 @@ class Property(Expression):
|
|||
arg_types = {"this": True, "value": True}
|
||||
|
||||
|
||||
class AlgorithmProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class DefinerProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class SqlSecurityProperty(Property):
|
||||
arg_types = {"definer": True}
|
||||
|
||||
|
||||
class TableFormatProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
@ -1425,13 +1439,15 @@ class IsolatedLoadingProperty(Property):
|
|||
|
||||
|
||||
class Properties(Expression):
|
||||
arg_types = {"expressions": True, "before": False}
|
||||
arg_types = {"expressions": True}
|
||||
|
||||
NAME_TO_PROPERTY = {
|
||||
"ALGORITHM": AlgorithmProperty,
|
||||
"AUTO_INCREMENT": AutoIncrementProperty,
|
||||
"CHARACTER SET": CharacterSetProperty,
|
||||
"COLLATE": CollateProperty,
|
||||
"COMMENT": SchemaCommentProperty,
|
||||
"DEFINER": DefinerProperty,
|
||||
"DISTKEY": DistKeyProperty,
|
||||
"DISTSTYLE": DistStyleProperty,
|
||||
"ENGINE": EngineProperty,
|
||||
|
@ -1447,6 +1463,14 @@ class Properties(Expression):
|
|||
|
||||
PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()}
|
||||
|
||||
class Location(AutoName):
|
||||
POST_CREATE = auto()
|
||||
PRE_SCHEMA = auto()
|
||||
POST_INDEX = auto()
|
||||
POST_SCHEMA_ROOT = auto()
|
||||
POST_SCHEMA_WITH = auto()
|
||||
UNSUPPORTED = auto()
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, properties_dict) -> Properties:
|
||||
expressions = []
|
||||
|
@ -1592,6 +1616,7 @@ QUERY_MODIFIERS = {
|
|||
"order": False,
|
||||
"limit": False,
|
||||
"offset": False,
|
||||
"lock": False,
|
||||
}
|
||||
|
||||
|
||||
|
@ -1713,6 +1738,12 @@ class Schema(Expression):
|
|||
arg_types = {"this": False, "expressions": False}
|
||||
|
||||
|
||||
# Used to represent the FOR UPDATE and FOR SHARE locking read types.
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/innodb-locking-reads.html
|
||||
class Lock(Expression):
|
||||
arg_types = {"update": True}
|
||||
|
||||
|
||||
class Select(Subqueryable):
|
||||
arg_types = {
|
||||
"with": False,
|
||||
|
@ -2243,6 +2274,30 @@ class Select(Subqueryable):
|
|||
properties=properties_expression,
|
||||
)
|
||||
|
||||
def lock(self, update: bool = True, copy: bool = True) -> Select:
|
||||
"""
|
||||
Set the locking read mode for this expression.
|
||||
|
||||
Examples:
|
||||
>>> Select().select("x").from_("tbl").where("x = 'a'").lock().sql("mysql")
|
||||
"SELECT x FROM tbl WHERE x = 'a' FOR UPDATE"
|
||||
|
||||
>>> Select().select("x").from_("tbl").where("x = 'a'").lock(update=False).sql("mysql")
|
||||
"SELECT x FROM tbl WHERE x = 'a' FOR SHARE"
|
||||
|
||||
Args:
|
||||
update: if `True`, the locking type will be `FOR UPDATE`, else it will be `FOR SHARE`.
|
||||
copy: if `False`, modify this expression instance in-place.
|
||||
|
||||
Returns:
|
||||
The modified expression.
|
||||
"""
|
||||
|
||||
inst = _maybe_copy(self, copy)
|
||||
inst.set("lock", Lock(update=update))
|
||||
|
||||
return inst
|
||||
|
||||
@property
|
||||
def named_selects(self) -> t.List[str]:
|
||||
return [e.output_name for e in self.expressions if e.alias_or_name]
|
||||
|
@ -2456,24 +2511,28 @@ class DataType(Expression):
|
|||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, dtype: str | DataType.Type, dialect: t.Optional[str | Dialect] = None, **kwargs
|
||||
cls, dtype: str | DataType | DataType.Type, dialect: DialectType = None, **kwargs
|
||||
) -> DataType:
|
||||
from sqlglot import parse_one
|
||||
|
||||
if isinstance(dtype, str):
|
||||
data_type_exp: t.Optional[Expression]
|
||||
if dtype.upper() in cls.Type.__members__:
|
||||
data_type_exp = DataType(this=DataType.Type[dtype.upper()])
|
||||
data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[dtype.upper()])
|
||||
else:
|
||||
data_type_exp = parse_one(dtype, read=dialect, into=DataType)
|
||||
if data_type_exp is None:
|
||||
raise ValueError(f"Unparsable data type value: {dtype}")
|
||||
elif isinstance(dtype, DataType.Type):
|
||||
data_type_exp = DataType(this=dtype)
|
||||
elif isinstance(dtype, DataType):
|
||||
return dtype
|
||||
else:
|
||||
raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type")
|
||||
return DataType(**{**data_type_exp.args, **kwargs})
|
||||
|
||||
def is_type(self, dtype: DataType.Type) -> bool:
|
||||
return self.this == dtype
|
||||
|
||||
|
||||
# https://www.postgresql.org/docs/15/datatype-pseudo.html
|
||||
class PseudoType(Expression):
|
||||
|
@ -2840,6 +2899,10 @@ class Array(Func):
|
|||
is_var_len_args = True
|
||||
|
||||
|
||||
class GenerateSeries(Func):
|
||||
arg_types = {"start": True, "end": True, "step": False}
|
||||
|
||||
|
||||
class ArrayAgg(AggFunc):
|
||||
pass
|
||||
|
||||
|
@ -2909,6 +2972,9 @@ class Cast(Func):
|
|||
def output_name(self):
|
||||
return self.name
|
||||
|
||||
def is_type(self, dtype: DataType.Type) -> bool:
|
||||
return self.to.is_type(dtype)
|
||||
|
||||
|
||||
class Collate(Binary):
|
||||
pass
|
||||
|
@ -2989,6 +3055,22 @@ class DatetimeTrunc(Func, TimeUnit):
|
|||
arg_types = {"this": True, "unit": True, "zone": False}
|
||||
|
||||
|
||||
class DayOfWeek(Func):
|
||||
_sql_names = ["DAY_OF_WEEK", "DAYOFWEEK"]
|
||||
|
||||
|
||||
class DayOfMonth(Func):
|
||||
_sql_names = ["DAY_OF_MONTH", "DAYOFMONTH"]
|
||||
|
||||
|
||||
class DayOfYear(Func):
|
||||
_sql_names = ["DAY_OF_YEAR", "DAYOFYEAR"]
|
||||
|
||||
|
||||
class WeekOfYear(Func):
|
||||
_sql_names = ["WEEK_OF_YEAR", "WEEKOFYEAR"]
|
||||
|
||||
|
||||
class LastDateOfMonth(Func):
|
||||
pass
|
||||
|
||||
|
@ -3239,7 +3321,7 @@ class ReadCSV(Func):
|
|||
|
||||
|
||||
class Reduce(Func):
|
||||
arg_types = {"this": True, "initial": True, "merge": True, "finish": True}
|
||||
arg_types = {"this": True, "initial": True, "merge": True, "finish": False}
|
||||
|
||||
|
||||
class RegexpLike(Func):
|
||||
|
@ -3476,7 +3558,7 @@ def maybe_parse(
|
|||
sql_or_expression: str | Expression,
|
||||
*,
|
||||
into: t.Optional[IntoType] = None,
|
||||
dialect: t.Optional[str] = None,
|
||||
dialect: DialectType = None,
|
||||
prefix: t.Optional[str] = None,
|
||||
**opts,
|
||||
) -> Expression:
|
||||
|
@ -3959,6 +4041,28 @@ def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
|
|||
return identifier
|
||||
|
||||
|
||||
INTERVAL_STRING_RE = re.compile(r"\s*([0-9]+)\s*([a-zA-Z]+)\s*")
|
||||
|
||||
|
||||
def to_interval(interval: str | Literal) -> Interval:
|
||||
"""Builds an interval expression from a string like '1 day' or '5 months'."""
|
||||
if isinstance(interval, Literal):
|
||||
if not interval.is_string:
|
||||
raise ValueError("Invalid interval string.")
|
||||
|
||||
interval = interval.this
|
||||
|
||||
interval_parts = INTERVAL_STRING_RE.match(interval) # type: ignore
|
||||
|
||||
if not interval_parts:
|
||||
raise ValueError("Invalid interval string.")
|
||||
|
||||
return Interval(
|
||||
this=Literal.string(interval_parts.group(1)),
|
||||
unit=Var(this=interval_parts.group(2)),
|
||||
)
|
||||
|
||||
|
||||
@t.overload
|
||||
def to_table(sql_path: str | Table, **kwargs) -> Table:
|
||||
...
|
||||
|
@ -4050,7 +4154,8 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
|
|||
def subquery(expression, alias=None, dialect=None, **opts):
|
||||
"""
|
||||
Build a subquery expression.
|
||||
Expample:
|
||||
|
||||
Example:
|
||||
>>> subquery('select x from tbl', 'bar').select('x').sql()
|
||||
'SELECT x FROM (SELECT x FROM tbl) AS bar'
|
||||
|
||||
|
@ -4072,6 +4177,7 @@ def subquery(expression, alias=None, dialect=None, **opts):
|
|||
def column(col, table=None, quoted=None) -> Column:
|
||||
"""
|
||||
Build a Column.
|
||||
|
||||
Args:
|
||||
col (str | Expression): column name
|
||||
table (str | Expression): table name
|
||||
|
@ -4084,6 +4190,24 @@ def column(col, table=None, quoted=None) -> Column:
|
|||
)
|
||||
|
||||
|
||||
def cast(expression: str | Expression, to: str | DataType | DataType.Type, **opts) -> Cast:
|
||||
"""Cast an expression to a data type.
|
||||
|
||||
Example:
|
||||
>>> cast('x + 1', 'int').sql()
|
||||
'CAST(x + 1 AS INT)'
|
||||
|
||||
Args:
|
||||
expression: The expression to cast.
|
||||
to: The datatype to cast to.
|
||||
|
||||
Returns:
|
||||
A cast node.
|
||||
"""
|
||||
expression = maybe_parse(expression, **opts)
|
||||
return Cast(this=expression, to=DataType.build(to, **opts))
|
||||
|
||||
|
||||
def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
|
||||
"""Build a Table.
|
||||
|
||||
|
@ -4137,7 +4261,7 @@ def values(
|
|||
types = list(columns.values())
|
||||
expressions[0].set(
|
||||
"expressions",
|
||||
[Cast(this=x, to=types[i]) for i, x in enumerate(expressions[0].expressions)],
|
||||
[cast(x, types[i]) for i, x in enumerate(expressions[0].expressions)],
|
||||
)
|
||||
return Values(
|
||||
expressions=expressions,
|
||||
|
@ -4373,7 +4497,7 @@ def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True
|
|||
return expression.transform(_expand, copy=copy)
|
||||
|
||||
|
||||
def func(name: str, *args, dialect: t.Optional[Dialect | str] = None, **kwargs) -> Func:
|
||||
def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
|
||||
"""
|
||||
Returns a Func expression.
|
||||
|
||||
|
|
|
@ -67,6 +67,7 @@ class Generator:
|
|||
exp.VolatilityProperty: lambda self, e: e.name,
|
||||
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
|
||||
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
|
||||
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
|
||||
}
|
||||
|
||||
# Whether 'CREATE ... TRANSIENT ... TABLE' is allowed
|
||||
|
@ -75,6 +76,9 @@ class Generator:
|
|||
# Whether or not null ordering is supported in order by
|
||||
NULL_ORDERING_SUPPORTED = True
|
||||
|
||||
# Whether or not locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported
|
||||
LOCKING_READS_SUPPORTED = False
|
||||
|
||||
# Always do union distinct or union all
|
||||
EXPLICIT_UNION = False
|
||||
|
||||
|
@ -99,34 +103,42 @@ class Generator:
|
|||
|
||||
STRUCT_DELIMITER = ("<", ">")
|
||||
|
||||
BEFORE_PROPERTIES = {
|
||||
exp.FallbackProperty,
|
||||
exp.WithJournalTableProperty,
|
||||
exp.LogProperty,
|
||||
exp.JournalProperty,
|
||||
exp.AfterJournalProperty,
|
||||
exp.ChecksumProperty,
|
||||
exp.FreespaceProperty,
|
||||
exp.MergeBlockRatioProperty,
|
||||
exp.DataBlocksizeProperty,
|
||||
exp.BlockCompressionProperty,
|
||||
exp.IsolatedLoadingProperty,
|
||||
}
|
||||
|
||||
ROOT_PROPERTIES = {
|
||||
exp.ReturnsProperty,
|
||||
exp.LanguageProperty,
|
||||
exp.DistStyleProperty,
|
||||
exp.DistKeyProperty,
|
||||
exp.SortKeyProperty,
|
||||
exp.LikeProperty,
|
||||
}
|
||||
|
||||
WITH_PROPERTIES = {
|
||||
exp.Property,
|
||||
exp.FileFormatProperty,
|
||||
exp.PartitionedByProperty,
|
||||
exp.TableFormatProperty,
|
||||
PROPERTIES_LOCATION = {
|
||||
exp.AfterJournalProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||
exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.BlockCompressionProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||
exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.ChecksumProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||
exp.CollateProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.DataBlocksizeProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||
exp.DefinerProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.EngineProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.ExecuteAsProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.FallbackProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA_WITH,
|
||||
exp.FreespaceProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||
exp.IsolatedLoadingProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||
exp.JournalProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||
exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.LikeProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.LocationProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.LogProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||
exp.MergeBlockRatioProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_WITH,
|
||||
exp.Property: exp.Properties.Location.POST_SCHEMA_WITH,
|
||||
exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA_WITH,
|
||||
exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||
exp.WithJournalTableProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||
}
|
||||
|
||||
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
|
||||
|
@ -284,10 +296,10 @@ class Generator:
|
|||
)
|
||||
return f"({self.sep('')}{this_sql}{self.seg(')', sep='')}"
|
||||
|
||||
def no_identify(self, func: t.Callable[[], str]) -> str:
|
||||
def no_identify(self, func: t.Callable[..., str], *args, **kwargs) -> str:
|
||||
original = self.identify
|
||||
self.identify = False
|
||||
result = func()
|
||||
result = func(*args, **kwargs)
|
||||
self.identify = original
|
||||
return result
|
||||
|
||||
|
@ -455,19 +467,33 @@ class Generator:
|
|||
|
||||
def create_sql(self, expression: exp.Create) -> str:
|
||||
kind = self.sql(expression, "kind").upper()
|
||||
has_before_properties = expression.args.get("properties")
|
||||
has_before_properties = (
|
||||
has_before_properties.args.get("before") if has_before_properties else None
|
||||
)
|
||||
if kind == "TABLE" and has_before_properties:
|
||||
properties = expression.args.get("properties")
|
||||
properties_exp = expression.copy()
|
||||
properties_locs = self.locate_properties(properties) if properties else {}
|
||||
if properties_locs.get(exp.Properties.Location.POST_SCHEMA_ROOT) or properties_locs.get(
|
||||
exp.Properties.Location.POST_SCHEMA_WITH
|
||||
):
|
||||
properties_exp.set(
|
||||
"properties",
|
||||
exp.Properties(
|
||||
expressions=[
|
||||
*properties_locs[exp.Properties.Location.POST_SCHEMA_ROOT],
|
||||
*properties_locs[exp.Properties.Location.POST_SCHEMA_WITH],
|
||||
]
|
||||
),
|
||||
)
|
||||
if kind == "TABLE" and properties_locs.get(exp.Properties.Location.PRE_SCHEMA):
|
||||
this_name = self.sql(expression.this, "this")
|
||||
this_properties = self.sql(expression, "properties")
|
||||
this_properties = self.properties(
|
||||
exp.Properties(expressions=properties_locs[exp.Properties.Location.PRE_SCHEMA]),
|
||||
wrapped=False,
|
||||
)
|
||||
this_schema = f"({self.expressions(expression.this)})"
|
||||
this = f"{this_name}, {this_properties} {this_schema}"
|
||||
properties = ""
|
||||
properties_sql = ""
|
||||
else:
|
||||
this = self.sql(expression, "this")
|
||||
properties = self.sql(expression, "properties")
|
||||
properties_sql = self.sql(properties_exp, "properties")
|
||||
begin = " BEGIN" if expression.args.get("begin") else ""
|
||||
expression_sql = self.sql(expression, "expression")
|
||||
expression_sql = f" AS{begin}{self.sep()}{expression_sql}" if expression_sql else ""
|
||||
|
@ -514,11 +540,31 @@ class Generator:
|
|||
if index.args.get("columns")
|
||||
else ""
|
||||
)
|
||||
if index.args.get("primary") and properties_locs.get(
|
||||
exp.Properties.Location.POST_INDEX
|
||||
):
|
||||
postindex_props_sql = self.properties(
|
||||
exp.Properties(
|
||||
expressions=properties_locs[exp.Properties.Location.POST_INDEX]
|
||||
),
|
||||
wrapped=False,
|
||||
)
|
||||
ind_columns = f"{ind_columns} {postindex_props_sql}"
|
||||
|
||||
indexes_sql.append(
|
||||
f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}"
|
||||
)
|
||||
index_sql = "".join(indexes_sql)
|
||||
|
||||
postcreate_props_sql = ""
|
||||
if properties_locs.get(exp.Properties.Location.POST_CREATE):
|
||||
postcreate_props_sql = self.properties(
|
||||
exp.Properties(expressions=properties_locs[exp.Properties.Location.POST_CREATE]),
|
||||
sep=" ",
|
||||
prefix=" ",
|
||||
wrapped=False,
|
||||
)
|
||||
|
||||
modifiers = "".join(
|
||||
(
|
||||
replace,
|
||||
|
@ -531,6 +577,7 @@ class Generator:
|
|||
multiset,
|
||||
global_temporary,
|
||||
volatile,
|
||||
postcreate_props_sql,
|
||||
)
|
||||
)
|
||||
no_schema_binding = (
|
||||
|
@ -539,7 +586,7 @@ class Generator:
|
|||
|
||||
post_expression_modifiers = "".join((data, statistics, no_primary_index))
|
||||
|
||||
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}{no_schema_binding}"
|
||||
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{post_expression_modifiers}{index_sql}{no_schema_binding}"
|
||||
return self.prepend_ctes(expression, expression_sql)
|
||||
|
||||
def describe_sql(self, expression: exp.Describe) -> str:
|
||||
|
@ -665,24 +712,19 @@ class Generator:
|
|||
return f"PARTITION({self.expressions(expression)})"
|
||||
|
||||
def properties_sql(self, expression: exp.Properties) -> str:
|
||||
before_properties = []
|
||||
root_properties = []
|
||||
with_properties = []
|
||||
|
||||
for p in expression.expressions:
|
||||
p_class = p.__class__
|
||||
if p_class in self.BEFORE_PROPERTIES:
|
||||
before_properties.append(p)
|
||||
elif p_class in self.WITH_PROPERTIES:
|
||||
p_loc = self.PROPERTIES_LOCATION[p.__class__]
|
||||
if p_loc == exp.Properties.Location.POST_SCHEMA_WITH:
|
||||
with_properties.append(p)
|
||||
elif p_class in self.ROOT_PROPERTIES:
|
||||
elif p_loc == exp.Properties.Location.POST_SCHEMA_ROOT:
|
||||
root_properties.append(p)
|
||||
|
||||
return (
|
||||
self.properties(exp.Properties(expressions=before_properties), before=True)
|
||||
+ self.root_properties(exp.Properties(expressions=root_properties))
|
||||
+ self.with_properties(exp.Properties(expressions=with_properties))
|
||||
)
|
||||
return self.root_properties(
|
||||
exp.Properties(expressions=root_properties)
|
||||
) + self.with_properties(exp.Properties(expressions=with_properties))
|
||||
|
||||
def root_properties(self, properties: exp.Properties) -> str:
|
||||
if properties.expressions:
|
||||
|
@ -695,17 +737,41 @@ class Generator:
|
|||
prefix: str = "",
|
||||
sep: str = ", ",
|
||||
suffix: str = "",
|
||||
before: bool = False,
|
||||
wrapped: bool = True,
|
||||
) -> str:
|
||||
if properties.expressions:
|
||||
expressions = self.expressions(properties, sep=sep, indent=False)
|
||||
expressions = expressions if before else self.wrap(expressions)
|
||||
expressions = self.wrap(expressions) if wrapped else expressions
|
||||
return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}"
|
||||
return ""
|
||||
|
||||
def with_properties(self, properties: exp.Properties) -> str:
|
||||
return self.properties(properties, prefix=self.seg("WITH"))
|
||||
|
||||
def locate_properties(
|
||||
self, properties: exp.Properties
|
||||
) -> t.Dict[exp.Properties.Location, list[exp.Property]]:
|
||||
properties_locs: t.Dict[exp.Properties.Location, list[exp.Property]] = {
|
||||
key: [] for key in exp.Properties.Location
|
||||
}
|
||||
|
||||
for p in properties.expressions:
|
||||
p_loc = self.PROPERTIES_LOCATION[p.__class__]
|
||||
if p_loc == exp.Properties.Location.PRE_SCHEMA:
|
||||
properties_locs[exp.Properties.Location.PRE_SCHEMA].append(p)
|
||||
elif p_loc == exp.Properties.Location.POST_INDEX:
|
||||
properties_locs[exp.Properties.Location.POST_INDEX].append(p)
|
||||
elif p_loc == exp.Properties.Location.POST_SCHEMA_ROOT:
|
||||
properties_locs[exp.Properties.Location.POST_SCHEMA_ROOT].append(p)
|
||||
elif p_loc == exp.Properties.Location.POST_SCHEMA_WITH:
|
||||
properties_locs[exp.Properties.Location.POST_SCHEMA_WITH].append(p)
|
||||
elif p_loc == exp.Properties.Location.POST_CREATE:
|
||||
properties_locs[exp.Properties.Location.POST_CREATE].append(p)
|
||||
elif p_loc == exp.Properties.Location.UNSUPPORTED:
|
||||
self.unsupported(f"Unsupported property {p.key}")
|
||||
|
||||
return properties_locs
|
||||
|
||||
def property_sql(self, expression: exp.Property) -> str:
|
||||
property_cls = expression.__class__
|
||||
if property_cls == exp.Property:
|
||||
|
@ -713,7 +779,7 @@ class Generator:
|
|||
|
||||
property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls)
|
||||
if not property_name:
|
||||
self.unsupported(f"Unsupported property {property_name}")
|
||||
self.unsupported(f"Unsupported property {expression.key}")
|
||||
|
||||
return f"{property_name}={self.sql(expression, 'this')}"
|
||||
|
||||
|
@ -975,7 +1041,7 @@ class Generator:
|
|||
rollup = self.expressions(expression, key="rollup", indent=False)
|
||||
rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else ""
|
||||
|
||||
return f"{group_by}{grouping_sets}{cube}{rollup}"
|
||||
return f"{group_by}{csv(grouping_sets, cube, rollup, sep=',')}"
|
||||
|
||||
def having_sql(self, expression: exp.Having) -> str:
|
||||
this = self.indent(self.sql(expression, "this"))
|
||||
|
@ -1015,7 +1081,7 @@ class Generator:
|
|||
def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str:
|
||||
args = self.expressions(expression, flat=True)
|
||||
args = f"({args})" if len(args.split(",")) > 1 else args
|
||||
return self.no_identify(lambda: f"{args} {arrow_sep} {self.sql(expression, 'this')}")
|
||||
return f"{args} {arrow_sep} {self.sql(expression, 'this')}"
|
||||
|
||||
def lateral_sql(self, expression: exp.Lateral) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -1043,6 +1109,14 @@ class Generator:
|
|||
this = self.sql(expression, "this")
|
||||
return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}"
|
||||
|
||||
def lock_sql(self, expression: exp.Lock) -> str:
|
||||
if self.LOCKING_READS_SUPPORTED:
|
||||
lock_type = "UPDATE" if expression.args["update"] else "SHARE"
|
||||
return self.seg(f"FOR {lock_type}")
|
||||
|
||||
self.unsupported("Locking reads using 'FOR UPDATE/SHARE' are not supported")
|
||||
return ""
|
||||
|
||||
def literal_sql(self, expression: exp.Literal) -> str:
|
||||
text = expression.this or ""
|
||||
if expression.is_string:
|
||||
|
@ -1163,6 +1237,7 @@ class Generator:
|
|||
self.sql(expression, "order"),
|
||||
self.sql(expression, "limit"),
|
||||
self.sql(expression, "offset"),
|
||||
self.sql(expression, "lock"),
|
||||
sep="",
|
||||
)
|
||||
|
||||
|
@ -1773,7 +1848,7 @@ class Generator:
|
|||
|
||||
def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
expressions = self.no_identify(lambda: self.expressions(expression))
|
||||
expressions = self.no_identify(self.expressions, expression)
|
||||
expressions = (
|
||||
self.wrap(expressions) if expression.args.get("wrapped") else f" {expressions}"
|
||||
)
|
||||
|
|
|
@ -9,6 +9,9 @@ from sqlglot.optimizer import Scope, build_scope, optimize
|
|||
from sqlglot.optimizer.qualify_columns import qualify_columns
|
||||
from sqlglot.optimizer.qualify_tables import qualify_tables
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Node:
|
||||
|
@ -36,7 +39,7 @@ def lineage(
|
|||
schema: t.Optional[t.Dict | Schema] = None,
|
||||
sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
|
||||
rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns),
|
||||
dialect: t.Optional[str] = None,
|
||||
dialect: DialectType = None,
|
||||
) -> Node:
|
||||
"""Build the lineage graph for a column of a SQL query.
|
||||
|
||||
|
@ -126,7 +129,7 @@ class LineageHTML:
|
|||
def __init__(
|
||||
self,
|
||||
node: Node,
|
||||
dialect: t.Optional[str] = None,
|
||||
dialect: DialectType = None,
|
||||
imports: bool = True,
|
||||
**opts: t.Any,
|
||||
):
|
||||
|
|
|
@ -114,7 +114,7 @@ def _eliminate_union(scope, existing_ctes, taken):
|
|||
taken[alias] = scope
|
||||
|
||||
# Try to maintain the selections
|
||||
expressions = scope.expression.args.get("expressions")
|
||||
expressions = scope.selects
|
||||
selects = [
|
||||
exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name)
|
||||
for e in expressions
|
||||
|
|
|
@ -300,7 +300,7 @@ class Scope:
|
|||
list[exp.Expression]: expressions
|
||||
"""
|
||||
if isinstance(self.expression, exp.Union):
|
||||
return []
|
||||
return self.expression.unnest().selects
|
||||
return self.expression.selects
|
||||
|
||||
@property
|
||||
|
|
|
@ -456,8 +456,10 @@ def extract_interval(interval):
|
|||
|
||||
|
||||
def date_literal(date):
|
||||
expr_type = exp.DataType.build("DATETIME" if isinstance(date, datetime.datetime) else "DATE")
|
||||
return exp.Cast(this=exp.Literal.string(date), to=expr_type)
|
||||
return exp.cast(
|
||||
exp.Literal.string(date),
|
||||
"DATETIME" if isinstance(date, datetime.datetime) else "DATE",
|
||||
)
|
||||
|
||||
|
||||
def boolean_literal(condition):
|
||||
|
|
|
@ -80,6 +80,7 @@ class Parser(metaclass=_Parser):
|
|||
length=exp.Literal.number(10),
|
||||
),
|
||||
"VAR_MAP": parse_var_map,
|
||||
"IFNULL": exp.Coalesce.from_arg_list,
|
||||
}
|
||||
|
||||
NO_PAREN_FUNCTIONS = {
|
||||
|
@ -567,6 +568,8 @@ class Parser(metaclass=_Parser):
|
|||
default=self._prev.text.upper() == "DEFAULT"
|
||||
),
|
||||
"BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(),
|
||||
"ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty),
|
||||
"DEFINER": lambda self: self._parse_definer(),
|
||||
}
|
||||
|
||||
CONSTRAINT_PARSERS = {
|
||||
|
@ -608,6 +611,7 @@ class Parser(metaclass=_Parser):
|
|||
"order": lambda self: self._parse_order(),
|
||||
"limit": lambda self: self._parse_limit(),
|
||||
"offset": lambda self: self._parse_offset(),
|
||||
"lock": lambda self: self._parse_lock(),
|
||||
}
|
||||
|
||||
SHOW_PARSERS: t.Dict[str, t.Callable] = {}
|
||||
|
@ -850,7 +854,7 @@ class Parser(metaclass=_Parser):
|
|||
self.raise_error(error_message)
|
||||
|
||||
def _find_sql(self, start: Token, end: Token) -> str:
|
||||
return self.sql[self._find_token(start) : self._find_token(end)]
|
||||
return self.sql[self._find_token(start) : self._find_token(end) + len(end.text)]
|
||||
|
||||
def _find_token(self, token: Token) -> int:
|
||||
line = 1
|
||||
|
@ -901,6 +905,7 @@ class Parser(metaclass=_Parser):
|
|||
return expression
|
||||
|
||||
def _parse_drop(self, default_kind: t.Optional[str] = None) -> t.Optional[exp.Expression]:
|
||||
start = self._prev
|
||||
temporary = self._match(TokenType.TEMPORARY)
|
||||
materialized = self._match(TokenType.MATERIALIZED)
|
||||
kind = self._match_set(self.CREATABLES) and self._prev.text
|
||||
|
@ -908,8 +913,7 @@ class Parser(metaclass=_Parser):
|
|||
if default_kind:
|
||||
kind = default_kind
|
||||
else:
|
||||
self.raise_error(f"Expected {self.CREATABLES}")
|
||||
return None
|
||||
return self._parse_as_command(start)
|
||||
|
||||
return self.expression(
|
||||
exp.Drop,
|
||||
|
@ -929,6 +933,7 @@ class Parser(metaclass=_Parser):
|
|||
)
|
||||
|
||||
def _parse_create(self) -> t.Optional[exp.Expression]:
|
||||
start = self._prev
|
||||
replace = self._match_pair(TokenType.OR, TokenType.REPLACE)
|
||||
set_ = self._match(TokenType.SET) # Teradata
|
||||
multiset = self._match_text_seq("MULTISET") # Teradata
|
||||
|
@ -943,16 +948,19 @@ class Parser(metaclass=_Parser):
|
|||
if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False):
|
||||
self._match(TokenType.TABLE)
|
||||
|
||||
properties = None
|
||||
create_token = self._match_set(self.CREATABLES) and self._prev
|
||||
|
||||
if not create_token:
|
||||
self.raise_error(f"Expected {self.CREATABLES}")
|
||||
return None
|
||||
properties = self._parse_properties()
|
||||
create_token = self._match_set(self.CREATABLES) and self._prev
|
||||
|
||||
if not properties or not create_token:
|
||||
return self._parse_as_command(start)
|
||||
|
||||
exists = self._parse_exists(not_=True)
|
||||
this = None
|
||||
expression = None
|
||||
properties = None
|
||||
data = None
|
||||
statistics = None
|
||||
no_primary_index = None
|
||||
|
@ -1006,6 +1014,14 @@ class Parser(metaclass=_Parser):
|
|||
indexes = []
|
||||
while True:
|
||||
index = self._parse_create_table_index()
|
||||
|
||||
# post index PARTITION BY property
|
||||
if self._match(TokenType.PARTITION_BY, advance=False):
|
||||
if properties:
|
||||
properties.expressions.append(self._parse_property())
|
||||
else:
|
||||
properties = self._parse_properties()
|
||||
|
||||
if not index:
|
||||
break
|
||||
else:
|
||||
|
@ -1040,6 +1056,9 @@ class Parser(metaclass=_Parser):
|
|||
)
|
||||
|
||||
def _parse_property_before(self) -> t.Optional[exp.Expression]:
|
||||
self._match(TokenType.COMMA)
|
||||
|
||||
# parsers look to _prev for no/dual/default, so need to consume first
|
||||
self._match_text_seq("NO")
|
||||
self._match_text_seq("DUAL")
|
||||
self._match_text_seq("DEFAULT")
|
||||
|
@ -1059,6 +1078,9 @@ class Parser(metaclass=_Parser):
|
|||
if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY):
|
||||
return self._parse_sortkey(compound=True)
|
||||
|
||||
if self._match_text_seq("SQL", "SECURITY"):
|
||||
return self.expression(exp.SqlSecurityProperty, definer=self._match_text_seq("DEFINER"))
|
||||
|
||||
assignment = self._match_pair(
|
||||
TokenType.VAR, TokenType.EQ, advance=False
|
||||
) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False)
|
||||
|
@ -1083,7 +1105,6 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
while True:
|
||||
if before:
|
||||
self._match(TokenType.COMMA)
|
||||
identified_property = self._parse_property_before()
|
||||
else:
|
||||
identified_property = self._parse_property()
|
||||
|
@ -1094,7 +1115,7 @@ class Parser(metaclass=_Parser):
|
|||
properties.append(p)
|
||||
|
||||
if properties:
|
||||
return self.expression(exp.Properties, expressions=properties, before=before)
|
||||
return self.expression(exp.Properties, expressions=properties)
|
||||
|
||||
return None
|
||||
|
||||
|
@ -1118,6 +1139,19 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return self._parse_withisolatedloading()
|
||||
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/create-view.html
|
||||
def _parse_definer(self) -> t.Optional[exp.Expression]:
|
||||
self._match(TokenType.EQ)
|
||||
|
||||
user = self._parse_id_var()
|
||||
self._match(TokenType.PARAMETER)
|
||||
host = self._parse_id_var() or (self._match(TokenType.MOD) and self._prev.text)
|
||||
|
||||
if not user or not host:
|
||||
return None
|
||||
|
||||
return exp.DefinerProperty(this=f"{user}@{host}")
|
||||
|
||||
def _parse_withjournaltable(self) -> exp.Expression:
|
||||
self._match_text_seq("WITH", "JOURNAL", "TABLE")
|
||||
self._match(TokenType.EQ)
|
||||
|
@ -1695,12 +1729,10 @@ class Parser(metaclass=_Parser):
|
|||
paren += 1
|
||||
if self._curr.token_type == TokenType.R_PAREN:
|
||||
paren -= 1
|
||||
end = self._prev
|
||||
self._advance()
|
||||
if paren > 0:
|
||||
self.raise_error("Expecting )", self._curr)
|
||||
if not self._curr:
|
||||
self.raise_error("Expecting pattern", self._curr)
|
||||
end = self._prev
|
||||
pattern = exp.Var(this=self._find_sql(start, end))
|
||||
else:
|
||||
pattern = None
|
||||
|
@ -2044,9 +2076,16 @@ class Parser(metaclass=_Parser):
|
|||
expressions = self._parse_csv(self._parse_conjunction)
|
||||
grouping_sets = self._parse_grouping_sets()
|
||||
|
||||
self._match(TokenType.COMMA)
|
||||
with_ = self._match(TokenType.WITH)
|
||||
cube = self._match(TokenType.CUBE) and (with_ or self._parse_wrapped_id_vars())
|
||||
rollup = self._match(TokenType.ROLLUP) and (with_ or self._parse_wrapped_id_vars())
|
||||
cube = self._match(TokenType.CUBE) and (
|
||||
with_ or self._parse_wrapped_csv(self._parse_column)
|
||||
)
|
||||
|
||||
self._match(TokenType.COMMA)
|
||||
rollup = self._match(TokenType.ROLLUP) and (
|
||||
with_ or self._parse_wrapped_csv(self._parse_column)
|
||||
)
|
||||
|
||||
return self.expression(
|
||||
exp.Group,
|
||||
|
@ -2149,6 +2188,14 @@ class Parser(metaclass=_Parser):
|
|||
self._match_set((TokenType.ROW, TokenType.ROWS))
|
||||
return self.expression(exp.Offset, this=this, expression=count)
|
||||
|
||||
def _parse_lock(self) -> t.Optional[exp.Expression]:
|
||||
if self._match_text_seq("FOR", "UPDATE"):
|
||||
return self.expression(exp.Lock, update=True)
|
||||
if self._match_text_seq("FOR", "SHARE"):
|
||||
return self.expression(exp.Lock, update=False)
|
||||
|
||||
return None
|
||||
|
||||
def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
if not self._match_set(self.SET_OPERATIONS):
|
||||
return this
|
||||
|
@ -2330,12 +2377,21 @@ class Parser(metaclass=_Parser):
|
|||
maybe_func = True
|
||||
|
||||
if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
|
||||
return exp.DataType(
|
||||
this = exp.DataType(
|
||||
this=exp.DataType.Type.ARRAY,
|
||||
expressions=[exp.DataType.build(type_token.value, expressions=expressions)],
|
||||
nested=True,
|
||||
)
|
||||
|
||||
while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
|
||||
this = exp.DataType(
|
||||
this=exp.DataType.Type.ARRAY,
|
||||
expressions=[this],
|
||||
nested=True,
|
||||
)
|
||||
|
||||
return this
|
||||
|
||||
if self._match(TokenType.L_BRACKET):
|
||||
self._retreat(index)
|
||||
return None
|
||||
|
@ -2430,7 +2486,12 @@ class Parser(metaclass=_Parser):
|
|||
self.raise_error("Expected type")
|
||||
elif op:
|
||||
self._advance()
|
||||
field = exp.Literal.string(self._prev.text)
|
||||
value = self._prev.text
|
||||
field = (
|
||||
exp.Literal.number(value)
|
||||
if self._prev.token_type == TokenType.NUMBER
|
||||
else exp.Literal.string(value)
|
||||
)
|
||||
else:
|
||||
field = self._parse_star() or self._parse_function() or self._parse_id_var()
|
||||
|
||||
|
@ -2752,7 +2813,23 @@ class Parser(metaclass=_Parser):
|
|||
if not self._curr:
|
||||
break
|
||||
|
||||
if self._match_text_seq("NOT", "ENFORCED"):
|
||||
if self._match(TokenType.ON):
|
||||
action = None
|
||||
on = self._advance_any() and self._prev.text
|
||||
|
||||
if self._match(TokenType.NO_ACTION):
|
||||
action = "NO ACTION"
|
||||
elif self._match(TokenType.CASCADE):
|
||||
action = "CASCADE"
|
||||
elif self._match_pair(TokenType.SET, TokenType.NULL):
|
||||
action = "SET NULL"
|
||||
elif self._match_pair(TokenType.SET, TokenType.DEFAULT):
|
||||
action = "SET DEFAULT"
|
||||
else:
|
||||
self.raise_error("Invalid key constraint")
|
||||
|
||||
options.append(f"ON {on} {action}")
|
||||
elif self._match_text_seq("NOT", "ENFORCED"):
|
||||
options.append("NOT ENFORCED")
|
||||
elif self._match_text_seq("DEFERRABLE"):
|
||||
options.append("DEFERRABLE")
|
||||
|
@ -2762,10 +2839,6 @@ class Parser(metaclass=_Parser):
|
|||
options.append("NORELY")
|
||||
elif self._match_text_seq("MATCH", "FULL"):
|
||||
options.append("MATCH FULL")
|
||||
elif self._match_text_seq("ON", "UPDATE", "NO ACTION"):
|
||||
options.append("ON UPDATE NO ACTION")
|
||||
elif self._match_text_seq("ON", "DELETE", "NO ACTION"):
|
||||
options.append("ON DELETE NO ACTION")
|
||||
else:
|
||||
break
|
||||
|
||||
|
@ -3158,7 +3231,9 @@ class Parser(metaclass=_Parser):
|
|||
prefix += self._prev.text
|
||||
|
||||
if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS):
|
||||
return exp.Identifier(this=prefix + self._prev.text, quoted=False)
|
||||
quoted = self._prev.token_type == TokenType.STRING
|
||||
return exp.Identifier(this=prefix + self._prev.text, quoted=quoted)
|
||||
|
||||
return None
|
||||
|
||||
def _parse_string(self) -> t.Optional[exp.Expression]:
|
||||
|
@ -3486,6 +3561,11 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_set(self) -> exp.Expression:
|
||||
return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
|
||||
|
||||
def _parse_as_command(self, start: Token) -> exp.Command:
|
||||
while self._curr:
|
||||
self._advance()
|
||||
return exp.Command(this=self._find_sql(start, self._prev))
|
||||
|
||||
def _find_parser(
|
||||
self, parsers: t.Dict[str, t.Callable], trie: t.Dict
|
||||
) -> t.Optional[t.Callable]:
|
||||
|
|
|
@ -11,6 +11,7 @@ from sqlglot.trie import in_trie, new_trie
|
|||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dataframe.sql.types import StructType
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
|
||||
ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
|
||||
|
||||
|
@ -153,7 +154,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
self,
|
||||
schema: t.Optional[t.Dict] = None,
|
||||
visible: t.Optional[t.Dict] = None,
|
||||
dialect: t.Optional[str] = None,
|
||||
dialect: DialectType = None,
|
||||
) -> None:
|
||||
self.dialect = dialect
|
||||
self.visible = visible or {}
|
||||
|
|
|
@ -665,6 +665,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"STRING": TokenType.TEXT,
|
||||
"TEXT": TokenType.TEXT,
|
||||
"CLOB": TokenType.TEXT,
|
||||
"LONGVARCHAR": TokenType.TEXT,
|
||||
"BINARY": TokenType.BINARY,
|
||||
"BLOB": TokenType.VARBINARY,
|
||||
"BYTEA": TokenType.VARBINARY,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue