1
0
Fork 0
sqlglot/sqlglot/executor/context.py
Daniel Baumann 87cdb8246e
Adding upstream version 10.0.8.
Signed-off-by: Daniel Baumann <daniel@debian.org>
2025-02-13 14:53:43 +01:00

101 lines
3.3 KiB
Python

from __future__ import annotations
import typing as t
from sqlglot.executor.env import ENV
if t.TYPE_CHECKING:
from sqlglot.executor.table import Table, TableIter
class Context:
"""
Execution context for sql expressions.
Context is used to hold relevant data tables which can then be queried on with eval.
References to columns can either be scalar or vectors. When set_row is used, column references
evaluate to scalars while set_range evaluates to vectors. This allows convenient and efficient
evaluation of aggregation functions.
"""
def __init__(self, tables: t.Dict[str, Table], env: t.Optional[t.Dict] = None) -> None:
"""
Args
tables: representing the scope of the current execution context.
env: dictionary of functions within the execution context.
"""
self.tables = tables
self._table: t.Optional[Table] = None
self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
self.row_readers = {name: table.reader for name, table in tables.items()}
self.env = {**(env or {}), "scope": self.row_readers}
def eval(self, code):
return eval(code, ENV, self.env)
def eval_tuple(self, codes):
return tuple(self.eval(code) for code in codes)
@property
def table(self) -> Table:
if self._table is None:
self._table = list(self.tables.values())[0]
for other in self.tables.values():
if self._table.columns != other.columns:
raise Exception(f"Columns are different.")
if len(self._table.rows) != len(other.rows):
raise Exception(f"Rows are different.")
return self._table
def add_columns(self, *columns: str) -> None:
for table in self.tables.values():
table.add_columns(*columns)
@property
def columns(self) -> t.Tuple:
return self.table.columns
def __iter__(self):
self.env["scope"] = self.row_readers
for i in range(len(self.table.rows)):
for table in self.tables.values():
reader = table[i]
yield reader, self
def table_iter(self, table: str) -> t.Generator[t.Tuple[TableIter, Context], None, None]:
self.env["scope"] = self.row_readers
for reader in self.tables[table]:
yield reader, self
def filter(self, condition) -> None:
rows = [reader.row for reader, _ in self if self.eval(condition)]
for table in self.tables.values():
table.rows = rows
def sort(self, key) -> None:
def sort_key(row: t.Tuple) -> t.Tuple:
self.set_row(row)
return self.eval_tuple(key)
self.table.rows.sort(key=sort_key)
def set_row(self, row: t.Tuple) -> None:
for table in self.tables.values():
table.reader.row = row
self.env["scope"] = self.row_readers
def set_index(self, index: int) -> None:
for table in self.tables.values():
table[index]
self.env["scope"] = self.row_readers
def set_range(self, start: int, end: int) -> None:
for name in self.tables:
self.range_readers[name].range = range(start, end)
self.env["scope"] = self.range_readers
def __contains__(self, table: str) -> bool:
return table in self.tables