1
0
Fork 0

Merging upstream version 21.0.2.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:21:45 +01:00
parent 7d0896f08b
commit b7d506d9b2
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
58 changed files with 25616 additions and 25078 deletions

View file

@ -333,6 +333,7 @@ class DuckDB(Dialect):
IGNORE_NULLS_IN_FUNC = True
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
SUPPORTS_CREATE_TABLE_LIKE = False
MULTI_ARG_DISTINCT = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,

View file

@ -232,6 +232,9 @@ class Postgres(Dialect):
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
HEREDOC_STRINGS = ["$"]
HEREDOC_TAG_IS_IDENTIFIER = True
HEREDOC_STRING_ALTERNATIVE = TokenType.PARAMETER
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"~~": TokenType.LIKE,
@ -381,6 +384,7 @@ class Postgres(Dialect):
JSON_TYPE_REQUIRED_FOR_EXTRACTION = True
SUPPORTS_UNLOGGED_TABLES = True
LIKE_PROPERTY_INSIDE_SCHEMA = True
MULTI_ARG_DISTINCT = False
SUPPORTED_JSON_PATH_PARTS = {
exp.JSONPathKey,

View file

@ -292,6 +292,7 @@ class Presto(Dialect):
LIMIT_ONLY_LITERALS = True
SUPPORTS_SINGLE_ARG_CONCAT = False
LIKE_PROPERTY_INSIDE_SCHEMA = True
MULTI_ARG_DISTINCT = False
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION,

View file

@ -50,9 +50,6 @@ class Spark(Spark2):
"DATEDIFF": _parse_datediff,
}
FUNCTION_PARSERS = Spark2.Parser.FUNCTION_PARSERS.copy()
FUNCTION_PARSERS.pop("ANY_VALUE")
def _parse_generated_as_identity(
self,
) -> (

View file

@ -1796,7 +1796,7 @@ class Lambda(Expression):
class Limit(Expression):
arg_types = {"this": False, "expression": True, "offset": False}
arg_types = {"this": False, "expression": True, "offset": False, "expressions": False}
class Literal(Condition):
@ -1969,7 +1969,7 @@ class Final(Expression):
class Offset(Expression):
arg_types = {"this": False, "expression": True}
arg_types = {"this": False, "expression": True, "expressions": False}
class Order(Expression):
@ -4291,6 +4291,11 @@ class RespectNulls(Expression):
pass
# https://cloud.google.com/bigquery/docs/reference/standard-sql/aggregate-function-calls#max_min_clause
class HavingMax(Expression):
arg_types = {"this": True, "expression": True, "max": True}
# Functions
class Func(Condition):
"""
@ -4491,7 +4496,7 @@ class Avg(AggFunc):
class AnyValue(AggFunc):
arg_types = {"this": True, "having": False, "max": False}
pass
class Lag(AggFunc):

View file

@ -296,6 +296,10 @@ class Generator(metaclass=_Generator):
# Whether or not the LikeProperty needs to be specified inside of the schema clause
LIKE_PROPERTY_INSIDE_SCHEMA = False
# Whether or not DISTINCT can be followed by multiple args in an AggFunc. If not, it will be
# transpiled into a series of CASE-WHEN-ELSE, ultimately using a tuple conseisting of the args
MULTI_ARG_DISTINCT = True
# Whether or not the JSON extraction operators expect a value of type JSON
JSON_TYPE_REQUIRED_FOR_EXTRACTION = False
@ -1841,15 +1845,18 @@ class Generator(metaclass=_Generator):
args_sql = ", ".join(self.sql(e) for e in args)
args_sql = f"({args_sql})" if any(top and not e.is_number for e in args) else args_sql
return f"{this}{self.seg('TOP' if top else 'LIMIT')} {args_sql}"
expressions = self.expressions(expression, flat=True)
expressions = f" BY {expressions}" if expressions else ""
return f"{this}{self.seg('TOP' if top else 'LIMIT')} {args_sql}{expressions}"
def offset_sql(self, expression: exp.Offset) -> str:
this = self.sql(expression, "this")
expression = expression.expression
expression = (
self._simplify_unless_literal(expression) if self.LIMIT_ONLY_LITERALS else expression
)
return f"{this}{self.seg('OFFSET')} {self.sql(expression)}"
value = expression.expression
value = self._simplify_unless_literal(value) if self.LIMIT_ONLY_LITERALS else value
expressions = self.expressions(expression, flat=True)
expressions = f" BY {expressions}" if expressions else ""
return f"{this}{self.seg('OFFSET')} {self.sql(value)}{expressions}"
def setitem_sql(self, expression: exp.SetItem) -> str:
kind = self.sql(expression, "kind")
@ -2834,6 +2841,13 @@ class Generator(metaclass=_Generator):
def distinct_sql(self, expression: exp.Distinct) -> str:
this = self.expressions(expression, flat=True)
if not self.MULTI_ARG_DISTINCT and len(expression.expressions) > 1:
case = exp.case()
for arg in expression.expressions:
case = case.when(arg.is_(exp.null()), exp.null())
this = self.sql(case.else_(f"({this})"))
this = f" {this}" if this else ""
on = self.sql(expression, "on")
@ -2846,13 +2860,33 @@ class Generator(metaclass=_Generator):
def respectnulls_sql(self, expression: exp.RespectNulls) -> str:
return self._embed_ignore_nulls(expression, "RESPECT NULLS")
def havingmax_sql(self, expression: exp.HavingMax) -> str:
this_sql = self.sql(expression, "this")
expression_sql = self.sql(expression, "expression")
kind = "MAX" if expression.args.get("max") else "MIN"
return f"{this_sql} HAVING {kind} {expression_sql}"
def _embed_ignore_nulls(self, expression: exp.IgnoreNulls | exp.RespectNulls, text: str) -> str:
if self.IGNORE_NULLS_IN_FUNC:
this = expression.find(exp.AggFunc)
if this:
sql = self.sql(this)
sql = sql[:-1] + f" {text})"
return sql
if self.IGNORE_NULLS_IN_FUNC and not expression.meta.get("inline"):
# The first modifier here will be the one closest to the AggFunc's arg
mods = sorted(
expression.find_all(exp.HavingMax, exp.Order, exp.Limit),
key=lambda x: 0
if isinstance(x, exp.HavingMax)
else (1 if isinstance(x, exp.Order) else 2),
)
if mods:
mod = mods[0]
this = expression.__class__(this=mod.this.copy())
this.meta["inline"] = True
mod.this.replace(this)
return self.sql(expression.this)
agg_func = expression.find(exp.AggFunc)
if agg_func:
return self.sql(agg_func)[:-1] + f" {text})"
return f"{self.sql(expression, 'this')} {text}"

View file

@ -263,6 +263,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.DateTrunc: lambda self, e: self._annotate_timeunit(e),
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Div: lambda self, e: self._annotate_div(e),
exp.Explode: lambda self, e: self._annotate_explode(e),
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
@ -333,9 +334,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
self._visited: t.Set[int] = set()
def _set_type(
self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type
self, expression: exp.Expression, target_type: t.Optional[exp.DataType | exp.DataType.Type]
) -> None:
expression.type = target_type # type: ignore
expression.type = target_type or exp.DataType.Type.UNKNOWN # type: ignore
self._visited.add(id(expression))
def annotate(self, expression: E) -> E:
@ -564,13 +565,11 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
if isinstance(bracket_arg, exp.Slice):
self._set_type(expression, this.type)
elif this.type.is_type(exp.DataType.Type.ARRAY):
contained_type = seq_get(this.type.expressions, 0) or exp.DataType.Type.UNKNOWN
self._set_type(expression, contained_type)
self._set_type(expression, seq_get(this.type.expressions, 0))
elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys:
index = this.keys.index(bracket_arg)
value = seq_get(this.values, index)
value_type = value.type if value else exp.DataType.Type.UNKNOWN
self._set_type(expression, value_type or exp.DataType.Type.UNKNOWN)
self._set_type(expression, value.type if value else None)
else:
self._set_type(expression, exp.DataType.Type.UNKNOWN)
@ -591,3 +590,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
self._set_type(expression, self._maybe_coerce(left_type, right_type))
return expression
def _annotate_explode(self, expression: exp.Explode) -> exp.Explode:
self._annotate_args(expression)
self._set_type(expression, seq_get(expression.this.type.expressions, 0))
return expression

View file

@ -872,7 +872,6 @@ class Parser(metaclass=_Parser):
FUNCTIONS_WITH_ALIASED_ARGS = {"STRUCT"}
FUNCTION_PARSERS = {
"ANY_VALUE": lambda self: self._parse_any_value(),
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
"DECODE": lambda self: self._parse_decode(),
@ -2465,8 +2464,14 @@ class Parser(metaclass=_Parser):
this.set(key, expression)
if key == "limit":
offset = expression.args.pop("offset", None)
if offset:
this.set("offset", exp.Offset(expression=offset))
offset = exp.Offset(expression=offset)
this.set("offset", offset)
limit_by_expressions = expression.expressions
expression.set("expressions", None)
offset.set("expressions", limit_by_expressions)
continue
break
return this
@ -3341,7 +3346,12 @@ class Parser(metaclass=_Parser):
offset = None
limit_exp = self.expression(
exp.Limit, this=this, expression=expression, offset=offset, comments=comments
exp.Limit,
this=this,
expression=expression,
offset=offset,
comments=comments,
expressions=self._parse_limit_by(),
)
return limit_exp
@ -3377,7 +3387,13 @@ class Parser(metaclass=_Parser):
count = self._parse_term()
self._match_set((TokenType.ROW, TokenType.ROWS))
return self.expression(exp.Offset, this=this, expression=count)
return self.expression(
exp.Offset, this=this, expression=count, expressions=self._parse_limit_by()
)
def _parse_limit_by(self) -> t.Optional[t.List[exp.Expression]]:
return self._match_text_seq("BY") and self._parse_csv(self._parse_bitwise)
def _parse_locks(self) -> t.List[exp.Lock]:
locks = []
@ -4115,7 +4131,9 @@ class Parser(metaclass=_Parser):
else:
this = self._parse_select_or_expression(alias=alias)
return self._parse_limit(self._parse_order(self._parse_respect_or_ignore_nulls(this)))
return self._parse_limit(
self._parse_order(self._parse_having_max(self._parse_respect_or_ignore_nulls(this)))
)
def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
index = self._index
@ -4549,18 +4567,6 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())
def _parse_any_value(self) -> exp.AnyValue:
this = self._parse_lambda()
is_max = None
having = None
if self._match(TokenType.HAVING):
self._match_texts(("MAX", "MIN"))
is_max = self._prev.text == "MAX"
having = self._parse_column()
return self.expression(exp.AnyValue, this=this, having=having, max=is_max)
def _parse_cast(self, strict: bool, safe: t.Optional[bool] = None) -> exp.Expression:
this = self._parse_conjunction()
@ -4941,6 +4947,16 @@ class Parser(metaclass=_Parser):
return self.expression(exp.RespectNulls, this=this)
return this
def _parse_having_max(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if self._match(TokenType.HAVING):
self._match_texts(("MAX", "MIN"))
max = self._prev.text.upper() != "MIN"
return self.expression(
exp.HavingMax, this=this, expression=self._parse_column(), max=max
)
return this
def _parse_window(
self, this: t.Optional[exp.Expression], alias: bool = False
) -> t.Optional[exp.Expression]:

View file

@ -106,19 +106,6 @@ class Schema(abc.ABC):
name = column if isinstance(column, str) else column.name
return name in self.column_names(table, dialect=dialect, normalize=normalize)
@abc.abstractmethod
def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
"""
Returns the schema of a given table.
Args:
table: the target table.
raise_on_missing: whether or not to raise in case the schema is not found.
Returns:
The schema of the target table.
"""
@property
@abc.abstractmethod
def supported_table_args(self) -> t.Tuple[str, ...]:
@ -170,6 +157,16 @@ class AbstractMappingSchema:
return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
"""
Returns the schema of a given table.
Args:
table: the target table.
raise_on_missing: whether or not to raise in case the schema is not found.
Returns:
The schema of the target table.
"""
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
value, trie = in_trie(self.mapping_trie, parts)

View file

@ -504,6 +504,7 @@ class _Tokenizer(type):
command_prefix_tokens={
_TOKEN_TYPE_TO_INDEX[v] for v in klass.COMMAND_PREFIX_TOKENS
},
heredoc_tag_is_identifier=klass.HEREDOC_TAG_IS_IDENTIFIER,
)
token_types = RsTokenTypeSettings(
bit_string=_TOKEN_TYPE_TO_INDEX[TokenType.BIT_STRING],
@ -517,6 +518,7 @@ class _Tokenizer(type):
semicolon=_TOKEN_TYPE_TO_INDEX[TokenType.SEMICOLON],
string=_TOKEN_TYPE_TO_INDEX[TokenType.STRING],
var=_TOKEN_TYPE_TO_INDEX[TokenType.VAR],
heredoc_string_alternative=_TOKEN_TYPE_TO_INDEX[klass.HEREDOC_STRING_ALTERNATIVE],
)
klass._RS_TOKENIZER = RsTokenizer(settings, token_types)
else:
@ -573,6 +575,12 @@ class Tokenizer(metaclass=_Tokenizer):
STRING_ESCAPES = ["'"]
VAR_SINGLE_TOKENS: t.Set[str] = set()
# Whether or not the heredoc tags follow the same lexical rules as unquoted identifiers
HEREDOC_TAG_IS_IDENTIFIER = False
# Token that we'll generate as a fallback if the heredoc prefix doesn't correspond to a heredoc
HEREDOC_STRING_ALTERNATIVE = TokenType.VAR
# Autofilled
_COMMENTS: t.Dict[str, str] = {}
_FORMAT_STRINGS: t.Dict[str, t.Tuple[str, TokenType]] = {}
@ -1249,6 +1257,18 @@ class Tokenizer(metaclass=_Tokenizer):
elif token_type == TokenType.BIT_STRING:
base = 2
elif token_type == TokenType.HEREDOC_STRING:
if (
self.HEREDOC_TAG_IS_IDENTIFIER
and not self._peek.isidentifier()
and not self._peek == end
):
if self.HEREDOC_STRING_ALTERNATIVE != token_type.VAR:
self._add(self.HEREDOC_STRING_ALTERNATIVE)
else:
self._scan_var()
return True
self._advance()
tag = "" if self._char == end else self._extract_string(end)
end = f"{start}{tag}{end}"