861 lines
37 KiB
Python
861 lines
37 KiB
Python
from __future__ import annotations
|
|
|
|
import functools
|
|
import logging
|
|
import typing as t
|
|
import zlib
|
|
from copy import copy
|
|
|
|
import sqlglot
|
|
from sqlglot import Dialect, expressions as exp
|
|
from sqlglot.dataframe.sql import functions as F
|
|
from sqlglot.dataframe.sql.column import Column
|
|
from sqlglot.dataframe.sql.group import GroupedData
|
|
from sqlglot.dataframe.sql.normalize import normalize
|
|
from sqlglot.dataframe.sql.operations import Operation, operation
|
|
from sqlglot.dataframe.sql.readwriter import DataFrameWriter
|
|
from sqlglot.dataframe.sql.transforms import replace_id_value
|
|
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
|
|
from sqlglot.dataframe.sql.window import Window
|
|
from sqlglot.helper import ensure_list, object_to_dict, seq_get
|
|
from sqlglot.optimizer import optimize as optimize_func
|
|
from sqlglot.optimizer.qualify_columns import quote_identifiers
|
|
|
|
if t.TYPE_CHECKING:
|
|
from sqlglot.dataframe.sql._typing import (
|
|
ColumnLiterals,
|
|
ColumnOrLiteral,
|
|
ColumnOrName,
|
|
OutputExpressionContainer,
|
|
)
|
|
from sqlglot.dataframe.sql.session import SparkSession
|
|
from sqlglot.dialects.dialect import DialectType
|
|
|
|
logger = logging.getLogger("sqlglot")
|
|
|
|
JOIN_HINTS = {
|
|
"BROADCAST",
|
|
"BROADCASTJOIN",
|
|
"MAPJOIN",
|
|
"MERGE",
|
|
"SHUFFLEMERGE",
|
|
"MERGEJOIN",
|
|
"SHUFFLE_HASH",
|
|
"SHUFFLE_REPLICATE_NL",
|
|
}
|
|
|
|
|
|
class DataFrame:
|
|
def __init__(
|
|
self,
|
|
spark: SparkSession,
|
|
expression: exp.Select,
|
|
branch_id: t.Optional[str] = None,
|
|
sequence_id: t.Optional[str] = None,
|
|
last_op: Operation = Operation.INIT,
|
|
pending_hints: t.Optional[t.List[exp.Expression]] = None,
|
|
output_expression_container: t.Optional[OutputExpressionContainer] = None,
|
|
**kwargs,
|
|
):
|
|
self.spark = spark
|
|
self.expression = expression
|
|
self.branch_id = branch_id or self.spark._random_branch_id
|
|
self.sequence_id = sequence_id or self.spark._random_sequence_id
|
|
self.last_op = last_op
|
|
self.pending_hints = pending_hints or []
|
|
self.output_expression_container = output_expression_container or exp.Select()
|
|
|
|
def __getattr__(self, column_name: str) -> Column:
|
|
return self[column_name]
|
|
|
|
def __getitem__(self, column_name: str) -> Column:
|
|
column_name = f"{self.branch_id}.{column_name}"
|
|
return Column(column_name)
|
|
|
|
def __copy__(self):
|
|
return self.copy()
|
|
|
|
@property
|
|
def sparkSession(self):
|
|
return self.spark
|
|
|
|
@property
|
|
def write(self):
|
|
return DataFrameWriter(self)
|
|
|
|
@property
|
|
def latest_cte_name(self) -> str:
|
|
if not self.expression.ctes:
|
|
from_exp = self.expression.args["from"]
|
|
if from_exp.alias_or_name:
|
|
return from_exp.alias_or_name
|
|
table_alias = from_exp.find(exp.TableAlias)
|
|
if not table_alias:
|
|
raise RuntimeError(
|
|
f"Could not find an alias name for this expression: {self.expression}"
|
|
)
|
|
return table_alias.alias_or_name
|
|
return self.expression.ctes[-1].alias
|
|
|
|
@property
|
|
def pending_join_hints(self):
|
|
return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)]
|
|
|
|
@property
|
|
def pending_partition_hints(self):
|
|
return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)]
|
|
|
|
@property
|
|
def columns(self) -> t.List[str]:
|
|
return self.expression.named_selects
|
|
|
|
@property
|
|
def na(self) -> DataFrameNaFunctions:
|
|
return DataFrameNaFunctions(self)
|
|
|
|
def _replace_cte_names_with_hashes(self, expression: exp.Select):
|
|
replacement_mapping = {}
|
|
for cte in expression.ctes:
|
|
old_name_id = cte.args["alias"].this
|
|
new_hashed_id = exp.to_identifier(
|
|
self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
|
|
)
|
|
replacement_mapping[old_name_id] = new_hashed_id
|
|
expression = expression.transform(replace_id_value, replacement_mapping)
|
|
return expression
|
|
|
|
def _create_cte_from_expression(
|
|
self,
|
|
expression: exp.Expression,
|
|
branch_id: t.Optional[str] = None,
|
|
sequence_id: t.Optional[str] = None,
|
|
**kwargs,
|
|
) -> t.Tuple[exp.CTE, str]:
|
|
name = self._create_hash_from_expression(expression)
|
|
expression_to_cte = expression.copy()
|
|
expression_to_cte.set("with", None)
|
|
cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0]
|
|
cte.set("branch_id", branch_id or self.branch_id)
|
|
cte.set("sequence_id", sequence_id or self.sequence_id)
|
|
return cte, name
|
|
|
|
@t.overload
|
|
def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: ...
|
|
|
|
@t.overload
|
|
def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: ...
|
|
|
|
def _ensure_list_of_columns(self, cols):
|
|
return Column.ensure_cols(ensure_list(cols))
|
|
|
|
def _ensure_and_normalize_cols(self, cols, expression: t.Optional[exp.Select] = None):
|
|
cols = self._ensure_list_of_columns(cols)
|
|
normalize(self.spark, expression or self.expression, cols)
|
|
return cols
|
|
|
|
def _ensure_and_normalize_col(self, col):
|
|
col = Column.ensure_col(col)
|
|
normalize(self.spark, self.expression, col)
|
|
return col
|
|
|
|
def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame:
|
|
df = self._resolve_pending_hints()
|
|
sequence_id = sequence_id or df.sequence_id
|
|
expression = df.expression.copy()
|
|
cte_expression, cte_name = df._create_cte_from_expression(
|
|
expression=expression, sequence_id=sequence_id
|
|
)
|
|
new_expression = df._add_ctes_to_expression(
|
|
exp.Select(), expression.ctes + [cte_expression]
|
|
)
|
|
sel_columns = df._get_outer_select_columns(cte_expression)
|
|
new_expression = new_expression.from_(cte_name).select(
|
|
*[x.alias_or_name for x in sel_columns]
|
|
)
|
|
return df.copy(expression=new_expression, sequence_id=sequence_id)
|
|
|
|
def _resolve_pending_hints(self) -> DataFrame:
|
|
df = self.copy()
|
|
if not self.pending_hints:
|
|
return df
|
|
expression = df.expression
|
|
hint_expression = expression.args.get("hint") or exp.Hint(expressions=[])
|
|
for hint in df.pending_partition_hints:
|
|
hint_expression.append("expressions", hint)
|
|
df.pending_hints.remove(hint)
|
|
|
|
join_aliases = {
|
|
join_table.alias_or_name
|
|
for join_table in get_tables_from_expression_with_join(expression)
|
|
}
|
|
if join_aliases:
|
|
for hint in df.pending_join_hints:
|
|
for sequence_id_expression in hint.expressions:
|
|
sequence_id_or_name = sequence_id_expression.alias_or_name
|
|
sequence_ids_to_match = [sequence_id_or_name]
|
|
if sequence_id_or_name in df.spark.name_to_sequence_id_mapping:
|
|
sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[
|
|
sequence_id_or_name
|
|
]
|
|
matching_ctes = [
|
|
cte
|
|
for cte in reversed(expression.ctes)
|
|
if cte.args["sequence_id"] in sequence_ids_to_match
|
|
]
|
|
for matching_cte in matching_ctes:
|
|
if matching_cte.alias_or_name in join_aliases:
|
|
sequence_id_expression.set("this", matching_cte.args["alias"].this)
|
|
df.pending_hints.remove(hint)
|
|
break
|
|
hint_expression.append("expressions", hint)
|
|
if hint_expression.expressions:
|
|
expression.set("hint", hint_expression)
|
|
return df
|
|
|
|
def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame:
|
|
hint_name = hint_name.upper()
|
|
hint_expression = (
|
|
exp.JoinHint(
|
|
this=hint_name,
|
|
expressions=[exp.to_table(parameter.alias_or_name) for parameter in args],
|
|
)
|
|
if hint_name in JOIN_HINTS
|
|
else exp.Anonymous(
|
|
this=hint_name, expressions=[parameter.expression for parameter in args]
|
|
)
|
|
)
|
|
new_df = self.copy()
|
|
new_df.pending_hints.append(hint_expression)
|
|
return new_df
|
|
|
|
def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool):
|
|
other_df = other._convert_leaf_to_cte()
|
|
base_expression = self.expression.copy()
|
|
base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes)
|
|
all_ctes = base_expression.ctes
|
|
other_df.expression.set("with", None)
|
|
base_expression.set("with", None)
|
|
operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression)
|
|
operation.set("with", exp.With(expressions=all_ctes))
|
|
return self.copy(expression=operation)._convert_leaf_to_cte()
|
|
|
|
def _cache(self, storage_level: str):
|
|
df = self._convert_leaf_to_cte()
|
|
df.expression.ctes[-1].set("cache_storage_level", storage_level)
|
|
return df
|
|
|
|
@classmethod
|
|
def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select:
|
|
expression = expression.copy()
|
|
with_expression = expression.args.get("with")
|
|
if with_expression:
|
|
existing_ctes = with_expression.expressions
|
|
existsing_cte_names = {x.alias_or_name for x in existing_ctes}
|
|
for cte in ctes:
|
|
if cte.alias_or_name not in existsing_cte_names:
|
|
existing_ctes.append(cte)
|
|
else:
|
|
existing_ctes = ctes
|
|
expression.set("with", exp.With(expressions=existing_ctes))
|
|
return expression
|
|
|
|
@classmethod
|
|
def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]:
|
|
expression = item.expression if isinstance(item, DataFrame) else item
|
|
return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions]
|
|
|
|
@classmethod
|
|
def _create_hash_from_expression(cls, expression: exp.Expression) -> str:
|
|
from sqlglot.dataframe.sql.session import SparkSession
|
|
|
|
value = expression.sql(dialect=SparkSession().dialect).encode("utf-8")
|
|
return f"t{zlib.crc32(value)}"[:6]
|
|
|
|
def _get_select_expressions(
|
|
self,
|
|
) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
|
|
select_expressions: t.List[
|
|
t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]
|
|
] = []
|
|
main_select_ctes: t.List[exp.CTE] = []
|
|
for cte in self.expression.ctes:
|
|
cache_storage_level = cte.args.get("cache_storage_level")
|
|
if cache_storage_level:
|
|
select_expression = cte.this.copy()
|
|
select_expression.set("with", exp.With(expressions=copy(main_select_ctes)))
|
|
select_expression.set("cte_alias_name", cte.alias_or_name)
|
|
select_expression.set("cache_storage_level", cache_storage_level)
|
|
select_expressions.append((exp.Cache, select_expression))
|
|
else:
|
|
main_select_ctes.append(cte)
|
|
main_select = self.expression.copy()
|
|
if main_select_ctes:
|
|
main_select.set("with", exp.With(expressions=main_select_ctes))
|
|
expression_select_pair = (type(self.output_expression_container), main_select)
|
|
select_expressions.append(expression_select_pair) # type: ignore
|
|
return select_expressions
|
|
|
|
def sql(self, dialect: DialectType = None, optimize: bool = True, **kwargs) -> t.List[str]:
|
|
from sqlglot.dataframe.sql.session import SparkSession
|
|
|
|
dialect = Dialect.get_or_raise(dialect or SparkSession().dialect)
|
|
|
|
df = self._resolve_pending_hints()
|
|
select_expressions = df._get_select_expressions()
|
|
output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
|
|
replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
|
|
|
|
for expression_type, select_expression in select_expressions:
|
|
select_expression = select_expression.transform(replace_id_value, replacement_mapping)
|
|
if optimize:
|
|
quote_identifiers(select_expression, dialect=dialect)
|
|
select_expression = t.cast(
|
|
exp.Select, optimize_func(select_expression, dialect=dialect)
|
|
)
|
|
|
|
select_expression = df._replace_cte_names_with_hashes(select_expression)
|
|
|
|
expression: t.Union[exp.Select, exp.Cache, exp.Drop]
|
|
if expression_type == exp.Cache:
|
|
cache_table_name = df._create_hash_from_expression(select_expression)
|
|
cache_table = exp.to_table(cache_table_name)
|
|
original_alias_name = select_expression.args["cte_alias_name"]
|
|
|
|
replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore
|
|
cache_table_name
|
|
)
|
|
sqlglot.schema.add_table(
|
|
cache_table_name,
|
|
{
|
|
expression.alias_or_name: expression.type.sql(dialect=dialect)
|
|
for expression in select_expression.expressions
|
|
},
|
|
dialect=dialect,
|
|
)
|
|
|
|
cache_storage_level = select_expression.args["cache_storage_level"]
|
|
options = [
|
|
exp.Literal.string("storageLevel"),
|
|
exp.Literal.string(cache_storage_level),
|
|
]
|
|
expression = exp.Cache(
|
|
this=cache_table, expression=select_expression, lazy=True, options=options
|
|
)
|
|
|
|
# We will drop the "view" if it exists before running the cache table
|
|
output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
|
|
elif expression_type == exp.Create:
|
|
expression = df.output_expression_container.copy()
|
|
expression.set("expression", select_expression)
|
|
elif expression_type == exp.Insert:
|
|
expression = df.output_expression_container.copy()
|
|
select_without_ctes = select_expression.copy()
|
|
select_without_ctes.set("with", None)
|
|
expression.set("expression", select_without_ctes)
|
|
|
|
if select_expression.ctes:
|
|
expression.set("with", exp.With(expressions=select_expression.ctes))
|
|
elif expression_type == exp.Select:
|
|
expression = select_expression
|
|
else:
|
|
raise ValueError(f"Invalid expression type: {expression_type}")
|
|
|
|
output_expressions.append(expression)
|
|
|
|
return [expression.sql(dialect=dialect, **kwargs) for expression in output_expressions]
|
|
|
|
def copy(self, **kwargs) -> DataFrame:
|
|
return DataFrame(**object_to_dict(self, **kwargs))
|
|
|
|
@operation(Operation.SELECT)
|
|
def select(self, *cols, **kwargs) -> DataFrame:
|
|
cols = self._ensure_and_normalize_cols(cols)
|
|
kwargs["append"] = kwargs.get("append", False)
|
|
if self.expression.args.get("joins"):
|
|
ambiguous_cols = [
|
|
col
|
|
for col in cols
|
|
if isinstance(col.column_expression, exp.Column) and not col.column_expression.table
|
|
]
|
|
if ambiguous_cols:
|
|
join_table_identifiers = [
|
|
x.this for x in get_tables_from_expression_with_join(self.expression)
|
|
]
|
|
cte_names_in_join = [x.this for x in join_table_identifiers]
|
|
# If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right
|
|
# and therefore we allow multiple columns with the same name in the result. This matches the behavior
|
|
# of Spark.
|
|
resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols}
|
|
for ambiguous_col in ambiguous_cols:
|
|
ctes_with_column = [
|
|
cte
|
|
for cte in self.expression.ctes
|
|
if cte.alias_or_name in cte_names_in_join
|
|
and ambiguous_col.alias_or_name in cte.this.named_selects
|
|
]
|
|
# Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise,
|
|
# use the same CTE we used before
|
|
cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1)
|
|
if cte:
|
|
resolved_column_position[ambiguous_col] += 1
|
|
else:
|
|
cte = ctes_with_column[resolved_column_position[ambiguous_col]]
|
|
ambiguous_col.expression.set("table", cte.alias_or_name)
|
|
return self.copy(
|
|
expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs
|
|
)
|
|
|
|
@operation(Operation.NO_OP)
|
|
def alias(self, name: str, **kwargs) -> DataFrame:
|
|
new_sequence_id = self.spark._random_sequence_id
|
|
df = self.copy()
|
|
for join_hint in df.pending_join_hints:
|
|
for expression in join_hint.expressions:
|
|
if expression.alias_or_name == self.sequence_id:
|
|
expression.set("this", Column.ensure_col(new_sequence_id).expression)
|
|
df.spark._add_alias_to_mapping(name, new_sequence_id)
|
|
return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
|
|
|
|
@operation(Operation.WHERE)
|
|
def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
|
|
col = self._ensure_and_normalize_col(column)
|
|
return self.copy(expression=self.expression.where(col.expression))
|
|
|
|
filter = where
|
|
|
|
@operation(Operation.GROUP_BY)
|
|
def groupBy(self, *cols, **kwargs) -> GroupedData:
|
|
columns = self._ensure_and_normalize_cols(cols)
|
|
return GroupedData(self, columns, self.last_op)
|
|
|
|
@operation(Operation.SELECT)
|
|
def agg(self, *exprs, **kwargs) -> DataFrame:
|
|
cols = self._ensure_and_normalize_cols(exprs)
|
|
return self.groupBy().agg(*cols)
|
|
|
|
@operation(Operation.FROM)
|
|
def join(
|
|
self,
|
|
other_df: DataFrame,
|
|
on: t.Union[str, t.List[str], Column, t.List[Column]],
|
|
how: str = "inner",
|
|
**kwargs,
|
|
) -> DataFrame:
|
|
other_df = other_df._convert_leaf_to_cte()
|
|
join_columns = self._ensure_list_of_columns(on)
|
|
# We will determine actual "join on" expression later so we don't provide it at first
|
|
join_expression = self.expression.join(
|
|
other_df.latest_cte_name, join_type=how.replace("_", " ")
|
|
)
|
|
join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes)
|
|
self_columns = self._get_outer_select_columns(join_expression)
|
|
other_columns = self._get_outer_select_columns(other_df)
|
|
# Determines the join clause and select columns to be used passed on what type of columns were provided for
|
|
# the join. The columns returned changes based on how the on expression is provided.
|
|
if isinstance(join_columns[0].expression, exp.Column):
|
|
"""
|
|
Unique characteristics of join on column names only:
|
|
* The column names are put at the front of the select list
|
|
* The column names are deduplicated across the entire select list and only the column names (other dups are allowed)
|
|
"""
|
|
table_names = [
|
|
table.alias_or_name
|
|
for table in get_tables_from_expression_with_join(join_expression)
|
|
]
|
|
potential_ctes = [
|
|
cte
|
|
for cte in join_expression.ctes
|
|
if cte.alias_or_name in table_names
|
|
and cte.alias_or_name != other_df.latest_cte_name
|
|
]
|
|
# Determine the table to reference for the left side of the join by checking each of the left side
|
|
# tables and see if they have the column being referenced.
|
|
join_column_pairs = []
|
|
for join_column in join_columns:
|
|
num_matching_ctes = 0
|
|
for cte in potential_ctes:
|
|
if join_column.alias_or_name in cte.this.named_selects:
|
|
left_column = join_column.copy().set_table_name(cte.alias_or_name)
|
|
right_column = join_column.copy().set_table_name(other_df.latest_cte_name)
|
|
join_column_pairs.append((left_column, right_column))
|
|
num_matching_ctes += 1
|
|
if num_matching_ctes > 1:
|
|
raise ValueError(
|
|
f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name."
|
|
)
|
|
elif num_matching_ctes == 0:
|
|
raise ValueError(
|
|
f"Column {join_column.alias_or_name} does not exist in any of the tables."
|
|
)
|
|
join_clause = functools.reduce(
|
|
lambda x, y: x & y,
|
|
[left_column == right_column for left_column, right_column in join_column_pairs],
|
|
)
|
|
join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs]
|
|
# To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list
|
|
select_column_names = [
|
|
(
|
|
column.alias_or_name
|
|
if not isinstance(column.expression.this, exp.Star)
|
|
else column.sql()
|
|
)
|
|
for column in self_columns + other_columns
|
|
]
|
|
select_column_names = [
|
|
column_name
|
|
for column_name in select_column_names
|
|
if column_name not in join_column_names
|
|
]
|
|
select_column_names = join_column_names + select_column_names
|
|
else:
|
|
"""
|
|
Unique characteristics of join on expressions:
|
|
* There is no deduplication of the results.
|
|
* The left join dataframe columns go first and right come after. No sort preference is given to join columns
|
|
"""
|
|
join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
|
|
if len(join_columns) > 1:
|
|
join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
|
|
join_clause = join_columns[0]
|
|
select_column_names = [column.alias_or_name for column in self_columns + other_columns]
|
|
|
|
# Update the on expression with the actual join clause to replace the dummy one from before
|
|
join_expression.args["joins"][-1].set("on", join_clause.expression)
|
|
new_df = self.copy(expression=join_expression)
|
|
new_df.pending_join_hints.extend(self.pending_join_hints)
|
|
new_df.pending_hints.extend(other_df.pending_hints)
|
|
new_df = new_df.select.__wrapped__(new_df, *select_column_names)
|
|
return new_df
|
|
|
|
@operation(Operation.ORDER_BY)
|
|
def orderBy(
|
|
self,
|
|
*cols: t.Union[str, Column],
|
|
ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
|
|
) -> DataFrame:
|
|
"""
|
|
This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
|
|
has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
|
|
is unlikely to come up.
|
|
"""
|
|
columns = self._ensure_and_normalize_cols(cols)
|
|
pre_ordered_col_indexes = [
|
|
i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered)
|
|
]
|
|
if ascending is None:
|
|
ascending = [True] * len(columns)
|
|
elif not isinstance(ascending, list):
|
|
ascending = [ascending] * len(columns)
|
|
ascending = [bool(x) for i, x in enumerate(ascending)]
|
|
assert len(columns) == len(
|
|
ascending
|
|
), "The length of items in ascending must equal the number of columns provided"
|
|
col_and_ascending = list(zip(columns, ascending))
|
|
order_by_columns = [
|
|
(
|
|
exp.Ordered(this=col.expression, desc=not asc)
|
|
if i not in pre_ordered_col_indexes
|
|
else columns[i].column_expression
|
|
)
|
|
for i, (col, asc) in enumerate(col_and_ascending)
|
|
]
|
|
return self.copy(expression=self.expression.order_by(*order_by_columns))
|
|
|
|
sort = orderBy
|
|
|
|
@operation(Operation.FROM)
|
|
def union(self, other: DataFrame) -> DataFrame:
|
|
return self._set_operation(exp.Union, other, False)
|
|
|
|
unionAll = union
|
|
|
|
@operation(Operation.FROM)
|
|
def unionByName(self, other: DataFrame, allowMissingColumns: bool = False):
|
|
l_columns = self.columns
|
|
r_columns = other.columns
|
|
if not allowMissingColumns:
|
|
l_expressions = l_columns
|
|
r_expressions = l_columns
|
|
else:
|
|
l_expressions = []
|
|
r_expressions = []
|
|
r_columns_unused = copy(r_columns)
|
|
for l_column in l_columns:
|
|
l_expressions.append(l_column)
|
|
if l_column in r_columns:
|
|
r_expressions.append(l_column)
|
|
r_columns_unused.remove(l_column)
|
|
else:
|
|
r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False))
|
|
for r_column in r_columns_unused:
|
|
l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False))
|
|
r_expressions.append(r_column)
|
|
r_df = (
|
|
other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
|
|
)
|
|
l_df = self.copy()
|
|
if allowMissingColumns:
|
|
l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
|
|
return l_df._set_operation(exp.Union, r_df, False)
|
|
|
|
@operation(Operation.FROM)
|
|
def intersect(self, other: DataFrame) -> DataFrame:
|
|
return self._set_operation(exp.Intersect, other, True)
|
|
|
|
@operation(Operation.FROM)
|
|
def intersectAll(self, other: DataFrame) -> DataFrame:
|
|
return self._set_operation(exp.Intersect, other, False)
|
|
|
|
@operation(Operation.FROM)
|
|
def exceptAll(self, other: DataFrame) -> DataFrame:
|
|
return self._set_operation(exp.Except, other, False)
|
|
|
|
@operation(Operation.SELECT)
|
|
def distinct(self) -> DataFrame:
|
|
return self.copy(expression=self.expression.distinct())
|
|
|
|
@operation(Operation.SELECT)
|
|
def dropDuplicates(self, subset: t.Optional[t.List[str]] = None):
|
|
if not subset:
|
|
return self.distinct()
|
|
column_names = ensure_list(subset)
|
|
window = Window.partitionBy(*column_names).orderBy(*column_names)
|
|
return (
|
|
self.copy()
|
|
.withColumn("row_num", F.row_number().over(window))
|
|
.where(F.col("row_num") == F.lit(1))
|
|
.drop("row_num")
|
|
)
|
|
|
|
@operation(Operation.FROM)
|
|
def dropna(
|
|
self,
|
|
how: str = "any",
|
|
thresh: t.Optional[int] = None,
|
|
subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
|
|
) -> DataFrame:
|
|
minimum_non_null = thresh or 0 # will be determined later if thresh is null
|
|
new_df = self.copy()
|
|
all_columns = self._get_outer_select_columns(new_df.expression)
|
|
if subset:
|
|
null_check_columns = self._ensure_and_normalize_cols(subset)
|
|
else:
|
|
null_check_columns = all_columns
|
|
if thresh is None:
|
|
minimum_num_nulls = 1 if how == "any" else len(null_check_columns)
|
|
else:
|
|
minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1
|
|
if minimum_num_nulls > len(null_check_columns):
|
|
raise RuntimeError(
|
|
f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
|
|
f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
|
|
)
|
|
if_null_checks = [
|
|
F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
|
|
]
|
|
nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
|
|
num_nulls = nulls_added_together.alias("num_nulls")
|
|
new_df = new_df.select(num_nulls, append=True)
|
|
filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls))
|
|
final_df = filtered_df.select(*all_columns)
|
|
return final_df
|
|
|
|
@operation(Operation.FROM)
|
|
def fillna(
|
|
self,
|
|
value: t.Union[ColumnLiterals],
|
|
subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
|
|
) -> DataFrame:
|
|
"""
|
|
Functionality Difference: If you provide a value to replace a null and that type conflicts
|
|
with the type of the column then PySpark will just ignore your replacement.
|
|
This will try to cast them to be the same in some cases. So they won't always match.
|
|
Best to not mix types so make sure replacement is the same type as the column
|
|
|
|
Possibility for improvement: Use `typeof` function to get the type of the column
|
|
and check if it matches the type of the value provided. If not then make it null.
|
|
"""
|
|
from sqlglot.dataframe.sql.functions import lit
|
|
|
|
values = None
|
|
columns = None
|
|
new_df = self.copy()
|
|
all_columns = self._get_outer_select_columns(new_df.expression)
|
|
all_column_mapping = {column.alias_or_name: column for column in all_columns}
|
|
if isinstance(value, dict):
|
|
values = list(value.values())
|
|
columns = self._ensure_and_normalize_cols(list(value))
|
|
if not columns:
|
|
columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
|
|
if not values:
|
|
values = [value] * len(columns)
|
|
value_columns = [lit(value) for value in values]
|
|
|
|
null_replacement_mapping = {
|
|
column.alias_or_name: (
|
|
F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
|
|
)
|
|
for column, value in zip(columns, value_columns)
|
|
}
|
|
null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
|
|
null_replacement_columns = [
|
|
null_replacement_mapping[column.alias_or_name] for column in all_columns
|
|
]
|
|
new_df = new_df.select(*null_replacement_columns)
|
|
return new_df
|
|
|
|
@operation(Operation.FROM)
|
|
def replace(
|
|
self,
|
|
to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
|
|
value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
|
|
subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
|
|
) -> DataFrame:
|
|
from sqlglot.dataframe.sql.functions import lit
|
|
|
|
old_values = None
|
|
new_df = self.copy()
|
|
all_columns = self._get_outer_select_columns(new_df.expression)
|
|
all_column_mapping = {column.alias_or_name: column for column in all_columns}
|
|
|
|
columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
|
|
if isinstance(to_replace, dict):
|
|
old_values = list(to_replace)
|
|
new_values = list(to_replace.values())
|
|
elif not old_values and isinstance(to_replace, list):
|
|
assert isinstance(value, list), "value must be a list since the replacements are a list"
|
|
assert len(to_replace) == len(
|
|
value
|
|
), "the replacements and values must be the same length"
|
|
old_values = to_replace
|
|
new_values = value
|
|
else:
|
|
old_values = [to_replace] * len(columns)
|
|
new_values = [value] * len(columns)
|
|
old_values = [lit(value) for value in old_values]
|
|
new_values = [lit(value) for value in new_values]
|
|
|
|
replacement_mapping = {}
|
|
for column in columns:
|
|
expression = Column(None)
|
|
for i, (old_value, new_value) in enumerate(zip(old_values, new_values)):
|
|
if i == 0:
|
|
expression = F.when(column == old_value, new_value)
|
|
else:
|
|
expression = expression.when(column == old_value, new_value) # type: ignore
|
|
replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias(
|
|
column.expression.alias_or_name
|
|
)
|
|
|
|
replacement_mapping = {**all_column_mapping, **replacement_mapping}
|
|
replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns]
|
|
new_df = new_df.select(*replacement_columns)
|
|
return new_df
|
|
|
|
@operation(Operation.SELECT)
|
|
def withColumn(self, colName: str, col: Column) -> DataFrame:
|
|
col = self._ensure_and_normalize_col(col)
|
|
existing_col_names = self.expression.named_selects
|
|
existing_col_index = (
|
|
existing_col_names.index(colName) if colName in existing_col_names else None
|
|
)
|
|
if existing_col_index:
|
|
expression = self.expression.copy()
|
|
expression.expressions[existing_col_index] = col.expression
|
|
return self.copy(expression=expression)
|
|
return self.copy().select(col.alias(colName), append=True)
|
|
|
|
@operation(Operation.SELECT)
|
|
def withColumnRenamed(self, existing: str, new: str):
|
|
expression = self.expression.copy()
|
|
existing_columns = [
|
|
expression
|
|
for expression in expression.expressions
|
|
if expression.alias_or_name == existing
|
|
]
|
|
if not existing_columns:
|
|
raise ValueError("Tried to rename a column that doesn't exist")
|
|
for existing_column in existing_columns:
|
|
if isinstance(existing_column, exp.Column):
|
|
existing_column.replace(exp.alias_(existing_column, new))
|
|
else:
|
|
existing_column.set("alias", exp.to_identifier(new))
|
|
return self.copy(expression=expression)
|
|
|
|
@operation(Operation.SELECT)
|
|
def drop(self, *cols: t.Union[str, Column]) -> DataFrame:
|
|
all_columns = self._get_outer_select_columns(self.expression)
|
|
drop_cols = self._ensure_and_normalize_cols(cols)
|
|
new_columns = [
|
|
col
|
|
for col in all_columns
|
|
if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols]
|
|
]
|
|
return self.copy().select(*new_columns, append=False)
|
|
|
|
@operation(Operation.LIMIT)
|
|
def limit(self, num: int) -> DataFrame:
|
|
return self.copy(expression=self.expression.limit(num))
|
|
|
|
@operation(Operation.NO_OP)
|
|
def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
|
|
parameter_list = ensure_list(parameters)
|
|
parameter_columns = (
|
|
self._ensure_list_of_columns(parameter_list)
|
|
if parameters
|
|
else Column.ensure_cols([self.sequence_id])
|
|
)
|
|
return self._hint(name, parameter_columns)
|
|
|
|
@operation(Operation.NO_OP)
|
|
def repartition(
|
|
self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName
|
|
) -> DataFrame:
|
|
num_partition_cols = self._ensure_list_of_columns(numPartitions)
|
|
columns = self._ensure_and_normalize_cols(cols)
|
|
args = num_partition_cols + columns
|
|
return self._hint("repartition", args)
|
|
|
|
@operation(Operation.NO_OP)
|
|
def coalesce(self, numPartitions: int) -> DataFrame:
|
|
num_partitions = Column.ensure_cols([numPartitions])
|
|
return self._hint("coalesce", num_partitions)
|
|
|
|
@operation(Operation.NO_OP)
|
|
def cache(self) -> DataFrame:
|
|
return self._cache(storage_level="MEMORY_AND_DISK")
|
|
|
|
@operation(Operation.NO_OP)
|
|
def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame:
|
|
"""
|
|
Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
|
|
"""
|
|
return self._cache(storageLevel)
|
|
|
|
|
|
class DataFrameNaFunctions:
|
|
def __init__(self, df: DataFrame):
|
|
self.df = df
|
|
|
|
def drop(
|
|
self,
|
|
how: str = "any",
|
|
thresh: t.Optional[int] = None,
|
|
subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
|
|
) -> DataFrame:
|
|
return self.df.dropna(how=how, thresh=thresh, subset=subset)
|
|
|
|
def fill(
|
|
self,
|
|
value: t.Union[int, bool, float, str, t.Dict[str, t.Any]],
|
|
subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
|
|
) -> DataFrame:
|
|
return self.df.fillna(value=value, subset=subset)
|
|
|
|
def replace(
|
|
self,
|
|
to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
|
|
value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
|
|
subset: t.Optional[t.Union[str, t.List[str]]] = None,
|
|
) -> DataFrame:
|
|
return self.df.replace(to_replace=to_replace, value=value, subset=subset)
|