Edit on GitHub

sqlglot.dialects.dialect

  1from __future__ import annotations
  2
  3import typing as t
  4from enum import Enum
  5
  6from sqlglot import exp
  7from sqlglot._typing import E
  8from sqlglot.errors import ParseError
  9from sqlglot.generator import Generator
 10from sqlglot.helper import flatten, seq_get
 11from sqlglot.parser import Parser
 12from sqlglot.time import format_time
 13from sqlglot.tokens import Token, Tokenizer, TokenType
 14from sqlglot.trie import new_trie
 15
 16B = t.TypeVar("B", bound=exp.Binary)
 17
 18
 19class Dialects(str, Enum):
 20    DIALECT = ""
 21
 22    BIGQUERY = "bigquery"
 23    CLICKHOUSE = "clickhouse"
 24    DATABRICKS = "databricks"
 25    DRILL = "drill"
 26    DUCKDB = "duckdb"
 27    HIVE = "hive"
 28    MYSQL = "mysql"
 29    ORACLE = "oracle"
 30    POSTGRES = "postgres"
 31    PRESTO = "presto"
 32    REDSHIFT = "redshift"
 33    SNOWFLAKE = "snowflake"
 34    SPARK = "spark"
 35    SPARK2 = "spark2"
 36    SQLITE = "sqlite"
 37    STARROCKS = "starrocks"
 38    TABLEAU = "tableau"
 39    TERADATA = "teradata"
 40    TRINO = "trino"
 41    TSQL = "tsql"
 42
 43
 44class _Dialect(type):
 45    classes: t.Dict[str, t.Type[Dialect]] = {}
 46
 47    def __eq__(cls, other: t.Any) -> bool:
 48        if cls is other:
 49            return True
 50        if isinstance(other, str):
 51            return cls is cls.get(other)
 52        if isinstance(other, Dialect):
 53            return cls is type(other)
 54
 55        return False
 56
 57    def __hash__(cls) -> int:
 58        return hash(cls.__name__.lower())
 59
 60    @classmethod
 61    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 62        return cls.classes[key]
 63
 64    @classmethod
 65    def get(
 66        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 67    ) -> t.Optional[t.Type[Dialect]]:
 68        return cls.classes.get(key, default)
 69
 70    def __new__(cls, clsname, bases, attrs):
 71        klass = super().__new__(cls, clsname, bases, attrs)
 72        enum = Dialects.__members__.get(clsname.upper())
 73        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 74
 75        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 76        klass.FORMAT_TRIE = (
 77            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 78        )
 79        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 80        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 81
 82        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 83        klass.parser_class = getattr(klass, "Parser", Parser)
 84        klass.generator_class = getattr(klass, "Generator", Generator)
 85
 86        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 87        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 88            klass.tokenizer_class._IDENTIFIERS.items()
 89        )[0]
 90
 91        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 92            return next(
 93                (
 94                    (s, e)
 95                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 96                    if t == token_type
 97                ),
 98                (None, None),
 99            )
100
101        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
102        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
103        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
104
105        dialect_properties = {
106            **{
107                k: v
108                for k, v in vars(klass).items()
109                if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__")
110            },
111            "STRING_ESCAPE": klass.tokenizer_class.STRING_ESCAPES[0],
112            "IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0],
113        }
114
115        if enum not in ("", "bigquery"):
116            dialect_properties["SELECT_KINDS"] = ()
117
118        # Pass required dialect properties to the tokenizer, parser and generator classes
119        for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class):
120            for name, value in dialect_properties.items():
121                if hasattr(subclass, name):
122                    setattr(subclass, name, value)
123
124        if not klass.STRICT_STRING_CONCAT:
125            klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
126
127        klass.generator_class.can_identify = klass.can_identify
128
129        return klass
130
131
132class Dialect(metaclass=_Dialect):
133    # Determines the base index offset for arrays
134    INDEX_OFFSET = 0
135
136    # If true unnest table aliases are considered only as column aliases
137    UNNEST_COLUMN_ONLY = False
138
139    # Determines whether or not the table alias comes after tablesample
140    ALIAS_POST_TABLESAMPLE = False
141
142    # Determines whether or not unquoted identifiers are resolved as uppercase
143    # When set to None, it means that the dialect treats all identifiers as case-insensitive
144    RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
145
146    # Determines whether or not an unquoted identifier can start with a digit
147    IDENTIFIERS_CAN_START_WITH_DIGIT = False
148
149    # Determines whether or not CONCAT's arguments must be strings
150    STRICT_STRING_CONCAT = False
151
152    # Determines how function names are going to be normalized
153    NORMALIZE_FUNCTIONS: bool | str = "upper"
154
155    # Indicates the default null ordering method to use if not explicitly set
156    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
157    NULL_ORDERING = "nulls_are_small"
158
159    DATE_FORMAT = "'%Y-%m-%d'"
160    DATEINT_FORMAT = "'%Y%m%d'"
161    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
162
163    # Custom time mappings in which the key represents dialect time format
164    # and the value represents a python time format
165    TIME_MAPPING: t.Dict[str, str] = {}
166
167    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
168    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
169    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
170    FORMAT_MAPPING: t.Dict[str, str] = {}
171
172    # Columns that are auto-generated by the engine corresponding to this dialect
173    # Such columns may be excluded from SELECT * queries, for example
174    PSEUDOCOLUMNS: t.Set[str] = set()
175
176    # Autofilled
177    tokenizer_class = Tokenizer
178    parser_class = Parser
179    generator_class = Generator
180
181    # A trie of the time_mapping keys
182    TIME_TRIE: t.Dict = {}
183    FORMAT_TRIE: t.Dict = {}
184
185    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
186    INVERSE_TIME_TRIE: t.Dict = {}
187
188    def __eq__(self, other: t.Any) -> bool:
189        return type(self) == other
190
191    def __hash__(self) -> int:
192        return hash(type(self))
193
194    @classmethod
195    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
196        if not dialect:
197            return cls
198        if isinstance(dialect, _Dialect):
199            return dialect
200        if isinstance(dialect, Dialect):
201            return dialect.__class__
202
203        result = cls.get(dialect)
204        if not result:
205            raise ValueError(f"Unknown dialect '{dialect}'")
206
207        return result
208
209    @classmethod
210    def format_time(
211        cls, expression: t.Optional[str | exp.Expression]
212    ) -> t.Optional[exp.Expression]:
213        if isinstance(expression, str):
214            return exp.Literal.string(
215                # the time formats are quoted
216                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
217            )
218
219        if expression and expression.is_string:
220            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
221
222        return expression
223
224    @classmethod
225    def normalize_identifier(cls, expression: E) -> E:
226        """
227        Normalizes an unquoted identifier to either lower or upper case, thus essentially
228        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
229        they will be normalized regardless of being quoted or not.
230        """
231        if isinstance(expression, exp.Identifier) and (
232            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
233        ):
234            expression.set(
235                "this",
236                expression.this.upper()
237                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
238                else expression.this.lower(),
239            )
240
241        return expression
242
243    @classmethod
244    def case_sensitive(cls, text: str) -> bool:
245        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
246        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
247            return False
248
249        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
250        return any(unsafe(char) for char in text)
251
252    @classmethod
253    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
254        """Checks if text can be identified given an identify option.
255
256        Args:
257            text: The text to check.
258            identify:
259                "always" or `True`: Always returns true.
260                "safe": True if the identifier is case-insensitive.
261
262        Returns:
263            Whether or not the given text can be identified.
264        """
265        if identify is True or identify == "always":
266            return True
267
268        if identify == "safe":
269            return not cls.case_sensitive(text)
270
271        return False
272
273    @classmethod
274    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
275        if isinstance(expression, exp.Identifier):
276            name = expression.this
277            expression.set(
278                "quoted",
279                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
280            )
281
282        return expression
283
284    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
285        return self.parser(**opts).parse(self.tokenize(sql), sql)
286
287    def parse_into(
288        self, expression_type: exp.IntoType, sql: str, **opts
289    ) -> t.List[t.Optional[exp.Expression]]:
290        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
291
292    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
293        return self.generator(**opts).generate(expression)
294
295    def transpile(self, sql: str, **opts) -> t.List[str]:
296        return [self.generate(expression, **opts) for expression in self.parse(sql)]
297
298    def tokenize(self, sql: str) -> t.List[Token]:
299        return self.tokenizer.tokenize(sql)
300
301    @property
302    def tokenizer(self) -> Tokenizer:
303        if not hasattr(self, "_tokenizer"):
304            self._tokenizer = self.tokenizer_class()
305        return self._tokenizer
306
307    def parser(self, **opts) -> Parser:
308        return self.parser_class(**opts)
309
310    def generator(self, **opts) -> Generator:
311        return self.generator_class(**opts)
312
313
314DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
315
316
317def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
318    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
319
320
321def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
322    if expression.args.get("accuracy"):
323        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
324    return self.func("APPROX_COUNT_DISTINCT", expression.this)
325
326
327def if_sql(self: Generator, expression: exp.If) -> str:
328    return self.func(
329        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
330    )
331
332
333def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
334    return self.binary(expression, "->")
335
336
337def arrow_json_extract_scalar_sql(
338    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
339) -> str:
340    return self.binary(expression, "->>")
341
342
343def inline_array_sql(self: Generator, expression: exp.Array) -> str:
344    return f"[{self.expressions(expression)}]"
345
346
347def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
348    return self.like_sql(
349        exp.Like(
350            this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy()
351        )
352    )
353
354
355def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
356    zone = self.sql(expression, "this")
357    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
358
359
360def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
361    if expression.args.get("recursive"):
362        self.unsupported("Recursive CTEs are unsupported")
363        expression.args["recursive"] = False
364    return self.with_sql(expression)
365
366
367def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
368    n = self.sql(expression, "this")
369    d = self.sql(expression, "expression")
370    return f"IF({d} <> 0, {n} / {d}, NULL)"
371
372
373def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
374    self.unsupported("TABLESAMPLE unsupported")
375    return self.sql(expression.this)
376
377
378def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
379    self.unsupported("PIVOT unsupported")
380    return ""
381
382
383def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
384    return self.cast_sql(expression)
385
386
387def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
388    self.unsupported("Properties unsupported")
389    return ""
390
391
392def no_comment_column_constraint_sql(
393    self: Generator, expression: exp.CommentColumnConstraint
394) -> str:
395    self.unsupported("CommentColumnConstraint unsupported")
396    return ""
397
398
399def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
400    self.unsupported("MAP_FROM_ENTRIES unsupported")
401    return ""
402
403
404def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
405    this = self.sql(expression, "this")
406    substr = self.sql(expression, "substr")
407    position = self.sql(expression, "position")
408    if position:
409        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
410    return f"STRPOS({this}, {substr})"
411
412
413def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
414    this = self.sql(expression, "this")
415    struct_key = self.sql(exp.Identifier(this=expression.expression.copy(), quoted=True))
416    return f"{this}.{struct_key}"
417
418
419def var_map_sql(
420    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
421) -> str:
422    keys = expression.args["keys"]
423    values = expression.args["values"]
424
425    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
426        self.unsupported("Cannot convert array columns into map.")
427        return self.func(map_func_name, keys, values)
428
429    args = []
430    for key, value in zip(keys.expressions, values.expressions):
431        args.append(self.sql(key))
432        args.append(self.sql(value))
433
434    return self.func(map_func_name, *args)
435
436
437def format_time_lambda(
438    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
439) -> t.Callable[[t.List], E]:
440    """Helper used for time expressions.
441
442    Args:
443        exp_class: the expression class to instantiate.
444        dialect: target sql dialect.
445        default: the default format, True being time.
446
447    Returns:
448        A callable that can be used to return the appropriately formatted time expression.
449    """
450
451    def _format_time(args: t.List):
452        return exp_class(
453            this=seq_get(args, 0),
454            format=Dialect[dialect].format_time(
455                seq_get(args, 1)
456                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
457            ),
458        )
459
460    return _format_time
461
462
463def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
464    """
465    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
466    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
467    columns are removed from the create statement.
468    """
469    has_schema = isinstance(expression.this, exp.Schema)
470    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
471
472    if has_schema and is_partitionable:
473        expression = expression.copy()
474        prop = expression.find(exp.PartitionedByProperty)
475        if prop and prop.this and not isinstance(prop.this, exp.Schema):
476            schema = expression.this
477            columns = {v.name.upper() for v in prop.this.expressions}
478            partitions = [col for col in schema.expressions if col.name.upper() in columns]
479            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
480            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
481            expression.set("this", schema)
482
483    return self.create_sql(expression)
484
485
486def parse_date_delta(
487    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
488) -> t.Callable[[t.List], E]:
489    def inner_func(args: t.List) -> E:
490        unit_based = len(args) == 3
491        this = args[2] if unit_based else seq_get(args, 0)
492        unit = args[0] if unit_based else exp.Literal.string("DAY")
493        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
494        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
495
496    return inner_func
497
498
499def parse_date_delta_with_interval(
500    expression_class: t.Type[E],
501) -> t.Callable[[t.List], t.Optional[E]]:
502    def func(args: t.List) -> t.Optional[E]:
503        if len(args) < 2:
504            return None
505
506        interval = args[1]
507
508        if not isinstance(interval, exp.Interval):
509            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
510
511        expression = interval.this
512        if expression and expression.is_string:
513            expression = exp.Literal.number(expression.this)
514
515        return expression_class(
516            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
517        )
518
519    return func
520
521
522def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
523    unit = seq_get(args, 0)
524    this = seq_get(args, 1)
525
526    if isinstance(this, exp.Cast) and this.is_type("date"):
527        return exp.DateTrunc(unit=unit, this=this)
528    return exp.TimestampTrunc(this=this, unit=unit)
529
530
531def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
532    return self.func(
533        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
534    )
535
536
537def locate_to_strposition(args: t.List) -> exp.Expression:
538    return exp.StrPosition(
539        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
540    )
541
542
543def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
544    return self.func(
545        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
546    )
547
548
549def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
550    expression = expression.copy()
551    return self.sql(
552        exp.Substring(
553            this=expression.this, start=exp.Literal.number(1), length=expression.expression
554        )
555    )
556
557
558def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
559    expression = expression.copy()
560    return self.sql(
561        exp.Substring(
562            this=expression.this,
563            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
564        )
565    )
566
567
568def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
569    return self.sql(exp.cast(expression.this, "timestamp"))
570
571
572def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
573    return self.sql(exp.cast(expression.this, "date"))
574
575
576# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
577def encode_decode_sql(
578    self: Generator, expression: exp.Expression, name: str, replace: bool = True
579) -> str:
580    charset = expression.args.get("charset")
581    if charset and charset.name.lower() != "utf-8":
582        self.unsupported(f"Expected utf-8 character set, got {charset}.")
583
584    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
585
586
587def min_or_least(self: Generator, expression: exp.Min) -> str:
588    name = "LEAST" if expression.expressions else "MIN"
589    return rename_func(name)(self, expression)
590
591
592def max_or_greatest(self: Generator, expression: exp.Max) -> str:
593    name = "GREATEST" if expression.expressions else "MAX"
594    return rename_func(name)(self, expression)
595
596
597def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
598    cond = expression.this
599
600    if isinstance(expression.this, exp.Distinct):
601        cond = expression.this.expressions[0]
602        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
603
604    return self.func("sum", exp.func("if", cond.copy(), 1, 0))
605
606
607def trim_sql(self: Generator, expression: exp.Trim) -> str:
608    target = self.sql(expression, "this")
609    trim_type = self.sql(expression, "position")
610    remove_chars = self.sql(expression, "expression")
611    collation = self.sql(expression, "collation")
612
613    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
614    if not remove_chars and not collation:
615        return self.trim_sql(expression)
616
617    trim_type = f"{trim_type} " if trim_type else ""
618    remove_chars = f"{remove_chars} " if remove_chars else ""
619    from_part = "FROM " if trim_type or remove_chars else ""
620    collation = f" COLLATE {collation}" if collation else ""
621    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
622
623
624def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
625    return self.func("STRPTIME", expression.this, self.format_time(expression))
626
627
628def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
629    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
630        _dialect = Dialect.get_or_raise(dialect)
631        time_format = self.format_time(expression)
632        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
633            return self.sql(exp.cast(str_to_time_sql(self, expression), "date"))
634
635        return self.sql(exp.cast(self.sql(expression, "this"), "date"))
636
637    return _ts_or_ds_to_date_sql
638
639
640def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
641    expression = expression.copy()
642    this, *rest_args = expression.expressions
643    for arg in rest_args:
644        this = exp.DPipe(this=this, expression=arg)
645
646    return self.sql(this)
647
648
649def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
650    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
651    if bad_args:
652        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
653
654    return self.func(
655        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
656    )
657
658
659def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
660    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
661    if bad_args:
662        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
663
664    return self.func(
665        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
666    )
667
668
669def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
670    names = []
671    for agg in aggregations:
672        if isinstance(agg, exp.Alias):
673            names.append(agg.alias)
674        else:
675            """
676            This case corresponds to aggregations without aliases being used as suffixes
677            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
678            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
679            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
680            """
681            agg_all_unquoted = agg.transform(
682                lambda node: exp.Identifier(this=node.name, quoted=False)
683                if isinstance(node, exp.Identifier)
684                else node
685            )
686            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
687
688    return names
689
690
691def simplify_literal(expression: E) -> E:
692    if not isinstance(expression.expression, exp.Literal):
693        from sqlglot.optimizer.simplify import simplify
694
695        simplify(expression.expression)
696
697    return expression
698
699
700def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
701    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
class Dialects(builtins.str, enum.Enum):
20class Dialects(str, Enum):
21    DIALECT = ""
22
23    BIGQUERY = "bigquery"
24    CLICKHOUSE = "clickhouse"
25    DATABRICKS = "databricks"
26    DRILL = "drill"
27    DUCKDB = "duckdb"
28    HIVE = "hive"
29    MYSQL = "mysql"
30    ORACLE = "oracle"
31    POSTGRES = "postgres"
32    PRESTO = "presto"
33    REDSHIFT = "redshift"
34    SNOWFLAKE = "snowflake"
35    SPARK = "spark"
36    SPARK2 = "spark2"
37    SQLITE = "sqlite"
38    STARROCKS = "starrocks"
39    TABLEAU = "tableau"
40    TERADATA = "teradata"
41    TRINO = "trino"
42    TSQL = "tsql"

An enumeration.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
133class Dialect(metaclass=_Dialect):
134    # Determines the base index offset for arrays
135    INDEX_OFFSET = 0
136
137    # If true unnest table aliases are considered only as column aliases
138    UNNEST_COLUMN_ONLY = False
139
140    # Determines whether or not the table alias comes after tablesample
141    ALIAS_POST_TABLESAMPLE = False
142
143    # Determines whether or not unquoted identifiers are resolved as uppercase
144    # When set to None, it means that the dialect treats all identifiers as case-insensitive
145    RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
146
147    # Determines whether or not an unquoted identifier can start with a digit
148    IDENTIFIERS_CAN_START_WITH_DIGIT = False
149
150    # Determines whether or not CONCAT's arguments must be strings
151    STRICT_STRING_CONCAT = False
152
153    # Determines how function names are going to be normalized
154    NORMALIZE_FUNCTIONS: bool | str = "upper"
155
156    # Indicates the default null ordering method to use if not explicitly set
157    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
158    NULL_ORDERING = "nulls_are_small"
159
160    DATE_FORMAT = "'%Y-%m-%d'"
161    DATEINT_FORMAT = "'%Y%m%d'"
162    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
163
164    # Custom time mappings in which the key represents dialect time format
165    # and the value represents a python time format
166    TIME_MAPPING: t.Dict[str, str] = {}
167
168    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
169    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
170    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
171    FORMAT_MAPPING: t.Dict[str, str] = {}
172
173    # Columns that are auto-generated by the engine corresponding to this dialect
174    # Such columns may be excluded from SELECT * queries, for example
175    PSEUDOCOLUMNS: t.Set[str] = set()
176
177    # Autofilled
178    tokenizer_class = Tokenizer
179    parser_class = Parser
180    generator_class = Generator
181
182    # A trie of the time_mapping keys
183    TIME_TRIE: t.Dict = {}
184    FORMAT_TRIE: t.Dict = {}
185
186    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
187    INVERSE_TIME_TRIE: t.Dict = {}
188
189    def __eq__(self, other: t.Any) -> bool:
190        return type(self) == other
191
192    def __hash__(self) -> int:
193        return hash(type(self))
194
195    @classmethod
196    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
197        if not dialect:
198            return cls
199        if isinstance(dialect, _Dialect):
200            return dialect
201        if isinstance(dialect, Dialect):
202            return dialect.__class__
203
204        result = cls.get(dialect)
205        if not result:
206            raise ValueError(f"Unknown dialect '{dialect}'")
207
208        return result
209
210    @classmethod
211    def format_time(
212        cls, expression: t.Optional[str | exp.Expression]
213    ) -> t.Optional[exp.Expression]:
214        if isinstance(expression, str):
215            return exp.Literal.string(
216                # the time formats are quoted
217                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
218            )
219
220        if expression and expression.is_string:
221            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
222
223        return expression
224
225    @classmethod
226    def normalize_identifier(cls, expression: E) -> E:
227        """
228        Normalizes an unquoted identifier to either lower or upper case, thus essentially
229        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
230        they will be normalized regardless of being quoted or not.
231        """
232        if isinstance(expression, exp.Identifier) and (
233            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
234        ):
235            expression.set(
236                "this",
237                expression.this.upper()
238                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
239                else expression.this.lower(),
240            )
241
242        return expression
243
244    @classmethod
245    def case_sensitive(cls, text: str) -> bool:
246        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
247        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
248            return False
249
250        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
251        return any(unsafe(char) for char in text)
252
253    @classmethod
254    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
255        """Checks if text can be identified given an identify option.
256
257        Args:
258            text: The text to check.
259            identify:
260                "always" or `True`: Always returns true.
261                "safe": True if the identifier is case-insensitive.
262
263        Returns:
264            Whether or not the given text can be identified.
265        """
266        if identify is True or identify == "always":
267            return True
268
269        if identify == "safe":
270            return not cls.case_sensitive(text)
271
272        return False
273
274    @classmethod
275    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
276        if isinstance(expression, exp.Identifier):
277            name = expression.this
278            expression.set(
279                "quoted",
280                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
281            )
282
283        return expression
284
285    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
286        return self.parser(**opts).parse(self.tokenize(sql), sql)
287
288    def parse_into(
289        self, expression_type: exp.IntoType, sql: str, **opts
290    ) -> t.List[t.Optional[exp.Expression]]:
291        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
292
293    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
294        return self.generator(**opts).generate(expression)
295
296    def transpile(self, sql: str, **opts) -> t.List[str]:
297        return [self.generate(expression, **opts) for expression in self.parse(sql)]
298
299    def tokenize(self, sql: str) -> t.List[Token]:
300        return self.tokenizer.tokenize(sql)
301
302    @property
303    def tokenizer(self) -> Tokenizer:
304        if not hasattr(self, "_tokenizer"):
305            self._tokenizer = self.tokenizer_class()
306        return self._tokenizer
307
308    def parser(self, **opts) -> Parser:
309        return self.parser_class(**opts)
310
311    def generator(self, **opts) -> Generator:
312        return self.generator_class(**opts)
INDEX_OFFSET = 0
UNNEST_COLUMN_ONLY = False
ALIAS_POST_TABLESAMPLE = False
RESOLVES_IDENTIFIERS_AS_UPPERCASE: Optional[bool] = False
IDENTIFIERS_CAN_START_WITH_DIGIT = False
STRICT_STRING_CONCAT = False
NORMALIZE_FUNCTIONS: bool | str = 'upper'
NULL_ORDERING = 'nulls_are_small'
DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}
FORMAT_MAPPING: Dict[str, str] = {}
PSEUDOCOLUMNS: Set[str] = set()
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
@classmethod
def get_or_raise( cls, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> Type[sqlglot.dialects.dialect.Dialect]:
195    @classmethod
196    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
197        if not dialect:
198            return cls
199        if isinstance(dialect, _Dialect):
200            return dialect
201        if isinstance(dialect, Dialect):
202            return dialect.__class__
203
204        result = cls.get(dialect)
205        if not result:
206            raise ValueError(f"Unknown dialect '{dialect}'")
207
208        return result
@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
210    @classmethod
211    def format_time(
212        cls, expression: t.Optional[str | exp.Expression]
213    ) -> t.Optional[exp.Expression]:
214        if isinstance(expression, str):
215            return exp.Literal.string(
216                # the time formats are quoted
217                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
218            )
219
220        if expression and expression.is_string:
221            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
222
223        return expression
@classmethod
def normalize_identifier(cls, expression: ~E) -> ~E:
225    @classmethod
226    def normalize_identifier(cls, expression: E) -> E:
227        """
228        Normalizes an unquoted identifier to either lower or upper case, thus essentially
229        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
230        they will be normalized regardless of being quoted or not.
231        """
232        if isinstance(expression, exp.Identifier) and (
233            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
234        ):
235            expression.set(
236                "this",
237                expression.this.upper()
238                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
239                else expression.this.lower(),
240            )
241
242        return expression

Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized regardless of being quoted or not.

@classmethod
def case_sensitive(cls, text: str) -> bool:
244    @classmethod
245    def case_sensitive(cls, text: str) -> bool:
246        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
247        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
248            return False
249
250        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
251        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

@classmethod
def can_identify(cls, text: str, identify: str | bool = 'safe') -> bool:
253    @classmethod
254    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
255        """Checks if text can be identified given an identify option.
256
257        Args:
258            text: The text to check.
259            identify:
260                "always" or `True`: Always returns true.
261                "safe": True if the identifier is case-insensitive.
262
263        Returns:
264            Whether or not the given text can be identified.
265        """
266        if identify is True or identify == "always":
267            return True
268
269        if identify == "safe":
270            return not cls.case_sensitive(text)
271
272        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:

Whether or not the given text can be identified.

@classmethod
def quote_identifier(cls, expression: ~E, identify: bool = True) -> ~E:
274    @classmethod
275    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
276        if isinstance(expression, exp.Identifier):
277            name = expression.this
278            expression.set(
279                "quoted",
280                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
281            )
282
283        return expression
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
285    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
286        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
288    def parse_into(
289        self, expression_type: exp.IntoType, sql: str, **opts
290    ) -> t.List[t.Optional[exp.Expression]]:
291        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: Optional[sqlglot.expressions.Expression], **opts) -> str:
293    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
294        return self.generator(**opts).generate(expression)
def transpile(self, sql: str, **opts) -> List[str]:
296    def transpile(self, sql: str, **opts) -> t.List[str]:
297        return [self.generate(expression, **opts) for expression in self.parse(sql)]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
299    def tokenize(self, sql: str) -> t.List[Token]:
300        return self.tokenizer.tokenize(sql)
def parser(self, **opts) -> sqlglot.parser.Parser:
308    def parser(self, **opts) -> Parser:
309        return self.parser_class(**opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
311    def generator(self, **opts) -> Generator:
312        return self.generator_class(**opts)
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START = None
BIT_END = None
HEX_START = None
HEX_END = None
BYTE_START = None
BYTE_END = None
DialectType = typing.Union[str, sqlglot.dialects.dialect.Dialect, typing.Type[sqlglot.dialects.dialect.Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
318def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
319    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
322def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
323    if expression.args.get("accuracy"):
324        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
325    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.If) -> str:
328def if_sql(self: Generator, expression: exp.If) -> str:
329    return self.func(
330        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
331    )
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
334def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
335    return self.binary(expression, "->")
def arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
338def arrow_json_extract_scalar_sql(
339    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
340) -> str:
341    return self.binary(expression, "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
344def inline_array_sql(self: Generator, expression: exp.Array) -> str:
345    return f"[{self.expressions(expression)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
348def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
349    return self.like_sql(
350        exp.Like(
351            this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy()
352        )
353    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
356def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
357    zone = self.sql(expression, "this")
358    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
361def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
362    if expression.args.get("recursive"):
363        self.unsupported("Recursive CTEs are unsupported")
364        expression.args["recursive"] = False
365    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
368def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
369    n = self.sql(expression, "this")
370    d = self.sql(expression, "expression")
371    return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
374def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
375    self.unsupported("TABLESAMPLE unsupported")
376    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
379def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
380    self.unsupported("PIVOT unsupported")
381    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
384def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
385    return self.cast_sql(expression)
def no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
388def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
389    self.unsupported("Properties unsupported")
390    return ""
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
393def no_comment_column_constraint_sql(
394    self: Generator, expression: exp.CommentColumnConstraint
395) -> str:
396    self.unsupported("CommentColumnConstraint unsupported")
397    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
400def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
401    self.unsupported("MAP_FROM_ENTRIES unsupported")
402    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
405def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
406    this = self.sql(expression, "this")
407    substr = self.sql(expression, "substr")
408    position = self.sql(expression, "position")
409    if position:
410        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
411    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
414def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
415    this = self.sql(expression, "this")
416    struct_key = self.sql(exp.Identifier(this=expression.expression.copy(), quoted=True))
417    return f"{this}.{struct_key}"
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
420def var_map_sql(
421    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
422) -> str:
423    keys = expression.args["keys"]
424    values = expression.args["values"]
425
426    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
427        self.unsupported("Cannot convert array columns into map.")
428        return self.func(map_func_name, keys, values)
429
430    args = []
431    for key, value in zip(keys.expressions, values.expressions):
432        args.append(self.sql(key))
433        args.append(self.sql(value))
434
435    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
438def format_time_lambda(
439    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
440) -> t.Callable[[t.List], E]:
441    """Helper used for time expressions.
442
443    Args:
444        exp_class: the expression class to instantiate.
445        dialect: target sql dialect.
446        default: the default format, True being time.
447
448    Returns:
449        A callable that can be used to return the appropriately formatted time expression.
450    """
451
452    def _format_time(args: t.List):
453        return exp_class(
454            this=seq_get(args, 0),
455            format=Dialect[dialect].format_time(
456                seq_get(args, 1)
457                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
458            ),
459        )
460
461    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
464def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
465    """
466    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
467    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
468    columns are removed from the create statement.
469    """
470    has_schema = isinstance(expression.this, exp.Schema)
471    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
472
473    if has_schema and is_partitionable:
474        expression = expression.copy()
475        prop = expression.find(exp.PartitionedByProperty)
476        if prop and prop.this and not isinstance(prop.this, exp.Schema):
477            schema = expression.this
478            columns = {v.name.upper() for v in prop.this.expressions}
479            partitions = [col for col in schema.expressions if col.name.upper() in columns]
480            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
481            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
482            expression.set("this", schema)
483
484    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
487def parse_date_delta(
488    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
489) -> t.Callable[[t.List], E]:
490    def inner_func(args: t.List) -> E:
491        unit_based = len(args) == 3
492        this = args[2] if unit_based else seq_get(args, 0)
493        unit = args[0] if unit_based else exp.Literal.string("DAY")
494        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
495        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
496
497    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
500def parse_date_delta_with_interval(
501    expression_class: t.Type[E],
502) -> t.Callable[[t.List], t.Optional[E]]:
503    def func(args: t.List) -> t.Optional[E]:
504        if len(args) < 2:
505            return None
506
507        interval = args[1]
508
509        if not isinstance(interval, exp.Interval):
510            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
511
512        expression = interval.this
513        if expression and expression.is_string:
514            expression = exp.Literal.number(expression.this)
515
516        return expression_class(
517            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
518        )
519
520    return func
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
523def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
524    unit = seq_get(args, 0)
525    this = seq_get(args, 1)
526
527    if isinstance(this, exp.Cast) and this.is_type("date"):
528        return exp.DateTrunc(unit=unit, this=this)
529    return exp.TimestampTrunc(this=this, unit=unit)
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
532def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
533    return self.func(
534        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
535    )
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
538def locate_to_strposition(args: t.List) -> exp.Expression:
539    return exp.StrPosition(
540        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
541    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
544def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
545    return self.func(
546        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
547    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
550def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
551    expression = expression.copy()
552    return self.sql(
553        exp.Substring(
554            this=expression.this, start=exp.Literal.number(1), length=expression.expression
555        )
556    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
559def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
560    expression = expression.copy()
561    return self.sql(
562        exp.Substring(
563            this=expression.this,
564            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
565        )
566    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
569def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
570    return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
573def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
574    return self.sql(exp.cast(expression.this, "date"))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
578def encode_decode_sql(
579    self: Generator, expression: exp.Expression, name: str, replace: bool = True
580) -> str:
581    charset = expression.args.get("charset")
582    if charset and charset.name.lower() != "utf-8":
583        self.unsupported(f"Expected utf-8 character set, got {charset}.")
584
585    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
588def min_or_least(self: Generator, expression: exp.Min) -> str:
589    name = "LEAST" if expression.expressions else "MIN"
590    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
593def max_or_greatest(self: Generator, expression: exp.Max) -> str:
594    name = "GREATEST" if expression.expressions else "MAX"
595    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
598def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
599    cond = expression.this
600
601    if isinstance(expression.this, exp.Distinct):
602        cond = expression.this.expressions[0]
603        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
604
605    return self.func("sum", exp.func("if", cond.copy(), 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
608def trim_sql(self: Generator, expression: exp.Trim) -> str:
609    target = self.sql(expression, "this")
610    trim_type = self.sql(expression, "position")
611    remove_chars = self.sql(expression, "expression")
612    collation = self.sql(expression, "collation")
613
614    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
615    if not remove_chars and not collation:
616        return self.trim_sql(expression)
617
618    trim_type = f"{trim_type} " if trim_type else ""
619    remove_chars = f"{remove_chars} " if remove_chars else ""
620    from_part = "FROM " if trim_type or remove_chars else ""
621    collation = f" COLLATE {collation}" if collation else ""
622    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
625def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
626    return self.func("STRPTIME", expression.this, self.format_time(expression))
def ts_or_ds_to_date_sql(dialect: str) -> Callable:
629def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
630    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
631        _dialect = Dialect.get_or_raise(dialect)
632        time_format = self.format_time(expression)
633        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
634            return self.sql(exp.cast(str_to_time_sql(self, expression), "date"))
635
636        return self.sql(exp.cast(self.sql(expression, "this"), "date"))
637
638    return _ts_or_ds_to_date_sql
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat | sqlglot.expressions.SafeConcat) -> str:
641def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
642    expression = expression.copy()
643    this, *rest_args = expression.expressions
644    for arg in rest_args:
645        this = exp.DPipe(this=this, expression=arg)
646
647    return self.sql(this)
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
650def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
651    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
652    if bad_args:
653        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
654
655    return self.func(
656        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
657    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
660def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
661    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
662    if bad_args:
663        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
664
665    return self.func(
666        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
667    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> List[str]:
670def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
671    names = []
672    for agg in aggregations:
673        if isinstance(agg, exp.Alias):
674            names.append(agg.alias)
675        else:
676            """
677            This case corresponds to aggregations without aliases being used as suffixes
678            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
679            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
680            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
681            """
682            agg_all_unquoted = agg.transform(
683                lambda node: exp.Identifier(this=node.name, quoted=False)
684                if isinstance(node, exp.Identifier)
685                else node
686            )
687            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
688
689    return names
def simplify_literal(expression: ~E) -> ~E:
692def simplify_literal(expression: E) -> E:
693    if not isinstance(expression.expression, exp.Literal):
694        from sqlglot.optimizer.simplify import simplify
695
696        simplify(expression.expression)
697
698    return expression
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
701def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
702    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))