Edit on GitHub

sqlglot.dialects.dialect

   1from __future__ import annotations
   2
   3import logging
   4import typing as t
   5from enum import Enum, auto
   6from functools import reduce
   7
   8from sqlglot import exp
   9from sqlglot.errors import ParseError
  10from sqlglot.generator import Generator
  11from sqlglot.helper import AutoName, flatten, is_int, seq_get
  12from sqlglot.jsonpath import parse as parse_json_path
  13from sqlglot.parser import Parser
  14from sqlglot.time import TIMEZONES, format_time
  15from sqlglot.tokens import Token, Tokenizer, TokenType
  16from sqlglot.trie import new_trie
  17
  18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
  19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
  20
  21if t.TYPE_CHECKING:
  22    from sqlglot._typing import B, E, F
  23
  24    JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
  25
  26logger = logging.getLogger("sqlglot")
  27
  28
  29class Dialects(str, Enum):
  30    """Dialects supported by SQLGLot."""
  31
  32    DIALECT = ""
  33
  34    BIGQUERY = "bigquery"
  35    CLICKHOUSE = "clickhouse"
  36    DATABRICKS = "databricks"
  37    DORIS = "doris"
  38    DRILL = "drill"
  39    DUCKDB = "duckdb"
  40    HIVE = "hive"
  41    MYSQL = "mysql"
  42    ORACLE = "oracle"
  43    POSTGRES = "postgres"
  44    PRESTO = "presto"
  45    REDSHIFT = "redshift"
  46    SNOWFLAKE = "snowflake"
  47    SPARK = "spark"
  48    SPARK2 = "spark2"
  49    SQLITE = "sqlite"
  50    STARROCKS = "starrocks"
  51    TABLEAU = "tableau"
  52    TERADATA = "teradata"
  53    TRINO = "trino"
  54    TSQL = "tsql"
  55
  56
  57class NormalizationStrategy(str, AutoName):
  58    """Specifies the strategy according to which identifiers should be normalized."""
  59
  60    LOWERCASE = auto()
  61    """Unquoted identifiers are lowercased."""
  62
  63    UPPERCASE = auto()
  64    """Unquoted identifiers are uppercased."""
  65
  66    CASE_SENSITIVE = auto()
  67    """Always case-sensitive, regardless of quotes."""
  68
  69    CASE_INSENSITIVE = auto()
  70    """Always case-insensitive, regardless of quotes."""
  71
  72
  73class _Dialect(type):
  74    classes: t.Dict[str, t.Type[Dialect]] = {}
  75
  76    def __eq__(cls, other: t.Any) -> bool:
  77        if cls is other:
  78            return True
  79        if isinstance(other, str):
  80            return cls is cls.get(other)
  81        if isinstance(other, Dialect):
  82            return cls is type(other)
  83
  84        return False
  85
  86    def __hash__(cls) -> int:
  87        return hash(cls.__name__.lower())
  88
  89    @classmethod
  90    def __getitem__(cls, key: str) -> t.Type[Dialect]:
  91        return cls.classes[key]
  92
  93    @classmethod
  94    def get(
  95        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
  96    ) -> t.Optional[t.Type[Dialect]]:
  97        return cls.classes.get(key, default)
  98
  99    def __new__(cls, clsname, bases, attrs):
 100        klass = super().__new__(cls, clsname, bases, attrs)
 101        enum = Dialects.__members__.get(clsname.upper())
 102        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 103
 104        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 105        klass.FORMAT_TRIE = (
 106            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 107        )
 108        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 109        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 110
 111        klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()}
 112
 113        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 114        klass.parser_class = getattr(klass, "Parser", Parser)
 115        klass.generator_class = getattr(klass, "Generator", Generator)
 116
 117        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 118        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 119            klass.tokenizer_class._IDENTIFIERS.items()
 120        )[0]
 121
 122        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 123            return next(
 124                (
 125                    (s, e)
 126                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 127                    if t == token_type
 128                ),
 129                (None, None),
 130            )
 131
 132        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
 133        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
 134        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
 135        klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
 136
 137        if enum not in ("", "bigquery"):
 138            klass.generator_class.SELECT_KINDS = ()
 139
 140        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
 141            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
 142                TokenType.ANTI,
 143                TokenType.SEMI,
 144            }
 145
 146        return klass
 147
 148
 149class Dialect(metaclass=_Dialect):
 150    INDEX_OFFSET = 0
 151    """Determines the base index offset for arrays."""
 152
 153    WEEK_OFFSET = 0
 154    """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
 155
 156    UNNEST_COLUMN_ONLY = False
 157    """Determines whether or not `UNNEST` table aliases are treated as column aliases."""
 158
 159    ALIAS_POST_TABLESAMPLE = False
 160    """Determines whether or not the table alias comes after tablesample."""
 161
 162    TABLESAMPLE_SIZE_IS_PERCENT = False
 163    """Determines whether or not a size in the table sample clause represents percentage."""
 164
 165    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
 166    """Specifies the strategy according to which identifiers should be normalized."""
 167
 168    IDENTIFIERS_CAN_START_WITH_DIGIT = False
 169    """Determines whether or not an unquoted identifier can start with a digit."""
 170
 171    DPIPE_IS_STRING_CONCAT = True
 172    """Determines whether or not the DPIPE token (`||`) is a string concatenation operator."""
 173
 174    STRICT_STRING_CONCAT = False
 175    """Determines whether or not `CONCAT`'s arguments must be strings."""
 176
 177    SUPPORTS_USER_DEFINED_TYPES = True
 178    """Determines whether or not user-defined data types are supported."""
 179
 180    SUPPORTS_SEMI_ANTI_JOIN = True
 181    """Determines whether or not `SEMI` or `ANTI` joins are supported."""
 182
 183    NORMALIZE_FUNCTIONS: bool | str = "upper"
 184    """Determines how function names are going to be normalized."""
 185
 186    LOG_BASE_FIRST = True
 187    """Determines whether the base comes first in the `LOG` function."""
 188
 189    NULL_ORDERING = "nulls_are_small"
 190    """
 191    Indicates the default `NULL` ordering method to use if not explicitly set.
 192    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
 193    """
 194
 195    TYPED_DIVISION = False
 196    """
 197    Whether the behavior of `a / b` depends on the types of `a` and `b`.
 198    False means `a / b` is always float division.
 199    True means `a / b` is integer division if both `a` and `b` are integers.
 200    """
 201
 202    SAFE_DIVISION = False
 203    """Determines whether division by zero throws an error (`False`) or returns NULL (`True`)."""
 204
 205    CONCAT_COALESCE = False
 206    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
 207
 208    DATE_FORMAT = "'%Y-%m-%d'"
 209    DATEINT_FORMAT = "'%Y%m%d'"
 210    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
 211
 212    TIME_MAPPING: t.Dict[str, str] = {}
 213    """Associates this dialect's time formats with their equivalent Python `strftime` format."""
 214
 215    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
 216    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
 217    FORMAT_MAPPING: t.Dict[str, str] = {}
 218    """
 219    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
 220    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
 221    """
 222
 223    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
 224    """Mapping of an unescaped escape sequence to the corresponding character."""
 225
 226    PSEUDOCOLUMNS: t.Set[str] = set()
 227    """
 228    Columns that are auto-generated by the engine corresponding to this dialect.
 229    For example, such columns may be excluded from `SELECT *` queries.
 230    """
 231
 232    PREFER_CTE_ALIAS_COLUMN = False
 233    """
 234    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
 235    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
 236    any projection aliases in the subquery.
 237
 238    For example,
 239        WITH y(c) AS (
 240            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
 241        ) SELECT c FROM y;
 242
 243        will be rewritten as
 244
 245        WITH y(c) AS (
 246            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
 247        ) SELECT c FROM y;
 248    """
 249
 250    # --- Autofilled ---
 251
 252    tokenizer_class = Tokenizer
 253    parser_class = Parser
 254    generator_class = Generator
 255
 256    # A trie of the time_mapping keys
 257    TIME_TRIE: t.Dict = {}
 258    FORMAT_TRIE: t.Dict = {}
 259
 260    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
 261    INVERSE_TIME_TRIE: t.Dict = {}
 262
 263    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
 264
 265    # Delimiters for string literals and identifiers
 266    QUOTE_START = "'"
 267    QUOTE_END = "'"
 268    IDENTIFIER_START = '"'
 269    IDENTIFIER_END = '"'
 270
 271    # Delimiters for bit, hex, byte and unicode literals
 272    BIT_START: t.Optional[str] = None
 273    BIT_END: t.Optional[str] = None
 274    HEX_START: t.Optional[str] = None
 275    HEX_END: t.Optional[str] = None
 276    BYTE_START: t.Optional[str] = None
 277    BYTE_END: t.Optional[str] = None
 278    UNICODE_START: t.Optional[str] = None
 279    UNICODE_END: t.Optional[str] = None
 280
 281    @classmethod
 282    def get_or_raise(cls, dialect: DialectType) -> Dialect:
 283        """
 284        Look up a dialect in the global dialect registry and return it if it exists.
 285
 286        Args:
 287            dialect: The target dialect. If this is a string, it can be optionally followed by
 288                additional key-value pairs that are separated by commas and are used to specify
 289                dialect settings, such as whether the dialect's identifiers are case-sensitive.
 290
 291        Example:
 292            >>> dialect = dialect_class = get_or_raise("duckdb")
 293            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
 294
 295        Returns:
 296            The corresponding Dialect instance.
 297        """
 298
 299        if not dialect:
 300            return cls()
 301        if isinstance(dialect, _Dialect):
 302            return dialect()
 303        if isinstance(dialect, Dialect):
 304            return dialect
 305        if isinstance(dialect, str):
 306            try:
 307                dialect_name, *kv_pairs = dialect.split(",")
 308                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
 309            except ValueError:
 310                raise ValueError(
 311                    f"Invalid dialect format: '{dialect}'. "
 312                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
 313                )
 314
 315            result = cls.get(dialect_name.strip())
 316            if not result:
 317                from difflib import get_close_matches
 318
 319                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
 320                if similar:
 321                    similar = f" Did you mean {similar}?"
 322
 323                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
 324
 325            return result(**kwargs)
 326
 327        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
 328
 329    @classmethod
 330    def format_time(
 331        cls, expression: t.Optional[str | exp.Expression]
 332    ) -> t.Optional[exp.Expression]:
 333        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
 334        if isinstance(expression, str):
 335            return exp.Literal.string(
 336                # the time formats are quoted
 337                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
 338            )
 339
 340        if expression and expression.is_string:
 341            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
 342
 343        return expression
 344
 345    def __init__(self, **kwargs) -> None:
 346        normalization_strategy = kwargs.get("normalization_strategy")
 347
 348        if normalization_strategy is None:
 349            self.normalization_strategy = self.NORMALIZATION_STRATEGY
 350        else:
 351            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
 352
 353    def __eq__(self, other: t.Any) -> bool:
 354        # Does not currently take dialect state into account
 355        return type(self) == other
 356
 357    def __hash__(self) -> int:
 358        # Does not currently take dialect state into account
 359        return hash(type(self))
 360
 361    def normalize_identifier(self, expression: E) -> E:
 362        """
 363        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
 364
 365        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
 366        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
 367        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
 368        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
 369
 370        There are also dialects like Spark, which are case-insensitive even when quotes are
 371        present, and dialects like MySQL, whose resolution rules match those employed by the
 372        underlying operating system, for example they may always be case-sensitive in Linux.
 373
 374        Finally, the normalization behavior of some engines can even be controlled through flags,
 375        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
 376
 377        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
 378        that it can analyze queries in the optimizer and successfully capture their semantics.
 379        """
 380        if (
 381            isinstance(expression, exp.Identifier)
 382            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
 383            and (
 384                not expression.quoted
 385                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
 386            )
 387        ):
 388            expression.set(
 389                "this",
 390                (
 391                    expression.this.upper()
 392                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 393                    else expression.this.lower()
 394                ),
 395            )
 396
 397        return expression
 398
 399    def case_sensitive(self, text: str) -> bool:
 400        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
 401        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
 402            return False
 403
 404        unsafe = (
 405            str.islower
 406            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 407            else str.isupper
 408        )
 409        return any(unsafe(char) for char in text)
 410
 411    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
 412        """Checks if text can be identified given an identify option.
 413
 414        Args:
 415            text: The text to check.
 416            identify:
 417                `"always"` or `True`: Always returns `True`.
 418                `"safe"`: Only returns `True` if the identifier is case-insensitive.
 419
 420        Returns:
 421            Whether or not the given text can be identified.
 422        """
 423        if identify is True or identify == "always":
 424            return True
 425
 426        if identify == "safe":
 427            return not self.case_sensitive(text)
 428
 429        return False
 430
 431    def quote_identifier(self, expression: E, identify: bool = True) -> E:
 432        """
 433        Adds quotes to a given identifier.
 434
 435        Args:
 436            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
 437            identify: If set to `False`, the quotes will only be added if the identifier is deemed
 438                "unsafe", with respect to its characters and this dialect's normalization strategy.
 439        """
 440        if isinstance(expression, exp.Identifier):
 441            name = expression.this
 442            expression.set(
 443                "quoted",
 444                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
 445            )
 446
 447        return expression
 448
 449    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
 450        if isinstance(path, exp.Literal):
 451            path_text = path.name
 452            if path.is_number:
 453                path_text = f"[{path_text}]"
 454
 455            try:
 456                return parse_json_path(path_text)
 457            except ParseError as e:
 458                logger.warning(f"Invalid JSON path syntax. {str(e)}")
 459
 460        return path
 461
 462    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
 463        return self.parser(**opts).parse(self.tokenize(sql), sql)
 464
 465    def parse_into(
 466        self, expression_type: exp.IntoType, sql: str, **opts
 467    ) -> t.List[t.Optional[exp.Expression]]:
 468        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
 469
 470    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
 471        return self.generator(**opts).generate(expression, copy=copy)
 472
 473    def transpile(self, sql: str, **opts) -> t.List[str]:
 474        return [
 475            self.generate(expression, copy=False, **opts) if expression else ""
 476            for expression in self.parse(sql)
 477        ]
 478
 479    def tokenize(self, sql: str) -> t.List[Token]:
 480        return self.tokenizer.tokenize(sql)
 481
 482    @property
 483    def tokenizer(self) -> Tokenizer:
 484        if not hasattr(self, "_tokenizer"):
 485            self._tokenizer = self.tokenizer_class(dialect=self)
 486        return self._tokenizer
 487
 488    def parser(self, **opts) -> Parser:
 489        return self.parser_class(dialect=self, **opts)
 490
 491    def generator(self, **opts) -> Generator:
 492        return self.generator_class(dialect=self, **opts)
 493
 494
 495DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
 496
 497
 498def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
 499    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
 500
 501
 502def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
 503    if expression.args.get("accuracy"):
 504        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
 505    return self.func("APPROX_COUNT_DISTINCT", expression.this)
 506
 507
 508def if_sql(
 509    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
 510) -> t.Callable[[Generator, exp.If], str]:
 511    def _if_sql(self: Generator, expression: exp.If) -> str:
 512        return self.func(
 513            name,
 514            expression.this,
 515            expression.args.get("true"),
 516            expression.args.get("false") or false_value,
 517        )
 518
 519    return _if_sql
 520
 521
 522def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
 523    this = expression.this
 524    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
 525        this.replace(exp.cast(this, "json"))
 526
 527    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
 528
 529
 530def inline_array_sql(self: Generator, expression: exp.Array) -> str:
 531    return f"[{self.expressions(expression, flat=True)}]"
 532
 533
 534def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
 535    return self.like_sql(
 536        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
 537    )
 538
 539
 540def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
 541    zone = self.sql(expression, "this")
 542    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
 543
 544
 545def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
 546    if expression.args.get("recursive"):
 547        self.unsupported("Recursive CTEs are unsupported")
 548        expression.args["recursive"] = False
 549    return self.with_sql(expression)
 550
 551
 552def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
 553    n = self.sql(expression, "this")
 554    d = self.sql(expression, "expression")
 555    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
 556
 557
 558def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
 559    self.unsupported("TABLESAMPLE unsupported")
 560    return self.sql(expression.this)
 561
 562
 563def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
 564    self.unsupported("PIVOT unsupported")
 565    return ""
 566
 567
 568def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
 569    return self.cast_sql(expression)
 570
 571
 572def no_comment_column_constraint_sql(
 573    self: Generator, expression: exp.CommentColumnConstraint
 574) -> str:
 575    self.unsupported("CommentColumnConstraint unsupported")
 576    return ""
 577
 578
 579def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
 580    self.unsupported("MAP_FROM_ENTRIES unsupported")
 581    return ""
 582
 583
 584def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
 585    this = self.sql(expression, "this")
 586    substr = self.sql(expression, "substr")
 587    position = self.sql(expression, "position")
 588    if position:
 589        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
 590    return f"STRPOS({this}, {substr})"
 591
 592
 593def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
 594    return (
 595        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
 596    )
 597
 598
 599def var_map_sql(
 600    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
 601) -> str:
 602    keys = expression.args["keys"]
 603    values = expression.args["values"]
 604
 605    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
 606        self.unsupported("Cannot convert array columns into map.")
 607        return self.func(map_func_name, keys, values)
 608
 609    args = []
 610    for key, value in zip(keys.expressions, values.expressions):
 611        args.append(self.sql(key))
 612        args.append(self.sql(value))
 613
 614    return self.func(map_func_name, *args)
 615
 616
 617def format_time_lambda(
 618    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
 619) -> t.Callable[[t.List], E]:
 620    """Helper used for time expressions.
 621
 622    Args:
 623        exp_class: the expression class to instantiate.
 624        dialect: target sql dialect.
 625        default: the default format, True being time.
 626
 627    Returns:
 628        A callable that can be used to return the appropriately formatted time expression.
 629    """
 630
 631    def _format_time(args: t.List):
 632        return exp_class(
 633            this=seq_get(args, 0),
 634            format=Dialect[dialect].format_time(
 635                seq_get(args, 1)
 636                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
 637            ),
 638        )
 639
 640    return _format_time
 641
 642
 643def time_format(
 644    dialect: DialectType = None,
 645) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
 646    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
 647        """
 648        Returns the time format for a given expression, unless it's equivalent
 649        to the default time format of the dialect of interest.
 650        """
 651        time_format = self.format_time(expression)
 652        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
 653
 654    return _time_format
 655
 656
 657def parse_date_delta(
 658    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
 659) -> t.Callable[[t.List], E]:
 660    def inner_func(args: t.List) -> E:
 661        unit_based = len(args) == 3
 662        this = args[2] if unit_based else seq_get(args, 0)
 663        unit = args[0] if unit_based else exp.Literal.string("DAY")
 664        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
 665        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
 666
 667    return inner_func
 668
 669
 670def parse_date_delta_with_interval(
 671    expression_class: t.Type[E],
 672) -> t.Callable[[t.List], t.Optional[E]]:
 673    def func(args: t.List) -> t.Optional[E]:
 674        if len(args) < 2:
 675            return None
 676
 677        interval = args[1]
 678
 679        if not isinstance(interval, exp.Interval):
 680            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
 681
 682        expression = interval.this
 683        if expression and expression.is_string:
 684            expression = exp.Literal.number(expression.this)
 685
 686        return expression_class(
 687            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
 688        )
 689
 690    return func
 691
 692
 693def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
 694    unit = seq_get(args, 0)
 695    this = seq_get(args, 1)
 696
 697    if isinstance(this, exp.Cast) and this.is_type("date"):
 698        return exp.DateTrunc(unit=unit, this=this)
 699    return exp.TimestampTrunc(this=this, unit=unit)
 700
 701
 702def date_add_interval_sql(
 703    data_type: str, kind: str
 704) -> t.Callable[[Generator, exp.Expression], str]:
 705    def func(self: Generator, expression: exp.Expression) -> str:
 706        this = self.sql(expression, "this")
 707        unit = expression.args.get("unit")
 708        unit = exp.var(unit.name.upper() if unit else "DAY")
 709        interval = exp.Interval(this=expression.expression, unit=unit)
 710        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
 711
 712    return func
 713
 714
 715def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
 716    return self.func(
 717        "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
 718    )
 719
 720
 721def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
 722    if not expression.expression:
 723        from sqlglot.optimizer.annotate_types import annotate_types
 724
 725        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
 726        return self.sql(exp.cast(expression.this, to=target_type))
 727    if expression.text("expression").lower() in TIMEZONES:
 728        return self.sql(
 729            exp.AtTimeZone(
 730                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
 731                zone=expression.expression,
 732            )
 733        )
 734    return self.func("TIMESTAMP", expression.this, expression.expression)
 735
 736
 737def locate_to_strposition(args: t.List) -> exp.Expression:
 738    return exp.StrPosition(
 739        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
 740    )
 741
 742
 743def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
 744    return self.func(
 745        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
 746    )
 747
 748
 749def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 750    return self.sql(
 751        exp.Substring(
 752            this=expression.this, start=exp.Literal.number(1), length=expression.expression
 753        )
 754    )
 755
 756
 757def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 758    return self.sql(
 759        exp.Substring(
 760            this=expression.this,
 761            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
 762        )
 763    )
 764
 765
 766def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
 767    return self.sql(exp.cast(expression.this, "timestamp"))
 768
 769
 770def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
 771    return self.sql(exp.cast(expression.this, "date"))
 772
 773
 774# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
 775def encode_decode_sql(
 776    self: Generator, expression: exp.Expression, name: str, replace: bool = True
 777) -> str:
 778    charset = expression.args.get("charset")
 779    if charset and charset.name.lower() != "utf-8":
 780        self.unsupported(f"Expected utf-8 character set, got {charset}.")
 781
 782    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
 783
 784
 785def min_or_least(self: Generator, expression: exp.Min) -> str:
 786    name = "LEAST" if expression.expressions else "MIN"
 787    return rename_func(name)(self, expression)
 788
 789
 790def max_or_greatest(self: Generator, expression: exp.Max) -> str:
 791    name = "GREATEST" if expression.expressions else "MAX"
 792    return rename_func(name)(self, expression)
 793
 794
 795def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
 796    cond = expression.this
 797
 798    if isinstance(expression.this, exp.Distinct):
 799        cond = expression.this.expressions[0]
 800        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
 801
 802    return self.func("sum", exp.func("if", cond, 1, 0))
 803
 804
 805def trim_sql(self: Generator, expression: exp.Trim) -> str:
 806    target = self.sql(expression, "this")
 807    trim_type = self.sql(expression, "position")
 808    remove_chars = self.sql(expression, "expression")
 809    collation = self.sql(expression, "collation")
 810
 811    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
 812    if not remove_chars and not collation:
 813        return self.trim_sql(expression)
 814
 815    trim_type = f"{trim_type} " if trim_type else ""
 816    remove_chars = f"{remove_chars} " if remove_chars else ""
 817    from_part = "FROM " if trim_type or remove_chars else ""
 818    collation = f" COLLATE {collation}" if collation else ""
 819    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
 820
 821
 822def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
 823    return self.func("STRPTIME", expression.this, self.format_time(expression))
 824
 825
 826def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
 827    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
 828
 829
 830def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
 831    delim, *rest_args = expression.expressions
 832    return self.sql(
 833        reduce(
 834            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
 835            rest_args,
 836        )
 837    )
 838
 839
 840def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
 841    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
 842    if bad_args:
 843        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
 844
 845    return self.func(
 846        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
 847    )
 848
 849
 850def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
 851    bad_args = list(
 852        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
 853    )
 854    if bad_args:
 855        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
 856
 857    return self.func(
 858        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
 859    )
 860
 861
 862def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
 863    names = []
 864    for agg in aggregations:
 865        if isinstance(agg, exp.Alias):
 866            names.append(agg.alias)
 867        else:
 868            """
 869            This case corresponds to aggregations without aliases being used as suffixes
 870            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
 871            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
 872            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
 873            """
 874            agg_all_unquoted = agg.transform(
 875                lambda node: (
 876                    exp.Identifier(this=node.name, quoted=False)
 877                    if isinstance(node, exp.Identifier)
 878                    else node
 879                )
 880            )
 881            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
 882
 883    return names
 884
 885
 886def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
 887    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
 888
 889
 890# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
 891def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
 892    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
 893
 894
 895def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
 896    return self.func("MAX", expression.this)
 897
 898
 899def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
 900    a = self.sql(expression.left)
 901    b = self.sql(expression.right)
 902    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
 903
 904
 905def is_parse_json(expression: exp.Expression) -> bool:
 906    return isinstance(expression, exp.ParseJSON) or (
 907        isinstance(expression, exp.Cast) and expression.is_type("json")
 908    )
 909
 910
 911def isnull_to_is_null(args: t.List) -> exp.Expression:
 912    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
 913
 914
 915def generatedasidentitycolumnconstraint_sql(
 916    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
 917) -> str:
 918    start = self.sql(expression, "start") or "1"
 919    increment = self.sql(expression, "increment") or "1"
 920    return f"IDENTITY({start}, {increment})"
 921
 922
 923def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
 924    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
 925        if expression.args.get("count"):
 926            self.unsupported(f"Only two arguments are supported in function {name}.")
 927
 928        return self.func(name, expression.this, expression.expression)
 929
 930    return _arg_max_or_min_sql
 931
 932
 933def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
 934    this = expression.this.copy()
 935
 936    return_type = expression.return_type
 937    if return_type.is_type(exp.DataType.Type.DATE):
 938        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
 939        # can truncate timestamp strings, because some dialects can't cast them to DATE
 940        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
 941
 942    expression.this.replace(exp.cast(this, return_type))
 943    return expression
 944
 945
 946def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
 947    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
 948        if cast and isinstance(expression, exp.TsOrDsAdd):
 949            expression = ts_or_ds_add_cast(expression)
 950
 951        return self.func(
 952            name,
 953            exp.var(expression.text("unit").upper() or "DAY"),
 954            expression.expression,
 955            expression.this,
 956        )
 957
 958    return _delta_sql
 959
 960
 961def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
 962    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
 963    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
 964    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
 965
 966    return self.sql(exp.cast(minus_one_day, "date"))
 967
 968
 969def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
 970    """Remove table refs from columns in when statements."""
 971    alias = expression.this.args.get("alias")
 972
 973    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
 974        return self.dialect.normalize_identifier(identifier).name if identifier else None
 975
 976    targets = {normalize(expression.this.this)}
 977
 978    if alias:
 979        targets.add(normalize(alias.this))
 980
 981    for when in expression.expressions:
 982        when.transform(
 983            lambda node: (
 984                exp.column(node.this)
 985                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
 986                else node
 987            ),
 988            copy=False,
 989        )
 990
 991    return self.merge_sql(expression)
 992
 993
 994def parse_json_extract_path(
 995    expr_type: t.Type[F], zero_based_indexing: bool = True
 996) -> t.Callable[[t.List], F]:
 997    def _parse_json_extract_path(args: t.List) -> F:
 998        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
 999        for arg in args[1:]:
1000            if not isinstance(arg, exp.Literal):
1001                # We use the fallback parser because we can't really transpile non-literals safely
1002                return expr_type.from_arg_list(args)
1003
1004            text = arg.name
1005            if is_int(text):
1006                index = int(text)
1007                segments.append(
1008                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1009                )
1010            else:
1011                segments.append(exp.JSONPathKey(this=text))
1012
1013        # This is done to avoid failing in the expression validator due to the arg count
1014        del args[2:]
1015        return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments))
1016
1017    return _parse_json_extract_path
1018
1019
1020def json_extract_segments(
1021    name: str, quoted_index: bool = True
1022) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1023    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1024        path = expression.expression
1025        if not isinstance(path, exp.JSONPath):
1026            return rename_func(name)(self, expression)
1027
1028        segments = []
1029        for segment in path.expressions:
1030            path = self.sql(segment)
1031            if path:
1032                if isinstance(segment, exp.JSONPathPart) and (
1033                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1034                ):
1035                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1036
1037                segments.append(path)
1038
1039        return self.func(name, expression.this, *segments)
1040
1041    return _json_extract_segments
1042
1043
1044def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1045    if isinstance(expression.this, exp.JSONPathWildcard):
1046        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1047
1048    return expression.name
logger = <Logger sqlglot (WARNING)>
class Dialects(builtins.str, enum.Enum):
30class Dialects(str, Enum):
31    """Dialects supported by SQLGLot."""
32
33    DIALECT = ""
34
35    BIGQUERY = "bigquery"
36    CLICKHOUSE = "clickhouse"
37    DATABRICKS = "databricks"
38    DORIS = "doris"
39    DRILL = "drill"
40    DUCKDB = "duckdb"
41    HIVE = "hive"
42    MYSQL = "mysql"
43    ORACLE = "oracle"
44    POSTGRES = "postgres"
45    PRESTO = "presto"
46    REDSHIFT = "redshift"
47    SNOWFLAKE = "snowflake"
48    SPARK = "spark"
49    SPARK2 = "spark2"
50    SQLITE = "sqlite"
51    STARROCKS = "starrocks"
52    TABLEAU = "tableau"
53    TERADATA = "teradata"
54    TRINO = "trino"
55    TSQL = "tsql"

Dialects supported by SQLGLot.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DORIS = <Dialects.DORIS: 'doris'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class NormalizationStrategy(builtins.str, sqlglot.helper.AutoName):
58class NormalizationStrategy(str, AutoName):
59    """Specifies the strategy according to which identifiers should be normalized."""
60
61    LOWERCASE = auto()
62    """Unquoted identifiers are lowercased."""
63
64    UPPERCASE = auto()
65    """Unquoted identifiers are uppercased."""
66
67    CASE_SENSITIVE = auto()
68    """Always case-sensitive, regardless of quotes."""
69
70    CASE_INSENSITIVE = auto()
71    """Always case-insensitive, regardless of quotes."""

Specifies the strategy according to which identifiers should be normalized.

LOWERCASE = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Unquoted identifiers are lowercased.

UPPERCASE = <NormalizationStrategy.UPPERCASE: 'UPPERCASE'>

Unquoted identifiers are uppercased.

CASE_SENSITIVE = <NormalizationStrategy.CASE_SENSITIVE: 'CASE_SENSITIVE'>

Always case-sensitive, regardless of quotes.

CASE_INSENSITIVE = <NormalizationStrategy.CASE_INSENSITIVE: 'CASE_INSENSITIVE'>

Always case-insensitive, regardless of quotes.

Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
150class Dialect(metaclass=_Dialect):
151    INDEX_OFFSET = 0
152    """Determines the base index offset for arrays."""
153
154    WEEK_OFFSET = 0
155    """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
156
157    UNNEST_COLUMN_ONLY = False
158    """Determines whether or not `UNNEST` table aliases are treated as column aliases."""
159
160    ALIAS_POST_TABLESAMPLE = False
161    """Determines whether or not the table alias comes after tablesample."""
162
163    TABLESAMPLE_SIZE_IS_PERCENT = False
164    """Determines whether or not a size in the table sample clause represents percentage."""
165
166    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
167    """Specifies the strategy according to which identifiers should be normalized."""
168
169    IDENTIFIERS_CAN_START_WITH_DIGIT = False
170    """Determines whether or not an unquoted identifier can start with a digit."""
171
172    DPIPE_IS_STRING_CONCAT = True
173    """Determines whether or not the DPIPE token (`||`) is a string concatenation operator."""
174
175    STRICT_STRING_CONCAT = False
176    """Determines whether or not `CONCAT`'s arguments must be strings."""
177
178    SUPPORTS_USER_DEFINED_TYPES = True
179    """Determines whether or not user-defined data types are supported."""
180
181    SUPPORTS_SEMI_ANTI_JOIN = True
182    """Determines whether or not `SEMI` or `ANTI` joins are supported."""
183
184    NORMALIZE_FUNCTIONS: bool | str = "upper"
185    """Determines how function names are going to be normalized."""
186
187    LOG_BASE_FIRST = True
188    """Determines whether the base comes first in the `LOG` function."""
189
190    NULL_ORDERING = "nulls_are_small"
191    """
192    Indicates the default `NULL` ordering method to use if not explicitly set.
193    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
194    """
195
196    TYPED_DIVISION = False
197    """
198    Whether the behavior of `a / b` depends on the types of `a` and `b`.
199    False means `a / b` is always float division.
200    True means `a / b` is integer division if both `a` and `b` are integers.
201    """
202
203    SAFE_DIVISION = False
204    """Determines whether division by zero throws an error (`False`) or returns NULL (`True`)."""
205
206    CONCAT_COALESCE = False
207    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
208
209    DATE_FORMAT = "'%Y-%m-%d'"
210    DATEINT_FORMAT = "'%Y%m%d'"
211    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
212
213    TIME_MAPPING: t.Dict[str, str] = {}
214    """Associates this dialect's time formats with their equivalent Python `strftime` format."""
215
216    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
217    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
218    FORMAT_MAPPING: t.Dict[str, str] = {}
219    """
220    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
221    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
222    """
223
224    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
225    """Mapping of an unescaped escape sequence to the corresponding character."""
226
227    PSEUDOCOLUMNS: t.Set[str] = set()
228    """
229    Columns that are auto-generated by the engine corresponding to this dialect.
230    For example, such columns may be excluded from `SELECT *` queries.
231    """
232
233    PREFER_CTE_ALIAS_COLUMN = False
234    """
235    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
236    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
237    any projection aliases in the subquery.
238
239    For example,
240        WITH y(c) AS (
241            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
242        ) SELECT c FROM y;
243
244        will be rewritten as
245
246        WITH y(c) AS (
247            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
248        ) SELECT c FROM y;
249    """
250
251    # --- Autofilled ---
252
253    tokenizer_class = Tokenizer
254    parser_class = Parser
255    generator_class = Generator
256
257    # A trie of the time_mapping keys
258    TIME_TRIE: t.Dict = {}
259    FORMAT_TRIE: t.Dict = {}
260
261    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
262    INVERSE_TIME_TRIE: t.Dict = {}
263
264    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
265
266    # Delimiters for string literals and identifiers
267    QUOTE_START = "'"
268    QUOTE_END = "'"
269    IDENTIFIER_START = '"'
270    IDENTIFIER_END = '"'
271
272    # Delimiters for bit, hex, byte and unicode literals
273    BIT_START: t.Optional[str] = None
274    BIT_END: t.Optional[str] = None
275    HEX_START: t.Optional[str] = None
276    HEX_END: t.Optional[str] = None
277    BYTE_START: t.Optional[str] = None
278    BYTE_END: t.Optional[str] = None
279    UNICODE_START: t.Optional[str] = None
280    UNICODE_END: t.Optional[str] = None
281
282    @classmethod
283    def get_or_raise(cls, dialect: DialectType) -> Dialect:
284        """
285        Look up a dialect in the global dialect registry and return it if it exists.
286
287        Args:
288            dialect: The target dialect. If this is a string, it can be optionally followed by
289                additional key-value pairs that are separated by commas and are used to specify
290                dialect settings, such as whether the dialect's identifiers are case-sensitive.
291
292        Example:
293            >>> dialect = dialect_class = get_or_raise("duckdb")
294            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
295
296        Returns:
297            The corresponding Dialect instance.
298        """
299
300        if not dialect:
301            return cls()
302        if isinstance(dialect, _Dialect):
303            return dialect()
304        if isinstance(dialect, Dialect):
305            return dialect
306        if isinstance(dialect, str):
307            try:
308                dialect_name, *kv_pairs = dialect.split(",")
309                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
310            except ValueError:
311                raise ValueError(
312                    f"Invalid dialect format: '{dialect}'. "
313                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
314                )
315
316            result = cls.get(dialect_name.strip())
317            if not result:
318                from difflib import get_close_matches
319
320                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
321                if similar:
322                    similar = f" Did you mean {similar}?"
323
324                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
325
326            return result(**kwargs)
327
328        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
329
330    @classmethod
331    def format_time(
332        cls, expression: t.Optional[str | exp.Expression]
333    ) -> t.Optional[exp.Expression]:
334        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
335        if isinstance(expression, str):
336            return exp.Literal.string(
337                # the time formats are quoted
338                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
339            )
340
341        if expression and expression.is_string:
342            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
343
344        return expression
345
346    def __init__(self, **kwargs) -> None:
347        normalization_strategy = kwargs.get("normalization_strategy")
348
349        if normalization_strategy is None:
350            self.normalization_strategy = self.NORMALIZATION_STRATEGY
351        else:
352            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
353
354    def __eq__(self, other: t.Any) -> bool:
355        # Does not currently take dialect state into account
356        return type(self) == other
357
358    def __hash__(self) -> int:
359        # Does not currently take dialect state into account
360        return hash(type(self))
361
362    def normalize_identifier(self, expression: E) -> E:
363        """
364        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
365
366        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
367        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
368        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
369        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
370
371        There are also dialects like Spark, which are case-insensitive even when quotes are
372        present, and dialects like MySQL, whose resolution rules match those employed by the
373        underlying operating system, for example they may always be case-sensitive in Linux.
374
375        Finally, the normalization behavior of some engines can even be controlled through flags,
376        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
377
378        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
379        that it can analyze queries in the optimizer and successfully capture their semantics.
380        """
381        if (
382            isinstance(expression, exp.Identifier)
383            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
384            and (
385                not expression.quoted
386                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
387            )
388        ):
389            expression.set(
390                "this",
391                (
392                    expression.this.upper()
393                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
394                    else expression.this.lower()
395                ),
396            )
397
398        return expression
399
400    def case_sensitive(self, text: str) -> bool:
401        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
402        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
403            return False
404
405        unsafe = (
406            str.islower
407            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
408            else str.isupper
409        )
410        return any(unsafe(char) for char in text)
411
412    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
413        """Checks if text can be identified given an identify option.
414
415        Args:
416            text: The text to check.
417            identify:
418                `"always"` or `True`: Always returns `True`.
419                `"safe"`: Only returns `True` if the identifier is case-insensitive.
420
421        Returns:
422            Whether or not the given text can be identified.
423        """
424        if identify is True or identify == "always":
425            return True
426
427        if identify == "safe":
428            return not self.case_sensitive(text)
429
430        return False
431
432    def quote_identifier(self, expression: E, identify: bool = True) -> E:
433        """
434        Adds quotes to a given identifier.
435
436        Args:
437            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
438            identify: If set to `False`, the quotes will only be added if the identifier is deemed
439                "unsafe", with respect to its characters and this dialect's normalization strategy.
440        """
441        if isinstance(expression, exp.Identifier):
442            name = expression.this
443            expression.set(
444                "quoted",
445                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
446            )
447
448        return expression
449
450    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
451        if isinstance(path, exp.Literal):
452            path_text = path.name
453            if path.is_number:
454                path_text = f"[{path_text}]"
455
456            try:
457                return parse_json_path(path_text)
458            except ParseError as e:
459                logger.warning(f"Invalid JSON path syntax. {str(e)}")
460
461        return path
462
463    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
464        return self.parser(**opts).parse(self.tokenize(sql), sql)
465
466    def parse_into(
467        self, expression_type: exp.IntoType, sql: str, **opts
468    ) -> t.List[t.Optional[exp.Expression]]:
469        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
470
471    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
472        return self.generator(**opts).generate(expression, copy=copy)
473
474    def transpile(self, sql: str, **opts) -> t.List[str]:
475        return [
476            self.generate(expression, copy=False, **opts) if expression else ""
477            for expression in self.parse(sql)
478        ]
479
480    def tokenize(self, sql: str) -> t.List[Token]:
481        return self.tokenizer.tokenize(sql)
482
483    @property
484    def tokenizer(self) -> Tokenizer:
485        if not hasattr(self, "_tokenizer"):
486            self._tokenizer = self.tokenizer_class(dialect=self)
487        return self._tokenizer
488
489    def parser(self, **opts) -> Parser:
490        return self.parser_class(dialect=self, **opts)
491
492    def generator(self, **opts) -> Generator:
493        return self.generator_class(dialect=self, **opts)
Dialect(**kwargs)
346    def __init__(self, **kwargs) -> None:
347        normalization_strategy = kwargs.get("normalization_strategy")
348
349        if normalization_strategy is None:
350            self.normalization_strategy = self.NORMALIZATION_STRATEGY
351        else:
352            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
INDEX_OFFSET = 0

Determines the base index offset for arrays.

WEEK_OFFSET = 0

Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.

UNNEST_COLUMN_ONLY = False

Determines whether or not UNNEST table aliases are treated as column aliases.

ALIAS_POST_TABLESAMPLE = False

Determines whether or not the table alias comes after tablesample.

TABLESAMPLE_SIZE_IS_PERCENT = False

Determines whether or not a size in the table sample clause represents percentage.

NORMALIZATION_STRATEGY = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Specifies the strategy according to which identifiers should be normalized.

IDENTIFIERS_CAN_START_WITH_DIGIT = False

Determines whether or not an unquoted identifier can start with a digit.

DPIPE_IS_STRING_CONCAT = True

Determines whether or not the DPIPE token (||) is a string concatenation operator.

STRICT_STRING_CONCAT = False

Determines whether or not CONCAT's arguments must be strings.

SUPPORTS_USER_DEFINED_TYPES = True

Determines whether or not user-defined data types are supported.

SUPPORTS_SEMI_ANTI_JOIN = True

Determines whether or not SEMI or ANTI joins are supported.

NORMALIZE_FUNCTIONS: bool | str = 'upper'

Determines how function names are going to be normalized.

LOG_BASE_FIRST = True

Determines whether the base comes first in the LOG function.

NULL_ORDERING = 'nulls_are_small'

Indicates the default NULL ordering method to use if not explicitly set. Possible values: "nulls_are_small", "nulls_are_large", "nulls_are_last"

TYPED_DIVISION = False

Whether the behavior of a / b depends on the types of a and b. False means a / b is always float division. True means a / b is integer division if both a and b are integers.

SAFE_DIVISION = False

Determines whether division by zero throws an error (False) or returns NULL (True).

CONCAT_COALESCE = False

A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string.

DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}

Associates this dialect's time formats with their equivalent Python strftime format.

FORMAT_MAPPING: Dict[str, str] = {}

Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy'). If empty, the corresponding trie will be constructed off of TIME_MAPPING.

ESCAPE_SEQUENCES: Dict[str, str] = {}

Mapping of an unescaped escape sequence to the corresponding character.

PSEUDOCOLUMNS: Set[str] = set()

Columns that are auto-generated by the engine corresponding to this dialect. For example, such columns may be excluded from SELECT * queries.

PREFER_CTE_ALIAS_COLUMN = False

Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.

For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;

will be rewritten as

WITH y(c) AS (
    SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
INVERSE_ESCAPE_SEQUENCES: Dict[str, str] = {}
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START: Optional[str] = None
BIT_END: Optional[str] = None
HEX_START: Optional[str] = None
HEX_END: Optional[str] = None
BYTE_START: Optional[str] = None
BYTE_END: Optional[str] = None
UNICODE_START: Optional[str] = None
UNICODE_END: Optional[str] = None
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Dialect:
282    @classmethod
283    def get_or_raise(cls, dialect: DialectType) -> Dialect:
284        """
285        Look up a dialect in the global dialect registry and return it if it exists.
286
287        Args:
288            dialect: The target dialect. If this is a string, it can be optionally followed by
289                additional key-value pairs that are separated by commas and are used to specify
290                dialect settings, such as whether the dialect's identifiers are case-sensitive.
291
292        Example:
293            >>> dialect = dialect_class = get_or_raise("duckdb")
294            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
295
296        Returns:
297            The corresponding Dialect instance.
298        """
299
300        if not dialect:
301            return cls()
302        if isinstance(dialect, _Dialect):
303            return dialect()
304        if isinstance(dialect, Dialect):
305            return dialect
306        if isinstance(dialect, str):
307            try:
308                dialect_name, *kv_pairs = dialect.split(",")
309                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
310            except ValueError:
311                raise ValueError(
312                    f"Invalid dialect format: '{dialect}'. "
313                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
314                )
315
316            result = cls.get(dialect_name.strip())
317            if not result:
318                from difflib import get_close_matches
319
320                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
321                if similar:
322                    similar = f" Did you mean {similar}?"
323
324                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
325
326            return result(**kwargs)
327
328        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")

Look up a dialect in the global dialect registry and return it if it exists.

Arguments:
  • dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb")
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:

The corresponding Dialect instance.

@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
330    @classmethod
331    def format_time(
332        cls, expression: t.Optional[str | exp.Expression]
333    ) -> t.Optional[exp.Expression]:
334        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
335        if isinstance(expression, str):
336            return exp.Literal.string(
337                # the time formats are quoted
338                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
339            )
340
341        if expression and expression.is_string:
342            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
343
344        return expression

Converts a time format in this dialect to its equivalent Python strftime format.

def normalize_identifier(self, expression: ~E) -> ~E:
362    def normalize_identifier(self, expression: E) -> E:
363        """
364        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
365
366        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
367        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
368        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
369        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
370
371        There are also dialects like Spark, which are case-insensitive even when quotes are
372        present, and dialects like MySQL, whose resolution rules match those employed by the
373        underlying operating system, for example they may always be case-sensitive in Linux.
374
375        Finally, the normalization behavior of some engines can even be controlled through flags,
376        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
377
378        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
379        that it can analyze queries in the optimizer and successfully capture their semantics.
380        """
381        if (
382            isinstance(expression, exp.Identifier)
383            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
384            and (
385                not expression.quoted
386                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
387            )
388        ):
389            expression.set(
390                "this",
391                (
392                    expression.this.upper()
393                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
394                    else expression.this.lower()
395                ),
396            )
397
398        return expression

Transforms an identifier in a way that resembles how it'd be resolved by this dialect.

For example, an identifier like FoO would be resolved as foo in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier.

There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.

Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.

SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.

def case_sensitive(self, text: str) -> bool:
400    def case_sensitive(self, text: str) -> bool:
401        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
402        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
403            return False
404
405        unsafe = (
406            str.islower
407            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
408            else str.isupper
409        )
410        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

def can_identify(self, text: str, identify: str | bool = 'safe') -> bool:
412    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
413        """Checks if text can be identified given an identify option.
414
415        Args:
416            text: The text to check.
417            identify:
418                `"always"` or `True`: Always returns `True`.
419                `"safe"`: Only returns `True` if the identifier is case-insensitive.
420
421        Returns:
422            Whether or not the given text can be identified.
423        """
424        if identify is True or identify == "always":
425            return True
426
427        if identify == "safe":
428            return not self.case_sensitive(text)
429
430        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns True. "safe": Only returns True if the identifier is case-insensitive.
Returns:

Whether or not the given text can be identified.

def quote_identifier(self, expression: ~E, identify: bool = True) -> ~E:
432    def quote_identifier(self, expression: E, identify: bool = True) -> E:
433        """
434        Adds quotes to a given identifier.
435
436        Args:
437            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
438            identify: If set to `False`, the quotes will only be added if the identifier is deemed
439                "unsafe", with respect to its characters and this dialect's normalization strategy.
440        """
441        if isinstance(expression, exp.Identifier):
442            name = expression.this
443            expression.set(
444                "quoted",
445                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
446            )
447
448        return expression

Adds quotes to a given identifier.

Arguments:
  • expression: The expression of interest. If it's not an Identifier, this method is a no-op.
  • identify: If set to False, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
def to_json_path( self, path: Optional[sqlglot.expressions.Expression]) -> Optional[sqlglot.expressions.Expression]:
450    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
451        if isinstance(path, exp.Literal):
452            path_text = path.name
453            if path.is_number:
454                path_text = f"[{path_text}]"
455
456            try:
457                return parse_json_path(path_text)
458            except ParseError as e:
459                logger.warning(f"Invalid JSON path syntax. {str(e)}")
460
461        return path
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
463    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
464        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
466    def parse_into(
467        self, expression_type: exp.IntoType, sql: str, **opts
468    ) -> t.List[t.Optional[exp.Expression]]:
469        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: sqlglot.expressions.Expression, copy: bool = True, **opts) -> str:
471    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
472        return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> List[str]:
474    def transpile(self, sql: str, **opts) -> t.List[str]:
475        return [
476            self.generate(expression, copy=False, **opts) if expression else ""
477            for expression in self.parse(sql)
478        ]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
480    def tokenize(self, sql: str) -> t.List[Token]:
481        return self.tokenizer.tokenize(sql)
tokenizer: sqlglot.tokens.Tokenizer
483    @property
484    def tokenizer(self) -> Tokenizer:
485        if not hasattr(self, "_tokenizer"):
486            self._tokenizer = self.tokenizer_class(dialect=self)
487        return self._tokenizer
def parser(self, **opts) -> sqlglot.parser.Parser:
489    def parser(self, **opts) -> Parser:
490        return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
492    def generator(self, **opts) -> Generator:
493        return self.generator_class(dialect=self, **opts)
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
499def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
500    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
503def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
504    if expression.args.get("accuracy"):
505        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
506    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
509def if_sql(
510    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
511) -> t.Callable[[Generator, exp.If], str]:
512    def _if_sql(self: Generator, expression: exp.If) -> str:
513        return self.func(
514            name,
515            expression.this,
516            expression.args.get("true"),
517            expression.args.get("false") or false_value,
518        )
519
520    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]) -> str:
523def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
524    this = expression.this
525    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
526        this.replace(exp.cast(this, "json"))
527
528    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
531def inline_array_sql(self: Generator, expression: exp.Array) -> str:
532    return f"[{self.expressions(expression, flat=True)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
535def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
536    return self.like_sql(
537        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
538    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
541def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
542    zone = self.sql(expression, "this")
543    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
546def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
547    if expression.args.get("recursive"):
548        self.unsupported("Recursive CTEs are unsupported")
549        expression.args["recursive"] = False
550    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
553def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
554    n = self.sql(expression, "this")
555    d = self.sql(expression, "expression")
556    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
559def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
560    self.unsupported("TABLESAMPLE unsupported")
561    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
564def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
565    self.unsupported("PIVOT unsupported")
566    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
569def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
570    return self.cast_sql(expression)
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
573def no_comment_column_constraint_sql(
574    self: Generator, expression: exp.CommentColumnConstraint
575) -> str:
576    self.unsupported("CommentColumnConstraint unsupported")
577    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
580def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
581    self.unsupported("MAP_FROM_ENTRIES unsupported")
582    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
585def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
586    this = self.sql(expression, "this")
587    substr = self.sql(expression, "substr")
588    position = self.sql(expression, "position")
589    if position:
590        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
591    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
594def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
595    return (
596        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
597    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
600def var_map_sql(
601    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
602) -> str:
603    keys = expression.args["keys"]
604    values = expression.args["values"]
605
606    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
607        self.unsupported("Cannot convert array columns into map.")
608        return self.func(map_func_name, keys, values)
609
610    args = []
611    for key, value in zip(keys.expressions, values.expressions):
612        args.append(self.sql(key))
613        args.append(self.sql(value))
614
615    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
618def format_time_lambda(
619    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
620) -> t.Callable[[t.List], E]:
621    """Helper used for time expressions.
622
623    Args:
624        exp_class: the expression class to instantiate.
625        dialect: target sql dialect.
626        default: the default format, True being time.
627
628    Returns:
629        A callable that can be used to return the appropriately formatted time expression.
630    """
631
632    def _format_time(args: t.List):
633        return exp_class(
634            this=seq_get(args, 0),
635            format=Dialect[dialect].format_time(
636                seq_get(args, 1)
637                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
638            ),
639        )
640
641    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
644def time_format(
645    dialect: DialectType = None,
646) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
647    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
648        """
649        Returns the time format for a given expression, unless it's equivalent
650        to the default time format of the dialect of interest.
651        """
652        time_format = self.format_time(expression)
653        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
654
655    return _time_format
def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
658def parse_date_delta(
659    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
660) -> t.Callable[[t.List], E]:
661    def inner_func(args: t.List) -> E:
662        unit_based = len(args) == 3
663        this = args[2] if unit_based else seq_get(args, 0)
664        unit = args[0] if unit_based else exp.Literal.string("DAY")
665        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
666        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
667
668    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
671def parse_date_delta_with_interval(
672    expression_class: t.Type[E],
673) -> t.Callable[[t.List], t.Optional[E]]:
674    def func(args: t.List) -> t.Optional[E]:
675        if len(args) < 2:
676            return None
677
678        interval = args[1]
679
680        if not isinstance(interval, exp.Interval):
681            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
682
683        expression = interval.this
684        if expression and expression.is_string:
685            expression = exp.Literal.number(expression.this)
686
687        return expression_class(
688            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
689        )
690
691    return func
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
694def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
695    unit = seq_get(args, 0)
696    this = seq_get(args, 1)
697
698    if isinstance(this, exp.Cast) and this.is_type("date"):
699        return exp.DateTrunc(unit=unit, this=this)
700    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
703def date_add_interval_sql(
704    data_type: str, kind: str
705) -> t.Callable[[Generator, exp.Expression], str]:
706    def func(self: Generator, expression: exp.Expression) -> str:
707        this = self.sql(expression, "this")
708        unit = expression.args.get("unit")
709        unit = exp.var(unit.name.upper() if unit else "DAY")
710        interval = exp.Interval(this=expression.expression, unit=unit)
711        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
712
713    return func
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
716def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
717    return self.func(
718        "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
719    )
def no_timestamp_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Timestamp) -> str:
722def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
723    if not expression.expression:
724        from sqlglot.optimizer.annotate_types import annotate_types
725
726        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
727        return self.sql(exp.cast(expression.this, to=target_type))
728    if expression.text("expression").lower() in TIMEZONES:
729        return self.sql(
730            exp.AtTimeZone(
731                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
732                zone=expression.expression,
733            )
734        )
735    return self.func("TIMESTAMP", expression.this, expression.expression)
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
738def locate_to_strposition(args: t.List) -> exp.Expression:
739    return exp.StrPosition(
740        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
741    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
744def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
745    return self.func(
746        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
747    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
750def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
751    return self.sql(
752        exp.Substring(
753            this=expression.this, start=exp.Literal.number(1), length=expression.expression
754        )
755    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
758def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
759    return self.sql(
760        exp.Substring(
761            this=expression.this,
762            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
763        )
764    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
767def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
768    return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
771def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
772    return self.sql(exp.cast(expression.this, "date"))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
776def encode_decode_sql(
777    self: Generator, expression: exp.Expression, name: str, replace: bool = True
778) -> str:
779    charset = expression.args.get("charset")
780    if charset and charset.name.lower() != "utf-8":
781        self.unsupported(f"Expected utf-8 character set, got {charset}.")
782
783    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
786def min_or_least(self: Generator, expression: exp.Min) -> str:
787    name = "LEAST" if expression.expressions else "MIN"
788    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
791def max_or_greatest(self: Generator, expression: exp.Max) -> str:
792    name = "GREATEST" if expression.expressions else "MAX"
793    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
796def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
797    cond = expression.this
798
799    if isinstance(expression.this, exp.Distinct):
800        cond = expression.this.expressions[0]
801        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
802
803    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
806def trim_sql(self: Generator, expression: exp.Trim) -> str:
807    target = self.sql(expression, "this")
808    trim_type = self.sql(expression, "position")
809    remove_chars = self.sql(expression, "expression")
810    collation = self.sql(expression, "collation")
811
812    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
813    if not remove_chars and not collation:
814        return self.trim_sql(expression)
815
816    trim_type = f"{trim_type} " if trim_type else ""
817    remove_chars = f"{remove_chars} " if remove_chars else ""
818    from_part = "FROM " if trim_type or remove_chars else ""
819    collation = f" COLLATE {collation}" if collation else ""
820    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
823def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
824    return self.func("STRPTIME", expression.this, self.format_time(expression))
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat) -> str:
827def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
828    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
831def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
832    delim, *rest_args = expression.expressions
833    return self.sql(
834        reduce(
835            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
836            rest_args,
837        )
838    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
841def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
842    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
843    if bad_args:
844        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
845
846    return self.func(
847        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
848    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
851def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
852    bad_args = list(
853        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
854    )
855    if bad_args:
856        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
857
858    return self.func(
859        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
860    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
863def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
864    names = []
865    for agg in aggregations:
866        if isinstance(agg, exp.Alias):
867            names.append(agg.alias)
868        else:
869            """
870            This case corresponds to aggregations without aliases being used as suffixes
871            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
872            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
873            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
874            """
875            agg_all_unquoted = agg.transform(
876                lambda node: (
877                    exp.Identifier(this=node.name, quoted=False)
878                    if isinstance(node, exp.Identifier)
879                    else node
880                )
881            )
882            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
883
884    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
887def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
888    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def parse_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
892def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
893    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
896def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
897    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
900def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
901    a = self.sql(expression.left)
902    b = self.sql(expression.right)
903    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
906def is_parse_json(expression: exp.Expression) -> bool:
907    return isinstance(expression, exp.ParseJSON) or (
908        isinstance(expression, exp.Cast) and expression.is_type("json")
909    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
912def isnull_to_is_null(args: t.List) -> exp.Expression:
913    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def generatedasidentitycolumnconstraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.GeneratedAsIdentityColumnConstraint) -> str:
916def generatedasidentitycolumnconstraint_sql(
917    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
918) -> str:
919    start = self.sql(expression, "start") or "1"
920    increment = self.sql(expression, "increment") or "1"
921    return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.ArgMax | sqlglot.expressions.ArgMin], str]:
924def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
925    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
926        if expression.args.get("count"):
927            self.unsupported(f"Only two arguments are supported in function {name}.")
928
929        return self.func(name, expression.this, expression.expression)
930
931    return _arg_max_or_min_sql
def ts_or_ds_add_cast( expression: sqlglot.expressions.TsOrDsAdd) -> sqlglot.expressions.TsOrDsAdd:
934def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
935    this = expression.this.copy()
936
937    return_type = expression.return_type
938    if return_type.is_type(exp.DataType.Type.DATE):
939        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
940        # can truncate timestamp strings, because some dialects can't cast them to DATE
941        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
942
943    expression.this.replace(exp.cast(this, return_type))
944    return expression
def date_delta_sql( name: str, cast: bool = False) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.DateAdd, sqlglot.expressions.TsOrDsAdd, sqlglot.expressions.DateDiff, sqlglot.expressions.TsOrDsDiff]], str]:
947def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
948    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
949        if cast and isinstance(expression, exp.TsOrDsAdd):
950            expression = ts_or_ds_add_cast(expression)
951
952        return self.func(
953            name,
954            exp.var(expression.text("unit").upper() or "DAY"),
955            expression.expression,
956            expression.this,
957        )
958
959    return _delta_sql
def no_last_day_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.LastDay) -> str:
962def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
963    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
964    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
965    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
966
967    return self.sql(exp.cast(minus_one_day, "date"))
def merge_without_target_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Merge) -> str:
970def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
971    """Remove table refs from columns in when statements."""
972    alias = expression.this.args.get("alias")
973
974    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
975        return self.dialect.normalize_identifier(identifier).name if identifier else None
976
977    targets = {normalize(expression.this.this)}
978
979    if alias:
980        targets.add(normalize(alias.this))
981
982    for when in expression.expressions:
983        when.transform(
984            lambda node: (
985                exp.column(node.this)
986                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
987                else node
988            ),
989            copy=False,
990        )
991
992    return self.merge_sql(expression)

Remove table refs from columns in when statements.

def parse_json_extract_path( expr_type: Type[~F], zero_based_indexing: bool = True) -> Callable[[List], ~F]:
 995def parse_json_extract_path(
 996    expr_type: t.Type[F], zero_based_indexing: bool = True
 997) -> t.Callable[[t.List], F]:
 998    def _parse_json_extract_path(args: t.List) -> F:
 999        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1000        for arg in args[1:]:
1001            if not isinstance(arg, exp.Literal):
1002                # We use the fallback parser because we can't really transpile non-literals safely
1003                return expr_type.from_arg_list(args)
1004
1005            text = arg.name
1006            if is_int(text):
1007                index = int(text)
1008                segments.append(
1009                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1010                )
1011            else:
1012                segments.append(exp.JSONPathKey(this=text))
1013
1014        # This is done to avoid failing in the expression validator due to the arg count
1015        del args[2:]
1016        return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments))
1017
1018    return _parse_json_extract_path
def json_extract_segments( name: str, quoted_index: bool = True) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]], str]:
1021def json_extract_segments(
1022    name: str, quoted_index: bool = True
1023) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1024    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1025        path = expression.expression
1026        if not isinstance(path, exp.JSONPath):
1027            return rename_func(name)(self, expression)
1028
1029        segments = []
1030        for segment in path.expressions:
1031            path = self.sql(segment)
1032            if path:
1033                if isinstance(segment, exp.JSONPathPart) and (
1034                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1035                ):
1036                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1037
1038                segments.append(path)
1039
1040        return self.func(name, expression.this, *segments)
1041
1042    return _json_extract_segments
def json_path_key_only_name( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONPathKey) -> str:
1045def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1046    if isinstance(expression.this, exp.JSONPathWildcard):
1047        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1048
1049    return expression.name