1
0
Fork 0

Merging upstream version 10.6.3.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:09:58 +01:00
parent d03a55eda6
commit ece6881255
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
48 changed files with 906 additions and 266 deletions

View file

@ -33,7 +33,13 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema, Schema
from sqlglot.tokens import Tokenizer, TokenType
__version__ = "10.6.0"
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
T = t.TypeVar("T", bound=Expression)
__version__ = "10.6.3"
pretty = False
"""Whether to format generated SQL by default."""
@ -42,9 +48,7 @@ schema = MappingSchema()
"""The default schema used by SQLGlot (e.g. in the optimizer)."""
def parse(
sql: str, read: t.Optional[str | Dialect] = None, **opts
) -> t.List[t.Optional[Expression]]:
def parse(sql: str, read: DialectType = None, **opts) -> t.List[t.Optional[Expression]]:
"""
Parses the given SQL string into a collection of syntax trees, one per parsed SQL statement.
@ -60,9 +64,57 @@ def parse(
return dialect.parse(sql, **opts)
@t.overload
def parse_one(
sql: str,
read: t.Optional[str | Dialect] = None,
read: None = None,
into: t.Type[T] = ...,
**opts,
) -> T:
...
@t.overload
def parse_one(
sql: str,
read: DialectType,
into: t.Type[T],
**opts,
) -> T:
...
@t.overload
def parse_one(
sql: str,
read: None = None,
into: t.Union[str, t.Collection[t.Union[str, t.Type[Expression]]]] = ...,
**opts,
) -> Expression:
...
@t.overload
def parse_one(
sql: str,
read: DialectType,
into: t.Union[str, t.Collection[t.Union[str, t.Type[Expression]]]],
**opts,
) -> Expression:
...
@t.overload
def parse_one(
sql: str,
**opts,
) -> Expression:
...
def parse_one(
sql: str,
read: DialectType = None,
into: t.Optional[exp.IntoType] = None,
**opts,
) -> Expression:
@ -96,8 +148,8 @@ def parse_one(
def transpile(
sql: str,
read: t.Optional[str | Dialect] = None,
write: t.Optional[str | Dialect] = None,
read: DialectType = None,
write: DialectType = None,
identity: bool = True,
error_level: t.Optional[ErrorLevel] = None,
**opts,

View file

@ -260,11 +260,7 @@ class Column:
"""
if isinstance(dataType, DataType):
dataType = dataType.simpleString()
new_expression = exp.Cast(
this=self.column_expression,
to=sqlglot.parse_one(dataType, into=exp.DataType, read="spark"), # type: ignore
)
return Column(new_expression)
return Column(exp.cast(self.column_expression, dataType, dialect="spark"))
def startswith(self, value: t.Union[str, Column]) -> Column:
value = self._lit(value) if not isinstance(value, Column) else value

View file

@ -536,15 +536,15 @@ def month(col: ColumnOrName) -> Column:
def dayofweek(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "DAYOFWEEK")
return Column.invoke_expression_over_column(col, glotexp.DayOfWeek)
def dayofmonth(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "DAYOFMONTH")
return Column.invoke_expression_over_column(col, glotexp.DayOfMonth)
def dayofyear(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "DAYOFYEAR")
return Column.invoke_expression_over_column(col, glotexp.DayOfYear)
def hour(col: ColumnOrName) -> Column:
@ -560,7 +560,7 @@ def second(col: ColumnOrName) -> Column:
def weekofyear(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "WEEKOFYEAR")
return Column.invoke_expression_over_column(col, glotexp.WeekOfYear)
def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Column:
@ -1144,10 +1144,16 @@ def aggregate(
merge_exp = _get_lambda_from_func(merge)
if finish is not None:
finish_exp = _get_lambda_from_func(finish)
return Column.invoke_anonymous_function(
col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp)
return Column.invoke_expression_over_column(
col,
glotexp.Reduce,
initial=initialValue,
merge=Column(merge_exp),
finish=Column(finish_exp),
)
return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp))
return Column.invoke_expression_over_column(
col, glotexp.Reduce, initial=initialValue, merge=Column(merge_exp)
)
def transform(

View file

@ -222,14 +222,6 @@ class BigQuery(Dialect):
exp.DataType.Type.NVARCHAR: "STRING",
}
ROOT_PROPERTIES = {
exp.LanguageProperty,
exp.ReturnsProperty,
exp.VolatilityProperty,
}
WITH_PROPERTIES = {exp.Property}
EXPLICIT_UNION = True
def array_sql(self, expression: exp.Array) -> str:

View file

@ -122,9 +122,15 @@ class Dialect(metaclass=_Dialect):
def get_or_raise(cls, dialect):
if not dialect:
return cls
if isinstance(dialect, _Dialect):
return dialect
if isinstance(dialect, Dialect):
return dialect.__class__
result = cls.get(dialect)
if not result:
raise ValueError(f"Unknown dialect '{dialect}'")
return result
@classmethod
@ -196,6 +202,10 @@ class Dialect(metaclass=_Dialect):
)
if t.TYPE_CHECKING:
DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
def rename_func(name):
def _rename(self, expression):
args = flatten(expression.args.values())

View file

@ -137,7 +137,10 @@ class Drill(Dialect):
exp.DataType.Type.DATETIME: "TIMESTAMP",
}
ROOT_PROPERTIES = {exp.PartitionedByProperty}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore

View file

@ -20,10 +20,6 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
def _unix_to_time(self, expression):
return f"TO_TIMESTAMP(CAST({self.sql(expression, 'this')} AS BIGINT))"
def _str_to_time_sql(self, expression):
return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})"
@ -113,7 +109,7 @@ class DuckDB(Dialect):
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"STRUCT_PACK": exp.Struct.from_arg_list,
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
"TO_TIMESTAMP": exp.UnixToTime.from_arg_list,
"UNNEST": exp.Explode.from_arg_list,
}
@ -162,9 +158,9 @@ class DuckDB(Dialect):
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add,
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
exp.UnixToStr: lambda self, e: f"STRFTIME({_unix_to_time(self, e)}, {self.format_time(e)})",
exp.UnixToTime: _unix_to_time,
exp.UnixToTimeStr: lambda self, e: f"CAST({_unix_to_time(self, e)} AS TEXT)",
exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: rename_func("TO_TIMESTAMP"),
exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)",
}
TYPE_MAPPING = {

View file

@ -322,17 +322,11 @@ class Hive(Dialect):
exp.LastDateOfMonth: rename_func("LAST_DAY"),
}
WITH_PROPERTIES = {exp.Property}
ROOT_PROPERTIES = {
exp.PartitionedByProperty,
exp.FileFormatProperty,
exp.SchemaCommentProperty,
exp.LocationProperty,
exp.TableFormatProperty,
exp.RowFormatDelimitedProperty,
exp.RowFormatSerdeProperty,
exp.SerdeProperties,
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
}
def with_properties(self, properties):

View file

@ -1,7 +1,5 @@
from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
@ -98,6 +96,8 @@ def _date_add_sql(kind):
class MySQL(Dialect):
time_format = "'%Y-%m-%d %T'"
# https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions
time_mapping = {
"%M": "%B",
@ -110,6 +110,7 @@ class MySQL(Dialect):
"%u": "%W",
"%k": "%-H",
"%l": "%-I",
"%T": "%H:%M:%S",
}
class Tokenizer(tokens.Tokenizer):
@ -428,6 +429,7 @@ class MySQL(Dialect):
)
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
NULL_ORDERING_SUPPORTED = False
TRANSFORMS = {
@ -449,23 +451,12 @@ class MySQL(Dialect):
exp.StrPosition: strposition_to_locate_sql,
}
ROOT_PROPERTIES = {
exp.EngineProperty,
exp.AutoIncrementProperty,
exp.CharacterSetProperty,
exp.CollateProperty,
exp.SchemaCommentProperty,
exp.LikeProperty,
}
TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy()
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT)
TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT)
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMBLOB)
TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB)
WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()
def show_sql(self, expression):
this = f" {expression.name}"
full = " FULL" if expression.args.get("full") else ""

View file

@ -44,6 +44,8 @@ class Oracle(Dialect):
}
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "NUMBER",
@ -69,6 +71,7 @@ class Oracle(Dialect):
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
exp.Substring: rename_func("SUBSTR"),
}
def query_modifiers(self, expression, *sqls):
@ -90,6 +93,7 @@ class Oracle(Dialect):
self.sql(expression, "order"),
self.sql(expression, "offset"), # offset before limit in oracle
self.sql(expression, "limit"),
self.sql(expression, "lock"),
sep="",
)

View file

@ -148,6 +148,22 @@ def _serial_to_generated(expression):
return expression
def _generate_series(args):
# The goal is to convert step values like '1 day' or INTERVAL '1 day' into INTERVAL '1' day
step = seq_get(args, 2)
if step is None:
# Postgres allows calls with just two arguments -- the "step" argument defaults to 1
return exp.GenerateSeries.from_arg_list(args)
if step.is_string:
args[2] = exp.to_interval(step.this)
elif isinstance(step, exp.Interval) and not step.args.get("unit"):
args[2] = exp.to_interval(step.this.this)
return exp.GenerateSeries.from_arg_list(args)
def _to_timestamp(args):
# TO_TIMESTAMP accepts either a single double argument or (text, text)
if len(args) == 1:
@ -195,29 +211,6 @@ class Postgres(Dialect):
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
CREATABLES = (
"AGGREGATE",
"CAST",
"CONVERSION",
"COLLATION",
"DEFAULT CONVERSION",
"CONSTRAINT",
"DOMAIN",
"EXTENSION",
"FOREIGN",
"FUNCTION",
"OPERATOR",
"POLICY",
"ROLE",
"RULE",
"SEQUENCE",
"TEXT",
"TRIGGER",
"TYPE",
"UNLOGGED",
"USER",
)
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"~~": TokenType.LIKE,
@ -243,8 +236,6 @@ class Postgres(Dialect):
"TEMP": TokenType.TEMPORARY,
"UUID": TokenType.UUID,
"CSTRING": TokenType.PSEUDO_TYPE,
**{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
**{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
}
QUOTES = ["'", "$$"]
SINGLE_TOKENS = {
@ -257,8 +248,10 @@ class Postgres(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"NOW": exp.CurrentTimestamp.from_arg_list,
"TO_TIMESTAMP": _to_timestamp,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
"GENERATE_SERIES": _generate_series,
}
BITWISE = {
@ -272,6 +265,8 @@ class Postgres(Dialect):
}
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "SMALLINT",

View file

@ -105,6 +105,29 @@ def _ts_or_ds_add_sql(self, expression):
return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))"
def _sequence_sql(self, expression):
start = expression.args["start"]
end = expression.args["end"]
step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series
target_type = None
if isinstance(start, exp.Cast):
target_type = start.to
elif isinstance(end, exp.Cast):
target_type = end.to
if target_type and target_type.this == exp.DataType.Type.TIMESTAMP:
to = target_type.copy()
if target_type is start.to:
end = exp.Cast(this=end, to=to)
else:
start = exp.Cast(this=start, to=to)
return f"SEQUENCE({self.format_args(start, end, step)})"
def _ensure_utf8(charset):
if charset.name.lower() != "utf-8":
raise UnsupportedError(f"Unsupported charset {charset}")
@ -145,7 +168,7 @@ def _from_unixtime(args):
class Presto(Dialect):
index_offset = 1
null_ordering = "nulls_are_last"
time_format = "'%Y-%m-%d %H:%i:%S'"
time_format = MySQL.time_format # type: ignore
time_mapping = MySQL.time_mapping # type: ignore
class Tokenizer(tokens.Tokenizer):
@ -197,7 +220,10 @@ class Presto(Dialect):
class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
ROOT_PROPERTIES = {exp.SchemaCommentProperty}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.LocationProperty: exp.Properties.Location.UNSUPPORTED,
}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
@ -223,6 +249,7 @@ class Presto(Dialect):
exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DataType: _datatype_sql,
exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
@ -231,6 +258,7 @@ class Presto(Dialect):
exp.Decode: _decode_sql,
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
exp.Encode: _encode_sql,
exp.GenerateSeries: _sequence_sql,
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql,
exp.ILike: no_ilike_sql,

View file

@ -61,14 +61,9 @@ class Redshift(Postgres):
exp.DataType.Type.INT: "INTEGER",
}
ROOT_PROPERTIES = {
exp.DistKeyProperty,
exp.SortKeyProperty,
exp.DistStyleProperty,
}
WITH_PROPERTIES = {
exp.LikeProperty,
PROPERTIES_LOCATION = {
**Postgres.Generator.PROPERTIES_LOCATION, # type: ignore
exp.LikeProperty: exp.Properties.Location.POST_SCHEMA_WITH,
}
TRANSFORMS = {

View file

@ -234,15 +234,6 @@ class Snowflake(Dialect):
"replace": "RENAME",
}
ROOT_PROPERTIES = {
exp.PartitionedByProperty,
exp.ReturnsProperty,
exp.LanguageProperty,
exp.SchemaCommentProperty,
exp.ExecuteAsProperty,
exp.VolatilityProperty,
}
def except_op(self, expression):
if not expression.args.get("distinct", False):
self.unsupported("EXCEPT with All is not supported in Snowflake")

View file

@ -73,6 +73,19 @@ class Spark(Hive):
),
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"IIF": exp.If.from_arg_list,
"AGGREGATE": exp.Reduce.from_arg_list,
"DAYOFWEEK": lambda args: exp.DayOfWeek(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"DAYOFMONTH": lambda args: exp.DayOfMonth(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"DAYOFYEAR": lambda args: exp.DayOfYear(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
}
FUNCTION_PARSERS = {
@ -105,6 +118,14 @@ class Spark(Hive):
exp.DataType.Type.BIGINT: "LONG",
}
PROPERTIES_LOCATION = {
**Hive.Generator.PROPERTIES_LOCATION, # type: ignore
exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
exp.CollateProperty: exp.Properties.Location.UNSUPPORTED,
}
TRANSFORMS = {
**Hive.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
@ -126,11 +147,27 @@ class Spark(Hive):
exp.VariancePop: rename_func("VAR_POP"),
exp.DateFromParts: rename_func("MAKE_DATE"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
}
TRANSFORMS.pop(exp.ArraySort)
TRANSFORMS.pop(exp.ILike)
WRAP_DERIVED_VALUES = False
def cast_sql(self, expression: exp.Cast) -> str:
if isinstance(expression.this, exp.Cast) and expression.this.is_type(
exp.DataType.Type.JSON
):
schema = f"'{self.sql(expression, 'to')}'"
return f"FROM_JSON({self.format_args(self.sql(expression.this, 'this'), schema)})"
if expression.to.is_type(exp.DataType.Type.JSON):
return f"TO_JSON({self.sql(expression, 'this')})"
return super(Spark.Generator, self).cast_sql(expression)
class Tokenizer(Hive.Tokenizer):
HEX_STRINGS = [("X'", "'")]

View file

@ -31,6 +31,5 @@ class Tableau(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"IFNULL": exp.Coalesce.from_arg_list,
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
}

View file

@ -76,6 +76,14 @@ class Teradata(Dialect):
)
class Generator(generator.Generator):
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX,
}
def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str:
return f"PARTITION BY {self.sql(expression, 'this')}"
# FROM before SET in Teradata UPDATE syntax
# https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause
def update_sql(self, expression: exp.Update) -> str:

View file

@ -412,6 +412,8 @@ class TSQL(Dialect):
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BOOLEAN: "BIT",

View file

@ -14,10 +14,6 @@ from sqlglot import Dialect
from sqlglot import expressions as exp
from sqlglot.helper import ensure_collection
if t.TYPE_CHECKING:
T = t.TypeVar("T")
Edit = t.Union[Insert, Remove, Move, Update, Keep]
@dataclass(frozen=True)
class Insert:
@ -56,6 +52,11 @@ class Keep:
target: exp.Expression
if t.TYPE_CHECKING:
T = t.TypeVar("T")
Edit = t.Union[Insert, Remove, Move, Update, Keep]
def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
"""
Returns the list of changes between the source and the target expressions.

View file

@ -1,5 +1,13 @@
"""
.. include:: ../../posts/python_sql_engine.md
----
"""
from __future__ import annotations
import logging
import time
import typing as t
from sqlglot import maybe_parse
from sqlglot.errors import ExecuteError
@ -11,42 +19,63 @@ from sqlglot.schema import ensure_schema
logger = logging.getLogger("sqlglot")
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
from sqlglot.executor.table import Tables
from sqlglot.expressions import Expression
from sqlglot.schema import Schema
def execute(sql, schema=None, read=None, tables=None):
def execute(
sql: str | Expression,
schema: t.Optional[t.Dict | Schema] = None,
read: DialectType = None,
tables: t.Optional[t.Dict] = None,
) -> Table:
"""
Run a sql query against data.
Args:
sql (str|sqlglot.Expression): a sql statement
schema (dict|sqlglot.optimizer.Schema): database schema.
This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of
the following forms:
1. {table: {col: type}}
2. {db: {table: {col: type}}}
3. {catalog: {db: {table: {col: type}}}}
read (str): the SQL dialect to apply during parsing
(eg. "spark", "hive", "presto", "mysql").
tables (dict): additional tables to register.
sql: a sql statement.
schema: database schema.
This can either be an instance of `Schema` or a mapping in one of the following forms:
1. {table: {col: type}}
2. {db: {table: {col: type}}}
3. {catalog: {db: {table: {col: type}}}}
read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
tables: additional tables to register.
Returns:
sqlglot.executor.Table: Simple columnar data structure.
Simple columnar data structure.
"""
tables = ensure_tables(tables)
tables_ = ensure_tables(tables)
if not schema:
schema = {
name: {column: type(table[0][column]).__name__ for column in table.columns}
for name, table in tables.mapping.items()
for name, table in tables_.mapping.items()
}
schema = ensure_schema(schema)
if tables.supported_table_args and tables.supported_table_args != schema.supported_table_args:
if tables_.supported_table_args and tables_.supported_table_args != schema.supported_table_args:
raise ExecuteError("Tables must support the same table args as schema")
expression = maybe_parse(sql, dialect=read)
now = time.time()
expression = optimize(expression, schema, leave_tables_isolated=True)
logger.debug("Optimization finished: %f", time.time() - now)
logger.debug("Optimized SQL: %s", expression.sql(pretty=True))
plan = Plan(expression)
logger.debug("Logical Plan: %s", plan)
now = time.time()
result = PythonExecutor(tables=tables).execute(plan)
result = PythonExecutor(tables=tables_).execute(plan)
logger.debug("Query finished: %f", time.time() - now)
return result

View file

@ -171,5 +171,6 @@ ENV = {
"STRPOSITION": str_position,
"SUB": null_if_any(lambda e, this: e - this),
"SUBSTRING": substring,
"TIMESTRTOTIME": null_if_any(lambda arg: datetime.datetime.fromisoformat(arg)),
"UPPER": null_if_any(lambda arg: arg.upper()),
}

View file

@ -1,5 +1,7 @@
from __future__ import annotations
import typing as t
from sqlglot.helper import dict_depth
from sqlglot.schema import AbstractMappingSchema
@ -106,11 +108,11 @@ class Tables(AbstractMappingSchema[Table]):
pass
def ensure_tables(d: dict | None) -> Tables:
def ensure_tables(d: t.Optional[t.Dict]) -> Tables:
return Tables(_ensure_tables(d))
def _ensure_tables(d: dict | None) -> dict:
def _ensure_tables(d: t.Optional[t.Dict]) -> t.Dict:
if not d:
return {}
@ -127,4 +129,5 @@ def _ensure_tables(d: dict | None) -> dict:
columns = tuple(table[0]) if table else ()
rows = [tuple(row[c] for c in columns) for row in table]
result[name] = Table(columns=columns, rows=rows)
return result

View file

@ -32,13 +32,7 @@ from sqlglot.helper import (
from sqlglot.tokens import Token
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import Dialect
IntoType = t.Union[
str,
t.Type[Expression],
t.Collection[t.Union[str, t.Type[Expression]]],
]
from sqlglot.dialects.dialect import DialectType
class _Expression(type):
@ -427,7 +421,7 @@ class Expression(metaclass=_Expression):
def __repr__(self):
return self._to_s()
def sql(self, dialect: Dialect | str | None = None, **opts) -> str:
def sql(self, dialect: DialectType = None, **opts) -> str:
"""
Returns SQL string representation of this tree.
@ -595,6 +589,14 @@ class Expression(metaclass=_Expression):
return load(obj)
if t.TYPE_CHECKING:
IntoType = t.Union[
str,
t.Type[Expression],
t.Collection[t.Union[str, t.Type[Expression]]],
]
class Condition(Expression):
def and_(self, *expressions, dialect=None, **opts):
"""
@ -1285,6 +1287,18 @@ class Property(Expression):
arg_types = {"this": True, "value": True}
class AlgorithmProperty(Property):
arg_types = {"this": True}
class DefinerProperty(Property):
arg_types = {"this": True}
class SqlSecurityProperty(Property):
arg_types = {"definer": True}
class TableFormatProperty(Property):
arg_types = {"this": True}
@ -1425,13 +1439,15 @@ class IsolatedLoadingProperty(Property):
class Properties(Expression):
arg_types = {"expressions": True, "before": False}
arg_types = {"expressions": True}
NAME_TO_PROPERTY = {
"ALGORITHM": AlgorithmProperty,
"AUTO_INCREMENT": AutoIncrementProperty,
"CHARACTER SET": CharacterSetProperty,
"COLLATE": CollateProperty,
"COMMENT": SchemaCommentProperty,
"DEFINER": DefinerProperty,
"DISTKEY": DistKeyProperty,
"DISTSTYLE": DistStyleProperty,
"ENGINE": EngineProperty,
@ -1447,6 +1463,14 @@ class Properties(Expression):
PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()}
class Location(AutoName):
POST_CREATE = auto()
PRE_SCHEMA = auto()
POST_INDEX = auto()
POST_SCHEMA_ROOT = auto()
POST_SCHEMA_WITH = auto()
UNSUPPORTED = auto()
@classmethod
def from_dict(cls, properties_dict) -> Properties:
expressions = []
@ -1592,6 +1616,7 @@ QUERY_MODIFIERS = {
"order": False,
"limit": False,
"offset": False,
"lock": False,
}
@ -1713,6 +1738,12 @@ class Schema(Expression):
arg_types = {"this": False, "expressions": False}
# Used to represent the FOR UPDATE and FOR SHARE locking read types.
# https://dev.mysql.com/doc/refman/8.0/en/innodb-locking-reads.html
class Lock(Expression):
arg_types = {"update": True}
class Select(Subqueryable):
arg_types = {
"with": False,
@ -2243,6 +2274,30 @@ class Select(Subqueryable):
properties=properties_expression,
)
def lock(self, update: bool = True, copy: bool = True) -> Select:
"""
Set the locking read mode for this expression.
Examples:
>>> Select().select("x").from_("tbl").where("x = 'a'").lock().sql("mysql")
"SELECT x FROM tbl WHERE x = 'a' FOR UPDATE"
>>> Select().select("x").from_("tbl").where("x = 'a'").lock(update=False).sql("mysql")
"SELECT x FROM tbl WHERE x = 'a' FOR SHARE"
Args:
update: if `True`, the locking type will be `FOR UPDATE`, else it will be `FOR SHARE`.
copy: if `False`, modify this expression instance in-place.
Returns:
The modified expression.
"""
inst = _maybe_copy(self, copy)
inst.set("lock", Lock(update=update))
return inst
@property
def named_selects(self) -> t.List[str]:
return [e.output_name for e in self.expressions if e.alias_or_name]
@ -2456,24 +2511,28 @@ class DataType(Expression):
@classmethod
def build(
cls, dtype: str | DataType.Type, dialect: t.Optional[str | Dialect] = None, **kwargs
cls, dtype: str | DataType | DataType.Type, dialect: DialectType = None, **kwargs
) -> DataType:
from sqlglot import parse_one
if isinstance(dtype, str):
data_type_exp: t.Optional[Expression]
if dtype.upper() in cls.Type.__members__:
data_type_exp = DataType(this=DataType.Type[dtype.upper()])
data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[dtype.upper()])
else:
data_type_exp = parse_one(dtype, read=dialect, into=DataType)
if data_type_exp is None:
raise ValueError(f"Unparsable data type value: {dtype}")
elif isinstance(dtype, DataType.Type):
data_type_exp = DataType(this=dtype)
elif isinstance(dtype, DataType):
return dtype
else:
raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type")
return DataType(**{**data_type_exp.args, **kwargs})
def is_type(self, dtype: DataType.Type) -> bool:
return self.this == dtype
# https://www.postgresql.org/docs/15/datatype-pseudo.html
class PseudoType(Expression):
@ -2840,6 +2899,10 @@ class Array(Func):
is_var_len_args = True
class GenerateSeries(Func):
arg_types = {"start": True, "end": True, "step": False}
class ArrayAgg(AggFunc):
pass
@ -2909,6 +2972,9 @@ class Cast(Func):
def output_name(self):
return self.name
def is_type(self, dtype: DataType.Type) -> bool:
return self.to.is_type(dtype)
class Collate(Binary):
pass
@ -2989,6 +3055,22 @@ class DatetimeTrunc(Func, TimeUnit):
arg_types = {"this": True, "unit": True, "zone": False}
class DayOfWeek(Func):
_sql_names = ["DAY_OF_WEEK", "DAYOFWEEK"]
class DayOfMonth(Func):
_sql_names = ["DAY_OF_MONTH", "DAYOFMONTH"]
class DayOfYear(Func):
_sql_names = ["DAY_OF_YEAR", "DAYOFYEAR"]
class WeekOfYear(Func):
_sql_names = ["WEEK_OF_YEAR", "WEEKOFYEAR"]
class LastDateOfMonth(Func):
pass
@ -3239,7 +3321,7 @@ class ReadCSV(Func):
class Reduce(Func):
arg_types = {"this": True, "initial": True, "merge": True, "finish": True}
arg_types = {"this": True, "initial": True, "merge": True, "finish": False}
class RegexpLike(Func):
@ -3476,7 +3558,7 @@ def maybe_parse(
sql_or_expression: str | Expression,
*,
into: t.Optional[IntoType] = None,
dialect: t.Optional[str] = None,
dialect: DialectType = None,
prefix: t.Optional[str] = None,
**opts,
) -> Expression:
@ -3959,6 +4041,28 @@ def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
return identifier
INTERVAL_STRING_RE = re.compile(r"\s*([0-9]+)\s*([a-zA-Z]+)\s*")
def to_interval(interval: str | Literal) -> Interval:
"""Builds an interval expression from a string like '1 day' or '5 months'."""
if isinstance(interval, Literal):
if not interval.is_string:
raise ValueError("Invalid interval string.")
interval = interval.this
interval_parts = INTERVAL_STRING_RE.match(interval) # type: ignore
if not interval_parts:
raise ValueError("Invalid interval string.")
return Interval(
this=Literal.string(interval_parts.group(1)),
unit=Var(this=interval_parts.group(2)),
)
@t.overload
def to_table(sql_path: str | Table, **kwargs) -> Table:
...
@ -4050,7 +4154,8 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
def subquery(expression, alias=None, dialect=None, **opts):
"""
Build a subquery expression.
Expample:
Example:
>>> subquery('select x from tbl', 'bar').select('x').sql()
'SELECT x FROM (SELECT x FROM tbl) AS bar'
@ -4072,6 +4177,7 @@ def subquery(expression, alias=None, dialect=None, **opts):
def column(col, table=None, quoted=None) -> Column:
"""
Build a Column.
Args:
col (str | Expression): column name
table (str | Expression): table name
@ -4084,6 +4190,24 @@ def column(col, table=None, quoted=None) -> Column:
)
def cast(expression: str | Expression, to: str | DataType | DataType.Type, **opts) -> Cast:
"""Cast an expression to a data type.
Example:
>>> cast('x + 1', 'int').sql()
'CAST(x + 1 AS INT)'
Args:
expression: The expression to cast.
to: The datatype to cast to.
Returns:
A cast node.
"""
expression = maybe_parse(expression, **opts)
return Cast(this=expression, to=DataType.build(to, **opts))
def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
"""Build a Table.
@ -4137,7 +4261,7 @@ def values(
types = list(columns.values())
expressions[0].set(
"expressions",
[Cast(this=x, to=types[i]) for i, x in enumerate(expressions[0].expressions)],
[cast(x, types[i]) for i, x in enumerate(expressions[0].expressions)],
)
return Values(
expressions=expressions,
@ -4373,7 +4497,7 @@ def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True
return expression.transform(_expand, copy=copy)
def func(name: str, *args, dialect: t.Optional[Dialect | str] = None, **kwargs) -> Func:
def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
"""
Returns a Func expression.

View file

@ -67,6 +67,7 @@ class Generator:
exp.VolatilityProperty: lambda self, e: e.name,
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
}
# Whether 'CREATE ... TRANSIENT ... TABLE' is allowed
@ -75,6 +76,9 @@ class Generator:
# Whether or not null ordering is supported in order by
NULL_ORDERING_SUPPORTED = True
# Whether or not locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported
LOCKING_READS_SUPPORTED = False
# Always do union distinct or union all
EXPLICIT_UNION = False
@ -99,34 +103,42 @@ class Generator:
STRUCT_DELIMITER = ("<", ">")
BEFORE_PROPERTIES = {
exp.FallbackProperty,
exp.WithJournalTableProperty,
exp.LogProperty,
exp.JournalProperty,
exp.AfterJournalProperty,
exp.ChecksumProperty,
exp.FreespaceProperty,
exp.MergeBlockRatioProperty,
exp.DataBlocksizeProperty,
exp.BlockCompressionProperty,
exp.IsolatedLoadingProperty,
}
ROOT_PROPERTIES = {
exp.ReturnsProperty,
exp.LanguageProperty,
exp.DistStyleProperty,
exp.DistKeyProperty,
exp.SortKeyProperty,
exp.LikeProperty,
}
WITH_PROPERTIES = {
exp.Property,
exp.FileFormatProperty,
exp.PartitionedByProperty,
exp.TableFormatProperty,
PROPERTIES_LOCATION = {
exp.AfterJournalProperty: exp.Properties.Location.PRE_SCHEMA,
exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE,
exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.BlockCompressionProperty: exp.Properties.Location.PRE_SCHEMA,
exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.ChecksumProperty: exp.Properties.Location.PRE_SCHEMA,
exp.CollateProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.DataBlocksizeProperty: exp.Properties.Location.PRE_SCHEMA,
exp.DefinerProperty: exp.Properties.Location.POST_CREATE,
exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.EngineProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.ExecuteAsProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.FallbackProperty: exp.Properties.Location.PRE_SCHEMA,
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA_WITH,
exp.FreespaceProperty: exp.Properties.Location.PRE_SCHEMA,
exp.IsolatedLoadingProperty: exp.Properties.Location.PRE_SCHEMA,
exp.JournalProperty: exp.Properties.Location.PRE_SCHEMA,
exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.LikeProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.LocationProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.LogProperty: exp.Properties.Location.PRE_SCHEMA,
exp.MergeBlockRatioProperty: exp.Properties.Location.PRE_SCHEMA,
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_WITH,
exp.Property: exp.Properties.Location.POST_SCHEMA_WITH,
exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA_WITH,
exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.WithJournalTableProperty: exp.Properties.Location.PRE_SCHEMA,
}
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
@ -284,10 +296,10 @@ class Generator:
)
return f"({self.sep('')}{this_sql}{self.seg(')', sep='')}"
def no_identify(self, func: t.Callable[[], str]) -> str:
def no_identify(self, func: t.Callable[..., str], *args, **kwargs) -> str:
original = self.identify
self.identify = False
result = func()
result = func(*args, **kwargs)
self.identify = original
return result
@ -455,19 +467,33 @@ class Generator:
def create_sql(self, expression: exp.Create) -> str:
kind = self.sql(expression, "kind").upper()
has_before_properties = expression.args.get("properties")
has_before_properties = (
has_before_properties.args.get("before") if has_before_properties else None
)
if kind == "TABLE" and has_before_properties:
properties = expression.args.get("properties")
properties_exp = expression.copy()
properties_locs = self.locate_properties(properties) if properties else {}
if properties_locs.get(exp.Properties.Location.POST_SCHEMA_ROOT) or properties_locs.get(
exp.Properties.Location.POST_SCHEMA_WITH
):
properties_exp.set(
"properties",
exp.Properties(
expressions=[
*properties_locs[exp.Properties.Location.POST_SCHEMA_ROOT],
*properties_locs[exp.Properties.Location.POST_SCHEMA_WITH],
]
),
)
if kind == "TABLE" and properties_locs.get(exp.Properties.Location.PRE_SCHEMA):
this_name = self.sql(expression.this, "this")
this_properties = self.sql(expression, "properties")
this_properties = self.properties(
exp.Properties(expressions=properties_locs[exp.Properties.Location.PRE_SCHEMA]),
wrapped=False,
)
this_schema = f"({self.expressions(expression.this)})"
this = f"{this_name}, {this_properties} {this_schema}"
properties = ""
properties_sql = ""
else:
this = self.sql(expression, "this")
properties = self.sql(expression, "properties")
properties_sql = self.sql(properties_exp, "properties")
begin = " BEGIN" if expression.args.get("begin") else ""
expression_sql = self.sql(expression, "expression")
expression_sql = f" AS{begin}{self.sep()}{expression_sql}" if expression_sql else ""
@ -514,11 +540,31 @@ class Generator:
if index.args.get("columns")
else ""
)
if index.args.get("primary") and properties_locs.get(
exp.Properties.Location.POST_INDEX
):
postindex_props_sql = self.properties(
exp.Properties(
expressions=properties_locs[exp.Properties.Location.POST_INDEX]
),
wrapped=False,
)
ind_columns = f"{ind_columns} {postindex_props_sql}"
indexes_sql.append(
f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}"
)
index_sql = "".join(indexes_sql)
postcreate_props_sql = ""
if properties_locs.get(exp.Properties.Location.POST_CREATE):
postcreate_props_sql = self.properties(
exp.Properties(expressions=properties_locs[exp.Properties.Location.POST_CREATE]),
sep=" ",
prefix=" ",
wrapped=False,
)
modifiers = "".join(
(
replace,
@ -531,6 +577,7 @@ class Generator:
multiset,
global_temporary,
volatile,
postcreate_props_sql,
)
)
no_schema_binding = (
@ -539,7 +586,7 @@ class Generator:
post_expression_modifiers = "".join((data, statistics, no_primary_index))
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}{no_schema_binding}"
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{post_expression_modifiers}{index_sql}{no_schema_binding}"
return self.prepend_ctes(expression, expression_sql)
def describe_sql(self, expression: exp.Describe) -> str:
@ -665,24 +712,19 @@ class Generator:
return f"PARTITION({self.expressions(expression)})"
def properties_sql(self, expression: exp.Properties) -> str:
before_properties = []
root_properties = []
with_properties = []
for p in expression.expressions:
p_class = p.__class__
if p_class in self.BEFORE_PROPERTIES:
before_properties.append(p)
elif p_class in self.WITH_PROPERTIES:
p_loc = self.PROPERTIES_LOCATION[p.__class__]
if p_loc == exp.Properties.Location.POST_SCHEMA_WITH:
with_properties.append(p)
elif p_class in self.ROOT_PROPERTIES:
elif p_loc == exp.Properties.Location.POST_SCHEMA_ROOT:
root_properties.append(p)
return (
self.properties(exp.Properties(expressions=before_properties), before=True)
+ self.root_properties(exp.Properties(expressions=root_properties))
+ self.with_properties(exp.Properties(expressions=with_properties))
)
return self.root_properties(
exp.Properties(expressions=root_properties)
) + self.with_properties(exp.Properties(expressions=with_properties))
def root_properties(self, properties: exp.Properties) -> str:
if properties.expressions:
@ -695,17 +737,41 @@ class Generator:
prefix: str = "",
sep: str = ", ",
suffix: str = "",
before: bool = False,
wrapped: bool = True,
) -> str:
if properties.expressions:
expressions = self.expressions(properties, sep=sep, indent=False)
expressions = expressions if before else self.wrap(expressions)
expressions = self.wrap(expressions) if wrapped else expressions
return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}"
return ""
def with_properties(self, properties: exp.Properties) -> str:
return self.properties(properties, prefix=self.seg("WITH"))
def locate_properties(
self, properties: exp.Properties
) -> t.Dict[exp.Properties.Location, list[exp.Property]]:
properties_locs: t.Dict[exp.Properties.Location, list[exp.Property]] = {
key: [] for key in exp.Properties.Location
}
for p in properties.expressions:
p_loc = self.PROPERTIES_LOCATION[p.__class__]
if p_loc == exp.Properties.Location.PRE_SCHEMA:
properties_locs[exp.Properties.Location.PRE_SCHEMA].append(p)
elif p_loc == exp.Properties.Location.POST_INDEX:
properties_locs[exp.Properties.Location.POST_INDEX].append(p)
elif p_loc == exp.Properties.Location.POST_SCHEMA_ROOT:
properties_locs[exp.Properties.Location.POST_SCHEMA_ROOT].append(p)
elif p_loc == exp.Properties.Location.POST_SCHEMA_WITH:
properties_locs[exp.Properties.Location.POST_SCHEMA_WITH].append(p)
elif p_loc == exp.Properties.Location.POST_CREATE:
properties_locs[exp.Properties.Location.POST_CREATE].append(p)
elif p_loc == exp.Properties.Location.UNSUPPORTED:
self.unsupported(f"Unsupported property {p.key}")
return properties_locs
def property_sql(self, expression: exp.Property) -> str:
property_cls = expression.__class__
if property_cls == exp.Property:
@ -713,7 +779,7 @@ class Generator:
property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls)
if not property_name:
self.unsupported(f"Unsupported property {property_name}")
self.unsupported(f"Unsupported property {expression.key}")
return f"{property_name}={self.sql(expression, 'this')}"
@ -975,7 +1041,7 @@ class Generator:
rollup = self.expressions(expression, key="rollup", indent=False)
rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else ""
return f"{group_by}{grouping_sets}{cube}{rollup}"
return f"{group_by}{csv(grouping_sets, cube, rollup, sep=',')}"
def having_sql(self, expression: exp.Having) -> str:
this = self.indent(self.sql(expression, "this"))
@ -1015,7 +1081,7 @@ class Generator:
def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str:
args = self.expressions(expression, flat=True)
args = f"({args})" if len(args.split(",")) > 1 else args
return self.no_identify(lambda: f"{args} {arrow_sep} {self.sql(expression, 'this')}")
return f"{args} {arrow_sep} {self.sql(expression, 'this')}"
def lateral_sql(self, expression: exp.Lateral) -> str:
this = self.sql(expression, "this")
@ -1043,6 +1109,14 @@ class Generator:
this = self.sql(expression, "this")
return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}"
def lock_sql(self, expression: exp.Lock) -> str:
if self.LOCKING_READS_SUPPORTED:
lock_type = "UPDATE" if expression.args["update"] else "SHARE"
return self.seg(f"FOR {lock_type}")
self.unsupported("Locking reads using 'FOR UPDATE/SHARE' are not supported")
return ""
def literal_sql(self, expression: exp.Literal) -> str:
text = expression.this or ""
if expression.is_string:
@ -1163,6 +1237,7 @@ class Generator:
self.sql(expression, "order"),
self.sql(expression, "limit"),
self.sql(expression, "offset"),
self.sql(expression, "lock"),
sep="",
)
@ -1773,7 +1848,7 @@ class Generator:
def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str:
this = self.sql(expression, "this")
expressions = self.no_identify(lambda: self.expressions(expression))
expressions = self.no_identify(self.expressions, expression)
expressions = (
self.wrap(expressions) if expression.args.get("wrapped") else f" {expressions}"
)

View file

@ -9,6 +9,9 @@ from sqlglot.optimizer import Scope, build_scope, optimize
from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
@dataclass(frozen=True)
class Node:
@ -36,7 +39,7 @@ def lineage(
schema: t.Optional[t.Dict | Schema] = None,
sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns),
dialect: t.Optional[str] = None,
dialect: DialectType = None,
) -> Node:
"""Build the lineage graph for a column of a SQL query.
@ -126,7 +129,7 @@ class LineageHTML:
def __init__(
self,
node: Node,
dialect: t.Optional[str] = None,
dialect: DialectType = None,
imports: bool = True,
**opts: t.Any,
):

View file

@ -114,7 +114,7 @@ def _eliminate_union(scope, existing_ctes, taken):
taken[alias] = scope
# Try to maintain the selections
expressions = scope.expression.args.get("expressions")
expressions = scope.selects
selects = [
exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name)
for e in expressions

View file

@ -300,7 +300,7 @@ class Scope:
list[exp.Expression]: expressions
"""
if isinstance(self.expression, exp.Union):
return []
return self.expression.unnest().selects
return self.expression.selects
@property

View file

@ -456,8 +456,10 @@ def extract_interval(interval):
def date_literal(date):
expr_type = exp.DataType.build("DATETIME" if isinstance(date, datetime.datetime) else "DATE")
return exp.Cast(this=exp.Literal.string(date), to=expr_type)
return exp.cast(
exp.Literal.string(date),
"DATETIME" if isinstance(date, datetime.datetime) else "DATE",
)
def boolean_literal(condition):

View file

@ -80,6 +80,7 @@ class Parser(metaclass=_Parser):
length=exp.Literal.number(10),
),
"VAR_MAP": parse_var_map,
"IFNULL": exp.Coalesce.from_arg_list,
}
NO_PAREN_FUNCTIONS = {
@ -567,6 +568,8 @@ class Parser(metaclass=_Parser):
default=self._prev.text.upper() == "DEFAULT"
),
"BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(),
"ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty),
"DEFINER": lambda self: self._parse_definer(),
}
CONSTRAINT_PARSERS = {
@ -608,6 +611,7 @@ class Parser(metaclass=_Parser):
"order": lambda self: self._parse_order(),
"limit": lambda self: self._parse_limit(),
"offset": lambda self: self._parse_offset(),
"lock": lambda self: self._parse_lock(),
}
SHOW_PARSERS: t.Dict[str, t.Callable] = {}
@ -850,7 +854,7 @@ class Parser(metaclass=_Parser):
self.raise_error(error_message)
def _find_sql(self, start: Token, end: Token) -> str:
return self.sql[self._find_token(start) : self._find_token(end)]
return self.sql[self._find_token(start) : self._find_token(end) + len(end.text)]
def _find_token(self, token: Token) -> int:
line = 1
@ -901,6 +905,7 @@ class Parser(metaclass=_Parser):
return expression
def _parse_drop(self, default_kind: t.Optional[str] = None) -> t.Optional[exp.Expression]:
start = self._prev
temporary = self._match(TokenType.TEMPORARY)
materialized = self._match(TokenType.MATERIALIZED)
kind = self._match_set(self.CREATABLES) and self._prev.text
@ -908,8 +913,7 @@ class Parser(metaclass=_Parser):
if default_kind:
kind = default_kind
else:
self.raise_error(f"Expected {self.CREATABLES}")
return None
return self._parse_as_command(start)
return self.expression(
exp.Drop,
@ -929,6 +933,7 @@ class Parser(metaclass=_Parser):
)
def _parse_create(self) -> t.Optional[exp.Expression]:
start = self._prev
replace = self._match_pair(TokenType.OR, TokenType.REPLACE)
set_ = self._match(TokenType.SET) # Teradata
multiset = self._match_text_seq("MULTISET") # Teradata
@ -943,16 +948,19 @@ class Parser(metaclass=_Parser):
if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False):
self._match(TokenType.TABLE)
properties = None
create_token = self._match_set(self.CREATABLES) and self._prev
if not create_token:
self.raise_error(f"Expected {self.CREATABLES}")
return None
properties = self._parse_properties()
create_token = self._match_set(self.CREATABLES) and self._prev
if not properties or not create_token:
return self._parse_as_command(start)
exists = self._parse_exists(not_=True)
this = None
expression = None
properties = None
data = None
statistics = None
no_primary_index = None
@ -1006,6 +1014,14 @@ class Parser(metaclass=_Parser):
indexes = []
while True:
index = self._parse_create_table_index()
# post index PARTITION BY property
if self._match(TokenType.PARTITION_BY, advance=False):
if properties:
properties.expressions.append(self._parse_property())
else:
properties = self._parse_properties()
if not index:
break
else:
@ -1040,6 +1056,9 @@ class Parser(metaclass=_Parser):
)
def _parse_property_before(self) -> t.Optional[exp.Expression]:
self._match(TokenType.COMMA)
# parsers look to _prev for no/dual/default, so need to consume first
self._match_text_seq("NO")
self._match_text_seq("DUAL")
self._match_text_seq("DEFAULT")
@ -1059,6 +1078,9 @@ class Parser(metaclass=_Parser):
if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY):
return self._parse_sortkey(compound=True)
if self._match_text_seq("SQL", "SECURITY"):
return self.expression(exp.SqlSecurityProperty, definer=self._match_text_seq("DEFINER"))
assignment = self._match_pair(
TokenType.VAR, TokenType.EQ, advance=False
) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False)
@ -1083,7 +1105,6 @@ class Parser(metaclass=_Parser):
while True:
if before:
self._match(TokenType.COMMA)
identified_property = self._parse_property_before()
else:
identified_property = self._parse_property()
@ -1094,7 +1115,7 @@ class Parser(metaclass=_Parser):
properties.append(p)
if properties:
return self.expression(exp.Properties, expressions=properties, before=before)
return self.expression(exp.Properties, expressions=properties)
return None
@ -1118,6 +1139,19 @@ class Parser(metaclass=_Parser):
return self._parse_withisolatedloading()
# https://dev.mysql.com/doc/refman/8.0/en/create-view.html
def _parse_definer(self) -> t.Optional[exp.Expression]:
self._match(TokenType.EQ)
user = self._parse_id_var()
self._match(TokenType.PARAMETER)
host = self._parse_id_var() or (self._match(TokenType.MOD) and self._prev.text)
if not user or not host:
return None
return exp.DefinerProperty(this=f"{user}@{host}")
def _parse_withjournaltable(self) -> exp.Expression:
self._match_text_seq("WITH", "JOURNAL", "TABLE")
self._match(TokenType.EQ)
@ -1695,12 +1729,10 @@ class Parser(metaclass=_Parser):
paren += 1
if self._curr.token_type == TokenType.R_PAREN:
paren -= 1
end = self._prev
self._advance()
if paren > 0:
self.raise_error("Expecting )", self._curr)
if not self._curr:
self.raise_error("Expecting pattern", self._curr)
end = self._prev
pattern = exp.Var(this=self._find_sql(start, end))
else:
pattern = None
@ -2044,9 +2076,16 @@ class Parser(metaclass=_Parser):
expressions = self._parse_csv(self._parse_conjunction)
grouping_sets = self._parse_grouping_sets()
self._match(TokenType.COMMA)
with_ = self._match(TokenType.WITH)
cube = self._match(TokenType.CUBE) and (with_ or self._parse_wrapped_id_vars())
rollup = self._match(TokenType.ROLLUP) and (with_ or self._parse_wrapped_id_vars())
cube = self._match(TokenType.CUBE) and (
with_ or self._parse_wrapped_csv(self._parse_column)
)
self._match(TokenType.COMMA)
rollup = self._match(TokenType.ROLLUP) and (
with_ or self._parse_wrapped_csv(self._parse_column)
)
return self.expression(
exp.Group,
@ -2149,6 +2188,14 @@ class Parser(metaclass=_Parser):
self._match_set((TokenType.ROW, TokenType.ROWS))
return self.expression(exp.Offset, this=this, expression=count)
def _parse_lock(self) -> t.Optional[exp.Expression]:
if self._match_text_seq("FOR", "UPDATE"):
return self.expression(exp.Lock, update=True)
if self._match_text_seq("FOR", "SHARE"):
return self.expression(exp.Lock, update=False)
return None
def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match_set(self.SET_OPERATIONS):
return this
@ -2330,12 +2377,21 @@ class Parser(metaclass=_Parser):
maybe_func = True
if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
return exp.DataType(
this = exp.DataType(
this=exp.DataType.Type.ARRAY,
expressions=[exp.DataType.build(type_token.value, expressions=expressions)],
nested=True,
)
while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
this = exp.DataType(
this=exp.DataType.Type.ARRAY,
expressions=[this],
nested=True,
)
return this
if self._match(TokenType.L_BRACKET):
self._retreat(index)
return None
@ -2430,7 +2486,12 @@ class Parser(metaclass=_Parser):
self.raise_error("Expected type")
elif op:
self._advance()
field = exp.Literal.string(self._prev.text)
value = self._prev.text
field = (
exp.Literal.number(value)
if self._prev.token_type == TokenType.NUMBER
else exp.Literal.string(value)
)
else:
field = self._parse_star() or self._parse_function() or self._parse_id_var()
@ -2752,7 +2813,23 @@ class Parser(metaclass=_Parser):
if not self._curr:
break
if self._match_text_seq("NOT", "ENFORCED"):
if self._match(TokenType.ON):
action = None
on = self._advance_any() and self._prev.text
if self._match(TokenType.NO_ACTION):
action = "NO ACTION"
elif self._match(TokenType.CASCADE):
action = "CASCADE"
elif self._match_pair(TokenType.SET, TokenType.NULL):
action = "SET NULL"
elif self._match_pair(TokenType.SET, TokenType.DEFAULT):
action = "SET DEFAULT"
else:
self.raise_error("Invalid key constraint")
options.append(f"ON {on} {action}")
elif self._match_text_seq("NOT", "ENFORCED"):
options.append("NOT ENFORCED")
elif self._match_text_seq("DEFERRABLE"):
options.append("DEFERRABLE")
@ -2762,10 +2839,6 @@ class Parser(metaclass=_Parser):
options.append("NORELY")
elif self._match_text_seq("MATCH", "FULL"):
options.append("MATCH FULL")
elif self._match_text_seq("ON", "UPDATE", "NO ACTION"):
options.append("ON UPDATE NO ACTION")
elif self._match_text_seq("ON", "DELETE", "NO ACTION"):
options.append("ON DELETE NO ACTION")
else:
break
@ -3158,7 +3231,9 @@ class Parser(metaclass=_Parser):
prefix += self._prev.text
if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS):
return exp.Identifier(this=prefix + self._prev.text, quoted=False)
quoted = self._prev.token_type == TokenType.STRING
return exp.Identifier(this=prefix + self._prev.text, quoted=quoted)
return None
def _parse_string(self) -> t.Optional[exp.Expression]:
@ -3486,6 +3561,11 @@ class Parser(metaclass=_Parser):
def _parse_set(self) -> exp.Expression:
return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
def _parse_as_command(self, start: Token) -> exp.Command:
while self._curr:
self._advance()
return exp.Command(this=self._find_sql(start, self._prev))
def _find_parser(
self, parsers: t.Dict[str, t.Callable], trie: t.Dict
) -> t.Optional[t.Callable]:

View file

@ -11,6 +11,7 @@ from sqlglot.trie import in_trie, new_trie
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.types import StructType
from sqlglot.dialects.dialect import DialectType
ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
@ -153,7 +154,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
self,
schema: t.Optional[t.Dict] = None,
visible: t.Optional[t.Dict] = None,
dialect: t.Optional[str] = None,
dialect: DialectType = None,
) -> None:
self.dialect = dialect
self.visible = visible or {}

View file

@ -665,6 +665,7 @@ class Tokenizer(metaclass=_Tokenizer):
"STRING": TokenType.TEXT,
"TEXT": TokenType.TEXT,
"CLOB": TokenType.TEXT,
"LONGVARCHAR": TokenType.TEXT,
"BINARY": TokenType.BINARY,
"BLOB": TokenType.VARBINARY,
"BYTEA": TokenType.VARBINARY,