Edit on GitHub

sqlglot.helper

  1from __future__ import annotations
  2
  3import datetime
  4import inspect
  5import logging
  6import re
  7import sys
  8import typing as t
  9from collections.abc import Collection, Set
 10from contextlib import contextmanager
 11from copy import copy
 12from enum import Enum
 13from itertools import count
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot import exp
 17    from sqlglot._typing import A, E, T
 18    from sqlglot.dialects.dialect import DialectType
 19    from sqlglot.expressions import Expression
 20
 21
 22CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
 23PYTHON_VERSION = sys.version_info[:2]
 24logger = logging.getLogger("sqlglot")
 25
 26
 27class AutoName(Enum):
 28    """
 29    This is used for creating Enum classes where `auto()` is the string form
 30    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
 31
 32    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
 33    """
 34
 35    def _generate_next_value_(name, _start, _count, _last_values):
 36        return name
 37
 38
 39class classproperty(property):
 40    """
 41    Similar to a normal property but works for class methods
 42    """
 43
 44    def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
 45        return classmethod(self.fget).__get__(None, owner)()  # type: ignore
 46
 47
 48def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
 49    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
 50    try:
 51        return seq[index]
 52    except IndexError:
 53        return None
 54
 55
 56@t.overload
 57def ensure_list(value: t.Collection[T]) -> t.List[T]: ...
 58
 59
 60@t.overload
 61def ensure_list(value: None) -> t.List: ...
 62
 63
 64@t.overload
 65def ensure_list(value: T) -> t.List[T]: ...
 66
 67
 68def ensure_list(value):
 69    """
 70    Ensures that a value is a list, otherwise casts or wraps it into one.
 71
 72    Args:
 73        value: The value of interest.
 74
 75    Returns:
 76        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
 77    """
 78    if value is None:
 79        return []
 80    if isinstance(value, (list, tuple)):
 81        return list(value)
 82
 83    return [value]
 84
 85
 86@t.overload
 87def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: ...
 88
 89
 90@t.overload
 91def ensure_collection(value: T) -> t.Collection[T]: ...
 92
 93
 94def ensure_collection(value):
 95    """
 96    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
 97
 98    Args:
 99        value: The value of interest.
100
101    Returns:
102        The value if it's a collection, or else the value wrapped in a list.
103    """
104    if value is None:
105        return []
106    return (
107        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
108    )
109
110
111def csv(*args: str, sep: str = ", ") -> str:
112    """
113    Formats any number of string arguments as CSV.
114
115    Args:
116        args: The string arguments to format.
117        sep: The argument separator.
118
119    Returns:
120        The arguments formatted as a CSV string.
121    """
122    return sep.join(arg for arg in args if arg)
123
124
125def subclasses(
126    module_name: str,
127    classes: t.Type | t.Tuple[t.Type, ...],
128    exclude: t.Type | t.Tuple[t.Type, ...] = (),
129) -> t.List[t.Type]:
130    """
131    Returns all subclasses for a collection of classes, possibly excluding some of them.
132
133    Args:
134        module_name: The name of the module to search for subclasses in.
135        classes: Class(es) we want to find the subclasses of.
136        exclude: Class(es) we want to exclude from the returned list.
137
138    Returns:
139        The target subclasses.
140    """
141    return [
142        obj
143        for _, obj in inspect.getmembers(
144            sys.modules[module_name],
145            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
146        )
147    ]
148
149
150def apply_index_offset(
151    this: exp.Expression,
152    expressions: t.List[E],
153    offset: int,
154    dialect: DialectType = None,
155) -> t.List[E]:
156    """
157    Applies an offset to a given integer literal expression.
158
159    Args:
160        this: The target of the index.
161        expressions: The expression the offset will be applied to, wrapped in a list.
162        offset: The offset that will be applied.
163        dialect: the dialect of interest.
164
165    Returns:
166        The original expression with the offset applied to it, wrapped in a list. If the provided
167        `expressions` argument contains more than one expression, it's returned unaffected.
168    """
169    if not offset or len(expressions) != 1:
170        return expressions
171
172    expression = expressions[0]
173
174    from sqlglot import exp
175    from sqlglot.optimizer.annotate_types import annotate_types
176    from sqlglot.optimizer.simplify import simplify
177
178    if not this.type:
179        annotate_types(this, dialect=dialect)
180
181    if t.cast(exp.DataType, this.type).this not in (
182        exp.DataType.Type.UNKNOWN,
183        exp.DataType.Type.ARRAY,
184    ):
185        return expressions
186
187    if not expression.type:
188        annotate_types(expression, dialect=dialect)
189
190    if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
191        logger.info("Applying array index offset (%s)", offset)
192        expression = simplify(expression + offset)
193        return [expression]
194
195    return expressions
196
197
198def camel_to_snake_case(name: str) -> str:
199    """Converts `name` from camelCase to snake_case and returns the result."""
200    return CAMEL_CASE_PATTERN.sub("_", name).upper()
201
202
203def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
204    """
205    Applies a transformation to a given expression until a fix point is reached.
206
207    Args:
208        expression: The expression to be transformed.
209        func: The transformation to be applied.
210
211    Returns:
212        The transformed expression.
213    """
214    while True:
215        for n in reversed(tuple(expression.walk())):
216            n._hash = hash(n)
217
218        start = hash(expression)
219        expression = func(expression)
220
221        for n in expression.walk():
222            n._hash = None
223        if start == hash(expression):
224            break
225
226    return expression
227
228
229def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
230    """
231    Sorts a given directed acyclic graph in topological order.
232
233    Args:
234        dag: The graph to be sorted.
235
236    Returns:
237        A list that contains all of the graph's nodes in topological order.
238    """
239    result = []
240
241    for node, deps in tuple(dag.items()):
242        for dep in deps:
243            if dep not in dag:
244                dag[dep] = set()
245
246    while dag:
247        current = {node for node, deps in dag.items() if not deps}
248
249        if not current:
250            raise ValueError("Cycle error")
251
252        for node in current:
253            dag.pop(node)
254
255        for deps in dag.values():
256            deps -= current
257
258        result.extend(sorted(current))  # type: ignore
259
260    return result
261
262
263def open_file(file_name: str) -> t.TextIO:
264    """Open a file that may be compressed as gzip and return it in universal newline mode."""
265    with open(file_name, "rb") as f:
266        gzipped = f.read(2) == b"\x1f\x8b"
267
268    if gzipped:
269        import gzip
270
271        return gzip.open(file_name, "rt", newline="")
272
273    return open(file_name, encoding="utf-8", newline="")
274
275
276@contextmanager
277def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
278    """
279    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
280
281    Args:
282        read_csv: A `ReadCSV` function call.
283
284    Yields:
285        A python csv reader.
286    """
287    args = read_csv.expressions
288    file = open_file(read_csv.name)
289
290    delimiter = ","
291    args = iter(arg.name for arg in args)  # type: ignore
292    for k, v in zip(args, args):
293        if k == "delimiter":
294            delimiter = v
295
296    try:
297        import csv as csv_
298
299        yield csv_.reader(file, delimiter=delimiter)
300    finally:
301        file.close()
302
303
304def find_new_name(taken: t.Collection[str], base: str) -> str:
305    """
306    Searches for a new name.
307
308    Args:
309        taken: A collection of taken names.
310        base: Base name to alter.
311
312    Returns:
313        The new, available name.
314    """
315    if base not in taken:
316        return base
317
318    i = 2
319    new = f"{base}_{i}"
320    while new in taken:
321        i += 1
322        new = f"{base}_{i}"
323
324    return new
325
326
327def is_int(text: str) -> bool:
328    return is_type(text, int)
329
330
331def is_float(text: str) -> bool:
332    return is_type(text, float)
333
334
335def is_type(text: str, target_type: t.Type) -> bool:
336    try:
337        target_type(text)
338        return True
339    except ValueError:
340        return False
341
342
343def name_sequence(prefix: str) -> t.Callable[[], str]:
344    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
345    sequence = count()
346    return lambda: f"{prefix}{next(sequence)}"
347
348
349def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
350    """Returns a dictionary created from an object's attributes."""
351    return {
352        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
353        **kwargs,
354    }
355
356
357def split_num_words(
358    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
359) -> t.List[t.Optional[str]]:
360    """
361    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
362
363    Args:
364        value: The value to be split.
365        sep: The value to use to split on.
366        min_num_words: The minimum number of words that are going to be in the result.
367        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
368
369    Examples:
370        >>> split_num_words("db.table", ".", 3)
371        [None, 'db', 'table']
372        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
373        ['db', 'table', None]
374        >>> split_num_words("db.table", ".", 1)
375        ['db', 'table']
376
377    Returns:
378        The list of words returned by `split`, possibly augmented by a number of `None` values.
379    """
380    words = value.split(sep)
381    if fill_from_start:
382        return [None] * (min_num_words - len(words)) + words
383    return words + [None] * (min_num_words - len(words))
384
385
386def is_iterable(value: t.Any) -> bool:
387    """
388    Checks if the value is an iterable, excluding the types `str` and `bytes`.
389
390    Examples:
391        >>> is_iterable([1,2])
392        True
393        >>> is_iterable("test")
394        False
395
396    Args:
397        value: The value to check if it is an iterable.
398
399    Returns:
400        A `bool` value indicating if it is an iterable.
401    """
402    from sqlglot import Expression
403
404    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))
405
406
407def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
408    """
409    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
410    type `str` and `bytes` are not regarded as iterables.
411
412    Examples:
413        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
414        [1, 2, 3, 4, 5, 'bla']
415        >>> list(flatten([1, 2, 3]))
416        [1, 2, 3]
417
418    Args:
419        values: The value to be flattened.
420
421    Yields:
422        Non-iterable elements in `values`.
423    """
424    for value in values:
425        if is_iterable(value):
426            yield from flatten(value)
427        else:
428            yield value
429
430
431def dict_depth(d: t.Dict) -> int:
432    """
433    Get the nesting depth of a dictionary.
434
435    Example:
436        >>> dict_depth(None)
437        0
438        >>> dict_depth({})
439        1
440        >>> dict_depth({"a": "b"})
441        1
442        >>> dict_depth({"a": {}})
443        2
444        >>> dict_depth({"a": {"b": {}}})
445        3
446    """
447    try:
448        return 1 + dict_depth(next(iter(d.values())))
449    except AttributeError:
450        # d doesn't have attribute "values"
451        return 0
452    except StopIteration:
453        # d.values() returns an empty sequence
454        return 1
455
456
457def first(it: t.Iterable[T]) -> T:
458    """Returns the first element from an iterable (useful for sets)."""
459    return next(i for i in it)
460
461
462def to_bool(value: t.Optional[str | bool]) -> t.Optional[str | bool]:
463    if isinstance(value, bool) or value is None:
464        return value
465
466    # Coerce the value to boolean if it matches to the truthy/falsy values below
467    value_lower = value.lower()
468    if value_lower in ("true", "1"):
469        return True
470    if value_lower in ("false", "0"):
471        return False
472
473    return value
474
475
476def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
477    """
478    Merges a sequence of ranges, represented as tuples (low, high) whose values
479    belong to some totally-ordered set.
480
481    Example:
482        >>> merge_ranges([(1, 3), (2, 6)])
483        [(1, 6)]
484    """
485    if not ranges:
486        return []
487
488    ranges = sorted(ranges)
489
490    merged = [ranges[0]]
491
492    for start, end in ranges[1:]:
493        last_start, last_end = merged[-1]
494
495        if start <= last_end:
496            merged[-1] = (last_start, max(last_end, end))
497        else:
498            merged.append((start, end))
499
500    return merged
501
502
503def is_iso_date(text: str) -> bool:
504    try:
505        datetime.date.fromisoformat(text)
506        return True
507    except ValueError:
508        return False
509
510
511def is_iso_datetime(text: str) -> bool:
512    try:
513        datetime.datetime.fromisoformat(text)
514        return True
515    except ValueError:
516        return False
517
518
519# Interval units that operate on date components
520DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
521
522
523def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
524    return expression is not None and expression.name.lower() in DATE_UNITS
525
526
527K = t.TypeVar("K")
528V = t.TypeVar("V")
529
530
531class SingleValuedMapping(t.Mapping[K, V]):
532    """
533    Mapping where all keys return the same value.
534
535    This rigamarole is meant to avoid copying keys, which was originally intended
536    as an optimization while qualifying columns for tables with lots of columns.
537    """
538
539    def __init__(self, keys: t.Collection[K], value: V):
540        self._keys = keys if isinstance(keys, Set) else set(keys)
541        self._value = value
542
543    def __getitem__(self, key: K) -> V:
544        if key in self._keys:
545            return self._value
546        raise KeyError(key)
547
548    def __len__(self) -> int:
549        return len(self._keys)
550
551    def __iter__(self) -> t.Iterator[K]:
552        return iter(self._keys)
CAMEL_CASE_PATTERN = re.compile('(?<!^)(?=[A-Z])')
PYTHON_VERSION = (3, 10)
logger = <Logger sqlglot (WARNING)>
class AutoName(enum.Enum):
28class AutoName(Enum):
29    """
30    This is used for creating Enum classes where `auto()` is the string form
31    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
32
33    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
34    """
35
36    def _generate_next_value_(name, _start, _count, _last_values):
37        return name

This is used for creating Enum classes where auto() is the string form of the corresponding enum's identifier (e.g. FOO.value results in "FOO").

Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values

class classproperty(builtins.property):
40class classproperty(property):
41    """
42    Similar to a normal property but works for class methods
43    """
44
45    def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
46        return classmethod(self.fget).__get__(None, owner)()  # type: ignore

Similar to a normal property but works for class methods

def seq_get(seq: Sequence[~T], index: int) -> Optional[~T]:
49def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
50    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
51    try:
52        return seq[index]
53    except IndexError:
54        return None

Returns the value in seq at position index, or None if index is out of bounds.

def ensure_list(value):
69def ensure_list(value):
70    """
71    Ensures that a value is a list, otherwise casts or wraps it into one.
72
73    Args:
74        value: The value of interest.
75
76    Returns:
77        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
78    """
79    if value is None:
80        return []
81    if isinstance(value, (list, tuple)):
82        return list(value)
83
84    return [value]

Ensures that a value is a list, otherwise casts or wraps it into one.

Arguments:
  • value: The value of interest.
Returns:

The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.

def ensure_collection(value):
 95def ensure_collection(value):
 96    """
 97    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
 98
 99    Args:
100        value: The value of interest.
101
102    Returns:
103        The value if it's a collection, or else the value wrapped in a list.
104    """
105    if value is None:
106        return []
107    return (
108        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
109    )

Ensures that a value is a collection (excluding str and bytes), otherwise wraps it into a list.

Arguments:
  • value: The value of interest.
Returns:

The value if it's a collection, or else the value wrapped in a list.

def csv(*args: str, sep: str = ', ') -> str:
112def csv(*args: str, sep: str = ", ") -> str:
113    """
114    Formats any number of string arguments as CSV.
115
116    Args:
117        args: The string arguments to format.
118        sep: The argument separator.
119
120    Returns:
121        The arguments formatted as a CSV string.
122    """
123    return sep.join(arg for arg in args if arg)

Formats any number of string arguments as CSV.

Arguments:
  • args: The string arguments to format.
  • sep: The argument separator.
Returns:

The arguments formatted as a CSV string.

def subclasses( module_name: str, classes: Union[Type, Tuple[Type, ...]], exclude: Union[Type, Tuple[Type, ...]] = ()) -> List[Type]:
126def subclasses(
127    module_name: str,
128    classes: t.Type | t.Tuple[t.Type, ...],
129    exclude: t.Type | t.Tuple[t.Type, ...] = (),
130) -> t.List[t.Type]:
131    """
132    Returns all subclasses for a collection of classes, possibly excluding some of them.
133
134    Args:
135        module_name: The name of the module to search for subclasses in.
136        classes: Class(es) we want to find the subclasses of.
137        exclude: Class(es) we want to exclude from the returned list.
138
139    Returns:
140        The target subclasses.
141    """
142    return [
143        obj
144        for _, obj in inspect.getmembers(
145            sys.modules[module_name],
146            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
147        )
148    ]

Returns all subclasses for a collection of classes, possibly excluding some of them.

Arguments:
  • module_name: The name of the module to search for subclasses in.
  • classes: Class(es) we want to find the subclasses of.
  • exclude: Class(es) we want to exclude from the returned list.
Returns:

The target subclasses.

def apply_index_offset( this: sqlglot.expressions.Expression, expressions: List[~E], offset: int, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> List[~E]:
151def apply_index_offset(
152    this: exp.Expression,
153    expressions: t.List[E],
154    offset: int,
155    dialect: DialectType = None,
156) -> t.List[E]:
157    """
158    Applies an offset to a given integer literal expression.
159
160    Args:
161        this: The target of the index.
162        expressions: The expression the offset will be applied to, wrapped in a list.
163        offset: The offset that will be applied.
164        dialect: the dialect of interest.
165
166    Returns:
167        The original expression with the offset applied to it, wrapped in a list. If the provided
168        `expressions` argument contains more than one expression, it's returned unaffected.
169    """
170    if not offset or len(expressions) != 1:
171        return expressions
172
173    expression = expressions[0]
174
175    from sqlglot import exp
176    from sqlglot.optimizer.annotate_types import annotate_types
177    from sqlglot.optimizer.simplify import simplify
178
179    if not this.type:
180        annotate_types(this, dialect=dialect)
181
182    if t.cast(exp.DataType, this.type).this not in (
183        exp.DataType.Type.UNKNOWN,
184        exp.DataType.Type.ARRAY,
185    ):
186        return expressions
187
188    if not expression.type:
189        annotate_types(expression, dialect=dialect)
190
191    if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
192        logger.info("Applying array index offset (%s)", offset)
193        expression = simplify(expression + offset)
194        return [expression]
195
196    return expressions

Applies an offset to a given integer literal expression.

Arguments:
  • this: The target of the index.
  • expressions: The expression the offset will be applied to, wrapped in a list.
  • offset: The offset that will be applied.
  • dialect: the dialect of interest.
Returns:

The original expression with the offset applied to it, wrapped in a list. If the provided expressions argument contains more than one expression, it's returned unaffected.

def camel_to_snake_case(name: str) -> str:
199def camel_to_snake_case(name: str) -> str:
200    """Converts `name` from camelCase to snake_case and returns the result."""
201    return CAMEL_CASE_PATTERN.sub("_", name).upper()

Converts name from camelCase to snake_case and returns the result.

def while_changing( expression: sqlglot.expressions.Expression, func: Callable[[sqlglot.expressions.Expression], ~E]) -> ~E:
204def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
205    """
206    Applies a transformation to a given expression until a fix point is reached.
207
208    Args:
209        expression: The expression to be transformed.
210        func: The transformation to be applied.
211
212    Returns:
213        The transformed expression.
214    """
215    while True:
216        for n in reversed(tuple(expression.walk())):
217            n._hash = hash(n)
218
219        start = hash(expression)
220        expression = func(expression)
221
222        for n in expression.walk():
223            n._hash = None
224        if start == hash(expression):
225            break
226
227    return expression

Applies a transformation to a given expression until a fix point is reached.

Arguments:
  • expression: The expression to be transformed.
  • func: The transformation to be applied.
Returns:

The transformed expression.

def tsort(dag: Dict[~T, Set[~T]]) -> List[~T]:
230def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
231    """
232    Sorts a given directed acyclic graph in topological order.
233
234    Args:
235        dag: The graph to be sorted.
236
237    Returns:
238        A list that contains all of the graph's nodes in topological order.
239    """
240    result = []
241
242    for node, deps in tuple(dag.items()):
243        for dep in deps:
244            if dep not in dag:
245                dag[dep] = set()
246
247    while dag:
248        current = {node for node, deps in dag.items() if not deps}
249
250        if not current:
251            raise ValueError("Cycle error")
252
253        for node in current:
254            dag.pop(node)
255
256        for deps in dag.values():
257            deps -= current
258
259        result.extend(sorted(current))  # type: ignore
260
261    return result

Sorts a given directed acyclic graph in topological order.

Arguments:
  • dag: The graph to be sorted.
Returns:

A list that contains all of the graph's nodes in topological order.

def open_file(file_name: str) -> <class 'TextIO'>:
264def open_file(file_name: str) -> t.TextIO:
265    """Open a file that may be compressed as gzip and return it in universal newline mode."""
266    with open(file_name, "rb") as f:
267        gzipped = f.read(2) == b"\x1f\x8b"
268
269    if gzipped:
270        import gzip
271
272        return gzip.open(file_name, "rt", newline="")
273
274    return open(file_name, encoding="utf-8", newline="")

Open a file that may be compressed as gzip and return it in universal newline mode.

@contextmanager
def csv_reader(read_csv: sqlglot.expressions.ReadCSV) -> Any:
277@contextmanager
278def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
279    """
280    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
281
282    Args:
283        read_csv: A `ReadCSV` function call.
284
285    Yields:
286        A python csv reader.
287    """
288    args = read_csv.expressions
289    file = open_file(read_csv.name)
290
291    delimiter = ","
292    args = iter(arg.name for arg in args)  # type: ignore
293    for k, v in zip(args, args):
294        if k == "delimiter":
295            delimiter = v
296
297    try:
298        import csv as csv_
299
300        yield csv_.reader(file, delimiter=delimiter)
301    finally:
302        file.close()

Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...]).

Arguments:
  • read_csv: A ReadCSV function call.
Yields:

A python csv reader.

def find_new_name(taken: Collection[str], base: str) -> str:
305def find_new_name(taken: t.Collection[str], base: str) -> str:
306    """
307    Searches for a new name.
308
309    Args:
310        taken: A collection of taken names.
311        base: Base name to alter.
312
313    Returns:
314        The new, available name.
315    """
316    if base not in taken:
317        return base
318
319    i = 2
320    new = f"{base}_{i}"
321    while new in taken:
322        i += 1
323        new = f"{base}_{i}"
324
325    return new

Searches for a new name.

Arguments:
  • taken: A collection of taken names.
  • base: Base name to alter.
Returns:

The new, available name.

def is_int(text: str) -> bool:
328def is_int(text: str) -> bool:
329    return is_type(text, int)
def is_float(text: str) -> bool:
332def is_float(text: str) -> bool:
333    return is_type(text, float)
def is_type(text: str, target_type: Type) -> bool:
336def is_type(text: str, target_type: t.Type) -> bool:
337    try:
338        target_type(text)
339        return True
340    except ValueError:
341        return False
def name_sequence(prefix: str) -> Callable[[], str]:
344def name_sequence(prefix: str) -> t.Callable[[], str]:
345    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
346    sequence = count()
347    return lambda: f"{prefix}{next(sequence)}"

Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").

def object_to_dict(obj: Any, **kwargs) -> Dict:
350def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
351    """Returns a dictionary created from an object's attributes."""
352    return {
353        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
354        **kwargs,
355    }

Returns a dictionary created from an object's attributes.

def split_num_words( value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> List[Optional[str]]:
358def split_num_words(
359    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
360) -> t.List[t.Optional[str]]:
361    """
362    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
363
364    Args:
365        value: The value to be split.
366        sep: The value to use to split on.
367        min_num_words: The minimum number of words that are going to be in the result.
368        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
369
370    Examples:
371        >>> split_num_words("db.table", ".", 3)
372        [None, 'db', 'table']
373        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
374        ['db', 'table', None]
375        >>> split_num_words("db.table", ".", 1)
376        ['db', 'table']
377
378    Returns:
379        The list of words returned by `split`, possibly augmented by a number of `None` values.
380    """
381    words = value.split(sep)
382    if fill_from_start:
383        return [None] * (min_num_words - len(words)) + words
384    return words + [None] * (min_num_words - len(words))

Perform a split on a value and return N words as a result with None used for words that don't exist.

Arguments:
  • value: The value to be split.
  • sep: The value to use to split on.
  • min_num_words: The minimum number of words that are going to be in the result.
  • fill_from_start: Indicates that if None values should be inserted at the start or end of the list.
Examples:
>>> split_num_words("db.table", ".", 3)
[None, 'db', 'table']
>>> split_num_words("db.table", ".", 3, fill_from_start=False)
['db', 'table', None]
>>> split_num_words("db.table", ".", 1)
['db', 'table']
Returns:

The list of words returned by split, possibly augmented by a number of None values.

def is_iterable(value: Any) -> bool:
387def is_iterable(value: t.Any) -> bool:
388    """
389    Checks if the value is an iterable, excluding the types `str` and `bytes`.
390
391    Examples:
392        >>> is_iterable([1,2])
393        True
394        >>> is_iterable("test")
395        False
396
397    Args:
398        value: The value to check if it is an iterable.
399
400    Returns:
401        A `bool` value indicating if it is an iterable.
402    """
403    from sqlglot import Expression
404
405    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))

Checks if the value is an iterable, excluding the types str and bytes.

Examples:
>>> is_iterable([1,2])
True
>>> is_iterable("test")
False
Arguments:
  • value: The value to check if it is an iterable.
Returns:

A bool value indicating if it is an iterable.

def flatten(values: Iterable[Union[Iterable[Any], Any]]) -> Iterator[Any]:
408def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
409    """
410    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
411    type `str` and `bytes` are not regarded as iterables.
412
413    Examples:
414        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
415        [1, 2, 3, 4, 5, 'bla']
416        >>> list(flatten([1, 2, 3]))
417        [1, 2, 3]
418
419    Args:
420        values: The value to be flattened.
421
422    Yields:
423        Non-iterable elements in `values`.
424    """
425    for value in values:
426        if is_iterable(value):
427            yield from flatten(value)
428        else:
429            yield value

Flattens an iterable that can contain both iterable and non-iterable elements. Objects of type str and bytes are not regarded as iterables.

Examples:
>>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
[1, 2, 3, 4, 5, 'bla']
>>> list(flatten([1, 2, 3]))
[1, 2, 3]
Arguments:
  • values: The value to be flattened.
Yields:

Non-iterable elements in values.

def dict_depth(d: Dict) -> int:
432def dict_depth(d: t.Dict) -> int:
433    """
434    Get the nesting depth of a dictionary.
435
436    Example:
437        >>> dict_depth(None)
438        0
439        >>> dict_depth({})
440        1
441        >>> dict_depth({"a": "b"})
442        1
443        >>> dict_depth({"a": {}})
444        2
445        >>> dict_depth({"a": {"b": {}}})
446        3
447    """
448    try:
449        return 1 + dict_depth(next(iter(d.values())))
450    except AttributeError:
451        # d doesn't have attribute "values"
452        return 0
453    except StopIteration:
454        # d.values() returns an empty sequence
455        return 1

Get the nesting depth of a dictionary.

Example:
>>> dict_depth(None)
0
>>> dict_depth({})
1
>>> dict_depth({"a": "b"})
1
>>> dict_depth({"a": {}})
2
>>> dict_depth({"a": {"b": {}}})
3
def first(it: Iterable[~T]) -> ~T:
458def first(it: t.Iterable[T]) -> T:
459    """Returns the first element from an iterable (useful for sets)."""
460    return next(i for i in it)

Returns the first element from an iterable (useful for sets).

def to_bool(value: Union[str, bool, NoneType]) -> Union[str, bool, NoneType]:
463def to_bool(value: t.Optional[str | bool]) -> t.Optional[str | bool]:
464    if isinstance(value, bool) or value is None:
465        return value
466
467    # Coerce the value to boolean if it matches to the truthy/falsy values below
468    value_lower = value.lower()
469    if value_lower in ("true", "1"):
470        return True
471    if value_lower in ("false", "0"):
472        return False
473
474    return value
def merge_ranges(ranges: List[Tuple[~A, ~A]]) -> List[Tuple[~A, ~A]]:
477def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
478    """
479    Merges a sequence of ranges, represented as tuples (low, high) whose values
480    belong to some totally-ordered set.
481
482    Example:
483        >>> merge_ranges([(1, 3), (2, 6)])
484        [(1, 6)]
485    """
486    if not ranges:
487        return []
488
489    ranges = sorted(ranges)
490
491    merged = [ranges[0]]
492
493    for start, end in ranges[1:]:
494        last_start, last_end = merged[-1]
495
496        if start <= last_end:
497            merged[-1] = (last_start, max(last_end, end))
498        else:
499            merged.append((start, end))
500
501    return merged

Merges a sequence of ranges, represented as tuples (low, high) whose values belong to some totally-ordered set.

Example:
>>> merge_ranges([(1, 3), (2, 6)])
[(1, 6)]
def is_iso_date(text: str) -> bool:
504def is_iso_date(text: str) -> bool:
505    try:
506        datetime.date.fromisoformat(text)
507        return True
508    except ValueError:
509        return False
def is_iso_datetime(text: str) -> bool:
512def is_iso_datetime(text: str) -> bool:
513    try:
514        datetime.datetime.fromisoformat(text)
515        return True
516    except ValueError:
517        return False
DATE_UNITS = {'day', 'month', 'year', 'week', 'quarter', 'year_month'}
def is_date_unit(expression: Optional[sqlglot.expressions.Expression]) -> bool:
524def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
525    return expression is not None and expression.name.lower() in DATE_UNITS
class SingleValuedMapping(typing.Mapping[~K, ~V]):
532class SingleValuedMapping(t.Mapping[K, V]):
533    """
534    Mapping where all keys return the same value.
535
536    This rigamarole is meant to avoid copying keys, which was originally intended
537    as an optimization while qualifying columns for tables with lots of columns.
538    """
539
540    def __init__(self, keys: t.Collection[K], value: V):
541        self._keys = keys if isinstance(keys, Set) else set(keys)
542        self._value = value
543
544    def __getitem__(self, key: K) -> V:
545        if key in self._keys:
546            return self._value
547        raise KeyError(key)
548
549    def __len__(self) -> int:
550        return len(self._keys)
551
552    def __iter__(self) -> t.Iterator[K]:
553        return iter(self._keys)

Mapping where all keys return the same value.

This rigamarole is meant to avoid copying keys, which was originally intended as an optimization while qualifying columns for tables with lots of columns.

SingleValuedMapping(keys: Collection[~K], value: ~V)
540    def __init__(self, keys: t.Collection[K], value: V):
541        self._keys = keys if isinstance(keys, Set) else set(keys)
542        self._value = value