1
0
Fork 0

Adding upstream version 0.12.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-24 10:57:24 +01:00
parent d887bee5ca
commit 148efc9122
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
69 changed files with 12923 additions and 0 deletions

157
src/scripts/benchmark.py Normal file
View file

@ -0,0 +1,157 @@
from __future__ import annotations
import gc
from pathlib import Path
from time import perf_counter
import pandas as pd
import polars as pl
from textual.app import App, ComposeResult
from textual.driver import Driver
from textual.pilot import Pilot
from textual.types import CSSPathType
from textual.widgets import DataTable as BuiltinDataTable
from textual_fastdatatable import ArrowBackend
from textual_fastdatatable import DataTable as FastDataTable
from textual_fastdatatable.backend import PolarsBackend
BENCHMARK_DATA = Path(__file__).parent.parent.parent / "tests" / "data"
async def scroller(pilot: Pilot) -> None:
first_paint = perf_counter() - pilot.app.start # type: ignore
for _ in range(5):
await pilot.press("pagedown")
for _ in range(15):
await pilot.press("right")
for _ in range(5):
await pilot.press("pagedown")
elapsed = perf_counter() - pilot.app.start # type: ignore
pilot.app.exit(result=(first_paint, elapsed))
class BuiltinApp(App):
TITLE = "Built-In DataTable"
def __init__(
self,
data_path: Path,
driver_class: type[Driver] | None = None,
css_path: CSSPathType | None = None,
watch_css: bool = False,
):
super().__init__(driver_class, css_path, watch_css)
self.data_path = data_path
def compose(self) -> ComposeResult:
df = pd.read_parquet(self.data_path)
rows = [tuple(row) for row in df.itertuples(index=False)]
self.start = perf_counter()
table: BuiltinDataTable = BuiltinDataTable()
table.add_columns(*[str(col) for col in df.columns])
for row in rows:
table.add_row(*row, height=1, label=None)
yield table
class ArrowBackendApp(App):
TITLE = "FastDataTable (Arrow from Parquet)"
def __init__(
self,
data_path: Path,
driver_class: type[Driver] | None = None,
css_path: CSSPathType | None = None,
watch_css: bool = False,
):
super().__init__(driver_class, css_path, watch_css)
self.data_path = data_path
def compose(self) -> ComposeResult:
self.start = perf_counter()
yield FastDataTable(data=self.data_path)
class ArrowBackendAppFromRecords(App):
TITLE = "FastDataTable (Arrow from Records)"
def __init__(
self,
data_path: Path,
driver_class: type[Driver] | None = None,
css_path: CSSPathType | None = None,
watch_css: bool = False,
):
super().__init__(driver_class, css_path, watch_css)
self.data_path = data_path
def compose(self) -> ComposeResult:
df = pd.read_parquet(self.data_path)
rows = [tuple(row) for row in df.itertuples(index=False)]
self.start = perf_counter()
backend = ArrowBackend.from_records(rows, has_header=False)
table = FastDataTable(
backend=backend, column_labels=[str(col) for col in df.columns]
)
yield table
class PolarsBackendApp(App):
TITLE = "FastDataTable (Polars from Parquet)"
def __init__(
self,
data_path: Path,
driver_class: type[Driver] | None = None,
css_path: CSSPathType | None = None,
watch_css: bool = False,
):
super().__init__(driver_class, css_path, watch_css)
self.data_path = data_path
def compose(self) -> ComposeResult:
self.start = perf_counter()
yield FastDataTable(
data=PolarsBackend.from_dataframe(pl.read_parquet(self.data_path))
)
if __name__ == "__main__":
app_defs = [
BuiltinApp,
ArrowBackendApp,
ArrowBackendAppFromRecords,
PolarsBackendApp,
]
bench = [
(f"lap_times_{n}.parquet", 3 if n <= 10000 else 1)
for n in [100, 1000, 10000, 100000, 538121]
]
bench.extend([(f"wide_{n}.parquet", 1) for n in [10000, 100000]])
with open("results.md", "w") as f:
print(
"Records |",
" | ".join([a.TITLE for a in app_defs]), # type: ignore
sep="",
file=f,
)
print("--------|", "|".join(["--------" for _ in app_defs]), sep="", file=f)
for p, tries in bench:
first_paint: list[list[float]] = [list() for _ in app_defs]
elapsed: list[list[float]] = [list() for _ in app_defs]
for i, app_cls in enumerate(app_defs):
for _ in range(tries):
app = app_cls(BENCHMARK_DATA / p)
gc.disable()
fp, el = app.run(headless=True, auto_pilot=scroller) # type: ignore
gc.collect()
first_paint[i].append(fp)
elapsed[i].append(el)
gc.enable()
avg_first_paint = [sum(app_times) / tries for app_times in first_paint]
avg_elapsed = [sum(app_times) / tries for app_times in elapsed]
formatted = [
f"{fp:7,.3f}s / {el:7,.3f}s"
for fp, el in zip(avg_first_paint, avg_elapsed)
]
print(f"{p} | {' | '.join(formatted)}", file=f)

View file

@ -0,0 +1,32 @@
from __future__ import annotations
from pathlib import Path
from textual.app import App, ComposeResult
from textual.driver import Driver
from textual.types import CSSPathType
from textual_fastdatatable import DataTable
BENCHMARK_DATA = Path(__file__).parent.parent.parent / "tests" / "data"
class ArrowBackendApp(App):
TITLE = "FastDataTable (Arrow)"
def __init__(
self,
data_path: Path,
driver_class: type[Driver] | None = None,
css_path: CSSPathType | None = None,
watch_css: bool = False,
):
super().__init__(driver_class, css_path, watch_css)
self.data_path = data_path
def compose(self) -> ComposeResult:
yield DataTable(data=self.data_path)
if __name__ == "__main__":
app = ArrowBackendApp(data_path=BENCHMARK_DATA / "wide_100000.parquet")
app.run()

View file

@ -0,0 +1,39 @@
from __future__ import annotations
from pathlib import Path
import pandas as pd
from textual.app import App, ComposeResult
from textual.driver import Driver
from textual.types import CSSPathType
from textual.widgets import DataTable
BENCHMARK_DATA = Path(__file__).parent.parent.parent / "tests" / "data"
class BuiltinApp(App):
TITLE = "Built-In DataTable"
def __init__(
self,
data_path: Path,
driver_class: type[Driver] | None = None,
css_path: CSSPathType | None = None,
watch_css: bool = False,
):
super().__init__(driver_class, css_path, watch_css)
self.data_path = data_path
def compose(self) -> ComposeResult:
df = pd.read_parquet(self.data_path)
rows = [tuple(row) for row in df.itertuples(index=False)]
table: DataTable = DataTable()
table.add_columns(*[str(col) for col in df.columns])
for row in rows:
table.add_row(*row, height=1, label=None)
yield table
if __name__ == "__main__":
app = BuiltinApp(data_path=BENCHMARK_DATA / "wide_10000.parquet")
app.run()

View file

@ -0,0 +1,13 @@
from textual_fastdatatable.backend import (
ArrowBackend,
DataTableBackend,
create_backend,
)
from textual_fastdatatable.data_table import DataTable
__all__ = [
"DataTable",
"ArrowBackend",
"DataTableBackend",
"create_backend",
]

View file

@ -0,0 +1,19 @@
from textual.app import App, ComposeResult
from textual_fastdatatable import ArrowBackend, DataTable
class TableApp(App, inherit_bindings=False):
BINDINGS = [("ctrl+q", "quit", "Quit"), ("ctrl+d", "quit", "Quit")]
def compose(self) -> ComposeResult:
backend = ArrowBackend.from_parquet("./tests/data/wide_100000.parquet")
yield DataTable(backend=backend, cursor_type="range", fixed_columns=2)
if __name__ == "__main__":
import locale
locale.setlocale(locale.LC_ALL, "")
app = TableApp()
app.run()

View file

@ -0,0 +1,706 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from contextlib import suppress
from datetime import date, datetime
from pathlib import Path
from typing import (
Any,
Dict,
Generic,
Iterable,
Literal,
Mapping,
Sequence,
TypeVar,
)
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.lib as pal
import pyarrow.parquet as pq
import pyarrow.types as pt
from rich.console import Console
from textual_fastdatatable.formatter import measure_width
AutoBackendType = Any
try:
import polars as pl
import polars.datatypes as pld
except ImportError:
_HAS_POLARS = False
else:
_HAS_POLARS = True
def create_backend(
data: "AutoBackendType",
max_rows: int | None = None,
has_header: bool = False,
) -> DataTableBackend:
if isinstance(data, pa.Table):
return ArrowBackend(data, max_rows=max_rows)
if isinstance(data, pa.RecordBatch):
return ArrowBackend.from_batches(data, max_rows=max_rows)
if _HAS_POLARS and isinstance(data, pl.DataFrame):
return PolarsBackend.from_dataframe(data, max_rows=max_rows)
if isinstance(data, Path) or isinstance(data, str):
data = Path(data)
if data.suffix in [".pqt", ".parquet"]:
return ArrowBackend.from_parquet(data, max_rows=max_rows)
if _HAS_POLARS:
return PolarsBackend.from_file_path(
data, max_rows=max_rows, has_header=has_header
)
if isinstance(data, Sequence) and not data:
return ArrowBackend(pa.table([]), max_rows=max_rows)
if isinstance(data, Sequence) and _is_iterable(data[0]):
return ArrowBackend.from_records(data, max_rows=max_rows, has_header=has_header)
if (
isinstance(data, Mapping)
and isinstance(next(iter(data.keys())), str)
and isinstance(next(iter(data.values())), Sequence)
):
return ArrowBackend.from_pydict(data, max_rows=max_rows)
raise TypeError(
f"Cannot automatically create backend for data of type: {type(data)}. "
f"Data must be of type: Union[pa.Table, pa.RecordBatch, Path, str, "
"Sequence[Iterable[Any]], Mapping[str, Sequence[Any]], pl.DataFrame",
)
def _is_iterable(item: Any) -> bool:
try:
iter(item)
except TypeError:
return False
else:
return True
_TableTypeT = TypeVar("_TableTypeT")
class DataTableBackend(ABC, Generic[_TableTypeT]):
data: _TableTypeT
@abstractmethod
def __init__(self, data: _TableTypeT, max_rows: int | None = None) -> None:
pass
@classmethod
@abstractmethod
def from_pydict(
cls, data: Mapping[str, Sequence[Any]], max_rows: int | None = None
) -> "DataTableBackend":
pass
@property
@abstractmethod
def source_data(self) -> _TableTypeT:
"""
Return the source data as an Arrow table
"""
pass
@property
@abstractmethod
def source_row_count(self) -> int:
"""
The number of rows in the source data, before filtering down to max_rows
"""
pass
@property
@abstractmethod
def row_count(self) -> int:
"""
The number of rows in backend's retained data, after filtering down to max_rows
"""
pass
@property
def column_count(self) -> int:
return len(self.columns)
@property
@abstractmethod
def columns(self) -> Sequence[str]:
"""
A list of column labels
"""
pass
@property
@abstractmethod
def column_content_widths(self) -> Sequence[int]:
"""
A list of integers corresponding to the widest utf8 string length
of any data in each column.
"""
pass
@abstractmethod
def get_row_at(self, index: int) -> Sequence[Any]:
pass
@abstractmethod
def get_column_at(self, index: int) -> Sequence[Any]:
pass
@abstractmethod
def get_cell_at(self, row_index: int, column_index: int) -> Any:
pass
@abstractmethod
def append_column(self, label: str, default: Any | None = None) -> int:
"""
Returns column index
"""
@abstractmethod
def append_rows(self, records: Iterable[Iterable[Any]]) -> list[int]:
"""
Returns new row indicies
"""
pass
@abstractmethod
def drop_row(self, row_index: int) -> None:
pass
@abstractmethod
def update_cell(self, row_index: int, column_index: int, value: Any) -> None:
"""
Raises IndexError if bad indicies
"""
@abstractmethod
def sort(
self, by: list[tuple[str, Literal["ascending", "descending"]]] | str
) -> None:
"""
by: str sorts table by the data in the column with that name (asc).
by: list[tuple] sorts the table by the named column(s) with the directions
indicated.
"""
class ArrowBackend(DataTableBackend[pa.Table]):
def __init__(self, data: pa.Table, max_rows: int | None = None) -> None:
self._source_data = data
# Arrow allows duplicate field names, but a table's to_pylist() and
# to_pydict() methods will drop duplicate-named fields!
field_names: list[str] = []
renamed = False
for field in data.column_names:
n = 0
while field in field_names:
field = f"{field}{n}"
renamed = True
n += 1
field_names.append(field)
if renamed:
data = data.rename_columns(field_names)
self._source_row_count = data.num_rows
if max_rows is not None and max_rows < self._source_row_count:
self.data = data.slice(offset=0, length=max_rows)
else:
self.data = data
self._console = Console()
self._column_content_widths: list[int] = []
@staticmethod
def _pydict_from_records(
records: Sequence[Iterable[Any]], has_header: bool = False
) -> dict[str, list[Any]]:
headers = (
records[0]
if has_header
else [f"f{i}" for i in range(len(list(records[0])))]
)
data = list(map(list, records[1:] if has_header else records))
pydict = {header: [row[i] for row in data] for i, header in enumerate(headers)}
return pydict
@staticmethod
def _handle_overflow(scalar: pa.Scalar) -> Any | None:
"""
PyArrow may throw an OverflowError when casting arrow types
to python types; in some cases we can catch these and
present a sensible value in the data table; otherwise
we return None.
"""
if pt.is_date32(scalar.type):
if scalar.value > 0: # type: ignore[attr-defined]
return date.max
elif scalar.value <= 0: # type: ignore[attr-defined]
return date.min
elif pt.is_date64(scalar.type):
if scalar.value > 0: # type: ignore[attr-defined]
return date.max
elif scalar.value <= 0: # type: ignore[attr-defined]
return date.min
elif pt.is_timestamp(scalar.type):
if scalar.value > 0: # type: ignore[attr-defined]
return datetime.max
elif scalar.value <= 0: # type: ignore[attr-defined]
return datetime.min
return None
@classmethod
def from_batches(
cls, data: pa.RecordBatch, max_rows: int | None = None
) -> "ArrowBackend":
tbl = pa.Table.from_batches([data])
return cls(tbl, max_rows=max_rows)
@classmethod
def from_parquet(
cls, path: Path | str, max_rows: int | None = None
) -> "ArrowBackend":
tbl = pq.read_table(str(path))
return cls(tbl, max_rows=max_rows)
@classmethod
def from_pydict(
cls, data: Mapping[str, Sequence[Any]], max_rows: int | None = None
) -> "ArrowBackend":
try:
tbl = pa.Table.from_pydict(dict(data))
except (pal.ArrowInvalid, pal.ArrowTypeError):
# one or more fields has mixed types, like int and
# string. Cast all to string for safety
new_data = {
k: [str(val) if val is not None else None for val in v]
for k, v in data.items()
}
tbl = pa.Table.from_pydict(new_data)
return cls(tbl, max_rows=max_rows)
@classmethod
def from_records(
cls,
records: Sequence[Iterable[Any]],
has_header: bool = False,
max_rows: int | None = None,
) -> "ArrowBackend":
pydict = cls._pydict_from_records(records, has_header)
return cls.from_pydict(pydict, max_rows=max_rows)
@property
def source_data(self) -> pa.Table:
return self._source_data
@property
def source_row_count(self) -> int:
return self._source_row_count
@property
def row_count(self) -> int:
return self.data.num_rows
@property
def column_count(self) -> int:
return self.data.num_columns
@property
def columns(self) -> Sequence[str]:
return self.data.column_names
@property
def column_content_widths(self) -> list[int]:
if not self._column_content_widths:
measurements = [self._measure(arr) for arr in self.data.columns]
# pc.max returns None for each column without rows; we need to return 0
# instead.
self._column_content_widths = [cw or 0 for cw in measurements]
return self._column_content_widths
def get_row_at(self, index: int) -> Sequence[Any]:
try:
row: Dict[str, Any] = self.data.slice(index, length=1).to_pylist()[0]
except OverflowError:
return [
self._handle_overflow(self.data[i][index])
for i in range(len(self.columns))
]
else:
return list(row.values())
def get_column_at(self, column_index: int) -> list[Any]:
try:
values = self.data[column_index].to_pylist()
except OverflowError:
# TODO: consider registering a scalar UDF here for parallel processing
return [self._handle_overflow(scalar) for scalar in self.data[column_index]]
else:
return values
def get_cell_at(self, row_index: int, column_index: int) -> Any:
scalar = self.data[column_index][row_index]
try:
value = scalar.as_py()
except OverflowError:
value = self._handle_overflow(scalar)
return value
def append_column(self, label: str, default: Any | None = None) -> int:
"""
Returns column index
"""
if default is None:
arr: pa.Array = pa.nulls(self.row_count)
else:
arr = pa.nulls(self.row_count, type=pa.string())
arr = arr.fill_null(str(default))
self.data = self.data.append_column(label, arr)
if self._column_content_widths:
self._column_content_widths.append(measure_width(default, self._console))
return self.data.num_columns - 1
def append_rows(self, records: Iterable[Iterable[Any]]) -> list[int]:
rows = list(records)
indicies = list(range(self.row_count, self.row_count + len(rows)))
records_with_headers = [self.data.column_names, *rows]
pydict = self._pydict_from_records(records_with_headers, has_header=True)
old_rows = self.data.to_batches()
new_rows = pa.RecordBatch.from_pydict(
pydict,
schema=self.data.schema,
)
self.data = pa.Table.from_batches([*old_rows, new_rows])
self._reset_content_widths()
return indicies
def drop_row(self, row_index: int) -> None:
if row_index < 0 or row_index >= self.row_count:
raise IndexError(f"Can't drop row {row_index} of {self.row_count}")
above = self.data.slice(0, row_index).to_batches()
below = self.data.slice(row_index + 1).to_batches()
self.data = pa.Table.from_batches([*above, *below])
self._reset_content_widths()
pass
def update_cell(self, row_index: int, column_index: int, value: Any) -> None:
column = self.data.column(column_index)
pycolumn = self.get_column_at(column_index=column_index)
pycolumn[row_index] = value
new_type = pa.string() if pt.is_null(column.type) else column.type
self.data = self.data.set_column(
column_index,
self.data.column_names[column_index],
pa.array(pycolumn, type=new_type),
)
if self._column_content_widths:
self._column_content_widths[column_index] = max(
measure_width(value, self._console),
self._column_content_widths[column_index],
)
def sort(
self, by: list[tuple[str, Literal["ascending", "descending"]]] | str
) -> None:
"""
by: str sorts table by the data in the column with that name (asc).
by: list[tuple] sorts the table by the named column(s) with the directions
indicated.
"""
self.data = self.data.sort_by(by)
def _reset_content_widths(self) -> None:
self._column_content_widths = []
def _measure(self, arr: pa._PandasConvertible) -> int:
# with some types we can measure the width more efficiently
if pt.is_boolean(arr.type):
return 7
elif pt.is_null(arr.type):
return 0
elif (
pt.is_integer(arr.type)
or pt.is_floating(arr.type)
or pt.is_decimal(arr.type)
):
try:
col_max = pc.max(arr.fill_null(0)).as_py()
except OverflowError:
col_max = 9223372036854775807
try:
col_min = pc.min(arr.fill_null(0)).as_py()
except OverflowError:
col_min = -9223372036854775807
return max([measure_width(el, self._console) for el in [col_max, col_min]])
elif pt.is_temporal(arr.type):
try:
value = arr.drop_null()[0].as_py()
except OverflowError:
return 26 # need space for the infinity sign and a space
except IndexError:
return 24
else:
# valid temporal types all have the same width for their type
return measure_width(value, self._console)
# for everything else, we need to compute it
# First, cast the data to strings
try:
arr = arr.cast(
pa.string(),
safe=False,
)
except (pal.ArrowNotImplementedError, pal.ArrowInvalid):
# some types can't be casted to strings natively by arrow, but they
# can be casted to strings by python. The arrow way is faster, but
# if it fails, register a python udf and try again
def py_str(_ctx: Any, arr: pa.Array) -> str | pa.Array | pa.ChunkedArray:
return pa.array([str(el) for el in arr], type=pa.string())
udf_name = f"tfdt_pystr_{arr.type}"
with suppress(pal.ArrowKeyError): # already registered
pc.register_scalar_function(
py_str,
function_name=udf_name,
function_doc={"summary": "str", "description": "built-in str"},
in_types={"arr": arr.type},
out_type=pa.string(),
)
arr = pc.call_function(udf_name, [arr])
# next, try to measure the UTF-encoded string length of each cell,
# then take the max
try:
width: int = pc.max(pc.utf8_length(arr.fill_null("")).fill_null(0)).as_py()
except OverflowError:
width = 10
return width
if _HAS_POLARS:
class PolarsBackend(DataTableBackend[pl.DataFrame]):
@classmethod
def from_file_path(
cls, path: Path, max_rows: int | None = None, has_header: bool = True
) -> "PolarsBackend":
if path.suffix in [".arrow", ".feather"]:
tbl = pl.read_ipc(path)
elif path.suffix == ".arrows":
tbl = pl.read_ipc_stream(path)
elif path.suffix == ".json":
tbl = pl.read_json(path)
elif path.suffix == ".csv":
tbl = pl.read_csv(path, has_header=has_header)
else:
raise TypeError(
f"Dont know how to load file type {path.suffix} for {path}"
)
return cls(tbl, max_rows=max_rows)
@classmethod
def from_pydict(
cls, pydict: Mapping[str, Sequence[Any]], max_rows: int | None = None
) -> "PolarsBackend":
return cls(pl.from_dict(pydict), max_rows=max_rows)
@classmethod
def from_dataframe(
cls, frame: pl.DataFrame, max_rows: int | None = None
) -> "PolarsBackend":
return cls(frame, max_rows=max_rows)
def __init__(self, data: pl.DataFrame, max_rows: int | None = None) -> None:
self._source_data = data
# Arrow allows duplicate field names, but a table's to_pylist() and
# to_pydict() methods will drop duplicate-named fields!
field_names: list[str] = []
for field in data.columns:
n = 0
while field in field_names:
field = f"{field}{n}"
n += 1
field_names.append(field)
data.columns = field_names
self._source_row_count = len(data)
if max_rows is not None and max_rows < self._source_row_count:
self.data = data.slice(offset=0, length=max_rows)
else:
self.data = data
self._console = Console()
self._column_content_widths: list[int] = []
@property
def source_data(self) -> pl.DataFrame:
return self._source_data
@property
def source_row_count(self) -> int:
return self._source_row_count
@property
def row_count(self) -> int:
return len(self.data)
@property
def column_count(self) -> int:
return len(self.data.columns)
@property
def columns(self) -> Sequence[str]:
return self.data.columns
def get_row_at(self, index: int) -> Sequence[Any]:
if index < 0 or index >= len(self.data):
raise IndexError(
f"Cannot get row={index} in table with {len(self.data)} rows "
f"and {len(self.data.columns)} cols"
)
return list(self.data.slice(index, length=1).to_dicts()[0].values())
def get_column_at(self, column_index: int) -> Sequence[Any]:
if column_index < 0 or column_index >= len(self.data.columns):
raise IndexError(
f"Cannot get column={column_index} in table with {len(self.data)} "
f"rows and {len(self.data.columns)} cols."
)
return list(self.data.to_series(column_index))
def get_cell_at(self, row_index: int, column_index: int) -> Any:
if (
row_index >= len(self.data)
or row_index < 0
or column_index < 0
or column_index >= len(self.data.columns)
):
raise IndexError(
f"Cannot get cell at row={row_index} col={column_index} in table "
f"with {len(self.data)} rows and {len(self.data.columns)} cols"
)
return self.data.to_series(column_index)[row_index]
def drop_row(self, row_index: int) -> None:
if row_index < 0 or row_index >= self.row_count:
raise IndexError(f"Can't drop row {row_index} of {self.row_count}")
above = self.data.slice(0, row_index)
below = self.data.slice(row_index + 1)
self.data = pl.concat([above, below])
self._reset_content_widths()
def append_rows(self, records: Iterable[Iterable[Any]]) -> list[int]:
rows_to_add = pl.from_dicts(
[dict(zip(self.data.columns, row)) for row in records]
)
indicies = list(range(self.row_count, self.row_count + len(rows_to_add)))
self.data = pl.concat([self.data, rows_to_add])
self._reset_content_widths()
return indicies
def append_column(self, label: str, default: Any | None = None) -> int:
"""
Returns column index
"""
self.data = self.data.with_columns(
pl.Series([default])
.extend_constant(default, self.row_count - 1)
.alias(label)
)
if self._column_content_widths:
self._column_content_widths.append(
measure_width(default, self._console)
)
return len(self.data.columns) - 1
def _reset_content_widths(self) -> None:
self._column_content_widths = []
def update_cell(self, row_index: int, column_index: int, value: Any) -> None:
if row_index >= len(self.data) or column_index >= len(self.data.columns):
raise IndexError(
f"Cannot update cell at row={row_index} col={column_index} in "
f"table with {len(self.data)} rows and "
f"{len(self.data.columns)} cols"
)
col_name = self.data.columns[column_index]
self.data = self.data.with_columns(
self.data.to_series(column_index)
.scatter(row_index, value)
.alias(col_name)
)
if self._column_content_widths:
self._column_content_widths[column_index] = max(
measure_width(value, self._console),
self._column_content_widths[column_index],
)
@property
def column_content_widths(self) -> list[int]:
if not self._column_content_widths:
measurements = [
self._measure(self.data[arr]) for arr in self.data.columns
]
# pc.max returns None for each column without rows; we need to return 0
# instead.
self._column_content_widths = [cw or 0 for cw in measurements]
return self._column_content_widths
def _measure(self, arr: pl.Series) -> int:
# with some types we can measure the width more efficiently
dtype = arr.dtype
if dtype == pld.Categorical():
return self._measure(arr.cat.get_categories())
if dtype.is_decimal() or dtype.is_float() or dtype.is_integer():
col_max = arr.max()
col_min = arr.min()
return max(
[measure_width(el, self._console) for el in [col_max, col_min]]
)
if dtype.is_temporal():
try:
value = arr.drop_nulls()[0]
except IndexError:
return 0
else:
return measure_width(value, self._console)
if dtype.is_(pld.Boolean()):
return 7
# for everything else, we need to compute it
arr = arr.cast(
pl.Utf8(),
strict=False,
)
width = arr.fill_null("<null>").str.len_chars().max()
assert isinstance(width, int)
return width
def sort(
self, by: list[tuple[str, Literal["ascending", "descending"]]] | str
) -> None:
"""
by: str sorts table by the data in the column with that name (asc).
by: list[tuple] sorts the table by the named column(s) with the directions
indicated.
"""
if isinstance(by, str):
cols = [by]
typs = [False]
else:
cols = [x for x, _ in by]
typs = [x == "descending" for _, x in by]
self.data = self.data.sort(cols, descending=typs)

View file

@ -0,0 +1,47 @@
from __future__ import annotations
import re
from dataclasses import dataclass
from rich.text import Text
CELL_X_PADDING = 2
SNAKE_ID_PROG = re.compile(r"(\b|_)id\b", flags=re.IGNORECASE)
CAMEL_ID_PROG = re.compile(r"[a-z]I[dD]\b")
@dataclass
class Column:
"""Metadata for a column in the DataTable."""
label: Text
width: int = 0
content_width: int = 0
auto_width: bool = False
max_content_width: int | None = None
def __post_init__(self) -> None:
self._is_id: bool | None = None
@property
def render_width(self) -> int:
"""Width in cells, required to render a column."""
# +2 is to account for space padding either side of the cell
if self.auto_width and self.max_content_width is not None:
return (
min(max(len(self.label), self.content_width), self.max_content_width)
+ CELL_X_PADDING
)
elif self.auto_width:
return max(len(self.label), self.content_width) + CELL_X_PADDING
else:
return self.width + CELL_X_PADDING
@property
def is_id(self) -> bool:
if self._is_id is None:
snake_id = SNAKE_ID_PROG.search(str(self.label)) is not None
camel_id = CAMEL_ID_PROG.search(str(self.label)) is not None
self._is_id = snake_id or camel_id
return self._is_id

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,101 @@
from __future__ import annotations
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from typing import cast
from rich.align import Align
from rich.console import Console, RenderableType
from rich.errors import MarkupError
from rich.markup import escape
from rich.protocol import is_renderable
from rich.text import Text
from textual_fastdatatable.column import Column
def cell_formatter(
obj: object, null_rep: Text, col: Column | None = None, render_markup: bool = True
) -> RenderableType:
"""Convert a cell into a Rich renderable for display.
For correct formatting, clients should call `locale.setlocale()` first.
Args:
obj: Data for a cell.
col: Column that the cell came from (used to compute width).
Returns:
A renderable to be displayed which represents the data.
"""
if obj is None:
return Align(null_rep, align="center")
elif isinstance(obj, str) and render_markup:
try:
rich_text: Text | str = Text.from_markup(obj)
except MarkupError:
rich_text = escape(obj)
return rich_text
elif isinstance(obj, str):
return escape(obj)
elif isinstance(obj, bool):
return Align(
f"[dim]{'' if obj else 'X'}[/] {obj}{' ' if obj else ''}",
style="bold" if obj else "",
align="right",
)
elif isinstance(obj, (float, Decimal)):
return Align(f"{obj:n}", align="right")
elif isinstance(obj, int):
if col is not None and col.is_id:
# no separators in ID fields
return Align(str(obj), align="right")
else:
return Align(f"{obj:n}", align="right")
elif isinstance(obj, (datetime, time)):
def _fmt_datetime(obj: datetime | time) -> str:
return obj.isoformat(timespec="milliseconds").replace("+00:00", "Z")
if obj in (datetime.max, datetime.min):
return Align(
(
f"[bold]{'' if obj == datetime.max else '-∞ '}[/]"
f"[dim]{_fmt_datetime(obj)}[/]"
),
align="right",
)
return Align(_fmt_datetime(obj), align="right")
elif isinstance(obj, date):
if obj in (date.max, date.min):
return Align(
(
f"[bold]{'' if obj == date.max else '-∞ '}[/]"
f"[dim]{obj.isoformat()}[/]"
),
align="right",
)
return Align(obj.isoformat(), align="right")
elif isinstance(obj, timedelta):
return Align(str(obj), align="right")
elif not is_renderable(obj):
return str(obj)
else:
return cast(RenderableType, obj)
def measure_width(obj: object, console: Console) -> int:
renderable = cell_formatter(obj, null_rep=Text(""))
return console.measure(renderable).maximum

View file