1
0
Fork 0

Merging upstream version 18.13.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:08:10 +01:00
parent a56b8dde5c
commit 320822f1c4
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
76 changed files with 21248 additions and 19605 deletions

View file

@ -1019,11 +1019,11 @@ def posexplode(col: ColumnOrName) -> Column:
def explode_outer(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "EXPLODE_OUTER")
return Column.invoke_expression_over_column(col, expression.ExplodeOuter)
def posexplode_outer(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "POSEXPLODE_OUTER")
return Column.invoke_expression_over_column(col, expression.PosexplodeOuter)
def get_json_object(col: ColumnOrName, path: str) -> Column:

View file

@ -10,6 +10,7 @@ from sqlglot.tokens import TokenType
class Databricks(Spark):
class Parser(Spark.Parser):
LOG_DEFAULTS_TO_LN = True
STRICT_CAST = True
FUNCTIONS = {
**Spark.Parser.FUNCTIONS,
@ -51,6 +52,8 @@ class Databricks(Spark):
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}
TRANSFORMS.pop(exp.TryCast)
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
constraint = expression.find(exp.GeneratedAsIdentityColumnConstraint)
kind = expression.args.get("kind")

View file

@ -133,6 +133,10 @@ class DuckDB(Dialect):
"UINTEGER": TokenType.UINT,
"USMALLINT": TokenType.USMALLINT,
"UTINYINT": TokenType.UTINYINT,
"TIMESTAMP_S": TokenType.TIMESTAMP_S,
"TIMESTAMP_MS": TokenType.TIMESTAMP_MS,
"TIMESTAMP_NS": TokenType.TIMESTAMP_NS,
"TIMESTAMP_US": TokenType.TIMESTAMP,
}
class Parser(parser.Parser):
@ -321,6 +325,9 @@ class DuckDB(Dialect):
exp.DataType.Type.UINT: "UINTEGER",
exp.DataType.Type.VARBINARY: "BLOB",
exp.DataType.Type.VARCHAR: "TEXT",
exp.DataType.Type.TIMESTAMP_S: "TIMESTAMP_S",
exp.DataType.Type.TIMESTAMP_MS: "TIMESTAMP_MS",
exp.DataType.Type.TIMESTAMP_NS: "TIMESTAMP_NS",
}
STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"}

View file

@ -82,7 +82,6 @@ class Oracle(Dialect):
this=self._parse_format_json(self._parse_bitwise()),
order=self._parse_order(),
),
"JSON_TABLE": lambda self: self._parse_json_table(),
"XMLTABLE": _parse_xml_table,
}
@ -96,29 +95,6 @@ class Oracle(Dialect):
# Reference: https://stackoverflow.com/a/336455
DISTINCT_TOKENS = {TokenType.DISTINCT, TokenType.UNIQUE}
# Note: this is currently incomplete; it only implements the "JSON_value_column" part
def _parse_json_column_def(self) -> exp.JSONColumnDef:
this = self._parse_id_var()
kind = self._parse_types(allow_identifiers=False)
path = self._match_text_seq("PATH") and self._parse_string()
return self.expression(exp.JSONColumnDef, this=this, kind=kind, path=path)
def _parse_json_table(self) -> exp.JSONTable:
this = self._parse_format_json(self._parse_bitwise())
path = self._match(TokenType.COMMA) and self._parse_string()
error_handling = self._parse_on_handling("ERROR", "ERROR", "NULL")
empty_handling = self._parse_on_handling("EMPTY", "ERROR", "NULL")
self._match(TokenType.COLUMN)
expressions = self._parse_wrapped_csv(self._parse_json_column_def, optional=True)
return exp.JSONTable(
this=this,
expressions=expressions,
path=path,
error_handling=error_handling,
empty_handling=empty_handling,
)
def _parse_json_array(self, expr_type: t.Type[E], **kwargs) -> E:
return self.expression(
expr_type,

View file

@ -34,7 +34,7 @@ def _approx_distinct_sql(self: Presto.Generator, expression: exp.ApproxDistinct)
def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> str:
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
if isinstance(expression.this, exp.Explode):
expression = expression.copy()
return self.sql(
exp.Join(

View file

@ -58,6 +58,11 @@ class Redshift(Postgres):
"STRTOL": exp.FromBase.from_arg_list,
}
NO_PAREN_FUNCTION_PARSERS = {
**Postgres.Parser.NO_PAREN_FUNCTION_PARSERS,
"APPROXIMATE": lambda self: self._parse_approximate_count(),
}
def _parse_table(
self,
schema: bool = False,
@ -93,11 +98,22 @@ class Redshift(Postgres):
return this
def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]:
def _parse_convert(
self, strict: bool, safe: t.Optional[bool] = None
) -> t.Optional[exp.Expression]:
to = self._parse_types()
self._match(TokenType.COMMA)
this = self._parse_bitwise()
return self.expression(exp.TryCast, this=this, to=to)
return self.expression(exp.TryCast, this=this, to=to, safe=safe)
def _parse_approximate_count(self) -> t.Optional[exp.ApproxDistinct]:
index = self._index - 1
func = self._parse_function()
if isinstance(func, exp.Count) and isinstance(func.this, exp.Distinct):
return self.expression(exp.ApproxDistinct, this=seq_get(func.this.expressions, 0))
self._retreat(index)
return None
class Tokenizer(Postgres.Tokenizer):
BIT_STRINGS = []
@ -144,6 +160,7 @@ class Redshift(Postgres):
**Postgres.Generator.TRANSFORMS,
exp.Concat: concat_to_dpipe_sql,
exp.ConcatWs: concat_ws_to_dpipe_sql,
exp.ApproxDistinct: lambda self, e: f"APPROXIMATE COUNT(DISTINCT {self.sql(e, 'this')})",
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
exp.DateAdd: lambda self, e: self.func(
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this

View file

@ -76,6 +76,9 @@ class Spark(Spark2):
exp.TimestampAdd: lambda self, e: self.func(
"DATEADD", e.args.get("unit") or "DAY", e.expression, e.this
),
exp.TryCast: lambda self, e: self.trycast_sql(e)
if e.args.get("safe")
else self.cast_sql(e),
}
TRANSFORMS.pop(exp.AnyValue)
TRANSFORMS.pop(exp.DateDiff)

View file

@ -477,7 +477,9 @@ class TSQL(Dialect):
returns.set("table", table)
return returns
def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]:
def _parse_convert(
self, strict: bool, safe: t.Optional[bool] = None
) -> t.Optional[exp.Expression]:
to = self._parse_types()
self._match(TokenType.COMMA)
this = self._parse_conjunction()
@ -513,12 +515,13 @@ class TSQL(Dialect):
exp.Cast if strict else exp.TryCast,
to=to,
this=self.expression(exp.TimeToStr, this=this, format=format_norm),
safe=safe,
)
elif to.this == DataType.Type.TEXT:
return self.expression(exp.TimeToStr, this=this, format=format_norm)
# Entails a simple cast without any format requirement
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to, safe=safe)
def _parse_user_defined_function(
self, kind: t.Optional[TokenType] = None

View file

@ -105,7 +105,7 @@ class RowReader:
return self.row[self.columns[column]]
class Tables(AbstractMappingSchema[Table]):
class Tables(AbstractMappingSchema):
pass

View file

@ -487,7 +487,7 @@ class Expression(metaclass=_Expression):
"""
for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not type(n) is self.__class__):
if not type(node) is self.__class__:
yield node.unnest() if unnest else node
yield node.unnest() if unnest and not isinstance(node, Subquery) else node
def __str__(self) -> str:
return self.sql()
@ -2107,7 +2107,7 @@ class LockingProperty(Property):
arg_types = {
"this": False,
"kind": True,
"for_or_in": True,
"for_or_in": False,
"lock_type": True,
"override": False,
}
@ -3605,6 +3605,9 @@ class DataType(Expression):
TIMESTAMP = auto()
TIMESTAMPLTZ = auto()
TIMESTAMPTZ = auto()
TIMESTAMP_S = auto()
TIMESTAMP_MS = auto()
TIMESTAMP_NS = auto()
TINYINT = auto()
TSMULTIRANGE = auto()
TSRANGE = auto()
@ -3661,6 +3664,9 @@ class DataType(Expression):
Type.TIMESTAMP,
Type.TIMESTAMPTZ,
Type.TIMESTAMPLTZ,
Type.TIMESTAMP_S,
Type.TIMESTAMP_MS,
Type.TIMESTAMP_NS,
Type.DATE,
Type.DATETIME,
Type.DATETIME64,
@ -4286,7 +4292,7 @@ class Case(Func):
class Cast(Func):
arg_types = {"this": True, "to": True, "format": False}
arg_types = {"this": True, "to": True, "format": False, "safe": False}
@property
def name(self) -> str:
@ -4538,6 +4544,18 @@ class Explode(Func):
pass
class ExplodeOuter(Explode):
pass
class Posexplode(Explode):
pass
class PosexplodeOuter(Posexplode):
pass
class Floor(Func):
arg_types = {"this": True, "decimals": False}
@ -4621,14 +4639,18 @@ class JSONArrayAgg(Func):
# https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/JSON_TABLE.html
# Note: parsing of JSON column definitions is currently incomplete.
class JSONColumnDef(Expression):
arg_types = {"this": True, "kind": False, "path": False}
arg_types = {"this": False, "kind": False, "path": False, "nested_schema": False}
class JSONSchema(Expression):
arg_types = {"expressions": True}
# # https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/JSON_TABLE.html
class JSONTable(Func):
arg_types = {
"this": True,
"expressions": True,
"schema": True,
"path": False,
"error_handling": False,
"empty_handling": False,
@ -4790,10 +4812,6 @@ class Nvl2(Func):
arg_types = {"this": True, "true": True, "false": False}
class Posexplode(Func):
pass
# https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-predict#mlpredict_function
class Predict(Func):
arg_types = {"this": True, "expression": True, "params_struct": False}

View file

@ -1226,9 +1226,10 @@ class Generator:
kind = expression.args.get("kind")
this = f" {self.sql(expression, 'this')}" if expression.this else ""
for_or_in = expression.args.get("for_or_in")
for_or_in = f" {for_or_in}" if for_or_in else ""
lock_type = expression.args.get("lock_type")
override = " OVERRIDE" if expression.args.get("override") else ""
return f"LOCKING {kind}{this} {for_or_in} {lock_type}{override}"
return f"LOCKING {kind}{this}{for_or_in} {lock_type}{override}"
def withdataproperty_sql(self, expression: exp.WithDataProperty) -> str:
data_sql = f"WITH {'NO ' if expression.args.get('no') else ''}DATA"
@ -2179,13 +2180,21 @@ class Generator:
)
def jsoncolumndef_sql(self, expression: exp.JSONColumnDef) -> str:
path = self.sql(expression, "path")
path = f" PATH {path}" if path else ""
nested_schema = self.sql(expression, "nested_schema")
if nested_schema:
return f"NESTED{path} {nested_schema}"
this = self.sql(expression, "this")
kind = self.sql(expression, "kind")
kind = f" {kind}" if kind else ""
path = self.sql(expression, "path")
path = f" PATH {path}" if path else ""
return f"{this}{kind}{path}"
def jsonschema_sql(self, expression: exp.JSONSchema) -> str:
return self.func("COLUMNS", *expression.expressions)
def jsontable_sql(self, expression: exp.JSONTable) -> str:
this = self.sql(expression, "this")
path = self.sql(expression, "path")
@ -2194,9 +2203,9 @@ class Generator:
error_handling = f" {error_handling}" if error_handling else ""
empty_handling = expression.args.get("empty_handling")
empty_handling = f" {empty_handling}" if empty_handling else ""
columns = f" COLUMNS ({self.expressions(expression, skip_first=True)})"
schema = self.sql(expression, "schema")
return self.func(
"JSON_TABLE", this, suffix=f"{path}{error_handling}{empty_handling}{columns})"
"JSON_TABLE", this, suffix=f"{path}{error_handling}{empty_handling} {schema})"
)
def openjsoncolumndef_sql(self, expression: exp.OpenJSONColumnDef) -> str:

View file

@ -441,6 +441,14 @@ def first(it: t.Iterable[T]) -> T:
def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
"""
Merges a sequence of ranges, represented as tuples (low, high) whose values
belong to some totally-ordered set.
Example:
>>> merge_ranges([(1, 3), (2, 6)])
[(1, 6)]
"""
if not ranges:
return []

View file

@ -6,6 +6,7 @@ from sqlglot import exp
from sqlglot.errors import OptimizeError
from sqlglot.generator import cached_generator
from sqlglot.helper import while_changing
from sqlglot.optimizer.scope import find_all_in_scope
from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
logger = logging.getLogger("sqlglot")
@ -63,15 +64,33 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
return expression
def normalized(expression, dnf=False):
ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root))
def normalization_distance(expression, dnf=False):
def normalized(expression: exp.Expression, dnf: bool = False) -> bool:
"""
The difference in the number of predicates between the current expression and the normalized form.
Checks whether a given expression is in a normal form of interest.
Example:
>>> from sqlglot import parse_one
>>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True)
True
>>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default
True
>>> normalized(parse_one("a AND (b OR c)"), dnf=True)
False
Args:
expression: The expression to check if it's normalized.
dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF).
Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
"""
ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
return not any(
connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root)
)
def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int:
"""
The difference in the number of predicates between a given expression and its normalized form.
This is used as an estimate of the cost of the conversion which is exponential in complexity.
@ -82,10 +101,12 @@ def normalization_distance(expression, dnf=False):
4
Args:
expression (sqlglot.Expression): expression to compute distance
dnf (bool): compute to dnf distance instead
expression: The expression to compute the normalization distance for.
dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF).
Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
Returns:
int: difference
The normalization distance.
"""
return sum(_predicate_lengths(expression, dnf)) - (
sum(1 for _ in expression.find_all(exp.Connector)) + 1

View file

@ -39,10 +39,14 @@ def optimize_joins(expression):
if len(other_table_names(dep)) < 2:
continue
operator = type(on)
for predicate in on.flatten():
if name in exp.column_table_names(predicate):
predicate.replace(exp.true())
join.on(predicate, copy=False)
predicate = exp._combine(
[join.args.get("on"), predicate], operator, copy=False
)
join.on(predicate, append=False, copy=False)
expression = reorder_joins(expression)
expression = normalize(expression)

View file

@ -9,7 +9,9 @@ from sqlglot.schema import ensure_schema
SELECT_ALL = object()
# Selection to use if selection list is empty
DEFAULT_SELECTION = lambda: alias("1", "_")
DEFAULT_SELECTION = lambda is_agg: alias(
exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_"
)
def pushdown_projections(expression, schema=None, remove_unused_selections=True):
@ -98,6 +100,7 @@ def _remove_unused_selections(scope, parent_selections, schema, alias_count):
new_selections = []
removed = False
star = False
is_agg = False
select_all = SELECT_ALL in parent_selections
@ -112,6 +115,9 @@ def _remove_unused_selections(scope, parent_selections, schema, alias_count):
star = True
removed = True
if not is_agg and selection.find(exp.AggFunc):
is_agg = True
if star:
resolver = Resolver(scope, schema)
names = {s.alias_or_name for s in new_selections}
@ -124,7 +130,7 @@ def _remove_unused_selections(scope, parent_selections, schema, alias_count):
# If there are no remaining selections, just select a single constant
if not new_selections:
new_selections.append(DEFAULT_SELECTION())
new_selections.append(DEFAULT_SELECTION(is_agg))
scope.expression.select(*new_selections, append=False, copy=False)

View file

@ -137,8 +137,8 @@ class Scope:
if not self._collected:
self._collect()
def walk(self, bfs=True):
return walk_in_scope(self.expression, bfs=bfs)
def walk(self, bfs=True, prune=None):
return walk_in_scope(self.expression, bfs=bfs, prune=None)
def find(self, *expression_types, bfs=True):
return find_in_scope(self.expression, expression_types, bfs=bfs)
@ -731,7 +731,7 @@ def _traverse_ddl(scope):
yield from _traverse_scope(query_scope)
def walk_in_scope(expression, bfs=True):
def walk_in_scope(expression, bfs=True, prune=None):
"""
Returns a generator object which visits all nodes in the syntrax tree, stopping at
nodes that start child scopes.
@ -740,16 +740,20 @@ def walk_in_scope(expression, bfs=True):
expression (exp.Expression):
bfs (bool): if set to True the BFS traversal order will be applied,
otherwise the DFS traversal will be used instead.
prune ((node, parent, arg_key) -> bool): callable that returns True if
the generator should stop traversing this branch of the tree.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
"""
# We'll use this variable to pass state into the dfs generator.
# Whenever we set it to True, we exclude a subtree from traversal.
prune = False
crossed_scope_boundary = False
for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
prune = False
for node, parent, key in expression.walk(
bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args))
):
crossed_scope_boundary = False
yield node, parent, key
@ -765,7 +769,7 @@ def walk_in_scope(expression, bfs=True):
or isinstance(node, exp.UDTF)
or isinstance(node, exp.Subqueryable)
):
prune = True
crossed_scope_boundary = True
if isinstance(node, (exp.Subquery, exp.UDTF)):
# The following args are not actually in the inner scope, so we should visit them

View file

@ -5,9 +5,11 @@ import typing as t
from collections import deque
from decimal import Decimal
import sqlglot
from sqlglot import exp
from sqlglot.generator import cached_generator
from sqlglot.helper import first, merge_ranges, while_changing
from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
# Final means that an expression should not be simplified
FINAL = "final"
@ -17,7 +19,7 @@ class UnsupportedUnit(Exception):
pass
def simplify(expression):
def simplify(expression, constant_propagation=False):
"""
Rewrite sqlglot AST to simplify expressions.
@ -29,6 +31,8 @@ def simplify(expression):
Args:
expression (sqlglot.Expression): expression to simplify
constant_propagation: whether or not the constant propagation rule should be used
Returns:
sqlglot.Expression: simplified expression
"""
@ -67,13 +71,16 @@ def simplify(expression):
node = absorb_and_eliminate(node, root)
node = simplify_concat(node)
if constant_propagation:
node = propagate_constants(node, root)
exp.replace_children(node, lambda e: _simplify(e, False))
# Post-order transformations
node = simplify_not(node)
node = flatten(node)
node = simplify_connectors(node, root)
node = remove_compliments(node, root)
node = remove_complements(node, root)
node = simplify_coalesce(node)
node.parent = expression.parent
node = simplify_literals(node, root)
@ -287,19 +294,19 @@ def _simplify_comparison(expression, left, right, or_=False):
return None
def remove_compliments(expression, root=True):
def remove_complements(expression, root=True):
"""
Removing compliments.
Removing complements.
A AND NOT A -> FALSE
A OR NOT A -> TRUE
"""
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
complement = exp.false() if isinstance(expression, exp.And) else exp.true()
for a, b in itertools.permutations(expression.flatten(), 2):
if is_complement(a, b):
return compliment
return complement
return expression
@ -369,6 +376,51 @@ def absorb_and_eliminate(expression, root=True):
return expression
def propagate_constants(expression, root=True):
"""
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes
SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
"""
if (
isinstance(expression, exp.And)
and (root or not expression.same_parent)
and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
):
constant_mapping = {}
for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
if isinstance(expr, exp.EQ):
l, r = expr.left, expr.right
# TODO: create a helper that can be used to detect nested literal expressions such
# as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
pass
elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
l, r = r, l
else:
continue
constant_mapping[l] = (id(l), r)
if constant_mapping:
for column in find_all_in_scope(expression, exp.Column):
parent = column.parent
column_id, constant = constant_mapping.get(column) or (None, None)
if (
column_id is not None
and id(column) != column_id
and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
):
column.replace(constant.copy())
return expression
INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
exp.DateAdd: exp.Sub,
exp.DateSub: exp.Add,
@ -609,21 +661,38 @@ SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
def simplify_concat(expression):
"""Reduces all groups that contain string literals by concatenating them."""
if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs):
if not isinstance(expression, CONCATS) or (
# We can't reduce a CONCAT_WS call if we don't statically know the separator
isinstance(expression, exp.ConcatWs)
and not expression.expressions[0].is_string
):
return expression
if isinstance(expression, exp.ConcatWs):
sep_expr, *expressions = expression.expressions
sep = sep_expr.name
concat_type = exp.ConcatWs
else:
expressions = expression.expressions
sep = ""
concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
new_args = []
for is_string_group, group in itertools.groupby(
expression.expressions or expression.flatten(), lambda e: e.is_string
expressions or expression.flatten(), lambda e: e.is_string
):
if is_string_group:
new_args.append(exp.Literal.string("".join(string.name for string in group)))
new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
else:
new_args.extend(group)
# Ensures we preserve the right concat type, i.e. whether it's "safe" or not
concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
if len(new_args) == 1 and new_args[0].is_string:
return new_args[0]
if concat_type is exp.ConcatWs:
new_args = [sep_expr] + new_args
return concat_type(expressions=new_args)
DateRange = t.Tuple[datetime.date, datetime.date]

View file

@ -160,6 +160,9 @@ class Parser(metaclass=_Parser):
TokenType.TIME,
TokenType.TIMETZ,
TokenType.TIMESTAMP,
TokenType.TIMESTAMP_S,
TokenType.TIMESTAMP_MS,
TokenType.TIMESTAMP_NS,
TokenType.TIMESTAMPTZ,
TokenType.TIMESTAMPLTZ,
TokenType.DATETIME,
@ -792,17 +795,18 @@ class Parser(metaclass=_Parser):
"DECODE": lambda self: self._parse_decode(),
"EXTRACT": lambda self: self._parse_extract(),
"JSON_OBJECT": lambda self: self._parse_json_object(),
"JSON_TABLE": lambda self: self._parse_json_table(),
"LOG": lambda self: self._parse_logarithm(),
"MATCH": lambda self: self._parse_match_against(),
"OPENJSON": lambda self: self._parse_open_json(),
"POSITION": lambda self: self._parse_position(),
"PREDICT": lambda self: self._parse_predict(),
"SAFE_CAST": lambda self: self._parse_cast(False),
"SAFE_CAST": lambda self: self._parse_cast(False, safe=True),
"STRING_AGG": lambda self: self._parse_string_agg(),
"SUBSTRING": lambda self: self._parse_substring(),
"TRIM": lambda self: self._parse_trim(),
"TRY_CAST": lambda self: self._parse_cast(False),
"TRY_CONVERT": lambda self: self._parse_convert(False),
"TRY_CAST": lambda self: self._parse_cast(False, safe=True),
"TRY_CONVERT": lambda self: self._parse_convert(False, safe=True),
}
QUERY_MODIFIER_PARSERS = {
@ -4135,7 +4139,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.AnyValue, this=this, having=having, max=is_max)
def _parse_cast(self, strict: bool) -> exp.Expression:
def _parse_cast(self, strict: bool, safe: t.Optional[bool] = None) -> exp.Expression:
this = self._parse_conjunction()
if not self._match(TokenType.ALIAS):
@ -4176,7 +4180,9 @@ class Parser(metaclass=_Parser):
return this
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to, format=fmt)
return self.expression(
exp.Cast if strict else exp.TryCast, this=this, to=to, format=fmt, safe=safe
)
def _parse_concat(self) -> t.Optional[exp.Expression]:
args = self._parse_csv(self._parse_conjunction)
@ -4230,7 +4236,9 @@ class Parser(metaclass=_Parser):
order = self._parse_order(this=seq_get(args, 0))
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]:
def _parse_convert(
self, strict: bool, safe: t.Optional[bool] = None
) -> t.Optional[exp.Expression]:
this = self._parse_bitwise()
if self._match(TokenType.USING):
@ -4242,7 +4250,7 @@ class Parser(metaclass=_Parser):
else:
to = None
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to, safe=safe)
def _parse_decode(self) -> t.Optional[exp.Decode | exp.Case]:
"""
@ -4347,6 +4355,50 @@ class Parser(metaclass=_Parser):
encoding=encoding,
)
# Note: this is currently incomplete; it only implements the "JSON_value_column" part
def _parse_json_column_def(self) -> exp.JSONColumnDef:
if not self._match_text_seq("NESTED"):
this = self._parse_id_var()
kind = self._parse_types(allow_identifiers=False)
nested = None
else:
this = None
kind = None
nested = True
path = self._match_text_seq("PATH") and self._parse_string()
nested_schema = nested and self._parse_json_schema()
return self.expression(
exp.JSONColumnDef,
this=this,
kind=kind,
path=path,
nested_schema=nested_schema,
)
def _parse_json_schema(self) -> exp.JSONSchema:
self._match_text_seq("COLUMNS")
return self.expression(
exp.JSONSchema,
expressions=self._parse_wrapped_csv(self._parse_json_column_def, optional=True),
)
def _parse_json_table(self) -> exp.JSONTable:
this = self._parse_format_json(self._parse_bitwise())
path = self._match(TokenType.COMMA) and self._parse_string()
error_handling = self._parse_on_handling("ERROR", "ERROR", "NULL")
empty_handling = self._parse_on_handling("EMPTY", "ERROR", "NULL")
schema = self._parse_json_schema()
return exp.JSONTable(
this=this,
schema=schema,
path=path,
error_handling=error_handling,
empty_handling=empty_handling,
)
def _parse_logarithm(self) -> exp.Func:
# Default argument order is base, expression
args = self._parse_csv(self._parse_range)
@ -4973,7 +5025,17 @@ class Parser(metaclass=_Parser):
self._match(TokenType.ON)
on = self._parse_conjunction()
return self.expression(
exp.Merge,
this=target,
using=using,
on=on,
expressions=self._parse_when_matched(),
)
def _parse_when_matched(self) -> t.List[exp.When]:
whens = []
while self._match(TokenType.WHEN):
matched = not self._match(TokenType.NOT)
self._match_text_seq("MATCHED")
@ -5020,14 +5082,7 @@ class Parser(metaclass=_Parser):
then=then,
)
)
return self.expression(
exp.Merge,
this=target,
using=using,
on=on,
expressions=whens,
)
return whens
def _parse_show(self) -> t.Optional[exp.Expression]:
parser = self._find_parser(self.SHOW_PARSERS, self.SHOW_TRIE)

View file

@ -5,7 +5,6 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
from sqlglot._typing import T
from sqlglot.dialects.dialect import Dialect
from sqlglot.errors import ParseError, SchemaError
from sqlglot.helper import dict_depth
@ -71,7 +70,7 @@ class Schema(abc.ABC):
def get_column_type(
self,
table: exp.Table | str,
column: exp.Column,
column: exp.Column | str,
dialect: DialectType = None,
normalize: t.Optional[bool] = None,
) -> exp.DataType:
@ -88,6 +87,28 @@ class Schema(abc.ABC):
The resulting column type.
"""
def has_column(
self,
table: exp.Table | str,
column: exp.Column | str,
dialect: DialectType = None,
normalize: t.Optional[bool] = None,
) -> bool:
"""
Returns whether or not `column` appears in `table`'s schema.
Args:
table: the source table.
column: the target column.
dialect: the SQL dialect that will be used to parse `table` if it's a string.
normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
"""
name = column if isinstance(column, str) else column.name
return name in self.column_names(table, dialect=dialect, normalize=normalize)
@property
@abc.abstractmethod
def supported_table_args(self) -> t.Tuple[str, ...]:
@ -101,7 +122,7 @@ class Schema(abc.ABC):
return True
class AbstractMappingSchema(t.Generic[T]):
class AbstractMappingSchema:
def __init__(
self,
mapping: t.Optional[t.Dict] = None,
@ -140,7 +161,7 @@ class AbstractMappingSchema(t.Generic[T]):
def find(
self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
) -> t.Optional[T]:
) -> t.Optional[t.Any]:
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
@ -170,7 +191,7 @@ class AbstractMappingSchema(t.Generic[T]):
)
class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
class MappingSchema(AbstractMappingSchema, Schema):
"""
Schema based on a nested mapping.
@ -287,7 +308,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
def get_column_type(
self,
table: exp.Table | str,
column: exp.Column,
column: exp.Column | str,
dialect: DialectType = None,
normalize: t.Optional[bool] = None,
) -> exp.DataType:
@ -304,10 +325,26 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
if isinstance(column_type, exp.DataType):
return column_type
elif isinstance(column_type, str):
return self._to_data_type(column_type.upper(), dialect=dialect)
return self._to_data_type(column_type, dialect=dialect)
return exp.DataType.build("unknown")
def has_column(
self,
table: exp.Table | str,
column: exp.Column | str,
dialect: DialectType = None,
normalize: t.Optional[bool] = None,
) -> bool:
normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
normalized_column_name = self._normalize_name(
column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
)
table_schema = self.find(normalized_table, raise_on_missing=False)
return normalized_column_name in table_schema if table_schema else False
def _normalize(self, schema: t.Dict) -> t.Dict:
"""
Normalizes all identifiers in the schema.

View file

@ -121,6 +121,9 @@ class TokenType(AutoName):
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto()
TIMESTAMP_S = auto()
TIMESTAMP_MS = auto()
TIMESTAMP_NS = auto()
DATETIME = auto()
DATETIME64 = auto()
DATE = auto()

View file

@ -189,9 +189,9 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
# we use list here because expression.selects is mutated inside the loop
for select in expression.selects.copy():
explode = select.find(exp.Explode, exp.Posexplode)
explode = select.find(exp.Explode)
if isinstance(explode, (exp.Explode, exp.Posexplode)):
if explode:
pos_alias = ""
explode_alias = ""
@ -204,7 +204,7 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
alias = select.replace(exp.alias_(select.this, "", copy=False))
else:
alias = select.replace(exp.alias_(select, ""))
explode = alias.find(exp.Explode, exp.Posexplode)
explode = alias.find(exp.Explode)
assert explode
is_posexplode = isinstance(explode, exp.Posexplode)