Edit on GitHub

sqlglot.helper

  1from __future__ import annotations
  2
  3import datetime
  4import inspect
  5import logging
  6import re
  7import sys
  8import typing as t
  9from collections.abc import Collection, Set
 10from contextlib import contextmanager
 11from copy import copy
 12from enum import Enum
 13from itertools import count
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot import exp
 17    from sqlglot._typing import A, E, T
 18    from sqlglot.expressions import Expression
 19
 20
 21CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
 22PYTHON_VERSION = sys.version_info[:2]
 23logger = logging.getLogger("sqlglot")
 24
 25
 26class AutoName(Enum):
 27    """
 28    This is used for creating Enum classes where `auto()` is the string form
 29    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
 30
 31    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
 32    """
 33
 34    def _generate_next_value_(name, _start, _count, _last_values):
 35        return name
 36
 37
 38class classproperty(property):
 39    """
 40    Similar to a normal property but works for class methods
 41    """
 42
 43    def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
 44        return classmethod(self.fget).__get__(None, owner)()  # type: ignore
 45
 46
 47def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
 48    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
 49    try:
 50        return seq[index]
 51    except IndexError:
 52        return None
 53
 54
 55@t.overload
 56def ensure_list(value: t.Collection[T]) -> t.List[T]:
 57    ...
 58
 59
 60@t.overload
 61def ensure_list(value: T) -> t.List[T]:
 62    ...
 63
 64
 65def ensure_list(value):
 66    """
 67    Ensures that a value is a list, otherwise casts or wraps it into one.
 68
 69    Args:
 70        value: The value of interest.
 71
 72    Returns:
 73        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
 74    """
 75    if value is None:
 76        return []
 77    if isinstance(value, (list, tuple)):
 78        return list(value)
 79
 80    return [value]
 81
 82
 83@t.overload
 84def ensure_collection(value: t.Collection[T]) -> t.Collection[T]:
 85    ...
 86
 87
 88@t.overload
 89def ensure_collection(value: T) -> t.Collection[T]:
 90    ...
 91
 92
 93def ensure_collection(value):
 94    """
 95    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
 96
 97    Args:
 98        value: The value of interest.
 99
100    Returns:
101        The value if it's a collection, or else the value wrapped in a list.
102    """
103    if value is None:
104        return []
105    return (
106        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
107    )
108
109
110def csv(*args: str, sep: str = ", ") -> str:
111    """
112    Formats any number of string arguments as CSV.
113
114    Args:
115        args: The string arguments to format.
116        sep: The argument separator.
117
118    Returns:
119        The arguments formatted as a CSV string.
120    """
121    return sep.join(arg for arg in args if arg)
122
123
124def subclasses(
125    module_name: str,
126    classes: t.Type | t.Tuple[t.Type, ...],
127    exclude: t.Type | t.Tuple[t.Type, ...] = (),
128) -> t.List[t.Type]:
129    """
130    Returns all subclasses for a collection of classes, possibly excluding some of them.
131
132    Args:
133        module_name: The name of the module to search for subclasses in.
134        classes: Class(es) we want to find the subclasses of.
135        exclude: Class(es) we want to exclude from the returned list.
136
137    Returns:
138        The target subclasses.
139    """
140    return [
141        obj
142        for _, obj in inspect.getmembers(
143            sys.modules[module_name],
144            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
145        )
146    ]
147
148
149def apply_index_offset(
150    this: exp.Expression,
151    expressions: t.List[E],
152    offset: int,
153) -> t.List[E]:
154    """
155    Applies an offset to a given integer literal expression.
156
157    Args:
158        this: The target of the index.
159        expressions: The expression the offset will be applied to, wrapped in a list.
160        offset: The offset that will be applied.
161
162    Returns:
163        The original expression with the offset applied to it, wrapped in a list. If the provided
164        `expressions` argument contains more than one expression, it's returned unaffected.
165    """
166    if not offset or len(expressions) != 1:
167        return expressions
168
169    expression = expressions[0]
170
171    from sqlglot import exp
172    from sqlglot.optimizer.annotate_types import annotate_types
173    from sqlglot.optimizer.simplify import simplify
174
175    if not this.type:
176        annotate_types(this)
177
178    if t.cast(exp.DataType, this.type).this not in (
179        exp.DataType.Type.UNKNOWN,
180        exp.DataType.Type.ARRAY,
181    ):
182        return expressions
183
184    if not expression.type:
185        annotate_types(expression)
186    if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
187        logger.warning("Applying array index offset (%s)", offset)
188        expression = simplify(exp.Add(this=expression, expression=exp.Literal.number(offset)))
189        return [expression]
190
191    return expressions
192
193
194def camel_to_snake_case(name: str) -> str:
195    """Converts `name` from camelCase to snake_case and returns the result."""
196    return CAMEL_CASE_PATTERN.sub("_", name).upper()
197
198
199def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
200    """
201    Applies a transformation to a given expression until a fix point is reached.
202
203    Args:
204        expression: The expression to be transformed.
205        func: The transformation to be applied.
206
207    Returns:
208        The transformed expression.
209    """
210    while True:
211        for n, *_ in reversed(tuple(expression.walk())):
212            n._hash = hash(n)
213
214        start = hash(expression)
215        expression = func(expression)
216
217        for n, *_ in expression.walk():
218            n._hash = None
219        if start == hash(expression):
220            break
221
222    return expression
223
224
225def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
226    """
227    Sorts a given directed acyclic graph in topological order.
228
229    Args:
230        dag: The graph to be sorted.
231
232    Returns:
233        A list that contains all of the graph's nodes in topological order.
234    """
235    result = []
236
237    for node, deps in tuple(dag.items()):
238        for dep in deps:
239            if dep not in dag:
240                dag[dep] = set()
241
242    while dag:
243        current = {node for node, deps in dag.items() if not deps}
244
245        if not current:
246            raise ValueError("Cycle error")
247
248        for node in current:
249            dag.pop(node)
250
251        for deps in dag.values():
252            deps -= current
253
254        result.extend(sorted(current))  # type: ignore
255
256    return result
257
258
259def open_file(file_name: str) -> t.TextIO:
260    """Open a file that may be compressed as gzip and return it in universal newline mode."""
261    with open(file_name, "rb") as f:
262        gzipped = f.read(2) == b"\x1f\x8b"
263
264    if gzipped:
265        import gzip
266
267        return gzip.open(file_name, "rt", newline="")
268
269    return open(file_name, encoding="utf-8", newline="")
270
271
272@contextmanager
273def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
274    """
275    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
276
277    Args:
278        read_csv: A `ReadCSV` function call.
279
280    Yields:
281        A python csv reader.
282    """
283    args = read_csv.expressions
284    file = open_file(read_csv.name)
285
286    delimiter = ","
287    args = iter(arg.name for arg in args)  # type: ignore
288    for k, v in zip(args, args):
289        if k == "delimiter":
290            delimiter = v
291
292    try:
293        import csv as csv_
294
295        yield csv_.reader(file, delimiter=delimiter)
296    finally:
297        file.close()
298
299
300def find_new_name(taken: t.Collection[str], base: str) -> str:
301    """
302    Searches for a new name.
303
304    Args:
305        taken: A collection of taken names.
306        base: Base name to alter.
307
308    Returns:
309        The new, available name.
310    """
311    if base not in taken:
312        return base
313
314    i = 2
315    new = f"{base}_{i}"
316    while new in taken:
317        i += 1
318        new = f"{base}_{i}"
319
320    return new
321
322
323def is_int(text: str) -> bool:
324    try:
325        int(text)
326        return True
327    except ValueError:
328        return False
329
330
331def name_sequence(prefix: str) -> t.Callable[[], str]:
332    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
333    sequence = count()
334    return lambda: f"{prefix}{next(sequence)}"
335
336
337def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
338    """Returns a dictionary created from an object's attributes."""
339    return {
340        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
341        **kwargs,
342    }
343
344
345def split_num_words(
346    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
347) -> t.List[t.Optional[str]]:
348    """
349    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
350
351    Args:
352        value: The value to be split.
353        sep: The value to use to split on.
354        min_num_words: The minimum number of words that are going to be in the result.
355        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
356
357    Examples:
358        >>> split_num_words("db.table", ".", 3)
359        [None, 'db', 'table']
360        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
361        ['db', 'table', None]
362        >>> split_num_words("db.table", ".", 1)
363        ['db', 'table']
364
365    Returns:
366        The list of words returned by `split`, possibly augmented by a number of `None` values.
367    """
368    words = value.split(sep)
369    if fill_from_start:
370        return [None] * (min_num_words - len(words)) + words
371    return words + [None] * (min_num_words - len(words))
372
373
374def is_iterable(value: t.Any) -> bool:
375    """
376    Checks if the value is an iterable, excluding the types `str` and `bytes`.
377
378    Examples:
379        >>> is_iterable([1,2])
380        True
381        >>> is_iterable("test")
382        False
383
384    Args:
385        value: The value to check if it is an iterable.
386
387    Returns:
388        A `bool` value indicating if it is an iterable.
389    """
390    from sqlglot import Expression
391
392    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))
393
394
395def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
396    """
397    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
398    type `str` and `bytes` are not regarded as iterables.
399
400    Examples:
401        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
402        [1, 2, 3, 4, 5, 'bla']
403        >>> list(flatten([1, 2, 3]))
404        [1, 2, 3]
405
406    Args:
407        values: The value to be flattened.
408
409    Yields:
410        Non-iterable elements in `values`.
411    """
412    for value in values:
413        if is_iterable(value):
414            yield from flatten(value)
415        else:
416            yield value
417
418
419def dict_depth(d: t.Dict) -> int:
420    """
421    Get the nesting depth of a dictionary.
422
423    Example:
424        >>> dict_depth(None)
425        0
426        >>> dict_depth({})
427        1
428        >>> dict_depth({"a": "b"})
429        1
430        >>> dict_depth({"a": {}})
431        2
432        >>> dict_depth({"a": {"b": {}}})
433        3
434    """
435    try:
436        return 1 + dict_depth(next(iter(d.values())))
437    except AttributeError:
438        # d doesn't have attribute "values"
439        return 0
440    except StopIteration:
441        # d.values() returns an empty sequence
442        return 1
443
444
445def first(it: t.Iterable[T]) -> T:
446    """Returns the first element from an iterable (useful for sets)."""
447    return next(i for i in it)
448
449
450def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
451    """
452    Merges a sequence of ranges, represented as tuples (low, high) whose values
453    belong to some totally-ordered set.
454
455    Example:
456        >>> merge_ranges([(1, 3), (2, 6)])
457        [(1, 6)]
458    """
459    if not ranges:
460        return []
461
462    ranges = sorted(ranges)
463
464    merged = [ranges[0]]
465
466    for start, end in ranges[1:]:
467        last_start, last_end = merged[-1]
468
469        if start <= last_end:
470            merged[-1] = (last_start, max(last_end, end))
471        else:
472            merged.append((start, end))
473
474    return merged
475
476
477def is_iso_date(text: str) -> bool:
478    try:
479        datetime.date.fromisoformat(text)
480        return True
481    except ValueError:
482        return False
483
484
485def is_iso_datetime(text: str) -> bool:
486    try:
487        datetime.datetime.fromisoformat(text)
488        return True
489    except ValueError:
490        return False
491
492
493# Interval units that operate on date components
494DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
495
496
497def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
498    return expression is not None and expression.name.lower() in DATE_UNITS
499
500
501K = t.TypeVar("K")
502V = t.TypeVar("V")
503
504
505class SingleValuedMapping(t.Mapping[K, V]):
506    """
507    Mapping where all keys return the same value.
508
509    This rigamarole is meant to avoid copying keys, which was originally intended
510    as an optimization while qualifying columns for tables with lots of columns.
511    """
512
513    def __init__(self, keys: t.Collection[K], value: V):
514        self._keys = keys if isinstance(keys, Set) else set(keys)
515        self._value = value
516
517    def __getitem__(self, key: K) -> V:
518        if key in self._keys:
519            return self._value
520        raise KeyError(key)
521
522    def __len__(self) -> int:
523        return len(self._keys)
524
525    def __iter__(self) -> t.Iterator[K]:
526        return iter(self._keys)
CAMEL_CASE_PATTERN = re.compile('(?<!^)(?=[A-Z])')
PYTHON_VERSION = (3, 10)
logger = <Logger sqlglot (WARNING)>
class AutoName(enum.Enum):
27class AutoName(Enum):
28    """
29    This is used for creating Enum classes where `auto()` is the string form
30    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
31
32    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
33    """
34
35    def _generate_next_value_(name, _start, _count, _last_values):
36        return name

This is used for creating Enum classes where auto() is the string form of the corresponding enum's identifier (e.g. FOO.value results in "FOO").

Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values

Inherited Members
enum.Enum
name
value
class classproperty(builtins.property):
39class classproperty(property):
40    """
41    Similar to a normal property but works for class methods
42    """
43
44    def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
45        return classmethod(self.fget).__get__(None, owner)()  # type: ignore

Similar to a normal property but works for class methods

Inherited Members
builtins.property
property
getter
setter
deleter
fget
fset
fdel
def seq_get(seq: Sequence[~T], index: int) -> Optional[~T]:
48def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
49    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
50    try:
51        return seq[index]
52    except IndexError:
53        return None

Returns the value in seq at position index, or None if index is out of bounds.

def ensure_list(value):
66def ensure_list(value):
67    """
68    Ensures that a value is a list, otherwise casts or wraps it into one.
69
70    Args:
71        value: The value of interest.
72
73    Returns:
74        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
75    """
76    if value is None:
77        return []
78    if isinstance(value, (list, tuple)):
79        return list(value)
80
81    return [value]

Ensures that a value is a list, otherwise casts or wraps it into one.

Arguments:
  • value: The value of interest.
Returns:

The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.

def ensure_collection(value):
 94def ensure_collection(value):
 95    """
 96    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
 97
 98    Args:
 99        value: The value of interest.
100
101    Returns:
102        The value if it's a collection, or else the value wrapped in a list.
103    """
104    if value is None:
105        return []
106    return (
107        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
108    )

Ensures that a value is a collection (excluding str and bytes), otherwise wraps it into a list.

Arguments:
  • value: The value of interest.
Returns:

The value if it's a collection, or else the value wrapped in a list.

def csv(*args: str, sep: str = ', ') -> str:
111def csv(*args: str, sep: str = ", ") -> str:
112    """
113    Formats any number of string arguments as CSV.
114
115    Args:
116        args: The string arguments to format.
117        sep: The argument separator.
118
119    Returns:
120        The arguments formatted as a CSV string.
121    """
122    return sep.join(arg for arg in args if arg)

Formats any number of string arguments as CSV.

Arguments:
  • args: The string arguments to format.
  • sep: The argument separator.
Returns:

The arguments formatted as a CSV string.

def subclasses( module_name: str, classes: Union[Type, Tuple[Type, ...]], exclude: Union[Type, Tuple[Type, ...]] = ()) -> List[Type]:
125def subclasses(
126    module_name: str,
127    classes: t.Type | t.Tuple[t.Type, ...],
128    exclude: t.Type | t.Tuple[t.Type, ...] = (),
129) -> t.List[t.Type]:
130    """
131    Returns all subclasses for a collection of classes, possibly excluding some of them.
132
133    Args:
134        module_name: The name of the module to search for subclasses in.
135        classes: Class(es) we want to find the subclasses of.
136        exclude: Class(es) we want to exclude from the returned list.
137
138    Returns:
139        The target subclasses.
140    """
141    return [
142        obj
143        for _, obj in inspect.getmembers(
144            sys.modules[module_name],
145            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
146        )
147    ]

Returns all subclasses for a collection of classes, possibly excluding some of them.

Arguments:
  • module_name: The name of the module to search for subclasses in.
  • classes: Class(es) we want to find the subclasses of.
  • exclude: Class(es) we want to exclude from the returned list.
Returns:

The target subclasses.

def apply_index_offset( this: sqlglot.expressions.Expression, expressions: List[~E], offset: int) -> List[~E]:
150def apply_index_offset(
151    this: exp.Expression,
152    expressions: t.List[E],
153    offset: int,
154) -> t.List[E]:
155    """
156    Applies an offset to a given integer literal expression.
157
158    Args:
159        this: The target of the index.
160        expressions: The expression the offset will be applied to, wrapped in a list.
161        offset: The offset that will be applied.
162
163    Returns:
164        The original expression with the offset applied to it, wrapped in a list. If the provided
165        `expressions` argument contains more than one expression, it's returned unaffected.
166    """
167    if not offset or len(expressions) != 1:
168        return expressions
169
170    expression = expressions[0]
171
172    from sqlglot import exp
173    from sqlglot.optimizer.annotate_types import annotate_types
174    from sqlglot.optimizer.simplify import simplify
175
176    if not this.type:
177        annotate_types(this)
178
179    if t.cast(exp.DataType, this.type).this not in (
180        exp.DataType.Type.UNKNOWN,
181        exp.DataType.Type.ARRAY,
182    ):
183        return expressions
184
185    if not expression.type:
186        annotate_types(expression)
187    if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
188        logger.warning("Applying array index offset (%s)", offset)
189        expression = simplify(exp.Add(this=expression, expression=exp.Literal.number(offset)))
190        return [expression]
191
192    return expressions

Applies an offset to a given integer literal expression.

Arguments:
  • this: The target of the index.
  • expressions: The expression the offset will be applied to, wrapped in a list.
  • offset: The offset that will be applied.
Returns:

The original expression with the offset applied to it, wrapped in a list. If the provided expressions argument contains more than one expression, it's returned unaffected.

def camel_to_snake_case(name: str) -> str:
195def camel_to_snake_case(name: str) -> str:
196    """Converts `name` from camelCase to snake_case and returns the result."""
197    return CAMEL_CASE_PATTERN.sub("_", name).upper()

Converts name from camelCase to snake_case and returns the result.

def while_changing( expression: sqlglot.expressions.Expression, func: Callable[[sqlglot.expressions.Expression], ~E]) -> ~E:
200def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
201    """
202    Applies a transformation to a given expression until a fix point is reached.
203
204    Args:
205        expression: The expression to be transformed.
206        func: The transformation to be applied.
207
208    Returns:
209        The transformed expression.
210    """
211    while True:
212        for n, *_ in reversed(tuple(expression.walk())):
213            n._hash = hash(n)
214
215        start = hash(expression)
216        expression = func(expression)
217
218        for n, *_ in expression.walk():
219            n._hash = None
220        if start == hash(expression):
221            break
222
223    return expression

Applies a transformation to a given expression until a fix point is reached.

Arguments:
  • expression: The expression to be transformed.
  • func: The transformation to be applied.
Returns:

The transformed expression.

def tsort(dag: Dict[~T, Set[~T]]) -> List[~T]:
226def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
227    """
228    Sorts a given directed acyclic graph in topological order.
229
230    Args:
231        dag: The graph to be sorted.
232
233    Returns:
234        A list that contains all of the graph's nodes in topological order.
235    """
236    result = []
237
238    for node, deps in tuple(dag.items()):
239        for dep in deps:
240            if dep not in dag:
241                dag[dep] = set()
242
243    while dag:
244        current = {node for node, deps in dag.items() if not deps}
245
246        if not current:
247            raise ValueError("Cycle error")
248
249        for node in current:
250            dag.pop(node)
251
252        for deps in dag.values():
253            deps -= current
254
255        result.extend(sorted(current))  # type: ignore
256
257    return result

Sorts a given directed acyclic graph in topological order.

Arguments:
  • dag: The graph to be sorted.
Returns:

A list that contains all of the graph's nodes in topological order.

def open_file(file_name: str) -> <class 'TextIO'>:
260def open_file(file_name: str) -> t.TextIO:
261    """Open a file that may be compressed as gzip and return it in universal newline mode."""
262    with open(file_name, "rb") as f:
263        gzipped = f.read(2) == b"\x1f\x8b"
264
265    if gzipped:
266        import gzip
267
268        return gzip.open(file_name, "rt", newline="")
269
270    return open(file_name, encoding="utf-8", newline="")

Open a file that may be compressed as gzip and return it in universal newline mode.

@contextmanager
def csv_reader(read_csv: sqlglot.expressions.ReadCSV) -> Any:
273@contextmanager
274def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
275    """
276    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
277
278    Args:
279        read_csv: A `ReadCSV` function call.
280
281    Yields:
282        A python csv reader.
283    """
284    args = read_csv.expressions
285    file = open_file(read_csv.name)
286
287    delimiter = ","
288    args = iter(arg.name for arg in args)  # type: ignore
289    for k, v in zip(args, args):
290        if k == "delimiter":
291            delimiter = v
292
293    try:
294        import csv as csv_
295
296        yield csv_.reader(file, delimiter=delimiter)
297    finally:
298        file.close()

Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...]).

Arguments:
  • read_csv: A ReadCSV function call.
Yields:

A python csv reader.

def find_new_name(taken: Collection[str], base: str) -> str:
301def find_new_name(taken: t.Collection[str], base: str) -> str:
302    """
303    Searches for a new name.
304
305    Args:
306        taken: A collection of taken names.
307        base: Base name to alter.
308
309    Returns:
310        The new, available name.
311    """
312    if base not in taken:
313        return base
314
315    i = 2
316    new = f"{base}_{i}"
317    while new in taken:
318        i += 1
319        new = f"{base}_{i}"
320
321    return new

Searches for a new name.

Arguments:
  • taken: A collection of taken names.
  • base: Base name to alter.
Returns:

The new, available name.

def is_int(text: str) -> bool:
324def is_int(text: str) -> bool:
325    try:
326        int(text)
327        return True
328    except ValueError:
329        return False
def name_sequence(prefix: str) -> Callable[[], str]:
332def name_sequence(prefix: str) -> t.Callable[[], str]:
333    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
334    sequence = count()
335    return lambda: f"{prefix}{next(sequence)}"

Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").

def object_to_dict(obj: Any, **kwargs) -> Dict:
338def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
339    """Returns a dictionary created from an object's attributes."""
340    return {
341        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
342        **kwargs,
343    }

Returns a dictionary created from an object's attributes.

def split_num_words( value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> List[Optional[str]]:
346def split_num_words(
347    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
348) -> t.List[t.Optional[str]]:
349    """
350    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
351
352    Args:
353        value: The value to be split.
354        sep: The value to use to split on.
355        min_num_words: The minimum number of words that are going to be in the result.
356        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
357
358    Examples:
359        >>> split_num_words("db.table", ".", 3)
360        [None, 'db', 'table']
361        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
362        ['db', 'table', None]
363        >>> split_num_words("db.table", ".", 1)
364        ['db', 'table']
365
366    Returns:
367        The list of words returned by `split`, possibly augmented by a number of `None` values.
368    """
369    words = value.split(sep)
370    if fill_from_start:
371        return [None] * (min_num_words - len(words)) + words
372    return words + [None] * (min_num_words - len(words))

Perform a split on a value and return N words as a result with None used for words that don't exist.

Arguments:
  • value: The value to be split.
  • sep: The value to use to split on.
  • min_num_words: The minimum number of words that are going to be in the result.
  • fill_from_start: Indicates that if None values should be inserted at the start or end of the list.
Examples:
>>> split_num_words("db.table", ".", 3)
[None, 'db', 'table']
>>> split_num_words("db.table", ".", 3, fill_from_start=False)
['db', 'table', None]
>>> split_num_words("db.table", ".", 1)
['db', 'table']
Returns:

The list of words returned by split, possibly augmented by a number of None values.

def is_iterable(value: Any) -> bool:
375def is_iterable(value: t.Any) -> bool:
376    """
377    Checks if the value is an iterable, excluding the types `str` and `bytes`.
378
379    Examples:
380        >>> is_iterable([1,2])
381        True
382        >>> is_iterable("test")
383        False
384
385    Args:
386        value: The value to check if it is an iterable.
387
388    Returns:
389        A `bool` value indicating if it is an iterable.
390    """
391    from sqlglot import Expression
392
393    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))

Checks if the value is an iterable, excluding the types str and bytes.

Examples:
>>> is_iterable([1,2])
True
>>> is_iterable("test")
False
Arguments:
  • value: The value to check if it is an iterable.
Returns:

A bool value indicating if it is an iterable.

def flatten(values: Iterable[Union[Iterable[Any], Any]]) -> Iterator[Any]:
396def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
397    """
398    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
399    type `str` and `bytes` are not regarded as iterables.
400
401    Examples:
402        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
403        [1, 2, 3, 4, 5, 'bla']
404        >>> list(flatten([1, 2, 3]))
405        [1, 2, 3]
406
407    Args:
408        values: The value to be flattened.
409
410    Yields:
411        Non-iterable elements in `values`.
412    """
413    for value in values:
414        if is_iterable(value):
415            yield from flatten(value)
416        else:
417            yield value

Flattens an iterable that can contain both iterable and non-iterable elements. Objects of type str and bytes are not regarded as iterables.

Examples:
>>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
[1, 2, 3, 4, 5, 'bla']
>>> list(flatten([1, 2, 3]))
[1, 2, 3]
Arguments:
  • values: The value to be flattened.
Yields:

Non-iterable elements in values.

def dict_depth(d: Dict) -> int:
420def dict_depth(d: t.Dict) -> int:
421    """
422    Get the nesting depth of a dictionary.
423
424    Example:
425        >>> dict_depth(None)
426        0
427        >>> dict_depth({})
428        1
429        >>> dict_depth({"a": "b"})
430        1
431        >>> dict_depth({"a": {}})
432        2
433        >>> dict_depth({"a": {"b": {}}})
434        3
435    """
436    try:
437        return 1 + dict_depth(next(iter(d.values())))
438    except AttributeError:
439        # d doesn't have attribute "values"
440        return 0
441    except StopIteration:
442        # d.values() returns an empty sequence
443        return 1

Get the nesting depth of a dictionary.

Example:
>>> dict_depth(None)
0
>>> dict_depth({})
1
>>> dict_depth({"a": "b"})
1
>>> dict_depth({"a": {}})
2
>>> dict_depth({"a": {"b": {}}})
3
def first(it: Iterable[~T]) -> ~T:
446def first(it: t.Iterable[T]) -> T:
447    """Returns the first element from an iterable (useful for sets)."""
448    return next(i for i in it)

Returns the first element from an iterable (useful for sets).

def merge_ranges(ranges: List[Tuple[~A, ~A]]) -> List[Tuple[~A, ~A]]:
451def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
452    """
453    Merges a sequence of ranges, represented as tuples (low, high) whose values
454    belong to some totally-ordered set.
455
456    Example:
457        >>> merge_ranges([(1, 3), (2, 6)])
458        [(1, 6)]
459    """
460    if not ranges:
461        return []
462
463    ranges = sorted(ranges)
464
465    merged = [ranges[0]]
466
467    for start, end in ranges[1:]:
468        last_start, last_end = merged[-1]
469
470        if start <= last_end:
471            merged[-1] = (last_start, max(last_end, end))
472        else:
473            merged.append((start, end))
474
475    return merged

Merges a sequence of ranges, represented as tuples (low, high) whose values belong to some totally-ordered set.

Example:
>>> merge_ranges([(1, 3), (2, 6)])
[(1, 6)]
def is_iso_date(text: str) -> bool:
478def is_iso_date(text: str) -> bool:
479    try:
480        datetime.date.fromisoformat(text)
481        return True
482    except ValueError:
483        return False
def is_iso_datetime(text: str) -> bool:
486def is_iso_datetime(text: str) -> bool:
487    try:
488        datetime.datetime.fromisoformat(text)
489        return True
490    except ValueError:
491        return False
DATE_UNITS = {'month', 'day', 'year_month', 'quarter', 'week', 'year'}
def is_date_unit(expression: Optional[sqlglot.expressions.Expression]) -> bool:
498def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
499    return expression is not None and expression.name.lower() in DATE_UNITS
class SingleValuedMapping(typing.Mapping[~K, ~V]):
506class SingleValuedMapping(t.Mapping[K, V]):
507    """
508    Mapping where all keys return the same value.
509
510    This rigamarole is meant to avoid copying keys, which was originally intended
511    as an optimization while qualifying columns for tables with lots of columns.
512    """
513
514    def __init__(self, keys: t.Collection[K], value: V):
515        self._keys = keys if isinstance(keys, Set) else set(keys)
516        self._value = value
517
518    def __getitem__(self, key: K) -> V:
519        if key in self._keys:
520            return self._value
521        raise KeyError(key)
522
523    def __len__(self) -> int:
524        return len(self._keys)
525
526    def __iter__(self) -> t.Iterator[K]:
527        return iter(self._keys)

Mapping where all keys return the same value.

This rigamarole is meant to avoid copying keys, which was originally intended as an optimization while qualifying columns for tables with lots of columns.

SingleValuedMapping(keys: Collection[~K], value: ~V)
514    def __init__(self, keys: t.Collection[K], value: V):
515        self._keys = keys if isinstance(keys, Set) else set(keys)
516        self._value = value
Inherited Members
collections.abc.Mapping
get
keys
items
values