1
0
Fork 0
sqlglot/sqlglot/dataframe/sql/session.py
Daniel Baumann fc63828ee4
Merging upstream version 15.0.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
2025-02-13 15:57:23 +01:00

151 lines
5.2 KiB
Python

from __future__ import annotations
import typing as t
import uuid
from collections import defaultdict
import sqlglot
from sqlglot import expressions as exp
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.dataframe import DataFrame
from sqlglot.dataframe.sql.readwriter import DataFrameReader
from sqlglot.dataframe.sql.types import StructType
from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput
class SparkSession:
known_ids: t.ClassVar[t.Set[str]] = set()
known_branch_ids: t.ClassVar[t.Set[str]] = set()
known_sequence_ids: t.ClassVar[t.Set[str]] = set()
name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list)
def __init__(self):
self.incrementing_id = 1
def __getattr__(self, name: str) -> SparkSession:
return self
def __call__(self, *args, **kwargs) -> SparkSession:
return self
@property
def read(self) -> DataFrameReader:
return DataFrameReader(self)
def table(self, tableName: str) -> DataFrame:
return self.read.table(tableName)
def createDataFrame(
self,
data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
schema: t.Optional[SchemaInput] = None,
samplingRatio: t.Optional[float] = None,
verifySchema: bool = False,
) -> DataFrame:
from sqlglot.dataframe.sql.dataframe import DataFrame
if samplingRatio is not None or verifySchema:
raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
if schema is not None and (
not isinstance(schema, (StructType, str, list))
or (isinstance(schema, list) and not isinstance(schema[0], str))
):
raise NotImplementedError("Only schema of either list or string of list supported")
if not data:
raise ValueError("Must provide data to create into a DataFrame")
column_mapping: t.Dict[str, t.Optional[str]]
if schema is not None:
column_mapping = get_column_mapping_from_schema_input(schema)
elif isinstance(data[0], dict):
column_mapping = {col_name.strip(): None for col_name in data[0]}
else:
column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
data_expressions = [
exp.Tuple(
expressions=list(
map(
lambda x: F.lit(x).expression,
row if not isinstance(row, dict) else row.values(),
)
)
)
for row in data
]
sel_columns = [
F.col(name).cast(data_type).alias(name).expression
if data_type is not None
else F.col(name).expression
for name, data_type in column_mapping.items()
]
select_kwargs = {
"expressions": sel_columns,
"from": exp.From(
this=exp.Values(
expressions=data_expressions,
alias=exp.TableAlias(
this=exp.to_identifier(self._auto_incrementing_name),
columns=[exp.to_identifier(col_name) for col_name in column_mapping],
),
),
),
}
sel_expression = exp.Select(**select_kwargs)
return DataFrame(self, sel_expression)
def sql(self, sqlQuery: str) -> DataFrame:
expression = sqlglot.parse_one(sqlQuery, read="spark")
if isinstance(expression, exp.Select):
df = DataFrame(self, expression)
df = df._convert_leaf_to_cte()
elif isinstance(expression, (exp.Create, exp.Insert)):
select_expression = expression.expression.copy()
if isinstance(expression, exp.Insert):
select_expression.set("with", expression.args.get("with"))
expression.set("with", None)
del expression.args["expression"]
df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore
df = df._convert_leaf_to_cte()
else:
raise ValueError(
"Unknown expression type provided in the SQL. Please create an issue with the SQL."
)
return df
@property
def _auto_incrementing_name(self) -> str:
name = f"a{self.incrementing_id}"
self.incrementing_id += 1
return name
@property
def _random_branch_id(self) -> str:
id = self._random_id
self.known_branch_ids.add(id)
return id
@property
def _random_sequence_id(self):
id = self._random_id
self.known_sequence_ids.add(id)
return id
@property
def _random_id(self) -> str:
id = "r" + uuid.uuid4().hex
self.known_ids.add(id)
return id
@property
def _join_hint_names(self) -> t.Set[str]:
return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"}
def _add_alias_to_mapping(self, name: str, sequence_id: str):
self.name_to_sequence_id_mapping[name].append(sequence_id)