Adding upstream version 15.0.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
70d5d3451a
commit
bb75596aa9
167 changed files with 58268 additions and 51337 deletions
|
@ -3,7 +3,12 @@ from __future__ import annotations
|
|||
import typing as t
|
||||
|
||||
from sqlglot import exp, parser, transforms
|
||||
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
|
||||
from sqlglot.dialects.dialect import (
|
||||
create_with_partitions_sql,
|
||||
pivot_column_names,
|
||||
rename_func,
|
||||
trim_sql,
|
||||
)
|
||||
from sqlglot.dialects.hive import Hive
|
||||
from sqlglot.helper import seq_get
|
||||
|
||||
|
@ -26,7 +31,7 @@ def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
|
|||
return f"MAP_FROM_ARRAYS({keys}, {values})"
|
||||
|
||||
|
||||
def _parse_as_cast(to_type: str) -> t.Callable[[t.Sequence], exp.Expression]:
|
||||
def _parse_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]:
|
||||
return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type))
|
||||
|
||||
|
||||
|
@ -53,10 +58,56 @@ def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
|
|||
raise ValueError("Improper scale for timestamp")
|
||||
|
||||
|
||||
def _unalias_pivot(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Spark doesn't allow PIVOT aliases, so we need to remove them and possibly wrap a
|
||||
pivoted source in a subquery with the same alias to preserve the query's semantics.
|
||||
|
||||
Example:
|
||||
>>> from sqlglot import parse_one
|
||||
>>> expr = parse_one("SELECT piv.x FROM tbl PIVOT (SUM(a) FOR b IN ('x')) piv")
|
||||
>>> print(_unalias_pivot(expr).sql(dialect="spark"))
|
||||
SELECT piv.x FROM (SELECT * FROM tbl PIVOT(SUM(a) FOR b IN ('x'))) AS piv
|
||||
"""
|
||||
if isinstance(expression, exp.From) and expression.this.args.get("pivots"):
|
||||
pivot = expression.this.args["pivots"][0]
|
||||
if pivot.alias:
|
||||
alias = pivot.args["alias"].pop()
|
||||
return exp.From(
|
||||
this=expression.this.replace(
|
||||
exp.select("*").from_(expression.this.copy()).subquery(alias=alias)
|
||||
)
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Spark doesn't allow the column referenced in the PIVOT's field to be qualified,
|
||||
so we need to unqualify it.
|
||||
|
||||
Example:
|
||||
>>> from sqlglot import parse_one
|
||||
>>> expr = parse_one("SELECT * FROM tbl PIVOT (SUM(tbl.sales) FOR tbl.quarter IN ('Q1', 'Q2'))")
|
||||
>>> print(_unqualify_pivot_columns(expr).sql(dialect="spark"))
|
||||
SELECT * FROM tbl PIVOT(SUM(tbl.sales) FOR quarter IN ('Q1', 'Q1'))
|
||||
"""
|
||||
if isinstance(expression, exp.Pivot):
|
||||
expression.args["field"].transform(
|
||||
lambda node: exp.column(node.output_name, quoted=node.this.quoted)
|
||||
if isinstance(node, exp.Column)
|
||||
else node,
|
||||
copy=False,
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
class Spark2(Hive):
|
||||
class Parser(Hive.Parser):
|
||||
FUNCTIONS = {
|
||||
**Hive.Parser.FUNCTIONS, # type: ignore
|
||||
**Hive.Parser.FUNCTIONS,
|
||||
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
|
||||
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
|
||||
"LEFT": lambda args: exp.Substring(
|
||||
|
@ -110,7 +161,7 @@ class Spark2(Hive):
|
|||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS, # type: ignore
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
|
||||
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
|
||||
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
|
||||
|
@ -131,43 +182,21 @@ class Spark2(Hive):
|
|||
kind="COLUMNS",
|
||||
)
|
||||
|
||||
def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
|
||||
# Spark doesn't add a suffix to the pivot columns when there's a single aggregation
|
||||
if len(pivot_columns) == 1:
|
||||
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
|
||||
if len(aggregations) == 1:
|
||||
return [""]
|
||||
|
||||
names = []
|
||||
for agg in pivot_columns:
|
||||
if isinstance(agg, exp.Alias):
|
||||
names.append(agg.alias)
|
||||
else:
|
||||
"""
|
||||
This case corresponds to aggregations without aliases being used as suffixes
|
||||
(e.g. col_avg(foo)). We need to unquote identifiers because they're going to
|
||||
be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
|
||||
Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
|
||||
|
||||
Moreover, function names are lowercased in order to mimic Spark's naming scheme.
|
||||
"""
|
||||
agg_all_unquoted = agg.transform(
|
||||
lambda node: exp.Identifier(this=node.name, quoted=False)
|
||||
if isinstance(node, exp.Identifier)
|
||||
else node
|
||||
)
|
||||
names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
|
||||
|
||||
return names
|
||||
return pivot_column_names(aggregations, dialect="spark")
|
||||
|
||||
class Generator(Hive.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**Hive.Generator.TYPE_MAPPING, # type: ignore
|
||||
**Hive.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.TINYINT: "BYTE",
|
||||
exp.DataType.Type.SMALLINT: "SHORT",
|
||||
exp.DataType.Type.BIGINT: "LONG",
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**Hive.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
**Hive.Generator.PROPERTIES_LOCATION,
|
||||
exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
|
@ -175,7 +204,7 @@ class Spark2(Hive):
|
|||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**Hive.Generator.TRANSFORMS, # type: ignore
|
||||
**Hive.Generator.TRANSFORMS,
|
||||
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
|
||||
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
|
||||
|
@ -188,11 +217,12 @@ class Spark2(Hive):
|
|||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
|
||||
exp.From: transforms.preprocess([_unalias_pivot]),
|
||||
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.Map: _map_sql,
|
||||
exp.Pivot: transforms.preprocess([transforms.unqualify_pivot_columns]),
|
||||
exp.Pivot: transforms.preprocess([_unqualify_pivot_columns]),
|
||||
exp.Reduce: rename_func("AGGREGATE"),
|
||||
exp.StrToDate: _str_to_date,
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue