Edit on GitHub

sqlglot.dialects.dialect

  1from __future__ import annotations
  2
  3import typing as t
  4from enum import Enum
  5from functools import reduce
  6
  7from sqlglot import exp
  8from sqlglot._typing import E
  9from sqlglot.errors import ParseError
 10from sqlglot.generator import Generator
 11from sqlglot.helper import flatten, seq_get
 12from sqlglot.parser import Parser
 13from sqlglot.time import format_time
 14from sqlglot.tokens import Token, Tokenizer, TokenType
 15from sqlglot.trie import new_trie
 16
 17B = t.TypeVar("B", bound=exp.Binary)
 18
 19
 20class Dialects(str, Enum):
 21    DIALECT = ""
 22
 23    BIGQUERY = "bigquery"
 24    CLICKHOUSE = "clickhouse"
 25    DATABRICKS = "databricks"
 26    DRILL = "drill"
 27    DUCKDB = "duckdb"
 28    HIVE = "hive"
 29    MYSQL = "mysql"
 30    ORACLE = "oracle"
 31    POSTGRES = "postgres"
 32    PRESTO = "presto"
 33    REDSHIFT = "redshift"
 34    SNOWFLAKE = "snowflake"
 35    SPARK = "spark"
 36    SPARK2 = "spark2"
 37    SQLITE = "sqlite"
 38    STARROCKS = "starrocks"
 39    TABLEAU = "tableau"
 40    TERADATA = "teradata"
 41    TRINO = "trino"
 42    TSQL = "tsql"
 43    Doris = "doris"
 44
 45
 46class _Dialect(type):
 47    classes: t.Dict[str, t.Type[Dialect]] = {}
 48
 49    def __eq__(cls, other: t.Any) -> bool:
 50        if cls is other:
 51            return True
 52        if isinstance(other, str):
 53            return cls is cls.get(other)
 54        if isinstance(other, Dialect):
 55            return cls is type(other)
 56
 57        return False
 58
 59    def __hash__(cls) -> int:
 60        return hash(cls.__name__.lower())
 61
 62    @classmethod
 63    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 64        return cls.classes[key]
 65
 66    @classmethod
 67    def get(
 68        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 69    ) -> t.Optional[t.Type[Dialect]]:
 70        return cls.classes.get(key, default)
 71
 72    def __new__(cls, clsname, bases, attrs):
 73        klass = super().__new__(cls, clsname, bases, attrs)
 74        enum = Dialects.__members__.get(clsname.upper())
 75        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 76
 77        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 78        klass.FORMAT_TRIE = (
 79            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 80        )
 81        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 82        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 83
 84        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 85        klass.parser_class = getattr(klass, "Parser", Parser)
 86        klass.generator_class = getattr(klass, "Generator", Generator)
 87
 88        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 89        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 90            klass.tokenizer_class._IDENTIFIERS.items()
 91        )[0]
 92
 93        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 94            return next(
 95                (
 96                    (s, e)
 97                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 98                    if t == token_type
 99                ),
100                (None, None),
101            )
102
103        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
104        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
105        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
106
107        dialect_properties = {
108            **{
109                k: v
110                for k, v in vars(klass).items()
111                if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__")
112            },
113            "TOKENIZER_CLASS": klass.tokenizer_class,
114        }
115
116        if enum not in ("", "bigquery"):
117            dialect_properties["SELECT_KINDS"] = ()
118
119        # Pass required dialect properties to the tokenizer, parser and generator classes
120        for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class):
121            for name, value in dialect_properties.items():
122                if hasattr(subclass, name):
123                    setattr(subclass, name, value)
124
125        if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT:
126            klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
127
128        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
129            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
130                TokenType.ANTI,
131                TokenType.SEMI,
132            }
133
134        klass.generator_class.can_identify = klass.can_identify
135
136        return klass
137
138
139class Dialect(metaclass=_Dialect):
140    # Determines the base index offset for arrays
141    INDEX_OFFSET = 0
142
143    # If true unnest table aliases are considered only as column aliases
144    UNNEST_COLUMN_ONLY = False
145
146    # Determines whether or not the table alias comes after tablesample
147    ALIAS_POST_TABLESAMPLE = False
148
149    # Determines whether or not unquoted identifiers are resolved as uppercase
150    # When set to None, it means that the dialect treats all identifiers as case-insensitive
151    RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
152
153    # Determines whether or not an unquoted identifier can start with a digit
154    IDENTIFIERS_CAN_START_WITH_DIGIT = False
155
156    # Determines whether or not the DPIPE token ('||') is a string concatenation operator
157    DPIPE_IS_STRING_CONCAT = True
158
159    # Determines whether or not CONCAT's arguments must be strings
160    STRICT_STRING_CONCAT = False
161
162    # Determines whether or not user-defined data types are supported
163    SUPPORTS_USER_DEFINED_TYPES = True
164
165    # Determines whether or not SEMI/ANTI JOINs are supported
166    SUPPORTS_SEMI_ANTI_JOIN = True
167
168    # Determines how function names are going to be normalized
169    NORMALIZE_FUNCTIONS: bool | str = "upper"
170
171    # Indicates the default null ordering method to use if not explicitly set
172    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
173    NULL_ORDERING = "nulls_are_small"
174
175    DATE_FORMAT = "'%Y-%m-%d'"
176    DATEINT_FORMAT = "'%Y%m%d'"
177    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
178
179    # Custom time mappings in which the key represents dialect time format
180    # and the value represents a python time format
181    TIME_MAPPING: t.Dict[str, str] = {}
182
183    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
184    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
185    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
186    FORMAT_MAPPING: t.Dict[str, str] = {}
187
188    # Columns that are auto-generated by the engine corresponding to this dialect
189    # Such columns may be excluded from SELECT * queries, for example
190    PSEUDOCOLUMNS: t.Set[str] = set()
191
192    # Autofilled
193    tokenizer_class = Tokenizer
194    parser_class = Parser
195    generator_class = Generator
196
197    # A trie of the time_mapping keys
198    TIME_TRIE: t.Dict = {}
199    FORMAT_TRIE: t.Dict = {}
200
201    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
202    INVERSE_TIME_TRIE: t.Dict = {}
203
204    def __eq__(self, other: t.Any) -> bool:
205        return type(self) == other
206
207    def __hash__(self) -> int:
208        return hash(type(self))
209
210    @classmethod
211    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
212        if not dialect:
213            return cls
214        if isinstance(dialect, _Dialect):
215            return dialect
216        if isinstance(dialect, Dialect):
217            return dialect.__class__
218
219        result = cls.get(dialect)
220        if not result:
221            raise ValueError(f"Unknown dialect '{dialect}'")
222
223        return result
224
225    @classmethod
226    def format_time(
227        cls, expression: t.Optional[str | exp.Expression]
228    ) -> t.Optional[exp.Expression]:
229        if isinstance(expression, str):
230            return exp.Literal.string(
231                # the time formats are quoted
232                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
233            )
234
235        if expression and expression.is_string:
236            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
237
238        return expression
239
240    @classmethod
241    def normalize_identifier(cls, expression: E) -> E:
242        """
243        Normalizes an unquoted identifier to either lower or upper case, thus essentially
244        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
245        they will be normalized regardless of being quoted or not.
246        """
247        if isinstance(expression, exp.Identifier) and (
248            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
249        ):
250            expression.set(
251                "this",
252                expression.this.upper()
253                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
254                else expression.this.lower(),
255            )
256
257        return expression
258
259    @classmethod
260    def case_sensitive(cls, text: str) -> bool:
261        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
262        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
263            return False
264
265        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
266        return any(unsafe(char) for char in text)
267
268    @classmethod
269    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
270        """Checks if text can be identified given an identify option.
271
272        Args:
273            text: The text to check.
274            identify:
275                "always" or `True`: Always returns true.
276                "safe": True if the identifier is case-insensitive.
277
278        Returns:
279            Whether or not the given text can be identified.
280        """
281        if identify is True or identify == "always":
282            return True
283
284        if identify == "safe":
285            return not cls.case_sensitive(text)
286
287        return False
288
289    @classmethod
290    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
291        if isinstance(expression, exp.Identifier):
292            name = expression.this
293            expression.set(
294                "quoted",
295                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
296            )
297
298        return expression
299
300    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
301        return self.parser(**opts).parse(self.tokenize(sql), sql)
302
303    def parse_into(
304        self, expression_type: exp.IntoType, sql: str, **opts
305    ) -> t.List[t.Optional[exp.Expression]]:
306        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
307
308    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
309        return self.generator(**opts).generate(expression)
310
311    def transpile(self, sql: str, **opts) -> t.List[str]:
312        return [self.generate(expression, **opts) for expression in self.parse(sql)]
313
314    def tokenize(self, sql: str) -> t.List[Token]:
315        return self.tokenizer.tokenize(sql)
316
317    @property
318    def tokenizer(self) -> Tokenizer:
319        if not hasattr(self, "_tokenizer"):
320            self._tokenizer = self.tokenizer_class()
321        return self._tokenizer
322
323    def parser(self, **opts) -> Parser:
324        return self.parser_class(**opts)
325
326    def generator(self, **opts) -> Generator:
327        return self.generator_class(**opts)
328
329
330DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
331
332
333def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
334    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
335
336
337def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
338    if expression.args.get("accuracy"):
339        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
340    return self.func("APPROX_COUNT_DISTINCT", expression.this)
341
342
343def if_sql(
344    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
345) -> t.Callable[[Generator, exp.If], str]:
346    def _if_sql(self: Generator, expression: exp.If) -> str:
347        return self.func(
348            name,
349            expression.this,
350            expression.args.get("true"),
351            expression.args.get("false") or false_value,
352        )
353
354    return _if_sql
355
356
357def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
358    return self.binary(expression, "->")
359
360
361def arrow_json_extract_scalar_sql(
362    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
363) -> str:
364    return self.binary(expression, "->>")
365
366
367def inline_array_sql(self: Generator, expression: exp.Array) -> str:
368    return f"[{self.expressions(expression, flat=True)}]"
369
370
371def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
372    return self.like_sql(
373        exp.Like(
374            this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy()
375        )
376    )
377
378
379def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
380    zone = self.sql(expression, "this")
381    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
382
383
384def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
385    if expression.args.get("recursive"):
386        self.unsupported("Recursive CTEs are unsupported")
387        expression.args["recursive"] = False
388    return self.with_sql(expression)
389
390
391def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
392    n = self.sql(expression, "this")
393    d = self.sql(expression, "expression")
394    return f"IF({d} <> 0, {n} / {d}, NULL)"
395
396
397def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
398    self.unsupported("TABLESAMPLE unsupported")
399    return self.sql(expression.this)
400
401
402def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
403    self.unsupported("PIVOT unsupported")
404    return ""
405
406
407def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
408    return self.cast_sql(expression)
409
410
411def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
412    self.unsupported("Properties unsupported")
413    return ""
414
415
416def no_comment_column_constraint_sql(
417    self: Generator, expression: exp.CommentColumnConstraint
418) -> str:
419    self.unsupported("CommentColumnConstraint unsupported")
420    return ""
421
422
423def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
424    self.unsupported("MAP_FROM_ENTRIES unsupported")
425    return ""
426
427
428def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
429    this = self.sql(expression, "this")
430    substr = self.sql(expression, "substr")
431    position = self.sql(expression, "position")
432    if position:
433        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
434    return f"STRPOS({this}, {substr})"
435
436
437def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
438    return (
439        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
440    )
441
442
443def var_map_sql(
444    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
445) -> str:
446    keys = expression.args["keys"]
447    values = expression.args["values"]
448
449    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
450        self.unsupported("Cannot convert array columns into map.")
451        return self.func(map_func_name, keys, values)
452
453    args = []
454    for key, value in zip(keys.expressions, values.expressions):
455        args.append(self.sql(key))
456        args.append(self.sql(value))
457
458    return self.func(map_func_name, *args)
459
460
461def format_time_lambda(
462    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
463) -> t.Callable[[t.List], E]:
464    """Helper used for time expressions.
465
466    Args:
467        exp_class: the expression class to instantiate.
468        dialect: target sql dialect.
469        default: the default format, True being time.
470
471    Returns:
472        A callable that can be used to return the appropriately formatted time expression.
473    """
474
475    def _format_time(args: t.List):
476        return exp_class(
477            this=seq_get(args, 0),
478            format=Dialect[dialect].format_time(
479                seq_get(args, 1)
480                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
481            ),
482        )
483
484    return _format_time
485
486
487def time_format(
488    dialect: DialectType = None,
489) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
490    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
491        """
492        Returns the time format for a given expression, unless it's equivalent
493        to the default time format of the dialect of interest.
494        """
495        time_format = self.format_time(expression)
496        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
497
498    return _time_format
499
500
501def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
502    """
503    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
504    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
505    columns are removed from the create statement.
506    """
507    has_schema = isinstance(expression.this, exp.Schema)
508    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
509
510    if has_schema and is_partitionable:
511        expression = expression.copy()
512        prop = expression.find(exp.PartitionedByProperty)
513        if prop and prop.this and not isinstance(prop.this, exp.Schema):
514            schema = expression.this
515            columns = {v.name.upper() for v in prop.this.expressions}
516            partitions = [col for col in schema.expressions if col.name.upper() in columns]
517            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
518            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
519            expression.set("this", schema)
520
521    return self.create_sql(expression)
522
523
524def parse_date_delta(
525    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
526) -> t.Callable[[t.List], E]:
527    def inner_func(args: t.List) -> E:
528        unit_based = len(args) == 3
529        this = args[2] if unit_based else seq_get(args, 0)
530        unit = args[0] if unit_based else exp.Literal.string("DAY")
531        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
532        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
533
534    return inner_func
535
536
537def parse_date_delta_with_interval(
538    expression_class: t.Type[E],
539) -> t.Callable[[t.List], t.Optional[E]]:
540    def func(args: t.List) -> t.Optional[E]:
541        if len(args) < 2:
542            return None
543
544        interval = args[1]
545
546        if not isinstance(interval, exp.Interval):
547            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
548
549        expression = interval.this
550        if expression and expression.is_string:
551            expression = exp.Literal.number(expression.this)
552
553        return expression_class(
554            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
555        )
556
557    return func
558
559
560def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
561    unit = seq_get(args, 0)
562    this = seq_get(args, 1)
563
564    if isinstance(this, exp.Cast) and this.is_type("date"):
565        return exp.DateTrunc(unit=unit, this=this)
566    return exp.TimestampTrunc(this=this, unit=unit)
567
568
569def date_add_interval_sql(
570    data_type: str, kind: str
571) -> t.Callable[[Generator, exp.Expression], str]:
572    def func(self: Generator, expression: exp.Expression) -> str:
573        this = self.sql(expression, "this")
574        unit = expression.args.get("unit")
575        unit = exp.var(unit.name.upper() if unit else "DAY")
576        interval = exp.Interval(this=expression.expression.copy(), unit=unit)
577        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
578
579    return func
580
581
582def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
583    return self.func(
584        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
585    )
586
587
588def locate_to_strposition(args: t.List) -> exp.Expression:
589    return exp.StrPosition(
590        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
591    )
592
593
594def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
595    return self.func(
596        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
597    )
598
599
600def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
601    expression = expression.copy()
602    return self.sql(
603        exp.Substring(
604            this=expression.this, start=exp.Literal.number(1), length=expression.expression
605        )
606    )
607
608
609def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
610    expression = expression.copy()
611    return self.sql(
612        exp.Substring(
613            this=expression.this,
614            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
615        )
616    )
617
618
619def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
620    return self.sql(exp.cast(expression.this, "timestamp"))
621
622
623def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
624    return self.sql(exp.cast(expression.this, "date"))
625
626
627# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
628def encode_decode_sql(
629    self: Generator, expression: exp.Expression, name: str, replace: bool = True
630) -> str:
631    charset = expression.args.get("charset")
632    if charset and charset.name.lower() != "utf-8":
633        self.unsupported(f"Expected utf-8 character set, got {charset}.")
634
635    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
636
637
638def min_or_least(self: Generator, expression: exp.Min) -> str:
639    name = "LEAST" if expression.expressions else "MIN"
640    return rename_func(name)(self, expression)
641
642
643def max_or_greatest(self: Generator, expression: exp.Max) -> str:
644    name = "GREATEST" if expression.expressions else "MAX"
645    return rename_func(name)(self, expression)
646
647
648def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
649    cond = expression.this
650
651    if isinstance(expression.this, exp.Distinct):
652        cond = expression.this.expressions[0]
653        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
654
655    return self.func("sum", exp.func("if", cond.copy(), 1, 0))
656
657
658def trim_sql(self: Generator, expression: exp.Trim) -> str:
659    target = self.sql(expression, "this")
660    trim_type = self.sql(expression, "position")
661    remove_chars = self.sql(expression, "expression")
662    collation = self.sql(expression, "collation")
663
664    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
665    if not remove_chars and not collation:
666        return self.trim_sql(expression)
667
668    trim_type = f"{trim_type} " if trim_type else ""
669    remove_chars = f"{remove_chars} " if remove_chars else ""
670    from_part = "FROM " if trim_type or remove_chars else ""
671    collation = f" COLLATE {collation}" if collation else ""
672    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
673
674
675def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
676    return self.func("STRPTIME", expression.this, self.format_time(expression))
677
678
679def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
680    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
681        _dialect = Dialect.get_or_raise(dialect)
682        time_format = self.format_time(expression)
683        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
684            return self.sql(exp.cast(str_to_time_sql(self, expression), "date"))
685
686        return self.sql(exp.cast(self.sql(expression, "this"), "date"))
687
688    return _ts_or_ds_to_date_sql
689
690
691def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
692    expression = expression.copy()
693    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
694
695
696def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
697    expression = expression.copy()
698    delim, *rest_args = expression.expressions
699    return self.sql(
700        reduce(
701            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
702            rest_args,
703        )
704    )
705
706
707def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
708    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
709    if bad_args:
710        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
711
712    return self.func(
713        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
714    )
715
716
717def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
718    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
719    if bad_args:
720        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
721
722    return self.func(
723        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
724    )
725
726
727def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
728    names = []
729    for agg in aggregations:
730        if isinstance(agg, exp.Alias):
731            names.append(agg.alias)
732        else:
733            """
734            This case corresponds to aggregations without aliases being used as suffixes
735            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
736            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
737            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
738            """
739            agg_all_unquoted = agg.transform(
740                lambda node: exp.Identifier(this=node.name, quoted=False)
741                if isinstance(node, exp.Identifier)
742                else node
743            )
744            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
745
746    return names
747
748
749def simplify_literal(expression: E) -> E:
750    if not isinstance(expression.expression, exp.Literal):
751        from sqlglot.optimizer.simplify import simplify
752
753        simplify(expression.expression)
754
755    return expression
756
757
758def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
759    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
760
761
762# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
763def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
764    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
765
766
767def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
768    return self.func("MAX", expression.this)
769
770
771def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
772    a = self.sql(expression.left)
773    b = self.sql(expression.right)
774    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
775
776
777# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon
778def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str:
779    return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
780
781
782def is_parse_json(expression: exp.Expression) -> bool:
783    return isinstance(expression, exp.ParseJSON) or (
784        isinstance(expression, exp.Cast) and expression.is_type("json")
785    )
786
787
788def isnull_to_is_null(args: t.List) -> exp.Expression:
789    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
790
791
792def move_insert_cte_sql(self: Generator, expression: exp.Insert) -> str:
793    if expression.expression.args.get("with"):
794        expression = expression.copy()
795        expression.set("with", expression.expression.args["with"].pop())
796    return self.insert_sql(expression)
class Dialects(builtins.str, enum.Enum):
21class Dialects(str, Enum):
22    DIALECT = ""
23
24    BIGQUERY = "bigquery"
25    CLICKHOUSE = "clickhouse"
26    DATABRICKS = "databricks"
27    DRILL = "drill"
28    DUCKDB = "duckdb"
29    HIVE = "hive"
30    MYSQL = "mysql"
31    ORACLE = "oracle"
32    POSTGRES = "postgres"
33    PRESTO = "presto"
34    REDSHIFT = "redshift"
35    SNOWFLAKE = "snowflake"
36    SPARK = "spark"
37    SPARK2 = "spark2"
38    SQLITE = "sqlite"
39    STARROCKS = "starrocks"
40    TABLEAU = "tableau"
41    TERADATA = "teradata"
42    TRINO = "trino"
43    TSQL = "tsql"
44    Doris = "doris"

An enumeration.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Doris = <Dialects.Doris: 'doris'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
140class Dialect(metaclass=_Dialect):
141    # Determines the base index offset for arrays
142    INDEX_OFFSET = 0
143
144    # If true unnest table aliases are considered only as column aliases
145    UNNEST_COLUMN_ONLY = False
146
147    # Determines whether or not the table alias comes after tablesample
148    ALIAS_POST_TABLESAMPLE = False
149
150    # Determines whether or not unquoted identifiers are resolved as uppercase
151    # When set to None, it means that the dialect treats all identifiers as case-insensitive
152    RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
153
154    # Determines whether or not an unquoted identifier can start with a digit
155    IDENTIFIERS_CAN_START_WITH_DIGIT = False
156
157    # Determines whether or not the DPIPE token ('||') is a string concatenation operator
158    DPIPE_IS_STRING_CONCAT = True
159
160    # Determines whether or not CONCAT's arguments must be strings
161    STRICT_STRING_CONCAT = False
162
163    # Determines whether or not user-defined data types are supported
164    SUPPORTS_USER_DEFINED_TYPES = True
165
166    # Determines whether or not SEMI/ANTI JOINs are supported
167    SUPPORTS_SEMI_ANTI_JOIN = True
168
169    # Determines how function names are going to be normalized
170    NORMALIZE_FUNCTIONS: bool | str = "upper"
171
172    # Indicates the default null ordering method to use if not explicitly set
173    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
174    NULL_ORDERING = "nulls_are_small"
175
176    DATE_FORMAT = "'%Y-%m-%d'"
177    DATEINT_FORMAT = "'%Y%m%d'"
178    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
179
180    # Custom time mappings in which the key represents dialect time format
181    # and the value represents a python time format
182    TIME_MAPPING: t.Dict[str, str] = {}
183
184    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
185    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
186    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
187    FORMAT_MAPPING: t.Dict[str, str] = {}
188
189    # Columns that are auto-generated by the engine corresponding to this dialect
190    # Such columns may be excluded from SELECT * queries, for example
191    PSEUDOCOLUMNS: t.Set[str] = set()
192
193    # Autofilled
194    tokenizer_class = Tokenizer
195    parser_class = Parser
196    generator_class = Generator
197
198    # A trie of the time_mapping keys
199    TIME_TRIE: t.Dict = {}
200    FORMAT_TRIE: t.Dict = {}
201
202    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
203    INVERSE_TIME_TRIE: t.Dict = {}
204
205    def __eq__(self, other: t.Any) -> bool:
206        return type(self) == other
207
208    def __hash__(self) -> int:
209        return hash(type(self))
210
211    @classmethod
212    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
213        if not dialect:
214            return cls
215        if isinstance(dialect, _Dialect):
216            return dialect
217        if isinstance(dialect, Dialect):
218            return dialect.__class__
219
220        result = cls.get(dialect)
221        if not result:
222            raise ValueError(f"Unknown dialect '{dialect}'")
223
224        return result
225
226    @classmethod
227    def format_time(
228        cls, expression: t.Optional[str | exp.Expression]
229    ) -> t.Optional[exp.Expression]:
230        if isinstance(expression, str):
231            return exp.Literal.string(
232                # the time formats are quoted
233                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
234            )
235
236        if expression and expression.is_string:
237            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
238
239        return expression
240
241    @classmethod
242    def normalize_identifier(cls, expression: E) -> E:
243        """
244        Normalizes an unquoted identifier to either lower or upper case, thus essentially
245        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
246        they will be normalized regardless of being quoted or not.
247        """
248        if isinstance(expression, exp.Identifier) and (
249            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
250        ):
251            expression.set(
252                "this",
253                expression.this.upper()
254                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
255                else expression.this.lower(),
256            )
257
258        return expression
259
260    @classmethod
261    def case_sensitive(cls, text: str) -> bool:
262        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
263        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
264            return False
265
266        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
267        return any(unsafe(char) for char in text)
268
269    @classmethod
270    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
271        """Checks if text can be identified given an identify option.
272
273        Args:
274            text: The text to check.
275            identify:
276                "always" or `True`: Always returns true.
277                "safe": True if the identifier is case-insensitive.
278
279        Returns:
280            Whether or not the given text can be identified.
281        """
282        if identify is True or identify == "always":
283            return True
284
285        if identify == "safe":
286            return not cls.case_sensitive(text)
287
288        return False
289
290    @classmethod
291    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
292        if isinstance(expression, exp.Identifier):
293            name = expression.this
294            expression.set(
295                "quoted",
296                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
297            )
298
299        return expression
300
301    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
302        return self.parser(**opts).parse(self.tokenize(sql), sql)
303
304    def parse_into(
305        self, expression_type: exp.IntoType, sql: str, **opts
306    ) -> t.List[t.Optional[exp.Expression]]:
307        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
308
309    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
310        return self.generator(**opts).generate(expression)
311
312    def transpile(self, sql: str, **opts) -> t.List[str]:
313        return [self.generate(expression, **opts) for expression in self.parse(sql)]
314
315    def tokenize(self, sql: str) -> t.List[Token]:
316        return self.tokenizer.tokenize(sql)
317
318    @property
319    def tokenizer(self) -> Tokenizer:
320        if not hasattr(self, "_tokenizer"):
321            self._tokenizer = self.tokenizer_class()
322        return self._tokenizer
323
324    def parser(self, **opts) -> Parser:
325        return self.parser_class(**opts)
326
327    def generator(self, **opts) -> Generator:
328        return self.generator_class(**opts)
INDEX_OFFSET = 0
UNNEST_COLUMN_ONLY = False
ALIAS_POST_TABLESAMPLE = False
RESOLVES_IDENTIFIERS_AS_UPPERCASE: Optional[bool] = False
IDENTIFIERS_CAN_START_WITH_DIGIT = False
DPIPE_IS_STRING_CONCAT = True
STRICT_STRING_CONCAT = False
SUPPORTS_USER_DEFINED_TYPES = True
SUPPORTS_SEMI_ANTI_JOIN = True
NORMALIZE_FUNCTIONS: bool | str = 'upper'
NULL_ORDERING = 'nulls_are_small'
DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}
FORMAT_MAPPING: Dict[str, str] = {}
PSEUDOCOLUMNS: Set[str] = set()
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Type[Dialect]:
211    @classmethod
212    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
213        if not dialect:
214            return cls
215        if isinstance(dialect, _Dialect):
216            return dialect
217        if isinstance(dialect, Dialect):
218            return dialect.__class__
219
220        result = cls.get(dialect)
221        if not result:
222            raise ValueError(f"Unknown dialect '{dialect}'")
223
224        return result
@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
226    @classmethod
227    def format_time(
228        cls, expression: t.Optional[str | exp.Expression]
229    ) -> t.Optional[exp.Expression]:
230        if isinstance(expression, str):
231            return exp.Literal.string(
232                # the time formats are quoted
233                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
234            )
235
236        if expression and expression.is_string:
237            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
238
239        return expression
@classmethod
def normalize_identifier(cls, expression: ~E) -> ~E:
241    @classmethod
242    def normalize_identifier(cls, expression: E) -> E:
243        """
244        Normalizes an unquoted identifier to either lower or upper case, thus essentially
245        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
246        they will be normalized regardless of being quoted or not.
247        """
248        if isinstance(expression, exp.Identifier) and (
249            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
250        ):
251            expression.set(
252                "this",
253                expression.this.upper()
254                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
255                else expression.this.lower(),
256            )
257
258        return expression

Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized regardless of being quoted or not.

@classmethod
def case_sensitive(cls, text: str) -> bool:
260    @classmethod
261    def case_sensitive(cls, text: str) -> bool:
262        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
263        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
264            return False
265
266        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
267        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

@classmethod
def can_identify(cls, text: str, identify: str | bool = 'safe') -> bool:
269    @classmethod
270    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
271        """Checks if text can be identified given an identify option.
272
273        Args:
274            text: The text to check.
275            identify:
276                "always" or `True`: Always returns true.
277                "safe": True if the identifier is case-insensitive.
278
279        Returns:
280            Whether or not the given text can be identified.
281        """
282        if identify is True or identify == "always":
283            return True
284
285        if identify == "safe":
286            return not cls.case_sensitive(text)
287
288        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:

Whether or not the given text can be identified.

@classmethod
def quote_identifier(cls, expression: ~E, identify: bool = True) -> ~E:
290    @classmethod
291    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
292        if isinstance(expression, exp.Identifier):
293            name = expression.this
294            expression.set(
295                "quoted",
296                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
297            )
298
299        return expression
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
301    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
302        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
304    def parse_into(
305        self, expression_type: exp.IntoType, sql: str, **opts
306    ) -> t.List[t.Optional[exp.Expression]]:
307        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: Optional[sqlglot.expressions.Expression], **opts) -> str:
309    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
310        return self.generator(**opts).generate(expression)
def transpile(self, sql: str, **opts) -> List[str]:
312    def transpile(self, sql: str, **opts) -> t.List[str]:
313        return [self.generate(expression, **opts) for expression in self.parse(sql)]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
315    def tokenize(self, sql: str) -> t.List[Token]:
316        return self.tokenizer.tokenize(sql)
def parser(self, **opts) -> sqlglot.parser.Parser:
324    def parser(self, **opts) -> Parser:
325        return self.parser_class(**opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
327    def generator(self, **opts) -> Generator:
328        return self.generator_class(**opts)
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START = None
BIT_END = None
HEX_START = None
HEX_END = None
BYTE_START = None
BYTE_END = None
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
334def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
335    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
338def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
339    if expression.args.get("accuracy"):
340        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
341    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
344def if_sql(
345    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
346) -> t.Callable[[Generator, exp.If], str]:
347    def _if_sql(self: Generator, expression: exp.If) -> str:
348        return self.func(
349            name,
350            expression.this,
351            expression.args.get("true"),
352            expression.args.get("false") or false_value,
353        )
354
355    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
358def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
359    return self.binary(expression, "->")
def arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
362def arrow_json_extract_scalar_sql(
363    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
364) -> str:
365    return self.binary(expression, "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
368def inline_array_sql(self: Generator, expression: exp.Array) -> str:
369    return f"[{self.expressions(expression, flat=True)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
372def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
373    return self.like_sql(
374        exp.Like(
375            this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy()
376        )
377    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
380def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
381    zone = self.sql(expression, "this")
382    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
385def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
386    if expression.args.get("recursive"):
387        self.unsupported("Recursive CTEs are unsupported")
388        expression.args["recursive"] = False
389    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
392def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
393    n = self.sql(expression, "this")
394    d = self.sql(expression, "expression")
395    return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
398def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
399    self.unsupported("TABLESAMPLE unsupported")
400    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
403def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
404    self.unsupported("PIVOT unsupported")
405    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
408def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
409    return self.cast_sql(expression)
def no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
412def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
413    self.unsupported("Properties unsupported")
414    return ""
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
417def no_comment_column_constraint_sql(
418    self: Generator, expression: exp.CommentColumnConstraint
419) -> str:
420    self.unsupported("CommentColumnConstraint unsupported")
421    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
424def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
425    self.unsupported("MAP_FROM_ENTRIES unsupported")
426    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
429def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
430    this = self.sql(expression, "this")
431    substr = self.sql(expression, "substr")
432    position = self.sql(expression, "position")
433    if position:
434        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
435    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
438def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
439    return (
440        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
441    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
444def var_map_sql(
445    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
446) -> str:
447    keys = expression.args["keys"]
448    values = expression.args["values"]
449
450    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
451        self.unsupported("Cannot convert array columns into map.")
452        return self.func(map_func_name, keys, values)
453
454    args = []
455    for key, value in zip(keys.expressions, values.expressions):
456        args.append(self.sql(key))
457        args.append(self.sql(value))
458
459    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
462def format_time_lambda(
463    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
464) -> t.Callable[[t.List], E]:
465    """Helper used for time expressions.
466
467    Args:
468        exp_class: the expression class to instantiate.
469        dialect: target sql dialect.
470        default: the default format, True being time.
471
472    Returns:
473        A callable that can be used to return the appropriately formatted time expression.
474    """
475
476    def _format_time(args: t.List):
477        return exp_class(
478            this=seq_get(args, 0),
479            format=Dialect[dialect].format_time(
480                seq_get(args, 1)
481                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
482            ),
483        )
484
485    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
488def time_format(
489    dialect: DialectType = None,
490) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
491    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
492        """
493        Returns the time format for a given expression, unless it's equivalent
494        to the default time format of the dialect of interest.
495        """
496        time_format = self.format_time(expression)
497        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
498
499    return _time_format
def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
502def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
503    """
504    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
505    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
506    columns are removed from the create statement.
507    """
508    has_schema = isinstance(expression.this, exp.Schema)
509    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
510
511    if has_schema and is_partitionable:
512        expression = expression.copy()
513        prop = expression.find(exp.PartitionedByProperty)
514        if prop and prop.this and not isinstance(prop.this, exp.Schema):
515            schema = expression.this
516            columns = {v.name.upper() for v in prop.this.expressions}
517            partitions = [col for col in schema.expressions if col.name.upper() in columns]
518            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
519            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
520            expression.set("this", schema)
521
522    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
525def parse_date_delta(
526    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
527) -> t.Callable[[t.List], E]:
528    def inner_func(args: t.List) -> E:
529        unit_based = len(args) == 3
530        this = args[2] if unit_based else seq_get(args, 0)
531        unit = args[0] if unit_based else exp.Literal.string("DAY")
532        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
533        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
534
535    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
538def parse_date_delta_with_interval(
539    expression_class: t.Type[E],
540) -> t.Callable[[t.List], t.Optional[E]]:
541    def func(args: t.List) -> t.Optional[E]:
542        if len(args) < 2:
543            return None
544
545        interval = args[1]
546
547        if not isinstance(interval, exp.Interval):
548            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
549
550        expression = interval.this
551        if expression and expression.is_string:
552            expression = exp.Literal.number(expression.this)
553
554        return expression_class(
555            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
556        )
557
558    return func
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
561def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
562    unit = seq_get(args, 0)
563    this = seq_get(args, 1)
564
565    if isinstance(this, exp.Cast) and this.is_type("date"):
566        return exp.DateTrunc(unit=unit, this=this)
567    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
570def date_add_interval_sql(
571    data_type: str, kind: str
572) -> t.Callable[[Generator, exp.Expression], str]:
573    def func(self: Generator, expression: exp.Expression) -> str:
574        this = self.sql(expression, "this")
575        unit = expression.args.get("unit")
576        unit = exp.var(unit.name.upper() if unit else "DAY")
577        interval = exp.Interval(this=expression.expression.copy(), unit=unit)
578        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
579
580    return func
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
583def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
584    return self.func(
585        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
586    )
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
589def locate_to_strposition(args: t.List) -> exp.Expression:
590    return exp.StrPosition(
591        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
592    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
595def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
596    return self.func(
597        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
598    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
601def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
602    expression = expression.copy()
603    return self.sql(
604        exp.Substring(
605            this=expression.this, start=exp.Literal.number(1), length=expression.expression
606        )
607    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
610def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
611    expression = expression.copy()
612    return self.sql(
613        exp.Substring(
614            this=expression.this,
615            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
616        )
617    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
620def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
621    return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
624def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
625    return self.sql(exp.cast(expression.this, "date"))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
629def encode_decode_sql(
630    self: Generator, expression: exp.Expression, name: str, replace: bool = True
631) -> str:
632    charset = expression.args.get("charset")
633    if charset and charset.name.lower() != "utf-8":
634        self.unsupported(f"Expected utf-8 character set, got {charset}.")
635
636    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
639def min_or_least(self: Generator, expression: exp.Min) -> str:
640    name = "LEAST" if expression.expressions else "MIN"
641    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
644def max_or_greatest(self: Generator, expression: exp.Max) -> str:
645    name = "GREATEST" if expression.expressions else "MAX"
646    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
649def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
650    cond = expression.this
651
652    if isinstance(expression.this, exp.Distinct):
653        cond = expression.this.expressions[0]
654        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
655
656    return self.func("sum", exp.func("if", cond.copy(), 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
659def trim_sql(self: Generator, expression: exp.Trim) -> str:
660    target = self.sql(expression, "this")
661    trim_type = self.sql(expression, "position")
662    remove_chars = self.sql(expression, "expression")
663    collation = self.sql(expression, "collation")
664
665    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
666    if not remove_chars and not collation:
667        return self.trim_sql(expression)
668
669    trim_type = f"{trim_type} " if trim_type else ""
670    remove_chars = f"{remove_chars} " if remove_chars else ""
671    from_part = "FROM " if trim_type or remove_chars else ""
672    collation = f" COLLATE {collation}" if collation else ""
673    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
676def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
677    return self.func("STRPTIME", expression.this, self.format_time(expression))
def ts_or_ds_to_date_sql(dialect: str) -> Callable:
680def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
681    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
682        _dialect = Dialect.get_or_raise(dialect)
683        time_format = self.format_time(expression)
684        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
685            return self.sql(exp.cast(str_to_time_sql(self, expression), "date"))
686
687        return self.sql(exp.cast(self.sql(expression, "this"), "date"))
688
689    return _ts_or_ds_to_date_sql
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat | sqlglot.expressions.SafeConcat) -> str:
692def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
693    expression = expression.copy()
694    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
697def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
698    expression = expression.copy()
699    delim, *rest_args = expression.expressions
700    return self.sql(
701        reduce(
702            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
703            rest_args,
704        )
705    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
708def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
709    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
710    if bad_args:
711        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
712
713    return self.func(
714        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
715    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
718def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
719    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
720    if bad_args:
721        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
722
723    return self.func(
724        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
725    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
728def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
729    names = []
730    for agg in aggregations:
731        if isinstance(agg, exp.Alias):
732            names.append(agg.alias)
733        else:
734            """
735            This case corresponds to aggregations without aliases being used as suffixes
736            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
737            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
738            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
739            """
740            agg_all_unquoted = agg.transform(
741                lambda node: exp.Identifier(this=node.name, quoted=False)
742                if isinstance(node, exp.Identifier)
743                else node
744            )
745            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
746
747    return names
def simplify_literal(expression: ~E) -> ~E:
750def simplify_literal(expression: E) -> E:
751    if not isinstance(expression.expression, exp.Literal):
752        from sqlglot.optimizer.simplify import simplify
753
754        simplify(expression.expression)
755
756    return expression
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
759def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
760    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def parse_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
764def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
765    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
768def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
769    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
772def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
773    a = self.sql(expression.left)
774    b = self.sql(expression.right)
775    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def json_keyvalue_comma_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONKeyValue) -> str:
779def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str:
780    return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
783def is_parse_json(expression: exp.Expression) -> bool:
784    return isinstance(expression, exp.ParseJSON) or (
785        isinstance(expression, exp.Cast) and expression.is_type("json")
786    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
789def isnull_to_is_null(args: t.List) -> exp.Expression:
790    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def move_insert_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Insert) -> str:
793def move_insert_cte_sql(self: Generator, expression: exp.Insert) -> str:
794    if expression.expression.args.get("with"):
795        expression = expression.copy()
796        expression.set("with", expression.expression.args["with"].pop())
797    return self.insert_sql(expression)