Edit on GitHub

sqlglot.dialects.dialect

   1from __future__ import annotations
   2
   3import logging
   4import typing as t
   5from enum import Enum, auto
   6from functools import reduce
   7
   8from sqlglot import exp
   9from sqlglot.errors import ParseError
  10from sqlglot.generator import Generator
  11from sqlglot.helper import AutoName, flatten, is_int, seq_get
  12from sqlglot.jsonpath import parse as parse_json_path
  13from sqlglot.parser import Parser
  14from sqlglot.time import TIMEZONES, format_time
  15from sqlglot.tokens import Token, Tokenizer, TokenType
  16from sqlglot.trie import new_trie
  17
  18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
  19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
  20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
  21
  22
  23if t.TYPE_CHECKING:
  24    from sqlglot._typing import B, E, F
  25
  26logger = logging.getLogger("sqlglot")
  27
  28
  29class Dialects(str, Enum):
  30    """Dialects supported by SQLGLot."""
  31
  32    DIALECT = ""
  33
  34    ATHENA = "athena"
  35    BIGQUERY = "bigquery"
  36    CLICKHOUSE = "clickhouse"
  37    DATABRICKS = "databricks"
  38    DORIS = "doris"
  39    DRILL = "drill"
  40    DUCKDB = "duckdb"
  41    HIVE = "hive"
  42    MYSQL = "mysql"
  43    ORACLE = "oracle"
  44    POSTGRES = "postgres"
  45    PRESTO = "presto"
  46    PRQL = "prql"
  47    REDSHIFT = "redshift"
  48    SNOWFLAKE = "snowflake"
  49    SPARK = "spark"
  50    SPARK2 = "spark2"
  51    SQLITE = "sqlite"
  52    STARROCKS = "starrocks"
  53    TABLEAU = "tableau"
  54    TERADATA = "teradata"
  55    TRINO = "trino"
  56    TSQL = "tsql"
  57
  58
  59class NormalizationStrategy(str, AutoName):
  60    """Specifies the strategy according to which identifiers should be normalized."""
  61
  62    LOWERCASE = auto()
  63    """Unquoted identifiers are lowercased."""
  64
  65    UPPERCASE = auto()
  66    """Unquoted identifiers are uppercased."""
  67
  68    CASE_SENSITIVE = auto()
  69    """Always case-sensitive, regardless of quotes."""
  70
  71    CASE_INSENSITIVE = auto()
  72    """Always case-insensitive, regardless of quotes."""
  73
  74
  75class _Dialect(type):
  76    classes: t.Dict[str, t.Type[Dialect]] = {}
  77
  78    def __eq__(cls, other: t.Any) -> bool:
  79        if cls is other:
  80            return True
  81        if isinstance(other, str):
  82            return cls is cls.get(other)
  83        if isinstance(other, Dialect):
  84            return cls is type(other)
  85
  86        return False
  87
  88    def __hash__(cls) -> int:
  89        return hash(cls.__name__.lower())
  90
  91    @classmethod
  92    def __getitem__(cls, key: str) -> t.Type[Dialect]:
  93        return cls.classes[key]
  94
  95    @classmethod
  96    def get(
  97        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
  98    ) -> t.Optional[t.Type[Dialect]]:
  99        return cls.classes.get(key, default)
 100
 101    def __new__(cls, clsname, bases, attrs):
 102        klass = super().__new__(cls, clsname, bases, attrs)
 103        enum = Dialects.__members__.get(clsname.upper())
 104        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 105
 106        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 107        klass.FORMAT_TRIE = (
 108            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 109        )
 110        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 111        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 112
 113        base = seq_get(bases, 0)
 114        base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),)
 115        base_parser = (getattr(base, "parser_class", Parser),)
 116        base_generator = (getattr(base, "generator_class", Generator),)
 117
 118        klass.tokenizer_class = klass.__dict__.get(
 119            "Tokenizer", type("Tokenizer", base_tokenizer, {})
 120        )
 121        klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {}))
 122        klass.generator_class = klass.__dict__.get(
 123            "Generator", type("Generator", base_generator, {})
 124        )
 125
 126        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 127        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 128            klass.tokenizer_class._IDENTIFIERS.items()
 129        )[0]
 130
 131        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 132            return next(
 133                (
 134                    (s, e)
 135                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 136                    if t == token_type
 137                ),
 138                (None, None),
 139            )
 140
 141        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
 142        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
 143        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
 144        klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
 145
 146        if "\\" in klass.tokenizer_class.STRING_ESCAPES:
 147            klass.UNESCAPED_SEQUENCES = {
 148                "\\a": "\a",
 149                "\\b": "\b",
 150                "\\f": "\f",
 151                "\\n": "\n",
 152                "\\r": "\r",
 153                "\\t": "\t",
 154                "\\v": "\v",
 155                "\\\\": "\\",
 156                **klass.UNESCAPED_SEQUENCES,
 157            }
 158
 159        klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()}
 160
 161        if enum not in ("", "bigquery"):
 162            klass.generator_class.SELECT_KINDS = ()
 163
 164        if enum not in ("", "databricks", "hive", "spark", "spark2"):
 165            modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
 166            for modifier in ("cluster", "distribute", "sort"):
 167                modifier_transforms.pop(modifier, None)
 168
 169            klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
 170
 171        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
 172            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
 173                TokenType.ANTI,
 174                TokenType.SEMI,
 175            }
 176
 177        return klass
 178
 179
 180class Dialect(metaclass=_Dialect):
 181    INDEX_OFFSET = 0
 182    """The base index offset for arrays."""
 183
 184    WEEK_OFFSET = 0
 185    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
 186
 187    UNNEST_COLUMN_ONLY = False
 188    """Whether `UNNEST` table aliases are treated as column aliases."""
 189
 190    ALIAS_POST_TABLESAMPLE = False
 191    """Whether the table alias comes after tablesample."""
 192
 193    TABLESAMPLE_SIZE_IS_PERCENT = False
 194    """Whether a size in the table sample clause represents percentage."""
 195
 196    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
 197    """Specifies the strategy according to which identifiers should be normalized."""
 198
 199    IDENTIFIERS_CAN_START_WITH_DIGIT = False
 200    """Whether an unquoted identifier can start with a digit."""
 201
 202    DPIPE_IS_STRING_CONCAT = True
 203    """Whether the DPIPE token (`||`) is a string concatenation operator."""
 204
 205    STRICT_STRING_CONCAT = False
 206    """Whether `CONCAT`'s arguments must be strings."""
 207
 208    SUPPORTS_USER_DEFINED_TYPES = True
 209    """Whether user-defined data types are supported."""
 210
 211    SUPPORTS_SEMI_ANTI_JOIN = True
 212    """Whether `SEMI` or `ANTI` joins are supported."""
 213
 214    NORMALIZE_FUNCTIONS: bool | str = "upper"
 215    """
 216    Determines how function names are going to be normalized.
 217    Possible values:
 218        "upper" or True: Convert names to uppercase.
 219        "lower": Convert names to lowercase.
 220        False: Disables function name normalization.
 221    """
 222
 223    LOG_BASE_FIRST: t.Optional[bool] = True
 224    """
 225    Whether the base comes first in the `LOG` function.
 226    Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
 227    """
 228
 229    NULL_ORDERING = "nulls_are_small"
 230    """
 231    Default `NULL` ordering method to use if not explicitly set.
 232    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
 233    """
 234
 235    TYPED_DIVISION = False
 236    """
 237    Whether the behavior of `a / b` depends on the types of `a` and `b`.
 238    False means `a / b` is always float division.
 239    True means `a / b` is integer division if both `a` and `b` are integers.
 240    """
 241
 242    SAFE_DIVISION = False
 243    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
 244
 245    CONCAT_COALESCE = False
 246    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
 247
 248    DATE_FORMAT = "'%Y-%m-%d'"
 249    DATEINT_FORMAT = "'%Y%m%d'"
 250    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
 251
 252    TIME_MAPPING: t.Dict[str, str] = {}
 253    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
 254
 255    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
 256    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
 257    FORMAT_MAPPING: t.Dict[str, str] = {}
 258    """
 259    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
 260    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
 261    """
 262
 263    UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
 264    """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
 265
 266    PSEUDOCOLUMNS: t.Set[str] = set()
 267    """
 268    Columns that are auto-generated by the engine corresponding to this dialect.
 269    For example, such columns may be excluded from `SELECT *` queries.
 270    """
 271
 272    PREFER_CTE_ALIAS_COLUMN = False
 273    """
 274    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
 275    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
 276    any projection aliases in the subquery.
 277
 278    For example,
 279        WITH y(c) AS (
 280            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
 281        ) SELECT c FROM y;
 282
 283        will be rewritten as
 284
 285        WITH y(c) AS (
 286            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
 287        ) SELECT c FROM y;
 288    """
 289
 290    # --- Autofilled ---
 291
 292    tokenizer_class = Tokenizer
 293    parser_class = Parser
 294    generator_class = Generator
 295
 296    # A trie of the time_mapping keys
 297    TIME_TRIE: t.Dict = {}
 298    FORMAT_TRIE: t.Dict = {}
 299
 300    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
 301    INVERSE_TIME_TRIE: t.Dict = {}
 302
 303    ESCAPED_SEQUENCES: t.Dict[str, str] = {}
 304
 305    # Delimiters for string literals and identifiers
 306    QUOTE_START = "'"
 307    QUOTE_END = "'"
 308    IDENTIFIER_START = '"'
 309    IDENTIFIER_END = '"'
 310
 311    # Delimiters for bit, hex, byte and unicode literals
 312    BIT_START: t.Optional[str] = None
 313    BIT_END: t.Optional[str] = None
 314    HEX_START: t.Optional[str] = None
 315    HEX_END: t.Optional[str] = None
 316    BYTE_START: t.Optional[str] = None
 317    BYTE_END: t.Optional[str] = None
 318    UNICODE_START: t.Optional[str] = None
 319    UNICODE_END: t.Optional[str] = None
 320
 321    @classmethod
 322    def get_or_raise(cls, dialect: DialectType) -> Dialect:
 323        """
 324        Look up a dialect in the global dialect registry and return it if it exists.
 325
 326        Args:
 327            dialect: The target dialect. If this is a string, it can be optionally followed by
 328                additional key-value pairs that are separated by commas and are used to specify
 329                dialect settings, such as whether the dialect's identifiers are case-sensitive.
 330
 331        Example:
 332            >>> dialect = dialect_class = get_or_raise("duckdb")
 333            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
 334
 335        Returns:
 336            The corresponding Dialect instance.
 337        """
 338
 339        if not dialect:
 340            return cls()
 341        if isinstance(dialect, _Dialect):
 342            return dialect()
 343        if isinstance(dialect, Dialect):
 344            return dialect
 345        if isinstance(dialect, str):
 346            try:
 347                dialect_name, *kv_pairs = dialect.split(",")
 348                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
 349            except ValueError:
 350                raise ValueError(
 351                    f"Invalid dialect format: '{dialect}'. "
 352                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
 353                )
 354
 355            result = cls.get(dialect_name.strip())
 356            if not result:
 357                from difflib import get_close_matches
 358
 359                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
 360                if similar:
 361                    similar = f" Did you mean {similar}?"
 362
 363                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
 364
 365            return result(**kwargs)
 366
 367        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
 368
 369    @classmethod
 370    def format_time(
 371        cls, expression: t.Optional[str | exp.Expression]
 372    ) -> t.Optional[exp.Expression]:
 373        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
 374        if isinstance(expression, str):
 375            return exp.Literal.string(
 376                # the time formats are quoted
 377                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
 378            )
 379
 380        if expression and expression.is_string:
 381            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
 382
 383        return expression
 384
 385    def __init__(self, **kwargs) -> None:
 386        normalization_strategy = kwargs.get("normalization_strategy")
 387
 388        if normalization_strategy is None:
 389            self.normalization_strategy = self.NORMALIZATION_STRATEGY
 390        else:
 391            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
 392
 393    def __eq__(self, other: t.Any) -> bool:
 394        # Does not currently take dialect state into account
 395        return type(self) == other
 396
 397    def __hash__(self) -> int:
 398        # Does not currently take dialect state into account
 399        return hash(type(self))
 400
 401    def normalize_identifier(self, expression: E) -> E:
 402        """
 403        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
 404
 405        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
 406        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
 407        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
 408        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
 409
 410        There are also dialects like Spark, which are case-insensitive even when quotes are
 411        present, and dialects like MySQL, whose resolution rules match those employed by the
 412        underlying operating system, for example they may always be case-sensitive in Linux.
 413
 414        Finally, the normalization behavior of some engines can even be controlled through flags,
 415        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
 416
 417        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
 418        that it can analyze queries in the optimizer and successfully capture their semantics.
 419        """
 420        if (
 421            isinstance(expression, exp.Identifier)
 422            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
 423            and (
 424                not expression.quoted
 425                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
 426            )
 427        ):
 428            expression.set(
 429                "this",
 430                (
 431                    expression.this.upper()
 432                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 433                    else expression.this.lower()
 434                ),
 435            )
 436
 437        return expression
 438
 439    def case_sensitive(self, text: str) -> bool:
 440        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
 441        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
 442            return False
 443
 444        unsafe = (
 445            str.islower
 446            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 447            else str.isupper
 448        )
 449        return any(unsafe(char) for char in text)
 450
 451    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
 452        """Checks if text can be identified given an identify option.
 453
 454        Args:
 455            text: The text to check.
 456            identify:
 457                `"always"` or `True`: Always returns `True`.
 458                `"safe"`: Only returns `True` if the identifier is case-insensitive.
 459
 460        Returns:
 461            Whether the given text can be identified.
 462        """
 463        if identify is True or identify == "always":
 464            return True
 465
 466        if identify == "safe":
 467            return not self.case_sensitive(text)
 468
 469        return False
 470
 471    def quote_identifier(self, expression: E, identify: bool = True) -> E:
 472        """
 473        Adds quotes to a given identifier.
 474
 475        Args:
 476            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
 477            identify: If set to `False`, the quotes will only be added if the identifier is deemed
 478                "unsafe", with respect to its characters and this dialect's normalization strategy.
 479        """
 480        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
 481            name = expression.this
 482            expression.set(
 483                "quoted",
 484                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
 485            )
 486
 487        return expression
 488
 489    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
 490        if isinstance(path, exp.Literal):
 491            path_text = path.name
 492            if path.is_number:
 493                path_text = f"[{path_text}]"
 494
 495            try:
 496                return parse_json_path(path_text)
 497            except ParseError as e:
 498                logger.warning(f"Invalid JSON path syntax. {str(e)}")
 499
 500        return path
 501
 502    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
 503        return self.parser(**opts).parse(self.tokenize(sql), sql)
 504
 505    def parse_into(
 506        self, expression_type: exp.IntoType, sql: str, **opts
 507    ) -> t.List[t.Optional[exp.Expression]]:
 508        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
 509
 510    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
 511        return self.generator(**opts).generate(expression, copy=copy)
 512
 513    def transpile(self, sql: str, **opts) -> t.List[str]:
 514        return [
 515            self.generate(expression, copy=False, **opts) if expression else ""
 516            for expression in self.parse(sql)
 517        ]
 518
 519    def tokenize(self, sql: str) -> t.List[Token]:
 520        return self.tokenizer.tokenize(sql)
 521
 522    @property
 523    def tokenizer(self) -> Tokenizer:
 524        if not hasattr(self, "_tokenizer"):
 525            self._tokenizer = self.tokenizer_class(dialect=self)
 526        return self._tokenizer
 527
 528    def parser(self, **opts) -> Parser:
 529        return self.parser_class(dialect=self, **opts)
 530
 531    def generator(self, **opts) -> Generator:
 532        return self.generator_class(dialect=self, **opts)
 533
 534
 535DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
 536
 537
 538def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
 539    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
 540
 541
 542def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
 543    if expression.args.get("accuracy"):
 544        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
 545    return self.func("APPROX_COUNT_DISTINCT", expression.this)
 546
 547
 548def if_sql(
 549    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
 550) -> t.Callable[[Generator, exp.If], str]:
 551    def _if_sql(self: Generator, expression: exp.If) -> str:
 552        return self.func(
 553            name,
 554            expression.this,
 555            expression.args.get("true"),
 556            expression.args.get("false") or false_value,
 557        )
 558
 559    return _if_sql
 560
 561
 562def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
 563    this = expression.this
 564    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
 565        this.replace(exp.cast(this, "json"))
 566
 567    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
 568
 569
 570def inline_array_sql(self: Generator, expression: exp.Array) -> str:
 571    return f"[{self.expressions(expression, flat=True)}]"
 572
 573
 574def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
 575    return self.like_sql(
 576        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
 577    )
 578
 579
 580def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
 581    zone = self.sql(expression, "this")
 582    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
 583
 584
 585def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
 586    if expression.args.get("recursive"):
 587        self.unsupported("Recursive CTEs are unsupported")
 588        expression.args["recursive"] = False
 589    return self.with_sql(expression)
 590
 591
 592def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
 593    n = self.sql(expression, "this")
 594    d = self.sql(expression, "expression")
 595    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
 596
 597
 598def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
 599    self.unsupported("TABLESAMPLE unsupported")
 600    return self.sql(expression.this)
 601
 602
 603def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
 604    self.unsupported("PIVOT unsupported")
 605    return ""
 606
 607
 608def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
 609    return self.cast_sql(expression)
 610
 611
 612def no_comment_column_constraint_sql(
 613    self: Generator, expression: exp.CommentColumnConstraint
 614) -> str:
 615    self.unsupported("CommentColumnConstraint unsupported")
 616    return ""
 617
 618
 619def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
 620    self.unsupported("MAP_FROM_ENTRIES unsupported")
 621    return ""
 622
 623
 624def str_position_sql(
 625    self: Generator, expression: exp.StrPosition, generate_instance: bool = False
 626) -> str:
 627    this = self.sql(expression, "this")
 628    substr = self.sql(expression, "substr")
 629    position = self.sql(expression, "position")
 630    instance = expression.args.get("instance") if generate_instance else None
 631    position_offset = ""
 632
 633    if position:
 634        # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
 635        this = self.func("SUBSTR", this, position)
 636        position_offset = f" + {position} - 1"
 637
 638    return self.func("STRPOS", this, substr, instance) + position_offset
 639
 640
 641def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
 642    return (
 643        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
 644    )
 645
 646
 647def var_map_sql(
 648    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
 649) -> str:
 650    keys = expression.args["keys"]
 651    values = expression.args["values"]
 652
 653    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
 654        self.unsupported("Cannot convert array columns into map.")
 655        return self.func(map_func_name, keys, values)
 656
 657    args = []
 658    for key, value in zip(keys.expressions, values.expressions):
 659        args.append(self.sql(key))
 660        args.append(self.sql(value))
 661
 662    return self.func(map_func_name, *args)
 663
 664
 665def build_formatted_time(
 666    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
 667) -> t.Callable[[t.List], E]:
 668    """Helper used for time expressions.
 669
 670    Args:
 671        exp_class: the expression class to instantiate.
 672        dialect: target sql dialect.
 673        default: the default format, True being time.
 674
 675    Returns:
 676        A callable that can be used to return the appropriately formatted time expression.
 677    """
 678
 679    def _builder(args: t.List):
 680        return exp_class(
 681            this=seq_get(args, 0),
 682            format=Dialect[dialect].format_time(
 683                seq_get(args, 1)
 684                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
 685            ),
 686        )
 687
 688    return _builder
 689
 690
 691def time_format(
 692    dialect: DialectType = None,
 693) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
 694    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
 695        """
 696        Returns the time format for a given expression, unless it's equivalent
 697        to the default time format of the dialect of interest.
 698        """
 699        time_format = self.format_time(expression)
 700        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
 701
 702    return _time_format
 703
 704
 705def build_date_delta(
 706    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
 707) -> t.Callable[[t.List], E]:
 708    def _builder(args: t.List) -> E:
 709        unit_based = len(args) == 3
 710        this = args[2] if unit_based else seq_get(args, 0)
 711        unit = args[0] if unit_based else exp.Literal.string("DAY")
 712        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
 713        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
 714
 715    return _builder
 716
 717
 718def build_date_delta_with_interval(
 719    expression_class: t.Type[E],
 720) -> t.Callable[[t.List], t.Optional[E]]:
 721    def _builder(args: t.List) -> t.Optional[E]:
 722        if len(args) < 2:
 723            return None
 724
 725        interval = args[1]
 726
 727        if not isinstance(interval, exp.Interval):
 728            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
 729
 730        expression = interval.this
 731        if expression and expression.is_string:
 732            expression = exp.Literal.number(expression.this)
 733
 734        return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
 735
 736    return _builder
 737
 738
 739def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
 740    unit = seq_get(args, 0)
 741    this = seq_get(args, 1)
 742
 743    if isinstance(this, exp.Cast) and this.is_type("date"):
 744        return exp.DateTrunc(unit=unit, this=this)
 745    return exp.TimestampTrunc(this=this, unit=unit)
 746
 747
 748def date_add_interval_sql(
 749    data_type: str, kind: str
 750) -> t.Callable[[Generator, exp.Expression], str]:
 751    def func(self: Generator, expression: exp.Expression) -> str:
 752        this = self.sql(expression, "this")
 753        interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
 754        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
 755
 756    return func
 757
 758
 759def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
 760    return self.func("DATE_TRUNC", unit_to_str(expression), expression.this)
 761
 762
 763def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
 764    if not expression.expression:
 765        from sqlglot.optimizer.annotate_types import annotate_types
 766
 767        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
 768        return self.sql(exp.cast(expression.this, to=target_type))
 769    if expression.text("expression").lower() in TIMEZONES:
 770        return self.sql(
 771            exp.AtTimeZone(
 772                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
 773                zone=expression.expression,
 774            )
 775        )
 776    return self.func("TIMESTAMP", expression.this, expression.expression)
 777
 778
 779def locate_to_strposition(args: t.List) -> exp.Expression:
 780    return exp.StrPosition(
 781        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
 782    )
 783
 784
 785def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
 786    return self.func(
 787        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
 788    )
 789
 790
 791def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 792    return self.sql(
 793        exp.Substring(
 794            this=expression.this, start=exp.Literal.number(1), length=expression.expression
 795        )
 796    )
 797
 798
 799def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 800    return self.sql(
 801        exp.Substring(
 802            this=expression.this,
 803            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
 804        )
 805    )
 806
 807
 808def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
 809    return self.sql(exp.cast(expression.this, "timestamp"))
 810
 811
 812def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
 813    return self.sql(exp.cast(expression.this, "date"))
 814
 815
 816# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
 817def encode_decode_sql(
 818    self: Generator, expression: exp.Expression, name: str, replace: bool = True
 819) -> str:
 820    charset = expression.args.get("charset")
 821    if charset and charset.name.lower() != "utf-8":
 822        self.unsupported(f"Expected utf-8 character set, got {charset}.")
 823
 824    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
 825
 826
 827def min_or_least(self: Generator, expression: exp.Min) -> str:
 828    name = "LEAST" if expression.expressions else "MIN"
 829    return rename_func(name)(self, expression)
 830
 831
 832def max_or_greatest(self: Generator, expression: exp.Max) -> str:
 833    name = "GREATEST" if expression.expressions else "MAX"
 834    return rename_func(name)(self, expression)
 835
 836
 837def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
 838    cond = expression.this
 839
 840    if isinstance(expression.this, exp.Distinct):
 841        cond = expression.this.expressions[0]
 842        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
 843
 844    return self.func("sum", exp.func("if", cond, 1, 0))
 845
 846
 847def trim_sql(self: Generator, expression: exp.Trim) -> str:
 848    target = self.sql(expression, "this")
 849    trim_type = self.sql(expression, "position")
 850    remove_chars = self.sql(expression, "expression")
 851    collation = self.sql(expression, "collation")
 852
 853    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
 854    if not remove_chars and not collation:
 855        return self.trim_sql(expression)
 856
 857    trim_type = f"{trim_type} " if trim_type else ""
 858    remove_chars = f"{remove_chars} " if remove_chars else ""
 859    from_part = "FROM " if trim_type or remove_chars else ""
 860    collation = f" COLLATE {collation}" if collation else ""
 861    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
 862
 863
 864def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
 865    return self.func("STRPTIME", expression.this, self.format_time(expression))
 866
 867
 868def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
 869    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
 870
 871
 872def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
 873    delim, *rest_args = expression.expressions
 874    return self.sql(
 875        reduce(
 876            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
 877            rest_args,
 878        )
 879    )
 880
 881
 882def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
 883    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
 884    if bad_args:
 885        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
 886
 887    return self.func(
 888        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
 889    )
 890
 891
 892def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
 893    bad_args = list(
 894        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
 895    )
 896    if bad_args:
 897        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
 898
 899    return self.func(
 900        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
 901    )
 902
 903
 904def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
 905    names = []
 906    for agg in aggregations:
 907        if isinstance(agg, exp.Alias):
 908            names.append(agg.alias)
 909        else:
 910            """
 911            This case corresponds to aggregations without aliases being used as suffixes
 912            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
 913            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
 914            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
 915            """
 916            agg_all_unquoted = agg.transform(
 917                lambda node: (
 918                    exp.Identifier(this=node.name, quoted=False)
 919                    if isinstance(node, exp.Identifier)
 920                    else node
 921                )
 922            )
 923            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
 924
 925    return names
 926
 927
 928def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
 929    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
 930
 931
 932# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
 933def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
 934    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
 935
 936
 937def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
 938    return self.func("MAX", expression.this)
 939
 940
 941def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
 942    a = self.sql(expression.left)
 943    b = self.sql(expression.right)
 944    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
 945
 946
 947def is_parse_json(expression: exp.Expression) -> bool:
 948    return isinstance(expression, exp.ParseJSON) or (
 949        isinstance(expression, exp.Cast) and expression.is_type("json")
 950    )
 951
 952
 953def isnull_to_is_null(args: t.List) -> exp.Expression:
 954    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
 955
 956
 957def generatedasidentitycolumnconstraint_sql(
 958    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
 959) -> str:
 960    start = self.sql(expression, "start") or "1"
 961    increment = self.sql(expression, "increment") or "1"
 962    return f"IDENTITY({start}, {increment})"
 963
 964
 965def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
 966    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
 967        if expression.args.get("count"):
 968            self.unsupported(f"Only two arguments are supported in function {name}.")
 969
 970        return self.func(name, expression.this, expression.expression)
 971
 972    return _arg_max_or_min_sql
 973
 974
 975def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
 976    this = expression.this.copy()
 977
 978    return_type = expression.return_type
 979    if return_type.is_type(exp.DataType.Type.DATE):
 980        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
 981        # can truncate timestamp strings, because some dialects can't cast them to DATE
 982        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
 983
 984    expression.this.replace(exp.cast(this, return_type))
 985    return expression
 986
 987
 988def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
 989    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
 990        if cast and isinstance(expression, exp.TsOrDsAdd):
 991            expression = ts_or_ds_add_cast(expression)
 992
 993        return self.func(
 994            name,
 995            unit_to_var(expression),
 996            expression.expression,
 997            expression.this,
 998        )
 999
1000    return _delta_sql
1001
1002
1003def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1004    unit = expression.args.get("unit")
1005
1006    if isinstance(unit, exp.Placeholder):
1007        return unit
1008    if unit:
1009        return exp.Literal.string(unit.name)
1010    return exp.Literal.string(default) if default else None
1011
1012
1013def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1014    unit = expression.args.get("unit")
1015
1016    if isinstance(unit, (exp.Var, exp.Placeholder)):
1017        return unit
1018    return exp.Var(this=default) if default else None
1019
1020
1021def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
1022    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
1023    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
1024    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
1025
1026    return self.sql(exp.cast(minus_one_day, "date"))
1027
1028
1029def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1030    """Remove table refs from columns in when statements."""
1031    alias = expression.this.args.get("alias")
1032
1033    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
1034        return self.dialect.normalize_identifier(identifier).name if identifier else None
1035
1036    targets = {normalize(expression.this.this)}
1037
1038    if alias:
1039        targets.add(normalize(alias.this))
1040
1041    for when in expression.expressions:
1042        when.transform(
1043            lambda node: (
1044                exp.column(node.this)
1045                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1046                else node
1047            ),
1048            copy=False,
1049        )
1050
1051    return self.merge_sql(expression)
1052
1053
1054def build_json_extract_path(
1055    expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
1056) -> t.Callable[[t.List], F]:
1057    def _builder(args: t.List) -> F:
1058        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1059        for arg in args[1:]:
1060            if not isinstance(arg, exp.Literal):
1061                # We use the fallback parser because we can't really transpile non-literals safely
1062                return expr_type.from_arg_list(args)
1063
1064            text = arg.name
1065            if is_int(text):
1066                index = int(text)
1067                segments.append(
1068                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1069                )
1070            else:
1071                segments.append(exp.JSONPathKey(this=text))
1072
1073        # This is done to avoid failing in the expression validator due to the arg count
1074        del args[2:]
1075        return expr_type(
1076            this=seq_get(args, 0),
1077            expression=exp.JSONPath(expressions=segments),
1078            only_json_types=arrow_req_json_type,
1079        )
1080
1081    return _builder
1082
1083
1084def json_extract_segments(
1085    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1086) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1087    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1088        path = expression.expression
1089        if not isinstance(path, exp.JSONPath):
1090            return rename_func(name)(self, expression)
1091
1092        segments = []
1093        for segment in path.expressions:
1094            path = self.sql(segment)
1095            if path:
1096                if isinstance(segment, exp.JSONPathPart) and (
1097                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1098                ):
1099                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1100
1101                segments.append(path)
1102
1103        if op:
1104            return f" {op} ".join([self.sql(expression.this), *segments])
1105        return self.func(name, expression.this, *segments)
1106
1107    return _json_extract_segments
1108
1109
1110def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1111    if isinstance(expression.this, exp.JSONPathWildcard):
1112        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1113
1114    return expression.name
1115
1116
1117def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1118    cond = expression.expression
1119    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1120        alias = cond.expressions[0]
1121        cond = cond.this
1122    elif isinstance(cond, exp.Predicate):
1123        alias = "_u"
1124    else:
1125        self.unsupported("Unsupported filter condition")
1126        return ""
1127
1128    unnest = exp.Unnest(expressions=[expression.this])
1129    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1130    return self.sql(exp.Array(expressions=[filtered]))
1131
1132
1133def to_number_with_nls_param(self, expression: exp.ToNumber) -> str:
1134    return self.func(
1135        "TO_NUMBER",
1136        expression.this,
1137        expression.args.get("format"),
1138        expression.args.get("nlsparam"),
1139    )
logger = <Logger sqlglot (WARNING)>
class Dialects(builtins.str, enum.Enum):
30class Dialects(str, Enum):
31    """Dialects supported by SQLGLot."""
32
33    DIALECT = ""
34
35    ATHENA = "athena"
36    BIGQUERY = "bigquery"
37    CLICKHOUSE = "clickhouse"
38    DATABRICKS = "databricks"
39    DORIS = "doris"
40    DRILL = "drill"
41    DUCKDB = "duckdb"
42    HIVE = "hive"
43    MYSQL = "mysql"
44    ORACLE = "oracle"
45    POSTGRES = "postgres"
46    PRESTO = "presto"
47    PRQL = "prql"
48    REDSHIFT = "redshift"
49    SNOWFLAKE = "snowflake"
50    SPARK = "spark"
51    SPARK2 = "spark2"
52    SQLITE = "sqlite"
53    STARROCKS = "starrocks"
54    TABLEAU = "tableau"
55    TERADATA = "teradata"
56    TRINO = "trino"
57    TSQL = "tsql"

Dialects supported by SQLGLot.

DIALECT = <Dialects.DIALECT: ''>
ATHENA = <Dialects.ATHENA: 'athena'>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DORIS = <Dialects.DORIS: 'doris'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
PRQL = <Dialects.PRQL: 'prql'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class NormalizationStrategy(builtins.str, sqlglot.helper.AutoName):
60class NormalizationStrategy(str, AutoName):
61    """Specifies the strategy according to which identifiers should be normalized."""
62
63    LOWERCASE = auto()
64    """Unquoted identifiers are lowercased."""
65
66    UPPERCASE = auto()
67    """Unquoted identifiers are uppercased."""
68
69    CASE_SENSITIVE = auto()
70    """Always case-sensitive, regardless of quotes."""
71
72    CASE_INSENSITIVE = auto()
73    """Always case-insensitive, regardless of quotes."""

Specifies the strategy according to which identifiers should be normalized.

LOWERCASE = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Unquoted identifiers are lowercased.

UPPERCASE = <NormalizationStrategy.UPPERCASE: 'UPPERCASE'>

Unquoted identifiers are uppercased.

CASE_SENSITIVE = <NormalizationStrategy.CASE_SENSITIVE: 'CASE_SENSITIVE'>

Always case-sensitive, regardless of quotes.

CASE_INSENSITIVE = <NormalizationStrategy.CASE_INSENSITIVE: 'CASE_INSENSITIVE'>

Always case-insensitive, regardless of quotes.

Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
181class Dialect(metaclass=_Dialect):
182    INDEX_OFFSET = 0
183    """The base index offset for arrays."""
184
185    WEEK_OFFSET = 0
186    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
187
188    UNNEST_COLUMN_ONLY = False
189    """Whether `UNNEST` table aliases are treated as column aliases."""
190
191    ALIAS_POST_TABLESAMPLE = False
192    """Whether the table alias comes after tablesample."""
193
194    TABLESAMPLE_SIZE_IS_PERCENT = False
195    """Whether a size in the table sample clause represents percentage."""
196
197    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
198    """Specifies the strategy according to which identifiers should be normalized."""
199
200    IDENTIFIERS_CAN_START_WITH_DIGIT = False
201    """Whether an unquoted identifier can start with a digit."""
202
203    DPIPE_IS_STRING_CONCAT = True
204    """Whether the DPIPE token (`||`) is a string concatenation operator."""
205
206    STRICT_STRING_CONCAT = False
207    """Whether `CONCAT`'s arguments must be strings."""
208
209    SUPPORTS_USER_DEFINED_TYPES = True
210    """Whether user-defined data types are supported."""
211
212    SUPPORTS_SEMI_ANTI_JOIN = True
213    """Whether `SEMI` or `ANTI` joins are supported."""
214
215    NORMALIZE_FUNCTIONS: bool | str = "upper"
216    """
217    Determines how function names are going to be normalized.
218    Possible values:
219        "upper" or True: Convert names to uppercase.
220        "lower": Convert names to lowercase.
221        False: Disables function name normalization.
222    """
223
224    LOG_BASE_FIRST: t.Optional[bool] = True
225    """
226    Whether the base comes first in the `LOG` function.
227    Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
228    """
229
230    NULL_ORDERING = "nulls_are_small"
231    """
232    Default `NULL` ordering method to use if not explicitly set.
233    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
234    """
235
236    TYPED_DIVISION = False
237    """
238    Whether the behavior of `a / b` depends on the types of `a` and `b`.
239    False means `a / b` is always float division.
240    True means `a / b` is integer division if both `a` and `b` are integers.
241    """
242
243    SAFE_DIVISION = False
244    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
245
246    CONCAT_COALESCE = False
247    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
248
249    DATE_FORMAT = "'%Y-%m-%d'"
250    DATEINT_FORMAT = "'%Y%m%d'"
251    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
252
253    TIME_MAPPING: t.Dict[str, str] = {}
254    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
255
256    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
257    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
258    FORMAT_MAPPING: t.Dict[str, str] = {}
259    """
260    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
261    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
262    """
263
264    UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
265    """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
266
267    PSEUDOCOLUMNS: t.Set[str] = set()
268    """
269    Columns that are auto-generated by the engine corresponding to this dialect.
270    For example, such columns may be excluded from `SELECT *` queries.
271    """
272
273    PREFER_CTE_ALIAS_COLUMN = False
274    """
275    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
276    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
277    any projection aliases in the subquery.
278
279    For example,
280        WITH y(c) AS (
281            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
282        ) SELECT c FROM y;
283
284        will be rewritten as
285
286        WITH y(c) AS (
287            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
288        ) SELECT c FROM y;
289    """
290
291    # --- Autofilled ---
292
293    tokenizer_class = Tokenizer
294    parser_class = Parser
295    generator_class = Generator
296
297    # A trie of the time_mapping keys
298    TIME_TRIE: t.Dict = {}
299    FORMAT_TRIE: t.Dict = {}
300
301    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
302    INVERSE_TIME_TRIE: t.Dict = {}
303
304    ESCAPED_SEQUENCES: t.Dict[str, str] = {}
305
306    # Delimiters for string literals and identifiers
307    QUOTE_START = "'"
308    QUOTE_END = "'"
309    IDENTIFIER_START = '"'
310    IDENTIFIER_END = '"'
311
312    # Delimiters for bit, hex, byte and unicode literals
313    BIT_START: t.Optional[str] = None
314    BIT_END: t.Optional[str] = None
315    HEX_START: t.Optional[str] = None
316    HEX_END: t.Optional[str] = None
317    BYTE_START: t.Optional[str] = None
318    BYTE_END: t.Optional[str] = None
319    UNICODE_START: t.Optional[str] = None
320    UNICODE_END: t.Optional[str] = None
321
322    @classmethod
323    def get_or_raise(cls, dialect: DialectType) -> Dialect:
324        """
325        Look up a dialect in the global dialect registry and return it if it exists.
326
327        Args:
328            dialect: The target dialect. If this is a string, it can be optionally followed by
329                additional key-value pairs that are separated by commas and are used to specify
330                dialect settings, such as whether the dialect's identifiers are case-sensitive.
331
332        Example:
333            >>> dialect = dialect_class = get_or_raise("duckdb")
334            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
335
336        Returns:
337            The corresponding Dialect instance.
338        """
339
340        if not dialect:
341            return cls()
342        if isinstance(dialect, _Dialect):
343            return dialect()
344        if isinstance(dialect, Dialect):
345            return dialect
346        if isinstance(dialect, str):
347            try:
348                dialect_name, *kv_pairs = dialect.split(",")
349                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
350            except ValueError:
351                raise ValueError(
352                    f"Invalid dialect format: '{dialect}'. "
353                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
354                )
355
356            result = cls.get(dialect_name.strip())
357            if not result:
358                from difflib import get_close_matches
359
360                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
361                if similar:
362                    similar = f" Did you mean {similar}?"
363
364                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
365
366            return result(**kwargs)
367
368        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
369
370    @classmethod
371    def format_time(
372        cls, expression: t.Optional[str | exp.Expression]
373    ) -> t.Optional[exp.Expression]:
374        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
375        if isinstance(expression, str):
376            return exp.Literal.string(
377                # the time formats are quoted
378                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
379            )
380
381        if expression and expression.is_string:
382            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
383
384        return expression
385
386    def __init__(self, **kwargs) -> None:
387        normalization_strategy = kwargs.get("normalization_strategy")
388
389        if normalization_strategy is None:
390            self.normalization_strategy = self.NORMALIZATION_STRATEGY
391        else:
392            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
393
394    def __eq__(self, other: t.Any) -> bool:
395        # Does not currently take dialect state into account
396        return type(self) == other
397
398    def __hash__(self) -> int:
399        # Does not currently take dialect state into account
400        return hash(type(self))
401
402    def normalize_identifier(self, expression: E) -> E:
403        """
404        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
405
406        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
407        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
408        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
409        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
410
411        There are also dialects like Spark, which are case-insensitive even when quotes are
412        present, and dialects like MySQL, whose resolution rules match those employed by the
413        underlying operating system, for example they may always be case-sensitive in Linux.
414
415        Finally, the normalization behavior of some engines can even be controlled through flags,
416        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
417
418        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
419        that it can analyze queries in the optimizer and successfully capture their semantics.
420        """
421        if (
422            isinstance(expression, exp.Identifier)
423            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
424            and (
425                not expression.quoted
426                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
427            )
428        ):
429            expression.set(
430                "this",
431                (
432                    expression.this.upper()
433                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
434                    else expression.this.lower()
435                ),
436            )
437
438        return expression
439
440    def case_sensitive(self, text: str) -> bool:
441        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
442        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
443            return False
444
445        unsafe = (
446            str.islower
447            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
448            else str.isupper
449        )
450        return any(unsafe(char) for char in text)
451
452    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
453        """Checks if text can be identified given an identify option.
454
455        Args:
456            text: The text to check.
457            identify:
458                `"always"` or `True`: Always returns `True`.
459                `"safe"`: Only returns `True` if the identifier is case-insensitive.
460
461        Returns:
462            Whether the given text can be identified.
463        """
464        if identify is True or identify == "always":
465            return True
466
467        if identify == "safe":
468            return not self.case_sensitive(text)
469
470        return False
471
472    def quote_identifier(self, expression: E, identify: bool = True) -> E:
473        """
474        Adds quotes to a given identifier.
475
476        Args:
477            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
478            identify: If set to `False`, the quotes will only be added if the identifier is deemed
479                "unsafe", with respect to its characters and this dialect's normalization strategy.
480        """
481        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
482            name = expression.this
483            expression.set(
484                "quoted",
485                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
486            )
487
488        return expression
489
490    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
491        if isinstance(path, exp.Literal):
492            path_text = path.name
493            if path.is_number:
494                path_text = f"[{path_text}]"
495
496            try:
497                return parse_json_path(path_text)
498            except ParseError as e:
499                logger.warning(f"Invalid JSON path syntax. {str(e)}")
500
501        return path
502
503    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
504        return self.parser(**opts).parse(self.tokenize(sql), sql)
505
506    def parse_into(
507        self, expression_type: exp.IntoType, sql: str, **opts
508    ) -> t.List[t.Optional[exp.Expression]]:
509        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
510
511    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
512        return self.generator(**opts).generate(expression, copy=copy)
513
514    def transpile(self, sql: str, **opts) -> t.List[str]:
515        return [
516            self.generate(expression, copy=False, **opts) if expression else ""
517            for expression in self.parse(sql)
518        ]
519
520    def tokenize(self, sql: str) -> t.List[Token]:
521        return self.tokenizer.tokenize(sql)
522
523    @property
524    def tokenizer(self) -> Tokenizer:
525        if not hasattr(self, "_tokenizer"):
526            self._tokenizer = self.tokenizer_class(dialect=self)
527        return self._tokenizer
528
529    def parser(self, **opts) -> Parser:
530        return self.parser_class(dialect=self, **opts)
531
532    def generator(self, **opts) -> Generator:
533        return self.generator_class(dialect=self, **opts)
Dialect(**kwargs)
386    def __init__(self, **kwargs) -> None:
387        normalization_strategy = kwargs.get("normalization_strategy")
388
389        if normalization_strategy is None:
390            self.normalization_strategy = self.NORMALIZATION_STRATEGY
391        else:
392            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
INDEX_OFFSET = 0

The base index offset for arrays.

WEEK_OFFSET = 0

First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.

UNNEST_COLUMN_ONLY = False

Whether UNNEST table aliases are treated as column aliases.

ALIAS_POST_TABLESAMPLE = False

Whether the table alias comes after tablesample.

TABLESAMPLE_SIZE_IS_PERCENT = False

Whether a size in the table sample clause represents percentage.

NORMALIZATION_STRATEGY = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Specifies the strategy according to which identifiers should be normalized.

IDENTIFIERS_CAN_START_WITH_DIGIT = False

Whether an unquoted identifier can start with a digit.

DPIPE_IS_STRING_CONCAT = True

Whether the DPIPE token (||) is a string concatenation operator.

STRICT_STRING_CONCAT = False

Whether CONCAT's arguments must be strings.

SUPPORTS_USER_DEFINED_TYPES = True

Whether user-defined data types are supported.

SUPPORTS_SEMI_ANTI_JOIN = True

Whether SEMI or ANTI joins are supported.

NORMALIZE_FUNCTIONS: bool | str = 'upper'

Determines how function names are going to be normalized.

Possible values:

"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.

LOG_BASE_FIRST: Optional[bool] = True

Whether the base comes first in the LOG function. Possible values: True, False, None (two arguments are not supported by LOG)

NULL_ORDERING = 'nulls_are_small'

Default NULL ordering method to use if not explicitly set. Possible values: "nulls_are_small", "nulls_are_large", "nulls_are_last"

TYPED_DIVISION = False

Whether the behavior of a / b depends on the types of a and b. False means a / b is always float division. True means a / b is integer division if both a and b are integers.

SAFE_DIVISION = False

Whether division by zero throws an error (False) or returns NULL (True).

CONCAT_COALESCE = False

A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string.

DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}

Associates this dialect's time formats with their equivalent Python strftime formats.

FORMAT_MAPPING: Dict[str, str] = {}

Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy'). If empty, the corresponding trie will be constructed off of TIME_MAPPING.

UNESCAPED_SEQUENCES: Dict[str, str] = {}

Mapping of an escaped sequence (\n) to its unescaped version ( ).

PSEUDOCOLUMNS: Set[str] = set()

Columns that are auto-generated by the engine corresponding to this dialect. For example, such columns may be excluded from SELECT * queries.

PREFER_CTE_ALIAS_COLUMN = False

Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.

For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;

will be rewritten as

WITH y(c) AS (
    SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
ESCAPED_SEQUENCES: Dict[str, str] = {}
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START: Optional[str] = None
BIT_END: Optional[str] = None
HEX_START: Optional[str] = None
HEX_END: Optional[str] = None
BYTE_START: Optional[str] = None
BYTE_END: Optional[str] = None
UNICODE_START: Optional[str] = None
UNICODE_END: Optional[str] = None
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Dialect:
322    @classmethod
323    def get_or_raise(cls, dialect: DialectType) -> Dialect:
324        """
325        Look up a dialect in the global dialect registry and return it if it exists.
326
327        Args:
328            dialect: The target dialect. If this is a string, it can be optionally followed by
329                additional key-value pairs that are separated by commas and are used to specify
330                dialect settings, such as whether the dialect's identifiers are case-sensitive.
331
332        Example:
333            >>> dialect = dialect_class = get_or_raise("duckdb")
334            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
335
336        Returns:
337            The corresponding Dialect instance.
338        """
339
340        if not dialect:
341            return cls()
342        if isinstance(dialect, _Dialect):
343            return dialect()
344        if isinstance(dialect, Dialect):
345            return dialect
346        if isinstance(dialect, str):
347            try:
348                dialect_name, *kv_pairs = dialect.split(",")
349                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
350            except ValueError:
351                raise ValueError(
352                    f"Invalid dialect format: '{dialect}'. "
353                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
354                )
355
356            result = cls.get(dialect_name.strip())
357            if not result:
358                from difflib import get_close_matches
359
360                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
361                if similar:
362                    similar = f" Did you mean {similar}?"
363
364                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
365
366            return result(**kwargs)
367
368        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")

Look up a dialect in the global dialect registry and return it if it exists.

Arguments:
  • dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb")
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:

The corresponding Dialect instance.

@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
370    @classmethod
371    def format_time(
372        cls, expression: t.Optional[str | exp.Expression]
373    ) -> t.Optional[exp.Expression]:
374        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
375        if isinstance(expression, str):
376            return exp.Literal.string(
377                # the time formats are quoted
378                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
379            )
380
381        if expression and expression.is_string:
382            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
383
384        return expression

Converts a time format in this dialect to its equivalent Python strftime format.

def normalize_identifier(self, expression: ~E) -> ~E:
402    def normalize_identifier(self, expression: E) -> E:
403        """
404        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
405
406        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
407        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
408        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
409        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
410
411        There are also dialects like Spark, which are case-insensitive even when quotes are
412        present, and dialects like MySQL, whose resolution rules match those employed by the
413        underlying operating system, for example they may always be case-sensitive in Linux.
414
415        Finally, the normalization behavior of some engines can even be controlled through flags,
416        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
417
418        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
419        that it can analyze queries in the optimizer and successfully capture their semantics.
420        """
421        if (
422            isinstance(expression, exp.Identifier)
423            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
424            and (
425                not expression.quoted
426                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
427            )
428        ):
429            expression.set(
430                "this",
431                (
432                    expression.this.upper()
433                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
434                    else expression.this.lower()
435                ),
436            )
437
438        return expression

Transforms an identifier in a way that resembles how it'd be resolved by this dialect.

For example, an identifier like FoO would be resolved as foo in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier.

There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.

Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.

SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.

def case_sensitive(self, text: str) -> bool:
440    def case_sensitive(self, text: str) -> bool:
441        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
442        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
443            return False
444
445        unsafe = (
446            str.islower
447            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
448            else str.isupper
449        )
450        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

def can_identify(self, text: str, identify: str | bool = 'safe') -> bool:
452    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
453        """Checks if text can be identified given an identify option.
454
455        Args:
456            text: The text to check.
457            identify:
458                `"always"` or `True`: Always returns `True`.
459                `"safe"`: Only returns `True` if the identifier is case-insensitive.
460
461        Returns:
462            Whether the given text can be identified.
463        """
464        if identify is True or identify == "always":
465            return True
466
467        if identify == "safe":
468            return not self.case_sensitive(text)
469
470        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns True. "safe": Only returns True if the identifier is case-insensitive.
Returns:

Whether the given text can be identified.

def quote_identifier(self, expression: ~E, identify: bool = True) -> ~E:
472    def quote_identifier(self, expression: E, identify: bool = True) -> E:
473        """
474        Adds quotes to a given identifier.
475
476        Args:
477            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
478            identify: If set to `False`, the quotes will only be added if the identifier is deemed
479                "unsafe", with respect to its characters and this dialect's normalization strategy.
480        """
481        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
482            name = expression.this
483            expression.set(
484                "quoted",
485                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
486            )
487
488        return expression

Adds quotes to a given identifier.

Arguments:
  • expression: The expression of interest. If it's not an Identifier, this method is a no-op.
  • identify: If set to False, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
def to_json_path( self, path: Optional[sqlglot.expressions.Expression]) -> Optional[sqlglot.expressions.Expression]:
490    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
491        if isinstance(path, exp.Literal):
492            path_text = path.name
493            if path.is_number:
494                path_text = f"[{path_text}]"
495
496            try:
497                return parse_json_path(path_text)
498            except ParseError as e:
499                logger.warning(f"Invalid JSON path syntax. {str(e)}")
500
501        return path
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
503    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
504        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
506    def parse_into(
507        self, expression_type: exp.IntoType, sql: str, **opts
508    ) -> t.List[t.Optional[exp.Expression]]:
509        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: sqlglot.expressions.Expression, copy: bool = True, **opts) -> str:
511    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
512        return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> List[str]:
514    def transpile(self, sql: str, **opts) -> t.List[str]:
515        return [
516            self.generate(expression, copy=False, **opts) if expression else ""
517            for expression in self.parse(sql)
518        ]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
520    def tokenize(self, sql: str) -> t.List[Token]:
521        return self.tokenizer.tokenize(sql)
tokenizer: sqlglot.tokens.Tokenizer
523    @property
524    def tokenizer(self) -> Tokenizer:
525        if not hasattr(self, "_tokenizer"):
526            self._tokenizer = self.tokenizer_class(dialect=self)
527        return self._tokenizer
def parser(self, **opts) -> sqlglot.parser.Parser:
529    def parser(self, **opts) -> Parser:
530        return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
532    def generator(self, **opts) -> Generator:
533        return self.generator_class(dialect=self, **opts)
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
539def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
540    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
543def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
544    if expression.args.get("accuracy"):
545        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
546    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
549def if_sql(
550    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
551) -> t.Callable[[Generator, exp.If], str]:
552    def _if_sql(self: Generator, expression: exp.If) -> str:
553        return self.func(
554            name,
555            expression.this,
556            expression.args.get("true"),
557            expression.args.get("false") or false_value,
558        )
559
560    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]) -> str:
563def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
564    this = expression.this
565    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
566        this.replace(exp.cast(this, "json"))
567
568    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
571def inline_array_sql(self: Generator, expression: exp.Array) -> str:
572    return f"[{self.expressions(expression, flat=True)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
575def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
576    return self.like_sql(
577        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
578    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
581def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
582    zone = self.sql(expression, "this")
583    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
586def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
587    if expression.args.get("recursive"):
588        self.unsupported("Recursive CTEs are unsupported")
589        expression.args["recursive"] = False
590    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
593def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
594    n = self.sql(expression, "this")
595    d = self.sql(expression, "expression")
596    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
599def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
600    self.unsupported("TABLESAMPLE unsupported")
601    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
604def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
605    self.unsupported("PIVOT unsupported")
606    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
609def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
610    return self.cast_sql(expression)
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
613def no_comment_column_constraint_sql(
614    self: Generator, expression: exp.CommentColumnConstraint
615) -> str:
616    self.unsupported("CommentColumnConstraint unsupported")
617    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
620def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
621    self.unsupported("MAP_FROM_ENTRIES unsupported")
622    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition, generate_instance: bool = False) -> str:
625def str_position_sql(
626    self: Generator, expression: exp.StrPosition, generate_instance: bool = False
627) -> str:
628    this = self.sql(expression, "this")
629    substr = self.sql(expression, "substr")
630    position = self.sql(expression, "position")
631    instance = expression.args.get("instance") if generate_instance else None
632    position_offset = ""
633
634    if position:
635        # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
636        this = self.func("SUBSTR", this, position)
637        position_offset = f" + {position} - 1"
638
639    return self.func("STRPOS", this, substr, instance) + position_offset
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
642def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
643    return (
644        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
645    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
648def var_map_sql(
649    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
650) -> str:
651    keys = expression.args["keys"]
652    values = expression.args["values"]
653
654    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
655        self.unsupported("Cannot convert array columns into map.")
656        return self.func(map_func_name, keys, values)
657
658    args = []
659    for key, value in zip(keys.expressions, values.expressions):
660        args.append(self.sql(key))
661        args.append(self.sql(value))
662
663    return self.func(map_func_name, *args)
def build_formatted_time( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
666def build_formatted_time(
667    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
668) -> t.Callable[[t.List], E]:
669    """Helper used for time expressions.
670
671    Args:
672        exp_class: the expression class to instantiate.
673        dialect: target sql dialect.
674        default: the default format, True being time.
675
676    Returns:
677        A callable that can be used to return the appropriately formatted time expression.
678    """
679
680    def _builder(args: t.List):
681        return exp_class(
682            this=seq_get(args, 0),
683            format=Dialect[dialect].format_time(
684                seq_get(args, 1)
685                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
686            ),
687        )
688
689    return _builder

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
692def time_format(
693    dialect: DialectType = None,
694) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
695    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
696        """
697        Returns the time format for a given expression, unless it's equivalent
698        to the default time format of the dialect of interest.
699        """
700        time_format = self.format_time(expression)
701        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
702
703    return _time_format
def build_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
706def build_date_delta(
707    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
708) -> t.Callable[[t.List], E]:
709    def _builder(args: t.List) -> E:
710        unit_based = len(args) == 3
711        this = args[2] if unit_based else seq_get(args, 0)
712        unit = args[0] if unit_based else exp.Literal.string("DAY")
713        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
714        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
715
716    return _builder
def build_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
719def build_date_delta_with_interval(
720    expression_class: t.Type[E],
721) -> t.Callable[[t.List], t.Optional[E]]:
722    def _builder(args: t.List) -> t.Optional[E]:
723        if len(args) < 2:
724            return None
725
726        interval = args[1]
727
728        if not isinstance(interval, exp.Interval):
729            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
730
731        expression = interval.this
732        if expression and expression.is_string:
733            expression = exp.Literal.number(expression.this)
734
735        return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
736
737    return _builder
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
740def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
741    unit = seq_get(args, 0)
742    this = seq_get(args, 1)
743
744    if isinstance(this, exp.Cast) and this.is_type("date"):
745        return exp.DateTrunc(unit=unit, this=this)
746    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
749def date_add_interval_sql(
750    data_type: str, kind: str
751) -> t.Callable[[Generator, exp.Expression], str]:
752    def func(self: Generator, expression: exp.Expression) -> str:
753        this = self.sql(expression, "this")
754        interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
755        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
756
757    return func
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
760def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
761    return self.func("DATE_TRUNC", unit_to_str(expression), expression.this)
def no_timestamp_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Timestamp) -> str:
764def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
765    if not expression.expression:
766        from sqlglot.optimizer.annotate_types import annotate_types
767
768        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
769        return self.sql(exp.cast(expression.this, to=target_type))
770    if expression.text("expression").lower() in TIMEZONES:
771        return self.sql(
772            exp.AtTimeZone(
773                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
774                zone=expression.expression,
775            )
776        )
777    return self.func("TIMESTAMP", expression.this, expression.expression)
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
780def locate_to_strposition(args: t.List) -> exp.Expression:
781    return exp.StrPosition(
782        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
783    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
786def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
787    return self.func(
788        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
789    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
792def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
793    return self.sql(
794        exp.Substring(
795            this=expression.this, start=exp.Literal.number(1), length=expression.expression
796        )
797    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
800def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
801    return self.sql(
802        exp.Substring(
803            this=expression.this,
804            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
805        )
806    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
809def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
810    return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
813def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
814    return self.sql(exp.cast(expression.this, "date"))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
818def encode_decode_sql(
819    self: Generator, expression: exp.Expression, name: str, replace: bool = True
820) -> str:
821    charset = expression.args.get("charset")
822    if charset and charset.name.lower() != "utf-8":
823        self.unsupported(f"Expected utf-8 character set, got {charset}.")
824
825    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
828def min_or_least(self: Generator, expression: exp.Min) -> str:
829    name = "LEAST" if expression.expressions else "MIN"
830    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
833def max_or_greatest(self: Generator, expression: exp.Max) -> str:
834    name = "GREATEST" if expression.expressions else "MAX"
835    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
838def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
839    cond = expression.this
840
841    if isinstance(expression.this, exp.Distinct):
842        cond = expression.this.expressions[0]
843        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
844
845    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
848def trim_sql(self: Generator, expression: exp.Trim) -> str:
849    target = self.sql(expression, "this")
850    trim_type = self.sql(expression, "position")
851    remove_chars = self.sql(expression, "expression")
852    collation = self.sql(expression, "collation")
853
854    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
855    if not remove_chars and not collation:
856        return self.trim_sql(expression)
857
858    trim_type = f"{trim_type} " if trim_type else ""
859    remove_chars = f"{remove_chars} " if remove_chars else ""
860    from_part = "FROM " if trim_type or remove_chars else ""
861    collation = f" COLLATE {collation}" if collation else ""
862    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
865def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
866    return self.func("STRPTIME", expression.this, self.format_time(expression))
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat) -> str:
869def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
870    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
873def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
874    delim, *rest_args = expression.expressions
875    return self.sql(
876        reduce(
877            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
878            rest_args,
879        )
880    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
883def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
884    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
885    if bad_args:
886        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
887
888    return self.func(
889        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
890    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
893def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
894    bad_args = list(
895        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
896    )
897    if bad_args:
898        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
899
900    return self.func(
901        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
902    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
905def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
906    names = []
907    for agg in aggregations:
908        if isinstance(agg, exp.Alias):
909            names.append(agg.alias)
910        else:
911            """
912            This case corresponds to aggregations without aliases being used as suffixes
913            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
914            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
915            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
916            """
917            agg_all_unquoted = agg.transform(
918                lambda node: (
919                    exp.Identifier(this=node.name, quoted=False)
920                    if isinstance(node, exp.Identifier)
921                    else node
922                )
923            )
924            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
925
926    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
929def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
930    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def build_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
934def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
935    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
938def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
939    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
942def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
943    a = self.sql(expression.left)
944    b = self.sql(expression.right)
945    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
948def is_parse_json(expression: exp.Expression) -> bool:
949    return isinstance(expression, exp.ParseJSON) or (
950        isinstance(expression, exp.Cast) and expression.is_type("json")
951    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
954def isnull_to_is_null(args: t.List) -> exp.Expression:
955    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def generatedasidentitycolumnconstraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.GeneratedAsIdentityColumnConstraint) -> str:
958def generatedasidentitycolumnconstraint_sql(
959    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
960) -> str:
961    start = self.sql(expression, "start") or "1"
962    increment = self.sql(expression, "increment") or "1"
963    return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.ArgMax | sqlglot.expressions.ArgMin], str]:
966def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
967    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
968        if expression.args.get("count"):
969            self.unsupported(f"Only two arguments are supported in function {name}.")
970
971        return self.func(name, expression.this, expression.expression)
972
973    return _arg_max_or_min_sql
def ts_or_ds_add_cast( expression: sqlglot.expressions.TsOrDsAdd) -> sqlglot.expressions.TsOrDsAdd:
976def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
977    this = expression.this.copy()
978
979    return_type = expression.return_type
980    if return_type.is_type(exp.DataType.Type.DATE):
981        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
982        # can truncate timestamp strings, because some dialects can't cast them to DATE
983        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
984
985    expression.this.replace(exp.cast(this, return_type))
986    return expression
def date_delta_sql( name: str, cast: bool = False) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.DateAdd, sqlglot.expressions.TsOrDsAdd, sqlglot.expressions.DateDiff, sqlglot.expressions.TsOrDsDiff]], str]:
 989def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
 990    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
 991        if cast and isinstance(expression, exp.TsOrDsAdd):
 992            expression = ts_or_ds_add_cast(expression)
 993
 994        return self.func(
 995            name,
 996            unit_to_var(expression),
 997            expression.expression,
 998            expression.this,
 999        )
1000
1001    return _delta_sql
def unit_to_str( expression: sqlglot.expressions.Expression, default: str = 'DAY') -> Optional[sqlglot.expressions.Expression]:
1004def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1005    unit = expression.args.get("unit")
1006
1007    if isinstance(unit, exp.Placeholder):
1008        return unit
1009    if unit:
1010        return exp.Literal.string(unit.name)
1011    return exp.Literal.string(default) if default else None
def unit_to_var( expression: sqlglot.expressions.Expression, default: str = 'DAY') -> Optional[sqlglot.expressions.Expression]:
1014def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1015    unit = expression.args.get("unit")
1016
1017    if isinstance(unit, (exp.Var, exp.Placeholder)):
1018        return unit
1019    return exp.Var(this=default) if default else None
def no_last_day_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.LastDay) -> str:
1022def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
1023    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
1024    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
1025    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
1026
1027    return self.sql(exp.cast(minus_one_day, "date"))
def merge_without_target_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Merge) -> str:
1030def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1031    """Remove table refs from columns in when statements."""
1032    alias = expression.this.args.get("alias")
1033
1034    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
1035        return self.dialect.normalize_identifier(identifier).name if identifier else None
1036
1037    targets = {normalize(expression.this.this)}
1038
1039    if alias:
1040        targets.add(normalize(alias.this))
1041
1042    for when in expression.expressions:
1043        when.transform(
1044            lambda node: (
1045                exp.column(node.this)
1046                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1047                else node
1048            ),
1049            copy=False,
1050        )
1051
1052    return self.merge_sql(expression)

Remove table refs from columns in when statements.

def build_json_extract_path( expr_type: Type[~F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False) -> Callable[[List], ~F]:
1055def build_json_extract_path(
1056    expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
1057) -> t.Callable[[t.List], F]:
1058    def _builder(args: t.List) -> F:
1059        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1060        for arg in args[1:]:
1061            if not isinstance(arg, exp.Literal):
1062                # We use the fallback parser because we can't really transpile non-literals safely
1063                return expr_type.from_arg_list(args)
1064
1065            text = arg.name
1066            if is_int(text):
1067                index = int(text)
1068                segments.append(
1069                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1070                )
1071            else:
1072                segments.append(exp.JSONPathKey(this=text))
1073
1074        # This is done to avoid failing in the expression validator due to the arg count
1075        del args[2:]
1076        return expr_type(
1077            this=seq_get(args, 0),
1078            expression=exp.JSONPath(expressions=segments),
1079            only_json_types=arrow_req_json_type,
1080        )
1081
1082    return _builder
def json_extract_segments( name: str, quoted_index: bool = True, op: Optional[str] = None) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]], str]:
1085def json_extract_segments(
1086    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1087) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1088    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1089        path = expression.expression
1090        if not isinstance(path, exp.JSONPath):
1091            return rename_func(name)(self, expression)
1092
1093        segments = []
1094        for segment in path.expressions:
1095            path = self.sql(segment)
1096            if path:
1097                if isinstance(segment, exp.JSONPathPart) and (
1098                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1099                ):
1100                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1101
1102                segments.append(path)
1103
1104        if op:
1105            return f" {op} ".join([self.sql(expression.this), *segments])
1106        return self.func(name, expression.this, *segments)
1107
1108    return _json_extract_segments
def json_path_key_only_name( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONPathKey) -> str:
1111def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1112    if isinstance(expression.this, exp.JSONPathWildcard):
1113        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1114
1115    return expression.name
def filter_array_using_unnest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ArrayFilter) -> str:
1118def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1119    cond = expression.expression
1120    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1121        alias = cond.expressions[0]
1122        cond = cond.this
1123    elif isinstance(cond, exp.Predicate):
1124        alias = "_u"
1125    else:
1126        self.unsupported("Unsupported filter condition")
1127        return ""
1128
1129    unnest = exp.Unnest(expressions=[expression.this])
1130    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1131    return self.sql(exp.Array(expressions=[filtered]))
def to_number_with_nls_param(self, expression: sqlglot.expressions.ToNumber) -> str:
1134def to_number_with_nls_param(self, expression: exp.ToNumber) -> str:
1135    return self.func(
1136        "TO_NUMBER",
1137        expression.this,
1138        expression.args.get("format"),
1139        expression.args.get("nlsparam"),
1140    )