Adding upstream version 0.12.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
d887bee5ca
commit
148efc9122
69 changed files with 12923 additions and 0 deletions
157
src/scripts/benchmark.py
Normal file
157
src/scripts/benchmark.py
Normal file
|
@ -0,0 +1,157 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
from pathlib import Path
|
||||
from time import perf_counter
|
||||
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
from textual.app import App, ComposeResult
|
||||
from textual.driver import Driver
|
||||
from textual.pilot import Pilot
|
||||
from textual.types import CSSPathType
|
||||
from textual.widgets import DataTable as BuiltinDataTable
|
||||
from textual_fastdatatable import ArrowBackend
|
||||
from textual_fastdatatable import DataTable as FastDataTable
|
||||
from textual_fastdatatable.backend import PolarsBackend
|
||||
|
||||
BENCHMARK_DATA = Path(__file__).parent.parent.parent / "tests" / "data"
|
||||
|
||||
|
||||
async def scroller(pilot: Pilot) -> None:
|
||||
first_paint = perf_counter() - pilot.app.start # type: ignore
|
||||
for _ in range(5):
|
||||
await pilot.press("pagedown")
|
||||
for _ in range(15):
|
||||
await pilot.press("right")
|
||||
for _ in range(5):
|
||||
await pilot.press("pagedown")
|
||||
elapsed = perf_counter() - pilot.app.start # type: ignore
|
||||
pilot.app.exit(result=(first_paint, elapsed))
|
||||
|
||||
|
||||
class BuiltinApp(App):
|
||||
TITLE = "Built-In DataTable"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path: Path,
|
||||
driver_class: type[Driver] | None = None,
|
||||
css_path: CSSPathType | None = None,
|
||||
watch_css: bool = False,
|
||||
):
|
||||
super().__init__(driver_class, css_path, watch_css)
|
||||
self.data_path = data_path
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
df = pd.read_parquet(self.data_path)
|
||||
rows = [tuple(row) for row in df.itertuples(index=False)]
|
||||
self.start = perf_counter()
|
||||
table: BuiltinDataTable = BuiltinDataTable()
|
||||
table.add_columns(*[str(col) for col in df.columns])
|
||||
for row in rows:
|
||||
table.add_row(*row, height=1, label=None)
|
||||
yield table
|
||||
|
||||
|
||||
class ArrowBackendApp(App):
|
||||
TITLE = "FastDataTable (Arrow from Parquet)"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path: Path,
|
||||
driver_class: type[Driver] | None = None,
|
||||
css_path: CSSPathType | None = None,
|
||||
watch_css: bool = False,
|
||||
):
|
||||
super().__init__(driver_class, css_path, watch_css)
|
||||
self.data_path = data_path
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
self.start = perf_counter()
|
||||
yield FastDataTable(data=self.data_path)
|
||||
|
||||
|
||||
class ArrowBackendAppFromRecords(App):
|
||||
TITLE = "FastDataTable (Arrow from Records)"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path: Path,
|
||||
driver_class: type[Driver] | None = None,
|
||||
css_path: CSSPathType | None = None,
|
||||
watch_css: bool = False,
|
||||
):
|
||||
super().__init__(driver_class, css_path, watch_css)
|
||||
self.data_path = data_path
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
df = pd.read_parquet(self.data_path)
|
||||
rows = [tuple(row) for row in df.itertuples(index=False)]
|
||||
self.start = perf_counter()
|
||||
backend = ArrowBackend.from_records(rows, has_header=False)
|
||||
table = FastDataTable(
|
||||
backend=backend, column_labels=[str(col) for col in df.columns]
|
||||
)
|
||||
yield table
|
||||
|
||||
|
||||
class PolarsBackendApp(App):
|
||||
TITLE = "FastDataTable (Polars from Parquet)"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path: Path,
|
||||
driver_class: type[Driver] | None = None,
|
||||
css_path: CSSPathType | None = None,
|
||||
watch_css: bool = False,
|
||||
):
|
||||
super().__init__(driver_class, css_path, watch_css)
|
||||
self.data_path = data_path
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
self.start = perf_counter()
|
||||
yield FastDataTable(
|
||||
data=PolarsBackend.from_dataframe(pl.read_parquet(self.data_path))
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app_defs = [
|
||||
BuiltinApp,
|
||||
ArrowBackendApp,
|
||||
ArrowBackendAppFromRecords,
|
||||
PolarsBackendApp,
|
||||
]
|
||||
bench = [
|
||||
(f"lap_times_{n}.parquet", 3 if n <= 10000 else 1)
|
||||
for n in [100, 1000, 10000, 100000, 538121]
|
||||
]
|
||||
bench.extend([(f"wide_{n}.parquet", 1) for n in [10000, 100000]])
|
||||
with open("results.md", "w") as f:
|
||||
print(
|
||||
"Records |",
|
||||
" | ".join([a.TITLE for a in app_defs]), # type: ignore
|
||||
sep="",
|
||||
file=f,
|
||||
)
|
||||
print("--------|", "|".join(["--------" for _ in app_defs]), sep="", file=f)
|
||||
for p, tries in bench:
|
||||
first_paint: list[list[float]] = [list() for _ in app_defs]
|
||||
elapsed: list[list[float]] = [list() for _ in app_defs]
|
||||
for i, app_cls in enumerate(app_defs):
|
||||
for _ in range(tries):
|
||||
app = app_cls(BENCHMARK_DATA / p)
|
||||
gc.disable()
|
||||
fp, el = app.run(headless=True, auto_pilot=scroller) # type: ignore
|
||||
gc.collect()
|
||||
first_paint[i].append(fp)
|
||||
elapsed[i].append(el)
|
||||
gc.enable()
|
||||
avg_first_paint = [sum(app_times) / tries for app_times in first_paint]
|
||||
avg_elapsed = [sum(app_times) / tries for app_times in elapsed]
|
||||
formatted = [
|
||||
f"{fp:7,.3f}s / {el:7,.3f}s"
|
||||
for fp, el in zip(avg_first_paint, avg_elapsed)
|
||||
]
|
||||
print(f"{p} | {' | '.join(formatted)}", file=f)
|
32
src/scripts/run_arrow_wide.py
Normal file
32
src/scripts/run_arrow_wide.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from textual.app import App, ComposeResult
|
||||
from textual.driver import Driver
|
||||
from textual.types import CSSPathType
|
||||
from textual_fastdatatable import DataTable
|
||||
|
||||
BENCHMARK_DATA = Path(__file__).parent.parent.parent / "tests" / "data"
|
||||
|
||||
|
||||
class ArrowBackendApp(App):
|
||||
TITLE = "FastDataTable (Arrow)"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path: Path,
|
||||
driver_class: type[Driver] | None = None,
|
||||
css_path: CSSPathType | None = None,
|
||||
watch_css: bool = False,
|
||||
):
|
||||
super().__init__(driver_class, css_path, watch_css)
|
||||
self.data_path = data_path
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield DataTable(data=self.data_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = ArrowBackendApp(data_path=BENCHMARK_DATA / "wide_100000.parquet")
|
||||
app.run()
|
39
src/scripts/run_builtin_wide.py
Normal file
39
src/scripts/run_builtin_wide.py
Normal file
|
@ -0,0 +1,39 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
from textual.app import App, ComposeResult
|
||||
from textual.driver import Driver
|
||||
from textual.types import CSSPathType
|
||||
from textual.widgets import DataTable
|
||||
|
||||
BENCHMARK_DATA = Path(__file__).parent.parent.parent / "tests" / "data"
|
||||
|
||||
|
||||
class BuiltinApp(App):
|
||||
TITLE = "Built-In DataTable"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path: Path,
|
||||
driver_class: type[Driver] | None = None,
|
||||
css_path: CSSPathType | None = None,
|
||||
watch_css: bool = False,
|
||||
):
|
||||
super().__init__(driver_class, css_path, watch_css)
|
||||
self.data_path = data_path
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
df = pd.read_parquet(self.data_path)
|
||||
rows = [tuple(row) for row in df.itertuples(index=False)]
|
||||
table: DataTable = DataTable()
|
||||
table.add_columns(*[str(col) for col in df.columns])
|
||||
for row in rows:
|
||||
table.add_row(*row, height=1, label=None)
|
||||
yield table
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = BuiltinApp(data_path=BENCHMARK_DATA / "wide_10000.parquet")
|
||||
app.run()
|
13
src/textual_fastdatatable/__init__.py
Normal file
13
src/textual_fastdatatable/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
from textual_fastdatatable.backend import (
|
||||
ArrowBackend,
|
||||
DataTableBackend,
|
||||
create_backend,
|
||||
)
|
||||
from textual_fastdatatable.data_table import DataTable
|
||||
|
||||
__all__ = [
|
||||
"DataTable",
|
||||
"ArrowBackend",
|
||||
"DataTableBackend",
|
||||
"create_backend",
|
||||
]
|
19
src/textual_fastdatatable/__main__.py
Normal file
19
src/textual_fastdatatable/__main__.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
from textual.app import App, ComposeResult
|
||||
|
||||
from textual_fastdatatable import ArrowBackend, DataTable
|
||||
|
||||
|
||||
class TableApp(App, inherit_bindings=False):
|
||||
BINDINGS = [("ctrl+q", "quit", "Quit"), ("ctrl+d", "quit", "Quit")]
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
backend = ArrowBackend.from_parquet("./tests/data/wide_100000.parquet")
|
||||
yield DataTable(backend=backend, cursor_type="range", fixed_columns=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import locale
|
||||
|
||||
locale.setlocale(locale.LC_ALL, "")
|
||||
app = TableApp()
|
||||
app.run()
|
706
src/textual_fastdatatable/backend.py
Normal file
706
src/textual_fastdatatable/backend.py
Normal file
|
@ -0,0 +1,706 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import suppress
|
||||
from datetime import date, datetime
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterable,
|
||||
Literal,
|
||||
Mapping,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.lib as pal
|
||||
import pyarrow.parquet as pq
|
||||
import pyarrow.types as pt
|
||||
from rich.console import Console
|
||||
|
||||
from textual_fastdatatable.formatter import measure_width
|
||||
|
||||
AutoBackendType = Any
|
||||
|
||||
try:
|
||||
import polars as pl
|
||||
import polars.datatypes as pld
|
||||
except ImportError:
|
||||
_HAS_POLARS = False
|
||||
else:
|
||||
_HAS_POLARS = True
|
||||
|
||||
|
||||
def create_backend(
|
||||
data: "AutoBackendType",
|
||||
max_rows: int | None = None,
|
||||
has_header: bool = False,
|
||||
) -> DataTableBackend:
|
||||
if isinstance(data, pa.Table):
|
||||
return ArrowBackend(data, max_rows=max_rows)
|
||||
if isinstance(data, pa.RecordBatch):
|
||||
return ArrowBackend.from_batches(data, max_rows=max_rows)
|
||||
if _HAS_POLARS and isinstance(data, pl.DataFrame):
|
||||
return PolarsBackend.from_dataframe(data, max_rows=max_rows)
|
||||
|
||||
if isinstance(data, Path) or isinstance(data, str):
|
||||
data = Path(data)
|
||||
if data.suffix in [".pqt", ".parquet"]:
|
||||
return ArrowBackend.from_parquet(data, max_rows=max_rows)
|
||||
if _HAS_POLARS:
|
||||
return PolarsBackend.from_file_path(
|
||||
data, max_rows=max_rows, has_header=has_header
|
||||
)
|
||||
if isinstance(data, Sequence) and not data:
|
||||
return ArrowBackend(pa.table([]), max_rows=max_rows)
|
||||
if isinstance(data, Sequence) and _is_iterable(data[0]):
|
||||
return ArrowBackend.from_records(data, max_rows=max_rows, has_header=has_header)
|
||||
|
||||
if (
|
||||
isinstance(data, Mapping)
|
||||
and isinstance(next(iter(data.keys())), str)
|
||||
and isinstance(next(iter(data.values())), Sequence)
|
||||
):
|
||||
return ArrowBackend.from_pydict(data, max_rows=max_rows)
|
||||
|
||||
raise TypeError(
|
||||
f"Cannot automatically create backend for data of type: {type(data)}. "
|
||||
f"Data must be of type: Union[pa.Table, pa.RecordBatch, Path, str, "
|
||||
"Sequence[Iterable[Any]], Mapping[str, Sequence[Any]], pl.DataFrame",
|
||||
)
|
||||
|
||||
|
||||
def _is_iterable(item: Any) -> bool:
|
||||
try:
|
||||
iter(item)
|
||||
except TypeError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
_TableTypeT = TypeVar("_TableTypeT")
|
||||
|
||||
|
||||
class DataTableBackend(ABC, Generic[_TableTypeT]):
|
||||
data: _TableTypeT
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, data: _TableTypeT, max_rows: int | None = None) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_pydict(
|
||||
cls, data: Mapping[str, Sequence[Any]], max_rows: int | None = None
|
||||
) -> "DataTableBackend":
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def source_data(self) -> _TableTypeT:
|
||||
"""
|
||||
Return the source data as an Arrow table
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def source_row_count(self) -> int:
|
||||
"""
|
||||
The number of rows in the source data, before filtering down to max_rows
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def row_count(self) -> int:
|
||||
"""
|
||||
The number of rows in backend's retained data, after filtering down to max_rows
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def column_count(self) -> int:
|
||||
return len(self.columns)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def columns(self) -> Sequence[str]:
|
||||
"""
|
||||
A list of column labels
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def column_content_widths(self) -> Sequence[int]:
|
||||
"""
|
||||
A list of integers corresponding to the widest utf8 string length
|
||||
of any data in each column.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_row_at(self, index: int) -> Sequence[Any]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_column_at(self, index: int) -> Sequence[Any]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_cell_at(self, row_index: int, column_index: int) -> Any:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def append_column(self, label: str, default: Any | None = None) -> int:
|
||||
"""
|
||||
Returns column index
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def append_rows(self, records: Iterable[Iterable[Any]]) -> list[int]:
|
||||
"""
|
||||
Returns new row indicies
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def drop_row(self, row_index: int) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_cell(self, row_index: int, column_index: int, value: Any) -> None:
|
||||
"""
|
||||
Raises IndexError if bad indicies
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def sort(
|
||||
self, by: list[tuple[str, Literal["ascending", "descending"]]] | str
|
||||
) -> None:
|
||||
"""
|
||||
by: str sorts table by the data in the column with that name (asc).
|
||||
by: list[tuple] sorts the table by the named column(s) with the directions
|
||||
indicated.
|
||||
"""
|
||||
|
||||
|
||||
class ArrowBackend(DataTableBackend[pa.Table]):
|
||||
def __init__(self, data: pa.Table, max_rows: int | None = None) -> None:
|
||||
self._source_data = data
|
||||
|
||||
# Arrow allows duplicate field names, but a table's to_pylist() and
|
||||
# to_pydict() methods will drop duplicate-named fields!
|
||||
field_names: list[str] = []
|
||||
renamed = False
|
||||
for field in data.column_names:
|
||||
n = 0
|
||||
while field in field_names:
|
||||
field = f"{field}{n}"
|
||||
renamed = True
|
||||
n += 1
|
||||
field_names.append(field)
|
||||
if renamed:
|
||||
data = data.rename_columns(field_names)
|
||||
|
||||
self._source_row_count = data.num_rows
|
||||
if max_rows is not None and max_rows < self._source_row_count:
|
||||
self.data = data.slice(offset=0, length=max_rows)
|
||||
else:
|
||||
self.data = data
|
||||
self._console = Console()
|
||||
self._column_content_widths: list[int] = []
|
||||
|
||||
@staticmethod
|
||||
def _pydict_from_records(
|
||||
records: Sequence[Iterable[Any]], has_header: bool = False
|
||||
) -> dict[str, list[Any]]:
|
||||
headers = (
|
||||
records[0]
|
||||
if has_header
|
||||
else [f"f{i}" for i in range(len(list(records[0])))]
|
||||
)
|
||||
data = list(map(list, records[1:] if has_header else records))
|
||||
pydict = {header: [row[i] for row in data] for i, header in enumerate(headers)}
|
||||
return pydict
|
||||
|
||||
@staticmethod
|
||||
def _handle_overflow(scalar: pa.Scalar) -> Any | None:
|
||||
"""
|
||||
PyArrow may throw an OverflowError when casting arrow types
|
||||
to python types; in some cases we can catch these and
|
||||
present a sensible value in the data table; otherwise
|
||||
we return None.
|
||||
"""
|
||||
if pt.is_date32(scalar.type):
|
||||
if scalar.value > 0: # type: ignore[attr-defined]
|
||||
return date.max
|
||||
elif scalar.value <= 0: # type: ignore[attr-defined]
|
||||
return date.min
|
||||
elif pt.is_date64(scalar.type):
|
||||
if scalar.value > 0: # type: ignore[attr-defined]
|
||||
return date.max
|
||||
elif scalar.value <= 0: # type: ignore[attr-defined]
|
||||
return date.min
|
||||
elif pt.is_timestamp(scalar.type):
|
||||
if scalar.value > 0: # type: ignore[attr-defined]
|
||||
return datetime.max
|
||||
elif scalar.value <= 0: # type: ignore[attr-defined]
|
||||
return datetime.min
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_batches(
|
||||
cls, data: pa.RecordBatch, max_rows: int | None = None
|
||||
) -> "ArrowBackend":
|
||||
tbl = pa.Table.from_batches([data])
|
||||
return cls(tbl, max_rows=max_rows)
|
||||
|
||||
@classmethod
|
||||
def from_parquet(
|
||||
cls, path: Path | str, max_rows: int | None = None
|
||||
) -> "ArrowBackend":
|
||||
tbl = pq.read_table(str(path))
|
||||
return cls(tbl, max_rows=max_rows)
|
||||
|
||||
@classmethod
|
||||
def from_pydict(
|
||||
cls, data: Mapping[str, Sequence[Any]], max_rows: int | None = None
|
||||
) -> "ArrowBackend":
|
||||
try:
|
||||
tbl = pa.Table.from_pydict(dict(data))
|
||||
except (pal.ArrowInvalid, pal.ArrowTypeError):
|
||||
# one or more fields has mixed types, like int and
|
||||
# string. Cast all to string for safety
|
||||
new_data = {
|
||||
k: [str(val) if val is not None else None for val in v]
|
||||
for k, v in data.items()
|
||||
}
|
||||
tbl = pa.Table.from_pydict(new_data)
|
||||
return cls(tbl, max_rows=max_rows)
|
||||
|
||||
@classmethod
|
||||
def from_records(
|
||||
cls,
|
||||
records: Sequence[Iterable[Any]],
|
||||
has_header: bool = False,
|
||||
max_rows: int | None = None,
|
||||
) -> "ArrowBackend":
|
||||
pydict = cls._pydict_from_records(records, has_header)
|
||||
return cls.from_pydict(pydict, max_rows=max_rows)
|
||||
|
||||
@property
|
||||
def source_data(self) -> pa.Table:
|
||||
return self._source_data
|
||||
|
||||
@property
|
||||
def source_row_count(self) -> int:
|
||||
return self._source_row_count
|
||||
|
||||
@property
|
||||
def row_count(self) -> int:
|
||||
return self.data.num_rows
|
||||
|
||||
@property
|
||||
def column_count(self) -> int:
|
||||
return self.data.num_columns
|
||||
|
||||
@property
|
||||
def columns(self) -> Sequence[str]:
|
||||
return self.data.column_names
|
||||
|
||||
@property
|
||||
def column_content_widths(self) -> list[int]:
|
||||
if not self._column_content_widths:
|
||||
measurements = [self._measure(arr) for arr in self.data.columns]
|
||||
# pc.max returns None for each column without rows; we need to return 0
|
||||
# instead.
|
||||
self._column_content_widths = [cw or 0 for cw in measurements]
|
||||
|
||||
return self._column_content_widths
|
||||
|
||||
def get_row_at(self, index: int) -> Sequence[Any]:
|
||||
try:
|
||||
row: Dict[str, Any] = self.data.slice(index, length=1).to_pylist()[0]
|
||||
except OverflowError:
|
||||
return [
|
||||
self._handle_overflow(self.data[i][index])
|
||||
for i in range(len(self.columns))
|
||||
]
|
||||
else:
|
||||
return list(row.values())
|
||||
|
||||
def get_column_at(self, column_index: int) -> list[Any]:
|
||||
try:
|
||||
values = self.data[column_index].to_pylist()
|
||||
except OverflowError:
|
||||
# TODO: consider registering a scalar UDF here for parallel processing
|
||||
return [self._handle_overflow(scalar) for scalar in self.data[column_index]]
|
||||
else:
|
||||
return values
|
||||
|
||||
def get_cell_at(self, row_index: int, column_index: int) -> Any:
|
||||
scalar = self.data[column_index][row_index]
|
||||
try:
|
||||
value = scalar.as_py()
|
||||
except OverflowError:
|
||||
value = self._handle_overflow(scalar)
|
||||
return value
|
||||
|
||||
def append_column(self, label: str, default: Any | None = None) -> int:
|
||||
"""
|
||||
Returns column index
|
||||
"""
|
||||
if default is None:
|
||||
arr: pa.Array = pa.nulls(self.row_count)
|
||||
else:
|
||||
arr = pa.nulls(self.row_count, type=pa.string())
|
||||
arr = arr.fill_null(str(default))
|
||||
|
||||
self.data = self.data.append_column(label, arr)
|
||||
if self._column_content_widths:
|
||||
self._column_content_widths.append(measure_width(default, self._console))
|
||||
return self.data.num_columns - 1
|
||||
|
||||
def append_rows(self, records: Iterable[Iterable[Any]]) -> list[int]:
|
||||
rows = list(records)
|
||||
indicies = list(range(self.row_count, self.row_count + len(rows)))
|
||||
records_with_headers = [self.data.column_names, *rows]
|
||||
pydict = self._pydict_from_records(records_with_headers, has_header=True)
|
||||
old_rows = self.data.to_batches()
|
||||
new_rows = pa.RecordBatch.from_pydict(
|
||||
pydict,
|
||||
schema=self.data.schema,
|
||||
)
|
||||
self.data = pa.Table.from_batches([*old_rows, new_rows])
|
||||
self._reset_content_widths()
|
||||
return indicies
|
||||
|
||||
def drop_row(self, row_index: int) -> None:
|
||||
if row_index < 0 or row_index >= self.row_count:
|
||||
raise IndexError(f"Can't drop row {row_index} of {self.row_count}")
|
||||
above = self.data.slice(0, row_index).to_batches()
|
||||
below = self.data.slice(row_index + 1).to_batches()
|
||||
self.data = pa.Table.from_batches([*above, *below])
|
||||
self._reset_content_widths()
|
||||
pass
|
||||
|
||||
def update_cell(self, row_index: int, column_index: int, value: Any) -> None:
|
||||
column = self.data.column(column_index)
|
||||
pycolumn = self.get_column_at(column_index=column_index)
|
||||
pycolumn[row_index] = value
|
||||
new_type = pa.string() if pt.is_null(column.type) else column.type
|
||||
self.data = self.data.set_column(
|
||||
column_index,
|
||||
self.data.column_names[column_index],
|
||||
pa.array(pycolumn, type=new_type),
|
||||
)
|
||||
if self._column_content_widths:
|
||||
self._column_content_widths[column_index] = max(
|
||||
measure_width(value, self._console),
|
||||
self._column_content_widths[column_index],
|
||||
)
|
||||
|
||||
def sort(
|
||||
self, by: list[tuple[str, Literal["ascending", "descending"]]] | str
|
||||
) -> None:
|
||||
"""
|
||||
by: str sorts table by the data in the column with that name (asc).
|
||||
by: list[tuple] sorts the table by the named column(s) with the directions
|
||||
indicated.
|
||||
"""
|
||||
self.data = self.data.sort_by(by)
|
||||
|
||||
def _reset_content_widths(self) -> None:
|
||||
self._column_content_widths = []
|
||||
|
||||
def _measure(self, arr: pa._PandasConvertible) -> int:
|
||||
# with some types we can measure the width more efficiently
|
||||
if pt.is_boolean(arr.type):
|
||||
return 7
|
||||
elif pt.is_null(arr.type):
|
||||
return 0
|
||||
elif (
|
||||
pt.is_integer(arr.type)
|
||||
or pt.is_floating(arr.type)
|
||||
or pt.is_decimal(arr.type)
|
||||
):
|
||||
try:
|
||||
col_max = pc.max(arr.fill_null(0)).as_py()
|
||||
except OverflowError:
|
||||
col_max = 9223372036854775807
|
||||
try:
|
||||
col_min = pc.min(arr.fill_null(0)).as_py()
|
||||
except OverflowError:
|
||||
col_min = -9223372036854775807
|
||||
return max([measure_width(el, self._console) for el in [col_max, col_min]])
|
||||
elif pt.is_temporal(arr.type):
|
||||
try:
|
||||
value = arr.drop_null()[0].as_py()
|
||||
except OverflowError:
|
||||
return 26 # need space for the infinity sign and a space
|
||||
except IndexError:
|
||||
return 24
|
||||
else:
|
||||
# valid temporal types all have the same width for their type
|
||||
return measure_width(value, self._console)
|
||||
|
||||
# for everything else, we need to compute it
|
||||
# First, cast the data to strings
|
||||
try:
|
||||
arr = arr.cast(
|
||||
pa.string(),
|
||||
safe=False,
|
||||
)
|
||||
except (pal.ArrowNotImplementedError, pal.ArrowInvalid):
|
||||
# some types can't be casted to strings natively by arrow, but they
|
||||
# can be casted to strings by python. The arrow way is faster, but
|
||||
# if it fails, register a python udf and try again
|
||||
def py_str(_ctx: Any, arr: pa.Array) -> str | pa.Array | pa.ChunkedArray:
|
||||
return pa.array([str(el) for el in arr], type=pa.string())
|
||||
|
||||
udf_name = f"tfdt_pystr_{arr.type}"
|
||||
with suppress(pal.ArrowKeyError): # already registered
|
||||
pc.register_scalar_function(
|
||||
py_str,
|
||||
function_name=udf_name,
|
||||
function_doc={"summary": "str", "description": "built-in str"},
|
||||
in_types={"arr": arr.type},
|
||||
out_type=pa.string(),
|
||||
)
|
||||
|
||||
arr = pc.call_function(udf_name, [arr])
|
||||
|
||||
# next, try to measure the UTF-encoded string length of each cell,
|
||||
# then take the max
|
||||
try:
|
||||
width: int = pc.max(pc.utf8_length(arr.fill_null("")).fill_null(0)).as_py()
|
||||
except OverflowError:
|
||||
width = 10
|
||||
return width
|
||||
|
||||
|
||||
if _HAS_POLARS:
|
||||
|
||||
class PolarsBackend(DataTableBackend[pl.DataFrame]):
|
||||
@classmethod
|
||||
def from_file_path(
|
||||
cls, path: Path, max_rows: int | None = None, has_header: bool = True
|
||||
) -> "PolarsBackend":
|
||||
if path.suffix in [".arrow", ".feather"]:
|
||||
tbl = pl.read_ipc(path)
|
||||
elif path.suffix == ".arrows":
|
||||
tbl = pl.read_ipc_stream(path)
|
||||
elif path.suffix == ".json":
|
||||
tbl = pl.read_json(path)
|
||||
elif path.suffix == ".csv":
|
||||
tbl = pl.read_csv(path, has_header=has_header)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Dont know how to load file type {path.suffix} for {path}"
|
||||
)
|
||||
return cls(tbl, max_rows=max_rows)
|
||||
|
||||
@classmethod
|
||||
def from_pydict(
|
||||
cls, pydict: Mapping[str, Sequence[Any]], max_rows: int | None = None
|
||||
) -> "PolarsBackend":
|
||||
return cls(pl.from_dict(pydict), max_rows=max_rows)
|
||||
|
||||
@classmethod
|
||||
def from_dataframe(
|
||||
cls, frame: pl.DataFrame, max_rows: int | None = None
|
||||
) -> "PolarsBackend":
|
||||
return cls(frame, max_rows=max_rows)
|
||||
|
||||
def __init__(self, data: pl.DataFrame, max_rows: int | None = None) -> None:
|
||||
self._source_data = data
|
||||
|
||||
# Arrow allows duplicate field names, but a table's to_pylist() and
|
||||
# to_pydict() methods will drop duplicate-named fields!
|
||||
field_names: list[str] = []
|
||||
for field in data.columns:
|
||||
n = 0
|
||||
while field in field_names:
|
||||
field = f"{field}{n}"
|
||||
n += 1
|
||||
field_names.append(field)
|
||||
data.columns = field_names
|
||||
|
||||
self._source_row_count = len(data)
|
||||
if max_rows is not None and max_rows < self._source_row_count:
|
||||
self.data = data.slice(offset=0, length=max_rows)
|
||||
else:
|
||||
self.data = data
|
||||
self._console = Console()
|
||||
self._column_content_widths: list[int] = []
|
||||
|
||||
@property
|
||||
def source_data(self) -> pl.DataFrame:
|
||||
return self._source_data
|
||||
|
||||
@property
|
||||
def source_row_count(self) -> int:
|
||||
return self._source_row_count
|
||||
|
||||
@property
|
||||
def row_count(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
@property
|
||||
def column_count(self) -> int:
|
||||
return len(self.data.columns)
|
||||
|
||||
@property
|
||||
def columns(self) -> Sequence[str]:
|
||||
return self.data.columns
|
||||
|
||||
def get_row_at(self, index: int) -> Sequence[Any]:
|
||||
if index < 0 or index >= len(self.data):
|
||||
raise IndexError(
|
||||
f"Cannot get row={index} in table with {len(self.data)} rows "
|
||||
f"and {len(self.data.columns)} cols"
|
||||
)
|
||||
return list(self.data.slice(index, length=1).to_dicts()[0].values())
|
||||
|
||||
def get_column_at(self, column_index: int) -> Sequence[Any]:
|
||||
if column_index < 0 or column_index >= len(self.data.columns):
|
||||
raise IndexError(
|
||||
f"Cannot get column={column_index} in table with {len(self.data)} "
|
||||
f"rows and {len(self.data.columns)} cols."
|
||||
)
|
||||
return list(self.data.to_series(column_index))
|
||||
|
||||
def get_cell_at(self, row_index: int, column_index: int) -> Any:
|
||||
if (
|
||||
row_index >= len(self.data)
|
||||
or row_index < 0
|
||||
or column_index < 0
|
||||
or column_index >= len(self.data.columns)
|
||||
):
|
||||
raise IndexError(
|
||||
f"Cannot get cell at row={row_index} col={column_index} in table "
|
||||
f"with {len(self.data)} rows and {len(self.data.columns)} cols"
|
||||
)
|
||||
return self.data.to_series(column_index)[row_index]
|
||||
|
||||
def drop_row(self, row_index: int) -> None:
|
||||
if row_index < 0 or row_index >= self.row_count:
|
||||
raise IndexError(f"Can't drop row {row_index} of {self.row_count}")
|
||||
above = self.data.slice(0, row_index)
|
||||
below = self.data.slice(row_index + 1)
|
||||
self.data = pl.concat([above, below])
|
||||
self._reset_content_widths()
|
||||
|
||||
def append_rows(self, records: Iterable[Iterable[Any]]) -> list[int]:
|
||||
rows_to_add = pl.from_dicts(
|
||||
[dict(zip(self.data.columns, row)) for row in records]
|
||||
)
|
||||
indicies = list(range(self.row_count, self.row_count + len(rows_to_add)))
|
||||
self.data = pl.concat([self.data, rows_to_add])
|
||||
self._reset_content_widths()
|
||||
return indicies
|
||||
|
||||
def append_column(self, label: str, default: Any | None = None) -> int:
|
||||
"""
|
||||
Returns column index
|
||||
"""
|
||||
self.data = self.data.with_columns(
|
||||
pl.Series([default])
|
||||
.extend_constant(default, self.row_count - 1)
|
||||
.alias(label)
|
||||
)
|
||||
if self._column_content_widths:
|
||||
self._column_content_widths.append(
|
||||
measure_width(default, self._console)
|
||||
)
|
||||
return len(self.data.columns) - 1
|
||||
|
||||
def _reset_content_widths(self) -> None:
|
||||
self._column_content_widths = []
|
||||
|
||||
def update_cell(self, row_index: int, column_index: int, value: Any) -> None:
|
||||
if row_index >= len(self.data) or column_index >= len(self.data.columns):
|
||||
raise IndexError(
|
||||
f"Cannot update cell at row={row_index} col={column_index} in "
|
||||
f"table with {len(self.data)} rows and "
|
||||
f"{len(self.data.columns)} cols"
|
||||
)
|
||||
col_name = self.data.columns[column_index]
|
||||
self.data = self.data.with_columns(
|
||||
self.data.to_series(column_index)
|
||||
.scatter(row_index, value)
|
||||
.alias(col_name)
|
||||
)
|
||||
if self._column_content_widths:
|
||||
self._column_content_widths[column_index] = max(
|
||||
measure_width(value, self._console),
|
||||
self._column_content_widths[column_index],
|
||||
)
|
||||
|
||||
@property
|
||||
def column_content_widths(self) -> list[int]:
|
||||
if not self._column_content_widths:
|
||||
measurements = [
|
||||
self._measure(self.data[arr]) for arr in self.data.columns
|
||||
]
|
||||
# pc.max returns None for each column without rows; we need to return 0
|
||||
# instead.
|
||||
self._column_content_widths = [cw or 0 for cw in measurements]
|
||||
|
||||
return self._column_content_widths
|
||||
|
||||
def _measure(self, arr: pl.Series) -> int:
|
||||
# with some types we can measure the width more efficiently
|
||||
dtype = arr.dtype
|
||||
if dtype == pld.Categorical():
|
||||
return self._measure(arr.cat.get_categories())
|
||||
|
||||
if dtype.is_decimal() or dtype.is_float() or dtype.is_integer():
|
||||
col_max = arr.max()
|
||||
col_min = arr.min()
|
||||
return max(
|
||||
[measure_width(el, self._console) for el in [col_max, col_min]]
|
||||
)
|
||||
if dtype.is_temporal():
|
||||
try:
|
||||
value = arr.drop_nulls()[0]
|
||||
except IndexError:
|
||||
return 0
|
||||
else:
|
||||
return measure_width(value, self._console)
|
||||
if dtype.is_(pld.Boolean()):
|
||||
return 7
|
||||
|
||||
# for everything else, we need to compute it
|
||||
|
||||
arr = arr.cast(
|
||||
pl.Utf8(),
|
||||
strict=False,
|
||||
)
|
||||
width = arr.fill_null("<null>").str.len_chars().max()
|
||||
assert isinstance(width, int)
|
||||
return width
|
||||
|
||||
def sort(
|
||||
self, by: list[tuple[str, Literal["ascending", "descending"]]] | str
|
||||
) -> None:
|
||||
"""
|
||||
by: str sorts table by the data in the column with that name (asc).
|
||||
by: list[tuple] sorts the table by the named column(s) with the directions
|
||||
indicated.
|
||||
"""
|
||||
if isinstance(by, str):
|
||||
cols = [by]
|
||||
typs = [False]
|
||||
else:
|
||||
cols = [x for x, _ in by]
|
||||
typs = [x == "descending" for _, x in by]
|
||||
self.data = self.data.sort(cols, descending=typs)
|
47
src/textual_fastdatatable/column.py
Normal file
47
src/textual_fastdatatable/column.py
Normal file
|
@ -0,0 +1,47 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
from rich.text import Text
|
||||
|
||||
CELL_X_PADDING = 2
|
||||
|
||||
SNAKE_ID_PROG = re.compile(r"(\b|_)id\b", flags=re.IGNORECASE)
|
||||
CAMEL_ID_PROG = re.compile(r"[a-z]I[dD]\b")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Column:
|
||||
"""Metadata for a column in the DataTable."""
|
||||
|
||||
label: Text
|
||||
width: int = 0
|
||||
content_width: int = 0
|
||||
auto_width: bool = False
|
||||
max_content_width: int | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._is_id: bool | None = None
|
||||
|
||||
@property
|
||||
def render_width(self) -> int:
|
||||
"""Width in cells, required to render a column."""
|
||||
# +2 is to account for space padding either side of the cell
|
||||
if self.auto_width and self.max_content_width is not None:
|
||||
return (
|
||||
min(max(len(self.label), self.content_width), self.max_content_width)
|
||||
+ CELL_X_PADDING
|
||||
)
|
||||
elif self.auto_width:
|
||||
return max(len(self.label), self.content_width) + CELL_X_PADDING
|
||||
else:
|
||||
return self.width + CELL_X_PADDING
|
||||
|
||||
@property
|
||||
def is_id(self) -> bool:
|
||||
if self._is_id is None:
|
||||
snake_id = SNAKE_ID_PROG.search(str(self.label)) is not None
|
||||
camel_id = CAMEL_ID_PROG.search(str(self.label)) is not None
|
||||
self._is_id = snake_id or camel_id
|
||||
return self._is_id
|
2808
src/textual_fastdatatable/data_table.py
Normal file
2808
src/textual_fastdatatable/data_table.py
Normal file
File diff suppressed because it is too large
Load diff
101
src/textual_fastdatatable/formatter.py
Normal file
101
src/textual_fastdatatable/formatter.py
Normal file
|
@ -0,0 +1,101 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime, time, timedelta
|
||||
from decimal import Decimal
|
||||
from typing import cast
|
||||
|
||||
from rich.align import Align
|
||||
from rich.console import Console, RenderableType
|
||||
from rich.errors import MarkupError
|
||||
from rich.markup import escape
|
||||
from rich.protocol import is_renderable
|
||||
from rich.text import Text
|
||||
|
||||
from textual_fastdatatable.column import Column
|
||||
|
||||
|
||||
def cell_formatter(
|
||||
obj: object, null_rep: Text, col: Column | None = None, render_markup: bool = True
|
||||
) -> RenderableType:
|
||||
"""Convert a cell into a Rich renderable for display.
|
||||
|
||||
For correct formatting, clients should call `locale.setlocale()` first.
|
||||
|
||||
Args:
|
||||
obj: Data for a cell.
|
||||
col: Column that the cell came from (used to compute width).
|
||||
|
||||
Returns:
|
||||
A renderable to be displayed which represents the data.
|
||||
"""
|
||||
if obj is None:
|
||||
return Align(null_rep, align="center")
|
||||
|
||||
elif isinstance(obj, str) and render_markup:
|
||||
try:
|
||||
rich_text: Text | str = Text.from_markup(obj)
|
||||
except MarkupError:
|
||||
rich_text = escape(obj)
|
||||
return rich_text
|
||||
|
||||
elif isinstance(obj, str):
|
||||
return escape(obj)
|
||||
|
||||
elif isinstance(obj, bool):
|
||||
return Align(
|
||||
f"[dim]{'✓' if obj else 'X'}[/] {obj}{' ' if obj else ''}",
|
||||
style="bold" if obj else "",
|
||||
align="right",
|
||||
)
|
||||
|
||||
elif isinstance(obj, (float, Decimal)):
|
||||
return Align(f"{obj:n}", align="right")
|
||||
|
||||
elif isinstance(obj, int):
|
||||
if col is not None and col.is_id:
|
||||
# no separators in ID fields
|
||||
return Align(str(obj), align="right")
|
||||
else:
|
||||
return Align(f"{obj:n}", align="right")
|
||||
|
||||
elif isinstance(obj, (datetime, time)):
|
||||
|
||||
def _fmt_datetime(obj: datetime | time) -> str:
|
||||
return obj.isoformat(timespec="milliseconds").replace("+00:00", "Z")
|
||||
|
||||
if obj in (datetime.max, datetime.min):
|
||||
return Align(
|
||||
(
|
||||
f"[bold]{'∞ ' if obj == datetime.max else '-∞ '}[/]"
|
||||
f"[dim]{_fmt_datetime(obj)}[/]"
|
||||
),
|
||||
align="right",
|
||||
)
|
||||
|
||||
return Align(_fmt_datetime(obj), align="right")
|
||||
|
||||
elif isinstance(obj, date):
|
||||
if obj in (date.max, date.min):
|
||||
return Align(
|
||||
(
|
||||
f"[bold]{'∞ ' if obj == date.max else '-∞ '}[/]"
|
||||
f"[dim]{obj.isoformat()}[/]"
|
||||
),
|
||||
align="right",
|
||||
)
|
||||
|
||||
return Align(obj.isoformat(), align="right")
|
||||
|
||||
elif isinstance(obj, timedelta):
|
||||
return Align(str(obj), align="right")
|
||||
|
||||
elif not is_renderable(obj):
|
||||
return str(obj)
|
||||
|
||||
else:
|
||||
return cast(RenderableType, obj)
|
||||
|
||||
|
||||
def measure_width(obj: object, console: Console) -> int:
|
||||
renderable = cell_formatter(obj, null_rep=Text(""))
|
||||
return console.measure(renderable).maximum
|
0
src/textual_fastdatatable/py.typed
Normal file
0
src/textual_fastdatatable/py.typed
Normal file
Loading…
Add table
Add a link
Reference in a new issue