Edit on GitHub

sqlglot.dialects.dialect

   1from __future__ import annotations
   2
   3import logging
   4import typing as t
   5from enum import Enum, auto
   6from functools import reduce
   7
   8from sqlglot import exp
   9from sqlglot.errors import ParseError
  10from sqlglot.generator import Generator
  11from sqlglot.helper import AutoName, flatten, is_int, seq_get
  12from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path
  13from sqlglot.parser import Parser
  14from sqlglot.time import TIMEZONES, format_time
  15from sqlglot.tokens import Token, Tokenizer, TokenType
  16from sqlglot.trie import new_trie
  17
  18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
  19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
  20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
  21
  22
  23if t.TYPE_CHECKING:
  24    from sqlglot._typing import B, E, F
  25
  26logger = logging.getLogger("sqlglot")
  27
  28UNESCAPED_SEQUENCES = {
  29    "\\a": "\a",
  30    "\\b": "\b",
  31    "\\f": "\f",
  32    "\\n": "\n",
  33    "\\r": "\r",
  34    "\\t": "\t",
  35    "\\v": "\v",
  36    "\\\\": "\\",
  37}
  38
  39
  40class Dialects(str, Enum):
  41    """Dialects supported by SQLGLot."""
  42
  43    DIALECT = ""
  44
  45    ATHENA = "athena"
  46    BIGQUERY = "bigquery"
  47    CLICKHOUSE = "clickhouse"
  48    DATABRICKS = "databricks"
  49    DORIS = "doris"
  50    DRILL = "drill"
  51    DUCKDB = "duckdb"
  52    HIVE = "hive"
  53    MATERIALIZE = "materialize"
  54    MYSQL = "mysql"
  55    ORACLE = "oracle"
  56    POSTGRES = "postgres"
  57    PRESTO = "presto"
  58    PRQL = "prql"
  59    REDSHIFT = "redshift"
  60    RISINGWAVE = "risingwave"
  61    SNOWFLAKE = "snowflake"
  62    SPARK = "spark"
  63    SPARK2 = "spark2"
  64    SQLITE = "sqlite"
  65    STARROCKS = "starrocks"
  66    TABLEAU = "tableau"
  67    TERADATA = "teradata"
  68    TRINO = "trino"
  69    TSQL = "tsql"
  70
  71
  72class NormalizationStrategy(str, AutoName):
  73    """Specifies the strategy according to which identifiers should be normalized."""
  74
  75    LOWERCASE = auto()
  76    """Unquoted identifiers are lowercased."""
  77
  78    UPPERCASE = auto()
  79    """Unquoted identifiers are uppercased."""
  80
  81    CASE_SENSITIVE = auto()
  82    """Always case-sensitive, regardless of quotes."""
  83
  84    CASE_INSENSITIVE = auto()
  85    """Always case-insensitive, regardless of quotes."""
  86
  87
  88class _Dialect(type):
  89    classes: t.Dict[str, t.Type[Dialect]] = {}
  90
  91    def __eq__(cls, other: t.Any) -> bool:
  92        if cls is other:
  93            return True
  94        if isinstance(other, str):
  95            return cls is cls.get(other)
  96        if isinstance(other, Dialect):
  97            return cls is type(other)
  98
  99        return False
 100
 101    def __hash__(cls) -> int:
 102        return hash(cls.__name__.lower())
 103
 104    @classmethod
 105    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 106        return cls.classes[key]
 107
 108    @classmethod
 109    def get(
 110        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 111    ) -> t.Optional[t.Type[Dialect]]:
 112        return cls.classes.get(key, default)
 113
 114    def __new__(cls, clsname, bases, attrs):
 115        klass = super().__new__(cls, clsname, bases, attrs)
 116        enum = Dialects.__members__.get(clsname.upper())
 117        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 118
 119        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 120        klass.FORMAT_TRIE = (
 121            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 122        )
 123        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 124        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 125        klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()}
 126        klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING)
 127
 128        base = seq_get(bases, 0)
 129        base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),)
 130        base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),)
 131        base_parser = (getattr(base, "parser_class", Parser),)
 132        base_generator = (getattr(base, "generator_class", Generator),)
 133
 134        klass.tokenizer_class = klass.__dict__.get(
 135            "Tokenizer", type("Tokenizer", base_tokenizer, {})
 136        )
 137        klass.jsonpath_tokenizer_class = klass.__dict__.get(
 138            "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {})
 139        )
 140        klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {}))
 141        klass.generator_class = klass.__dict__.get(
 142            "Generator", type("Generator", base_generator, {})
 143        )
 144
 145        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 146        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 147            klass.tokenizer_class._IDENTIFIERS.items()
 148        )[0]
 149
 150        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 151            return next(
 152                (
 153                    (s, e)
 154                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 155                    if t == token_type
 156                ),
 157                (None, None),
 158            )
 159
 160        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
 161        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
 162        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
 163        klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
 164
 165        if "\\" in klass.tokenizer_class.STRING_ESCAPES:
 166            klass.UNESCAPED_SEQUENCES = {
 167                **UNESCAPED_SEQUENCES,
 168                **klass.UNESCAPED_SEQUENCES,
 169            }
 170
 171        klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()}
 172
 173        klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS
 174
 175        if enum not in ("", "bigquery"):
 176            klass.generator_class.SELECT_KINDS = ()
 177
 178        if enum not in ("", "athena", "presto", "trino"):
 179            klass.generator_class.TRY_SUPPORTED = False
 180            klass.generator_class.SUPPORTS_UESCAPE = False
 181
 182        if enum not in ("", "databricks", "hive", "spark", "spark2"):
 183            modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
 184            for modifier in ("cluster", "distribute", "sort"):
 185                modifier_transforms.pop(modifier, None)
 186
 187            klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
 188
 189        if enum not in ("", "doris", "mysql"):
 190            klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | {
 191                TokenType.STRAIGHT_JOIN,
 192            }
 193            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
 194                TokenType.STRAIGHT_JOIN,
 195            }
 196
 197        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
 198            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
 199                TokenType.ANTI,
 200                TokenType.SEMI,
 201            }
 202
 203        return klass
 204
 205
 206class Dialect(metaclass=_Dialect):
 207    INDEX_OFFSET = 0
 208    """The base index offset for arrays."""
 209
 210    WEEK_OFFSET = 0
 211    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
 212
 213    UNNEST_COLUMN_ONLY = False
 214    """Whether `UNNEST` table aliases are treated as column aliases."""
 215
 216    ALIAS_POST_TABLESAMPLE = False
 217    """Whether the table alias comes after tablesample."""
 218
 219    TABLESAMPLE_SIZE_IS_PERCENT = False
 220    """Whether a size in the table sample clause represents percentage."""
 221
 222    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
 223    """Specifies the strategy according to which identifiers should be normalized."""
 224
 225    IDENTIFIERS_CAN_START_WITH_DIGIT = False
 226    """Whether an unquoted identifier can start with a digit."""
 227
 228    DPIPE_IS_STRING_CONCAT = True
 229    """Whether the DPIPE token (`||`) is a string concatenation operator."""
 230
 231    STRICT_STRING_CONCAT = False
 232    """Whether `CONCAT`'s arguments must be strings."""
 233
 234    SUPPORTS_USER_DEFINED_TYPES = True
 235    """Whether user-defined data types are supported."""
 236
 237    SUPPORTS_SEMI_ANTI_JOIN = True
 238    """Whether `SEMI` or `ANTI` joins are supported."""
 239
 240    SUPPORTS_COLUMN_JOIN_MARKS = False
 241    """Whether the old-style outer join (+) syntax is supported."""
 242
 243    COPY_PARAMS_ARE_CSV = True
 244    """Separator of COPY statement parameters."""
 245
 246    NORMALIZE_FUNCTIONS: bool | str = "upper"
 247    """
 248    Determines how function names are going to be normalized.
 249    Possible values:
 250        "upper" or True: Convert names to uppercase.
 251        "lower": Convert names to lowercase.
 252        False: Disables function name normalization.
 253    """
 254
 255    LOG_BASE_FIRST: t.Optional[bool] = True
 256    """
 257    Whether the base comes first in the `LOG` function.
 258    Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
 259    """
 260
 261    NULL_ORDERING = "nulls_are_small"
 262    """
 263    Default `NULL` ordering method to use if not explicitly set.
 264    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
 265    """
 266
 267    TYPED_DIVISION = False
 268    """
 269    Whether the behavior of `a / b` depends on the types of `a` and `b`.
 270    False means `a / b` is always float division.
 271    True means `a / b` is integer division if both `a` and `b` are integers.
 272    """
 273
 274    SAFE_DIVISION = False
 275    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
 276
 277    CONCAT_COALESCE = False
 278    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
 279
 280    HEX_LOWERCASE = False
 281    """Whether the `HEX` function returns a lowercase hexadecimal string."""
 282
 283    DATE_FORMAT = "'%Y-%m-%d'"
 284    DATEINT_FORMAT = "'%Y%m%d'"
 285    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
 286
 287    TIME_MAPPING: t.Dict[str, str] = {}
 288    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
 289
 290    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
 291    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
 292    FORMAT_MAPPING: t.Dict[str, str] = {}
 293    """
 294    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
 295    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
 296    """
 297
 298    UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
 299    """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
 300
 301    PSEUDOCOLUMNS: t.Set[str] = set()
 302    """
 303    Columns that are auto-generated by the engine corresponding to this dialect.
 304    For example, such columns may be excluded from `SELECT *` queries.
 305    """
 306
 307    PREFER_CTE_ALIAS_COLUMN = False
 308    """
 309    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
 310    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
 311    any projection aliases in the subquery.
 312
 313    For example,
 314        WITH y(c) AS (
 315            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
 316        ) SELECT c FROM y;
 317
 318        will be rewritten as
 319
 320        WITH y(c) AS (
 321            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
 322        ) SELECT c FROM y;
 323    """
 324
 325    COPY_PARAMS_ARE_CSV = True
 326    """
 327    Whether COPY statement parameters are separated by comma or whitespace
 328    """
 329
 330    FORCE_EARLY_ALIAS_REF_EXPANSION = False
 331    """
 332    Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
 333
 334    For example:
 335        WITH data AS (
 336        SELECT
 337            1 AS id,
 338            2 AS my_id
 339        )
 340        SELECT
 341            id AS my_id
 342        FROM
 343            data
 344        WHERE
 345            my_id = 1
 346        GROUP BY
 347            my_id,
 348        HAVING
 349            my_id = 1
 350
 351    In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except:
 352        - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1"
 353        - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
 354    """
 355
 356    EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False
 357    """Whether alias reference expansion before qualification should only happen for the GROUP BY clause."""
 358
 359    SUPPORTS_ORDER_BY_ALL = False
 360    """
 361    Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks
 362    """
 363
 364    # --- Autofilled ---
 365
 366    tokenizer_class = Tokenizer
 367    jsonpath_tokenizer_class = JSONPathTokenizer
 368    parser_class = Parser
 369    generator_class = Generator
 370
 371    # A trie of the time_mapping keys
 372    TIME_TRIE: t.Dict = {}
 373    FORMAT_TRIE: t.Dict = {}
 374
 375    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
 376    INVERSE_TIME_TRIE: t.Dict = {}
 377    INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {}
 378    INVERSE_FORMAT_TRIE: t.Dict = {}
 379
 380    ESCAPED_SEQUENCES: t.Dict[str, str] = {}
 381
 382    # Delimiters for string literals and identifiers
 383    QUOTE_START = "'"
 384    QUOTE_END = "'"
 385    IDENTIFIER_START = '"'
 386    IDENTIFIER_END = '"'
 387
 388    # Delimiters for bit, hex, byte and unicode literals
 389    BIT_START: t.Optional[str] = None
 390    BIT_END: t.Optional[str] = None
 391    HEX_START: t.Optional[str] = None
 392    HEX_END: t.Optional[str] = None
 393    BYTE_START: t.Optional[str] = None
 394    BYTE_END: t.Optional[str] = None
 395    UNICODE_START: t.Optional[str] = None
 396    UNICODE_END: t.Optional[str] = None
 397
 398    DATE_PART_MAPPING = {
 399        "Y": "YEAR",
 400        "YY": "YEAR",
 401        "YYY": "YEAR",
 402        "YYYY": "YEAR",
 403        "YR": "YEAR",
 404        "YEARS": "YEAR",
 405        "YRS": "YEAR",
 406        "MM": "MONTH",
 407        "MON": "MONTH",
 408        "MONS": "MONTH",
 409        "MONTHS": "MONTH",
 410        "D": "DAY",
 411        "DD": "DAY",
 412        "DAYS": "DAY",
 413        "DAYOFMONTH": "DAY",
 414        "DAY OF WEEK": "DAYOFWEEK",
 415        "WEEKDAY": "DAYOFWEEK",
 416        "DOW": "DAYOFWEEK",
 417        "DW": "DAYOFWEEK",
 418        "WEEKDAY_ISO": "DAYOFWEEKISO",
 419        "DOW_ISO": "DAYOFWEEKISO",
 420        "DW_ISO": "DAYOFWEEKISO",
 421        "DAY OF YEAR": "DAYOFYEAR",
 422        "DOY": "DAYOFYEAR",
 423        "DY": "DAYOFYEAR",
 424        "W": "WEEK",
 425        "WK": "WEEK",
 426        "WEEKOFYEAR": "WEEK",
 427        "WOY": "WEEK",
 428        "WY": "WEEK",
 429        "WEEK_ISO": "WEEKISO",
 430        "WEEKOFYEARISO": "WEEKISO",
 431        "WEEKOFYEAR_ISO": "WEEKISO",
 432        "Q": "QUARTER",
 433        "QTR": "QUARTER",
 434        "QTRS": "QUARTER",
 435        "QUARTERS": "QUARTER",
 436        "H": "HOUR",
 437        "HH": "HOUR",
 438        "HR": "HOUR",
 439        "HOURS": "HOUR",
 440        "HRS": "HOUR",
 441        "M": "MINUTE",
 442        "MI": "MINUTE",
 443        "MIN": "MINUTE",
 444        "MINUTES": "MINUTE",
 445        "MINS": "MINUTE",
 446        "S": "SECOND",
 447        "SEC": "SECOND",
 448        "SECONDS": "SECOND",
 449        "SECS": "SECOND",
 450        "MS": "MILLISECOND",
 451        "MSEC": "MILLISECOND",
 452        "MSECS": "MILLISECOND",
 453        "MSECOND": "MILLISECOND",
 454        "MSECONDS": "MILLISECOND",
 455        "MILLISEC": "MILLISECOND",
 456        "MILLISECS": "MILLISECOND",
 457        "MILLISECON": "MILLISECOND",
 458        "MILLISECONDS": "MILLISECOND",
 459        "US": "MICROSECOND",
 460        "USEC": "MICROSECOND",
 461        "USECS": "MICROSECOND",
 462        "MICROSEC": "MICROSECOND",
 463        "MICROSECS": "MICROSECOND",
 464        "USECOND": "MICROSECOND",
 465        "USECONDS": "MICROSECOND",
 466        "MICROSECONDS": "MICROSECOND",
 467        "NS": "NANOSECOND",
 468        "NSEC": "NANOSECOND",
 469        "NANOSEC": "NANOSECOND",
 470        "NSECOND": "NANOSECOND",
 471        "NSECONDS": "NANOSECOND",
 472        "NANOSECS": "NANOSECOND",
 473        "EPOCH_SECOND": "EPOCH",
 474        "EPOCH_SECONDS": "EPOCH",
 475        "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND",
 476        "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND",
 477        "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND",
 478        "TZH": "TIMEZONE_HOUR",
 479        "TZM": "TIMEZONE_MINUTE",
 480        "DEC": "DECADE",
 481        "DECS": "DECADE",
 482        "DECADES": "DECADE",
 483        "MIL": "MILLENIUM",
 484        "MILS": "MILLENIUM",
 485        "MILLENIA": "MILLENIUM",
 486        "C": "CENTURY",
 487        "CENT": "CENTURY",
 488        "CENTS": "CENTURY",
 489        "CENTURIES": "CENTURY",
 490    }
 491
 492    @classmethod
 493    def get_or_raise(cls, dialect: DialectType) -> Dialect:
 494        """
 495        Look up a dialect in the global dialect registry and return it if it exists.
 496
 497        Args:
 498            dialect: The target dialect. If this is a string, it can be optionally followed by
 499                additional key-value pairs that are separated by commas and are used to specify
 500                dialect settings, such as whether the dialect's identifiers are case-sensitive.
 501
 502        Example:
 503            >>> dialect = dialect_class = get_or_raise("duckdb")
 504            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
 505
 506        Returns:
 507            The corresponding Dialect instance.
 508        """
 509
 510        if not dialect:
 511            return cls()
 512        if isinstance(dialect, _Dialect):
 513            return dialect()
 514        if isinstance(dialect, Dialect):
 515            return dialect
 516        if isinstance(dialect, str):
 517            try:
 518                dialect_name, *kv_strings = dialect.split(",")
 519                kv_pairs = (kv.split("=") for kv in kv_strings)
 520                kwargs = {}
 521                for pair in kv_pairs:
 522                    key = pair[0].strip()
 523                    value: t.Union[bool | str | None] = None
 524
 525                    if len(pair) == 1:
 526                        # Default initialize standalone settings to True
 527                        value = True
 528                    elif len(pair) == 2:
 529                        value = pair[1].strip()
 530
 531                        # Coerce the value to boolean if it matches to the truthy/falsy values below
 532                        value_lower = value.lower()
 533                        if value_lower in ("true", "1"):
 534                            value = True
 535                        elif value_lower in ("false", "0"):
 536                            value = False
 537
 538                    kwargs[key] = value
 539
 540            except ValueError:
 541                raise ValueError(
 542                    f"Invalid dialect format: '{dialect}'. "
 543                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
 544                )
 545
 546            result = cls.get(dialect_name.strip())
 547            if not result:
 548                from difflib import get_close_matches
 549
 550                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
 551                if similar:
 552                    similar = f" Did you mean {similar}?"
 553
 554                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
 555
 556            return result(**kwargs)
 557
 558        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
 559
 560    @classmethod
 561    def format_time(
 562        cls, expression: t.Optional[str | exp.Expression]
 563    ) -> t.Optional[exp.Expression]:
 564        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
 565        if isinstance(expression, str):
 566            return exp.Literal.string(
 567                # the time formats are quoted
 568                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
 569            )
 570
 571        if expression and expression.is_string:
 572            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
 573
 574        return expression
 575
 576    def __init__(self, **kwargs) -> None:
 577        normalization_strategy = kwargs.pop("normalization_strategy", None)
 578
 579        if normalization_strategy is None:
 580            self.normalization_strategy = self.NORMALIZATION_STRATEGY
 581        else:
 582            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
 583
 584        self.settings = kwargs
 585
 586    def __eq__(self, other: t.Any) -> bool:
 587        # Does not currently take dialect state into account
 588        return type(self) == other
 589
 590    def __hash__(self) -> int:
 591        # Does not currently take dialect state into account
 592        return hash(type(self))
 593
 594    def normalize_identifier(self, expression: E) -> E:
 595        """
 596        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
 597
 598        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
 599        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
 600        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
 601        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
 602
 603        There are also dialects like Spark, which are case-insensitive even when quotes are
 604        present, and dialects like MySQL, whose resolution rules match those employed by the
 605        underlying operating system, for example they may always be case-sensitive in Linux.
 606
 607        Finally, the normalization behavior of some engines can even be controlled through flags,
 608        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
 609
 610        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
 611        that it can analyze queries in the optimizer and successfully capture their semantics.
 612        """
 613        if (
 614            isinstance(expression, exp.Identifier)
 615            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
 616            and (
 617                not expression.quoted
 618                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
 619            )
 620        ):
 621            expression.set(
 622                "this",
 623                (
 624                    expression.this.upper()
 625                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 626                    else expression.this.lower()
 627                ),
 628            )
 629
 630        return expression
 631
 632    def case_sensitive(self, text: str) -> bool:
 633        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
 634        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
 635            return False
 636
 637        unsafe = (
 638            str.islower
 639            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 640            else str.isupper
 641        )
 642        return any(unsafe(char) for char in text)
 643
 644    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
 645        """Checks if text can be identified given an identify option.
 646
 647        Args:
 648            text: The text to check.
 649            identify:
 650                `"always"` or `True`: Always returns `True`.
 651                `"safe"`: Only returns `True` if the identifier is case-insensitive.
 652
 653        Returns:
 654            Whether the given text can be identified.
 655        """
 656        if identify is True or identify == "always":
 657            return True
 658
 659        if identify == "safe":
 660            return not self.case_sensitive(text)
 661
 662        return False
 663
 664    def quote_identifier(self, expression: E, identify: bool = True) -> E:
 665        """
 666        Adds quotes to a given identifier.
 667
 668        Args:
 669            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
 670            identify: If set to `False`, the quotes will only be added if the identifier is deemed
 671                "unsafe", with respect to its characters and this dialect's normalization strategy.
 672        """
 673        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
 674            name = expression.this
 675            expression.set(
 676                "quoted",
 677                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
 678            )
 679
 680        return expression
 681
 682    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
 683        if isinstance(path, exp.Literal):
 684            path_text = path.name
 685            if path.is_number:
 686                path_text = f"[{path_text}]"
 687            try:
 688                return parse_json_path(path_text, self)
 689            except ParseError as e:
 690                logger.warning(f"Invalid JSON path syntax. {str(e)}")
 691
 692        return path
 693
 694    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
 695        return self.parser(**opts).parse(self.tokenize(sql), sql)
 696
 697    def parse_into(
 698        self, expression_type: exp.IntoType, sql: str, **opts
 699    ) -> t.List[t.Optional[exp.Expression]]:
 700        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
 701
 702    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
 703        return self.generator(**opts).generate(expression, copy=copy)
 704
 705    def transpile(self, sql: str, **opts) -> t.List[str]:
 706        return [
 707            self.generate(expression, copy=False, **opts) if expression else ""
 708            for expression in self.parse(sql)
 709        ]
 710
 711    def tokenize(self, sql: str) -> t.List[Token]:
 712        return self.tokenizer.tokenize(sql)
 713
 714    @property
 715    def tokenizer(self) -> Tokenizer:
 716        return self.tokenizer_class(dialect=self)
 717
 718    @property
 719    def jsonpath_tokenizer(self) -> JSONPathTokenizer:
 720        return self.jsonpath_tokenizer_class(dialect=self)
 721
 722    def parser(self, **opts) -> Parser:
 723        return self.parser_class(dialect=self, **opts)
 724
 725    def generator(self, **opts) -> Generator:
 726        return self.generator_class(dialect=self, **opts)
 727
 728
 729DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
 730
 731
 732def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
 733    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
 734
 735
 736def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
 737    if expression.args.get("accuracy"):
 738        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
 739    return self.func("APPROX_COUNT_DISTINCT", expression.this)
 740
 741
 742def if_sql(
 743    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
 744) -> t.Callable[[Generator, exp.If], str]:
 745    def _if_sql(self: Generator, expression: exp.If) -> str:
 746        return self.func(
 747            name,
 748            expression.this,
 749            expression.args.get("true"),
 750            expression.args.get("false") or false_value,
 751        )
 752
 753    return _if_sql
 754
 755
 756def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
 757    this = expression.this
 758    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
 759        this.replace(exp.cast(this, exp.DataType.Type.JSON))
 760
 761    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
 762
 763
 764def inline_array_sql(self: Generator, expression: exp.Array) -> str:
 765    return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]"
 766
 767
 768def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
 769    elem = seq_get(expression.expressions, 0)
 770    if isinstance(elem, exp.Expression) and elem.find(exp.Query):
 771        return self.func("ARRAY", elem)
 772    return inline_array_sql(self, expression)
 773
 774
 775def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
 776    return self.like_sql(
 777        exp.Like(
 778            this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression)
 779        )
 780    )
 781
 782
 783def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
 784    zone = self.sql(expression, "this")
 785    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
 786
 787
 788def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
 789    if expression.args.get("recursive"):
 790        self.unsupported("Recursive CTEs are unsupported")
 791        expression.args["recursive"] = False
 792    return self.with_sql(expression)
 793
 794
 795def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
 796    n = self.sql(expression, "this")
 797    d = self.sql(expression, "expression")
 798    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
 799
 800
 801def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
 802    self.unsupported("TABLESAMPLE unsupported")
 803    return self.sql(expression.this)
 804
 805
 806def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
 807    self.unsupported("PIVOT unsupported")
 808    return ""
 809
 810
 811def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
 812    return self.cast_sql(expression)
 813
 814
 815def no_comment_column_constraint_sql(
 816    self: Generator, expression: exp.CommentColumnConstraint
 817) -> str:
 818    self.unsupported("CommentColumnConstraint unsupported")
 819    return ""
 820
 821
 822def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
 823    self.unsupported("MAP_FROM_ENTRIES unsupported")
 824    return ""
 825
 826
 827def str_position_sql(
 828    self: Generator, expression: exp.StrPosition, generate_instance: bool = False
 829) -> str:
 830    this = self.sql(expression, "this")
 831    substr = self.sql(expression, "substr")
 832    position = self.sql(expression, "position")
 833    instance = expression.args.get("instance") if generate_instance else None
 834    position_offset = ""
 835
 836    if position:
 837        # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
 838        this = self.func("SUBSTR", this, position)
 839        position_offset = f" + {position} - 1"
 840
 841    return self.func("STRPOS", this, substr, instance) + position_offset
 842
 843
 844def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
 845    return (
 846        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
 847    )
 848
 849
 850def var_map_sql(
 851    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
 852) -> str:
 853    keys = expression.args["keys"]
 854    values = expression.args["values"]
 855
 856    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
 857        self.unsupported("Cannot convert array columns into map.")
 858        return self.func(map_func_name, keys, values)
 859
 860    args = []
 861    for key, value in zip(keys.expressions, values.expressions):
 862        args.append(self.sql(key))
 863        args.append(self.sql(value))
 864
 865    return self.func(map_func_name, *args)
 866
 867
 868def build_formatted_time(
 869    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
 870) -> t.Callable[[t.List], E]:
 871    """Helper used for time expressions.
 872
 873    Args:
 874        exp_class: the expression class to instantiate.
 875        dialect: target sql dialect.
 876        default: the default format, True being time.
 877
 878    Returns:
 879        A callable that can be used to return the appropriately formatted time expression.
 880    """
 881
 882    def _builder(args: t.List):
 883        return exp_class(
 884            this=seq_get(args, 0),
 885            format=Dialect[dialect].format_time(
 886                seq_get(args, 1)
 887                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
 888            ),
 889        )
 890
 891    return _builder
 892
 893
 894def time_format(
 895    dialect: DialectType = None,
 896) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
 897    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
 898        """
 899        Returns the time format for a given expression, unless it's equivalent
 900        to the default time format of the dialect of interest.
 901        """
 902        time_format = self.format_time(expression)
 903        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
 904
 905    return _time_format
 906
 907
 908def build_date_delta(
 909    exp_class: t.Type[E],
 910    unit_mapping: t.Optional[t.Dict[str, str]] = None,
 911    default_unit: t.Optional[str] = "DAY",
 912) -> t.Callable[[t.List], E]:
 913    def _builder(args: t.List) -> E:
 914        unit_based = len(args) == 3
 915        this = args[2] if unit_based else seq_get(args, 0)
 916        unit = None
 917        if unit_based or default_unit:
 918            unit = args[0] if unit_based else exp.Literal.string(default_unit)
 919            unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
 920        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
 921
 922    return _builder
 923
 924
 925def build_date_delta_with_interval(
 926    expression_class: t.Type[E],
 927) -> t.Callable[[t.List], t.Optional[E]]:
 928    def _builder(args: t.List) -> t.Optional[E]:
 929        if len(args) < 2:
 930            return None
 931
 932        interval = args[1]
 933
 934        if not isinstance(interval, exp.Interval):
 935            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
 936
 937        expression = interval.this
 938        if expression and expression.is_string:
 939            expression = exp.Literal.number(expression.this)
 940
 941        return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
 942
 943    return _builder
 944
 945
 946def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
 947    unit = seq_get(args, 0)
 948    this = seq_get(args, 1)
 949
 950    if isinstance(this, exp.Cast) and this.is_type("date"):
 951        return exp.DateTrunc(unit=unit, this=this)
 952    return exp.TimestampTrunc(this=this, unit=unit)
 953
 954
 955def date_add_interval_sql(
 956    data_type: str, kind: str
 957) -> t.Callable[[Generator, exp.Expression], str]:
 958    def func(self: Generator, expression: exp.Expression) -> str:
 959        this = self.sql(expression, "this")
 960        interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
 961        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
 962
 963    return func
 964
 965
 966def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]:
 967    def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
 968        args = [unit_to_str(expression), expression.this]
 969        if zone:
 970            args.append(expression.args.get("zone"))
 971        return self.func("DATE_TRUNC", *args)
 972
 973    return _timestamptrunc_sql
 974
 975
 976def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
 977    zone = expression.args.get("zone")
 978    if not zone:
 979        from sqlglot.optimizer.annotate_types import annotate_types
 980
 981        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
 982        return self.sql(exp.cast(expression.this, target_type))
 983    if zone.name.lower() in TIMEZONES:
 984        return self.sql(
 985            exp.AtTimeZone(
 986                this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
 987                zone=zone,
 988            )
 989        )
 990    return self.func("TIMESTAMP", expression.this, zone)
 991
 992
 993def no_time_sql(self: Generator, expression: exp.Time) -> str:
 994    # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME)
 995    this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ)
 996    expr = exp.cast(
 997        exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME
 998    )
 999    return self.sql(expr)
1000
1001
1002def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str:
1003    this = expression.this
1004    expr = expression.expression
1005
1006    if expr.name.lower() in TIMEZONES:
1007        # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP)
1008        this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ)
1009        this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP)
1010        return self.sql(this)
1011
1012    this = exp.cast(this, exp.DataType.Type.DATE)
1013    expr = exp.cast(expr, exp.DataType.Type.TIME)
1014
1015    return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
1016
1017
1018def locate_to_strposition(args: t.List) -> exp.Expression:
1019    return exp.StrPosition(
1020        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
1021    )
1022
1023
1024def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
1025    return self.func(
1026        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
1027    )
1028
1029
1030def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
1031    return self.sql(
1032        exp.Substring(
1033            this=expression.this, start=exp.Literal.number(1), length=expression.expression
1034        )
1035    )
1036
1037
1038def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
1039    return self.sql(
1040        exp.Substring(
1041            this=expression.this,
1042            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
1043        )
1044    )
1045
1046
1047def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
1048    return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
1049
1050
1051def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
1052    return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE))
1053
1054
1055# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
1056def encode_decode_sql(
1057    self: Generator, expression: exp.Expression, name: str, replace: bool = True
1058) -> str:
1059    charset = expression.args.get("charset")
1060    if charset and charset.name.lower() != "utf-8":
1061        self.unsupported(f"Expected utf-8 character set, got {charset}.")
1062
1063    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
1064
1065
1066def min_or_least(self: Generator, expression: exp.Min) -> str:
1067    name = "LEAST" if expression.expressions else "MIN"
1068    return rename_func(name)(self, expression)
1069
1070
1071def max_or_greatest(self: Generator, expression: exp.Max) -> str:
1072    name = "GREATEST" if expression.expressions else "MAX"
1073    return rename_func(name)(self, expression)
1074
1075
1076def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
1077    cond = expression.this
1078
1079    if isinstance(expression.this, exp.Distinct):
1080        cond = expression.this.expressions[0]
1081        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
1082
1083    return self.func("sum", exp.func("if", cond, 1, 0))
1084
1085
1086def trim_sql(self: Generator, expression: exp.Trim) -> str:
1087    target = self.sql(expression, "this")
1088    trim_type = self.sql(expression, "position")
1089    remove_chars = self.sql(expression, "expression")
1090    collation = self.sql(expression, "collation")
1091
1092    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
1093    if not remove_chars and not collation:
1094        return self.trim_sql(expression)
1095
1096    trim_type = f"{trim_type} " if trim_type else ""
1097    remove_chars = f"{remove_chars} " if remove_chars else ""
1098    from_part = "FROM " if trim_type or remove_chars else ""
1099    collation = f" COLLATE {collation}" if collation else ""
1100    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
1101
1102
1103def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
1104    return self.func("STRPTIME", expression.this, self.format_time(expression))
1105
1106
1107def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
1108    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
1109
1110
1111def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
1112    delim, *rest_args = expression.expressions
1113    return self.sql(
1114        reduce(
1115            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
1116            rest_args,
1117        )
1118    )
1119
1120
1121def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
1122    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
1123    if bad_args:
1124        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
1125
1126    return self.func(
1127        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
1128    )
1129
1130
1131def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
1132    bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers")))
1133    if bad_args:
1134        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
1135
1136    return self.func(
1137        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
1138    )
1139
1140
1141def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
1142    names = []
1143    for agg in aggregations:
1144        if isinstance(agg, exp.Alias):
1145            names.append(agg.alias)
1146        else:
1147            """
1148            This case corresponds to aggregations without aliases being used as suffixes
1149            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
1150            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
1151            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
1152            """
1153            agg_all_unquoted = agg.transform(
1154                lambda node: (
1155                    exp.Identifier(this=node.name, quoted=False)
1156                    if isinstance(node, exp.Identifier)
1157                    else node
1158                )
1159            )
1160            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
1161
1162    return names
1163
1164
1165def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
1166    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
1167
1168
1169# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
1170def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
1171    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
1172
1173
1174def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
1175    return self.func("MAX", expression.this)
1176
1177
1178def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
1179    a = self.sql(expression.left)
1180    b = self.sql(expression.right)
1181    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
1182
1183
1184def is_parse_json(expression: exp.Expression) -> bool:
1185    return isinstance(expression, exp.ParseJSON) or (
1186        isinstance(expression, exp.Cast) and expression.is_type("json")
1187    )
1188
1189
1190def isnull_to_is_null(args: t.List) -> exp.Expression:
1191    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
1192
1193
1194def generatedasidentitycolumnconstraint_sql(
1195    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
1196) -> str:
1197    start = self.sql(expression, "start") or "1"
1198    increment = self.sql(expression, "increment") or "1"
1199    return f"IDENTITY({start}, {increment})"
1200
1201
1202def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
1203    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
1204        if expression.args.get("count"):
1205            self.unsupported(f"Only two arguments are supported in function {name}.")
1206
1207        return self.func(name, expression.this, expression.expression)
1208
1209    return _arg_max_or_min_sql
1210
1211
1212def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
1213    this = expression.this.copy()
1214
1215    return_type = expression.return_type
1216    if return_type.is_type(exp.DataType.Type.DATE):
1217        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
1218        # can truncate timestamp strings, because some dialects can't cast them to DATE
1219        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
1220
1221    expression.this.replace(exp.cast(this, return_type))
1222    return expression
1223
1224
1225def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
1226    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
1227        if cast and isinstance(expression, exp.TsOrDsAdd):
1228            expression = ts_or_ds_add_cast(expression)
1229
1230        return self.func(
1231            name,
1232            unit_to_var(expression),
1233            expression.expression,
1234            expression.this,
1235        )
1236
1237    return _delta_sql
1238
1239
1240def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1241    unit = expression.args.get("unit")
1242
1243    if isinstance(unit, exp.Placeholder):
1244        return unit
1245    if unit:
1246        return exp.Literal.string(unit.name)
1247    return exp.Literal.string(default) if default else None
1248
1249
1250def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1251    unit = expression.args.get("unit")
1252
1253    if isinstance(unit, (exp.Var, exp.Placeholder)):
1254        return unit
1255    return exp.Var(this=default) if default else None
1256
1257
1258@t.overload
1259def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var:
1260    pass
1261
1262
1263@t.overload
1264def map_date_part(
1265    part: t.Optional[exp.Expression], dialect: DialectType = Dialect
1266) -> t.Optional[exp.Expression]:
1267    pass
1268
1269
1270def map_date_part(part, dialect: DialectType = Dialect):
1271    mapped = (
1272        Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None
1273    )
1274    return exp.var(mapped) if mapped else part
1275
1276
1277def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
1278    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
1279    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
1280    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
1281
1282    return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1283
1284
1285def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1286    """Remove table refs from columns in when statements."""
1287    alias = expression.this.args.get("alias")
1288
1289    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
1290        return self.dialect.normalize_identifier(identifier).name if identifier else None
1291
1292    targets = {normalize(expression.this.this)}
1293
1294    if alias:
1295        targets.add(normalize(alias.this))
1296
1297    for when in expression.expressions:
1298        when.transform(
1299            lambda node: (
1300                exp.column(node.this)
1301                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1302                else node
1303            ),
1304            copy=False,
1305        )
1306
1307    return self.merge_sql(expression)
1308
1309
1310def build_json_extract_path(
1311    expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
1312) -> t.Callable[[t.List], F]:
1313    def _builder(args: t.List) -> F:
1314        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1315        for arg in args[1:]:
1316            if not isinstance(arg, exp.Literal):
1317                # We use the fallback parser because we can't really transpile non-literals safely
1318                return expr_type.from_arg_list(args)
1319
1320            text = arg.name
1321            if is_int(text):
1322                index = int(text)
1323                segments.append(
1324                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1325                )
1326            else:
1327                segments.append(exp.JSONPathKey(this=text))
1328
1329        # This is done to avoid failing in the expression validator due to the arg count
1330        del args[2:]
1331        return expr_type(
1332            this=seq_get(args, 0),
1333            expression=exp.JSONPath(expressions=segments),
1334            only_json_types=arrow_req_json_type,
1335        )
1336
1337    return _builder
1338
1339
1340def json_extract_segments(
1341    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1342) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1343    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1344        path = expression.expression
1345        if not isinstance(path, exp.JSONPath):
1346            return rename_func(name)(self, expression)
1347
1348        segments = []
1349        for segment in path.expressions:
1350            path = self.sql(segment)
1351            if path:
1352                if isinstance(segment, exp.JSONPathPart) and (
1353                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1354                ):
1355                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1356
1357                segments.append(path)
1358
1359        if op:
1360            return f" {op} ".join([self.sql(expression.this), *segments])
1361        return self.func(name, expression.this, *segments)
1362
1363    return _json_extract_segments
1364
1365
1366def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1367    if isinstance(expression.this, exp.JSONPathWildcard):
1368        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1369
1370    return expression.name
1371
1372
1373def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1374    cond = expression.expression
1375    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1376        alias = cond.expressions[0]
1377        cond = cond.this
1378    elif isinstance(cond, exp.Predicate):
1379        alias = "_u"
1380    else:
1381        self.unsupported("Unsupported filter condition")
1382        return ""
1383
1384    unnest = exp.Unnest(expressions=[expression.this])
1385    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1386    return self.sql(exp.Array(expressions=[filtered]))
1387
1388
1389def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str:
1390    return self.func(
1391        "TO_NUMBER",
1392        expression.this,
1393        expression.args.get("format"),
1394        expression.args.get("nlsparam"),
1395    )
1396
1397
1398def build_default_decimal_type(
1399    precision: t.Optional[int] = None, scale: t.Optional[int] = None
1400) -> t.Callable[[exp.DataType], exp.DataType]:
1401    def _builder(dtype: exp.DataType) -> exp.DataType:
1402        if dtype.expressions or precision is None:
1403            return dtype
1404
1405        params = f"{precision}{f', {scale}' if scale is not None else ''}"
1406        return exp.DataType.build(f"DECIMAL({params})")
1407
1408    return _builder
1409
1410
1411def build_timestamp_from_parts(args: t.List) -> exp.Func:
1412    if len(args) == 2:
1413        # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept,
1414        # so we parse this into Anonymous for now instead of introducing complexity
1415        return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args)
1416
1417    return exp.TimestampFromParts.from_arg_list(args)
1418
1419
1420def sha256_sql(self: Generator, expression: exp.SHA2) -> str:
1421    return self.func(f"SHA{expression.text('length') or '256'}", expression.this)
logger = <Logger sqlglot (WARNING)>
UNESCAPED_SEQUENCES = {'\\a': '\x07', '\\b': '\x08', '\\f': '\x0c', '\\n': '\n', '\\r': '\r', '\\t': '\t', '\\v': '\x0b', '\\\\': '\\'}
class Dialects(builtins.str, enum.Enum):
41class Dialects(str, Enum):
42    """Dialects supported by SQLGLot."""
43
44    DIALECT = ""
45
46    ATHENA = "athena"
47    BIGQUERY = "bigquery"
48    CLICKHOUSE = "clickhouse"
49    DATABRICKS = "databricks"
50    DORIS = "doris"
51    DRILL = "drill"
52    DUCKDB = "duckdb"
53    HIVE = "hive"
54    MATERIALIZE = "materialize"
55    MYSQL = "mysql"
56    ORACLE = "oracle"
57    POSTGRES = "postgres"
58    PRESTO = "presto"
59    PRQL = "prql"
60    REDSHIFT = "redshift"
61    RISINGWAVE = "risingwave"
62    SNOWFLAKE = "snowflake"
63    SPARK = "spark"
64    SPARK2 = "spark2"
65    SQLITE = "sqlite"
66    STARROCKS = "starrocks"
67    TABLEAU = "tableau"
68    TERADATA = "teradata"
69    TRINO = "trino"
70    TSQL = "tsql"

Dialects supported by SQLGLot.

DIALECT = <Dialects.DIALECT: ''>
ATHENA = <Dialects.ATHENA: 'athena'>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DORIS = <Dialects.DORIS: 'doris'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MATERIALIZE = <Dialects.MATERIALIZE: 'materialize'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
PRQL = <Dialects.PRQL: 'prql'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
RISINGWAVE = <Dialects.RISINGWAVE: 'risingwave'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class NormalizationStrategy(builtins.str, sqlglot.helper.AutoName):
73class NormalizationStrategy(str, AutoName):
74    """Specifies the strategy according to which identifiers should be normalized."""
75
76    LOWERCASE = auto()
77    """Unquoted identifiers are lowercased."""
78
79    UPPERCASE = auto()
80    """Unquoted identifiers are uppercased."""
81
82    CASE_SENSITIVE = auto()
83    """Always case-sensitive, regardless of quotes."""
84
85    CASE_INSENSITIVE = auto()
86    """Always case-insensitive, regardless of quotes."""

Specifies the strategy according to which identifiers should be normalized.

LOWERCASE = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Unquoted identifiers are lowercased.

UPPERCASE = <NormalizationStrategy.UPPERCASE: 'UPPERCASE'>

Unquoted identifiers are uppercased.

CASE_SENSITIVE = <NormalizationStrategy.CASE_SENSITIVE: 'CASE_SENSITIVE'>

Always case-sensitive, regardless of quotes.

CASE_INSENSITIVE = <NormalizationStrategy.CASE_INSENSITIVE: 'CASE_INSENSITIVE'>

Always case-insensitive, regardless of quotes.

Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
207class Dialect(metaclass=_Dialect):
208    INDEX_OFFSET = 0
209    """The base index offset for arrays."""
210
211    WEEK_OFFSET = 0
212    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
213
214    UNNEST_COLUMN_ONLY = False
215    """Whether `UNNEST` table aliases are treated as column aliases."""
216
217    ALIAS_POST_TABLESAMPLE = False
218    """Whether the table alias comes after tablesample."""
219
220    TABLESAMPLE_SIZE_IS_PERCENT = False
221    """Whether a size in the table sample clause represents percentage."""
222
223    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
224    """Specifies the strategy according to which identifiers should be normalized."""
225
226    IDENTIFIERS_CAN_START_WITH_DIGIT = False
227    """Whether an unquoted identifier can start with a digit."""
228
229    DPIPE_IS_STRING_CONCAT = True
230    """Whether the DPIPE token (`||`) is a string concatenation operator."""
231
232    STRICT_STRING_CONCAT = False
233    """Whether `CONCAT`'s arguments must be strings."""
234
235    SUPPORTS_USER_DEFINED_TYPES = True
236    """Whether user-defined data types are supported."""
237
238    SUPPORTS_SEMI_ANTI_JOIN = True
239    """Whether `SEMI` or `ANTI` joins are supported."""
240
241    SUPPORTS_COLUMN_JOIN_MARKS = False
242    """Whether the old-style outer join (+) syntax is supported."""
243
244    COPY_PARAMS_ARE_CSV = True
245    """Separator of COPY statement parameters."""
246
247    NORMALIZE_FUNCTIONS: bool | str = "upper"
248    """
249    Determines how function names are going to be normalized.
250    Possible values:
251        "upper" or True: Convert names to uppercase.
252        "lower": Convert names to lowercase.
253        False: Disables function name normalization.
254    """
255
256    LOG_BASE_FIRST: t.Optional[bool] = True
257    """
258    Whether the base comes first in the `LOG` function.
259    Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
260    """
261
262    NULL_ORDERING = "nulls_are_small"
263    """
264    Default `NULL` ordering method to use if not explicitly set.
265    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
266    """
267
268    TYPED_DIVISION = False
269    """
270    Whether the behavior of `a / b` depends on the types of `a` and `b`.
271    False means `a / b` is always float division.
272    True means `a / b` is integer division if both `a` and `b` are integers.
273    """
274
275    SAFE_DIVISION = False
276    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
277
278    CONCAT_COALESCE = False
279    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
280
281    HEX_LOWERCASE = False
282    """Whether the `HEX` function returns a lowercase hexadecimal string."""
283
284    DATE_FORMAT = "'%Y-%m-%d'"
285    DATEINT_FORMAT = "'%Y%m%d'"
286    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
287
288    TIME_MAPPING: t.Dict[str, str] = {}
289    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
290
291    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
292    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
293    FORMAT_MAPPING: t.Dict[str, str] = {}
294    """
295    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
296    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
297    """
298
299    UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
300    """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
301
302    PSEUDOCOLUMNS: t.Set[str] = set()
303    """
304    Columns that are auto-generated by the engine corresponding to this dialect.
305    For example, such columns may be excluded from `SELECT *` queries.
306    """
307
308    PREFER_CTE_ALIAS_COLUMN = False
309    """
310    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
311    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
312    any projection aliases in the subquery.
313
314    For example,
315        WITH y(c) AS (
316            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
317        ) SELECT c FROM y;
318
319        will be rewritten as
320
321        WITH y(c) AS (
322            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
323        ) SELECT c FROM y;
324    """
325
326    COPY_PARAMS_ARE_CSV = True
327    """
328    Whether COPY statement parameters are separated by comma or whitespace
329    """
330
331    FORCE_EARLY_ALIAS_REF_EXPANSION = False
332    """
333    Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
334
335    For example:
336        WITH data AS (
337        SELECT
338            1 AS id,
339            2 AS my_id
340        )
341        SELECT
342            id AS my_id
343        FROM
344            data
345        WHERE
346            my_id = 1
347        GROUP BY
348            my_id,
349        HAVING
350            my_id = 1
351
352    In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except:
353        - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1"
354        - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
355    """
356
357    EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False
358    """Whether alias reference expansion before qualification should only happen for the GROUP BY clause."""
359
360    SUPPORTS_ORDER_BY_ALL = False
361    """
362    Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks
363    """
364
365    # --- Autofilled ---
366
367    tokenizer_class = Tokenizer
368    jsonpath_tokenizer_class = JSONPathTokenizer
369    parser_class = Parser
370    generator_class = Generator
371
372    # A trie of the time_mapping keys
373    TIME_TRIE: t.Dict = {}
374    FORMAT_TRIE: t.Dict = {}
375
376    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
377    INVERSE_TIME_TRIE: t.Dict = {}
378    INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {}
379    INVERSE_FORMAT_TRIE: t.Dict = {}
380
381    ESCAPED_SEQUENCES: t.Dict[str, str] = {}
382
383    # Delimiters for string literals and identifiers
384    QUOTE_START = "'"
385    QUOTE_END = "'"
386    IDENTIFIER_START = '"'
387    IDENTIFIER_END = '"'
388
389    # Delimiters for bit, hex, byte and unicode literals
390    BIT_START: t.Optional[str] = None
391    BIT_END: t.Optional[str] = None
392    HEX_START: t.Optional[str] = None
393    HEX_END: t.Optional[str] = None
394    BYTE_START: t.Optional[str] = None
395    BYTE_END: t.Optional[str] = None
396    UNICODE_START: t.Optional[str] = None
397    UNICODE_END: t.Optional[str] = None
398
399    DATE_PART_MAPPING = {
400        "Y": "YEAR",
401        "YY": "YEAR",
402        "YYY": "YEAR",
403        "YYYY": "YEAR",
404        "YR": "YEAR",
405        "YEARS": "YEAR",
406        "YRS": "YEAR",
407        "MM": "MONTH",
408        "MON": "MONTH",
409        "MONS": "MONTH",
410        "MONTHS": "MONTH",
411        "D": "DAY",
412        "DD": "DAY",
413        "DAYS": "DAY",
414        "DAYOFMONTH": "DAY",
415        "DAY OF WEEK": "DAYOFWEEK",
416        "WEEKDAY": "DAYOFWEEK",
417        "DOW": "DAYOFWEEK",
418        "DW": "DAYOFWEEK",
419        "WEEKDAY_ISO": "DAYOFWEEKISO",
420        "DOW_ISO": "DAYOFWEEKISO",
421        "DW_ISO": "DAYOFWEEKISO",
422        "DAY OF YEAR": "DAYOFYEAR",
423        "DOY": "DAYOFYEAR",
424        "DY": "DAYOFYEAR",
425        "W": "WEEK",
426        "WK": "WEEK",
427        "WEEKOFYEAR": "WEEK",
428        "WOY": "WEEK",
429        "WY": "WEEK",
430        "WEEK_ISO": "WEEKISO",
431        "WEEKOFYEARISO": "WEEKISO",
432        "WEEKOFYEAR_ISO": "WEEKISO",
433        "Q": "QUARTER",
434        "QTR": "QUARTER",
435        "QTRS": "QUARTER",
436        "QUARTERS": "QUARTER",
437        "H": "HOUR",
438        "HH": "HOUR",
439        "HR": "HOUR",
440        "HOURS": "HOUR",
441        "HRS": "HOUR",
442        "M": "MINUTE",
443        "MI": "MINUTE",
444        "MIN": "MINUTE",
445        "MINUTES": "MINUTE",
446        "MINS": "MINUTE",
447        "S": "SECOND",
448        "SEC": "SECOND",
449        "SECONDS": "SECOND",
450        "SECS": "SECOND",
451        "MS": "MILLISECOND",
452        "MSEC": "MILLISECOND",
453        "MSECS": "MILLISECOND",
454        "MSECOND": "MILLISECOND",
455        "MSECONDS": "MILLISECOND",
456        "MILLISEC": "MILLISECOND",
457        "MILLISECS": "MILLISECOND",
458        "MILLISECON": "MILLISECOND",
459        "MILLISECONDS": "MILLISECOND",
460        "US": "MICROSECOND",
461        "USEC": "MICROSECOND",
462        "USECS": "MICROSECOND",
463        "MICROSEC": "MICROSECOND",
464        "MICROSECS": "MICROSECOND",
465        "USECOND": "MICROSECOND",
466        "USECONDS": "MICROSECOND",
467        "MICROSECONDS": "MICROSECOND",
468        "NS": "NANOSECOND",
469        "NSEC": "NANOSECOND",
470        "NANOSEC": "NANOSECOND",
471        "NSECOND": "NANOSECOND",
472        "NSECONDS": "NANOSECOND",
473        "NANOSECS": "NANOSECOND",
474        "EPOCH_SECOND": "EPOCH",
475        "EPOCH_SECONDS": "EPOCH",
476        "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND",
477        "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND",
478        "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND",
479        "TZH": "TIMEZONE_HOUR",
480        "TZM": "TIMEZONE_MINUTE",
481        "DEC": "DECADE",
482        "DECS": "DECADE",
483        "DECADES": "DECADE",
484        "MIL": "MILLENIUM",
485        "MILS": "MILLENIUM",
486        "MILLENIA": "MILLENIUM",
487        "C": "CENTURY",
488        "CENT": "CENTURY",
489        "CENTS": "CENTURY",
490        "CENTURIES": "CENTURY",
491    }
492
493    @classmethod
494    def get_or_raise(cls, dialect: DialectType) -> Dialect:
495        """
496        Look up a dialect in the global dialect registry and return it if it exists.
497
498        Args:
499            dialect: The target dialect. If this is a string, it can be optionally followed by
500                additional key-value pairs that are separated by commas and are used to specify
501                dialect settings, such as whether the dialect's identifiers are case-sensitive.
502
503        Example:
504            >>> dialect = dialect_class = get_or_raise("duckdb")
505            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
506
507        Returns:
508            The corresponding Dialect instance.
509        """
510
511        if not dialect:
512            return cls()
513        if isinstance(dialect, _Dialect):
514            return dialect()
515        if isinstance(dialect, Dialect):
516            return dialect
517        if isinstance(dialect, str):
518            try:
519                dialect_name, *kv_strings = dialect.split(",")
520                kv_pairs = (kv.split("=") for kv in kv_strings)
521                kwargs = {}
522                for pair in kv_pairs:
523                    key = pair[0].strip()
524                    value: t.Union[bool | str | None] = None
525
526                    if len(pair) == 1:
527                        # Default initialize standalone settings to True
528                        value = True
529                    elif len(pair) == 2:
530                        value = pair[1].strip()
531
532                        # Coerce the value to boolean if it matches to the truthy/falsy values below
533                        value_lower = value.lower()
534                        if value_lower in ("true", "1"):
535                            value = True
536                        elif value_lower in ("false", "0"):
537                            value = False
538
539                    kwargs[key] = value
540
541            except ValueError:
542                raise ValueError(
543                    f"Invalid dialect format: '{dialect}'. "
544                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
545                )
546
547            result = cls.get(dialect_name.strip())
548            if not result:
549                from difflib import get_close_matches
550
551                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
552                if similar:
553                    similar = f" Did you mean {similar}?"
554
555                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
556
557            return result(**kwargs)
558
559        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
560
561    @classmethod
562    def format_time(
563        cls, expression: t.Optional[str | exp.Expression]
564    ) -> t.Optional[exp.Expression]:
565        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
566        if isinstance(expression, str):
567            return exp.Literal.string(
568                # the time formats are quoted
569                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
570            )
571
572        if expression and expression.is_string:
573            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
574
575        return expression
576
577    def __init__(self, **kwargs) -> None:
578        normalization_strategy = kwargs.pop("normalization_strategy", None)
579
580        if normalization_strategy is None:
581            self.normalization_strategy = self.NORMALIZATION_STRATEGY
582        else:
583            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
584
585        self.settings = kwargs
586
587    def __eq__(self, other: t.Any) -> bool:
588        # Does not currently take dialect state into account
589        return type(self) == other
590
591    def __hash__(self) -> int:
592        # Does not currently take dialect state into account
593        return hash(type(self))
594
595    def normalize_identifier(self, expression: E) -> E:
596        """
597        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
598
599        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
600        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
601        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
602        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
603
604        There are also dialects like Spark, which are case-insensitive even when quotes are
605        present, and dialects like MySQL, whose resolution rules match those employed by the
606        underlying operating system, for example they may always be case-sensitive in Linux.
607
608        Finally, the normalization behavior of some engines can even be controlled through flags,
609        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
610
611        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
612        that it can analyze queries in the optimizer and successfully capture their semantics.
613        """
614        if (
615            isinstance(expression, exp.Identifier)
616            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
617            and (
618                not expression.quoted
619                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
620            )
621        ):
622            expression.set(
623                "this",
624                (
625                    expression.this.upper()
626                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
627                    else expression.this.lower()
628                ),
629            )
630
631        return expression
632
633    def case_sensitive(self, text: str) -> bool:
634        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
635        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
636            return False
637
638        unsafe = (
639            str.islower
640            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
641            else str.isupper
642        )
643        return any(unsafe(char) for char in text)
644
645    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
646        """Checks if text can be identified given an identify option.
647
648        Args:
649            text: The text to check.
650            identify:
651                `"always"` or `True`: Always returns `True`.
652                `"safe"`: Only returns `True` if the identifier is case-insensitive.
653
654        Returns:
655            Whether the given text can be identified.
656        """
657        if identify is True or identify == "always":
658            return True
659
660        if identify == "safe":
661            return not self.case_sensitive(text)
662
663        return False
664
665    def quote_identifier(self, expression: E, identify: bool = True) -> E:
666        """
667        Adds quotes to a given identifier.
668
669        Args:
670            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
671            identify: If set to `False`, the quotes will only be added if the identifier is deemed
672                "unsafe", with respect to its characters and this dialect's normalization strategy.
673        """
674        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
675            name = expression.this
676            expression.set(
677                "quoted",
678                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
679            )
680
681        return expression
682
683    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
684        if isinstance(path, exp.Literal):
685            path_text = path.name
686            if path.is_number:
687                path_text = f"[{path_text}]"
688            try:
689                return parse_json_path(path_text, self)
690            except ParseError as e:
691                logger.warning(f"Invalid JSON path syntax. {str(e)}")
692
693        return path
694
695    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
696        return self.parser(**opts).parse(self.tokenize(sql), sql)
697
698    def parse_into(
699        self, expression_type: exp.IntoType, sql: str, **opts
700    ) -> t.List[t.Optional[exp.Expression]]:
701        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
702
703    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
704        return self.generator(**opts).generate(expression, copy=copy)
705
706    def transpile(self, sql: str, **opts) -> t.List[str]:
707        return [
708            self.generate(expression, copy=False, **opts) if expression else ""
709            for expression in self.parse(sql)
710        ]
711
712    def tokenize(self, sql: str) -> t.List[Token]:
713        return self.tokenizer.tokenize(sql)
714
715    @property
716    def tokenizer(self) -> Tokenizer:
717        return self.tokenizer_class(dialect=self)
718
719    @property
720    def jsonpath_tokenizer(self) -> JSONPathTokenizer:
721        return self.jsonpath_tokenizer_class(dialect=self)
722
723    def parser(self, **opts) -> Parser:
724        return self.parser_class(dialect=self, **opts)
725
726    def generator(self, **opts) -> Generator:
727        return self.generator_class(dialect=self, **opts)
Dialect(**kwargs)
577    def __init__(self, **kwargs) -> None:
578        normalization_strategy = kwargs.pop("normalization_strategy", None)
579
580        if normalization_strategy is None:
581            self.normalization_strategy = self.NORMALIZATION_STRATEGY
582        else:
583            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
584
585        self.settings = kwargs
INDEX_OFFSET = 0

The base index offset for arrays.

WEEK_OFFSET = 0

First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.

UNNEST_COLUMN_ONLY = False

Whether UNNEST table aliases are treated as column aliases.

ALIAS_POST_TABLESAMPLE = False

Whether the table alias comes after tablesample.

TABLESAMPLE_SIZE_IS_PERCENT = False

Whether a size in the table sample clause represents percentage.

NORMALIZATION_STRATEGY = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Specifies the strategy according to which identifiers should be normalized.

IDENTIFIERS_CAN_START_WITH_DIGIT = False

Whether an unquoted identifier can start with a digit.

DPIPE_IS_STRING_CONCAT = True

Whether the DPIPE token (||) is a string concatenation operator.

STRICT_STRING_CONCAT = False

Whether CONCAT's arguments must be strings.

SUPPORTS_USER_DEFINED_TYPES = True

Whether user-defined data types are supported.

SUPPORTS_SEMI_ANTI_JOIN = True

Whether SEMI or ANTI joins are supported.

SUPPORTS_COLUMN_JOIN_MARKS = False

Whether the old-style outer join (+) syntax is supported.

COPY_PARAMS_ARE_CSV = True

Whether COPY statement parameters are separated by comma or whitespace

NORMALIZE_FUNCTIONS: bool | str = 'upper'

Determines how function names are going to be normalized.

Possible values:

"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.

LOG_BASE_FIRST: Optional[bool] = True

Whether the base comes first in the LOG function. Possible values: True, False, None (two arguments are not supported by LOG)

NULL_ORDERING = 'nulls_are_small'

Default NULL ordering method to use if not explicitly set. Possible values: "nulls_are_small", "nulls_are_large", "nulls_are_last"

TYPED_DIVISION = False

Whether the behavior of a / b depends on the types of a and b. False means a / b is always float division. True means a / b is integer division if both a and b are integers.

SAFE_DIVISION = False

Whether division by zero throws an error (False) or returns NULL (True).

CONCAT_COALESCE = False

A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string.

HEX_LOWERCASE = False

Whether the HEX function returns a lowercase hexadecimal string.

DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}

Associates this dialect's time formats with their equivalent Python strftime formats.

FORMAT_MAPPING: Dict[str, str] = {}

Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy'). If empty, the corresponding trie will be constructed off of TIME_MAPPING.

UNESCAPED_SEQUENCES: Dict[str, str] = {}

Mapping of an escaped sequence (\n) to its unescaped version ( ).

PSEUDOCOLUMNS: Set[str] = set()

Columns that are auto-generated by the engine corresponding to this dialect. For example, such columns may be excluded from SELECT * queries.

PREFER_CTE_ALIAS_COLUMN = False

Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.

For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;

will be rewritten as

WITH y(c) AS (
    SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
FORCE_EARLY_ALIAS_REF_EXPANSION = False

Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).

For example:

WITH data AS ( SELECT 1 AS id, 2 AS my_id ) SELECT id AS my_id FROM data WHERE my_id = 1 GROUP BY my_id, HAVING my_id = 1

In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"

EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False

Whether alias reference expansion before qualification should only happen for the GROUP BY clause.

SUPPORTS_ORDER_BY_ALL = False

Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks

tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
jsonpath_tokenizer_class = <class 'sqlglot.tokens.JSONPathTokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
INVERSE_FORMAT_MAPPING: Dict[str, str] = {}
INVERSE_FORMAT_TRIE: Dict = {}
ESCAPED_SEQUENCES: Dict[str, str] = {}
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START: Optional[str] = None
BIT_END: Optional[str] = None
HEX_START: Optional[str] = None
HEX_END: Optional[str] = None
BYTE_START: Optional[str] = None
BYTE_END: Optional[str] = None
UNICODE_START: Optional[str] = None
UNICODE_END: Optional[str] = None
DATE_PART_MAPPING = {'Y': 'YEAR', 'YY': 'YEAR', 'YYY': 'YEAR', 'YYYY': 'YEAR', 'YR': 'YEAR', 'YEARS': 'YEAR', 'YRS': 'YEAR', 'MM': 'MONTH', 'MON': 'MONTH', 'MONS': 'MONTH', 'MONTHS': 'MONTH', 'D': 'DAY', 'DD': 'DAY', 'DAYS': 'DAY', 'DAYOFMONTH': 'DAY', 'DAY OF WEEK': 'DAYOFWEEK', 'WEEKDAY': 'DAYOFWEEK', 'DOW': 'DAYOFWEEK', 'DW': 'DAYOFWEEK', 'WEEKDAY_ISO': 'DAYOFWEEKISO', 'DOW_ISO': 'DAYOFWEEKISO', 'DW_ISO': 'DAYOFWEEKISO', 'DAY OF YEAR': 'DAYOFYEAR', 'DOY': 'DAYOFYEAR', 'DY': 'DAYOFYEAR', 'W': 'WEEK', 'WK': 'WEEK', 'WEEKOFYEAR': 'WEEK', 'WOY': 'WEEK', 'WY': 'WEEK', 'WEEK_ISO': 'WEEKISO', 'WEEKOFYEARISO': 'WEEKISO', 'WEEKOFYEAR_ISO': 'WEEKISO', 'Q': 'QUARTER', 'QTR': 'QUARTER', 'QTRS': 'QUARTER', 'QUARTERS': 'QUARTER', 'H': 'HOUR', 'HH': 'HOUR', 'HR': 'HOUR', 'HOURS': 'HOUR', 'HRS': 'HOUR', 'M': 'MINUTE', 'MI': 'MINUTE', 'MIN': 'MINUTE', 'MINUTES': 'MINUTE', 'MINS': 'MINUTE', 'S': 'SECOND', 'SEC': 'SECOND', 'SECONDS': 'SECOND', 'SECS': 'SECOND', 'MS': 'MILLISECOND', 'MSEC': 'MILLISECOND', 'MSECS': 'MILLISECOND', 'MSECOND': 'MILLISECOND', 'MSECONDS': 'MILLISECOND', 'MILLISEC': 'MILLISECOND', 'MILLISECS': 'MILLISECOND', 'MILLISECON': 'MILLISECOND', 'MILLISECONDS': 'MILLISECOND', 'US': 'MICROSECOND', 'USEC': 'MICROSECOND', 'USECS': 'MICROSECOND', 'MICROSEC': 'MICROSECOND', 'MICROSECS': 'MICROSECOND', 'USECOND': 'MICROSECOND', 'USECONDS': 'MICROSECOND', 'MICROSECONDS': 'MICROSECOND', 'NS': 'NANOSECOND', 'NSEC': 'NANOSECOND', 'NANOSEC': 'NANOSECOND', 'NSECOND': 'NANOSECOND', 'NSECONDS': 'NANOSECOND', 'NANOSECS': 'NANOSECOND', 'EPOCH_SECOND': 'EPOCH', 'EPOCH_SECONDS': 'EPOCH', 'EPOCH_MILLISECONDS': 'EPOCH_MILLISECOND', 'EPOCH_MICROSECONDS': 'EPOCH_MICROSECOND', 'EPOCH_NANOSECONDS': 'EPOCH_NANOSECOND', 'TZH': 'TIMEZONE_HOUR', 'TZM': 'TIMEZONE_MINUTE', 'DEC': 'DECADE', 'DECS': 'DECADE', 'DECADES': 'DECADE', 'MIL': 'MILLENIUM', 'MILS': 'MILLENIUM', 'MILLENIA': 'MILLENIUM', 'C': 'CENTURY', 'CENT': 'CENTURY', 'CENTS': 'CENTURY', 'CENTURIES': 'CENTURY'}
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Dialect:
493    @classmethod
494    def get_or_raise(cls, dialect: DialectType) -> Dialect:
495        """
496        Look up a dialect in the global dialect registry and return it if it exists.
497
498        Args:
499            dialect: The target dialect. If this is a string, it can be optionally followed by
500                additional key-value pairs that are separated by commas and are used to specify
501                dialect settings, such as whether the dialect's identifiers are case-sensitive.
502
503        Example:
504            >>> dialect = dialect_class = get_or_raise("duckdb")
505            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
506
507        Returns:
508            The corresponding Dialect instance.
509        """
510
511        if not dialect:
512            return cls()
513        if isinstance(dialect, _Dialect):
514            return dialect()
515        if isinstance(dialect, Dialect):
516            return dialect
517        if isinstance(dialect, str):
518            try:
519                dialect_name, *kv_strings = dialect.split(",")
520                kv_pairs = (kv.split("=") for kv in kv_strings)
521                kwargs = {}
522                for pair in kv_pairs:
523                    key = pair[0].strip()
524                    value: t.Union[bool | str | None] = None
525
526                    if len(pair) == 1:
527                        # Default initialize standalone settings to True
528                        value = True
529                    elif len(pair) == 2:
530                        value = pair[1].strip()
531
532                        # Coerce the value to boolean if it matches to the truthy/falsy values below
533                        value_lower = value.lower()
534                        if value_lower in ("true", "1"):
535                            value = True
536                        elif value_lower in ("false", "0"):
537                            value = False
538
539                    kwargs[key] = value
540
541            except ValueError:
542                raise ValueError(
543                    f"Invalid dialect format: '{dialect}'. "
544                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
545                )
546
547            result = cls.get(dialect_name.strip())
548            if not result:
549                from difflib import get_close_matches
550
551                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
552                if similar:
553                    similar = f" Did you mean {similar}?"
554
555                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
556
557            return result(**kwargs)
558
559        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")

Look up a dialect in the global dialect registry and return it if it exists.

Arguments:
  • dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb")
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:

The corresponding Dialect instance.

@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
561    @classmethod
562    def format_time(
563        cls, expression: t.Optional[str | exp.Expression]
564    ) -> t.Optional[exp.Expression]:
565        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
566        if isinstance(expression, str):
567            return exp.Literal.string(
568                # the time formats are quoted
569                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
570            )
571
572        if expression and expression.is_string:
573            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
574
575        return expression

Converts a time format in this dialect to its equivalent Python strftime format.

settings
def normalize_identifier(self, expression: ~E) -> ~E:
595    def normalize_identifier(self, expression: E) -> E:
596        """
597        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
598
599        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
600        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
601        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
602        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
603
604        There are also dialects like Spark, which are case-insensitive even when quotes are
605        present, and dialects like MySQL, whose resolution rules match those employed by the
606        underlying operating system, for example they may always be case-sensitive in Linux.
607
608        Finally, the normalization behavior of some engines can even be controlled through flags,
609        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
610
611        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
612        that it can analyze queries in the optimizer and successfully capture their semantics.
613        """
614        if (
615            isinstance(expression, exp.Identifier)
616            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
617            and (
618                not expression.quoted
619                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
620            )
621        ):
622            expression.set(
623                "this",
624                (
625                    expression.this.upper()
626                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
627                    else expression.this.lower()
628                ),
629            )
630
631        return expression

Transforms an identifier in a way that resembles how it'd be resolved by this dialect.

For example, an identifier like FoO would be resolved as foo in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier.

There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.

Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.

SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.

def case_sensitive(self, text: str) -> bool:
633    def case_sensitive(self, text: str) -> bool:
634        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
635        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
636            return False
637
638        unsafe = (
639            str.islower
640            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
641            else str.isupper
642        )
643        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

def can_identify(self, text: str, identify: str | bool = 'safe') -> bool:
645    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
646        """Checks if text can be identified given an identify option.
647
648        Args:
649            text: The text to check.
650            identify:
651                `"always"` or `True`: Always returns `True`.
652                `"safe"`: Only returns `True` if the identifier is case-insensitive.
653
654        Returns:
655            Whether the given text can be identified.
656        """
657        if identify is True or identify == "always":
658            return True
659
660        if identify == "safe":
661            return not self.case_sensitive(text)
662
663        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns True. "safe": Only returns True if the identifier is case-insensitive.
Returns:

Whether the given text can be identified.

def quote_identifier(self, expression: ~E, identify: bool = True) -> ~E:
665    def quote_identifier(self, expression: E, identify: bool = True) -> E:
666        """
667        Adds quotes to a given identifier.
668
669        Args:
670            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
671            identify: If set to `False`, the quotes will only be added if the identifier is deemed
672                "unsafe", with respect to its characters and this dialect's normalization strategy.
673        """
674        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
675            name = expression.this
676            expression.set(
677                "quoted",
678                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
679            )
680
681        return expression

Adds quotes to a given identifier.

Arguments:
  • expression: The expression of interest. If it's not an Identifier, this method is a no-op.
  • identify: If set to False, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
def to_json_path( self, path: Optional[sqlglot.expressions.Expression]) -> Optional[sqlglot.expressions.Expression]:
683    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
684        if isinstance(path, exp.Literal):
685            path_text = path.name
686            if path.is_number:
687                path_text = f"[{path_text}]"
688            try:
689                return parse_json_path(path_text, self)
690            except ParseError as e:
691                logger.warning(f"Invalid JSON path syntax. {str(e)}")
692
693        return path
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
695    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
696        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
698    def parse_into(
699        self, expression_type: exp.IntoType, sql: str, **opts
700    ) -> t.List[t.Optional[exp.Expression]]:
701        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: sqlglot.expressions.Expression, copy: bool = True, **opts) -> str:
703    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
704        return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> List[str]:
706    def transpile(self, sql: str, **opts) -> t.List[str]:
707        return [
708            self.generate(expression, copy=False, **opts) if expression else ""
709            for expression in self.parse(sql)
710        ]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
712    def tokenize(self, sql: str) -> t.List[Token]:
713        return self.tokenizer.tokenize(sql)
tokenizer: sqlglot.tokens.Tokenizer
715    @property
716    def tokenizer(self) -> Tokenizer:
717        return self.tokenizer_class(dialect=self)
jsonpath_tokenizer: sqlglot.jsonpath.JSONPathTokenizer
719    @property
720    def jsonpath_tokenizer(self) -> JSONPathTokenizer:
721        return self.jsonpath_tokenizer_class(dialect=self)
def parser(self, **opts) -> sqlglot.parser.Parser:
723    def parser(self, **opts) -> Parser:
724        return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
726    def generator(self, **opts) -> Generator:
727        return self.generator_class(dialect=self, **opts)
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
733def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
734    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
737def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
738    if expression.args.get("accuracy"):
739        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
740    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
743def if_sql(
744    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
745) -> t.Callable[[Generator, exp.If], str]:
746    def _if_sql(self: Generator, expression: exp.If) -> str:
747        return self.func(
748            name,
749            expression.this,
750            expression.args.get("true"),
751            expression.args.get("false") or false_value,
752        )
753
754    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]) -> str:
757def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
758    this = expression.this
759    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
760        this.replace(exp.cast(this, exp.DataType.Type.JSON))
761
762    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
765def inline_array_sql(self: Generator, expression: exp.Array) -> str:
766    return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]"
def inline_array_unless_query( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
769def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
770    elem = seq_get(expression.expressions, 0)
771    if isinstance(elem, exp.Expression) and elem.find(exp.Query):
772        return self.func("ARRAY", elem)
773    return inline_array_sql(self, expression)
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
776def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
777    return self.like_sql(
778        exp.Like(
779            this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression)
780        )
781    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
784def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
785    zone = self.sql(expression, "this")
786    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
789def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
790    if expression.args.get("recursive"):
791        self.unsupported("Recursive CTEs are unsupported")
792        expression.args["recursive"] = False
793    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
796def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
797    n = self.sql(expression, "this")
798    d = self.sql(expression, "expression")
799    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
802def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
803    self.unsupported("TABLESAMPLE unsupported")
804    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
807def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
808    self.unsupported("PIVOT unsupported")
809    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
812def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
813    return self.cast_sql(expression)
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
816def no_comment_column_constraint_sql(
817    self: Generator, expression: exp.CommentColumnConstraint
818) -> str:
819    self.unsupported("CommentColumnConstraint unsupported")
820    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
823def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
824    self.unsupported("MAP_FROM_ENTRIES unsupported")
825    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition, generate_instance: bool = False) -> str:
828def str_position_sql(
829    self: Generator, expression: exp.StrPosition, generate_instance: bool = False
830) -> str:
831    this = self.sql(expression, "this")
832    substr = self.sql(expression, "substr")
833    position = self.sql(expression, "position")
834    instance = expression.args.get("instance") if generate_instance else None
835    position_offset = ""
836
837    if position:
838        # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
839        this = self.func("SUBSTR", this, position)
840        position_offset = f" + {position} - 1"
841
842    return self.func("STRPOS", this, substr, instance) + position_offset
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
845def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
846    return (
847        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
848    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
851def var_map_sql(
852    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
853) -> str:
854    keys = expression.args["keys"]
855    values = expression.args["values"]
856
857    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
858        self.unsupported("Cannot convert array columns into map.")
859        return self.func(map_func_name, keys, values)
860
861    args = []
862    for key, value in zip(keys.expressions, values.expressions):
863        args.append(self.sql(key))
864        args.append(self.sql(value))
865
866    return self.func(map_func_name, *args)
def build_formatted_time( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
869def build_formatted_time(
870    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
871) -> t.Callable[[t.List], E]:
872    """Helper used for time expressions.
873
874    Args:
875        exp_class: the expression class to instantiate.
876        dialect: target sql dialect.
877        default: the default format, True being time.
878
879    Returns:
880        A callable that can be used to return the appropriately formatted time expression.
881    """
882
883    def _builder(args: t.List):
884        return exp_class(
885            this=seq_get(args, 0),
886            format=Dialect[dialect].format_time(
887                seq_get(args, 1)
888                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
889            ),
890        )
891
892    return _builder

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
895def time_format(
896    dialect: DialectType = None,
897) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
898    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
899        """
900        Returns the time format for a given expression, unless it's equivalent
901        to the default time format of the dialect of interest.
902        """
903        time_format = self.format_time(expression)
904        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
905
906    return _time_format
def build_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None, default_unit: Optional[str] = 'DAY') -> Callable[[List], ~E]:
909def build_date_delta(
910    exp_class: t.Type[E],
911    unit_mapping: t.Optional[t.Dict[str, str]] = None,
912    default_unit: t.Optional[str] = "DAY",
913) -> t.Callable[[t.List], E]:
914    def _builder(args: t.List) -> E:
915        unit_based = len(args) == 3
916        this = args[2] if unit_based else seq_get(args, 0)
917        unit = None
918        if unit_based or default_unit:
919            unit = args[0] if unit_based else exp.Literal.string(default_unit)
920            unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
921        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
922
923    return _builder
def build_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
926def build_date_delta_with_interval(
927    expression_class: t.Type[E],
928) -> t.Callable[[t.List], t.Optional[E]]:
929    def _builder(args: t.List) -> t.Optional[E]:
930        if len(args) < 2:
931            return None
932
933        interval = args[1]
934
935        if not isinstance(interval, exp.Interval):
936            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
937
938        expression = interval.this
939        if expression and expression.is_string:
940            expression = exp.Literal.number(expression.this)
941
942        return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
943
944    return _builder
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
947def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
948    unit = seq_get(args, 0)
949    this = seq_get(args, 1)
950
951    if isinstance(this, exp.Cast) and this.is_type("date"):
952        return exp.DateTrunc(unit=unit, this=this)
953    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
956def date_add_interval_sql(
957    data_type: str, kind: str
958) -> t.Callable[[Generator, exp.Expression], str]:
959    def func(self: Generator, expression: exp.Expression) -> str:
960        this = self.sql(expression, "this")
961        interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
962        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
963
964    return func
def timestamptrunc_sql( zone: bool = False) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.TimestampTrunc], str]:
967def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]:
968    def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
969        args = [unit_to_str(expression), expression.this]
970        if zone:
971            args.append(expression.args.get("zone"))
972        return self.func("DATE_TRUNC", *args)
973
974    return _timestamptrunc_sql
def no_timestamp_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Timestamp) -> str:
977def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
978    zone = expression.args.get("zone")
979    if not zone:
980        from sqlglot.optimizer.annotate_types import annotate_types
981
982        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
983        return self.sql(exp.cast(expression.this, target_type))
984    if zone.name.lower() in TIMEZONES:
985        return self.sql(
986            exp.AtTimeZone(
987                this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
988                zone=zone,
989            )
990        )
991    return self.func("TIMESTAMP", expression.this, zone)
def no_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Time) -> str:
 994def no_time_sql(self: Generator, expression: exp.Time) -> str:
 995    # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME)
 996    this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ)
 997    expr = exp.cast(
 998        exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME
 999    )
1000    return self.sql(expr)
def no_datetime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Datetime) -> str:
1003def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str:
1004    this = expression.this
1005    expr = expression.expression
1006
1007    if expr.name.lower() in TIMEZONES:
1008        # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP)
1009        this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ)
1010        this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP)
1011        return self.sql(this)
1012
1013    this = exp.cast(this, exp.DataType.Type.DATE)
1014    expr = exp.cast(expr, exp.DataType.Type.TIME)
1015
1016    return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
1019def locate_to_strposition(args: t.List) -> exp.Expression:
1020    return exp.StrPosition(
1021        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
1022    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
1025def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
1026    return self.func(
1027        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
1028    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
1031def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
1032    return self.sql(
1033        exp.Substring(
1034            this=expression.this, start=exp.Literal.number(1), length=expression.expression
1035        )
1036    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
1039def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
1040    return self.sql(
1041        exp.Substring(
1042            this=expression.this,
1043            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
1044        )
1045    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
1048def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
1049    return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
1052def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
1053    return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
1057def encode_decode_sql(
1058    self: Generator, expression: exp.Expression, name: str, replace: bool = True
1059) -> str:
1060    charset = expression.args.get("charset")
1061    if charset and charset.name.lower() != "utf-8":
1062        self.unsupported(f"Expected utf-8 character set, got {charset}.")
1063
1064    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
1067def min_or_least(self: Generator, expression: exp.Min) -> str:
1068    name = "LEAST" if expression.expressions else "MIN"
1069    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
1072def max_or_greatest(self: Generator, expression: exp.Max) -> str:
1073    name = "GREATEST" if expression.expressions else "MAX"
1074    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
1077def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
1078    cond = expression.this
1079
1080    if isinstance(expression.this, exp.Distinct):
1081        cond = expression.this.expressions[0]
1082        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
1083
1084    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
1087def trim_sql(self: Generator, expression: exp.Trim) -> str:
1088    target = self.sql(expression, "this")
1089    trim_type = self.sql(expression, "position")
1090    remove_chars = self.sql(expression, "expression")
1091    collation = self.sql(expression, "collation")
1092
1093    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
1094    if not remove_chars and not collation:
1095        return self.trim_sql(expression)
1096
1097    trim_type = f"{trim_type} " if trim_type else ""
1098    remove_chars = f"{remove_chars} " if remove_chars else ""
1099    from_part = "FROM " if trim_type or remove_chars else ""
1100    collation = f" COLLATE {collation}" if collation else ""
1101    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
1104def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
1105    return self.func("STRPTIME", expression.this, self.format_time(expression))
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat) -> str:
1108def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
1109    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
1112def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
1113    delim, *rest_args = expression.expressions
1114    return self.sql(
1115        reduce(
1116            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
1117            rest_args,
1118        )
1119    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
1122def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
1123    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
1124    if bad_args:
1125        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
1126
1127    return self.func(
1128        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
1129    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
1132def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
1133    bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers")))
1134    if bad_args:
1135        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
1136
1137    return self.func(
1138        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
1139    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
1142def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
1143    names = []
1144    for agg in aggregations:
1145        if isinstance(agg, exp.Alias):
1146            names.append(agg.alias)
1147        else:
1148            """
1149            This case corresponds to aggregations without aliases being used as suffixes
1150            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
1151            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
1152            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
1153            """
1154            agg_all_unquoted = agg.transform(
1155                lambda node: (
1156                    exp.Identifier(this=node.name, quoted=False)
1157                    if isinstance(node, exp.Identifier)
1158                    else node
1159                )
1160            )
1161            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
1162
1163    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
1166def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
1167    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def build_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
1171def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
1172    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
1175def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
1176    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
1179def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
1180    a = self.sql(expression.left)
1181    b = self.sql(expression.right)
1182    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
1185def is_parse_json(expression: exp.Expression) -> bool:
1186    return isinstance(expression, exp.ParseJSON) or (
1187        isinstance(expression, exp.Cast) and expression.is_type("json")
1188    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
1191def isnull_to_is_null(args: t.List) -> exp.Expression:
1192    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def generatedasidentitycolumnconstraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.GeneratedAsIdentityColumnConstraint) -> str:
1195def generatedasidentitycolumnconstraint_sql(
1196    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
1197) -> str:
1198    start = self.sql(expression, "start") or "1"
1199    increment = self.sql(expression, "increment") or "1"
1200    return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.ArgMax | sqlglot.expressions.ArgMin], str]:
1203def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
1204    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
1205        if expression.args.get("count"):
1206            self.unsupported(f"Only two arguments are supported in function {name}.")
1207
1208        return self.func(name, expression.this, expression.expression)
1209
1210    return _arg_max_or_min_sql
def ts_or_ds_add_cast( expression: sqlglot.expressions.TsOrDsAdd) -> sqlglot.expressions.TsOrDsAdd:
1213def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
1214    this = expression.this.copy()
1215
1216    return_type = expression.return_type
1217    if return_type.is_type(exp.DataType.Type.DATE):
1218        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
1219        # can truncate timestamp strings, because some dialects can't cast them to DATE
1220        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
1221
1222    expression.this.replace(exp.cast(this, return_type))
1223    return expression
def date_delta_sql( name: str, cast: bool = False) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.DateAdd, sqlglot.expressions.TsOrDsAdd, sqlglot.expressions.DateDiff, sqlglot.expressions.TsOrDsDiff]], str]:
1226def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
1227    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
1228        if cast and isinstance(expression, exp.TsOrDsAdd):
1229            expression = ts_or_ds_add_cast(expression)
1230
1231        return self.func(
1232            name,
1233            unit_to_var(expression),
1234            expression.expression,
1235            expression.this,
1236        )
1237
1238    return _delta_sql
def unit_to_str( expression: sqlglot.expressions.Expression, default: str = 'DAY') -> Optional[sqlglot.expressions.Expression]:
1241def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1242    unit = expression.args.get("unit")
1243
1244    if isinstance(unit, exp.Placeholder):
1245        return unit
1246    if unit:
1247        return exp.Literal.string(unit.name)
1248    return exp.Literal.string(default) if default else None
def unit_to_var( expression: sqlglot.expressions.Expression, default: str = 'DAY') -> Optional[sqlglot.expressions.Expression]:
1251def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1252    unit = expression.args.get("unit")
1253
1254    if isinstance(unit, (exp.Var, exp.Placeholder)):
1255        return unit
1256    return exp.Var(this=default) if default else None
def map_date_part( part, dialect: Union[str, Dialect, Type[Dialect], NoneType] = <class 'Dialect'>):
1271def map_date_part(part, dialect: DialectType = Dialect):
1272    mapped = (
1273        Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None
1274    )
1275    return exp.var(mapped) if mapped else part
def no_last_day_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.LastDay) -> str:
1278def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
1279    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
1280    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
1281    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
1282
1283    return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
def merge_without_target_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Merge) -> str:
1286def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1287    """Remove table refs from columns in when statements."""
1288    alias = expression.this.args.get("alias")
1289
1290    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
1291        return self.dialect.normalize_identifier(identifier).name if identifier else None
1292
1293    targets = {normalize(expression.this.this)}
1294
1295    if alias:
1296        targets.add(normalize(alias.this))
1297
1298    for when in expression.expressions:
1299        when.transform(
1300            lambda node: (
1301                exp.column(node.this)
1302                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1303                else node
1304            ),
1305            copy=False,
1306        )
1307
1308    return self.merge_sql(expression)

Remove table refs from columns in when statements.

def build_json_extract_path( expr_type: Type[~F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False) -> Callable[[List], ~F]:
1311def build_json_extract_path(
1312    expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
1313) -> t.Callable[[t.List], F]:
1314    def _builder(args: t.List) -> F:
1315        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1316        for arg in args[1:]:
1317            if not isinstance(arg, exp.Literal):
1318                # We use the fallback parser because we can't really transpile non-literals safely
1319                return expr_type.from_arg_list(args)
1320
1321            text = arg.name
1322            if is_int(text):
1323                index = int(text)
1324                segments.append(
1325                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1326                )
1327            else:
1328                segments.append(exp.JSONPathKey(this=text))
1329
1330        # This is done to avoid failing in the expression validator due to the arg count
1331        del args[2:]
1332        return expr_type(
1333            this=seq_get(args, 0),
1334            expression=exp.JSONPath(expressions=segments),
1335            only_json_types=arrow_req_json_type,
1336        )
1337
1338    return _builder
def json_extract_segments( name: str, quoted_index: bool = True, op: Optional[str] = None) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]], str]:
1341def json_extract_segments(
1342    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1343) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1344    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1345        path = expression.expression
1346        if not isinstance(path, exp.JSONPath):
1347            return rename_func(name)(self, expression)
1348
1349        segments = []
1350        for segment in path.expressions:
1351            path = self.sql(segment)
1352            if path:
1353                if isinstance(segment, exp.JSONPathPart) and (
1354                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1355                ):
1356                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1357
1358                segments.append(path)
1359
1360        if op:
1361            return f" {op} ".join([self.sql(expression.this), *segments])
1362        return self.func(name, expression.this, *segments)
1363
1364    return _json_extract_segments
def json_path_key_only_name( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONPathKey) -> str:
1367def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1368    if isinstance(expression.this, exp.JSONPathWildcard):
1369        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1370
1371    return expression.name
def filter_array_using_unnest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ArrayFilter) -> str:
1374def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1375    cond = expression.expression
1376    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1377        alias = cond.expressions[0]
1378        cond = cond.this
1379    elif isinstance(cond, exp.Predicate):
1380        alias = "_u"
1381    else:
1382        self.unsupported("Unsupported filter condition")
1383        return ""
1384
1385    unnest = exp.Unnest(expressions=[expression.this])
1386    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1387    return self.sql(exp.Array(expressions=[filtered]))
def to_number_with_nls_param( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ToNumber) -> str:
1390def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str:
1391    return self.func(
1392        "TO_NUMBER",
1393        expression.this,
1394        expression.args.get("format"),
1395        expression.args.get("nlsparam"),
1396    )
def build_default_decimal_type( precision: Optional[int] = None, scale: Optional[int] = None) -> Callable[[sqlglot.expressions.DataType], sqlglot.expressions.DataType]:
1399def build_default_decimal_type(
1400    precision: t.Optional[int] = None, scale: t.Optional[int] = None
1401) -> t.Callable[[exp.DataType], exp.DataType]:
1402    def _builder(dtype: exp.DataType) -> exp.DataType:
1403        if dtype.expressions or precision is None:
1404            return dtype
1405
1406        params = f"{precision}{f', {scale}' if scale is not None else ''}"
1407        return exp.DataType.build(f"DECIMAL({params})")
1408
1409    return _builder
def build_timestamp_from_parts(args: List) -> sqlglot.expressions.Func:
1412def build_timestamp_from_parts(args: t.List) -> exp.Func:
1413    if len(args) == 2:
1414        # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept,
1415        # so we parse this into Anonymous for now instead of introducing complexity
1416        return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args)
1417
1418    return exp.TimestampFromParts.from_arg_list(args)
def sha256_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SHA2) -> str:
1421def sha256_sql(self: Generator, expression: exp.SHA2) -> str:
1422    return self.func(f"SHA{expression.text('length') or '256'}", expression.this)