Edit on GitHub

sqlglot.dialects.dialect

   1from __future__ import annotations
   2
   3import logging
   4import typing as t
   5from enum import Enum, auto
   6from functools import reduce
   7
   8from sqlglot import exp
   9from sqlglot.errors import ParseError
  10from sqlglot.generator import Generator
  11from sqlglot.helper import AutoName, flatten, is_int, seq_get
  12from sqlglot.jsonpath import parse as parse_json_path
  13from sqlglot.parser import Parser
  14from sqlglot.time import TIMEZONES, format_time
  15from sqlglot.tokens import Token, Tokenizer, TokenType
  16from sqlglot.trie import new_trie
  17
  18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
  19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
  20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
  21
  22
  23if t.TYPE_CHECKING:
  24    from sqlglot._typing import B, E, F
  25
  26logger = logging.getLogger("sqlglot")
  27
  28
  29class Dialects(str, Enum):
  30    """Dialects supported by SQLGLot."""
  31
  32    DIALECT = ""
  33
  34    ATHENA = "athena"
  35    BIGQUERY = "bigquery"
  36    CLICKHOUSE = "clickhouse"
  37    DATABRICKS = "databricks"
  38    DORIS = "doris"
  39    DRILL = "drill"
  40    DUCKDB = "duckdb"
  41    HIVE = "hive"
  42    MYSQL = "mysql"
  43    ORACLE = "oracle"
  44    POSTGRES = "postgres"
  45    PRESTO = "presto"
  46    PRQL = "prql"
  47    REDSHIFT = "redshift"
  48    SNOWFLAKE = "snowflake"
  49    SPARK = "spark"
  50    SPARK2 = "spark2"
  51    SQLITE = "sqlite"
  52    STARROCKS = "starrocks"
  53    TABLEAU = "tableau"
  54    TERADATA = "teradata"
  55    TRINO = "trino"
  56    TSQL = "tsql"
  57
  58
  59class NormalizationStrategy(str, AutoName):
  60    """Specifies the strategy according to which identifiers should be normalized."""
  61
  62    LOWERCASE = auto()
  63    """Unquoted identifiers are lowercased."""
  64
  65    UPPERCASE = auto()
  66    """Unquoted identifiers are uppercased."""
  67
  68    CASE_SENSITIVE = auto()
  69    """Always case-sensitive, regardless of quotes."""
  70
  71    CASE_INSENSITIVE = auto()
  72    """Always case-insensitive, regardless of quotes."""
  73
  74
  75class _Dialect(type):
  76    classes: t.Dict[str, t.Type[Dialect]] = {}
  77
  78    def __eq__(cls, other: t.Any) -> bool:
  79        if cls is other:
  80            return True
  81        if isinstance(other, str):
  82            return cls is cls.get(other)
  83        if isinstance(other, Dialect):
  84            return cls is type(other)
  85
  86        return False
  87
  88    def __hash__(cls) -> int:
  89        return hash(cls.__name__.lower())
  90
  91    @classmethod
  92    def __getitem__(cls, key: str) -> t.Type[Dialect]:
  93        return cls.classes[key]
  94
  95    @classmethod
  96    def get(
  97        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
  98    ) -> t.Optional[t.Type[Dialect]]:
  99        return cls.classes.get(key, default)
 100
 101    def __new__(cls, clsname, bases, attrs):
 102        klass = super().__new__(cls, clsname, bases, attrs)
 103        enum = Dialects.__members__.get(clsname.upper())
 104        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 105
 106        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 107        klass.FORMAT_TRIE = (
 108            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 109        )
 110        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 111        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 112
 113        base = seq_get(bases, 0)
 114        base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),)
 115        base_parser = (getattr(base, "parser_class", Parser),)
 116        base_generator = (getattr(base, "generator_class", Generator),)
 117
 118        klass.tokenizer_class = klass.__dict__.get(
 119            "Tokenizer", type("Tokenizer", base_tokenizer, {})
 120        )
 121        klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {}))
 122        klass.generator_class = klass.__dict__.get(
 123            "Generator", type("Generator", base_generator, {})
 124        )
 125
 126        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 127        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 128            klass.tokenizer_class._IDENTIFIERS.items()
 129        )[0]
 130
 131        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 132            return next(
 133                (
 134                    (s, e)
 135                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 136                    if t == token_type
 137                ),
 138                (None, None),
 139            )
 140
 141        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
 142        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
 143        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
 144        klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
 145
 146        if "\\" in klass.tokenizer_class.STRING_ESCAPES:
 147            klass.UNESCAPED_SEQUENCES = {
 148                "\\a": "\a",
 149                "\\b": "\b",
 150                "\\f": "\f",
 151                "\\n": "\n",
 152                "\\r": "\r",
 153                "\\t": "\t",
 154                "\\v": "\v",
 155                "\\\\": "\\",
 156                **klass.UNESCAPED_SEQUENCES,
 157            }
 158
 159        klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()}
 160
 161        if enum not in ("", "bigquery"):
 162            klass.generator_class.SELECT_KINDS = ()
 163
 164        if enum not in ("", "databricks", "hive", "spark", "spark2"):
 165            modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
 166            for modifier in ("cluster", "distribute", "sort"):
 167                modifier_transforms.pop(modifier, None)
 168
 169            klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
 170
 171        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
 172            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
 173                TokenType.ANTI,
 174                TokenType.SEMI,
 175            }
 176
 177        return klass
 178
 179
 180class Dialect(metaclass=_Dialect):
 181    INDEX_OFFSET = 0
 182    """The base index offset for arrays."""
 183
 184    WEEK_OFFSET = 0
 185    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
 186
 187    UNNEST_COLUMN_ONLY = False
 188    """Whether `UNNEST` table aliases are treated as column aliases."""
 189
 190    ALIAS_POST_TABLESAMPLE = False
 191    """Whether the table alias comes after tablesample."""
 192
 193    TABLESAMPLE_SIZE_IS_PERCENT = False
 194    """Whether a size in the table sample clause represents percentage."""
 195
 196    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
 197    """Specifies the strategy according to which identifiers should be normalized."""
 198
 199    IDENTIFIERS_CAN_START_WITH_DIGIT = False
 200    """Whether an unquoted identifier can start with a digit."""
 201
 202    DPIPE_IS_STRING_CONCAT = True
 203    """Whether the DPIPE token (`||`) is a string concatenation operator."""
 204
 205    STRICT_STRING_CONCAT = False
 206    """Whether `CONCAT`'s arguments must be strings."""
 207
 208    SUPPORTS_USER_DEFINED_TYPES = True
 209    """Whether user-defined data types are supported."""
 210
 211    SUPPORTS_SEMI_ANTI_JOIN = True
 212    """Whether `SEMI` or `ANTI` joins are supported."""
 213
 214    NORMALIZE_FUNCTIONS: bool | str = "upper"
 215    """
 216    Determines how function names are going to be normalized.
 217    Possible values:
 218        "upper" or True: Convert names to uppercase.
 219        "lower": Convert names to lowercase.
 220        False: Disables function name normalization.
 221    """
 222
 223    LOG_BASE_FIRST: t.Optional[bool] = True
 224    """
 225    Whether the base comes first in the `LOG` function.
 226    Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
 227    """
 228
 229    NULL_ORDERING = "nulls_are_small"
 230    """
 231    Default `NULL` ordering method to use if not explicitly set.
 232    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
 233    """
 234
 235    TYPED_DIVISION = False
 236    """
 237    Whether the behavior of `a / b` depends on the types of `a` and `b`.
 238    False means `a / b` is always float division.
 239    True means `a / b` is integer division if both `a` and `b` are integers.
 240    """
 241
 242    SAFE_DIVISION = False
 243    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
 244
 245    CONCAT_COALESCE = False
 246    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
 247
 248    DATE_FORMAT = "'%Y-%m-%d'"
 249    DATEINT_FORMAT = "'%Y%m%d'"
 250    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
 251
 252    TIME_MAPPING: t.Dict[str, str] = {}
 253    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
 254
 255    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
 256    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
 257    FORMAT_MAPPING: t.Dict[str, str] = {}
 258    """
 259    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
 260    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
 261    """
 262
 263    UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
 264    """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
 265
 266    PSEUDOCOLUMNS: t.Set[str] = set()
 267    """
 268    Columns that are auto-generated by the engine corresponding to this dialect.
 269    For example, such columns may be excluded from `SELECT *` queries.
 270    """
 271
 272    PREFER_CTE_ALIAS_COLUMN = False
 273    """
 274    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
 275    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
 276    any projection aliases in the subquery.
 277
 278    For example,
 279        WITH y(c) AS (
 280            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
 281        ) SELECT c FROM y;
 282
 283        will be rewritten as
 284
 285        WITH y(c) AS (
 286            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
 287        ) SELECT c FROM y;
 288    """
 289
 290    # --- Autofilled ---
 291
 292    tokenizer_class = Tokenizer
 293    parser_class = Parser
 294    generator_class = Generator
 295
 296    # A trie of the time_mapping keys
 297    TIME_TRIE: t.Dict = {}
 298    FORMAT_TRIE: t.Dict = {}
 299
 300    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
 301    INVERSE_TIME_TRIE: t.Dict = {}
 302
 303    ESCAPED_SEQUENCES: t.Dict[str, str] = {}
 304
 305    # Delimiters for string literals and identifiers
 306    QUOTE_START = "'"
 307    QUOTE_END = "'"
 308    IDENTIFIER_START = '"'
 309    IDENTIFIER_END = '"'
 310
 311    # Delimiters for bit, hex, byte and unicode literals
 312    BIT_START: t.Optional[str] = None
 313    BIT_END: t.Optional[str] = None
 314    HEX_START: t.Optional[str] = None
 315    HEX_END: t.Optional[str] = None
 316    BYTE_START: t.Optional[str] = None
 317    BYTE_END: t.Optional[str] = None
 318    UNICODE_START: t.Optional[str] = None
 319    UNICODE_END: t.Optional[str] = None
 320
 321    @classmethod
 322    def get_or_raise(cls, dialect: DialectType) -> Dialect:
 323        """
 324        Look up a dialect in the global dialect registry and return it if it exists.
 325
 326        Args:
 327            dialect: The target dialect. If this is a string, it can be optionally followed by
 328                additional key-value pairs that are separated by commas and are used to specify
 329                dialect settings, such as whether the dialect's identifiers are case-sensitive.
 330
 331        Example:
 332            >>> dialect = dialect_class = get_or_raise("duckdb")
 333            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
 334
 335        Returns:
 336            The corresponding Dialect instance.
 337        """
 338
 339        if not dialect:
 340            return cls()
 341        if isinstance(dialect, _Dialect):
 342            return dialect()
 343        if isinstance(dialect, Dialect):
 344            return dialect
 345        if isinstance(dialect, str):
 346            try:
 347                dialect_name, *kv_pairs = dialect.split(",")
 348                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
 349            except ValueError:
 350                raise ValueError(
 351                    f"Invalid dialect format: '{dialect}'. "
 352                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
 353                )
 354
 355            result = cls.get(dialect_name.strip())
 356            if not result:
 357                from difflib import get_close_matches
 358
 359                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
 360                if similar:
 361                    similar = f" Did you mean {similar}?"
 362
 363                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
 364
 365            return result(**kwargs)
 366
 367        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
 368
 369    @classmethod
 370    def format_time(
 371        cls, expression: t.Optional[str | exp.Expression]
 372    ) -> t.Optional[exp.Expression]:
 373        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
 374        if isinstance(expression, str):
 375            return exp.Literal.string(
 376                # the time formats are quoted
 377                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
 378            )
 379
 380        if expression and expression.is_string:
 381            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
 382
 383        return expression
 384
 385    def __init__(self, **kwargs) -> None:
 386        normalization_strategy = kwargs.get("normalization_strategy")
 387
 388        if normalization_strategy is None:
 389            self.normalization_strategy = self.NORMALIZATION_STRATEGY
 390        else:
 391            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
 392
 393    def __eq__(self, other: t.Any) -> bool:
 394        # Does not currently take dialect state into account
 395        return type(self) == other
 396
 397    def __hash__(self) -> int:
 398        # Does not currently take dialect state into account
 399        return hash(type(self))
 400
 401    def normalize_identifier(self, expression: E) -> E:
 402        """
 403        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
 404
 405        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
 406        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
 407        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
 408        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
 409
 410        There are also dialects like Spark, which are case-insensitive even when quotes are
 411        present, and dialects like MySQL, whose resolution rules match those employed by the
 412        underlying operating system, for example they may always be case-sensitive in Linux.
 413
 414        Finally, the normalization behavior of some engines can even be controlled through flags,
 415        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
 416
 417        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
 418        that it can analyze queries in the optimizer and successfully capture their semantics.
 419        """
 420        if (
 421            isinstance(expression, exp.Identifier)
 422            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
 423            and (
 424                not expression.quoted
 425                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
 426            )
 427        ):
 428            expression.set(
 429                "this",
 430                (
 431                    expression.this.upper()
 432                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 433                    else expression.this.lower()
 434                ),
 435            )
 436
 437        return expression
 438
 439    def case_sensitive(self, text: str) -> bool:
 440        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
 441        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
 442            return False
 443
 444        unsafe = (
 445            str.islower
 446            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 447            else str.isupper
 448        )
 449        return any(unsafe(char) for char in text)
 450
 451    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
 452        """Checks if text can be identified given an identify option.
 453
 454        Args:
 455            text: The text to check.
 456            identify:
 457                `"always"` or `True`: Always returns `True`.
 458                `"safe"`: Only returns `True` if the identifier is case-insensitive.
 459
 460        Returns:
 461            Whether the given text can be identified.
 462        """
 463        if identify is True or identify == "always":
 464            return True
 465
 466        if identify == "safe":
 467            return not self.case_sensitive(text)
 468
 469        return False
 470
 471    def quote_identifier(self, expression: E, identify: bool = True) -> E:
 472        """
 473        Adds quotes to a given identifier.
 474
 475        Args:
 476            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
 477            identify: If set to `False`, the quotes will only be added if the identifier is deemed
 478                "unsafe", with respect to its characters and this dialect's normalization strategy.
 479        """
 480        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
 481            name = expression.this
 482            expression.set(
 483                "quoted",
 484                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
 485            )
 486
 487        return expression
 488
 489    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
 490        if isinstance(path, exp.Literal):
 491            path_text = path.name
 492            if path.is_number:
 493                path_text = f"[{path_text}]"
 494
 495            try:
 496                return parse_json_path(path_text)
 497            except ParseError as e:
 498                logger.warning(f"Invalid JSON path syntax. {str(e)}")
 499
 500        return path
 501
 502    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
 503        return self.parser(**opts).parse(self.tokenize(sql), sql)
 504
 505    def parse_into(
 506        self, expression_type: exp.IntoType, sql: str, **opts
 507    ) -> t.List[t.Optional[exp.Expression]]:
 508        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
 509
 510    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
 511        return self.generator(**opts).generate(expression, copy=copy)
 512
 513    def transpile(self, sql: str, **opts) -> t.List[str]:
 514        return [
 515            self.generate(expression, copy=False, **opts) if expression else ""
 516            for expression in self.parse(sql)
 517        ]
 518
 519    def tokenize(self, sql: str) -> t.List[Token]:
 520        return self.tokenizer.tokenize(sql)
 521
 522    @property
 523    def tokenizer(self) -> Tokenizer:
 524        if not hasattr(self, "_tokenizer"):
 525            self._tokenizer = self.tokenizer_class(dialect=self)
 526        return self._tokenizer
 527
 528    def parser(self, **opts) -> Parser:
 529        return self.parser_class(dialect=self, **opts)
 530
 531    def generator(self, **opts) -> Generator:
 532        return self.generator_class(dialect=self, **opts)
 533
 534
 535DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
 536
 537
 538def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
 539    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
 540
 541
 542def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
 543    if expression.args.get("accuracy"):
 544        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
 545    return self.func("APPROX_COUNT_DISTINCT", expression.this)
 546
 547
 548def if_sql(
 549    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
 550) -> t.Callable[[Generator, exp.If], str]:
 551    def _if_sql(self: Generator, expression: exp.If) -> str:
 552        return self.func(
 553            name,
 554            expression.this,
 555            expression.args.get("true"),
 556            expression.args.get("false") or false_value,
 557        )
 558
 559    return _if_sql
 560
 561
 562def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
 563    this = expression.this
 564    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
 565        this.replace(exp.cast(this, exp.DataType.Type.JSON))
 566
 567    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
 568
 569
 570def inline_array_sql(self: Generator, expression: exp.Array) -> str:
 571    return f"[{self.expressions(expression, flat=True)}]"
 572
 573
 574def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
 575    elem = seq_get(expression.expressions, 0)
 576    if isinstance(elem, exp.Expression) and elem.find(exp.Query):
 577        return self.func("ARRAY", elem)
 578    return inline_array_sql(self, expression)
 579
 580
 581def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
 582    return self.like_sql(
 583        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
 584    )
 585
 586
 587def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
 588    zone = self.sql(expression, "this")
 589    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
 590
 591
 592def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
 593    if expression.args.get("recursive"):
 594        self.unsupported("Recursive CTEs are unsupported")
 595        expression.args["recursive"] = False
 596    return self.with_sql(expression)
 597
 598
 599def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
 600    n = self.sql(expression, "this")
 601    d = self.sql(expression, "expression")
 602    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
 603
 604
 605def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
 606    self.unsupported("TABLESAMPLE unsupported")
 607    return self.sql(expression.this)
 608
 609
 610def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
 611    self.unsupported("PIVOT unsupported")
 612    return ""
 613
 614
 615def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
 616    return self.cast_sql(expression)
 617
 618
 619def no_comment_column_constraint_sql(
 620    self: Generator, expression: exp.CommentColumnConstraint
 621) -> str:
 622    self.unsupported("CommentColumnConstraint unsupported")
 623    return ""
 624
 625
 626def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
 627    self.unsupported("MAP_FROM_ENTRIES unsupported")
 628    return ""
 629
 630
 631def str_position_sql(
 632    self: Generator, expression: exp.StrPosition, generate_instance: bool = False
 633) -> str:
 634    this = self.sql(expression, "this")
 635    substr = self.sql(expression, "substr")
 636    position = self.sql(expression, "position")
 637    instance = expression.args.get("instance") if generate_instance else None
 638    position_offset = ""
 639
 640    if position:
 641        # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
 642        this = self.func("SUBSTR", this, position)
 643        position_offset = f" + {position} - 1"
 644
 645    return self.func("STRPOS", this, substr, instance) + position_offset
 646
 647
 648def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
 649    return (
 650        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
 651    )
 652
 653
 654def var_map_sql(
 655    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
 656) -> str:
 657    keys = expression.args["keys"]
 658    values = expression.args["values"]
 659
 660    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
 661        self.unsupported("Cannot convert array columns into map.")
 662        return self.func(map_func_name, keys, values)
 663
 664    args = []
 665    for key, value in zip(keys.expressions, values.expressions):
 666        args.append(self.sql(key))
 667        args.append(self.sql(value))
 668
 669    return self.func(map_func_name, *args)
 670
 671
 672def build_formatted_time(
 673    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
 674) -> t.Callable[[t.List], E]:
 675    """Helper used for time expressions.
 676
 677    Args:
 678        exp_class: the expression class to instantiate.
 679        dialect: target sql dialect.
 680        default: the default format, True being time.
 681
 682    Returns:
 683        A callable that can be used to return the appropriately formatted time expression.
 684    """
 685
 686    def _builder(args: t.List):
 687        return exp_class(
 688            this=seq_get(args, 0),
 689            format=Dialect[dialect].format_time(
 690                seq_get(args, 1)
 691                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
 692            ),
 693        )
 694
 695    return _builder
 696
 697
 698def time_format(
 699    dialect: DialectType = None,
 700) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
 701    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
 702        """
 703        Returns the time format for a given expression, unless it's equivalent
 704        to the default time format of the dialect of interest.
 705        """
 706        time_format = self.format_time(expression)
 707        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
 708
 709    return _time_format
 710
 711
 712def build_date_delta(
 713    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
 714) -> t.Callable[[t.List], E]:
 715    def _builder(args: t.List) -> E:
 716        unit_based = len(args) == 3
 717        this = args[2] if unit_based else seq_get(args, 0)
 718        unit = args[0] if unit_based else exp.Literal.string("DAY")
 719        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
 720        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
 721
 722    return _builder
 723
 724
 725def build_date_delta_with_interval(
 726    expression_class: t.Type[E],
 727) -> t.Callable[[t.List], t.Optional[E]]:
 728    def _builder(args: t.List) -> t.Optional[E]:
 729        if len(args) < 2:
 730            return None
 731
 732        interval = args[1]
 733
 734        if not isinstance(interval, exp.Interval):
 735            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
 736
 737        expression = interval.this
 738        if expression and expression.is_string:
 739            expression = exp.Literal.number(expression.this)
 740
 741        return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
 742
 743    return _builder
 744
 745
 746def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
 747    unit = seq_get(args, 0)
 748    this = seq_get(args, 1)
 749
 750    if isinstance(this, exp.Cast) and this.is_type("date"):
 751        return exp.DateTrunc(unit=unit, this=this)
 752    return exp.TimestampTrunc(this=this, unit=unit)
 753
 754
 755def date_add_interval_sql(
 756    data_type: str, kind: str
 757) -> t.Callable[[Generator, exp.Expression], str]:
 758    def func(self: Generator, expression: exp.Expression) -> str:
 759        this = self.sql(expression, "this")
 760        interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
 761        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
 762
 763    return func
 764
 765
 766def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
 767    return self.func("DATE_TRUNC", unit_to_str(expression), expression.this)
 768
 769
 770def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
 771    if not expression.expression:
 772        from sqlglot.optimizer.annotate_types import annotate_types
 773
 774        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
 775        return self.sql(exp.cast(expression.this, target_type))
 776    if expression.text("expression").lower() in TIMEZONES:
 777        return self.sql(
 778            exp.AtTimeZone(
 779                this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
 780                zone=expression.expression,
 781            )
 782        )
 783    return self.func("TIMESTAMP", expression.this, expression.expression)
 784
 785
 786def locate_to_strposition(args: t.List) -> exp.Expression:
 787    return exp.StrPosition(
 788        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
 789    )
 790
 791
 792def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
 793    return self.func(
 794        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
 795    )
 796
 797
 798def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 799    return self.sql(
 800        exp.Substring(
 801            this=expression.this, start=exp.Literal.number(1), length=expression.expression
 802        )
 803    )
 804
 805
 806def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 807    return self.sql(
 808        exp.Substring(
 809            this=expression.this,
 810            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
 811        )
 812    )
 813
 814
 815def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
 816    return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
 817
 818
 819def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
 820    return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE))
 821
 822
 823# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
 824def encode_decode_sql(
 825    self: Generator, expression: exp.Expression, name: str, replace: bool = True
 826) -> str:
 827    charset = expression.args.get("charset")
 828    if charset and charset.name.lower() != "utf-8":
 829        self.unsupported(f"Expected utf-8 character set, got {charset}.")
 830
 831    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
 832
 833
 834def min_or_least(self: Generator, expression: exp.Min) -> str:
 835    name = "LEAST" if expression.expressions else "MIN"
 836    return rename_func(name)(self, expression)
 837
 838
 839def max_or_greatest(self: Generator, expression: exp.Max) -> str:
 840    name = "GREATEST" if expression.expressions else "MAX"
 841    return rename_func(name)(self, expression)
 842
 843
 844def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
 845    cond = expression.this
 846
 847    if isinstance(expression.this, exp.Distinct):
 848        cond = expression.this.expressions[0]
 849        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
 850
 851    return self.func("sum", exp.func("if", cond, 1, 0))
 852
 853
 854def trim_sql(self: Generator, expression: exp.Trim) -> str:
 855    target = self.sql(expression, "this")
 856    trim_type = self.sql(expression, "position")
 857    remove_chars = self.sql(expression, "expression")
 858    collation = self.sql(expression, "collation")
 859
 860    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
 861    if not remove_chars and not collation:
 862        return self.trim_sql(expression)
 863
 864    trim_type = f"{trim_type} " if trim_type else ""
 865    remove_chars = f"{remove_chars} " if remove_chars else ""
 866    from_part = "FROM " if trim_type or remove_chars else ""
 867    collation = f" COLLATE {collation}" if collation else ""
 868    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
 869
 870
 871def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
 872    return self.func("STRPTIME", expression.this, self.format_time(expression))
 873
 874
 875def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
 876    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
 877
 878
 879def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
 880    delim, *rest_args = expression.expressions
 881    return self.sql(
 882        reduce(
 883            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
 884            rest_args,
 885        )
 886    )
 887
 888
 889def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
 890    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
 891    if bad_args:
 892        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
 893
 894    return self.func(
 895        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
 896    )
 897
 898
 899def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
 900    bad_args = list(
 901        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
 902    )
 903    if bad_args:
 904        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
 905
 906    return self.func(
 907        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
 908    )
 909
 910
 911def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
 912    names = []
 913    for agg in aggregations:
 914        if isinstance(agg, exp.Alias):
 915            names.append(agg.alias)
 916        else:
 917            """
 918            This case corresponds to aggregations without aliases being used as suffixes
 919            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
 920            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
 921            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
 922            """
 923            agg_all_unquoted = agg.transform(
 924                lambda node: (
 925                    exp.Identifier(this=node.name, quoted=False)
 926                    if isinstance(node, exp.Identifier)
 927                    else node
 928                )
 929            )
 930            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
 931
 932    return names
 933
 934
 935def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
 936    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
 937
 938
 939# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
 940def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
 941    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
 942
 943
 944def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
 945    return self.func("MAX", expression.this)
 946
 947
 948def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
 949    a = self.sql(expression.left)
 950    b = self.sql(expression.right)
 951    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
 952
 953
 954def is_parse_json(expression: exp.Expression) -> bool:
 955    return isinstance(expression, exp.ParseJSON) or (
 956        isinstance(expression, exp.Cast) and expression.is_type("json")
 957    )
 958
 959
 960def isnull_to_is_null(args: t.List) -> exp.Expression:
 961    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
 962
 963
 964def generatedasidentitycolumnconstraint_sql(
 965    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
 966) -> str:
 967    start = self.sql(expression, "start") or "1"
 968    increment = self.sql(expression, "increment") or "1"
 969    return f"IDENTITY({start}, {increment})"
 970
 971
 972def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
 973    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
 974        if expression.args.get("count"):
 975            self.unsupported(f"Only two arguments are supported in function {name}.")
 976
 977        return self.func(name, expression.this, expression.expression)
 978
 979    return _arg_max_or_min_sql
 980
 981
 982def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
 983    this = expression.this.copy()
 984
 985    return_type = expression.return_type
 986    if return_type.is_type(exp.DataType.Type.DATE):
 987        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
 988        # can truncate timestamp strings, because some dialects can't cast them to DATE
 989        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
 990
 991    expression.this.replace(exp.cast(this, return_type))
 992    return expression
 993
 994
 995def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
 996    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
 997        if cast and isinstance(expression, exp.TsOrDsAdd):
 998            expression = ts_or_ds_add_cast(expression)
 999
1000        return self.func(
1001            name,
1002            unit_to_var(expression),
1003            expression.expression,
1004            expression.this,
1005        )
1006
1007    return _delta_sql
1008
1009
1010def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1011    unit = expression.args.get("unit")
1012
1013    if isinstance(unit, exp.Placeholder):
1014        return unit
1015    if unit:
1016        return exp.Literal.string(unit.name)
1017    return exp.Literal.string(default) if default else None
1018
1019
1020def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1021    unit = expression.args.get("unit")
1022
1023    if isinstance(unit, (exp.Var, exp.Placeholder)):
1024        return unit
1025    return exp.Var(this=default) if default else None
1026
1027
1028def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
1029    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
1030    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
1031    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
1032
1033    return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1034
1035
1036def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1037    """Remove table refs from columns in when statements."""
1038    alias = expression.this.args.get("alias")
1039
1040    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
1041        return self.dialect.normalize_identifier(identifier).name if identifier else None
1042
1043    targets = {normalize(expression.this.this)}
1044
1045    if alias:
1046        targets.add(normalize(alias.this))
1047
1048    for when in expression.expressions:
1049        when.transform(
1050            lambda node: (
1051                exp.column(node.this)
1052                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1053                else node
1054            ),
1055            copy=False,
1056        )
1057
1058    return self.merge_sql(expression)
1059
1060
1061def build_json_extract_path(
1062    expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
1063) -> t.Callable[[t.List], F]:
1064    def _builder(args: t.List) -> F:
1065        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1066        for arg in args[1:]:
1067            if not isinstance(arg, exp.Literal):
1068                # We use the fallback parser because we can't really transpile non-literals safely
1069                return expr_type.from_arg_list(args)
1070
1071            text = arg.name
1072            if is_int(text):
1073                index = int(text)
1074                segments.append(
1075                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1076                )
1077            else:
1078                segments.append(exp.JSONPathKey(this=text))
1079
1080        # This is done to avoid failing in the expression validator due to the arg count
1081        del args[2:]
1082        return expr_type(
1083            this=seq_get(args, 0),
1084            expression=exp.JSONPath(expressions=segments),
1085            only_json_types=arrow_req_json_type,
1086        )
1087
1088    return _builder
1089
1090
1091def json_extract_segments(
1092    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1093) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1094    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1095        path = expression.expression
1096        if not isinstance(path, exp.JSONPath):
1097            return rename_func(name)(self, expression)
1098
1099        segments = []
1100        for segment in path.expressions:
1101            path = self.sql(segment)
1102            if path:
1103                if isinstance(segment, exp.JSONPathPart) and (
1104                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1105                ):
1106                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1107
1108                segments.append(path)
1109
1110        if op:
1111            return f" {op} ".join([self.sql(expression.this), *segments])
1112        return self.func(name, expression.this, *segments)
1113
1114    return _json_extract_segments
1115
1116
1117def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1118    if isinstance(expression.this, exp.JSONPathWildcard):
1119        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1120
1121    return expression.name
1122
1123
1124def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1125    cond = expression.expression
1126    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1127        alias = cond.expressions[0]
1128        cond = cond.this
1129    elif isinstance(cond, exp.Predicate):
1130        alias = "_u"
1131    else:
1132        self.unsupported("Unsupported filter condition")
1133        return ""
1134
1135    unnest = exp.Unnest(expressions=[expression.this])
1136    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1137    return self.sql(exp.Array(expressions=[filtered]))
1138
1139
1140def to_number_with_nls_param(self, expression: exp.ToNumber) -> str:
1141    return self.func(
1142        "TO_NUMBER",
1143        expression.this,
1144        expression.args.get("format"),
1145        expression.args.get("nlsparam"),
1146    )
logger = <Logger sqlglot (WARNING)>
class Dialects(builtins.str, enum.Enum):
30class Dialects(str, Enum):
31    """Dialects supported by SQLGLot."""
32
33    DIALECT = ""
34
35    ATHENA = "athena"
36    BIGQUERY = "bigquery"
37    CLICKHOUSE = "clickhouse"
38    DATABRICKS = "databricks"
39    DORIS = "doris"
40    DRILL = "drill"
41    DUCKDB = "duckdb"
42    HIVE = "hive"
43    MYSQL = "mysql"
44    ORACLE = "oracle"
45    POSTGRES = "postgres"
46    PRESTO = "presto"
47    PRQL = "prql"
48    REDSHIFT = "redshift"
49    SNOWFLAKE = "snowflake"
50    SPARK = "spark"
51    SPARK2 = "spark2"
52    SQLITE = "sqlite"
53    STARROCKS = "starrocks"
54    TABLEAU = "tableau"
55    TERADATA = "teradata"
56    TRINO = "trino"
57    TSQL = "tsql"

Dialects supported by SQLGLot.

DIALECT = <Dialects.DIALECT: ''>
ATHENA = <Dialects.ATHENA: 'athena'>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DORIS = <Dialects.DORIS: 'doris'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
PRQL = <Dialects.PRQL: 'prql'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class NormalizationStrategy(builtins.str, sqlglot.helper.AutoName):
60class NormalizationStrategy(str, AutoName):
61    """Specifies the strategy according to which identifiers should be normalized."""
62
63    LOWERCASE = auto()
64    """Unquoted identifiers are lowercased."""
65
66    UPPERCASE = auto()
67    """Unquoted identifiers are uppercased."""
68
69    CASE_SENSITIVE = auto()
70    """Always case-sensitive, regardless of quotes."""
71
72    CASE_INSENSITIVE = auto()
73    """Always case-insensitive, regardless of quotes."""

Specifies the strategy according to which identifiers should be normalized.

LOWERCASE = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Unquoted identifiers are lowercased.

UPPERCASE = <NormalizationStrategy.UPPERCASE: 'UPPERCASE'>

Unquoted identifiers are uppercased.

CASE_SENSITIVE = <NormalizationStrategy.CASE_SENSITIVE: 'CASE_SENSITIVE'>

Always case-sensitive, regardless of quotes.

CASE_INSENSITIVE = <NormalizationStrategy.CASE_INSENSITIVE: 'CASE_INSENSITIVE'>

Always case-insensitive, regardless of quotes.

Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
181class Dialect(metaclass=_Dialect):
182    INDEX_OFFSET = 0
183    """The base index offset for arrays."""
184
185    WEEK_OFFSET = 0
186    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
187
188    UNNEST_COLUMN_ONLY = False
189    """Whether `UNNEST` table aliases are treated as column aliases."""
190
191    ALIAS_POST_TABLESAMPLE = False
192    """Whether the table alias comes after tablesample."""
193
194    TABLESAMPLE_SIZE_IS_PERCENT = False
195    """Whether a size in the table sample clause represents percentage."""
196
197    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
198    """Specifies the strategy according to which identifiers should be normalized."""
199
200    IDENTIFIERS_CAN_START_WITH_DIGIT = False
201    """Whether an unquoted identifier can start with a digit."""
202
203    DPIPE_IS_STRING_CONCAT = True
204    """Whether the DPIPE token (`||`) is a string concatenation operator."""
205
206    STRICT_STRING_CONCAT = False
207    """Whether `CONCAT`'s arguments must be strings."""
208
209    SUPPORTS_USER_DEFINED_TYPES = True
210    """Whether user-defined data types are supported."""
211
212    SUPPORTS_SEMI_ANTI_JOIN = True
213    """Whether `SEMI` or `ANTI` joins are supported."""
214
215    NORMALIZE_FUNCTIONS: bool | str = "upper"
216    """
217    Determines how function names are going to be normalized.
218    Possible values:
219        "upper" or True: Convert names to uppercase.
220        "lower": Convert names to lowercase.
221        False: Disables function name normalization.
222    """
223
224    LOG_BASE_FIRST: t.Optional[bool] = True
225    """
226    Whether the base comes first in the `LOG` function.
227    Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
228    """
229
230    NULL_ORDERING = "nulls_are_small"
231    """
232    Default `NULL` ordering method to use if not explicitly set.
233    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
234    """
235
236    TYPED_DIVISION = False
237    """
238    Whether the behavior of `a / b` depends on the types of `a` and `b`.
239    False means `a / b` is always float division.
240    True means `a / b` is integer division if both `a` and `b` are integers.
241    """
242
243    SAFE_DIVISION = False
244    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
245
246    CONCAT_COALESCE = False
247    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
248
249    DATE_FORMAT = "'%Y-%m-%d'"
250    DATEINT_FORMAT = "'%Y%m%d'"
251    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
252
253    TIME_MAPPING: t.Dict[str, str] = {}
254    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
255
256    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
257    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
258    FORMAT_MAPPING: t.Dict[str, str] = {}
259    """
260    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
261    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
262    """
263
264    UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
265    """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
266
267    PSEUDOCOLUMNS: t.Set[str] = set()
268    """
269    Columns that are auto-generated by the engine corresponding to this dialect.
270    For example, such columns may be excluded from `SELECT *` queries.
271    """
272
273    PREFER_CTE_ALIAS_COLUMN = False
274    """
275    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
276    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
277    any projection aliases in the subquery.
278
279    For example,
280        WITH y(c) AS (
281            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
282        ) SELECT c FROM y;
283
284        will be rewritten as
285
286        WITH y(c) AS (
287            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
288        ) SELECT c FROM y;
289    """
290
291    # --- Autofilled ---
292
293    tokenizer_class = Tokenizer
294    parser_class = Parser
295    generator_class = Generator
296
297    # A trie of the time_mapping keys
298    TIME_TRIE: t.Dict = {}
299    FORMAT_TRIE: t.Dict = {}
300
301    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
302    INVERSE_TIME_TRIE: t.Dict = {}
303
304    ESCAPED_SEQUENCES: t.Dict[str, str] = {}
305
306    # Delimiters for string literals and identifiers
307    QUOTE_START = "'"
308    QUOTE_END = "'"
309    IDENTIFIER_START = '"'
310    IDENTIFIER_END = '"'
311
312    # Delimiters for bit, hex, byte and unicode literals
313    BIT_START: t.Optional[str] = None
314    BIT_END: t.Optional[str] = None
315    HEX_START: t.Optional[str] = None
316    HEX_END: t.Optional[str] = None
317    BYTE_START: t.Optional[str] = None
318    BYTE_END: t.Optional[str] = None
319    UNICODE_START: t.Optional[str] = None
320    UNICODE_END: t.Optional[str] = None
321
322    @classmethod
323    def get_or_raise(cls, dialect: DialectType) -> Dialect:
324        """
325        Look up a dialect in the global dialect registry and return it if it exists.
326
327        Args:
328            dialect: The target dialect. If this is a string, it can be optionally followed by
329                additional key-value pairs that are separated by commas and are used to specify
330                dialect settings, such as whether the dialect's identifiers are case-sensitive.
331
332        Example:
333            >>> dialect = dialect_class = get_or_raise("duckdb")
334            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
335
336        Returns:
337            The corresponding Dialect instance.
338        """
339
340        if not dialect:
341            return cls()
342        if isinstance(dialect, _Dialect):
343            return dialect()
344        if isinstance(dialect, Dialect):
345            return dialect
346        if isinstance(dialect, str):
347            try:
348                dialect_name, *kv_pairs = dialect.split(",")
349                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
350            except ValueError:
351                raise ValueError(
352                    f"Invalid dialect format: '{dialect}'. "
353                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
354                )
355
356            result = cls.get(dialect_name.strip())
357            if not result:
358                from difflib import get_close_matches
359
360                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
361                if similar:
362                    similar = f" Did you mean {similar}?"
363
364                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
365
366            return result(**kwargs)
367
368        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
369
370    @classmethod
371    def format_time(
372        cls, expression: t.Optional[str | exp.Expression]
373    ) -> t.Optional[exp.Expression]:
374        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
375        if isinstance(expression, str):
376            return exp.Literal.string(
377                # the time formats are quoted
378                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
379            )
380
381        if expression and expression.is_string:
382            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
383
384        return expression
385
386    def __init__(self, **kwargs) -> None:
387        normalization_strategy = kwargs.get("normalization_strategy")
388
389        if normalization_strategy is None:
390            self.normalization_strategy = self.NORMALIZATION_STRATEGY
391        else:
392            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
393
394    def __eq__(self, other: t.Any) -> bool:
395        # Does not currently take dialect state into account
396        return type(self) == other
397
398    def __hash__(self) -> int:
399        # Does not currently take dialect state into account
400        return hash(type(self))
401
402    def normalize_identifier(self, expression: E) -> E:
403        """
404        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
405
406        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
407        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
408        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
409        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
410
411        There are also dialects like Spark, which are case-insensitive even when quotes are
412        present, and dialects like MySQL, whose resolution rules match those employed by the
413        underlying operating system, for example they may always be case-sensitive in Linux.
414
415        Finally, the normalization behavior of some engines can even be controlled through flags,
416        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
417
418        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
419        that it can analyze queries in the optimizer and successfully capture their semantics.
420        """
421        if (
422            isinstance(expression, exp.Identifier)
423            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
424            and (
425                not expression.quoted
426                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
427            )
428        ):
429            expression.set(
430                "this",
431                (
432                    expression.this.upper()
433                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
434                    else expression.this.lower()
435                ),
436            )
437
438        return expression
439
440    def case_sensitive(self, text: str) -> bool:
441        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
442        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
443            return False
444
445        unsafe = (
446            str.islower
447            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
448            else str.isupper
449        )
450        return any(unsafe(char) for char in text)
451
452    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
453        """Checks if text can be identified given an identify option.
454
455        Args:
456            text: The text to check.
457            identify:
458                `"always"` or `True`: Always returns `True`.
459                `"safe"`: Only returns `True` if the identifier is case-insensitive.
460
461        Returns:
462            Whether the given text can be identified.
463        """
464        if identify is True or identify == "always":
465            return True
466
467        if identify == "safe":
468            return not self.case_sensitive(text)
469
470        return False
471
472    def quote_identifier(self, expression: E, identify: bool = True) -> E:
473        """
474        Adds quotes to a given identifier.
475
476        Args:
477            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
478            identify: If set to `False`, the quotes will only be added if the identifier is deemed
479                "unsafe", with respect to its characters and this dialect's normalization strategy.
480        """
481        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
482            name = expression.this
483            expression.set(
484                "quoted",
485                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
486            )
487
488        return expression
489
490    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
491        if isinstance(path, exp.Literal):
492            path_text = path.name
493            if path.is_number:
494                path_text = f"[{path_text}]"
495
496            try:
497                return parse_json_path(path_text)
498            except ParseError as e:
499                logger.warning(f"Invalid JSON path syntax. {str(e)}")
500
501        return path
502
503    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
504        return self.parser(**opts).parse(self.tokenize(sql), sql)
505
506    def parse_into(
507        self, expression_type: exp.IntoType, sql: str, **opts
508    ) -> t.List[t.Optional[exp.Expression]]:
509        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
510
511    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
512        return self.generator(**opts).generate(expression, copy=copy)
513
514    def transpile(self, sql: str, **opts) -> t.List[str]:
515        return [
516            self.generate(expression, copy=False, **opts) if expression else ""
517            for expression in self.parse(sql)
518        ]
519
520    def tokenize(self, sql: str) -> t.List[Token]:
521        return self.tokenizer.tokenize(sql)
522
523    @property
524    def tokenizer(self) -> Tokenizer:
525        if not hasattr(self, "_tokenizer"):
526            self._tokenizer = self.tokenizer_class(dialect=self)
527        return self._tokenizer
528
529    def parser(self, **opts) -> Parser:
530        return self.parser_class(dialect=self, **opts)
531
532    def generator(self, **opts) -> Generator:
533        return self.generator_class(dialect=self, **opts)
Dialect(**kwargs)
386    def __init__(self, **kwargs) -> None:
387        normalization_strategy = kwargs.get("normalization_strategy")
388
389        if normalization_strategy is None:
390            self.normalization_strategy = self.NORMALIZATION_STRATEGY
391        else:
392            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
INDEX_OFFSET = 0

The base index offset for arrays.

WEEK_OFFSET = 0

First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.

UNNEST_COLUMN_ONLY = False

Whether UNNEST table aliases are treated as column aliases.

ALIAS_POST_TABLESAMPLE = False

Whether the table alias comes after tablesample.

TABLESAMPLE_SIZE_IS_PERCENT = False

Whether a size in the table sample clause represents percentage.

NORMALIZATION_STRATEGY = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Specifies the strategy according to which identifiers should be normalized.

IDENTIFIERS_CAN_START_WITH_DIGIT = False

Whether an unquoted identifier can start with a digit.

DPIPE_IS_STRING_CONCAT = True

Whether the DPIPE token (||) is a string concatenation operator.

STRICT_STRING_CONCAT = False

Whether CONCAT's arguments must be strings.

SUPPORTS_USER_DEFINED_TYPES = True

Whether user-defined data types are supported.

SUPPORTS_SEMI_ANTI_JOIN = True

Whether SEMI or ANTI joins are supported.

NORMALIZE_FUNCTIONS: bool | str = 'upper'

Determines how function names are going to be normalized.

Possible values:

"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.

LOG_BASE_FIRST: Optional[bool] = True

Whether the base comes first in the LOG function. Possible values: True, False, None (two arguments are not supported by LOG)

NULL_ORDERING = 'nulls_are_small'

Default NULL ordering method to use if not explicitly set. Possible values: "nulls_are_small", "nulls_are_large", "nulls_are_last"

TYPED_DIVISION = False

Whether the behavior of a / b depends on the types of a and b. False means a / b is always float division. True means a / b is integer division if both a and b are integers.

SAFE_DIVISION = False

Whether division by zero throws an error (False) or returns NULL (True).

CONCAT_COALESCE = False

A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string.

DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}

Associates this dialect's time formats with their equivalent Python strftime formats.

FORMAT_MAPPING: Dict[str, str] = {}

Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy'). If empty, the corresponding trie will be constructed off of TIME_MAPPING.

UNESCAPED_SEQUENCES: Dict[str, str] = {}

Mapping of an escaped sequence (\n) to its unescaped version ( ).

PSEUDOCOLUMNS: Set[str] = set()

Columns that are auto-generated by the engine corresponding to this dialect. For example, such columns may be excluded from SELECT * queries.

PREFER_CTE_ALIAS_COLUMN = False

Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.

For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;

will be rewritten as

WITH y(c) AS (
    SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
ESCAPED_SEQUENCES: Dict[str, str] = {}
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START: Optional[str] = None
BIT_END: Optional[str] = None
HEX_START: Optional[str] = None
HEX_END: Optional[str] = None
BYTE_START: Optional[str] = None
BYTE_END: Optional[str] = None
UNICODE_START: Optional[str] = None
UNICODE_END: Optional[str] = None
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Dialect:
322    @classmethod
323    def get_or_raise(cls, dialect: DialectType) -> Dialect:
324        """
325        Look up a dialect in the global dialect registry and return it if it exists.
326
327        Args:
328            dialect: The target dialect. If this is a string, it can be optionally followed by
329                additional key-value pairs that are separated by commas and are used to specify
330                dialect settings, such as whether the dialect's identifiers are case-sensitive.
331
332        Example:
333            >>> dialect = dialect_class = get_or_raise("duckdb")
334            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
335
336        Returns:
337            The corresponding Dialect instance.
338        """
339
340        if not dialect:
341            return cls()
342        if isinstance(dialect, _Dialect):
343            return dialect()
344        if isinstance(dialect, Dialect):
345            return dialect
346        if isinstance(dialect, str):
347            try:
348                dialect_name, *kv_pairs = dialect.split(",")
349                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
350            except ValueError:
351                raise ValueError(
352                    f"Invalid dialect format: '{dialect}'. "
353                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
354                )
355
356            result = cls.get(dialect_name.strip())
357            if not result:
358                from difflib import get_close_matches
359
360                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
361                if similar:
362                    similar = f" Did you mean {similar}?"
363
364                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
365
366            return result(**kwargs)
367
368        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")

Look up a dialect in the global dialect registry and return it if it exists.

Arguments:
  • dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb")
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:

The corresponding Dialect instance.

@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
370    @classmethod
371    def format_time(
372        cls, expression: t.Optional[str | exp.Expression]
373    ) -> t.Optional[exp.Expression]:
374        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
375        if isinstance(expression, str):
376            return exp.Literal.string(
377                # the time formats are quoted
378                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
379            )
380
381        if expression and expression.is_string:
382            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
383
384        return expression

Converts a time format in this dialect to its equivalent Python strftime format.

def normalize_identifier(self, expression: ~E) -> ~E:
402    def normalize_identifier(self, expression: E) -> E:
403        """
404        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
405
406        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
407        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
408        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
409        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
410
411        There are also dialects like Spark, which are case-insensitive even when quotes are
412        present, and dialects like MySQL, whose resolution rules match those employed by the
413        underlying operating system, for example they may always be case-sensitive in Linux.
414
415        Finally, the normalization behavior of some engines can even be controlled through flags,
416        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
417
418        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
419        that it can analyze queries in the optimizer and successfully capture their semantics.
420        """
421        if (
422            isinstance(expression, exp.Identifier)
423            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
424            and (
425                not expression.quoted
426                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
427            )
428        ):
429            expression.set(
430                "this",
431                (
432                    expression.this.upper()
433                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
434                    else expression.this.lower()
435                ),
436            )
437
438        return expression

Transforms an identifier in a way that resembles how it'd be resolved by this dialect.

For example, an identifier like FoO would be resolved as foo in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier.

There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.

Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.

SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.

def case_sensitive(self, text: str) -> bool:
440    def case_sensitive(self, text: str) -> bool:
441        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
442        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
443            return False
444
445        unsafe = (
446            str.islower
447            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
448            else str.isupper
449        )
450        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

def can_identify(self, text: str, identify: str | bool = 'safe') -> bool:
452    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
453        """Checks if text can be identified given an identify option.
454
455        Args:
456            text: The text to check.
457            identify:
458                `"always"` or `True`: Always returns `True`.
459                `"safe"`: Only returns `True` if the identifier is case-insensitive.
460
461        Returns:
462            Whether the given text can be identified.
463        """
464        if identify is True or identify == "always":
465            return True
466
467        if identify == "safe":
468            return not self.case_sensitive(text)
469
470        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns True. "safe": Only returns True if the identifier is case-insensitive.
Returns:

Whether the given text can be identified.

def quote_identifier(self, expression: ~E, identify: bool = True) -> ~E:
472    def quote_identifier(self, expression: E, identify: bool = True) -> E:
473        """
474        Adds quotes to a given identifier.
475
476        Args:
477            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
478            identify: If set to `False`, the quotes will only be added if the identifier is deemed
479                "unsafe", with respect to its characters and this dialect's normalization strategy.
480        """
481        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
482            name = expression.this
483            expression.set(
484                "quoted",
485                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
486            )
487
488        return expression

Adds quotes to a given identifier.

Arguments:
  • expression: The expression of interest. If it's not an Identifier, this method is a no-op.
  • identify: If set to False, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
def to_json_path( self, path: Optional[sqlglot.expressions.Expression]) -> Optional[sqlglot.expressions.Expression]:
490    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
491        if isinstance(path, exp.Literal):
492            path_text = path.name
493            if path.is_number:
494                path_text = f"[{path_text}]"
495
496            try:
497                return parse_json_path(path_text)
498            except ParseError as e:
499                logger.warning(f"Invalid JSON path syntax. {str(e)}")
500
501        return path
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
503    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
504        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
506    def parse_into(
507        self, expression_type: exp.IntoType, sql: str, **opts
508    ) -> t.List[t.Optional[exp.Expression]]:
509        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: sqlglot.expressions.Expression, copy: bool = True, **opts) -> str:
511    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
512        return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> List[str]:
514    def transpile(self, sql: str, **opts) -> t.List[str]:
515        return [
516            self.generate(expression, copy=False, **opts) if expression else ""
517            for expression in self.parse(sql)
518        ]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
520    def tokenize(self, sql: str) -> t.List[Token]:
521        return self.tokenizer.tokenize(sql)
tokenizer: sqlglot.tokens.Tokenizer
523    @property
524    def tokenizer(self) -> Tokenizer:
525        if not hasattr(self, "_tokenizer"):
526            self._tokenizer = self.tokenizer_class(dialect=self)
527        return self._tokenizer
def parser(self, **opts) -> sqlglot.parser.Parser:
529    def parser(self, **opts) -> Parser:
530        return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
532    def generator(self, **opts) -> Generator:
533        return self.generator_class(dialect=self, **opts)
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
539def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
540    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
543def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
544    if expression.args.get("accuracy"):
545        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
546    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
549def if_sql(
550    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
551) -> t.Callable[[Generator, exp.If], str]:
552    def _if_sql(self: Generator, expression: exp.If) -> str:
553        return self.func(
554            name,
555            expression.this,
556            expression.args.get("true"),
557            expression.args.get("false") or false_value,
558        )
559
560    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]) -> str:
563def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
564    this = expression.this
565    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
566        this.replace(exp.cast(this, exp.DataType.Type.JSON))
567
568    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
571def inline_array_sql(self: Generator, expression: exp.Array) -> str:
572    return f"[{self.expressions(expression, flat=True)}]"
def inline_array_unless_query( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
575def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
576    elem = seq_get(expression.expressions, 0)
577    if isinstance(elem, exp.Expression) and elem.find(exp.Query):
578        return self.func("ARRAY", elem)
579    return inline_array_sql(self, expression)
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
582def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
583    return self.like_sql(
584        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
585    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
588def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
589    zone = self.sql(expression, "this")
590    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
593def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
594    if expression.args.get("recursive"):
595        self.unsupported("Recursive CTEs are unsupported")
596        expression.args["recursive"] = False
597    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
600def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
601    n = self.sql(expression, "this")
602    d = self.sql(expression, "expression")
603    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
606def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
607    self.unsupported("TABLESAMPLE unsupported")
608    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
611def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
612    self.unsupported("PIVOT unsupported")
613    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
616def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
617    return self.cast_sql(expression)
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
620def no_comment_column_constraint_sql(
621    self: Generator, expression: exp.CommentColumnConstraint
622) -> str:
623    self.unsupported("CommentColumnConstraint unsupported")
624    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
627def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
628    self.unsupported("MAP_FROM_ENTRIES unsupported")
629    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition, generate_instance: bool = False) -> str:
632def str_position_sql(
633    self: Generator, expression: exp.StrPosition, generate_instance: bool = False
634) -> str:
635    this = self.sql(expression, "this")
636    substr = self.sql(expression, "substr")
637    position = self.sql(expression, "position")
638    instance = expression.args.get("instance") if generate_instance else None
639    position_offset = ""
640
641    if position:
642        # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
643        this = self.func("SUBSTR", this, position)
644        position_offset = f" + {position} - 1"
645
646    return self.func("STRPOS", this, substr, instance) + position_offset
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
649def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
650    return (
651        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
652    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
655def var_map_sql(
656    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
657) -> str:
658    keys = expression.args["keys"]
659    values = expression.args["values"]
660
661    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
662        self.unsupported("Cannot convert array columns into map.")
663        return self.func(map_func_name, keys, values)
664
665    args = []
666    for key, value in zip(keys.expressions, values.expressions):
667        args.append(self.sql(key))
668        args.append(self.sql(value))
669
670    return self.func(map_func_name, *args)
def build_formatted_time( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
673def build_formatted_time(
674    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
675) -> t.Callable[[t.List], E]:
676    """Helper used for time expressions.
677
678    Args:
679        exp_class: the expression class to instantiate.
680        dialect: target sql dialect.
681        default: the default format, True being time.
682
683    Returns:
684        A callable that can be used to return the appropriately formatted time expression.
685    """
686
687    def _builder(args: t.List):
688        return exp_class(
689            this=seq_get(args, 0),
690            format=Dialect[dialect].format_time(
691                seq_get(args, 1)
692                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
693            ),
694        )
695
696    return _builder

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
699def time_format(
700    dialect: DialectType = None,
701) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
702    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
703        """
704        Returns the time format for a given expression, unless it's equivalent
705        to the default time format of the dialect of interest.
706        """
707        time_format = self.format_time(expression)
708        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
709
710    return _time_format
def build_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
713def build_date_delta(
714    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
715) -> t.Callable[[t.List], E]:
716    def _builder(args: t.List) -> E:
717        unit_based = len(args) == 3
718        this = args[2] if unit_based else seq_get(args, 0)
719        unit = args[0] if unit_based else exp.Literal.string("DAY")
720        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
721        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
722
723    return _builder
def build_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
726def build_date_delta_with_interval(
727    expression_class: t.Type[E],
728) -> t.Callable[[t.List], t.Optional[E]]:
729    def _builder(args: t.List) -> t.Optional[E]:
730        if len(args) < 2:
731            return None
732
733        interval = args[1]
734
735        if not isinstance(interval, exp.Interval):
736            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
737
738        expression = interval.this
739        if expression and expression.is_string:
740            expression = exp.Literal.number(expression.this)
741
742        return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
743
744    return _builder
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
747def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
748    unit = seq_get(args, 0)
749    this = seq_get(args, 1)
750
751    if isinstance(this, exp.Cast) and this.is_type("date"):
752        return exp.DateTrunc(unit=unit, this=this)
753    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
756def date_add_interval_sql(
757    data_type: str, kind: str
758) -> t.Callable[[Generator, exp.Expression], str]:
759    def func(self: Generator, expression: exp.Expression) -> str:
760        this = self.sql(expression, "this")
761        interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
762        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
763
764    return func
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
767def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
768    return self.func("DATE_TRUNC", unit_to_str(expression), expression.this)
def no_timestamp_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Timestamp) -> str:
771def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
772    if not expression.expression:
773        from sqlglot.optimizer.annotate_types import annotate_types
774
775        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
776        return self.sql(exp.cast(expression.this, target_type))
777    if expression.text("expression").lower() in TIMEZONES:
778        return self.sql(
779            exp.AtTimeZone(
780                this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
781                zone=expression.expression,
782            )
783        )
784    return self.func("TIMESTAMP", expression.this, expression.expression)
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
787def locate_to_strposition(args: t.List) -> exp.Expression:
788    return exp.StrPosition(
789        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
790    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
793def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
794    return self.func(
795        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
796    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
799def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
800    return self.sql(
801        exp.Substring(
802            this=expression.this, start=exp.Literal.number(1), length=expression.expression
803        )
804    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
807def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
808    return self.sql(
809        exp.Substring(
810            this=expression.this,
811            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
812        )
813    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
816def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
817    return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
820def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
821    return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
825def encode_decode_sql(
826    self: Generator, expression: exp.Expression, name: str, replace: bool = True
827) -> str:
828    charset = expression.args.get("charset")
829    if charset and charset.name.lower() != "utf-8":
830        self.unsupported(f"Expected utf-8 character set, got {charset}.")
831
832    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
835def min_or_least(self: Generator, expression: exp.Min) -> str:
836    name = "LEAST" if expression.expressions else "MIN"
837    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
840def max_or_greatest(self: Generator, expression: exp.Max) -> str:
841    name = "GREATEST" if expression.expressions else "MAX"
842    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
845def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
846    cond = expression.this
847
848    if isinstance(expression.this, exp.Distinct):
849        cond = expression.this.expressions[0]
850        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
851
852    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
855def trim_sql(self: Generator, expression: exp.Trim) -> str:
856    target = self.sql(expression, "this")
857    trim_type = self.sql(expression, "position")
858    remove_chars = self.sql(expression, "expression")
859    collation = self.sql(expression, "collation")
860
861    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
862    if not remove_chars and not collation:
863        return self.trim_sql(expression)
864
865    trim_type = f"{trim_type} " if trim_type else ""
866    remove_chars = f"{remove_chars} " if remove_chars else ""
867    from_part = "FROM " if trim_type or remove_chars else ""
868    collation = f" COLLATE {collation}" if collation else ""
869    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
872def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
873    return self.func("STRPTIME", expression.this, self.format_time(expression))
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat) -> str:
876def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
877    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
880def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
881    delim, *rest_args = expression.expressions
882    return self.sql(
883        reduce(
884            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
885            rest_args,
886        )
887    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
890def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
891    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
892    if bad_args:
893        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
894
895    return self.func(
896        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
897    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
900def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
901    bad_args = list(
902        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
903    )
904    if bad_args:
905        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
906
907    return self.func(
908        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
909    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
912def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
913    names = []
914    for agg in aggregations:
915        if isinstance(agg, exp.Alias):
916            names.append(agg.alias)
917        else:
918            """
919            This case corresponds to aggregations without aliases being used as suffixes
920            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
921            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
922            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
923            """
924            agg_all_unquoted = agg.transform(
925                lambda node: (
926                    exp.Identifier(this=node.name, quoted=False)
927                    if isinstance(node, exp.Identifier)
928                    else node
929                )
930            )
931            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
932
933    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
936def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
937    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def build_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
941def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
942    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
945def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
946    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
949def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
950    a = self.sql(expression.left)
951    b = self.sql(expression.right)
952    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
955def is_parse_json(expression: exp.Expression) -> bool:
956    return isinstance(expression, exp.ParseJSON) or (
957        isinstance(expression, exp.Cast) and expression.is_type("json")
958    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
961def isnull_to_is_null(args: t.List) -> exp.Expression:
962    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def generatedasidentitycolumnconstraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.GeneratedAsIdentityColumnConstraint) -> str:
965def generatedasidentitycolumnconstraint_sql(
966    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
967) -> str:
968    start = self.sql(expression, "start") or "1"
969    increment = self.sql(expression, "increment") or "1"
970    return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.ArgMax | sqlglot.expressions.ArgMin], str]:
973def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
974    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
975        if expression.args.get("count"):
976            self.unsupported(f"Only two arguments are supported in function {name}.")
977
978        return self.func(name, expression.this, expression.expression)
979
980    return _arg_max_or_min_sql
def ts_or_ds_add_cast( expression: sqlglot.expressions.TsOrDsAdd) -> sqlglot.expressions.TsOrDsAdd:
983def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
984    this = expression.this.copy()
985
986    return_type = expression.return_type
987    if return_type.is_type(exp.DataType.Type.DATE):
988        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
989        # can truncate timestamp strings, because some dialects can't cast them to DATE
990        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
991
992    expression.this.replace(exp.cast(this, return_type))
993    return expression
def date_delta_sql( name: str, cast: bool = False) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.DateAdd, sqlglot.expressions.TsOrDsAdd, sqlglot.expressions.DateDiff, sqlglot.expressions.TsOrDsDiff]], str]:
 996def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
 997    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
 998        if cast and isinstance(expression, exp.TsOrDsAdd):
 999            expression = ts_or_ds_add_cast(expression)
1000
1001        return self.func(
1002            name,
1003            unit_to_var(expression),
1004            expression.expression,
1005            expression.this,
1006        )
1007
1008    return _delta_sql
def unit_to_str( expression: sqlglot.expressions.Expression, default: str = 'DAY') -> Optional[sqlglot.expressions.Expression]:
1011def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1012    unit = expression.args.get("unit")
1013
1014    if isinstance(unit, exp.Placeholder):
1015        return unit
1016    if unit:
1017        return exp.Literal.string(unit.name)
1018    return exp.Literal.string(default) if default else None
def unit_to_var( expression: sqlglot.expressions.Expression, default: str = 'DAY') -> Optional[sqlglot.expressions.Expression]:
1021def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1022    unit = expression.args.get("unit")
1023
1024    if isinstance(unit, (exp.Var, exp.Placeholder)):
1025        return unit
1026    return exp.Var(this=default) if default else None
def no_last_day_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.LastDay) -> str:
1029def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
1030    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
1031    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
1032    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
1033
1034    return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
def merge_without_target_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Merge) -> str:
1037def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1038    """Remove table refs from columns in when statements."""
1039    alias = expression.this.args.get("alias")
1040
1041    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
1042        return self.dialect.normalize_identifier(identifier).name if identifier else None
1043
1044    targets = {normalize(expression.this.this)}
1045
1046    if alias:
1047        targets.add(normalize(alias.this))
1048
1049    for when in expression.expressions:
1050        when.transform(
1051            lambda node: (
1052                exp.column(node.this)
1053                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1054                else node
1055            ),
1056            copy=False,
1057        )
1058
1059    return self.merge_sql(expression)

Remove table refs from columns in when statements.

def build_json_extract_path( expr_type: Type[~F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False) -> Callable[[List], ~F]:
1062def build_json_extract_path(
1063    expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
1064) -> t.Callable[[t.List], F]:
1065    def _builder(args: t.List) -> F:
1066        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1067        for arg in args[1:]:
1068            if not isinstance(arg, exp.Literal):
1069                # We use the fallback parser because we can't really transpile non-literals safely
1070                return expr_type.from_arg_list(args)
1071
1072            text = arg.name
1073            if is_int(text):
1074                index = int(text)
1075                segments.append(
1076                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1077                )
1078            else:
1079                segments.append(exp.JSONPathKey(this=text))
1080
1081        # This is done to avoid failing in the expression validator due to the arg count
1082        del args[2:]
1083        return expr_type(
1084            this=seq_get(args, 0),
1085            expression=exp.JSONPath(expressions=segments),
1086            only_json_types=arrow_req_json_type,
1087        )
1088
1089    return _builder
def json_extract_segments( name: str, quoted_index: bool = True, op: Optional[str] = None) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]], str]:
1092def json_extract_segments(
1093    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1094) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1095    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1096        path = expression.expression
1097        if not isinstance(path, exp.JSONPath):
1098            return rename_func(name)(self, expression)
1099
1100        segments = []
1101        for segment in path.expressions:
1102            path = self.sql(segment)
1103            if path:
1104                if isinstance(segment, exp.JSONPathPart) and (
1105                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1106                ):
1107                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1108
1109                segments.append(path)
1110
1111        if op:
1112            return f" {op} ".join([self.sql(expression.this), *segments])
1113        return self.func(name, expression.this, *segments)
1114
1115    return _json_extract_segments
def json_path_key_only_name( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONPathKey) -> str:
1118def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1119    if isinstance(expression.this, exp.JSONPathWildcard):
1120        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1121
1122    return expression.name
def filter_array_using_unnest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ArrayFilter) -> str:
1125def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1126    cond = expression.expression
1127    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1128        alias = cond.expressions[0]
1129        cond = cond.this
1130    elif isinstance(cond, exp.Predicate):
1131        alias = "_u"
1132    else:
1133        self.unsupported("Unsupported filter condition")
1134        return ""
1135
1136    unnest = exp.Unnest(expressions=[expression.this])
1137    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1138    return self.sql(exp.Array(expressions=[filtered]))
def to_number_with_nls_param(self, expression: sqlglot.expressions.ToNumber) -> str:
1141def to_number_with_nls_param(self, expression: exp.ToNumber) -> str:
1142    return self.func(
1143        "TO_NUMBER",
1144        expression.this,
1145        expression.args.get("format"),
1146        expression.args.get("nlsparam"),
1147    )