1
0
Fork 0
sqlglot/sqlglot/dataframe/sql/session.py
Daniel Baumann bf03050a25
Merging upstream version 21.1.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
2025-02-13 21:27:51 +01:00

190 lines
6.4 KiB
Python

from __future__ import annotations
import typing as t
import uuid
from collections import defaultdict
import sqlglot
from sqlglot import Dialect, expressions as exp
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.dataframe import DataFrame
from sqlglot.dataframe.sql.readwriter import DataFrameReader
from sqlglot.dataframe.sql.types import StructType
from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input
from sqlglot.helper import classproperty
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput
class SparkSession:
DEFAULT_DIALECT = "spark"
_instance = None
def __init__(self):
if not hasattr(self, "known_ids"):
self.known_ids = set()
self.known_branch_ids = set()
self.known_sequence_ids = set()
self.name_to_sequence_id_mapping = defaultdict(list)
self.incrementing_id = 1
self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT)
def __new__(cls, *args, **kwargs) -> SparkSession:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
@property
def read(self) -> DataFrameReader:
return DataFrameReader(self)
def table(self, tableName: str) -> DataFrame:
return self.read.table(tableName)
def createDataFrame(
self,
data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
schema: t.Optional[SchemaInput] = None,
samplingRatio: t.Optional[float] = None,
verifySchema: bool = False,
) -> DataFrame:
from sqlglot.dataframe.sql.dataframe import DataFrame
if samplingRatio is not None or verifySchema:
raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
if schema is not None and (
not isinstance(schema, (StructType, str, list))
or (isinstance(schema, list) and not isinstance(schema[0], str))
):
raise NotImplementedError("Only schema of either list or string of list supported")
if not data:
raise ValueError("Must provide data to create into a DataFrame")
column_mapping: t.Dict[str, t.Optional[str]]
if schema is not None:
column_mapping = get_column_mapping_from_schema_input(schema)
elif isinstance(data[0], dict):
column_mapping = {col_name.strip(): None for col_name in data[0]}
else:
column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
data_expressions = [
exp.tuple_(
*map(
lambda x: F.lit(x).expression,
row if not isinstance(row, dict) else row.values(),
)
)
for row in data
]
sel_columns = [
(
F.col(name).cast(data_type).alias(name).expression
if data_type is not None
else F.col(name).expression
)
for name, data_type in column_mapping.items()
]
select_kwargs = {
"expressions": sel_columns,
"from": exp.From(
this=exp.Values(
expressions=data_expressions,
alias=exp.TableAlias(
this=exp.to_identifier(self._auto_incrementing_name),
columns=[exp.to_identifier(col_name) for col_name in column_mapping],
),
),
),
}
sel_expression = exp.Select(**select_kwargs)
return DataFrame(self, sel_expression)
def sql(self, sqlQuery: str) -> DataFrame:
expression = sqlglot.parse_one(sqlQuery, read=self.dialect)
if isinstance(expression, exp.Select):
df = DataFrame(self, expression)
df = df._convert_leaf_to_cte()
elif isinstance(expression, (exp.Create, exp.Insert)):
select_expression = expression.expression.copy()
if isinstance(expression, exp.Insert):
select_expression.set("with", expression.args.get("with"))
expression.set("with", None)
del expression.args["expression"]
df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore
df = df._convert_leaf_to_cte()
else:
raise ValueError(
"Unknown expression type provided in the SQL. Please create an issue with the SQL."
)
return df
@property
def _auto_incrementing_name(self) -> str:
name = f"a{self.incrementing_id}"
self.incrementing_id += 1
return name
@property
def _random_branch_id(self) -> str:
id = self._random_id
self.known_branch_ids.add(id)
return id
@property
def _random_sequence_id(self):
id = self._random_id
self.known_sequence_ids.add(id)
return id
@property
def _random_id(self) -> str:
id = "r" + uuid.uuid4().hex
self.known_ids.add(id)
return id
@property
def _join_hint_names(self) -> t.Set[str]:
return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"}
def _add_alias_to_mapping(self, name: str, sequence_id: str):
self.name_to_sequence_id_mapping[name].append(sequence_id)
class Builder:
SQLFRAME_DIALECT_KEY = "sqlframe.dialect"
def __init__(self):
self.dialect = "spark"
def __getattr__(self, item) -> SparkSession.Builder:
return self
def __call__(self, *args, **kwargs):
return self
def config(
self,
key: t.Optional[str] = None,
value: t.Optional[t.Any] = None,
*,
map: t.Optional[t.Dict[str, t.Any]] = None,
**kwargs: t.Any,
) -> SparkSession.Builder:
if key == self.SQLFRAME_DIALECT_KEY:
self.dialect = value
elif map and self.SQLFRAME_DIALECT_KEY in map:
self.dialect = map[self.SQLFRAME_DIALECT_KEY]
return self
def getOrCreate(self) -> SparkSession:
spark = SparkSession()
spark.dialect = Dialect.get_or_raise(self.dialect)
return spark
@classproperty
def builder(cls) -> Builder:
return cls.Builder()