Edit on GitHub

sqlglot.helper

  1from __future__ import annotations
  2
  3import datetime
  4import inspect
  5import logging
  6import re
  7import sys
  8import typing as t
  9from collections.abc import Collection, Set
 10from contextlib import contextmanager
 11from copy import copy
 12from enum import Enum
 13from itertools import count
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot import exp
 17    from sqlglot._typing import A, E, T
 18    from sqlglot.expressions import Expression
 19
 20
 21CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
 22PYTHON_VERSION = sys.version_info[:2]
 23logger = logging.getLogger("sqlglot")
 24
 25
 26class AutoName(Enum):
 27    """
 28    This is used for creating Enum classes where `auto()` is the string form
 29    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
 30
 31    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
 32    """
 33
 34    def _generate_next_value_(name, _start, _count, _last_values):
 35        return name
 36
 37
 38class classproperty(property):
 39    """
 40    Similar to a normal property but works for class methods
 41    """
 42
 43    def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
 44        return classmethod(self.fget).__get__(None, owner)()  # type: ignore
 45
 46
 47def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
 48    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
 49    try:
 50        return seq[index]
 51    except IndexError:
 52        return None
 53
 54
 55@t.overload
 56def ensure_list(value: t.Collection[T]) -> t.List[T]: ...
 57
 58
 59@t.overload
 60def ensure_list(value: T) -> t.List[T]: ...
 61
 62
 63def ensure_list(value):
 64    """
 65    Ensures that a value is a list, otherwise casts or wraps it into one.
 66
 67    Args:
 68        value: The value of interest.
 69
 70    Returns:
 71        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
 72    """
 73    if value is None:
 74        return []
 75    if isinstance(value, (list, tuple)):
 76        return list(value)
 77
 78    return [value]
 79
 80
 81@t.overload
 82def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: ...
 83
 84
 85@t.overload
 86def ensure_collection(value: T) -> t.Collection[T]: ...
 87
 88
 89def ensure_collection(value):
 90    """
 91    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
 92
 93    Args:
 94        value: The value of interest.
 95
 96    Returns:
 97        The value if it's a collection, or else the value wrapped in a list.
 98    """
 99    if value is None:
100        return []
101    return (
102        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
103    )
104
105
106def csv(*args: str, sep: str = ", ") -> str:
107    """
108    Formats any number of string arguments as CSV.
109
110    Args:
111        args: The string arguments to format.
112        sep: The argument separator.
113
114    Returns:
115        The arguments formatted as a CSV string.
116    """
117    return sep.join(arg for arg in args if arg)
118
119
120def subclasses(
121    module_name: str,
122    classes: t.Type | t.Tuple[t.Type, ...],
123    exclude: t.Type | t.Tuple[t.Type, ...] = (),
124) -> t.List[t.Type]:
125    """
126    Returns all subclasses for a collection of classes, possibly excluding some of them.
127
128    Args:
129        module_name: The name of the module to search for subclasses in.
130        classes: Class(es) we want to find the subclasses of.
131        exclude: Class(es) we want to exclude from the returned list.
132
133    Returns:
134        The target subclasses.
135    """
136    return [
137        obj
138        for _, obj in inspect.getmembers(
139            sys.modules[module_name],
140            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
141        )
142    ]
143
144
145def apply_index_offset(
146    this: exp.Expression,
147    expressions: t.List[E],
148    offset: int,
149) -> t.List[E]:
150    """
151    Applies an offset to a given integer literal expression.
152
153    Args:
154        this: The target of the index.
155        expressions: The expression the offset will be applied to, wrapped in a list.
156        offset: The offset that will be applied.
157
158    Returns:
159        The original expression with the offset applied to it, wrapped in a list. If the provided
160        `expressions` argument contains more than one expression, it's returned unaffected.
161    """
162    if not offset or len(expressions) != 1:
163        return expressions
164
165    expression = expressions[0]
166
167    from sqlglot import exp
168    from sqlglot.optimizer.annotate_types import annotate_types
169    from sqlglot.optimizer.simplify import simplify
170
171    if not this.type:
172        annotate_types(this)
173
174    if t.cast(exp.DataType, this.type).this not in (
175        exp.DataType.Type.UNKNOWN,
176        exp.DataType.Type.ARRAY,
177    ):
178        return expressions
179
180    if not expression.type:
181        annotate_types(expression)
182
183    if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
184        logger.info("Applying array index offset (%s)", offset)
185        expression = simplify(expression + offset)
186        return [expression]
187
188    return expressions
189
190
191def camel_to_snake_case(name: str) -> str:
192    """Converts `name` from camelCase to snake_case and returns the result."""
193    return CAMEL_CASE_PATTERN.sub("_", name).upper()
194
195
196def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
197    """
198    Applies a transformation to a given expression until a fix point is reached.
199
200    Args:
201        expression: The expression to be transformed.
202        func: The transformation to be applied.
203
204    Returns:
205        The transformed expression.
206    """
207    while True:
208        for n in reversed(tuple(expression.walk())):
209            n._hash = hash(n)
210
211        start = hash(expression)
212        expression = func(expression)
213
214        for n in expression.walk():
215            n._hash = None
216        if start == hash(expression):
217            break
218
219    return expression
220
221
222def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
223    """
224    Sorts a given directed acyclic graph in topological order.
225
226    Args:
227        dag: The graph to be sorted.
228
229    Returns:
230        A list that contains all of the graph's nodes in topological order.
231    """
232    result = []
233
234    for node, deps in tuple(dag.items()):
235        for dep in deps:
236            if dep not in dag:
237                dag[dep] = set()
238
239    while dag:
240        current = {node for node, deps in dag.items() if not deps}
241
242        if not current:
243            raise ValueError("Cycle error")
244
245        for node in current:
246            dag.pop(node)
247
248        for deps in dag.values():
249            deps -= current
250
251        result.extend(sorted(current))  # type: ignore
252
253    return result
254
255
256def open_file(file_name: str) -> t.TextIO:
257    """Open a file that may be compressed as gzip and return it in universal newline mode."""
258    with open(file_name, "rb") as f:
259        gzipped = f.read(2) == b"\x1f\x8b"
260
261    if gzipped:
262        import gzip
263
264        return gzip.open(file_name, "rt", newline="")
265
266    return open(file_name, encoding="utf-8", newline="")
267
268
269@contextmanager
270def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
271    """
272    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
273
274    Args:
275        read_csv: A `ReadCSV` function call.
276
277    Yields:
278        A python csv reader.
279    """
280    args = read_csv.expressions
281    file = open_file(read_csv.name)
282
283    delimiter = ","
284    args = iter(arg.name for arg in args)  # type: ignore
285    for k, v in zip(args, args):
286        if k == "delimiter":
287            delimiter = v
288
289    try:
290        import csv as csv_
291
292        yield csv_.reader(file, delimiter=delimiter)
293    finally:
294        file.close()
295
296
297def find_new_name(taken: t.Collection[str], base: str) -> str:
298    """
299    Searches for a new name.
300
301    Args:
302        taken: A collection of taken names.
303        base: Base name to alter.
304
305    Returns:
306        The new, available name.
307    """
308    if base not in taken:
309        return base
310
311    i = 2
312    new = f"{base}_{i}"
313    while new in taken:
314        i += 1
315        new = f"{base}_{i}"
316
317    return new
318
319
320def is_int(text: str) -> bool:
321    return is_type(text, int)
322
323
324def is_float(text: str) -> bool:
325    return is_type(text, float)
326
327
328def is_type(text: str, target_type: t.Type) -> bool:
329    try:
330        target_type(text)
331        return True
332    except ValueError:
333        return False
334
335
336def name_sequence(prefix: str) -> t.Callable[[], str]:
337    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
338    sequence = count()
339    return lambda: f"{prefix}{next(sequence)}"
340
341
342def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
343    """Returns a dictionary created from an object's attributes."""
344    return {
345        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
346        **kwargs,
347    }
348
349
350def split_num_words(
351    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
352) -> t.List[t.Optional[str]]:
353    """
354    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
355
356    Args:
357        value: The value to be split.
358        sep: The value to use to split on.
359        min_num_words: The minimum number of words that are going to be in the result.
360        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
361
362    Examples:
363        >>> split_num_words("db.table", ".", 3)
364        [None, 'db', 'table']
365        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
366        ['db', 'table', None]
367        >>> split_num_words("db.table", ".", 1)
368        ['db', 'table']
369
370    Returns:
371        The list of words returned by `split`, possibly augmented by a number of `None` values.
372    """
373    words = value.split(sep)
374    if fill_from_start:
375        return [None] * (min_num_words - len(words)) + words
376    return words + [None] * (min_num_words - len(words))
377
378
379def is_iterable(value: t.Any) -> bool:
380    """
381    Checks if the value is an iterable, excluding the types `str` and `bytes`.
382
383    Examples:
384        >>> is_iterable([1,2])
385        True
386        >>> is_iterable("test")
387        False
388
389    Args:
390        value: The value to check if it is an iterable.
391
392    Returns:
393        A `bool` value indicating if it is an iterable.
394    """
395    from sqlglot import Expression
396
397    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))
398
399
400def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
401    """
402    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
403    type `str` and `bytes` are not regarded as iterables.
404
405    Examples:
406        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
407        [1, 2, 3, 4, 5, 'bla']
408        >>> list(flatten([1, 2, 3]))
409        [1, 2, 3]
410
411    Args:
412        values: The value to be flattened.
413
414    Yields:
415        Non-iterable elements in `values`.
416    """
417    for value in values:
418        if is_iterable(value):
419            yield from flatten(value)
420        else:
421            yield value
422
423
424def dict_depth(d: t.Dict) -> int:
425    """
426    Get the nesting depth of a dictionary.
427
428    Example:
429        >>> dict_depth(None)
430        0
431        >>> dict_depth({})
432        1
433        >>> dict_depth({"a": "b"})
434        1
435        >>> dict_depth({"a": {}})
436        2
437        >>> dict_depth({"a": {"b": {}}})
438        3
439    """
440    try:
441        return 1 + dict_depth(next(iter(d.values())))
442    except AttributeError:
443        # d doesn't have attribute "values"
444        return 0
445    except StopIteration:
446        # d.values() returns an empty sequence
447        return 1
448
449
450def first(it: t.Iterable[T]) -> T:
451    """Returns the first element from an iterable (useful for sets)."""
452    return next(i for i in it)
453
454
455def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
456    """
457    Merges a sequence of ranges, represented as tuples (low, high) whose values
458    belong to some totally-ordered set.
459
460    Example:
461        >>> merge_ranges([(1, 3), (2, 6)])
462        [(1, 6)]
463    """
464    if not ranges:
465        return []
466
467    ranges = sorted(ranges)
468
469    merged = [ranges[0]]
470
471    for start, end in ranges[1:]:
472        last_start, last_end = merged[-1]
473
474        if start <= last_end:
475            merged[-1] = (last_start, max(last_end, end))
476        else:
477            merged.append((start, end))
478
479    return merged
480
481
482def is_iso_date(text: str) -> bool:
483    try:
484        datetime.date.fromisoformat(text)
485        return True
486    except ValueError:
487        return False
488
489
490def is_iso_datetime(text: str) -> bool:
491    try:
492        datetime.datetime.fromisoformat(text)
493        return True
494    except ValueError:
495        return False
496
497
498# Interval units that operate on date components
499DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
500
501
502def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
503    return expression is not None and expression.name.lower() in DATE_UNITS
504
505
506K = t.TypeVar("K")
507V = t.TypeVar("V")
508
509
510class SingleValuedMapping(t.Mapping[K, V]):
511    """
512    Mapping where all keys return the same value.
513
514    This rigamarole is meant to avoid copying keys, which was originally intended
515    as an optimization while qualifying columns for tables with lots of columns.
516    """
517
518    def __init__(self, keys: t.Collection[K], value: V):
519        self._keys = keys if isinstance(keys, Set) else set(keys)
520        self._value = value
521
522    def __getitem__(self, key: K) -> V:
523        if key in self._keys:
524            return self._value
525        raise KeyError(key)
526
527    def __len__(self) -> int:
528        return len(self._keys)
529
530    def __iter__(self) -> t.Iterator[K]:
531        return iter(self._keys)
CAMEL_CASE_PATTERN = re.compile('(?<!^)(?=[A-Z])')
PYTHON_VERSION = (3, 10)
logger = <Logger sqlglot (WARNING)>
class AutoName(enum.Enum):
27class AutoName(Enum):
28    """
29    This is used for creating Enum classes where `auto()` is the string form
30    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
31
32    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
33    """
34
35    def _generate_next_value_(name, _start, _count, _last_values):
36        return name

This is used for creating Enum classes where auto() is the string form of the corresponding enum's identifier (e.g. FOO.value results in "FOO").

Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values

Inherited Members
enum.Enum
name
value
class classproperty(builtins.property):
39class classproperty(property):
40    """
41    Similar to a normal property but works for class methods
42    """
43
44    def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
45        return classmethod(self.fget).__get__(None, owner)()  # type: ignore

Similar to a normal property but works for class methods

Inherited Members
builtins.property
property
getter
setter
deleter
fget
fset
fdel
def seq_get(seq: Sequence[~T], index: int) -> Optional[~T]:
48def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
49    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
50    try:
51        return seq[index]
52    except IndexError:
53        return None

Returns the value in seq at position index, or None if index is out of bounds.

def ensure_list(value):
64def ensure_list(value):
65    """
66    Ensures that a value is a list, otherwise casts or wraps it into one.
67
68    Args:
69        value: The value of interest.
70
71    Returns:
72        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
73    """
74    if value is None:
75        return []
76    if isinstance(value, (list, tuple)):
77        return list(value)
78
79    return [value]

Ensures that a value is a list, otherwise casts or wraps it into one.

Arguments:
  • value: The value of interest.
Returns:

The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.

def ensure_collection(value):
 90def ensure_collection(value):
 91    """
 92    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
 93
 94    Args:
 95        value: The value of interest.
 96
 97    Returns:
 98        The value if it's a collection, or else the value wrapped in a list.
 99    """
100    if value is None:
101        return []
102    return (
103        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
104    )

Ensures that a value is a collection (excluding str and bytes), otherwise wraps it into a list.

Arguments:
  • value: The value of interest.
Returns:

The value if it's a collection, or else the value wrapped in a list.

def csv(*args: str, sep: str = ', ') -> str:
107def csv(*args: str, sep: str = ", ") -> str:
108    """
109    Formats any number of string arguments as CSV.
110
111    Args:
112        args: The string arguments to format.
113        sep: The argument separator.
114
115    Returns:
116        The arguments formatted as a CSV string.
117    """
118    return sep.join(arg for arg in args if arg)

Formats any number of string arguments as CSV.

Arguments:
  • args: The string arguments to format.
  • sep: The argument separator.
Returns:

The arguments formatted as a CSV string.

def subclasses( module_name: str, classes: Union[Type, Tuple[Type, ...]], exclude: Union[Type, Tuple[Type, ...]] = ()) -> List[Type]:
121def subclasses(
122    module_name: str,
123    classes: t.Type | t.Tuple[t.Type, ...],
124    exclude: t.Type | t.Tuple[t.Type, ...] = (),
125) -> t.List[t.Type]:
126    """
127    Returns all subclasses for a collection of classes, possibly excluding some of them.
128
129    Args:
130        module_name: The name of the module to search for subclasses in.
131        classes: Class(es) we want to find the subclasses of.
132        exclude: Class(es) we want to exclude from the returned list.
133
134    Returns:
135        The target subclasses.
136    """
137    return [
138        obj
139        for _, obj in inspect.getmembers(
140            sys.modules[module_name],
141            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
142        )
143    ]

Returns all subclasses for a collection of classes, possibly excluding some of them.

Arguments:
  • module_name: The name of the module to search for subclasses in.
  • classes: Class(es) we want to find the subclasses of.
  • exclude: Class(es) we want to exclude from the returned list.
Returns:

The target subclasses.

def apply_index_offset( this: sqlglot.expressions.Expression, expressions: List[~E], offset: int) -> List[~E]:
146def apply_index_offset(
147    this: exp.Expression,
148    expressions: t.List[E],
149    offset: int,
150) -> t.List[E]:
151    """
152    Applies an offset to a given integer literal expression.
153
154    Args:
155        this: The target of the index.
156        expressions: The expression the offset will be applied to, wrapped in a list.
157        offset: The offset that will be applied.
158
159    Returns:
160        The original expression with the offset applied to it, wrapped in a list. If the provided
161        `expressions` argument contains more than one expression, it's returned unaffected.
162    """
163    if not offset or len(expressions) != 1:
164        return expressions
165
166    expression = expressions[0]
167
168    from sqlglot import exp
169    from sqlglot.optimizer.annotate_types import annotate_types
170    from sqlglot.optimizer.simplify import simplify
171
172    if not this.type:
173        annotate_types(this)
174
175    if t.cast(exp.DataType, this.type).this not in (
176        exp.DataType.Type.UNKNOWN,
177        exp.DataType.Type.ARRAY,
178    ):
179        return expressions
180
181    if not expression.type:
182        annotate_types(expression)
183
184    if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
185        logger.info("Applying array index offset (%s)", offset)
186        expression = simplify(expression + offset)
187        return [expression]
188
189    return expressions

Applies an offset to a given integer literal expression.

Arguments:
  • this: The target of the index.
  • expressions: The expression the offset will be applied to, wrapped in a list.
  • offset: The offset that will be applied.
Returns:

The original expression with the offset applied to it, wrapped in a list. If the provided expressions argument contains more than one expression, it's returned unaffected.

def camel_to_snake_case(name: str) -> str:
192def camel_to_snake_case(name: str) -> str:
193    """Converts `name` from camelCase to snake_case and returns the result."""
194    return CAMEL_CASE_PATTERN.sub("_", name).upper()

Converts name from camelCase to snake_case and returns the result.

def while_changing( expression: sqlglot.expressions.Expression, func: Callable[[sqlglot.expressions.Expression], ~E]) -> ~E:
197def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
198    """
199    Applies a transformation to a given expression until a fix point is reached.
200
201    Args:
202        expression: The expression to be transformed.
203        func: The transformation to be applied.
204
205    Returns:
206        The transformed expression.
207    """
208    while True:
209        for n in reversed(tuple(expression.walk())):
210            n._hash = hash(n)
211
212        start = hash(expression)
213        expression = func(expression)
214
215        for n in expression.walk():
216            n._hash = None
217        if start == hash(expression):
218            break
219
220    return expression

Applies a transformation to a given expression until a fix point is reached.

Arguments:
  • expression: The expression to be transformed.
  • func: The transformation to be applied.
Returns:

The transformed expression.

def tsort(dag: Dict[~T, Set[~T]]) -> List[~T]:
223def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
224    """
225    Sorts a given directed acyclic graph in topological order.
226
227    Args:
228        dag: The graph to be sorted.
229
230    Returns:
231        A list that contains all of the graph's nodes in topological order.
232    """
233    result = []
234
235    for node, deps in tuple(dag.items()):
236        for dep in deps:
237            if dep not in dag:
238                dag[dep] = set()
239
240    while dag:
241        current = {node for node, deps in dag.items() if not deps}
242
243        if not current:
244            raise ValueError("Cycle error")
245
246        for node in current:
247            dag.pop(node)
248
249        for deps in dag.values():
250            deps -= current
251
252        result.extend(sorted(current))  # type: ignore
253
254    return result

Sorts a given directed acyclic graph in topological order.

Arguments:
  • dag: The graph to be sorted.
Returns:

A list that contains all of the graph's nodes in topological order.

def open_file(file_name: str) -> <class 'TextIO'>:
257def open_file(file_name: str) -> t.TextIO:
258    """Open a file that may be compressed as gzip and return it in universal newline mode."""
259    with open(file_name, "rb") as f:
260        gzipped = f.read(2) == b"\x1f\x8b"
261
262    if gzipped:
263        import gzip
264
265        return gzip.open(file_name, "rt", newline="")
266
267    return open(file_name, encoding="utf-8", newline="")

Open a file that may be compressed as gzip and return it in universal newline mode.

@contextmanager
def csv_reader(read_csv: sqlglot.expressions.ReadCSV) -> Any:
270@contextmanager
271def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
272    """
273    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
274
275    Args:
276        read_csv: A `ReadCSV` function call.
277
278    Yields:
279        A python csv reader.
280    """
281    args = read_csv.expressions
282    file = open_file(read_csv.name)
283
284    delimiter = ","
285    args = iter(arg.name for arg in args)  # type: ignore
286    for k, v in zip(args, args):
287        if k == "delimiter":
288            delimiter = v
289
290    try:
291        import csv as csv_
292
293        yield csv_.reader(file, delimiter=delimiter)
294    finally:
295        file.close()

Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...]).

Arguments:
  • read_csv: A ReadCSV function call.
Yields:

A python csv reader.

def find_new_name(taken: Collection[str], base: str) -> str:
298def find_new_name(taken: t.Collection[str], base: str) -> str:
299    """
300    Searches for a new name.
301
302    Args:
303        taken: A collection of taken names.
304        base: Base name to alter.
305
306    Returns:
307        The new, available name.
308    """
309    if base not in taken:
310        return base
311
312    i = 2
313    new = f"{base}_{i}"
314    while new in taken:
315        i += 1
316        new = f"{base}_{i}"
317
318    return new

Searches for a new name.

Arguments:
  • taken: A collection of taken names.
  • base: Base name to alter.
Returns:

The new, available name.

def is_int(text: str) -> bool:
321def is_int(text: str) -> bool:
322    return is_type(text, int)
def is_float(text: str) -> bool:
325def is_float(text: str) -> bool:
326    return is_type(text, float)
def is_type(text: str, target_type: Type) -> bool:
329def is_type(text: str, target_type: t.Type) -> bool:
330    try:
331        target_type(text)
332        return True
333    except ValueError:
334        return False
def name_sequence(prefix: str) -> Callable[[], str]:
337def name_sequence(prefix: str) -> t.Callable[[], str]:
338    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
339    sequence = count()
340    return lambda: f"{prefix}{next(sequence)}"

Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").

def object_to_dict(obj: Any, **kwargs) -> Dict:
343def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
344    """Returns a dictionary created from an object's attributes."""
345    return {
346        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
347        **kwargs,
348    }

Returns a dictionary created from an object's attributes.

def split_num_words( value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> List[Optional[str]]:
351def split_num_words(
352    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
353) -> t.List[t.Optional[str]]:
354    """
355    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
356
357    Args:
358        value: The value to be split.
359        sep: The value to use to split on.
360        min_num_words: The minimum number of words that are going to be in the result.
361        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
362
363    Examples:
364        >>> split_num_words("db.table", ".", 3)
365        [None, 'db', 'table']
366        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
367        ['db', 'table', None]
368        >>> split_num_words("db.table", ".", 1)
369        ['db', 'table']
370
371    Returns:
372        The list of words returned by `split`, possibly augmented by a number of `None` values.
373    """
374    words = value.split(sep)
375    if fill_from_start:
376        return [None] * (min_num_words - len(words)) + words
377    return words + [None] * (min_num_words - len(words))

Perform a split on a value and return N words as a result with None used for words that don't exist.

Arguments:
  • value: The value to be split.
  • sep: The value to use to split on.
  • min_num_words: The minimum number of words that are going to be in the result.
  • fill_from_start: Indicates that if None values should be inserted at the start or end of the list.
Examples:
>>> split_num_words("db.table", ".", 3)
[None, 'db', 'table']
>>> split_num_words("db.table", ".", 3, fill_from_start=False)
['db', 'table', None]
>>> split_num_words("db.table", ".", 1)
['db', 'table']
Returns:

The list of words returned by split, possibly augmented by a number of None values.

def is_iterable(value: Any) -> bool:
380def is_iterable(value: t.Any) -> bool:
381    """
382    Checks if the value is an iterable, excluding the types `str` and `bytes`.
383
384    Examples:
385        >>> is_iterable([1,2])
386        True
387        >>> is_iterable("test")
388        False
389
390    Args:
391        value: The value to check if it is an iterable.
392
393    Returns:
394        A `bool` value indicating if it is an iterable.
395    """
396    from sqlglot import Expression
397
398    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))

Checks if the value is an iterable, excluding the types str and bytes.

Examples:
>>> is_iterable([1,2])
True
>>> is_iterable("test")
False
Arguments:
  • value: The value to check if it is an iterable.
Returns:

A bool value indicating if it is an iterable.

def flatten(values: Iterable[Union[Iterable[Any], Any]]) -> Iterator[Any]:
401def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
402    """
403    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
404    type `str` and `bytes` are not regarded as iterables.
405
406    Examples:
407        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
408        [1, 2, 3, 4, 5, 'bla']
409        >>> list(flatten([1, 2, 3]))
410        [1, 2, 3]
411
412    Args:
413        values: The value to be flattened.
414
415    Yields:
416        Non-iterable elements in `values`.
417    """
418    for value in values:
419        if is_iterable(value):
420            yield from flatten(value)
421        else:
422            yield value

Flattens an iterable that can contain both iterable and non-iterable elements. Objects of type str and bytes are not regarded as iterables.

Examples:
>>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
[1, 2, 3, 4, 5, 'bla']
>>> list(flatten([1, 2, 3]))
[1, 2, 3]
Arguments:
  • values: The value to be flattened.
Yields:

Non-iterable elements in values.

def dict_depth(d: Dict) -> int:
425def dict_depth(d: t.Dict) -> int:
426    """
427    Get the nesting depth of a dictionary.
428
429    Example:
430        >>> dict_depth(None)
431        0
432        >>> dict_depth({})
433        1
434        >>> dict_depth({"a": "b"})
435        1
436        >>> dict_depth({"a": {}})
437        2
438        >>> dict_depth({"a": {"b": {}}})
439        3
440    """
441    try:
442        return 1 + dict_depth(next(iter(d.values())))
443    except AttributeError:
444        # d doesn't have attribute "values"
445        return 0
446    except StopIteration:
447        # d.values() returns an empty sequence
448        return 1

Get the nesting depth of a dictionary.

Example:
>>> dict_depth(None)
0
>>> dict_depth({})
1
>>> dict_depth({"a": "b"})
1
>>> dict_depth({"a": {}})
2
>>> dict_depth({"a": {"b": {}}})
3
def first(it: Iterable[~T]) -> ~T:
451def first(it: t.Iterable[T]) -> T:
452    """Returns the first element from an iterable (useful for sets)."""
453    return next(i for i in it)

Returns the first element from an iterable (useful for sets).

def merge_ranges(ranges: List[Tuple[~A, ~A]]) -> List[Tuple[~A, ~A]]:
456def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
457    """
458    Merges a sequence of ranges, represented as tuples (low, high) whose values
459    belong to some totally-ordered set.
460
461    Example:
462        >>> merge_ranges([(1, 3), (2, 6)])
463        [(1, 6)]
464    """
465    if not ranges:
466        return []
467
468    ranges = sorted(ranges)
469
470    merged = [ranges[0]]
471
472    for start, end in ranges[1:]:
473        last_start, last_end = merged[-1]
474
475        if start <= last_end:
476            merged[-1] = (last_start, max(last_end, end))
477        else:
478            merged.append((start, end))
479
480    return merged

Merges a sequence of ranges, represented as tuples (low, high) whose values belong to some totally-ordered set.

Example:
>>> merge_ranges([(1, 3), (2, 6)])
[(1, 6)]
def is_iso_date(text: str) -> bool:
483def is_iso_date(text: str) -> bool:
484    try:
485        datetime.date.fromisoformat(text)
486        return True
487    except ValueError:
488        return False
def is_iso_datetime(text: str) -> bool:
491def is_iso_datetime(text: str) -> bool:
492    try:
493        datetime.datetime.fromisoformat(text)
494        return True
495    except ValueError:
496        return False
DATE_UNITS = {'day', 'quarter', 'month', 'year_month', 'week', 'year'}
def is_date_unit(expression: Optional[sqlglot.expressions.Expression]) -> bool:
503def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
504    return expression is not None and expression.name.lower() in DATE_UNITS
class SingleValuedMapping(typing.Mapping[~K, ~V]):
511class SingleValuedMapping(t.Mapping[K, V]):
512    """
513    Mapping where all keys return the same value.
514
515    This rigamarole is meant to avoid copying keys, which was originally intended
516    as an optimization while qualifying columns for tables with lots of columns.
517    """
518
519    def __init__(self, keys: t.Collection[K], value: V):
520        self._keys = keys if isinstance(keys, Set) else set(keys)
521        self._value = value
522
523    def __getitem__(self, key: K) -> V:
524        if key in self._keys:
525            return self._value
526        raise KeyError(key)
527
528    def __len__(self) -> int:
529        return len(self._keys)
530
531    def __iter__(self) -> t.Iterator[K]:
532        return iter(self._keys)

Mapping where all keys return the same value.

This rigamarole is meant to avoid copying keys, which was originally intended as an optimization while qualifying columns for tables with lots of columns.

SingleValuedMapping(keys: Collection[~K], value: ~V)
519    def __init__(self, keys: t.Collection[K], value: V):
520        self._keys = keys if isinstance(keys, Set) else set(keys)
521        self._value = value
Inherited Members
collections.abc.Mapping
get
keys
items
values