Edit on GitHub

sqlglot.dialects.dialect

  1from __future__ import annotations
  2
  3import typing as t
  4from enum import Enum, auto
  5from functools import reduce
  6
  7from sqlglot import exp
  8from sqlglot._typing import E
  9from sqlglot.errors import ParseError
 10from sqlglot.generator import Generator
 11from sqlglot.helper import AutoName, flatten, seq_get
 12from sqlglot.parser import Parser
 13from sqlglot.time import TIMEZONES, format_time
 14from sqlglot.tokens import Token, Tokenizer, TokenType
 15from sqlglot.trie import new_trie
 16
 17B = t.TypeVar("B", bound=exp.Binary)
 18
 19DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
 20DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
 21
 22
 23class Dialects(str, Enum):
 24    DIALECT = ""
 25
 26    BIGQUERY = "bigquery"
 27    CLICKHOUSE = "clickhouse"
 28    DATABRICKS = "databricks"
 29    DRILL = "drill"
 30    DUCKDB = "duckdb"
 31    HIVE = "hive"
 32    MYSQL = "mysql"
 33    ORACLE = "oracle"
 34    POSTGRES = "postgres"
 35    PRESTO = "presto"
 36    REDSHIFT = "redshift"
 37    SNOWFLAKE = "snowflake"
 38    SPARK = "spark"
 39    SPARK2 = "spark2"
 40    SQLITE = "sqlite"
 41    STARROCKS = "starrocks"
 42    TABLEAU = "tableau"
 43    TERADATA = "teradata"
 44    TRINO = "trino"
 45    TSQL = "tsql"
 46    Doris = "doris"
 47
 48
 49class NormalizationStrategy(str, AutoName):
 50    """Specifies the strategy according to which identifiers should be normalized."""
 51
 52    LOWERCASE = auto()  # Unquoted identifiers are lowercased
 53    UPPERCASE = auto()  # Unquoted identifiers are uppercased
 54    CASE_SENSITIVE = auto()  # Always case-sensitive, regardless of quotes
 55    CASE_INSENSITIVE = auto()  # Always case-insensitive, regardless of quotes
 56
 57
 58class _Dialect(type):
 59    classes: t.Dict[str, t.Type[Dialect]] = {}
 60
 61    def __eq__(cls, other: t.Any) -> bool:
 62        if cls is other:
 63            return True
 64        if isinstance(other, str):
 65            return cls is cls.get(other)
 66        if isinstance(other, Dialect):
 67            return cls is type(other)
 68
 69        return False
 70
 71    def __hash__(cls) -> int:
 72        return hash(cls.__name__.lower())
 73
 74    @classmethod
 75    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 76        return cls.classes[key]
 77
 78    @classmethod
 79    def get(
 80        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 81    ) -> t.Optional[t.Type[Dialect]]:
 82        return cls.classes.get(key, default)
 83
 84    def __new__(cls, clsname, bases, attrs):
 85        klass = super().__new__(cls, clsname, bases, attrs)
 86        enum = Dialects.__members__.get(clsname.upper())
 87        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 88
 89        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 90        klass.FORMAT_TRIE = (
 91            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 92        )
 93        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 94        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 95
 96        klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()}
 97
 98        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 99        klass.parser_class = getattr(klass, "Parser", Parser)
100        klass.generator_class = getattr(klass, "Generator", Generator)
101
102        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
103        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
104            klass.tokenizer_class._IDENTIFIERS.items()
105        )[0]
106
107        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
108            return next(
109                (
110                    (s, e)
111                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
112                    if t == token_type
113                ),
114                (None, None),
115            )
116
117        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
118        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
119        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
120
121        if enum not in ("", "bigquery"):
122            klass.generator_class.SELECT_KINDS = ()
123
124        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
125            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
126                TokenType.ANTI,
127                TokenType.SEMI,
128            }
129
130        return klass
131
132
133class Dialect(metaclass=_Dialect):
134    # Determines the base index offset for arrays
135    INDEX_OFFSET = 0
136
137    # If true unnest table aliases are considered only as column aliases
138    UNNEST_COLUMN_ONLY = False
139
140    # Determines whether or not the table alias comes after tablesample
141    ALIAS_POST_TABLESAMPLE = False
142
143    # Specifies the strategy according to which identifiers should be normalized.
144    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
145
146    # Determines whether or not an unquoted identifier can start with a digit
147    IDENTIFIERS_CAN_START_WITH_DIGIT = False
148
149    # Determines whether or not the DPIPE token ('||') is a string concatenation operator
150    DPIPE_IS_STRING_CONCAT = True
151
152    # Determines whether or not CONCAT's arguments must be strings
153    STRICT_STRING_CONCAT = False
154
155    # Determines whether or not user-defined data types are supported
156    SUPPORTS_USER_DEFINED_TYPES = True
157
158    # Determines whether or not SEMI/ANTI JOINs are supported
159    SUPPORTS_SEMI_ANTI_JOIN = True
160
161    # Determines how function names are going to be normalized
162    NORMALIZE_FUNCTIONS: bool | str = "upper"
163
164    # Determines whether the base comes first in the LOG function
165    LOG_BASE_FIRST = True
166
167    # Indicates the default null ordering method to use if not explicitly set
168    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
169    NULL_ORDERING = "nulls_are_small"
170
171    # Whether the behavior of a / b depends on the types of a and b.
172    # False means a / b is always float division.
173    # True means a / b is integer division if both a and b are integers.
174    TYPED_DIVISION = False
175
176    # False means 1 / 0 throws an error.
177    # True means 1 / 0 returns null.
178    SAFE_DIVISION = False
179
180    DATE_FORMAT = "'%Y-%m-%d'"
181    DATEINT_FORMAT = "'%Y%m%d'"
182    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
183
184    # Custom time mappings in which the key represents dialect time format
185    # and the value represents a python time format
186    TIME_MAPPING: t.Dict[str, str] = {}
187
188    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
189    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
190    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
191    FORMAT_MAPPING: t.Dict[str, str] = {}
192
193    # Mapping of an unescaped escape sequence to the corresponding character
194    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
195
196    # Columns that are auto-generated by the engine corresponding to this dialect
197    # Such columns may be excluded from SELECT * queries, for example
198    PSEUDOCOLUMNS: t.Set[str] = set()
199
200    # --- Autofilled ---
201
202    tokenizer_class = Tokenizer
203    parser_class = Parser
204    generator_class = Generator
205
206    # A trie of the time_mapping keys
207    TIME_TRIE: t.Dict = {}
208    FORMAT_TRIE: t.Dict = {}
209
210    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
211    INVERSE_TIME_TRIE: t.Dict = {}
212
213    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
214
215    # Delimiters for quotes, identifiers and the corresponding escape characters
216    QUOTE_START = "'"
217    QUOTE_END = "'"
218    IDENTIFIER_START = '"'
219    IDENTIFIER_END = '"'
220
221    # Delimiters for bit, hex and byte literals
222    BIT_START: t.Optional[str] = None
223    BIT_END: t.Optional[str] = None
224    HEX_START: t.Optional[str] = None
225    HEX_END: t.Optional[str] = None
226    BYTE_START: t.Optional[str] = None
227    BYTE_END: t.Optional[str] = None
228
229    @classmethod
230    def get_or_raise(cls, dialect: DialectType) -> Dialect:
231        """
232        Look up a dialect in the global dialect registry and return it if it exists.
233
234        Args:
235            dialect: The target dialect. If this is a string, it can be optionally followed by
236                additional key-value pairs that are separated by commas and are used to specify
237                dialect settings, such as whether the dialect's identifiers are case-sensitive.
238
239        Example:
240            >>> dialect = dialect_class = get_or_raise("duckdb")
241            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
242
243        Returns:
244            The corresponding Dialect instance.
245        """
246
247        if not dialect:
248            return cls()
249        if isinstance(dialect, _Dialect):
250            return dialect()
251        if isinstance(dialect, Dialect):
252            return dialect
253        if isinstance(dialect, str):
254            try:
255                dialect_name, *kv_pairs = dialect.split(",")
256                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
257            except ValueError:
258                raise ValueError(
259                    f"Invalid dialect format: '{dialect}'. "
260                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
261                )
262
263            result = cls.get(dialect_name.strip())
264            if not result:
265                raise ValueError(f"Unknown dialect '{dialect_name}'.")
266
267            return result(**kwargs)
268
269        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
270
271    @classmethod
272    def format_time(
273        cls, expression: t.Optional[str | exp.Expression]
274    ) -> t.Optional[exp.Expression]:
275        if isinstance(expression, str):
276            return exp.Literal.string(
277                # the time formats are quoted
278                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
279            )
280
281        if expression and expression.is_string:
282            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
283
284        return expression
285
286    def __init__(self, **kwargs) -> None:
287        normalization_strategy = kwargs.get("normalization_strategy")
288
289        if normalization_strategy is None:
290            self.normalization_strategy = self.NORMALIZATION_STRATEGY
291        else:
292            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
293
294    def __eq__(self, other: t.Any) -> bool:
295        # Does not currently take dialect state into account
296        return type(self) == other
297
298    def __hash__(self) -> int:
299        # Does not currently take dialect state into account
300        return hash(type(self))
301
302    def normalize_identifier(self, expression: E) -> E:
303        """
304        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
305
306        For example, an identifier like FoO would be resolved as foo in Postgres, because it
307        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
308        it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive,
309        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
310
311        There are also dialects like Spark, which are case-insensitive even when quotes are
312        present, and dialects like MySQL, whose resolution rules match those employed by the
313        underlying operating system, for example they may always be case-sensitive in Linux.
314
315        Finally, the normalization behavior of some engines can even be controlled through flags,
316        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
317
318        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
319        that it can analyze queries in the optimizer and successfully capture their semantics.
320        """
321        if (
322            isinstance(expression, exp.Identifier)
323            and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
324            and (
325                not expression.quoted
326                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
327            )
328        ):
329            expression.set(
330                "this",
331                expression.this.upper()
332                if self.normalization_strategy is NormalizationStrategy.UPPERCASE
333                else expression.this.lower(),
334            )
335
336        return expression
337
338    def case_sensitive(self, text: str) -> bool:
339        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
340        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
341            return False
342
343        unsafe = (
344            str.islower
345            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
346            else str.isupper
347        )
348        return any(unsafe(char) for char in text)
349
350    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
351        """Checks if text can be identified given an identify option.
352
353        Args:
354            text: The text to check.
355            identify:
356                "always" or `True`: Always returns true.
357                "safe": True if the identifier is case-insensitive.
358
359        Returns:
360            Whether or not the given text can be identified.
361        """
362        if identify is True or identify == "always":
363            return True
364
365        if identify == "safe":
366            return not self.case_sensitive(text)
367
368        return False
369
370    def quote_identifier(self, expression: E, identify: bool = True) -> E:
371        if isinstance(expression, exp.Identifier):
372            name = expression.this
373            expression.set(
374                "quoted",
375                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
376            )
377
378        return expression
379
380    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
381        return self.parser(**opts).parse(self.tokenize(sql), sql)
382
383    def parse_into(
384        self, expression_type: exp.IntoType, sql: str, **opts
385    ) -> t.List[t.Optional[exp.Expression]]:
386        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
387
388    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
389        return self.generator(**opts).generate(expression, copy=copy)
390
391    def transpile(self, sql: str, **opts) -> t.List[str]:
392        return [
393            self.generate(expression, copy=False, **opts) if expression else ""
394            for expression in self.parse(sql)
395        ]
396
397    def tokenize(self, sql: str) -> t.List[Token]:
398        return self.tokenizer.tokenize(sql)
399
400    @property
401    def tokenizer(self) -> Tokenizer:
402        if not hasattr(self, "_tokenizer"):
403            self._tokenizer = self.tokenizer_class(dialect=self)
404        return self._tokenizer
405
406    def parser(self, **opts) -> Parser:
407        return self.parser_class(dialect=self, **opts)
408
409    def generator(self, **opts) -> Generator:
410        return self.generator_class(dialect=self, **opts)
411
412
413DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
414
415
416def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
417    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
418
419
420def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
421    if expression.args.get("accuracy"):
422        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
423    return self.func("APPROX_COUNT_DISTINCT", expression.this)
424
425
426def if_sql(
427    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
428) -> t.Callable[[Generator, exp.If], str]:
429    def _if_sql(self: Generator, expression: exp.If) -> str:
430        return self.func(
431            name,
432            expression.this,
433            expression.args.get("true"),
434            expression.args.get("false") or false_value,
435        )
436
437    return _if_sql
438
439
440def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
441    return self.binary(expression, "->")
442
443
444def arrow_json_extract_scalar_sql(
445    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
446) -> str:
447    return self.binary(expression, "->>")
448
449
450def inline_array_sql(self: Generator, expression: exp.Array) -> str:
451    return f"[{self.expressions(expression, flat=True)}]"
452
453
454def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
455    return self.like_sql(
456        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
457    )
458
459
460def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
461    zone = self.sql(expression, "this")
462    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
463
464
465def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
466    if expression.args.get("recursive"):
467        self.unsupported("Recursive CTEs are unsupported")
468        expression.args["recursive"] = False
469    return self.with_sql(expression)
470
471
472def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
473    n = self.sql(expression, "this")
474    d = self.sql(expression, "expression")
475    return f"IF({d} <> 0, {n} / {d}, NULL)"
476
477
478def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
479    self.unsupported("TABLESAMPLE unsupported")
480    return self.sql(expression.this)
481
482
483def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
484    self.unsupported("PIVOT unsupported")
485    return ""
486
487
488def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
489    return self.cast_sql(expression)
490
491
492def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
493    self.unsupported("Properties unsupported")
494    return ""
495
496
497def no_comment_column_constraint_sql(
498    self: Generator, expression: exp.CommentColumnConstraint
499) -> str:
500    self.unsupported("CommentColumnConstraint unsupported")
501    return ""
502
503
504def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
505    self.unsupported("MAP_FROM_ENTRIES unsupported")
506    return ""
507
508
509def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
510    this = self.sql(expression, "this")
511    substr = self.sql(expression, "substr")
512    position = self.sql(expression, "position")
513    if position:
514        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
515    return f"STRPOS({this}, {substr})"
516
517
518def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
519    return (
520        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
521    )
522
523
524def var_map_sql(
525    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
526) -> str:
527    keys = expression.args["keys"]
528    values = expression.args["values"]
529
530    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
531        self.unsupported("Cannot convert array columns into map.")
532        return self.func(map_func_name, keys, values)
533
534    args = []
535    for key, value in zip(keys.expressions, values.expressions):
536        args.append(self.sql(key))
537        args.append(self.sql(value))
538
539    return self.func(map_func_name, *args)
540
541
542def format_time_lambda(
543    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
544) -> t.Callable[[t.List], E]:
545    """Helper used for time expressions.
546
547    Args:
548        exp_class: the expression class to instantiate.
549        dialect: target sql dialect.
550        default: the default format, True being time.
551
552    Returns:
553        A callable that can be used to return the appropriately formatted time expression.
554    """
555
556    def _format_time(args: t.List):
557        return exp_class(
558            this=seq_get(args, 0),
559            format=Dialect[dialect].format_time(
560                seq_get(args, 1)
561                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
562            ),
563        )
564
565    return _format_time
566
567
568def time_format(
569    dialect: DialectType = None,
570) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
571    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
572        """
573        Returns the time format for a given expression, unless it's equivalent
574        to the default time format of the dialect of interest.
575        """
576        time_format = self.format_time(expression)
577        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
578
579    return _time_format
580
581
582def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
583    """
584    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
585    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
586    columns are removed from the create statement.
587    """
588    has_schema = isinstance(expression.this, exp.Schema)
589    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
590
591    if has_schema and is_partitionable:
592        prop = expression.find(exp.PartitionedByProperty)
593        if prop and prop.this and not isinstance(prop.this, exp.Schema):
594            schema = expression.this
595            columns = {v.name.upper() for v in prop.this.expressions}
596            partitions = [col for col in schema.expressions if col.name.upper() in columns]
597            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
598            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
599            expression.set("this", schema)
600
601    return self.create_sql(expression)
602
603
604def parse_date_delta(
605    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
606) -> t.Callable[[t.List], E]:
607    def inner_func(args: t.List) -> E:
608        unit_based = len(args) == 3
609        this = args[2] if unit_based else seq_get(args, 0)
610        unit = args[0] if unit_based else exp.Literal.string("DAY")
611        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
612        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
613
614    return inner_func
615
616
617def parse_date_delta_with_interval(
618    expression_class: t.Type[E],
619) -> t.Callable[[t.List], t.Optional[E]]:
620    def func(args: t.List) -> t.Optional[E]:
621        if len(args) < 2:
622            return None
623
624        interval = args[1]
625
626        if not isinstance(interval, exp.Interval):
627            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
628
629        expression = interval.this
630        if expression and expression.is_string:
631            expression = exp.Literal.number(expression.this)
632
633        return expression_class(
634            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
635        )
636
637    return func
638
639
640def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
641    unit = seq_get(args, 0)
642    this = seq_get(args, 1)
643
644    if isinstance(this, exp.Cast) and this.is_type("date"):
645        return exp.DateTrunc(unit=unit, this=this)
646    return exp.TimestampTrunc(this=this, unit=unit)
647
648
649def date_add_interval_sql(
650    data_type: str, kind: str
651) -> t.Callable[[Generator, exp.Expression], str]:
652    def func(self: Generator, expression: exp.Expression) -> str:
653        this = self.sql(expression, "this")
654        unit = expression.args.get("unit")
655        unit = exp.var(unit.name.upper() if unit else "DAY")
656        interval = exp.Interval(this=expression.expression, unit=unit)
657        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
658
659    return func
660
661
662def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
663    return self.func(
664        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
665    )
666
667
668def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
669    if not expression.expression:
670        return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP))
671    if expression.text("expression").lower() in TIMEZONES:
672        return self.sql(
673            exp.AtTimeZone(
674                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
675                zone=expression.expression,
676            )
677        )
678    return self.function_fallback_sql(expression)
679
680
681def locate_to_strposition(args: t.List) -> exp.Expression:
682    return exp.StrPosition(
683        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
684    )
685
686
687def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
688    return self.func(
689        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
690    )
691
692
693def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
694    return self.sql(
695        exp.Substring(
696            this=expression.this, start=exp.Literal.number(1), length=expression.expression
697        )
698    )
699
700
701def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
702    return self.sql(
703        exp.Substring(
704            this=expression.this,
705            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
706        )
707    )
708
709
710def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
711    return self.sql(exp.cast(expression.this, "timestamp"))
712
713
714def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
715    return self.sql(exp.cast(expression.this, "date"))
716
717
718# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
719def encode_decode_sql(
720    self: Generator, expression: exp.Expression, name: str, replace: bool = True
721) -> str:
722    charset = expression.args.get("charset")
723    if charset and charset.name.lower() != "utf-8":
724        self.unsupported(f"Expected utf-8 character set, got {charset}.")
725
726    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
727
728
729def min_or_least(self: Generator, expression: exp.Min) -> str:
730    name = "LEAST" if expression.expressions else "MIN"
731    return rename_func(name)(self, expression)
732
733
734def max_or_greatest(self: Generator, expression: exp.Max) -> str:
735    name = "GREATEST" if expression.expressions else "MAX"
736    return rename_func(name)(self, expression)
737
738
739def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
740    cond = expression.this
741
742    if isinstance(expression.this, exp.Distinct):
743        cond = expression.this.expressions[0]
744        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
745
746    return self.func("sum", exp.func("if", cond, 1, 0))
747
748
749def trim_sql(self: Generator, expression: exp.Trim) -> str:
750    target = self.sql(expression, "this")
751    trim_type = self.sql(expression, "position")
752    remove_chars = self.sql(expression, "expression")
753    collation = self.sql(expression, "collation")
754
755    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
756    if not remove_chars and not collation:
757        return self.trim_sql(expression)
758
759    trim_type = f"{trim_type} " if trim_type else ""
760    remove_chars = f"{remove_chars} " if remove_chars else ""
761    from_part = "FROM " if trim_type or remove_chars else ""
762    collation = f" COLLATE {collation}" if collation else ""
763    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
764
765
766def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
767    return self.func("STRPTIME", expression.this, self.format_time(expression))
768
769
770def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
771    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
772        _dialect = Dialect.get_or_raise(dialect)
773        time_format = self.format_time(expression)
774        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
775            return self.sql(
776                exp.cast(
777                    exp.StrToTime(this=expression.this, format=expression.args["format"]),
778                    "date",
779                )
780            )
781        return self.sql(exp.cast(expression.this, "date"))
782
783    return _ts_or_ds_to_date_sql
784
785
786def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
787    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
788
789
790def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
791    delim, *rest_args = expression.expressions
792    return self.sql(
793        reduce(
794            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
795            rest_args,
796        )
797    )
798
799
800def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
801    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
802    if bad_args:
803        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
804
805    return self.func(
806        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
807    )
808
809
810def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
811    bad_args = list(
812        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
813    )
814    if bad_args:
815        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
816
817    return self.func(
818        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
819    )
820
821
822def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
823    names = []
824    for agg in aggregations:
825        if isinstance(agg, exp.Alias):
826            names.append(agg.alias)
827        else:
828            """
829            This case corresponds to aggregations without aliases being used as suffixes
830            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
831            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
832            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
833            """
834            agg_all_unquoted = agg.transform(
835                lambda node: exp.Identifier(this=node.name, quoted=False)
836                if isinstance(node, exp.Identifier)
837                else node
838            )
839            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
840
841    return names
842
843
844def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
845    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
846
847
848# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
849def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
850    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
851
852
853def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
854    return self.func("MAX", expression.this)
855
856
857def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
858    a = self.sql(expression.left)
859    b = self.sql(expression.right)
860    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
861
862
863# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon
864def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str:
865    return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
866
867
868def is_parse_json(expression: exp.Expression) -> bool:
869    return isinstance(expression, exp.ParseJSON) or (
870        isinstance(expression, exp.Cast) and expression.is_type("json")
871    )
872
873
874def isnull_to_is_null(args: t.List) -> exp.Expression:
875    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
876
877
878def generatedasidentitycolumnconstraint_sql(
879    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
880) -> str:
881    start = self.sql(expression, "start") or "1"
882    increment = self.sql(expression, "increment") or "1"
883    return f"IDENTITY({start}, {increment})"
884
885
886def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
887    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
888        if expression.args.get("count"):
889            self.unsupported(f"Only two arguments are supported in function {name}.")
890
891        return self.func(name, expression.this, expression.expression)
892
893    return _arg_max_or_min_sql
894
895
896def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
897    this = expression.this.copy()
898
899    return_type = expression.return_type
900    if return_type.is_type(exp.DataType.Type.DATE):
901        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
902        # can truncate timestamp strings, because some dialects can't cast them to DATE
903        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
904
905    expression.this.replace(exp.cast(this, return_type))
906    return expression
907
908
909def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
910    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
911        if cast and isinstance(expression, exp.TsOrDsAdd):
912            expression = ts_or_ds_add_cast(expression)
913
914        return self.func(
915            name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this
916        )
917
918    return _delta_sql
class Dialects(builtins.str, enum.Enum):
24class Dialects(str, Enum):
25    DIALECT = ""
26
27    BIGQUERY = "bigquery"
28    CLICKHOUSE = "clickhouse"
29    DATABRICKS = "databricks"
30    DRILL = "drill"
31    DUCKDB = "duckdb"
32    HIVE = "hive"
33    MYSQL = "mysql"
34    ORACLE = "oracle"
35    POSTGRES = "postgres"
36    PRESTO = "presto"
37    REDSHIFT = "redshift"
38    SNOWFLAKE = "snowflake"
39    SPARK = "spark"
40    SPARK2 = "spark2"
41    SQLITE = "sqlite"
42    STARROCKS = "starrocks"
43    TABLEAU = "tableau"
44    TERADATA = "teradata"
45    TRINO = "trino"
46    TSQL = "tsql"
47    Doris = "doris"

An enumeration.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Doris = <Dialects.Doris: 'doris'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class NormalizationStrategy(builtins.str, sqlglot.helper.AutoName):
50class NormalizationStrategy(str, AutoName):
51    """Specifies the strategy according to which identifiers should be normalized."""
52
53    LOWERCASE = auto()  # Unquoted identifiers are lowercased
54    UPPERCASE = auto()  # Unquoted identifiers are uppercased
55    CASE_SENSITIVE = auto()  # Always case-sensitive, regardless of quotes
56    CASE_INSENSITIVE = auto()  # Always case-insensitive, regardless of quotes

Specifies the strategy according to which identifiers should be normalized.

LOWERCASE = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>
UPPERCASE = <NormalizationStrategy.UPPERCASE: 'UPPERCASE'>
CASE_SENSITIVE = <NormalizationStrategy.CASE_SENSITIVE: 'CASE_SENSITIVE'>
CASE_INSENSITIVE = <NormalizationStrategy.CASE_INSENSITIVE: 'CASE_INSENSITIVE'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
134class Dialect(metaclass=_Dialect):
135    # Determines the base index offset for arrays
136    INDEX_OFFSET = 0
137
138    # If true unnest table aliases are considered only as column aliases
139    UNNEST_COLUMN_ONLY = False
140
141    # Determines whether or not the table alias comes after tablesample
142    ALIAS_POST_TABLESAMPLE = False
143
144    # Specifies the strategy according to which identifiers should be normalized.
145    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
146
147    # Determines whether or not an unquoted identifier can start with a digit
148    IDENTIFIERS_CAN_START_WITH_DIGIT = False
149
150    # Determines whether or not the DPIPE token ('||') is a string concatenation operator
151    DPIPE_IS_STRING_CONCAT = True
152
153    # Determines whether or not CONCAT's arguments must be strings
154    STRICT_STRING_CONCAT = False
155
156    # Determines whether or not user-defined data types are supported
157    SUPPORTS_USER_DEFINED_TYPES = True
158
159    # Determines whether or not SEMI/ANTI JOINs are supported
160    SUPPORTS_SEMI_ANTI_JOIN = True
161
162    # Determines how function names are going to be normalized
163    NORMALIZE_FUNCTIONS: bool | str = "upper"
164
165    # Determines whether the base comes first in the LOG function
166    LOG_BASE_FIRST = True
167
168    # Indicates the default null ordering method to use if not explicitly set
169    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
170    NULL_ORDERING = "nulls_are_small"
171
172    # Whether the behavior of a / b depends on the types of a and b.
173    # False means a / b is always float division.
174    # True means a / b is integer division if both a and b are integers.
175    TYPED_DIVISION = False
176
177    # False means 1 / 0 throws an error.
178    # True means 1 / 0 returns null.
179    SAFE_DIVISION = False
180
181    DATE_FORMAT = "'%Y-%m-%d'"
182    DATEINT_FORMAT = "'%Y%m%d'"
183    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
184
185    # Custom time mappings in which the key represents dialect time format
186    # and the value represents a python time format
187    TIME_MAPPING: t.Dict[str, str] = {}
188
189    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
190    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
191    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
192    FORMAT_MAPPING: t.Dict[str, str] = {}
193
194    # Mapping of an unescaped escape sequence to the corresponding character
195    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
196
197    # Columns that are auto-generated by the engine corresponding to this dialect
198    # Such columns may be excluded from SELECT * queries, for example
199    PSEUDOCOLUMNS: t.Set[str] = set()
200
201    # --- Autofilled ---
202
203    tokenizer_class = Tokenizer
204    parser_class = Parser
205    generator_class = Generator
206
207    # A trie of the time_mapping keys
208    TIME_TRIE: t.Dict = {}
209    FORMAT_TRIE: t.Dict = {}
210
211    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
212    INVERSE_TIME_TRIE: t.Dict = {}
213
214    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
215
216    # Delimiters for quotes, identifiers and the corresponding escape characters
217    QUOTE_START = "'"
218    QUOTE_END = "'"
219    IDENTIFIER_START = '"'
220    IDENTIFIER_END = '"'
221
222    # Delimiters for bit, hex and byte literals
223    BIT_START: t.Optional[str] = None
224    BIT_END: t.Optional[str] = None
225    HEX_START: t.Optional[str] = None
226    HEX_END: t.Optional[str] = None
227    BYTE_START: t.Optional[str] = None
228    BYTE_END: t.Optional[str] = None
229
230    @classmethod
231    def get_or_raise(cls, dialect: DialectType) -> Dialect:
232        """
233        Look up a dialect in the global dialect registry and return it if it exists.
234
235        Args:
236            dialect: The target dialect. If this is a string, it can be optionally followed by
237                additional key-value pairs that are separated by commas and are used to specify
238                dialect settings, such as whether the dialect's identifiers are case-sensitive.
239
240        Example:
241            >>> dialect = dialect_class = get_or_raise("duckdb")
242            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
243
244        Returns:
245            The corresponding Dialect instance.
246        """
247
248        if not dialect:
249            return cls()
250        if isinstance(dialect, _Dialect):
251            return dialect()
252        if isinstance(dialect, Dialect):
253            return dialect
254        if isinstance(dialect, str):
255            try:
256                dialect_name, *kv_pairs = dialect.split(",")
257                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
258            except ValueError:
259                raise ValueError(
260                    f"Invalid dialect format: '{dialect}'. "
261                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
262                )
263
264            result = cls.get(dialect_name.strip())
265            if not result:
266                raise ValueError(f"Unknown dialect '{dialect_name}'.")
267
268            return result(**kwargs)
269
270        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
271
272    @classmethod
273    def format_time(
274        cls, expression: t.Optional[str | exp.Expression]
275    ) -> t.Optional[exp.Expression]:
276        if isinstance(expression, str):
277            return exp.Literal.string(
278                # the time formats are quoted
279                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
280            )
281
282        if expression and expression.is_string:
283            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
284
285        return expression
286
287    def __init__(self, **kwargs) -> None:
288        normalization_strategy = kwargs.get("normalization_strategy")
289
290        if normalization_strategy is None:
291            self.normalization_strategy = self.NORMALIZATION_STRATEGY
292        else:
293            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
294
295    def __eq__(self, other: t.Any) -> bool:
296        # Does not currently take dialect state into account
297        return type(self) == other
298
299    def __hash__(self) -> int:
300        # Does not currently take dialect state into account
301        return hash(type(self))
302
303    def normalize_identifier(self, expression: E) -> E:
304        """
305        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
306
307        For example, an identifier like FoO would be resolved as foo in Postgres, because it
308        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
309        it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive,
310        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
311
312        There are also dialects like Spark, which are case-insensitive even when quotes are
313        present, and dialects like MySQL, whose resolution rules match those employed by the
314        underlying operating system, for example they may always be case-sensitive in Linux.
315
316        Finally, the normalization behavior of some engines can even be controlled through flags,
317        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
318
319        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
320        that it can analyze queries in the optimizer and successfully capture their semantics.
321        """
322        if (
323            isinstance(expression, exp.Identifier)
324            and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
325            and (
326                not expression.quoted
327                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
328            )
329        ):
330            expression.set(
331                "this",
332                expression.this.upper()
333                if self.normalization_strategy is NormalizationStrategy.UPPERCASE
334                else expression.this.lower(),
335            )
336
337        return expression
338
339    def case_sensitive(self, text: str) -> bool:
340        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
341        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
342            return False
343
344        unsafe = (
345            str.islower
346            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
347            else str.isupper
348        )
349        return any(unsafe(char) for char in text)
350
351    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
352        """Checks if text can be identified given an identify option.
353
354        Args:
355            text: The text to check.
356            identify:
357                "always" or `True`: Always returns true.
358                "safe": True if the identifier is case-insensitive.
359
360        Returns:
361            Whether or not the given text can be identified.
362        """
363        if identify is True or identify == "always":
364            return True
365
366        if identify == "safe":
367            return not self.case_sensitive(text)
368
369        return False
370
371    def quote_identifier(self, expression: E, identify: bool = True) -> E:
372        if isinstance(expression, exp.Identifier):
373            name = expression.this
374            expression.set(
375                "quoted",
376                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
377            )
378
379        return expression
380
381    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
382        return self.parser(**opts).parse(self.tokenize(sql), sql)
383
384    def parse_into(
385        self, expression_type: exp.IntoType, sql: str, **opts
386    ) -> t.List[t.Optional[exp.Expression]]:
387        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
388
389    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
390        return self.generator(**opts).generate(expression, copy=copy)
391
392    def transpile(self, sql: str, **opts) -> t.List[str]:
393        return [
394            self.generate(expression, copy=False, **opts) if expression else ""
395            for expression in self.parse(sql)
396        ]
397
398    def tokenize(self, sql: str) -> t.List[Token]:
399        return self.tokenizer.tokenize(sql)
400
401    @property
402    def tokenizer(self) -> Tokenizer:
403        if not hasattr(self, "_tokenizer"):
404            self._tokenizer = self.tokenizer_class(dialect=self)
405        return self._tokenizer
406
407    def parser(self, **opts) -> Parser:
408        return self.parser_class(dialect=self, **opts)
409
410    def generator(self, **opts) -> Generator:
411        return self.generator_class(dialect=self, **opts)
Dialect(**kwargs)
287    def __init__(self, **kwargs) -> None:
288        normalization_strategy = kwargs.get("normalization_strategy")
289
290        if normalization_strategy is None:
291            self.normalization_strategy = self.NORMALIZATION_STRATEGY
292        else:
293            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
INDEX_OFFSET = 0
UNNEST_COLUMN_ONLY = False
ALIAS_POST_TABLESAMPLE = False
NORMALIZATION_STRATEGY = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>
IDENTIFIERS_CAN_START_WITH_DIGIT = False
DPIPE_IS_STRING_CONCAT = True
STRICT_STRING_CONCAT = False
SUPPORTS_USER_DEFINED_TYPES = True
SUPPORTS_SEMI_ANTI_JOIN = True
NORMALIZE_FUNCTIONS: bool | str = 'upper'
LOG_BASE_FIRST = True
NULL_ORDERING = 'nulls_are_small'
TYPED_DIVISION = False
SAFE_DIVISION = False
DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}
FORMAT_MAPPING: Dict[str, str] = {}
ESCAPE_SEQUENCES: Dict[str, str] = {}
PSEUDOCOLUMNS: Set[str] = set()
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
INVERSE_ESCAPE_SEQUENCES: Dict[str, str] = {}
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START: Optional[str] = None
BIT_END: Optional[str] = None
HEX_START: Optional[str] = None
HEX_END: Optional[str] = None
BYTE_START: Optional[str] = None
BYTE_END: Optional[str] = None
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Dialect:
230    @classmethod
231    def get_or_raise(cls, dialect: DialectType) -> Dialect:
232        """
233        Look up a dialect in the global dialect registry and return it if it exists.
234
235        Args:
236            dialect: The target dialect. If this is a string, it can be optionally followed by
237                additional key-value pairs that are separated by commas and are used to specify
238                dialect settings, such as whether the dialect's identifiers are case-sensitive.
239
240        Example:
241            >>> dialect = dialect_class = get_or_raise("duckdb")
242            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
243
244        Returns:
245            The corresponding Dialect instance.
246        """
247
248        if not dialect:
249            return cls()
250        if isinstance(dialect, _Dialect):
251            return dialect()
252        if isinstance(dialect, Dialect):
253            return dialect
254        if isinstance(dialect, str):
255            try:
256                dialect_name, *kv_pairs = dialect.split(",")
257                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
258            except ValueError:
259                raise ValueError(
260                    f"Invalid dialect format: '{dialect}'. "
261                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
262                )
263
264            result = cls.get(dialect_name.strip())
265            if not result:
266                raise ValueError(f"Unknown dialect '{dialect_name}'.")
267
268            return result(**kwargs)
269
270        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")

Look up a dialect in the global dialect registry and return it if it exists.

Arguments:
  • dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb")
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:

The corresponding Dialect instance.

@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
272    @classmethod
273    def format_time(
274        cls, expression: t.Optional[str | exp.Expression]
275    ) -> t.Optional[exp.Expression]:
276        if isinstance(expression, str):
277            return exp.Literal.string(
278                # the time formats are quoted
279                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
280            )
281
282        if expression and expression.is_string:
283            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
284
285        return expression
def normalize_identifier(self, expression: ~E) -> ~E:
303    def normalize_identifier(self, expression: E) -> E:
304        """
305        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
306
307        For example, an identifier like FoO would be resolved as foo in Postgres, because it
308        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
309        it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive,
310        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
311
312        There are also dialects like Spark, which are case-insensitive even when quotes are
313        present, and dialects like MySQL, whose resolution rules match those employed by the
314        underlying operating system, for example they may always be case-sensitive in Linux.
315
316        Finally, the normalization behavior of some engines can even be controlled through flags,
317        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
318
319        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
320        that it can analyze queries in the optimizer and successfully capture their semantics.
321        """
322        if (
323            isinstance(expression, exp.Identifier)
324            and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
325            and (
326                not expression.quoted
327                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
328            )
329        ):
330            expression.set(
331                "this",
332                expression.this.upper()
333                if self.normalization_strategy is NormalizationStrategy.UPPERCASE
334                else expression.this.lower(),
335            )
336
337        return expression

Transforms an identifier in a way that resembles how it'd be resolved by this dialect.

For example, an identifier like FoO would be resolved as foo in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier.

There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.

Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.

SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.

def case_sensitive(self, text: str) -> bool:
339    def case_sensitive(self, text: str) -> bool:
340        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
341        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
342            return False
343
344        unsafe = (
345            str.islower
346            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
347            else str.isupper
348        )
349        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

def can_identify(self, text: str, identify: str | bool = 'safe') -> bool:
351    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
352        """Checks if text can be identified given an identify option.
353
354        Args:
355            text: The text to check.
356            identify:
357                "always" or `True`: Always returns true.
358                "safe": True if the identifier is case-insensitive.
359
360        Returns:
361            Whether or not the given text can be identified.
362        """
363        if identify is True or identify == "always":
364            return True
365
366        if identify == "safe":
367            return not self.case_sensitive(text)
368
369        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:

Whether or not the given text can be identified.

def quote_identifier(self, expression: ~E, identify: bool = True) -> ~E:
371    def quote_identifier(self, expression: E, identify: bool = True) -> E:
372        if isinstance(expression, exp.Identifier):
373            name = expression.this
374            expression.set(
375                "quoted",
376                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
377            )
378
379        return expression
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
381    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
382        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
384    def parse_into(
385        self, expression_type: exp.IntoType, sql: str, **opts
386    ) -> t.List[t.Optional[exp.Expression]]:
387        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: sqlglot.expressions.Expression, copy: bool = True, **opts) -> str:
389    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
390        return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> List[str]:
392    def transpile(self, sql: str, **opts) -> t.List[str]:
393        return [
394            self.generate(expression, copy=False, **opts) if expression else ""
395            for expression in self.parse(sql)
396        ]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
398    def tokenize(self, sql: str) -> t.List[Token]:
399        return self.tokenizer.tokenize(sql)
def parser(self, **opts) -> sqlglot.parser.Parser:
407    def parser(self, **opts) -> Parser:
408        return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
410    def generator(self, **opts) -> Generator:
411        return self.generator_class(dialect=self, **opts)
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
417def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
418    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
421def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
422    if expression.args.get("accuracy"):
423        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
424    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
427def if_sql(
428    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
429) -> t.Callable[[Generator, exp.If], str]:
430    def _if_sql(self: Generator, expression: exp.If) -> str:
431        return self.func(
432            name,
433            expression.this,
434            expression.args.get("true"),
435            expression.args.get("false") or false_value,
436        )
437
438    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
441def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
442    return self.binary(expression, "->")
def arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
445def arrow_json_extract_scalar_sql(
446    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
447) -> str:
448    return self.binary(expression, "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
451def inline_array_sql(self: Generator, expression: exp.Array) -> str:
452    return f"[{self.expressions(expression, flat=True)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
455def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
456    return self.like_sql(
457        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
458    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
461def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
462    zone = self.sql(expression, "this")
463    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
466def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
467    if expression.args.get("recursive"):
468        self.unsupported("Recursive CTEs are unsupported")
469        expression.args["recursive"] = False
470    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
473def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
474    n = self.sql(expression, "this")
475    d = self.sql(expression, "expression")
476    return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
479def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
480    self.unsupported("TABLESAMPLE unsupported")
481    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
484def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
485    self.unsupported("PIVOT unsupported")
486    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
489def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
490    return self.cast_sql(expression)
def no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
493def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
494    self.unsupported("Properties unsupported")
495    return ""
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
498def no_comment_column_constraint_sql(
499    self: Generator, expression: exp.CommentColumnConstraint
500) -> str:
501    self.unsupported("CommentColumnConstraint unsupported")
502    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
505def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
506    self.unsupported("MAP_FROM_ENTRIES unsupported")
507    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
510def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
511    this = self.sql(expression, "this")
512    substr = self.sql(expression, "substr")
513    position = self.sql(expression, "position")
514    if position:
515        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
516    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
519def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
520    return (
521        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
522    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
525def var_map_sql(
526    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
527) -> str:
528    keys = expression.args["keys"]
529    values = expression.args["values"]
530
531    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
532        self.unsupported("Cannot convert array columns into map.")
533        return self.func(map_func_name, keys, values)
534
535    args = []
536    for key, value in zip(keys.expressions, values.expressions):
537        args.append(self.sql(key))
538        args.append(self.sql(value))
539
540    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
543def format_time_lambda(
544    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
545) -> t.Callable[[t.List], E]:
546    """Helper used for time expressions.
547
548    Args:
549        exp_class: the expression class to instantiate.
550        dialect: target sql dialect.
551        default: the default format, True being time.
552
553    Returns:
554        A callable that can be used to return the appropriately formatted time expression.
555    """
556
557    def _format_time(args: t.List):
558        return exp_class(
559            this=seq_get(args, 0),
560            format=Dialect[dialect].format_time(
561                seq_get(args, 1)
562                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
563            ),
564        )
565
566    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
569def time_format(
570    dialect: DialectType = None,
571) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
572    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
573        """
574        Returns the time format for a given expression, unless it's equivalent
575        to the default time format of the dialect of interest.
576        """
577        time_format = self.format_time(expression)
578        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
579
580    return _time_format
def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
583def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
584    """
585    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
586    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
587    columns are removed from the create statement.
588    """
589    has_schema = isinstance(expression.this, exp.Schema)
590    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
591
592    if has_schema and is_partitionable:
593        prop = expression.find(exp.PartitionedByProperty)
594        if prop and prop.this and not isinstance(prop.this, exp.Schema):
595            schema = expression.this
596            columns = {v.name.upper() for v in prop.this.expressions}
597            partitions = [col for col in schema.expressions if col.name.upper() in columns]
598            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
599            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
600            expression.set("this", schema)
601
602    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
605def parse_date_delta(
606    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
607) -> t.Callable[[t.List], E]:
608    def inner_func(args: t.List) -> E:
609        unit_based = len(args) == 3
610        this = args[2] if unit_based else seq_get(args, 0)
611        unit = args[0] if unit_based else exp.Literal.string("DAY")
612        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
613        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
614
615    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
618def parse_date_delta_with_interval(
619    expression_class: t.Type[E],
620) -> t.Callable[[t.List], t.Optional[E]]:
621    def func(args: t.List) -> t.Optional[E]:
622        if len(args) < 2:
623            return None
624
625        interval = args[1]
626
627        if not isinstance(interval, exp.Interval):
628            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
629
630        expression = interval.this
631        if expression and expression.is_string:
632            expression = exp.Literal.number(expression.this)
633
634        return expression_class(
635            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
636        )
637
638    return func
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
641def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
642    unit = seq_get(args, 0)
643    this = seq_get(args, 1)
644
645    if isinstance(this, exp.Cast) and this.is_type("date"):
646        return exp.DateTrunc(unit=unit, this=this)
647    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
650def date_add_interval_sql(
651    data_type: str, kind: str
652) -> t.Callable[[Generator, exp.Expression], str]:
653    def func(self: Generator, expression: exp.Expression) -> str:
654        this = self.sql(expression, "this")
655        unit = expression.args.get("unit")
656        unit = exp.var(unit.name.upper() if unit else "DAY")
657        interval = exp.Interval(this=expression.expression, unit=unit)
658        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
659
660    return func
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
663def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
664    return self.func(
665        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
666    )
def no_timestamp_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Timestamp) -> str:
669def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
670    if not expression.expression:
671        return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP))
672    if expression.text("expression").lower() in TIMEZONES:
673        return self.sql(
674            exp.AtTimeZone(
675                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
676                zone=expression.expression,
677            )
678        )
679    return self.function_fallback_sql(expression)
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
682def locate_to_strposition(args: t.List) -> exp.Expression:
683    return exp.StrPosition(
684        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
685    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
688def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
689    return self.func(
690        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
691    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
694def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
695    return self.sql(
696        exp.Substring(
697            this=expression.this, start=exp.Literal.number(1), length=expression.expression
698        )
699    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
702def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
703    return self.sql(
704        exp.Substring(
705            this=expression.this,
706            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
707        )
708    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
711def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
712    return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
715def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
716    return self.sql(exp.cast(expression.this, "date"))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
720def encode_decode_sql(
721    self: Generator, expression: exp.Expression, name: str, replace: bool = True
722) -> str:
723    charset = expression.args.get("charset")
724    if charset and charset.name.lower() != "utf-8":
725        self.unsupported(f"Expected utf-8 character set, got {charset}.")
726
727    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
730def min_or_least(self: Generator, expression: exp.Min) -> str:
731    name = "LEAST" if expression.expressions else "MIN"
732    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
735def max_or_greatest(self: Generator, expression: exp.Max) -> str:
736    name = "GREATEST" if expression.expressions else "MAX"
737    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
740def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
741    cond = expression.this
742
743    if isinstance(expression.this, exp.Distinct):
744        cond = expression.this.expressions[0]
745        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
746
747    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
750def trim_sql(self: Generator, expression: exp.Trim) -> str:
751    target = self.sql(expression, "this")
752    trim_type = self.sql(expression, "position")
753    remove_chars = self.sql(expression, "expression")
754    collation = self.sql(expression, "collation")
755
756    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
757    if not remove_chars and not collation:
758        return self.trim_sql(expression)
759
760    trim_type = f"{trim_type} " if trim_type else ""
761    remove_chars = f"{remove_chars} " if remove_chars else ""
762    from_part = "FROM " if trim_type or remove_chars else ""
763    collation = f" COLLATE {collation}" if collation else ""
764    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
767def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
768    return self.func("STRPTIME", expression.this, self.format_time(expression))
def ts_or_ds_to_date_sql(dialect: str) -> Callable:
771def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
772    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
773        _dialect = Dialect.get_or_raise(dialect)
774        time_format = self.format_time(expression)
775        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
776            return self.sql(
777                exp.cast(
778                    exp.StrToTime(this=expression.this, format=expression.args["format"]),
779                    "date",
780                )
781            )
782        return self.sql(exp.cast(expression.this, "date"))
783
784    return _ts_or_ds_to_date_sql
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat) -> str:
787def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
788    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
791def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
792    delim, *rest_args = expression.expressions
793    return self.sql(
794        reduce(
795            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
796            rest_args,
797        )
798    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
801def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
802    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
803    if bad_args:
804        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
805
806    return self.func(
807        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
808    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
811def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
812    bad_args = list(
813        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
814    )
815    if bad_args:
816        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
817
818    return self.func(
819        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
820    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
823def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
824    names = []
825    for agg in aggregations:
826        if isinstance(agg, exp.Alias):
827            names.append(agg.alias)
828        else:
829            """
830            This case corresponds to aggregations without aliases being used as suffixes
831            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
832            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
833            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
834            """
835            agg_all_unquoted = agg.transform(
836                lambda node: exp.Identifier(this=node.name, quoted=False)
837                if isinstance(node, exp.Identifier)
838                else node
839            )
840            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
841
842    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
845def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
846    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def parse_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
850def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
851    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
854def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
855    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
858def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
859    a = self.sql(expression.left)
860    b = self.sql(expression.right)
861    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def json_keyvalue_comma_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONKeyValue) -> str:
865def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str:
866    return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
869def is_parse_json(expression: exp.Expression) -> bool:
870    return isinstance(expression, exp.ParseJSON) or (
871        isinstance(expression, exp.Cast) and expression.is_type("json")
872    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
875def isnull_to_is_null(args: t.List) -> exp.Expression:
876    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def generatedasidentitycolumnconstraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.GeneratedAsIdentityColumnConstraint) -> str:
879def generatedasidentitycolumnconstraint_sql(
880    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
881) -> str:
882    start = self.sql(expression, "start") or "1"
883    increment = self.sql(expression, "increment") or "1"
884    return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.ArgMax | sqlglot.expressions.ArgMin], str]:
887def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
888    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
889        if expression.args.get("count"):
890            self.unsupported(f"Only two arguments are supported in function {name}.")
891
892        return self.func(name, expression.this, expression.expression)
893
894    return _arg_max_or_min_sql
def ts_or_ds_add_cast( expression: sqlglot.expressions.TsOrDsAdd) -> sqlglot.expressions.TsOrDsAdd:
897def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
898    this = expression.this.copy()
899
900    return_type = expression.return_type
901    if return_type.is_type(exp.DataType.Type.DATE):
902        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
903        # can truncate timestamp strings, because some dialects can't cast them to DATE
904        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
905
906    expression.this.replace(exp.cast(this, return_type))
907    return expression
def date_delta_sql( name: str, cast: bool = False) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.DateAdd, sqlglot.expressions.TsOrDsAdd, sqlglot.expressions.DateDiff, sqlglot.expressions.TsOrDsDiff]], str]:
910def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
911    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
912        if cast and isinstance(expression, exp.TsOrDsAdd):
913            expression = ts_or_ds_add_cast(expression)
914
915        return self.func(
916            name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this
917        )
918
919    return _delta_sql