Edit on GitHub

sqlglot.executor.table

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot.dialects.dialect import DialectType
  6from sqlglot.helper import dict_depth
  7from sqlglot.schema import AbstractMappingSchema, normalize_name
  8
  9
 10class Table:
 11    def __init__(self, columns, rows=None, column_range=None):
 12        self.columns = tuple(columns)
 13        self.column_range = column_range
 14        self.reader = RowReader(self.columns, self.column_range)
 15        self.rows = rows or []
 16        if rows:
 17            assert len(rows[0]) == len(self.columns)
 18        self.range_reader = RangeReader(self)
 19
 20    def add_columns(self, *columns: str) -> None:
 21        self.columns += columns
 22        if self.column_range:
 23            self.column_range = range(
 24                self.column_range.start, self.column_range.stop + len(columns)
 25            )
 26        self.reader = RowReader(self.columns, self.column_range)
 27
 28    def append(self, row):
 29        assert len(row) == len(self.columns)
 30        self.rows.append(row)
 31
 32    def pop(self):
 33        self.rows.pop()
 34
 35    def to_pylist(self):
 36        return [dict(zip(self.columns, row)) for row in self.rows]
 37
 38    @property
 39    def width(self):
 40        return len(self.columns)
 41
 42    def __len__(self):
 43        return len(self.rows)
 44
 45    def __iter__(self):
 46        return TableIter(self)
 47
 48    def __getitem__(self, index):
 49        self.reader.row = self.rows[index]
 50        return self.reader
 51
 52    def __repr__(self):
 53        columns = tuple(
 54            column
 55            for i, column in enumerate(self.columns)
 56            if not self.column_range or i in self.column_range
 57        )
 58        widths = {column: len(column) for column in columns}
 59        lines = [" ".join(column for column in columns)]
 60
 61        for i, row in enumerate(self):
 62            if i > 10:
 63                break
 64
 65            lines.append(
 66                " ".join(
 67                    str(row[column]).rjust(widths[column])[0 : widths[column]] for column in columns
 68                )
 69            )
 70        return "\n".join(lines)
 71
 72
 73class TableIter:
 74    def __init__(self, table):
 75        self.table = table
 76        self.index = -1
 77
 78    def __iter__(self):
 79        return self
 80
 81    def __next__(self):
 82        self.index += 1
 83        if self.index < len(self.table):
 84            return self.table[self.index]
 85        raise StopIteration
 86
 87
 88class RangeReader:
 89    def __init__(self, table):
 90        self.table = table
 91        self.range = range(0)
 92
 93    def __len__(self):
 94        return len(self.range)
 95
 96    def __getitem__(self, column):
 97        return (self.table[i][column] for i in self.range)
 98
 99
100class RowReader:
101    def __init__(self, columns, column_range=None):
102        self.columns = {
103            column: i for i, column in enumerate(columns) if not column_range or i in column_range
104        }
105        self.row = None
106
107    def __getitem__(self, column):
108        return self.row[self.columns[column]]
109
110
111class Tables(AbstractMappingSchema):
112    pass
113
114
115def ensure_tables(d: t.Optional[t.Dict], dialect: DialectType = None) -> Tables:
116    return Tables(_ensure_tables(d, dialect=dialect))
117
118
119def _ensure_tables(d: t.Optional[t.Dict], dialect: DialectType = None) -> t.Dict:
120    if not d:
121        return {}
122
123    depth = dict_depth(d)
124    if depth > 1:
125        return {
126            normalize_name(k, dialect=dialect, is_table=True).name: _ensure_tables(
127                v, dialect=dialect
128            )
129            for k, v in d.items()
130        }
131
132    result = {}
133    for table_name, table in d.items():
134        table_name = normalize_name(table_name, dialect=dialect).name
135
136        if isinstance(table, Table):
137            result[table_name] = table
138        else:
139            table = [
140                {
141                    normalize_name(column_name, dialect=dialect).name: value
142                    for column_name, value in row.items()
143                }
144                for row in table
145            ]
146            column_names = tuple(column_name for column_name in table[0]) if table else ()
147            rows = [tuple(row[name] for name in column_names) for row in table]
148            result[table_name] = Table(columns=column_names, rows=rows)
149
150    return result
class Table:
11class Table:
12    def __init__(self, columns, rows=None, column_range=None):
13        self.columns = tuple(columns)
14        self.column_range = column_range
15        self.reader = RowReader(self.columns, self.column_range)
16        self.rows = rows or []
17        if rows:
18            assert len(rows[0]) == len(self.columns)
19        self.range_reader = RangeReader(self)
20
21    def add_columns(self, *columns: str) -> None:
22        self.columns += columns
23        if self.column_range:
24            self.column_range = range(
25                self.column_range.start, self.column_range.stop + len(columns)
26            )
27        self.reader = RowReader(self.columns, self.column_range)
28
29    def append(self, row):
30        assert len(row) == len(self.columns)
31        self.rows.append(row)
32
33    def pop(self):
34        self.rows.pop()
35
36    def to_pylist(self):
37        return [dict(zip(self.columns, row)) for row in self.rows]
38
39    @property
40    def width(self):
41        return len(self.columns)
42
43    def __len__(self):
44        return len(self.rows)
45
46    def __iter__(self):
47        return TableIter(self)
48
49    def __getitem__(self, index):
50        self.reader.row = self.rows[index]
51        return self.reader
52
53    def __repr__(self):
54        columns = tuple(
55            column
56            for i, column in enumerate(self.columns)
57            if not self.column_range or i in self.column_range
58        )
59        widths = {column: len(column) for column in columns}
60        lines = [" ".join(column for column in columns)]
61
62        for i, row in enumerate(self):
63            if i > 10:
64                break
65
66            lines.append(
67                " ".join(
68                    str(row[column]).rjust(widths[column])[0 : widths[column]] for column in columns
69                )
70            )
71        return "\n".join(lines)
Table(columns, rows=None, column_range=None)
12    def __init__(self, columns, rows=None, column_range=None):
13        self.columns = tuple(columns)
14        self.column_range = column_range
15        self.reader = RowReader(self.columns, self.column_range)
16        self.rows = rows or []
17        if rows:
18            assert len(rows[0]) == len(self.columns)
19        self.range_reader = RangeReader(self)
columns
column_range
reader
rows
range_reader
def add_columns(self, *columns: str) -> None:
21    def add_columns(self, *columns: str) -> None:
22        self.columns += columns
23        if self.column_range:
24            self.column_range = range(
25                self.column_range.start, self.column_range.stop + len(columns)
26            )
27        self.reader = RowReader(self.columns, self.column_range)
def append(self, row):
29    def append(self, row):
30        assert len(row) == len(self.columns)
31        self.rows.append(row)
def pop(self):
33    def pop(self):
34        self.rows.pop()
def to_pylist(self):
36    def to_pylist(self):
37        return [dict(zip(self.columns, row)) for row in self.rows]
width
39    @property
40    def width(self):
41        return len(self.columns)
class TableIter:
74class TableIter:
75    def __init__(self, table):
76        self.table = table
77        self.index = -1
78
79    def __iter__(self):
80        return self
81
82    def __next__(self):
83        self.index += 1
84        if self.index < len(self.table):
85            return self.table[self.index]
86        raise StopIteration
TableIter(table)
75    def __init__(self, table):
76        self.table = table
77        self.index = -1
table
index
class RangeReader:
89class RangeReader:
90    def __init__(self, table):
91        self.table = table
92        self.range = range(0)
93
94    def __len__(self):
95        return len(self.range)
96
97    def __getitem__(self, column):
98        return (self.table[i][column] for i in self.range)
RangeReader(table)
90    def __init__(self, table):
91        self.table = table
92        self.range = range(0)
table
range
class RowReader:
101class RowReader:
102    def __init__(self, columns, column_range=None):
103        self.columns = {
104            column: i for i, column in enumerate(columns) if not column_range or i in column_range
105        }
106        self.row = None
107
108    def __getitem__(self, column):
109        return self.row[self.columns[column]]
RowReader(columns, column_range=None)
102    def __init__(self, columns, column_range=None):
103        self.columns = {
104            column: i for i, column in enumerate(columns) if not column_range or i in column_range
105        }
106        self.row = None
columns
row
class Tables(sqlglot.schema.AbstractMappingSchema):
112class Tables(AbstractMappingSchema):
113    pass
def ensure_tables( d: Optional[Dict], dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> Tables:
116def ensure_tables(d: t.Optional[t.Dict], dialect: DialectType = None) -> Tables:
117    return Tables(_ensure_tables(d, dialect=dialect))