1
0
Fork 0

Merging upstream version 10.2.6.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:58:37 +01:00
parent 40155883c5
commit 17f6b2c749
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
36 changed files with 1281 additions and 493 deletions

View file

@ -30,7 +30,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType
__version__ = "10.1.3"
__version__ = "10.2.6"
pretty = False

View file

@ -317,7 +317,7 @@ class DataFrame:
sqlglot.schema.add_table(
cache_table_name,
{
expression.alias_or_name: expression.type.name
expression.alias_or_name: expression.type.sql("spark")
for expression in select_expression.expressions
},
)

View file

@ -110,17 +110,17 @@ class BigQuery(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
"CURRENT_TIME": TokenType.CURRENT_TIME,
"GEOGRAPHY": TokenType.GEOGRAPHY,
"INT64": TokenType.BIGINT,
"FLOAT64": TokenType.DOUBLE,
"INT64": TokenType.BIGINT,
"NOT DETERMINISTIC": TokenType.VOLATILE,
"QUALIFY": TokenType.QUALIFY,
"UNKNOWN": TokenType.NULL,
"WINDOW": TokenType.WINDOW,
"NOT DETERMINISTIC": TokenType.VOLATILE,
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
}
KEYWORDS.pop("DIV")
@ -131,6 +131,7 @@ class BigQuery(Dialect):
"DATE_ADD": _date_add(exp.DateAdd),
"DATETIME_ADD": _date_add(exp.DatetimeAdd),
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
"REGEXP_CONTAINS": exp.RegexpLike.from_arg_list,
"TIME_ADD": _date_add(exp.TimeAdd),
"TIMESTAMP_ADD": _date_add(exp.TimestampAdd),
"DATE_SUB": _date_add(exp.DateSub),
@ -144,6 +145,7 @@ class BigQuery(Dialect):
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
"ARRAY": lambda self: self.expression(exp.Array, expressions=[self._parse_statement()]),
}
FUNCTION_PARSERS.pop("TRIM")
@ -161,7 +163,6 @@ class BigQuery(Dialect):
class Generator(generator.Generator):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.Array: inline_array_sql,
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
exp.DateSub: _date_add_sql("DATE", "SUB"),
@ -183,6 +184,7 @@ class BigQuery(Dialect):
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
if e.name == "IMMUTABLE"
else "NOT DETERMINISTIC",
exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
}
TYPE_MAPPING = {
@ -210,24 +212,31 @@ class BigQuery(Dialect):
EXPLICIT_UNION = True
def transaction_sql(self, *_):
def array_sql(self, expression: exp.Array) -> str:
first_arg = seq_get(expression.expressions, 0)
if isinstance(first_arg, exp.Subqueryable):
return f"ARRAY{self.wrap(self.sql(first_arg))}"
return inline_array_sql(self, expression)
def transaction_sql(self, *_) -> str:
return "BEGIN TRANSACTION"
def commit_sql(self, *_):
def commit_sql(self, *_) -> str:
return "COMMIT TRANSACTION"
def rollback_sql(self, *_):
def rollback_sql(self, *_) -> str:
return "ROLLBACK TRANSACTION"
def in_unnest_op(self, unnest):
return self.sql(unnest)
def in_unnest_op(self, expression: exp.Unnest) -> str:
return self.sql(expression)
def except_op(self, expression):
def except_op(self, expression: exp.Except) -> str:
if not expression.args.get("distinct", False):
self.unsupported("EXCEPT without DISTINCT is not supported in BigQuery")
return f"EXCEPT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
def intersect_op(self, expression):
def intersect_op(self, expression: exp.Intersect) -> str:
if not expression.args.get("distinct", False):
self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery")
return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"

View file

@ -190,6 +190,7 @@ class Hive(Dialect):
"ADD FILES": TokenType.COMMAND,
"ADD JAR": TokenType.COMMAND,
"ADD JARS": TokenType.COMMAND,
"WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES,
}
class Parser(parser.Parser):
@ -238,6 +239,13 @@ class Hive(Dialect):
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
}
PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS,
TokenType.SERDE_PROPERTIES: lambda self: exp.SerdeProperties(
expressions=self._parse_wrapped_csv(self._parse_property)
),
}
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@ -297,6 +305,8 @@ class Hive(Dialect):
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
exp.RowFormatSerdeProperty: lambda self, e: f"ROW FORMAT SERDE {self.sql(e, 'this')}",
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
}
@ -308,12 +318,15 @@ class Hive(Dialect):
exp.SchemaCommentProperty,
exp.LocationProperty,
exp.TableFormatProperty,
exp.RowFormatDelimitedProperty,
exp.RowFormatSerdeProperty,
exp.SerdeProperties,
}
def with_properties(self, properties):
return self.properties(
properties,
prefix="TBLPROPERTIES",
prefix=self.seg("TBLPROPERTIES"),
)
def datatype_sql(self, expression):

View file

@ -98,6 +98,7 @@ class Oracle(Dialect):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"MINUS": TokenType.EXCEPT,
"START": TokenType.BEGIN,
"TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR,

View file

@ -1,6 +1,7 @@
from __future__ import annotations
from sqlglot import exp, transforms
from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.postgres import Postgres
from sqlglot.tokens import TokenType
@ -13,12 +14,20 @@ class Redshift(Postgres):
"HH": "%H",
}
class Parser(Postgres.Parser):
FUNCTIONS = {
**Postgres.Parser.FUNCTIONS, # type: ignore
"DECODE": exp.Matches.from_arg_list,
"NVL": exp.Coalesce.from_arg_list,
}
class Tokenizer(Postgres.Tokenizer):
ESCAPES = ["\\"]
KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS, # type: ignore
"COPY": TokenType.COMMAND,
"ENCODE": TokenType.ENCODE,
"GEOMETRY": TokenType.GEOMETRY,
"GEOGRAPHY": TokenType.GEOGRAPHY,
"HLLSKETCH": TokenType.HLLSKETCH,
@ -50,4 +59,5 @@ class Redshift(Postgres):
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
exp.Matches: rename_func("DECODE"),
}

View file

@ -198,6 +198,7 @@ class Snowflake(Dialect):
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPNTZ": TokenType.TIMESTAMP,
"MINUS": TokenType.EXCEPT,
"SAMPLE": TokenType.TABLE_SAMPLE,
}

View file

@ -19,10 +19,13 @@ class reverse_key:
return other.obj < self.obj
def filter_nulls(func):
def filter_nulls(func, empty_null=True):
@wraps(func)
def _func(values):
return func(v for v in values if v is not None)
filtered = tuple(v for v in values if v is not None)
if not filtered and empty_null:
return None
return func(filtered)
return _func
@ -126,7 +129,7 @@ ENV = {
# aggs
"SUM": filter_nulls(sum),
"AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore
"COUNT": filter_nulls(lambda acc: sum(1 for _ in acc)),
"COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False),
"MAX": filter_nulls(max),
"MIN": filter_nulls(min),
# scalar functions

View file

@ -310,9 +310,9 @@ class PythonExecutor:
if i == length - 1:
context.set_range(start, end - 1)
add_row()
elif step.limit > 0:
elif step.limit > 0 and not group_by:
context.set_range(0, 0)
table.append(context.eval_tuple(group_by) + context.eval_tuple(aggregations))
table.append(context.eval_tuple(aggregations))
context = self.context({step.name: table, **{name: table for name in context.tables}})

View file

@ -43,14 +43,14 @@ class Expression(metaclass=_Expression):
key = "Expression"
arg_types = {"this": True}
__slots__ = ("args", "parent", "arg_key", "type", "comments")
__slots__ = ("args", "parent", "arg_key", "comments", "_type")
def __init__(self, **args):
self.args = args
self.parent = None
self.arg_key = None
self.type = None
self.comments = None
self._type: t.Optional[DataType] = None
for arg_key, value in self.args.items():
self._set_parent(arg_key, value)
@ -122,6 +122,16 @@ class Expression(metaclass=_Expression):
return "NULL"
return self.alias or self.name
@property
def type(self) -> t.Optional[DataType]:
return self._type
@type.setter
def type(self, dtype: t.Optional[DataType | DataType.Type | str]) -> None:
if dtype and not isinstance(dtype, DataType):
dtype = DataType.build(dtype)
self._type = dtype # type: ignore
def __deepcopy__(self, memo):
copy = self.__class__(**deepcopy(self.args))
copy.comments = self.comments
@ -348,7 +358,7 @@ class Expression(metaclass=_Expression):
indent += "".join([" "] * level)
left = f"({self.key.upper()} "
args = {
args: t.Dict[str, t.Any] = {
k: ", ".join(
v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v)
for v in ensure_collection(vs)
@ -612,6 +622,7 @@ class Create(Expression):
"properties": False,
"temporary": False,
"transient": False,
"external": False,
"replace": False,
"unique": False,
"materialized": False,
@ -744,13 +755,17 @@ class DefaultColumnConstraint(ColumnConstraintKind):
pass
class EncodeColumnConstraint(ColumnConstraintKind):
pass
class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
# this: True -> ALWAYS, this: False -> BY DEFAULT
arg_types = {"this": True, "expression": False}
class NotNullColumnConstraint(ColumnConstraintKind):
pass
arg_types = {"allow_null": False}
class PrimaryKeyColumnConstraint(ColumnConstraintKind):
@ -766,7 +781,7 @@ class Constraint(Expression):
class Delete(Expression):
arg_types = {"with": False, "this": True, "using": False, "where": False}
arg_types = {"with": False, "this": False, "using": False, "where": False}
class Drop(Expression):
@ -850,7 +865,7 @@ class Insert(Expression):
arg_types = {
"with": False,
"this": True,
"expression": True,
"expression": False,
"overwrite": False,
"exists": False,
"partition": False,
@ -1125,6 +1140,27 @@ class VolatilityProperty(Property):
arg_types = {"this": True}
class RowFormatDelimitedProperty(Property):
# https://cwiki.apache.org/confluence/display/hive/languagemanual+dml
arg_types = {
"fields": False,
"escaped": False,
"collection_items": False,
"map_keys": False,
"lines": False,
"null": False,
"serde": False,
}
class RowFormatSerdeProperty(Property):
arg_types = {"this": True}
class SerdeProperties(Property):
arg_types = {"expressions": True}
class Properties(Expression):
arg_types = {"expressions": True}
@ -1169,18 +1205,6 @@ class Reference(Expression):
arg_types = {"this": True, "expressions": True}
class RowFormat(Expression):
# https://cwiki.apache.org/confluence/display/hive/languagemanual+dml
arg_types = {
"fields": False,
"escaped": False,
"collection_items": False,
"map_keys": False,
"lines": False,
"null": False,
}
class Tuple(Expression):
arg_types = {"expressions": False}
@ -1208,6 +1232,9 @@ class Subqueryable(Unionable):
alias=TableAlias(this=to_identifier(alias)),
)
def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
raise NotImplementedError
@property
def ctes(self):
with_ = self.args.get("with")
@ -1320,6 +1347,32 @@ class Union(Subqueryable):
**QUERY_MODIFIERS,
}
def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
"""
Set the LIMIT expression.
Example:
>>> select("1").union(select("1")).limit(1).sql()
'SELECT * FROM (SELECT 1 UNION SELECT 1) AS "_l_0" LIMIT 1'
Args:
expression (str | int | Expression): the SQL code string to parse.
This can also be an integer.
If a `Limit` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `Limit`.
dialect (str): the dialect used to parse the input expression.
copy (bool): if `False`, modify this expression instance in-place.
opts (kwargs): other options to use to parse the input expressions.
Returns:
Select: The limited subqueryable.
"""
return (
select("*")
.from_(self.subquery(alias="_l_0", copy=copy))
.limit(expression, dialect=dialect, copy=False, **opts)
)
@property
def named_selects(self):
return self.this.unnest().named_selects
@ -1356,7 +1409,7 @@ class Unnest(UDTF):
class Update(Expression):
arg_types = {
"with": False,
"this": True,
"this": False,
"expressions": True,
"from": False,
"where": False,
@ -2057,15 +2110,20 @@ class DataType(Expression):
Type.TEXT,
}
NUMERIC_TYPES = {
INTEGER_TYPES = {
Type.INT,
Type.TINYINT,
Type.SMALLINT,
Type.BIGINT,
}
FLOAT_TYPES = {
Type.FLOAT,
Type.DOUBLE,
}
NUMERIC_TYPES = {*INTEGER_TYPES, *FLOAT_TYPES}
TEMPORAL_TYPES = {
Type.TIMESTAMP,
Type.TIMESTAMPTZ,
@ -2968,6 +3026,14 @@ class Use(Expression):
pass
class Merge(Expression):
arg_types = {"this": True, "using": True, "on": True, "expressions": True}
class When(Func):
arg_types = {"this": True, "then": True}
def _norm_args(expression):
args = {}

File diff suppressed because it is too large Load diff

View file

@ -385,3 +385,11 @@ def dict_depth(d: t.Dict) -> int:
except StopIteration:
# d.values() returns an empty sequence
return 1
def first(it: t.Iterable[T]) -> T:
"""Returns the first element from an iterable.
Useful for sets.
"""
return next(i for i in it)

View file

@ -14,7 +14,7 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
>>> schema = {"y": {"cola": "SMALLINT"}}
>>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
>>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
>>> annotated_expr.expressions[0].type # Get the type of "x.cola + 2.5 AS cola"
>>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola"
<Type.DOUBLE: 'DOUBLE'>
Args:
@ -41,9 +41,12 @@ class TypeAnnotator:
expr_type: lambda self, expr: self._annotate_binary(expr)
for expr_type in subclasses(exp.__name__, exp.Binary)
},
exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"].this),
exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.this),
exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr),
exp.Alias: lambda self, expr: self._annotate_unary(expr),
exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Literal: lambda self, expr: self._annotate_literal(expr),
exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
@ -52,6 +55,9 @@ class TypeAnnotator:
expr, exp.DataType.Type.BIGINT
),
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Min: lambda self, expr: self._annotate_by_args(expr, "this"),
exp.Max: lambda self, expr: self._annotate_by_args(expr, "this"),
exp.Sum: lambda self, expr: self._annotate_by_args(expr, "this", promote=True),
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
@ -263,10 +269,10 @@ class TypeAnnotator:
}
# First annotate the current scope's column references
for col in scope.columns:
source = scope.sources[col.table]
source = scope.sources.get(col.table)
if isinstance(source, exp.Table):
col.type = self.schema.get_column_type(source, col)
else:
elif source:
col.type = selects[col.table][col.name].type
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)
@ -280,6 +286,7 @@ class TypeAnnotator:
return expression # We've already inferred the expression's type
annotator = self.annotators.get(expression.__class__)
return (
annotator(self, expression)
if annotator
@ -295,18 +302,23 @@ class TypeAnnotator:
def _maybe_coerce(self, type1, type2):
# We propagate the NULL / UNKNOWN types upwards if found
if isinstance(type1, exp.DataType):
type1 = type1.this
if isinstance(type2, exp.DataType):
type2 = type2.this
if exp.DataType.Type.NULL in (type1, type2):
return exp.DataType.Type.NULL
if exp.DataType.Type.UNKNOWN in (type1, type2):
return exp.DataType.Type.UNKNOWN
return type2 if type2 in self.coerces_to[type1] else type1
return type2 if type2 in self.coerces_to.get(type1, {}) else type1
def _annotate_binary(self, expression):
self._annotate_args(expression)
left_type = expression.left.type
right_type = expression.right.type
left_type = expression.left.type.this
right_type = expression.right.type.this
if isinstance(expression, (exp.And, exp.Or)):
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
@ -348,7 +360,7 @@ class TypeAnnotator:
expression.type = target_type
return self._annotate_args(expression)
def _annotate_by_args(self, expression, *args):
def _annotate_by_args(self, expression, *args, promote=False):
self._annotate_args(expression)
expressions = []
for arg in args:
@ -360,4 +372,11 @@ class TypeAnnotator:
last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
expression.type = last_datatype or exp.DataType.Type.UNKNOWN
if promote:
if expression.type.this in exp.DataType.INTEGER_TYPES:
expression.type = exp.DataType.Type.BIGINT
elif expression.type.this in exp.DataType.FLOAT_TYPES:
expression.type = exp.DataType.Type.DOUBLE
return expression

View file

@ -13,13 +13,16 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
expression: The expression to canonicalize.
"""
exp.replace_children(expression, canonicalize)
expression = add_text_to_concat(expression)
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
return expression
def add_text_to_concat(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Add) and node.type in exp.DataType.TEXT_TYPES:
if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
node = exp.Concat(this=node.this, expression=node.expression)
return node
@ -30,14 +33,30 @@ def coerce_type(node: exp.Expression) -> exp.Expression:
elif isinstance(node, exp.Between):
_coerce_date(node.this, node.args["low"])
elif isinstance(node, exp.Extract):
if node.expression.type not in exp.DataType.TEMPORAL_TYPES:
if node.expression.type.this not in exp.DataType.TEMPORAL_TYPES:
_replace_cast(node.expression, "datetime")
return node
def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
if (
isinstance(expression, exp.Cast)
and expression.to.type
and expression.this.type
and expression.to.type.this == expression.this.type.this
):
return expression.this
return expression
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
for a, b in itertools.permutations([a, b]):
if a.type == exp.DataType.Type.DATE and b.type != exp.DataType.Type.DATE:
if (
a.type
and a.type.this == exp.DataType.Type.DATE
and b.type
and b.type.this != exp.DataType.Type.DATE
):
_replace_cast(b, "date")

View file

@ -7,7 +7,7 @@ from decimal import Decimal
from sqlglot import exp
from sqlglot.expressions import FALSE, NULL, TRUE
from sqlglot.generator import Generator
from sqlglot.helper import while_changing
from sqlglot.helper import first, while_changing
GENERATOR = Generator(normalize=True, identify=True)
@ -30,6 +30,7 @@ def simplify(expression):
def _simplify(expression, root=True):
node = expression
node = rewrite_between(node)
node = uniq_sort(node)
node = absorb_and_eliminate(node)
exp.replace_children(node, lambda e: _simplify(e, False))
@ -49,6 +50,19 @@ def simplify(expression):
return expression
def rewrite_between(expression: exp.Expression) -> exp.Expression:
"""Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
"""
if isinstance(expression, exp.Between):
return exp.and_(
exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
)
return expression
def simplify_not(expression):
"""
Demorgan's Law
@ -57,7 +71,7 @@ def simplify_not(expression):
"""
if isinstance(expression, exp.Not):
if isinstance(expression.this, exp.Null):
return NULL
return exp.null()
if isinstance(expression.this, exp.Paren):
condition = expression.this.unnest()
if isinstance(condition, exp.And):
@ -65,11 +79,11 @@ def simplify_not(expression):
if isinstance(condition, exp.Or):
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
if isinstance(condition, exp.Null):
return NULL
return exp.null()
if always_true(expression.this):
return FALSE
return exp.false()
if expression.this == FALSE:
return TRUE
return exp.true()
if isinstance(expression.this, exp.Not):
# double negation
# NOT NOT x -> x
@ -91,40 +105,119 @@ def flatten(expression):
def simplify_connectors(expression):
if isinstance(expression, exp.Connector):
left = expression.left
right = expression.right
if left == right:
return left
if isinstance(expression, exp.And):
if FALSE in (left, right):
return FALSE
if NULL in (left, right):
return NULL
if always_true(left) and always_true(right):
return TRUE
if always_true(left):
return right
if always_true(right):
def _simplify_connectors(expression, left, right):
if isinstance(expression, exp.Connector):
if left == right:
return left
elif isinstance(expression, exp.Or):
if always_true(left) or always_true(right):
return TRUE
if left == FALSE and right == FALSE:
return FALSE
if (
(left == NULL and right == NULL)
or (left == NULL and right == FALSE)
or (left == FALSE and right == NULL)
):
return NULL
if left == FALSE:
return right
if right == FALSE:
return left
return expression
if isinstance(expression, exp.And):
if FALSE in (left, right):
return exp.false()
if NULL in (left, right):
return exp.null()
if always_true(left) and always_true(right):
return exp.true()
if always_true(left):
return right
if always_true(right):
return left
return _simplify_comparison(expression, left, right)
elif isinstance(expression, exp.Or):
if always_true(left) or always_true(right):
return exp.true()
if left == FALSE and right == FALSE:
return exp.false()
if (
(left == NULL and right == NULL)
or (left == NULL and right == FALSE)
or (left == FALSE and right == NULL)
):
return exp.null()
if left == FALSE:
return right
if right == FALSE:
return left
return _simplify_comparison(expression, left, right, or_=True)
return None
return _flat_simplify(expression, _simplify_connectors)
LT_LTE = (exp.LT, exp.LTE)
GT_GTE = (exp.GT, exp.GTE)
COMPARISONS = (
*LT_LTE,
*GT_GTE,
exp.EQ,
exp.NEQ,
)
INVERSE_COMPARISONS = {
exp.LT: exp.GT,
exp.GT: exp.LT,
exp.LTE: exp.GTE,
exp.GTE: exp.LTE,
}
def _simplify_comparison(expression, left, right, or_=False):
if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
ll, lr = left.args.values()
rl, rr = right.args.values()
largs = {ll, lr}
rargs = {rl, rr}
matching = largs & rargs
columns = {m for m in matching if isinstance(m, exp.Column)}
if matching and columns:
try:
l = first(largs - columns)
r = first(rargs - columns)
except StopIteration:
return expression
# make sure the comparison is always of the form x > 1 instead of 1 < x
if left.__class__ in INVERSE_COMPARISONS and l == ll:
left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
if right.__class__ in INVERSE_COMPARISONS and r == rl:
right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
if l.is_number and r.is_number:
l = float(l.name)
r = float(r.name)
elif l.is_string and r.is_string:
l = l.name
r = r.name
else:
return None
for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
return left if (av > bv if or_ else av <= bv) else right
if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
return left if (av < bv if or_ else av >= bv) else right
# we can't ever shortcut to true because the column could be null
if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
if not or_ and av <= bv:
return exp.false()
elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
if not or_ and av >= bv:
return exp.false()
elif isinstance(a, exp.EQ):
if isinstance(b, exp.LT):
return exp.false() if av >= bv else a
if isinstance(b, exp.LTE):
return exp.false() if av > bv else a
if isinstance(b, exp.GT):
return exp.false() if av <= bv else a
if isinstance(b, exp.GTE):
return exp.false() if av < bv else a
if isinstance(b, exp.NEQ):
return exp.false() if av == bv else a
return None
def remove_compliments(expression):
@ -135,7 +228,7 @@ def remove_compliments(expression):
A OR NOT A -> TRUE
"""
if isinstance(expression, exp.Connector):
compliment = FALSE if isinstance(expression, exp.And) else TRUE
compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
for a, b in itertools.permutations(expression.flatten(), 2):
if is_complement(a, b):
@ -211,27 +304,7 @@ def absorb_and_eliminate(expression):
def simplify_literals(expression):
if isinstance(expression, exp.Binary):
operands = []
queue = deque(expression.flatten(unnest=False))
size = len(queue)
while queue:
a = queue.popleft()
for b in queue:
result = _simplify_binary(expression, a, b)
if result:
queue.remove(b)
queue.append(result)
break
else:
operands.append(a)
if len(operands) < size:
return functools.reduce(
lambda a, b: expression.__class__(this=a, expression=b), operands
)
return _flat_simplify(expression, _simplify_binary)
elif isinstance(expression, exp.Neg):
this = expression.this
if this.is_number:
@ -254,20 +327,13 @@ def _simplify_binary(expression, a, b):
if c == NULL:
if isinstance(a, exp.Literal):
return TRUE if not_ else FALSE
return exp.true() if not_ else exp.false()
if a == NULL:
return FALSE if not_ else TRUE
elif isinstance(expression, exp.NullSafeEQ):
if a == b:
return TRUE
elif isinstance(expression, exp.NullSafeNEQ):
if a == b:
return FALSE
return exp.false() if not_ else exp.true()
elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
return None
elif NULL in (a, b):
return NULL
if isinstance(expression, exp.EQ) and a == b:
return TRUE
return exp.null()
if a.is_number and b.is_number:
a = int(a.name) if a.is_int else Decimal(a.name)
@ -388,4 +454,27 @@ def date_literal(date):
def boolean_literal(condition):
return TRUE if condition else FALSE
return exp.true() if condition else exp.false()
def _flat_simplify(expression, simplifier):
operands = []
queue = deque(expression.flatten(unnest=False))
size = len(queue)
while queue:
a = queue.popleft()
for b in queue:
result = simplifier(expression, a, b)
if result:
queue.remove(b)
queue.append(result)
break
else:
operands.append(a)
if len(operands) < size:
return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
return expression

View file

@ -185,6 +185,7 @@ class Parser(metaclass=_Parser):
TokenType.LOCAL,
TokenType.LOCATION,
TokenType.MATERIALIZED,
TokenType.MERGE,
TokenType.NATURAL,
TokenType.NEXT,
TokenType.ONLY,
@ -211,7 +212,6 @@ class Parser(metaclass=_Parser):
TokenType.TABLE,
TokenType.TABLE_FORMAT,
TokenType.TEMPORARY,
TokenType.TRANSIENT,
TokenType.TOP,
TokenType.TRAILING,
TokenType.TRUE,
@ -229,6 +229,8 @@ class Parser(metaclass=_Parser):
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY}
UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET}
TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH}
FUNC_TOKENS = {
@ -241,6 +243,7 @@ class Parser(metaclass=_Parser):
TokenType.FORMAT,
TokenType.IDENTIFIER,
TokenType.ISNULL,
TokenType.MERGE,
TokenType.OFFSET,
TokenType.PRIMARY_KEY,
TokenType.REPLACE,
@ -407,6 +410,7 @@ class Parser(metaclass=_Parser):
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
TokenType.END: lambda self: self._parse_commit_or_rollback(),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
TokenType.MERGE: lambda self: self._parse_merge(),
}
UNARY_PARSERS = {
@ -474,6 +478,7 @@ class Parser(metaclass=_Parser):
TokenType.SORTKEY: lambda self: self._parse_sortkey(),
TokenType.LIKE: lambda self: self._parse_create_like(),
TokenType.RETURNS: lambda self: self._parse_returns(),
TokenType.ROW: lambda self: self._parse_row(),
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
TokenType.FORMAT: lambda self: self._parse_property_assignment(exp.FileFormatProperty),
@ -495,6 +500,8 @@ class Parser(metaclass=_Parser):
TokenType.VOLATILE: lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
),
TokenType.WITH: lambda self: self._parse_wrapped_csv(self._parse_property),
TokenType.PROPERTIES: lambda self: self._parse_wrapped_csv(self._parse_property),
}
CONSTRAINT_PARSERS = {
@ -802,7 +809,8 @@ class Parser(metaclass=_Parser):
def _parse_create(self):
replace = self._match_pair(TokenType.OR, TokenType.REPLACE)
temporary = self._match(TokenType.TEMPORARY)
transient = self._match(TokenType.TRANSIENT)
transient = self._match_text_seq("TRANSIENT")
external = self._match_text_seq("EXTERNAL")
unique = self._match(TokenType.UNIQUE)
materialized = self._match(TokenType.MATERIALIZED)
@ -846,6 +854,7 @@ class Parser(metaclass=_Parser):
properties=properties,
temporary=temporary,
transient=transient,
external=external,
replace=replace,
unique=unique,
materialized=materialized,
@ -861,8 +870,12 @@ class Parser(metaclass=_Parser):
if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY):
return self._parse_sortkey(compound=True)
if self._match_pair(TokenType.VAR, TokenType.EQ, advance=False):
key = self._parse_var()
assignment = self._match_pair(
TokenType.VAR, TokenType.EQ, advance=False
) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False)
if assignment:
key = self._parse_var() or self._parse_string()
self._match(TokenType.EQ)
return self.expression(exp.Property, this=key, value=self._parse_column())
@ -871,7 +884,10 @@ class Parser(metaclass=_Parser):
def _parse_property_assignment(self, exp_class):
self._match(TokenType.EQ)
self._match(TokenType.ALIAS)
return self.expression(exp_class, this=self._parse_var_or_string() or self._parse_number())
return self.expression(
exp_class,
this=self._parse_var_or_string() or self._parse_number() or self._parse_id_var(),
)
def _parse_partitioned_by(self):
self._match(TokenType.EQ)
@ -881,7 +897,7 @@ class Parser(metaclass=_Parser):
)
def _parse_distkey(self):
return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_var))
return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var))
def _parse_create_like(self):
table = self._parse_table(schema=True)
@ -898,7 +914,7 @@ class Parser(metaclass=_Parser):
def _parse_sortkey(self, compound=False):
return self.expression(
exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_var), compound=compound
exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_id_var), compound=compound
)
def _parse_character_set(self, default=False):
@ -929,23 +945,11 @@ class Parser(metaclass=_Parser):
properties = []
while True:
if self._match(TokenType.WITH):
properties.extend(self._parse_wrapped_csv(self._parse_property))
elif self._match(TokenType.PROPERTIES):
properties.extend(
self._parse_wrapped_csv(
lambda: self.expression(
exp.Property,
this=self._parse_string(),
value=self._match(TokenType.EQ) and self._parse_string(),
)
)
)
else:
identified_property = self._parse_property()
if not identified_property:
break
properties.append(identified_property)
identified_property = self._parse_property()
if not identified_property:
break
for p in ensure_collection(identified_property):
properties.append(p)
if properties:
return self.expression(exp.Properties, expressions=properties)
@ -963,7 +967,7 @@ class Parser(metaclass=_Parser):
exp.Directory,
this=self._parse_var_or_string(),
local=local,
row_format=self._parse_row_format(),
row_format=self._parse_row_format(match_row=True),
)
else:
self._match(TokenType.INTO)
@ -978,9 +982,17 @@ class Parser(metaclass=_Parser):
overwrite=overwrite,
)
def _parse_row_format(self):
if not self._match_pair(TokenType.ROW, TokenType.FORMAT):
def _parse_row(self):
if not self._match(TokenType.FORMAT):
return None
return self._parse_row_format()
def _parse_row_format(self, match_row=False):
if match_row and not self._match_pair(TokenType.ROW, TokenType.FORMAT):
return None
if self._match_text_seq("SERDE"):
return self.expression(exp.RowFormatSerdeProperty, this=self._parse_string())
self._match_text_seq("DELIMITED")
@ -998,7 +1010,7 @@ class Parser(metaclass=_Parser):
kwargs["lines"] = self._parse_string()
if self._match_text_seq("NULL", "DEFINED", "AS"):
kwargs["null"] = self._parse_string()
return self.expression(exp.RowFormat, **kwargs)
return self.expression(exp.RowFormatDelimitedProperty, **kwargs)
def _parse_load_data(self):
local = self._match(TokenType.LOCAL)
@ -1032,7 +1044,7 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.Update,
**{
"this": self._parse_table(schema=True),
"this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
"expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality),
"from": self._parse_from(),
"where": self._parse_where(),
@ -1183,9 +1195,11 @@ class Parser(metaclass=_Parser):
alias=alias,
)
def _parse_table_alias(self):
def _parse_table_alias(self, alias_tokens=None):
any_token = self._match(TokenType.ALIAS)
alias = self._parse_id_var(any_token=any_token, tokens=self.TABLE_ALIAS_TOKENS)
alias = self._parse_id_var(
any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS
)
columns = None
if self._match(TokenType.L_PAREN):
@ -1337,7 +1351,7 @@ class Parser(metaclass=_Parser):
columns=self._parse_expression(),
)
def _parse_table(self, schema=False):
def _parse_table(self, schema=False, alias_tokens=None):
lateral = self._parse_lateral()
if lateral:
@ -1372,7 +1386,7 @@ class Parser(metaclass=_Parser):
table = self._parse_id_var()
if not table:
self.raise_error("Expected table name")
self.raise_error(f"Expected table name but got {self._curr}")
this = self.expression(
exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
@ -1384,7 +1398,7 @@ class Parser(metaclass=_Parser):
if self.alias_post_tablesample:
table_sample = self._parse_table_sample()
alias = self._parse_table_alias()
alias = self._parse_table_alias(alias_tokens=alias_tokens or self.TABLE_ALIAS_TOKENS)
if alias:
this.set("alias", alias)
@ -2092,10 +2106,14 @@ class Parser(metaclass=_Parser):
kind = self.expression(exp.CheckColumnConstraint, this=constraint)
elif self._match(TokenType.COLLATE):
kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var())
elif self._match(TokenType.ENCODE):
kind = self.expression(exp.EncodeColumnConstraint, this=self._parse_var())
elif self._match(TokenType.DEFAULT):
kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_conjunction())
elif self._match_pair(TokenType.NOT, TokenType.NULL):
kind = exp.NotNullColumnConstraint()
elif self._match(TokenType.NULL):
kind = exp.NotNullColumnConstraint(allow_null=True)
elif self._match(TokenType.SCHEMA_COMMENT):
kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string())
elif self._match(TokenType.PRIMARY_KEY):
@ -2234,7 +2252,7 @@ class Parser(metaclass=_Parser):
return self._parse_window(this)
def _parse_extract(self):
this = self._parse_var() or self._parse_type()
this = self._parse_function() or self._parse_var() or self._parse_type()
if self._match(TokenType.FROM):
return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())
@ -2635,6 +2653,54 @@ class Parser(metaclass=_Parser):
parser = self._find_parser(self.SET_PARSERS, self._set_trie)
return parser(self) if parser else self._default_parse_set_item()
def _parse_merge(self):
self._match(TokenType.INTO)
target = self._parse_table(schema=True)
self._match(TokenType.USING)
using = self._parse_table()
self._match(TokenType.ON)
on = self._parse_conjunction()
whens = []
while self._match(TokenType.WHEN):
this = self._parse_conjunction()
self._match(TokenType.THEN)
if self._match(TokenType.INSERT):
_this = self._parse_star()
if _this:
then = self.expression(exp.Insert, this=_this)
else:
then = self.expression(
exp.Insert,
this=self._parse_value(),
expression=self._match(TokenType.VALUES) and self._parse_value(),
)
elif self._match(TokenType.UPDATE):
expressions = self._parse_star()
if expressions:
then = self.expression(exp.Update, expressions=expressions)
else:
then = self.expression(
exp.Update,
expressions=self._match(TokenType.SET)
and self._parse_csv(self._parse_equality),
)
elif self._match(TokenType.DELETE):
then = self.expression(exp.Var, this=self._prev.text)
whens.append(self.expression(exp.When, this=this, then=then))
return self.expression(
exp.Merge,
this=target,
using=using,
on=on,
expressions=whens,
)
def _parse_set(self):
return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))

View file

@ -47,7 +47,7 @@ class Schema(abc.ABC):
"""
@abc.abstractmethod
def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType.Type:
def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType:
"""
Get the :class:`sqlglot.exp.DataType` type of a column in the schema.
@ -160,8 +160,8 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
super().__init__(schema)
self.visible = visible or {}
self.dialect = dialect
self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {
"STR": exp.DataType.Type.TEXT,
self._type_mapping_cache: t.Dict[str, exp.DataType] = {
"STR": exp.DataType.build("text"),
}
@classmethod
@ -231,18 +231,18 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
visible = self._nested_get(self.table_parts(table_), self.visible)
return [col for col in schema if col in visible] # type: ignore
def get_column_type(
self, table: exp.Table | str, column: exp.Column | str
) -> exp.DataType.Type:
def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
column_name = column if isinstance(column, str) else column.name
table_ = exp.to_table(table)
if table_:
table_schema = self.find(table_)
schema_type = table_schema.get(column_name).upper() # type: ignore
return self._convert_type(schema_type)
table_schema = self.find(table_, raise_on_missing=False)
if table_schema:
schema_type = table_schema.get(column_name).upper() # type: ignore
return self._convert_type(schema_type)
return exp.DataType(this=exp.DataType.Type.UNKNOWN)
raise SchemaError(f"Could not convert table '{table}'")
def _convert_type(self, schema_type: str) -> exp.DataType.Type:
def _convert_type(self, schema_type: str) -> exp.DataType:
"""
Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
@ -257,7 +257,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect)
if expression is None:
raise ValueError(f"Could not parse {schema_type}")
self._type_mapping_cache[schema_type] = expression.this
self._type_mapping_cache[schema_type] = expression # type: ignore
except AttributeError:
raise SchemaError(f"Failed to convert type {schema_type}")

View file

@ -49,6 +49,9 @@ class TokenType(AutoName):
PARAMETER = auto()
SESSION_PARAMETER = auto()
BLOCK_START = auto()
BLOCK_END = auto()
SPACE = auto()
BREAK = auto()
@ -156,6 +159,7 @@ class TokenType(AutoName):
DIV = auto()
DROP = auto()
ELSE = auto()
ENCODE = auto()
END = auto()
ENGINE = auto()
ESCAPE = auto()
@ -207,6 +211,7 @@ class TokenType(AutoName):
LOCATION = auto()
MAP = auto()
MATERIALIZED = auto()
MERGE = auto()
MOD = auto()
NATURAL = auto()
NEXT = auto()
@ -255,6 +260,7 @@ class TokenType(AutoName):
SELECT = auto()
SEMI = auto()
SEPARATOR = auto()
SERDE_PROPERTIES = auto()
SET = auto()
SHOW = auto()
SIMILAR_TO = auto()
@ -267,7 +273,6 @@ class TokenType(AutoName):
TABLE_FORMAT = auto()
TABLE_SAMPLE = auto()
TEMPORARY = auto()
TRANSIENT = auto()
TOP = auto()
THEN = auto()
TRAILING = auto()
@ -420,6 +425,16 @@ class Tokenizer(metaclass=_Tokenizer):
ESCAPES = ["'"]
KEYWORDS = {
**{
f"{key}{postfix}": TokenType.BLOCK_START
for key in ("{{", "{%", "{#")
for postfix in ("", "+", "-")
},
**{
f"{prefix}{key}": TokenType.BLOCK_END
for key in ("}}", "%}", "#}")
for prefix in ("", "+", "-")
},
"/*+": TokenType.HINT,
"==": TokenType.EQ,
"::": TokenType.DCOLON,
@ -523,6 +538,7 @@ class Tokenizer(metaclass=_Tokenizer):
"LOCAL": TokenType.LOCAL,
"LOCATION": TokenType.LOCATION,
"MATERIALIZED": TokenType.MATERIALIZED,
"MERGE": TokenType.MERGE,
"NATURAL": TokenType.NATURAL,
"NEXT": TokenType.NEXT,
"NO ACTION": TokenType.NO_ACTION,
@ -582,7 +598,6 @@ class Tokenizer(metaclass=_Tokenizer):
"TABLESAMPLE": TokenType.TABLE_SAMPLE,
"TEMP": TokenType.TEMPORARY,
"TEMPORARY": TokenType.TEMPORARY,
"TRANSIENT": TokenType.TRANSIENT,
"THEN": TokenType.THEN,
"TRUE": TokenType.TRUE,
"TRAILING": TokenType.TRAILING,