Edit on GitHub

sqlglot.dialects.dialect

  1from __future__ import annotations
  2
  3import typing as t
  4from enum import Enum
  5from functools import reduce
  6
  7from sqlglot import exp
  8from sqlglot._typing import E
  9from sqlglot.errors import ParseError
 10from sqlglot.generator import Generator
 11from sqlglot.helper import flatten, seq_get
 12from sqlglot.parser import Parser
 13from sqlglot.time import format_time
 14from sqlglot.tokens import Token, Tokenizer, TokenType
 15from sqlglot.trie import new_trie
 16
 17B = t.TypeVar("B", bound=exp.Binary)
 18
 19
 20class Dialects(str, Enum):
 21    DIALECT = ""
 22
 23    BIGQUERY = "bigquery"
 24    CLICKHOUSE = "clickhouse"
 25    DATABRICKS = "databricks"
 26    DRILL = "drill"
 27    DUCKDB = "duckdb"
 28    HIVE = "hive"
 29    MYSQL = "mysql"
 30    ORACLE = "oracle"
 31    POSTGRES = "postgres"
 32    PRESTO = "presto"
 33    REDSHIFT = "redshift"
 34    SNOWFLAKE = "snowflake"
 35    SPARK = "spark"
 36    SPARK2 = "spark2"
 37    SQLITE = "sqlite"
 38    STARROCKS = "starrocks"
 39    TABLEAU = "tableau"
 40    TERADATA = "teradata"
 41    TRINO = "trino"
 42    TSQL = "tsql"
 43    Doris = "doris"
 44
 45
 46class _Dialect(type):
 47    classes: t.Dict[str, t.Type[Dialect]] = {}
 48
 49    def __eq__(cls, other: t.Any) -> bool:
 50        if cls is other:
 51            return True
 52        if isinstance(other, str):
 53            return cls is cls.get(other)
 54        if isinstance(other, Dialect):
 55            return cls is type(other)
 56
 57        return False
 58
 59    def __hash__(cls) -> int:
 60        return hash(cls.__name__.lower())
 61
 62    @classmethod
 63    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 64        return cls.classes[key]
 65
 66    @classmethod
 67    def get(
 68        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 69    ) -> t.Optional[t.Type[Dialect]]:
 70        return cls.classes.get(key, default)
 71
 72    def __new__(cls, clsname, bases, attrs):
 73        klass = super().__new__(cls, clsname, bases, attrs)
 74        enum = Dialects.__members__.get(clsname.upper())
 75        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 76
 77        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 78        klass.FORMAT_TRIE = (
 79            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 80        )
 81        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 82        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 83
 84        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 85        klass.parser_class = getattr(klass, "Parser", Parser)
 86        klass.generator_class = getattr(klass, "Generator", Generator)
 87
 88        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 89        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 90            klass.tokenizer_class._IDENTIFIERS.items()
 91        )[0]
 92
 93        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 94            return next(
 95                (
 96                    (s, e)
 97                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 98                    if t == token_type
 99                ),
100                (None, None),
101            )
102
103        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
104        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
105        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
106
107        dialect_properties = {
108            **{
109                k: v
110                for k, v in vars(klass).items()
111                if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__")
112            },
113            "TOKENIZER_CLASS": klass.tokenizer_class,
114        }
115
116        if enum not in ("", "bigquery"):
117            dialect_properties["SELECT_KINDS"] = ()
118
119        # Pass required dialect properties to the tokenizer, parser and generator classes
120        for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class):
121            for name, value in dialect_properties.items():
122                if hasattr(subclass, name):
123                    setattr(subclass, name, value)
124
125        if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT:
126            klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
127
128        klass.generator_class.can_identify = klass.can_identify
129
130        return klass
131
132
133class Dialect(metaclass=_Dialect):
134    # Determines the base index offset for arrays
135    INDEX_OFFSET = 0
136
137    # If true unnest table aliases are considered only as column aliases
138    UNNEST_COLUMN_ONLY = False
139
140    # Determines whether or not the table alias comes after tablesample
141    ALIAS_POST_TABLESAMPLE = False
142
143    # Determines whether or not unquoted identifiers are resolved as uppercase
144    # When set to None, it means that the dialect treats all identifiers as case-insensitive
145    RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
146
147    # Determines whether or not an unquoted identifier can start with a digit
148    IDENTIFIERS_CAN_START_WITH_DIGIT = False
149
150    # Determines whether or not the DPIPE token ('||') is a string concatenation operator
151    DPIPE_IS_STRING_CONCAT = True
152
153    # Determines whether or not CONCAT's arguments must be strings
154    STRICT_STRING_CONCAT = False
155
156    # Determines how function names are going to be normalized
157    NORMALIZE_FUNCTIONS: bool | str = "upper"
158
159    # Indicates the default null ordering method to use if not explicitly set
160    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
161    NULL_ORDERING = "nulls_are_small"
162
163    DATE_FORMAT = "'%Y-%m-%d'"
164    DATEINT_FORMAT = "'%Y%m%d'"
165    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
166
167    # Custom time mappings in which the key represents dialect time format
168    # and the value represents a python time format
169    TIME_MAPPING: t.Dict[str, str] = {}
170
171    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
172    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
173    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
174    FORMAT_MAPPING: t.Dict[str, str] = {}
175
176    # Columns that are auto-generated by the engine corresponding to this dialect
177    # Such columns may be excluded from SELECT * queries, for example
178    PSEUDOCOLUMNS: t.Set[str] = set()
179
180    # Autofilled
181    tokenizer_class = Tokenizer
182    parser_class = Parser
183    generator_class = Generator
184
185    # A trie of the time_mapping keys
186    TIME_TRIE: t.Dict = {}
187    FORMAT_TRIE: t.Dict = {}
188
189    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
190    INVERSE_TIME_TRIE: t.Dict = {}
191
192    def __eq__(self, other: t.Any) -> bool:
193        return type(self) == other
194
195    def __hash__(self) -> int:
196        return hash(type(self))
197
198    @classmethod
199    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
200        if not dialect:
201            return cls
202        if isinstance(dialect, _Dialect):
203            return dialect
204        if isinstance(dialect, Dialect):
205            return dialect.__class__
206
207        result = cls.get(dialect)
208        if not result:
209            raise ValueError(f"Unknown dialect '{dialect}'")
210
211        return result
212
213    @classmethod
214    def format_time(
215        cls, expression: t.Optional[str | exp.Expression]
216    ) -> t.Optional[exp.Expression]:
217        if isinstance(expression, str):
218            return exp.Literal.string(
219                # the time formats are quoted
220                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
221            )
222
223        if expression and expression.is_string:
224            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
225
226        return expression
227
228    @classmethod
229    def normalize_identifier(cls, expression: E) -> E:
230        """
231        Normalizes an unquoted identifier to either lower or upper case, thus essentially
232        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
233        they will be normalized regardless of being quoted or not.
234        """
235        if isinstance(expression, exp.Identifier) and (
236            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
237        ):
238            expression.set(
239                "this",
240                expression.this.upper()
241                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
242                else expression.this.lower(),
243            )
244
245        return expression
246
247    @classmethod
248    def case_sensitive(cls, text: str) -> bool:
249        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
250        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
251            return False
252
253        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
254        return any(unsafe(char) for char in text)
255
256    @classmethod
257    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
258        """Checks if text can be identified given an identify option.
259
260        Args:
261            text: The text to check.
262            identify:
263                "always" or `True`: Always returns true.
264                "safe": True if the identifier is case-insensitive.
265
266        Returns:
267            Whether or not the given text can be identified.
268        """
269        if identify is True or identify == "always":
270            return True
271
272        if identify == "safe":
273            return not cls.case_sensitive(text)
274
275        return False
276
277    @classmethod
278    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
279        if isinstance(expression, exp.Identifier):
280            name = expression.this
281            expression.set(
282                "quoted",
283                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
284            )
285
286        return expression
287
288    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
289        return self.parser(**opts).parse(self.tokenize(sql), sql)
290
291    def parse_into(
292        self, expression_type: exp.IntoType, sql: str, **opts
293    ) -> t.List[t.Optional[exp.Expression]]:
294        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
295
296    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
297        return self.generator(**opts).generate(expression)
298
299    def transpile(self, sql: str, **opts) -> t.List[str]:
300        return [self.generate(expression, **opts) for expression in self.parse(sql)]
301
302    def tokenize(self, sql: str) -> t.List[Token]:
303        return self.tokenizer.tokenize(sql)
304
305    @property
306    def tokenizer(self) -> Tokenizer:
307        if not hasattr(self, "_tokenizer"):
308            self._tokenizer = self.tokenizer_class()
309        return self._tokenizer
310
311    def parser(self, **opts) -> Parser:
312        return self.parser_class(**opts)
313
314    def generator(self, **opts) -> Generator:
315        return self.generator_class(**opts)
316
317
318DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
319
320
321def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
322    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
323
324
325def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
326    if expression.args.get("accuracy"):
327        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
328    return self.func("APPROX_COUNT_DISTINCT", expression.this)
329
330
331def if_sql(self: Generator, expression: exp.If) -> str:
332    return self.func(
333        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
334    )
335
336
337def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
338    return self.binary(expression, "->")
339
340
341def arrow_json_extract_scalar_sql(
342    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
343) -> str:
344    return self.binary(expression, "->>")
345
346
347def inline_array_sql(self: Generator, expression: exp.Array) -> str:
348    return f"[{self.expressions(expression, flat=True)}]"
349
350
351def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
352    return self.like_sql(
353        exp.Like(
354            this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy()
355        )
356    )
357
358
359def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
360    zone = self.sql(expression, "this")
361    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
362
363
364def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
365    if expression.args.get("recursive"):
366        self.unsupported("Recursive CTEs are unsupported")
367        expression.args["recursive"] = False
368    return self.with_sql(expression)
369
370
371def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
372    n = self.sql(expression, "this")
373    d = self.sql(expression, "expression")
374    return f"IF({d} <> 0, {n} / {d}, NULL)"
375
376
377def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
378    self.unsupported("TABLESAMPLE unsupported")
379    return self.sql(expression.this)
380
381
382def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
383    self.unsupported("PIVOT unsupported")
384    return ""
385
386
387def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
388    return self.cast_sql(expression)
389
390
391def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
392    self.unsupported("Properties unsupported")
393    return ""
394
395
396def no_comment_column_constraint_sql(
397    self: Generator, expression: exp.CommentColumnConstraint
398) -> str:
399    self.unsupported("CommentColumnConstraint unsupported")
400    return ""
401
402
403def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
404    self.unsupported("MAP_FROM_ENTRIES unsupported")
405    return ""
406
407
408def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
409    this = self.sql(expression, "this")
410    substr = self.sql(expression, "substr")
411    position = self.sql(expression, "position")
412    if position:
413        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
414    return f"STRPOS({this}, {substr})"
415
416
417def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
418    return (
419        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
420    )
421
422
423def var_map_sql(
424    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
425) -> str:
426    keys = expression.args["keys"]
427    values = expression.args["values"]
428
429    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
430        self.unsupported("Cannot convert array columns into map.")
431        return self.func(map_func_name, keys, values)
432
433    args = []
434    for key, value in zip(keys.expressions, values.expressions):
435        args.append(self.sql(key))
436        args.append(self.sql(value))
437
438    return self.func(map_func_name, *args)
439
440
441def format_time_lambda(
442    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
443) -> t.Callable[[t.List], E]:
444    """Helper used for time expressions.
445
446    Args:
447        exp_class: the expression class to instantiate.
448        dialect: target sql dialect.
449        default: the default format, True being time.
450
451    Returns:
452        A callable that can be used to return the appropriately formatted time expression.
453    """
454
455    def _format_time(args: t.List):
456        return exp_class(
457            this=seq_get(args, 0),
458            format=Dialect[dialect].format_time(
459                seq_get(args, 1)
460                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
461            ),
462        )
463
464    return _format_time
465
466
467def time_format(
468    dialect: DialectType = None,
469) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
470    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
471        """
472        Returns the time format for a given expression, unless it's equivalent
473        to the default time format of the dialect of interest.
474        """
475        time_format = self.format_time(expression)
476        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
477
478    return _time_format
479
480
481def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
482    """
483    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
484    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
485    columns are removed from the create statement.
486    """
487    has_schema = isinstance(expression.this, exp.Schema)
488    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
489
490    if has_schema and is_partitionable:
491        expression = expression.copy()
492        prop = expression.find(exp.PartitionedByProperty)
493        if prop and prop.this and not isinstance(prop.this, exp.Schema):
494            schema = expression.this
495            columns = {v.name.upper() for v in prop.this.expressions}
496            partitions = [col for col in schema.expressions if col.name.upper() in columns]
497            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
498            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
499            expression.set("this", schema)
500
501    return self.create_sql(expression)
502
503
504def parse_date_delta(
505    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
506) -> t.Callable[[t.List], E]:
507    def inner_func(args: t.List) -> E:
508        unit_based = len(args) == 3
509        this = args[2] if unit_based else seq_get(args, 0)
510        unit = args[0] if unit_based else exp.Literal.string("DAY")
511        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
512        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
513
514    return inner_func
515
516
517def parse_date_delta_with_interval(
518    expression_class: t.Type[E],
519) -> t.Callable[[t.List], t.Optional[E]]:
520    def func(args: t.List) -> t.Optional[E]:
521        if len(args) < 2:
522            return None
523
524        interval = args[1]
525
526        if not isinstance(interval, exp.Interval):
527            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
528
529        expression = interval.this
530        if expression and expression.is_string:
531            expression = exp.Literal.number(expression.this)
532
533        return expression_class(
534            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
535        )
536
537    return func
538
539
540def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
541    unit = seq_get(args, 0)
542    this = seq_get(args, 1)
543
544    if isinstance(this, exp.Cast) and this.is_type("date"):
545        return exp.DateTrunc(unit=unit, this=this)
546    return exp.TimestampTrunc(this=this, unit=unit)
547
548
549def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
550    return self.func(
551        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
552    )
553
554
555def locate_to_strposition(args: t.List) -> exp.Expression:
556    return exp.StrPosition(
557        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
558    )
559
560
561def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
562    return self.func(
563        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
564    )
565
566
567def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
568    expression = expression.copy()
569    return self.sql(
570        exp.Substring(
571            this=expression.this, start=exp.Literal.number(1), length=expression.expression
572        )
573    )
574
575
576def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
577    expression = expression.copy()
578    return self.sql(
579        exp.Substring(
580            this=expression.this,
581            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
582        )
583    )
584
585
586def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
587    return self.sql(exp.cast(expression.this, "timestamp"))
588
589
590def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
591    return self.sql(exp.cast(expression.this, "date"))
592
593
594# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
595def encode_decode_sql(
596    self: Generator, expression: exp.Expression, name: str, replace: bool = True
597) -> str:
598    charset = expression.args.get("charset")
599    if charset and charset.name.lower() != "utf-8":
600        self.unsupported(f"Expected utf-8 character set, got {charset}.")
601
602    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
603
604
605def min_or_least(self: Generator, expression: exp.Min) -> str:
606    name = "LEAST" if expression.expressions else "MIN"
607    return rename_func(name)(self, expression)
608
609
610def max_or_greatest(self: Generator, expression: exp.Max) -> str:
611    name = "GREATEST" if expression.expressions else "MAX"
612    return rename_func(name)(self, expression)
613
614
615def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
616    cond = expression.this
617
618    if isinstance(expression.this, exp.Distinct):
619        cond = expression.this.expressions[0]
620        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
621
622    return self.func("sum", exp.func("if", cond.copy(), 1, 0))
623
624
625def trim_sql(self: Generator, expression: exp.Trim) -> str:
626    target = self.sql(expression, "this")
627    trim_type = self.sql(expression, "position")
628    remove_chars = self.sql(expression, "expression")
629    collation = self.sql(expression, "collation")
630
631    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
632    if not remove_chars and not collation:
633        return self.trim_sql(expression)
634
635    trim_type = f"{trim_type} " if trim_type else ""
636    remove_chars = f"{remove_chars} " if remove_chars else ""
637    from_part = "FROM " if trim_type or remove_chars else ""
638    collation = f" COLLATE {collation}" if collation else ""
639    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
640
641
642def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
643    return self.func("STRPTIME", expression.this, self.format_time(expression))
644
645
646def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
647    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
648        _dialect = Dialect.get_or_raise(dialect)
649        time_format = self.format_time(expression)
650        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
651            return self.sql(exp.cast(str_to_time_sql(self, expression), "date"))
652
653        return self.sql(exp.cast(self.sql(expression, "this"), "date"))
654
655    return _ts_or_ds_to_date_sql
656
657
658def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
659    expression = expression.copy()
660    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
661
662
663def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
664    expression = expression.copy()
665    delim, *rest_args = expression.expressions
666    return self.sql(
667        reduce(
668            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
669            rest_args,
670        )
671    )
672
673
674def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
675    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
676    if bad_args:
677        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
678
679    return self.func(
680        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
681    )
682
683
684def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
685    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
686    if bad_args:
687        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
688
689    return self.func(
690        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
691    )
692
693
694def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
695    names = []
696    for agg in aggregations:
697        if isinstance(agg, exp.Alias):
698            names.append(agg.alias)
699        else:
700            """
701            This case corresponds to aggregations without aliases being used as suffixes
702            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
703            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
704            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
705            """
706            agg_all_unquoted = agg.transform(
707                lambda node: exp.Identifier(this=node.name, quoted=False)
708                if isinstance(node, exp.Identifier)
709                else node
710            )
711            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
712
713    return names
714
715
716def simplify_literal(expression: E) -> E:
717    if not isinstance(expression.expression, exp.Literal):
718        from sqlglot.optimizer.simplify import simplify
719
720        simplify(expression.expression)
721
722    return expression
723
724
725def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
726    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
727
728
729# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
730def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
731    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
732
733
734def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
735    return self.func("MAX", expression.this)
736
737
738# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon
739def json_keyvalue_comma_sql(self, expression: exp.JSONKeyValue) -> str:
740    return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
class Dialects(builtins.str, enum.Enum):
21class Dialects(str, Enum):
22    DIALECT = ""
23
24    BIGQUERY = "bigquery"
25    CLICKHOUSE = "clickhouse"
26    DATABRICKS = "databricks"
27    DRILL = "drill"
28    DUCKDB = "duckdb"
29    HIVE = "hive"
30    MYSQL = "mysql"
31    ORACLE = "oracle"
32    POSTGRES = "postgres"
33    PRESTO = "presto"
34    REDSHIFT = "redshift"
35    SNOWFLAKE = "snowflake"
36    SPARK = "spark"
37    SPARK2 = "spark2"
38    SQLITE = "sqlite"
39    STARROCKS = "starrocks"
40    TABLEAU = "tableau"
41    TERADATA = "teradata"
42    TRINO = "trino"
43    TSQL = "tsql"
44    Doris = "doris"

An enumeration.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Doris = <Dialects.Doris: 'doris'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
134class Dialect(metaclass=_Dialect):
135    # Determines the base index offset for arrays
136    INDEX_OFFSET = 0
137
138    # If true unnest table aliases are considered only as column aliases
139    UNNEST_COLUMN_ONLY = False
140
141    # Determines whether or not the table alias comes after tablesample
142    ALIAS_POST_TABLESAMPLE = False
143
144    # Determines whether or not unquoted identifiers are resolved as uppercase
145    # When set to None, it means that the dialect treats all identifiers as case-insensitive
146    RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
147
148    # Determines whether or not an unquoted identifier can start with a digit
149    IDENTIFIERS_CAN_START_WITH_DIGIT = False
150
151    # Determines whether or not the DPIPE token ('||') is a string concatenation operator
152    DPIPE_IS_STRING_CONCAT = True
153
154    # Determines whether or not CONCAT's arguments must be strings
155    STRICT_STRING_CONCAT = False
156
157    # Determines how function names are going to be normalized
158    NORMALIZE_FUNCTIONS: bool | str = "upper"
159
160    # Indicates the default null ordering method to use if not explicitly set
161    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
162    NULL_ORDERING = "nulls_are_small"
163
164    DATE_FORMAT = "'%Y-%m-%d'"
165    DATEINT_FORMAT = "'%Y%m%d'"
166    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
167
168    # Custom time mappings in which the key represents dialect time format
169    # and the value represents a python time format
170    TIME_MAPPING: t.Dict[str, str] = {}
171
172    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
173    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
174    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
175    FORMAT_MAPPING: t.Dict[str, str] = {}
176
177    # Columns that are auto-generated by the engine corresponding to this dialect
178    # Such columns may be excluded from SELECT * queries, for example
179    PSEUDOCOLUMNS: t.Set[str] = set()
180
181    # Autofilled
182    tokenizer_class = Tokenizer
183    parser_class = Parser
184    generator_class = Generator
185
186    # A trie of the time_mapping keys
187    TIME_TRIE: t.Dict = {}
188    FORMAT_TRIE: t.Dict = {}
189
190    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
191    INVERSE_TIME_TRIE: t.Dict = {}
192
193    def __eq__(self, other: t.Any) -> bool:
194        return type(self) == other
195
196    def __hash__(self) -> int:
197        return hash(type(self))
198
199    @classmethod
200    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
201        if not dialect:
202            return cls
203        if isinstance(dialect, _Dialect):
204            return dialect
205        if isinstance(dialect, Dialect):
206            return dialect.__class__
207
208        result = cls.get(dialect)
209        if not result:
210            raise ValueError(f"Unknown dialect '{dialect}'")
211
212        return result
213
214    @classmethod
215    def format_time(
216        cls, expression: t.Optional[str | exp.Expression]
217    ) -> t.Optional[exp.Expression]:
218        if isinstance(expression, str):
219            return exp.Literal.string(
220                # the time formats are quoted
221                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
222            )
223
224        if expression and expression.is_string:
225            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
226
227        return expression
228
229    @classmethod
230    def normalize_identifier(cls, expression: E) -> E:
231        """
232        Normalizes an unquoted identifier to either lower or upper case, thus essentially
233        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
234        they will be normalized regardless of being quoted or not.
235        """
236        if isinstance(expression, exp.Identifier) and (
237            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
238        ):
239            expression.set(
240                "this",
241                expression.this.upper()
242                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
243                else expression.this.lower(),
244            )
245
246        return expression
247
248    @classmethod
249    def case_sensitive(cls, text: str) -> bool:
250        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
251        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
252            return False
253
254        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
255        return any(unsafe(char) for char in text)
256
257    @classmethod
258    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
259        """Checks if text can be identified given an identify option.
260
261        Args:
262            text: The text to check.
263            identify:
264                "always" or `True`: Always returns true.
265                "safe": True if the identifier is case-insensitive.
266
267        Returns:
268            Whether or not the given text can be identified.
269        """
270        if identify is True or identify == "always":
271            return True
272
273        if identify == "safe":
274            return not cls.case_sensitive(text)
275
276        return False
277
278    @classmethod
279    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
280        if isinstance(expression, exp.Identifier):
281            name = expression.this
282            expression.set(
283                "quoted",
284                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
285            )
286
287        return expression
288
289    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
290        return self.parser(**opts).parse(self.tokenize(sql), sql)
291
292    def parse_into(
293        self, expression_type: exp.IntoType, sql: str, **opts
294    ) -> t.List[t.Optional[exp.Expression]]:
295        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
296
297    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
298        return self.generator(**opts).generate(expression)
299
300    def transpile(self, sql: str, **opts) -> t.List[str]:
301        return [self.generate(expression, **opts) for expression in self.parse(sql)]
302
303    def tokenize(self, sql: str) -> t.List[Token]:
304        return self.tokenizer.tokenize(sql)
305
306    @property
307    def tokenizer(self) -> Tokenizer:
308        if not hasattr(self, "_tokenizer"):
309            self._tokenizer = self.tokenizer_class()
310        return self._tokenizer
311
312    def parser(self, **opts) -> Parser:
313        return self.parser_class(**opts)
314
315    def generator(self, **opts) -> Generator:
316        return self.generator_class(**opts)
INDEX_OFFSET = 0
UNNEST_COLUMN_ONLY = False
ALIAS_POST_TABLESAMPLE = False
RESOLVES_IDENTIFIERS_AS_UPPERCASE: Optional[bool] = False
IDENTIFIERS_CAN_START_WITH_DIGIT = False
DPIPE_IS_STRING_CONCAT = True
STRICT_STRING_CONCAT = False
NORMALIZE_FUNCTIONS: bool | str = 'upper'
NULL_ORDERING = 'nulls_are_small'
DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}
FORMAT_MAPPING: Dict[str, str] = {}
PSEUDOCOLUMNS: Set[str] = set()
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Type[Dialect]:
199    @classmethod
200    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
201        if not dialect:
202            return cls
203        if isinstance(dialect, _Dialect):
204            return dialect
205        if isinstance(dialect, Dialect):
206            return dialect.__class__
207
208        result = cls.get(dialect)
209        if not result:
210            raise ValueError(f"Unknown dialect '{dialect}'")
211
212        return result
@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
214    @classmethod
215    def format_time(
216        cls, expression: t.Optional[str | exp.Expression]
217    ) -> t.Optional[exp.Expression]:
218        if isinstance(expression, str):
219            return exp.Literal.string(
220                # the time formats are quoted
221                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
222            )
223
224        if expression and expression.is_string:
225            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
226
227        return expression
@classmethod
def normalize_identifier(cls, expression: ~E) -> ~E:
229    @classmethod
230    def normalize_identifier(cls, expression: E) -> E:
231        """
232        Normalizes an unquoted identifier to either lower or upper case, thus essentially
233        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
234        they will be normalized regardless of being quoted or not.
235        """
236        if isinstance(expression, exp.Identifier) and (
237            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
238        ):
239            expression.set(
240                "this",
241                expression.this.upper()
242                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
243                else expression.this.lower(),
244            )
245
246        return expression

Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized regardless of being quoted or not.

@classmethod
def case_sensitive(cls, text: str) -> bool:
248    @classmethod
249    def case_sensitive(cls, text: str) -> bool:
250        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
251        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
252            return False
253
254        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
255        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

@classmethod
def can_identify(cls, text: str, identify: str | bool = 'safe') -> bool:
257    @classmethod
258    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
259        """Checks if text can be identified given an identify option.
260
261        Args:
262            text: The text to check.
263            identify:
264                "always" or `True`: Always returns true.
265                "safe": True if the identifier is case-insensitive.
266
267        Returns:
268            Whether or not the given text can be identified.
269        """
270        if identify is True or identify == "always":
271            return True
272
273        if identify == "safe":
274            return not cls.case_sensitive(text)
275
276        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:

Whether or not the given text can be identified.

@classmethod
def quote_identifier(cls, expression: ~E, identify: bool = True) -> ~E:
278    @classmethod
279    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
280        if isinstance(expression, exp.Identifier):
281            name = expression.this
282            expression.set(
283                "quoted",
284                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
285            )
286
287        return expression
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
289    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
290        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
292    def parse_into(
293        self, expression_type: exp.IntoType, sql: str, **opts
294    ) -> t.List[t.Optional[exp.Expression]]:
295        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: Optional[sqlglot.expressions.Expression], **opts) -> str:
297    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
298        return self.generator(**opts).generate(expression)
def transpile(self, sql: str, **opts) -> List[str]:
300    def transpile(self, sql: str, **opts) -> t.List[str]:
301        return [self.generate(expression, **opts) for expression in self.parse(sql)]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
303    def tokenize(self, sql: str) -> t.List[Token]:
304        return self.tokenizer.tokenize(sql)
def parser(self, **opts) -> sqlglot.parser.Parser:
312    def parser(self, **opts) -> Parser:
313        return self.parser_class(**opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
315    def generator(self, **opts) -> Generator:
316        return self.generator_class(**opts)
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START = None
BIT_END = None
HEX_START = None
HEX_END = None
BYTE_START = None
BYTE_END = None
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
322def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
323    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
326def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
327    if expression.args.get("accuracy"):
328        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
329    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.If) -> str:
332def if_sql(self: Generator, expression: exp.If) -> str:
333    return self.func(
334        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
335    )
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
338def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
339    return self.binary(expression, "->")
def arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
342def arrow_json_extract_scalar_sql(
343    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
344) -> str:
345    return self.binary(expression, "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
348def inline_array_sql(self: Generator, expression: exp.Array) -> str:
349    return f"[{self.expressions(expression, flat=True)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
352def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
353    return self.like_sql(
354        exp.Like(
355            this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy()
356        )
357    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
360def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
361    zone = self.sql(expression, "this")
362    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
365def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
366    if expression.args.get("recursive"):
367        self.unsupported("Recursive CTEs are unsupported")
368        expression.args["recursive"] = False
369    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
372def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
373    n = self.sql(expression, "this")
374    d = self.sql(expression, "expression")
375    return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
378def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
379    self.unsupported("TABLESAMPLE unsupported")
380    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
383def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
384    self.unsupported("PIVOT unsupported")
385    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
388def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
389    return self.cast_sql(expression)
def no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
392def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
393    self.unsupported("Properties unsupported")
394    return ""
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
397def no_comment_column_constraint_sql(
398    self: Generator, expression: exp.CommentColumnConstraint
399) -> str:
400    self.unsupported("CommentColumnConstraint unsupported")
401    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
404def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
405    self.unsupported("MAP_FROM_ENTRIES unsupported")
406    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
409def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
410    this = self.sql(expression, "this")
411    substr = self.sql(expression, "substr")
412    position = self.sql(expression, "position")
413    if position:
414        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
415    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
418def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
419    return (
420        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
421    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
424def var_map_sql(
425    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
426) -> str:
427    keys = expression.args["keys"]
428    values = expression.args["values"]
429
430    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
431        self.unsupported("Cannot convert array columns into map.")
432        return self.func(map_func_name, keys, values)
433
434    args = []
435    for key, value in zip(keys.expressions, values.expressions):
436        args.append(self.sql(key))
437        args.append(self.sql(value))
438
439    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
442def format_time_lambda(
443    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
444) -> t.Callable[[t.List], E]:
445    """Helper used for time expressions.
446
447    Args:
448        exp_class: the expression class to instantiate.
449        dialect: target sql dialect.
450        default: the default format, True being time.
451
452    Returns:
453        A callable that can be used to return the appropriately formatted time expression.
454    """
455
456    def _format_time(args: t.List):
457        return exp_class(
458            this=seq_get(args, 0),
459            format=Dialect[dialect].format_time(
460                seq_get(args, 1)
461                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
462            ),
463        )
464
465    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
468def time_format(
469    dialect: DialectType = None,
470) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
471    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
472        """
473        Returns the time format for a given expression, unless it's equivalent
474        to the default time format of the dialect of interest.
475        """
476        time_format = self.format_time(expression)
477        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
478
479    return _time_format
def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
482def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
483    """
484    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
485    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
486    columns are removed from the create statement.
487    """
488    has_schema = isinstance(expression.this, exp.Schema)
489    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
490
491    if has_schema and is_partitionable:
492        expression = expression.copy()
493        prop = expression.find(exp.PartitionedByProperty)
494        if prop and prop.this and not isinstance(prop.this, exp.Schema):
495            schema = expression.this
496            columns = {v.name.upper() for v in prop.this.expressions}
497            partitions = [col for col in schema.expressions if col.name.upper() in columns]
498            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
499            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
500            expression.set("this", schema)
501
502    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
505def parse_date_delta(
506    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
507) -> t.Callable[[t.List], E]:
508    def inner_func(args: t.List) -> E:
509        unit_based = len(args) == 3
510        this = args[2] if unit_based else seq_get(args, 0)
511        unit = args[0] if unit_based else exp.Literal.string("DAY")
512        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
513        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
514
515    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
518def parse_date_delta_with_interval(
519    expression_class: t.Type[E],
520) -> t.Callable[[t.List], t.Optional[E]]:
521    def func(args: t.List) -> t.Optional[E]:
522        if len(args) < 2:
523            return None
524
525        interval = args[1]
526
527        if not isinstance(interval, exp.Interval):
528            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
529
530        expression = interval.this
531        if expression and expression.is_string:
532            expression = exp.Literal.number(expression.this)
533
534        return expression_class(
535            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
536        )
537
538    return func
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
541def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
542    unit = seq_get(args, 0)
543    this = seq_get(args, 1)
544
545    if isinstance(this, exp.Cast) and this.is_type("date"):
546        return exp.DateTrunc(unit=unit, this=this)
547    return exp.TimestampTrunc(this=this, unit=unit)
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
550def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
551    return self.func(
552        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
553    )
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
556def locate_to_strposition(args: t.List) -> exp.Expression:
557    return exp.StrPosition(
558        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
559    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
562def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
563    return self.func(
564        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
565    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
568def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
569    expression = expression.copy()
570    return self.sql(
571        exp.Substring(
572            this=expression.this, start=exp.Literal.number(1), length=expression.expression
573        )
574    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
577def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
578    expression = expression.copy()
579    return self.sql(
580        exp.Substring(
581            this=expression.this,
582            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
583        )
584    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
587def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
588    return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
591def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
592    return self.sql(exp.cast(expression.this, "date"))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
596def encode_decode_sql(
597    self: Generator, expression: exp.Expression, name: str, replace: bool = True
598) -> str:
599    charset = expression.args.get("charset")
600    if charset and charset.name.lower() != "utf-8":
601        self.unsupported(f"Expected utf-8 character set, got {charset}.")
602
603    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
606def min_or_least(self: Generator, expression: exp.Min) -> str:
607    name = "LEAST" if expression.expressions else "MIN"
608    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
611def max_or_greatest(self: Generator, expression: exp.Max) -> str:
612    name = "GREATEST" if expression.expressions else "MAX"
613    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
616def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
617    cond = expression.this
618
619    if isinstance(expression.this, exp.Distinct):
620        cond = expression.this.expressions[0]
621        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
622
623    return self.func("sum", exp.func("if", cond.copy(), 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
626def trim_sql(self: Generator, expression: exp.Trim) -> str:
627    target = self.sql(expression, "this")
628    trim_type = self.sql(expression, "position")
629    remove_chars = self.sql(expression, "expression")
630    collation = self.sql(expression, "collation")
631
632    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
633    if not remove_chars and not collation:
634        return self.trim_sql(expression)
635
636    trim_type = f"{trim_type} " if trim_type else ""
637    remove_chars = f"{remove_chars} " if remove_chars else ""
638    from_part = "FROM " if trim_type or remove_chars else ""
639    collation = f" COLLATE {collation}" if collation else ""
640    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
643def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
644    return self.func("STRPTIME", expression.this, self.format_time(expression))
def ts_or_ds_to_date_sql(dialect: str) -> Callable:
647def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
648    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
649        _dialect = Dialect.get_or_raise(dialect)
650        time_format = self.format_time(expression)
651        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
652            return self.sql(exp.cast(str_to_time_sql(self, expression), "date"))
653
654        return self.sql(exp.cast(self.sql(expression, "this"), "date"))
655
656    return _ts_or_ds_to_date_sql
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat | sqlglot.expressions.SafeConcat) -> str:
659def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
660    expression = expression.copy()
661    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
664def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
665    expression = expression.copy()
666    delim, *rest_args = expression.expressions
667    return self.sql(
668        reduce(
669            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
670            rest_args,
671        )
672    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
675def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
676    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
677    if bad_args:
678        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
679
680    return self.func(
681        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
682    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
685def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
686    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
687    if bad_args:
688        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
689
690    return self.func(
691        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
692    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
695def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
696    names = []
697    for agg in aggregations:
698        if isinstance(agg, exp.Alias):
699            names.append(agg.alias)
700        else:
701            """
702            This case corresponds to aggregations without aliases being used as suffixes
703            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
704            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
705            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
706            """
707            agg_all_unquoted = agg.transform(
708                lambda node: exp.Identifier(this=node.name, quoted=False)
709                if isinstance(node, exp.Identifier)
710                else node
711            )
712            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
713
714    return names
def simplify_literal(expression: ~E) -> ~E:
717def simplify_literal(expression: E) -> E:
718    if not isinstance(expression.expression, exp.Literal):
719        from sqlglot.optimizer.simplify import simplify
720
721        simplify(expression.expression)
722
723    return expression
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
726def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
727    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def parse_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
731def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
732    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
735def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
736    return self.func("MAX", expression.this)
def json_keyvalue_comma_sql(self, expression: sqlglot.expressions.JSONKeyValue) -> str:
740def json_keyvalue_comma_sql(self, expression: exp.JSONKeyValue) -> str:
741    return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"