Edit on GitHub

sqlglot.dialects.dialect

  1from __future__ import annotations
  2
  3import typing as t
  4from enum import Enum, auto
  5from functools import reduce
  6
  7from sqlglot import exp
  8from sqlglot._typing import E
  9from sqlglot.errors import ParseError
 10from sqlglot.generator import Generator
 11from sqlglot.helper import AutoName, flatten, seq_get
 12from sqlglot.parser import Parser
 13from sqlglot.time import TIMEZONES, format_time
 14from sqlglot.tokens import Token, Tokenizer, TokenType
 15from sqlglot.trie import new_trie
 16
 17B = t.TypeVar("B", bound=exp.Binary)
 18
 19DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
 20DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
 21
 22
 23class Dialects(str, Enum):
 24    DIALECT = ""
 25
 26    BIGQUERY = "bigquery"
 27    CLICKHOUSE = "clickhouse"
 28    DATABRICKS = "databricks"
 29    DRILL = "drill"
 30    DUCKDB = "duckdb"
 31    HIVE = "hive"
 32    MYSQL = "mysql"
 33    ORACLE = "oracle"
 34    POSTGRES = "postgres"
 35    PRESTO = "presto"
 36    REDSHIFT = "redshift"
 37    SNOWFLAKE = "snowflake"
 38    SPARK = "spark"
 39    SPARK2 = "spark2"
 40    SQLITE = "sqlite"
 41    STARROCKS = "starrocks"
 42    TABLEAU = "tableau"
 43    TERADATA = "teradata"
 44    TRINO = "trino"
 45    TSQL = "tsql"
 46    Doris = "doris"
 47
 48
 49class NormalizationStrategy(str, AutoName):
 50    """Specifies the strategy according to which identifiers should be normalized."""
 51
 52    LOWERCASE = auto()  # Unquoted identifiers are lowercased
 53    UPPERCASE = auto()  # Unquoted identifiers are uppercased
 54    CASE_SENSITIVE = auto()  # Always case-sensitive, regardless of quotes
 55    CASE_INSENSITIVE = auto()  # Always case-insensitive, regardless of quotes
 56
 57
 58class _Dialect(type):
 59    classes: t.Dict[str, t.Type[Dialect]] = {}
 60
 61    def __eq__(cls, other: t.Any) -> bool:
 62        if cls is other:
 63            return True
 64        if isinstance(other, str):
 65            return cls is cls.get(other)
 66        if isinstance(other, Dialect):
 67            return cls is type(other)
 68
 69        return False
 70
 71    def __hash__(cls) -> int:
 72        return hash(cls.__name__.lower())
 73
 74    @classmethod
 75    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 76        return cls.classes[key]
 77
 78    @classmethod
 79    def get(
 80        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 81    ) -> t.Optional[t.Type[Dialect]]:
 82        return cls.classes.get(key, default)
 83
 84    def __new__(cls, clsname, bases, attrs):
 85        klass = super().__new__(cls, clsname, bases, attrs)
 86        enum = Dialects.__members__.get(clsname.upper())
 87        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 88
 89        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 90        klass.FORMAT_TRIE = (
 91            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 92        )
 93        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 94        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 95
 96        klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()}
 97
 98        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 99        klass.parser_class = getattr(klass, "Parser", Parser)
100        klass.generator_class = getattr(klass, "Generator", Generator)
101
102        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
103        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
104            klass.tokenizer_class._IDENTIFIERS.items()
105        )[0]
106
107        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
108            return next(
109                (
110                    (s, e)
111                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
112                    if t == token_type
113                ),
114                (None, None),
115            )
116
117        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
118        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
119        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
120
121        if enum not in ("", "bigquery"):
122            klass.generator_class.SELECT_KINDS = ()
123
124        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
125            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
126                TokenType.ANTI,
127                TokenType.SEMI,
128            }
129
130        return klass
131
132
133class Dialect(metaclass=_Dialect):
134    # Determines the base index offset for arrays
135    INDEX_OFFSET = 0
136
137    # If true unnest table aliases are considered only as column aliases
138    UNNEST_COLUMN_ONLY = False
139
140    # Determines whether or not the table alias comes after tablesample
141    ALIAS_POST_TABLESAMPLE = False
142
143    # Specifies the strategy according to which identifiers should be normalized.
144    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
145
146    # Determines whether or not an unquoted identifier can start with a digit
147    IDENTIFIERS_CAN_START_WITH_DIGIT = False
148
149    # Determines whether or not the DPIPE token ('||') is a string concatenation operator
150    DPIPE_IS_STRING_CONCAT = True
151
152    # Determines whether or not CONCAT's arguments must be strings
153    STRICT_STRING_CONCAT = False
154
155    # Determines whether or not user-defined data types are supported
156    SUPPORTS_USER_DEFINED_TYPES = True
157
158    # Determines whether or not SEMI/ANTI JOINs are supported
159    SUPPORTS_SEMI_ANTI_JOIN = True
160
161    # Determines how function names are going to be normalized
162    NORMALIZE_FUNCTIONS: bool | str = "upper"
163
164    # Determines whether the base comes first in the LOG function
165    LOG_BASE_FIRST = True
166
167    # Indicates the default null ordering method to use if not explicitly set
168    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
169    NULL_ORDERING = "nulls_are_small"
170
171    # Whether the behavior of a / b depends on the types of a and b.
172    # False means a / b is always float division.
173    # True means a / b is integer division if both a and b are integers.
174    TYPED_DIVISION = False
175
176    # False means 1 / 0 throws an error.
177    # True means 1 / 0 returns null.
178    SAFE_DIVISION = False
179
180    # A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string
181    CONCAT_COALESCE = False
182
183    DATE_FORMAT = "'%Y-%m-%d'"
184    DATEINT_FORMAT = "'%Y%m%d'"
185    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
186
187    # Custom time mappings in which the key represents dialect time format
188    # and the value represents a python time format
189    TIME_MAPPING: t.Dict[str, str] = {}
190
191    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
192    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
193    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
194    FORMAT_MAPPING: t.Dict[str, str] = {}
195
196    # Mapping of an unescaped escape sequence to the corresponding character
197    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
198
199    # Columns that are auto-generated by the engine corresponding to this dialect
200    # Such columns may be excluded from SELECT * queries, for example
201    PSEUDOCOLUMNS: t.Set[str] = set()
202
203    # --- Autofilled ---
204
205    tokenizer_class = Tokenizer
206    parser_class = Parser
207    generator_class = Generator
208
209    # A trie of the time_mapping keys
210    TIME_TRIE: t.Dict = {}
211    FORMAT_TRIE: t.Dict = {}
212
213    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
214    INVERSE_TIME_TRIE: t.Dict = {}
215
216    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
217
218    # Delimiters for quotes, identifiers and the corresponding escape characters
219    QUOTE_START = "'"
220    QUOTE_END = "'"
221    IDENTIFIER_START = '"'
222    IDENTIFIER_END = '"'
223
224    # Delimiters for bit, hex and byte literals
225    BIT_START: t.Optional[str] = None
226    BIT_END: t.Optional[str] = None
227    HEX_START: t.Optional[str] = None
228    HEX_END: t.Optional[str] = None
229    BYTE_START: t.Optional[str] = None
230    BYTE_END: t.Optional[str] = None
231
232    @classmethod
233    def get_or_raise(cls, dialect: DialectType) -> Dialect:
234        """
235        Look up a dialect in the global dialect registry and return it if it exists.
236
237        Args:
238            dialect: The target dialect. If this is a string, it can be optionally followed by
239                additional key-value pairs that are separated by commas and are used to specify
240                dialect settings, such as whether the dialect's identifiers are case-sensitive.
241
242        Example:
243            >>> dialect = dialect_class = get_or_raise("duckdb")
244            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
245
246        Returns:
247            The corresponding Dialect instance.
248        """
249
250        if not dialect:
251            return cls()
252        if isinstance(dialect, _Dialect):
253            return dialect()
254        if isinstance(dialect, Dialect):
255            return dialect
256        if isinstance(dialect, str):
257            try:
258                dialect_name, *kv_pairs = dialect.split(",")
259                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
260            except ValueError:
261                raise ValueError(
262                    f"Invalid dialect format: '{dialect}'. "
263                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
264                )
265
266            result = cls.get(dialect_name.strip())
267            if not result:
268                raise ValueError(f"Unknown dialect '{dialect_name}'.")
269
270            return result(**kwargs)
271
272        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
273
274    @classmethod
275    def format_time(
276        cls, expression: t.Optional[str | exp.Expression]
277    ) -> t.Optional[exp.Expression]:
278        if isinstance(expression, str):
279            return exp.Literal.string(
280                # the time formats are quoted
281                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
282            )
283
284        if expression and expression.is_string:
285            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
286
287        return expression
288
289    def __init__(self, **kwargs) -> None:
290        normalization_strategy = kwargs.get("normalization_strategy")
291
292        if normalization_strategy is None:
293            self.normalization_strategy = self.NORMALIZATION_STRATEGY
294        else:
295            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
296
297    def __eq__(self, other: t.Any) -> bool:
298        # Does not currently take dialect state into account
299        return type(self) == other
300
301    def __hash__(self) -> int:
302        # Does not currently take dialect state into account
303        return hash(type(self))
304
305    def normalize_identifier(self, expression: E) -> E:
306        """
307        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
308
309        For example, an identifier like FoO would be resolved as foo in Postgres, because it
310        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
311        it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive,
312        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
313
314        There are also dialects like Spark, which are case-insensitive even when quotes are
315        present, and dialects like MySQL, whose resolution rules match those employed by the
316        underlying operating system, for example they may always be case-sensitive in Linux.
317
318        Finally, the normalization behavior of some engines can even be controlled through flags,
319        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
320
321        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
322        that it can analyze queries in the optimizer and successfully capture their semantics.
323        """
324        if (
325            isinstance(expression, exp.Identifier)
326            and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
327            and (
328                not expression.quoted
329                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
330            )
331        ):
332            expression.set(
333                "this",
334                expression.this.upper()
335                if self.normalization_strategy is NormalizationStrategy.UPPERCASE
336                else expression.this.lower(),
337            )
338
339        return expression
340
341    def case_sensitive(self, text: str) -> bool:
342        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
343        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
344            return False
345
346        unsafe = (
347            str.islower
348            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
349            else str.isupper
350        )
351        return any(unsafe(char) for char in text)
352
353    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
354        """Checks if text can be identified given an identify option.
355
356        Args:
357            text: The text to check.
358            identify:
359                "always" or `True`: Always returns true.
360                "safe": True if the identifier is case-insensitive.
361
362        Returns:
363            Whether or not the given text can be identified.
364        """
365        if identify is True or identify == "always":
366            return True
367
368        if identify == "safe":
369            return not self.case_sensitive(text)
370
371        return False
372
373    def quote_identifier(self, expression: E, identify: bool = True) -> E:
374        if isinstance(expression, exp.Identifier):
375            name = expression.this
376            expression.set(
377                "quoted",
378                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
379            )
380
381        return expression
382
383    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
384        return self.parser(**opts).parse(self.tokenize(sql), sql)
385
386    def parse_into(
387        self, expression_type: exp.IntoType, sql: str, **opts
388    ) -> t.List[t.Optional[exp.Expression]]:
389        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
390
391    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
392        return self.generator(**opts).generate(expression, copy=copy)
393
394    def transpile(self, sql: str, **opts) -> t.List[str]:
395        return [
396            self.generate(expression, copy=False, **opts) if expression else ""
397            for expression in self.parse(sql)
398        ]
399
400    def tokenize(self, sql: str) -> t.List[Token]:
401        return self.tokenizer.tokenize(sql)
402
403    @property
404    def tokenizer(self) -> Tokenizer:
405        if not hasattr(self, "_tokenizer"):
406            self._tokenizer = self.tokenizer_class(dialect=self)
407        return self._tokenizer
408
409    def parser(self, **opts) -> Parser:
410        return self.parser_class(dialect=self, **opts)
411
412    def generator(self, **opts) -> Generator:
413        return self.generator_class(dialect=self, **opts)
414
415
416DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
417
418
419def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
420    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
421
422
423def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
424    if expression.args.get("accuracy"):
425        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
426    return self.func("APPROX_COUNT_DISTINCT", expression.this)
427
428
429def if_sql(
430    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
431) -> t.Callable[[Generator, exp.If], str]:
432    def _if_sql(self: Generator, expression: exp.If) -> str:
433        return self.func(
434            name,
435            expression.this,
436            expression.args.get("true"),
437            expression.args.get("false") or false_value,
438        )
439
440    return _if_sql
441
442
443def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
444    return self.binary(expression, "->")
445
446
447def arrow_json_extract_scalar_sql(
448    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
449) -> str:
450    return self.binary(expression, "->>")
451
452
453def inline_array_sql(self: Generator, expression: exp.Array) -> str:
454    return f"[{self.expressions(expression, flat=True)}]"
455
456
457def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
458    return self.like_sql(
459        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
460    )
461
462
463def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
464    zone = self.sql(expression, "this")
465    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
466
467
468def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
469    if expression.args.get("recursive"):
470        self.unsupported("Recursive CTEs are unsupported")
471        expression.args["recursive"] = False
472    return self.with_sql(expression)
473
474
475def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
476    n = self.sql(expression, "this")
477    d = self.sql(expression, "expression")
478    return f"IF({d} <> 0, {n} / {d}, NULL)"
479
480
481def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
482    self.unsupported("TABLESAMPLE unsupported")
483    return self.sql(expression.this)
484
485
486def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
487    self.unsupported("PIVOT unsupported")
488    return ""
489
490
491def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
492    return self.cast_sql(expression)
493
494
495def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
496    self.unsupported("Properties unsupported")
497    return ""
498
499
500def no_comment_column_constraint_sql(
501    self: Generator, expression: exp.CommentColumnConstraint
502) -> str:
503    self.unsupported("CommentColumnConstraint unsupported")
504    return ""
505
506
507def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
508    self.unsupported("MAP_FROM_ENTRIES unsupported")
509    return ""
510
511
512def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
513    this = self.sql(expression, "this")
514    substr = self.sql(expression, "substr")
515    position = self.sql(expression, "position")
516    if position:
517        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
518    return f"STRPOS({this}, {substr})"
519
520
521def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
522    return (
523        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
524    )
525
526
527def var_map_sql(
528    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
529) -> str:
530    keys = expression.args["keys"]
531    values = expression.args["values"]
532
533    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
534        self.unsupported("Cannot convert array columns into map.")
535        return self.func(map_func_name, keys, values)
536
537    args = []
538    for key, value in zip(keys.expressions, values.expressions):
539        args.append(self.sql(key))
540        args.append(self.sql(value))
541
542    return self.func(map_func_name, *args)
543
544
545def format_time_lambda(
546    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
547) -> t.Callable[[t.List], E]:
548    """Helper used for time expressions.
549
550    Args:
551        exp_class: the expression class to instantiate.
552        dialect: target sql dialect.
553        default: the default format, True being time.
554
555    Returns:
556        A callable that can be used to return the appropriately formatted time expression.
557    """
558
559    def _format_time(args: t.List):
560        return exp_class(
561            this=seq_get(args, 0),
562            format=Dialect[dialect].format_time(
563                seq_get(args, 1)
564                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
565            ),
566        )
567
568    return _format_time
569
570
571def time_format(
572    dialect: DialectType = None,
573) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
574    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
575        """
576        Returns the time format for a given expression, unless it's equivalent
577        to the default time format of the dialect of interest.
578        """
579        time_format = self.format_time(expression)
580        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
581
582    return _time_format
583
584
585def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
586    """
587    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
588    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
589    columns are removed from the create statement.
590    """
591    has_schema = isinstance(expression.this, exp.Schema)
592    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
593
594    if has_schema and is_partitionable:
595        prop = expression.find(exp.PartitionedByProperty)
596        if prop and prop.this and not isinstance(prop.this, exp.Schema):
597            schema = expression.this
598            columns = {v.name.upper() for v in prop.this.expressions}
599            partitions = [col for col in schema.expressions if col.name.upper() in columns]
600            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
601            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
602            expression.set("this", schema)
603
604    return self.create_sql(expression)
605
606
607def parse_date_delta(
608    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
609) -> t.Callable[[t.List], E]:
610    def inner_func(args: t.List) -> E:
611        unit_based = len(args) == 3
612        this = args[2] if unit_based else seq_get(args, 0)
613        unit = args[0] if unit_based else exp.Literal.string("DAY")
614        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
615        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
616
617    return inner_func
618
619
620def parse_date_delta_with_interval(
621    expression_class: t.Type[E],
622) -> t.Callable[[t.List], t.Optional[E]]:
623    def func(args: t.List) -> t.Optional[E]:
624        if len(args) < 2:
625            return None
626
627        interval = args[1]
628
629        if not isinstance(interval, exp.Interval):
630            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
631
632        expression = interval.this
633        if expression and expression.is_string:
634            expression = exp.Literal.number(expression.this)
635
636        return expression_class(
637            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
638        )
639
640    return func
641
642
643def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
644    unit = seq_get(args, 0)
645    this = seq_get(args, 1)
646
647    if isinstance(this, exp.Cast) and this.is_type("date"):
648        return exp.DateTrunc(unit=unit, this=this)
649    return exp.TimestampTrunc(this=this, unit=unit)
650
651
652def date_add_interval_sql(
653    data_type: str, kind: str
654) -> t.Callable[[Generator, exp.Expression], str]:
655    def func(self: Generator, expression: exp.Expression) -> str:
656        this = self.sql(expression, "this")
657        unit = expression.args.get("unit")
658        unit = exp.var(unit.name.upper() if unit else "DAY")
659        interval = exp.Interval(this=expression.expression, unit=unit)
660        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
661
662    return func
663
664
665def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
666    return self.func(
667        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
668    )
669
670
671def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
672    if not expression.expression:
673        return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP))
674    if expression.text("expression").lower() in TIMEZONES:
675        return self.sql(
676            exp.AtTimeZone(
677                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
678                zone=expression.expression,
679            )
680        )
681    return self.function_fallback_sql(expression)
682
683
684def locate_to_strposition(args: t.List) -> exp.Expression:
685    return exp.StrPosition(
686        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
687    )
688
689
690def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
691    return self.func(
692        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
693    )
694
695
696def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
697    return self.sql(
698        exp.Substring(
699            this=expression.this, start=exp.Literal.number(1), length=expression.expression
700        )
701    )
702
703
704def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
705    return self.sql(
706        exp.Substring(
707            this=expression.this,
708            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
709        )
710    )
711
712
713def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
714    return self.sql(exp.cast(expression.this, "timestamp"))
715
716
717def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
718    return self.sql(exp.cast(expression.this, "date"))
719
720
721# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
722def encode_decode_sql(
723    self: Generator, expression: exp.Expression, name: str, replace: bool = True
724) -> str:
725    charset = expression.args.get("charset")
726    if charset and charset.name.lower() != "utf-8":
727        self.unsupported(f"Expected utf-8 character set, got {charset}.")
728
729    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
730
731
732def min_or_least(self: Generator, expression: exp.Min) -> str:
733    name = "LEAST" if expression.expressions else "MIN"
734    return rename_func(name)(self, expression)
735
736
737def max_or_greatest(self: Generator, expression: exp.Max) -> str:
738    name = "GREATEST" if expression.expressions else "MAX"
739    return rename_func(name)(self, expression)
740
741
742def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
743    cond = expression.this
744
745    if isinstance(expression.this, exp.Distinct):
746        cond = expression.this.expressions[0]
747        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
748
749    return self.func("sum", exp.func("if", cond, 1, 0))
750
751
752def trim_sql(self: Generator, expression: exp.Trim) -> str:
753    target = self.sql(expression, "this")
754    trim_type = self.sql(expression, "position")
755    remove_chars = self.sql(expression, "expression")
756    collation = self.sql(expression, "collation")
757
758    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
759    if not remove_chars and not collation:
760        return self.trim_sql(expression)
761
762    trim_type = f"{trim_type} " if trim_type else ""
763    remove_chars = f"{remove_chars} " if remove_chars else ""
764    from_part = "FROM " if trim_type or remove_chars else ""
765    collation = f" COLLATE {collation}" if collation else ""
766    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
767
768
769def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
770    return self.func("STRPTIME", expression.this, self.format_time(expression))
771
772
773def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
774    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
775        _dialect = Dialect.get_or_raise(dialect)
776        time_format = self.format_time(expression)
777        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
778            return self.sql(
779                exp.cast(
780                    exp.StrToTime(this=expression.this, format=expression.args["format"]),
781                    "date",
782                )
783            )
784        return self.sql(exp.cast(expression.this, "date"))
785
786    return _ts_or_ds_to_date_sql
787
788
789def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
790    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
791
792
793def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
794    delim, *rest_args = expression.expressions
795    return self.sql(
796        reduce(
797            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
798            rest_args,
799        )
800    )
801
802
803def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
804    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
805    if bad_args:
806        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
807
808    return self.func(
809        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
810    )
811
812
813def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
814    bad_args = list(
815        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
816    )
817    if bad_args:
818        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
819
820    return self.func(
821        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
822    )
823
824
825def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
826    names = []
827    for agg in aggregations:
828        if isinstance(agg, exp.Alias):
829            names.append(agg.alias)
830        else:
831            """
832            This case corresponds to aggregations without aliases being used as suffixes
833            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
834            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
835            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
836            """
837            agg_all_unquoted = agg.transform(
838                lambda node: exp.Identifier(this=node.name, quoted=False)
839                if isinstance(node, exp.Identifier)
840                else node
841            )
842            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
843
844    return names
845
846
847def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
848    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
849
850
851# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
852def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
853    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
854
855
856def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
857    return self.func("MAX", expression.this)
858
859
860def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
861    a = self.sql(expression.left)
862    b = self.sql(expression.right)
863    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
864
865
866# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon
867def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str:
868    return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
869
870
871def is_parse_json(expression: exp.Expression) -> bool:
872    return isinstance(expression, exp.ParseJSON) or (
873        isinstance(expression, exp.Cast) and expression.is_type("json")
874    )
875
876
877def isnull_to_is_null(args: t.List) -> exp.Expression:
878    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
879
880
881def generatedasidentitycolumnconstraint_sql(
882    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
883) -> str:
884    start = self.sql(expression, "start") or "1"
885    increment = self.sql(expression, "increment") or "1"
886    return f"IDENTITY({start}, {increment})"
887
888
889def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
890    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
891        if expression.args.get("count"):
892            self.unsupported(f"Only two arguments are supported in function {name}.")
893
894        return self.func(name, expression.this, expression.expression)
895
896    return _arg_max_or_min_sql
897
898
899def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
900    this = expression.this.copy()
901
902    return_type = expression.return_type
903    if return_type.is_type(exp.DataType.Type.DATE):
904        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
905        # can truncate timestamp strings, because some dialects can't cast them to DATE
906        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
907
908    expression.this.replace(exp.cast(this, return_type))
909    return expression
910
911
912def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
913    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
914        if cast and isinstance(expression, exp.TsOrDsAdd):
915            expression = ts_or_ds_add_cast(expression)
916
917        return self.func(
918            name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this
919        )
920
921    return _delta_sql
class Dialects(builtins.str, enum.Enum):
24class Dialects(str, Enum):
25    DIALECT = ""
26
27    BIGQUERY = "bigquery"
28    CLICKHOUSE = "clickhouse"
29    DATABRICKS = "databricks"
30    DRILL = "drill"
31    DUCKDB = "duckdb"
32    HIVE = "hive"
33    MYSQL = "mysql"
34    ORACLE = "oracle"
35    POSTGRES = "postgres"
36    PRESTO = "presto"
37    REDSHIFT = "redshift"
38    SNOWFLAKE = "snowflake"
39    SPARK = "spark"
40    SPARK2 = "spark2"
41    SQLITE = "sqlite"
42    STARROCKS = "starrocks"
43    TABLEAU = "tableau"
44    TERADATA = "teradata"
45    TRINO = "trino"
46    TSQL = "tsql"
47    Doris = "doris"

An enumeration.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Doris = <Dialects.Doris: 'doris'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class NormalizationStrategy(builtins.str, sqlglot.helper.AutoName):
50class NormalizationStrategy(str, AutoName):
51    """Specifies the strategy according to which identifiers should be normalized."""
52
53    LOWERCASE = auto()  # Unquoted identifiers are lowercased
54    UPPERCASE = auto()  # Unquoted identifiers are uppercased
55    CASE_SENSITIVE = auto()  # Always case-sensitive, regardless of quotes
56    CASE_INSENSITIVE = auto()  # Always case-insensitive, regardless of quotes

Specifies the strategy according to which identifiers should be normalized.

LOWERCASE = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>
UPPERCASE = <NormalizationStrategy.UPPERCASE: 'UPPERCASE'>
CASE_SENSITIVE = <NormalizationStrategy.CASE_SENSITIVE: 'CASE_SENSITIVE'>
CASE_INSENSITIVE = <NormalizationStrategy.CASE_INSENSITIVE: 'CASE_INSENSITIVE'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
134class Dialect(metaclass=_Dialect):
135    # Determines the base index offset for arrays
136    INDEX_OFFSET = 0
137
138    # If true unnest table aliases are considered only as column aliases
139    UNNEST_COLUMN_ONLY = False
140
141    # Determines whether or not the table alias comes after tablesample
142    ALIAS_POST_TABLESAMPLE = False
143
144    # Specifies the strategy according to which identifiers should be normalized.
145    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
146
147    # Determines whether or not an unquoted identifier can start with a digit
148    IDENTIFIERS_CAN_START_WITH_DIGIT = False
149
150    # Determines whether or not the DPIPE token ('||') is a string concatenation operator
151    DPIPE_IS_STRING_CONCAT = True
152
153    # Determines whether or not CONCAT's arguments must be strings
154    STRICT_STRING_CONCAT = False
155
156    # Determines whether or not user-defined data types are supported
157    SUPPORTS_USER_DEFINED_TYPES = True
158
159    # Determines whether or not SEMI/ANTI JOINs are supported
160    SUPPORTS_SEMI_ANTI_JOIN = True
161
162    # Determines how function names are going to be normalized
163    NORMALIZE_FUNCTIONS: bool | str = "upper"
164
165    # Determines whether the base comes first in the LOG function
166    LOG_BASE_FIRST = True
167
168    # Indicates the default null ordering method to use if not explicitly set
169    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
170    NULL_ORDERING = "nulls_are_small"
171
172    # Whether the behavior of a / b depends on the types of a and b.
173    # False means a / b is always float division.
174    # True means a / b is integer division if both a and b are integers.
175    TYPED_DIVISION = False
176
177    # False means 1 / 0 throws an error.
178    # True means 1 / 0 returns null.
179    SAFE_DIVISION = False
180
181    # A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string
182    CONCAT_COALESCE = False
183
184    DATE_FORMAT = "'%Y-%m-%d'"
185    DATEINT_FORMAT = "'%Y%m%d'"
186    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
187
188    # Custom time mappings in which the key represents dialect time format
189    # and the value represents a python time format
190    TIME_MAPPING: t.Dict[str, str] = {}
191
192    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
193    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
194    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
195    FORMAT_MAPPING: t.Dict[str, str] = {}
196
197    # Mapping of an unescaped escape sequence to the corresponding character
198    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
199
200    # Columns that are auto-generated by the engine corresponding to this dialect
201    # Such columns may be excluded from SELECT * queries, for example
202    PSEUDOCOLUMNS: t.Set[str] = set()
203
204    # --- Autofilled ---
205
206    tokenizer_class = Tokenizer
207    parser_class = Parser
208    generator_class = Generator
209
210    # A trie of the time_mapping keys
211    TIME_TRIE: t.Dict = {}
212    FORMAT_TRIE: t.Dict = {}
213
214    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
215    INVERSE_TIME_TRIE: t.Dict = {}
216
217    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
218
219    # Delimiters for quotes, identifiers and the corresponding escape characters
220    QUOTE_START = "'"
221    QUOTE_END = "'"
222    IDENTIFIER_START = '"'
223    IDENTIFIER_END = '"'
224
225    # Delimiters for bit, hex and byte literals
226    BIT_START: t.Optional[str] = None
227    BIT_END: t.Optional[str] = None
228    HEX_START: t.Optional[str] = None
229    HEX_END: t.Optional[str] = None
230    BYTE_START: t.Optional[str] = None
231    BYTE_END: t.Optional[str] = None
232
233    @classmethod
234    def get_or_raise(cls, dialect: DialectType) -> Dialect:
235        """
236        Look up a dialect in the global dialect registry and return it if it exists.
237
238        Args:
239            dialect: The target dialect. If this is a string, it can be optionally followed by
240                additional key-value pairs that are separated by commas and are used to specify
241                dialect settings, such as whether the dialect's identifiers are case-sensitive.
242
243        Example:
244            >>> dialect = dialect_class = get_or_raise("duckdb")
245            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
246
247        Returns:
248            The corresponding Dialect instance.
249        """
250
251        if not dialect:
252            return cls()
253        if isinstance(dialect, _Dialect):
254            return dialect()
255        if isinstance(dialect, Dialect):
256            return dialect
257        if isinstance(dialect, str):
258            try:
259                dialect_name, *kv_pairs = dialect.split(",")
260                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
261            except ValueError:
262                raise ValueError(
263                    f"Invalid dialect format: '{dialect}'. "
264                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
265                )
266
267            result = cls.get(dialect_name.strip())
268            if not result:
269                raise ValueError(f"Unknown dialect '{dialect_name}'.")
270
271            return result(**kwargs)
272
273        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
274
275    @classmethod
276    def format_time(
277        cls, expression: t.Optional[str | exp.Expression]
278    ) -> t.Optional[exp.Expression]:
279        if isinstance(expression, str):
280            return exp.Literal.string(
281                # the time formats are quoted
282                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
283            )
284
285        if expression and expression.is_string:
286            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
287
288        return expression
289
290    def __init__(self, **kwargs) -> None:
291        normalization_strategy = kwargs.get("normalization_strategy")
292
293        if normalization_strategy is None:
294            self.normalization_strategy = self.NORMALIZATION_STRATEGY
295        else:
296            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
297
298    def __eq__(self, other: t.Any) -> bool:
299        # Does not currently take dialect state into account
300        return type(self) == other
301
302    def __hash__(self) -> int:
303        # Does not currently take dialect state into account
304        return hash(type(self))
305
306    def normalize_identifier(self, expression: E) -> E:
307        """
308        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
309
310        For example, an identifier like FoO would be resolved as foo in Postgres, because it
311        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
312        it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive,
313        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
314
315        There are also dialects like Spark, which are case-insensitive even when quotes are
316        present, and dialects like MySQL, whose resolution rules match those employed by the
317        underlying operating system, for example they may always be case-sensitive in Linux.
318
319        Finally, the normalization behavior of some engines can even be controlled through flags,
320        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
321
322        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
323        that it can analyze queries in the optimizer and successfully capture their semantics.
324        """
325        if (
326            isinstance(expression, exp.Identifier)
327            and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
328            and (
329                not expression.quoted
330                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
331            )
332        ):
333            expression.set(
334                "this",
335                expression.this.upper()
336                if self.normalization_strategy is NormalizationStrategy.UPPERCASE
337                else expression.this.lower(),
338            )
339
340        return expression
341
342    def case_sensitive(self, text: str) -> bool:
343        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
344        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
345            return False
346
347        unsafe = (
348            str.islower
349            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
350            else str.isupper
351        )
352        return any(unsafe(char) for char in text)
353
354    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
355        """Checks if text can be identified given an identify option.
356
357        Args:
358            text: The text to check.
359            identify:
360                "always" or `True`: Always returns true.
361                "safe": True if the identifier is case-insensitive.
362
363        Returns:
364            Whether or not the given text can be identified.
365        """
366        if identify is True or identify == "always":
367            return True
368
369        if identify == "safe":
370            return not self.case_sensitive(text)
371
372        return False
373
374    def quote_identifier(self, expression: E, identify: bool = True) -> E:
375        if isinstance(expression, exp.Identifier):
376            name = expression.this
377            expression.set(
378                "quoted",
379                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
380            )
381
382        return expression
383
384    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
385        return self.parser(**opts).parse(self.tokenize(sql), sql)
386
387    def parse_into(
388        self, expression_type: exp.IntoType, sql: str, **opts
389    ) -> t.List[t.Optional[exp.Expression]]:
390        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
391
392    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
393        return self.generator(**opts).generate(expression, copy=copy)
394
395    def transpile(self, sql: str, **opts) -> t.List[str]:
396        return [
397            self.generate(expression, copy=False, **opts) if expression else ""
398            for expression in self.parse(sql)
399        ]
400
401    def tokenize(self, sql: str) -> t.List[Token]:
402        return self.tokenizer.tokenize(sql)
403
404    @property
405    def tokenizer(self) -> Tokenizer:
406        if not hasattr(self, "_tokenizer"):
407            self._tokenizer = self.tokenizer_class(dialect=self)
408        return self._tokenizer
409
410    def parser(self, **opts) -> Parser:
411        return self.parser_class(dialect=self, **opts)
412
413    def generator(self, **opts) -> Generator:
414        return self.generator_class(dialect=self, **opts)
Dialect(**kwargs)
290    def __init__(self, **kwargs) -> None:
291        normalization_strategy = kwargs.get("normalization_strategy")
292
293        if normalization_strategy is None:
294            self.normalization_strategy = self.NORMALIZATION_STRATEGY
295        else:
296            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
INDEX_OFFSET = 0
UNNEST_COLUMN_ONLY = False
ALIAS_POST_TABLESAMPLE = False
NORMALIZATION_STRATEGY = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>
IDENTIFIERS_CAN_START_WITH_DIGIT = False
DPIPE_IS_STRING_CONCAT = True
STRICT_STRING_CONCAT = False
SUPPORTS_USER_DEFINED_TYPES = True
SUPPORTS_SEMI_ANTI_JOIN = True
NORMALIZE_FUNCTIONS: bool | str = 'upper'
LOG_BASE_FIRST = True
NULL_ORDERING = 'nulls_are_small'
TYPED_DIVISION = False
SAFE_DIVISION = False
CONCAT_COALESCE = False
DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}
FORMAT_MAPPING: Dict[str, str] = {}
ESCAPE_SEQUENCES: Dict[str, str] = {}
PSEUDOCOLUMNS: Set[str] = set()
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
INVERSE_ESCAPE_SEQUENCES: Dict[str, str] = {}
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START: Optional[str] = None
BIT_END: Optional[str] = None
HEX_START: Optional[str] = None
HEX_END: Optional[str] = None
BYTE_START: Optional[str] = None
BYTE_END: Optional[str] = None
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Dialect:
233    @classmethod
234    def get_or_raise(cls, dialect: DialectType) -> Dialect:
235        """
236        Look up a dialect in the global dialect registry and return it if it exists.
237
238        Args:
239            dialect: The target dialect. If this is a string, it can be optionally followed by
240                additional key-value pairs that are separated by commas and are used to specify
241                dialect settings, such as whether the dialect's identifiers are case-sensitive.
242
243        Example:
244            >>> dialect = dialect_class = get_or_raise("duckdb")
245            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
246
247        Returns:
248            The corresponding Dialect instance.
249        """
250
251        if not dialect:
252            return cls()
253        if isinstance(dialect, _Dialect):
254            return dialect()
255        if isinstance(dialect, Dialect):
256            return dialect
257        if isinstance(dialect, str):
258            try:
259                dialect_name, *kv_pairs = dialect.split(",")
260                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
261            except ValueError:
262                raise ValueError(
263                    f"Invalid dialect format: '{dialect}'. "
264                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
265                )
266
267            result = cls.get(dialect_name.strip())
268            if not result:
269                raise ValueError(f"Unknown dialect '{dialect_name}'.")
270
271            return result(**kwargs)
272
273        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")

Look up a dialect in the global dialect registry and return it if it exists.

Arguments:
  • dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb")
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:

The corresponding Dialect instance.

@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
275    @classmethod
276    def format_time(
277        cls, expression: t.Optional[str | exp.Expression]
278    ) -> t.Optional[exp.Expression]:
279        if isinstance(expression, str):
280            return exp.Literal.string(
281                # the time formats are quoted
282                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
283            )
284
285        if expression and expression.is_string:
286            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
287
288        return expression
def normalize_identifier(self, expression: ~E) -> ~E:
306    def normalize_identifier(self, expression: E) -> E:
307        """
308        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
309
310        For example, an identifier like FoO would be resolved as foo in Postgres, because it
311        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
312        it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive,
313        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
314
315        There are also dialects like Spark, which are case-insensitive even when quotes are
316        present, and dialects like MySQL, whose resolution rules match those employed by the
317        underlying operating system, for example they may always be case-sensitive in Linux.
318
319        Finally, the normalization behavior of some engines can even be controlled through flags,
320        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
321
322        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
323        that it can analyze queries in the optimizer and successfully capture their semantics.
324        """
325        if (
326            isinstance(expression, exp.Identifier)
327            and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
328            and (
329                not expression.quoted
330                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
331            )
332        ):
333            expression.set(
334                "this",
335                expression.this.upper()
336                if self.normalization_strategy is NormalizationStrategy.UPPERCASE
337                else expression.this.lower(),
338            )
339
340        return expression

Transforms an identifier in a way that resembles how it'd be resolved by this dialect.

For example, an identifier like FoO would be resolved as foo in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier.

There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.

Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.

SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.

def case_sensitive(self, text: str) -> bool:
342    def case_sensitive(self, text: str) -> bool:
343        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
344        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
345            return False
346
347        unsafe = (
348            str.islower
349            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
350            else str.isupper
351        )
352        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

def can_identify(self, text: str, identify: str | bool = 'safe') -> bool:
354    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
355        """Checks if text can be identified given an identify option.
356
357        Args:
358            text: The text to check.
359            identify:
360                "always" or `True`: Always returns true.
361                "safe": True if the identifier is case-insensitive.
362
363        Returns:
364            Whether or not the given text can be identified.
365        """
366        if identify is True or identify == "always":
367            return True
368
369        if identify == "safe":
370            return not self.case_sensitive(text)
371
372        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:

Whether or not the given text can be identified.

def quote_identifier(self, expression: ~E, identify: bool = True) -> ~E:
374    def quote_identifier(self, expression: E, identify: bool = True) -> E:
375        if isinstance(expression, exp.Identifier):
376            name = expression.this
377            expression.set(
378                "quoted",
379                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
380            )
381
382        return expression
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
384    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
385        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
387    def parse_into(
388        self, expression_type: exp.IntoType, sql: str, **opts
389    ) -> t.List[t.Optional[exp.Expression]]:
390        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: sqlglot.expressions.Expression, copy: bool = True, **opts) -> str:
392    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
393        return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> List[str]:
395    def transpile(self, sql: str, **opts) -> t.List[str]:
396        return [
397            self.generate(expression, copy=False, **opts) if expression else ""
398            for expression in self.parse(sql)
399        ]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
401    def tokenize(self, sql: str) -> t.List[Token]:
402        return self.tokenizer.tokenize(sql)
def parser(self, **opts) -> sqlglot.parser.Parser:
410    def parser(self, **opts) -> Parser:
411        return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
413    def generator(self, **opts) -> Generator:
414        return self.generator_class(dialect=self, **opts)
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
420def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
421    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
424def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
425    if expression.args.get("accuracy"):
426        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
427    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
430def if_sql(
431    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
432) -> t.Callable[[Generator, exp.If], str]:
433    def _if_sql(self: Generator, expression: exp.If) -> str:
434        return self.func(
435            name,
436            expression.this,
437            expression.args.get("true"),
438            expression.args.get("false") or false_value,
439        )
440
441    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
444def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
445    return self.binary(expression, "->")
def arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
448def arrow_json_extract_scalar_sql(
449    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
450) -> str:
451    return self.binary(expression, "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
454def inline_array_sql(self: Generator, expression: exp.Array) -> str:
455    return f"[{self.expressions(expression, flat=True)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
458def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
459    return self.like_sql(
460        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
461    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
464def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
465    zone = self.sql(expression, "this")
466    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
469def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
470    if expression.args.get("recursive"):
471        self.unsupported("Recursive CTEs are unsupported")
472        expression.args["recursive"] = False
473    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
476def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
477    n = self.sql(expression, "this")
478    d = self.sql(expression, "expression")
479    return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
482def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
483    self.unsupported("TABLESAMPLE unsupported")
484    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
487def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
488    self.unsupported("PIVOT unsupported")
489    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
492def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
493    return self.cast_sql(expression)
def no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
496def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
497    self.unsupported("Properties unsupported")
498    return ""
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
501def no_comment_column_constraint_sql(
502    self: Generator, expression: exp.CommentColumnConstraint
503) -> str:
504    self.unsupported("CommentColumnConstraint unsupported")
505    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
508def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
509    self.unsupported("MAP_FROM_ENTRIES unsupported")
510    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
513def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
514    this = self.sql(expression, "this")
515    substr = self.sql(expression, "substr")
516    position = self.sql(expression, "position")
517    if position:
518        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
519    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
522def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
523    return (
524        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
525    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
528def var_map_sql(
529    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
530) -> str:
531    keys = expression.args["keys"]
532    values = expression.args["values"]
533
534    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
535        self.unsupported("Cannot convert array columns into map.")
536        return self.func(map_func_name, keys, values)
537
538    args = []
539    for key, value in zip(keys.expressions, values.expressions):
540        args.append(self.sql(key))
541        args.append(self.sql(value))
542
543    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
546def format_time_lambda(
547    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
548) -> t.Callable[[t.List], E]:
549    """Helper used for time expressions.
550
551    Args:
552        exp_class: the expression class to instantiate.
553        dialect: target sql dialect.
554        default: the default format, True being time.
555
556    Returns:
557        A callable that can be used to return the appropriately formatted time expression.
558    """
559
560    def _format_time(args: t.List):
561        return exp_class(
562            this=seq_get(args, 0),
563            format=Dialect[dialect].format_time(
564                seq_get(args, 1)
565                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
566            ),
567        )
568
569    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
572def time_format(
573    dialect: DialectType = None,
574) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
575    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
576        """
577        Returns the time format for a given expression, unless it's equivalent
578        to the default time format of the dialect of interest.
579        """
580        time_format = self.format_time(expression)
581        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
582
583    return _time_format
def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
586def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
587    """
588    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
589    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
590    columns are removed from the create statement.
591    """
592    has_schema = isinstance(expression.this, exp.Schema)
593    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
594
595    if has_schema and is_partitionable:
596        prop = expression.find(exp.PartitionedByProperty)
597        if prop and prop.this and not isinstance(prop.this, exp.Schema):
598            schema = expression.this
599            columns = {v.name.upper() for v in prop.this.expressions}
600            partitions = [col for col in schema.expressions if col.name.upper() in columns]
601            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
602            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
603            expression.set("this", schema)
604
605    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
608def parse_date_delta(
609    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
610) -> t.Callable[[t.List], E]:
611    def inner_func(args: t.List) -> E:
612        unit_based = len(args) == 3
613        this = args[2] if unit_based else seq_get(args, 0)
614        unit = args[0] if unit_based else exp.Literal.string("DAY")
615        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
616        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
617
618    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
621def parse_date_delta_with_interval(
622    expression_class: t.Type[E],
623) -> t.Callable[[t.List], t.Optional[E]]:
624    def func(args: t.List) -> t.Optional[E]:
625        if len(args) < 2:
626            return None
627
628        interval = args[1]
629
630        if not isinstance(interval, exp.Interval):
631            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
632
633        expression = interval.this
634        if expression and expression.is_string:
635            expression = exp.Literal.number(expression.this)
636
637        return expression_class(
638            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
639        )
640
641    return func
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
644def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
645    unit = seq_get(args, 0)
646    this = seq_get(args, 1)
647
648    if isinstance(this, exp.Cast) and this.is_type("date"):
649        return exp.DateTrunc(unit=unit, this=this)
650    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
653def date_add_interval_sql(
654    data_type: str, kind: str
655) -> t.Callable[[Generator, exp.Expression], str]:
656    def func(self: Generator, expression: exp.Expression) -> str:
657        this = self.sql(expression, "this")
658        unit = expression.args.get("unit")
659        unit = exp.var(unit.name.upper() if unit else "DAY")
660        interval = exp.Interval(this=expression.expression, unit=unit)
661        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
662
663    return func
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
666def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
667    return self.func(
668        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
669    )
def no_timestamp_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Timestamp) -> str:
672def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
673    if not expression.expression:
674        return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP))
675    if expression.text("expression").lower() in TIMEZONES:
676        return self.sql(
677            exp.AtTimeZone(
678                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
679                zone=expression.expression,
680            )
681        )
682    return self.function_fallback_sql(expression)
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
685def locate_to_strposition(args: t.List) -> exp.Expression:
686    return exp.StrPosition(
687        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
688    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
691def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
692    return self.func(
693        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
694    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
697def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
698    return self.sql(
699        exp.Substring(
700            this=expression.this, start=exp.Literal.number(1), length=expression.expression
701        )
702    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
705def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
706    return self.sql(
707        exp.Substring(
708            this=expression.this,
709            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
710        )
711    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
714def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
715    return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
718def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
719    return self.sql(exp.cast(expression.this, "date"))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
723def encode_decode_sql(
724    self: Generator, expression: exp.Expression, name: str, replace: bool = True
725) -> str:
726    charset = expression.args.get("charset")
727    if charset and charset.name.lower() != "utf-8":
728        self.unsupported(f"Expected utf-8 character set, got {charset}.")
729
730    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
733def min_or_least(self: Generator, expression: exp.Min) -> str:
734    name = "LEAST" if expression.expressions else "MIN"
735    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
738def max_or_greatest(self: Generator, expression: exp.Max) -> str:
739    name = "GREATEST" if expression.expressions else "MAX"
740    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
743def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
744    cond = expression.this
745
746    if isinstance(expression.this, exp.Distinct):
747        cond = expression.this.expressions[0]
748        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
749
750    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
753def trim_sql(self: Generator, expression: exp.Trim) -> str:
754    target = self.sql(expression, "this")
755    trim_type = self.sql(expression, "position")
756    remove_chars = self.sql(expression, "expression")
757    collation = self.sql(expression, "collation")
758
759    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
760    if not remove_chars and not collation:
761        return self.trim_sql(expression)
762
763    trim_type = f"{trim_type} " if trim_type else ""
764    remove_chars = f"{remove_chars} " if remove_chars else ""
765    from_part = "FROM " if trim_type or remove_chars else ""
766    collation = f" COLLATE {collation}" if collation else ""
767    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
770def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
771    return self.func("STRPTIME", expression.this, self.format_time(expression))
def ts_or_ds_to_date_sql(dialect: str) -> Callable:
774def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
775    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
776        _dialect = Dialect.get_or_raise(dialect)
777        time_format = self.format_time(expression)
778        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
779            return self.sql(
780                exp.cast(
781                    exp.StrToTime(this=expression.this, format=expression.args["format"]),
782                    "date",
783                )
784            )
785        return self.sql(exp.cast(expression.this, "date"))
786
787    return _ts_or_ds_to_date_sql
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat) -> str:
790def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
791    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
794def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
795    delim, *rest_args = expression.expressions
796    return self.sql(
797        reduce(
798            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
799            rest_args,
800        )
801    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
804def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
805    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
806    if bad_args:
807        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
808
809    return self.func(
810        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
811    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
814def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
815    bad_args = list(
816        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
817    )
818    if bad_args:
819        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
820
821    return self.func(
822        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
823    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
826def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
827    names = []
828    for agg in aggregations:
829        if isinstance(agg, exp.Alias):
830            names.append(agg.alias)
831        else:
832            """
833            This case corresponds to aggregations without aliases being used as suffixes
834            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
835            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
836            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
837            """
838            agg_all_unquoted = agg.transform(
839                lambda node: exp.Identifier(this=node.name, quoted=False)
840                if isinstance(node, exp.Identifier)
841                else node
842            )
843            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
844
845    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
848def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
849    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def parse_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
853def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
854    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
857def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
858    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
861def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
862    a = self.sql(expression.left)
863    b = self.sql(expression.right)
864    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def json_keyvalue_comma_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONKeyValue) -> str:
868def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str:
869    return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
872def is_parse_json(expression: exp.Expression) -> bool:
873    return isinstance(expression, exp.ParseJSON) or (
874        isinstance(expression, exp.Cast) and expression.is_type("json")
875    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
878def isnull_to_is_null(args: t.List) -> exp.Expression:
879    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def generatedasidentitycolumnconstraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.GeneratedAsIdentityColumnConstraint) -> str:
882def generatedasidentitycolumnconstraint_sql(
883    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
884) -> str:
885    start = self.sql(expression, "start") or "1"
886    increment = self.sql(expression, "increment") or "1"
887    return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.ArgMax | sqlglot.expressions.ArgMin], str]:
890def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
891    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
892        if expression.args.get("count"):
893            self.unsupported(f"Only two arguments are supported in function {name}.")
894
895        return self.func(name, expression.this, expression.expression)
896
897    return _arg_max_or_min_sql
def ts_or_ds_add_cast( expression: sqlglot.expressions.TsOrDsAdd) -> sqlglot.expressions.TsOrDsAdd:
900def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
901    this = expression.this.copy()
902
903    return_type = expression.return_type
904    if return_type.is_type(exp.DataType.Type.DATE):
905        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
906        # can truncate timestamp strings, because some dialects can't cast them to DATE
907        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
908
909    expression.this.replace(exp.cast(this, return_type))
910    return expression
def date_delta_sql( name: str, cast: bool = False) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.DateAdd, sqlglot.expressions.TsOrDsAdd, sqlglot.expressions.DateDiff, sqlglot.expressions.TsOrDsDiff]], str]:
913def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
914    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
915        if cast and isinstance(expression, exp.TsOrDsAdd):
916            expression = ts_or_ds_add_cast(expression)
917
918        return self.func(
919            name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this
920        )
921
922    return _delta_sql