Edit on GitHub

sqlglot.dialects.dialect

  1from __future__ import annotations
  2
  3import typing as t
  4from enum import Enum
  5from functools import reduce
  6
  7from sqlglot import exp
  8from sqlglot._typing import E
  9from sqlglot.errors import ParseError
 10from sqlglot.generator import Generator
 11from sqlglot.helper import flatten, seq_get
 12from sqlglot.parser import Parser
 13from sqlglot.time import format_time
 14from sqlglot.tokens import Token, Tokenizer, TokenType
 15from sqlglot.trie import new_trie
 16
 17B = t.TypeVar("B", bound=exp.Binary)
 18
 19
 20class Dialects(str, Enum):
 21    DIALECT = ""
 22
 23    BIGQUERY = "bigquery"
 24    CLICKHOUSE = "clickhouse"
 25    DATABRICKS = "databricks"
 26    DRILL = "drill"
 27    DUCKDB = "duckdb"
 28    HIVE = "hive"
 29    MYSQL = "mysql"
 30    ORACLE = "oracle"
 31    POSTGRES = "postgres"
 32    PRESTO = "presto"
 33    REDSHIFT = "redshift"
 34    SNOWFLAKE = "snowflake"
 35    SPARK = "spark"
 36    SPARK2 = "spark2"
 37    SQLITE = "sqlite"
 38    STARROCKS = "starrocks"
 39    TABLEAU = "tableau"
 40    TERADATA = "teradata"
 41    TRINO = "trino"
 42    TSQL = "tsql"
 43    Doris = "doris"
 44
 45
 46class _Dialect(type):
 47    classes: t.Dict[str, t.Type[Dialect]] = {}
 48
 49    def __eq__(cls, other: t.Any) -> bool:
 50        if cls is other:
 51            return True
 52        if isinstance(other, str):
 53            return cls is cls.get(other)
 54        if isinstance(other, Dialect):
 55            return cls is type(other)
 56
 57        return False
 58
 59    def __hash__(cls) -> int:
 60        return hash(cls.__name__.lower())
 61
 62    @classmethod
 63    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 64        return cls.classes[key]
 65
 66    @classmethod
 67    def get(
 68        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 69    ) -> t.Optional[t.Type[Dialect]]:
 70        return cls.classes.get(key, default)
 71
 72    def __new__(cls, clsname, bases, attrs):
 73        klass = super().__new__(cls, clsname, bases, attrs)
 74        enum = Dialects.__members__.get(clsname.upper())
 75        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 76
 77        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 78        klass.FORMAT_TRIE = (
 79            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 80        )
 81        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 82        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 83
 84        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 85        klass.parser_class = getattr(klass, "Parser", Parser)
 86        klass.generator_class = getattr(klass, "Generator", Generator)
 87
 88        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 89        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 90            klass.tokenizer_class._IDENTIFIERS.items()
 91        )[0]
 92
 93        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 94            return next(
 95                (
 96                    (s, e)
 97                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 98                    if t == token_type
 99                ),
100                (None, None),
101            )
102
103        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
104        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
105        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
106
107        dialect_properties = {
108            **{
109                k: v
110                for k, v in vars(klass).items()
111                if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__")
112            },
113            "TOKENIZER_CLASS": klass.tokenizer_class,
114        }
115
116        if enum not in ("", "bigquery"):
117            dialect_properties["SELECT_KINDS"] = ()
118
119        # Pass required dialect properties to the tokenizer, parser and generator classes
120        for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class):
121            for name, value in dialect_properties.items():
122                if hasattr(subclass, name):
123                    setattr(subclass, name, value)
124
125        if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT:
126            klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
127
128        klass.generator_class.can_identify = klass.can_identify
129
130        return klass
131
132
133class Dialect(metaclass=_Dialect):
134    # Determines the base index offset for arrays
135    INDEX_OFFSET = 0
136
137    # If true unnest table aliases are considered only as column aliases
138    UNNEST_COLUMN_ONLY = False
139
140    # Determines whether or not the table alias comes after tablesample
141    ALIAS_POST_TABLESAMPLE = False
142
143    # Determines whether or not unquoted identifiers are resolved as uppercase
144    # When set to None, it means that the dialect treats all identifiers as case-insensitive
145    RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
146
147    # Determines whether or not an unquoted identifier can start with a digit
148    IDENTIFIERS_CAN_START_WITH_DIGIT = False
149
150    # Determines whether or not the DPIPE token ('||') is a string concatenation operator
151    DPIPE_IS_STRING_CONCAT = True
152
153    # Determines whether or not CONCAT's arguments must be strings
154    STRICT_STRING_CONCAT = False
155
156    # Determines whether or not user-defined data types are supported
157    SUPPORTS_USER_DEFINED_TYPES = True
158
159    # Determines how function names are going to be normalized
160    NORMALIZE_FUNCTIONS: bool | str = "upper"
161
162    # Indicates the default null ordering method to use if not explicitly set
163    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
164    NULL_ORDERING = "nulls_are_small"
165
166    DATE_FORMAT = "'%Y-%m-%d'"
167    DATEINT_FORMAT = "'%Y%m%d'"
168    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
169
170    # Custom time mappings in which the key represents dialect time format
171    # and the value represents a python time format
172    TIME_MAPPING: t.Dict[str, str] = {}
173
174    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
175    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
176    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
177    FORMAT_MAPPING: t.Dict[str, str] = {}
178
179    # Columns that are auto-generated by the engine corresponding to this dialect
180    # Such columns may be excluded from SELECT * queries, for example
181    PSEUDOCOLUMNS: t.Set[str] = set()
182
183    # Autofilled
184    tokenizer_class = Tokenizer
185    parser_class = Parser
186    generator_class = Generator
187
188    # A trie of the time_mapping keys
189    TIME_TRIE: t.Dict = {}
190    FORMAT_TRIE: t.Dict = {}
191
192    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
193    INVERSE_TIME_TRIE: t.Dict = {}
194
195    def __eq__(self, other: t.Any) -> bool:
196        return type(self) == other
197
198    def __hash__(self) -> int:
199        return hash(type(self))
200
201    @classmethod
202    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
203        if not dialect:
204            return cls
205        if isinstance(dialect, _Dialect):
206            return dialect
207        if isinstance(dialect, Dialect):
208            return dialect.__class__
209
210        result = cls.get(dialect)
211        if not result:
212            raise ValueError(f"Unknown dialect '{dialect}'")
213
214        return result
215
216    @classmethod
217    def format_time(
218        cls, expression: t.Optional[str | exp.Expression]
219    ) -> t.Optional[exp.Expression]:
220        if isinstance(expression, str):
221            return exp.Literal.string(
222                # the time formats are quoted
223                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
224            )
225
226        if expression and expression.is_string:
227            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
228
229        return expression
230
231    @classmethod
232    def normalize_identifier(cls, expression: E) -> E:
233        """
234        Normalizes an unquoted identifier to either lower or upper case, thus essentially
235        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
236        they will be normalized regardless of being quoted or not.
237        """
238        if isinstance(expression, exp.Identifier) and (
239            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
240        ):
241            expression.set(
242                "this",
243                expression.this.upper()
244                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
245                else expression.this.lower(),
246            )
247
248        return expression
249
250    @classmethod
251    def case_sensitive(cls, text: str) -> bool:
252        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
253        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
254            return False
255
256        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
257        return any(unsafe(char) for char in text)
258
259    @classmethod
260    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
261        """Checks if text can be identified given an identify option.
262
263        Args:
264            text: The text to check.
265            identify:
266                "always" or `True`: Always returns true.
267                "safe": True if the identifier is case-insensitive.
268
269        Returns:
270            Whether or not the given text can be identified.
271        """
272        if identify is True or identify == "always":
273            return True
274
275        if identify == "safe":
276            return not cls.case_sensitive(text)
277
278        return False
279
280    @classmethod
281    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
282        if isinstance(expression, exp.Identifier):
283            name = expression.this
284            expression.set(
285                "quoted",
286                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
287            )
288
289        return expression
290
291    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
292        return self.parser(**opts).parse(self.tokenize(sql), sql)
293
294    def parse_into(
295        self, expression_type: exp.IntoType, sql: str, **opts
296    ) -> t.List[t.Optional[exp.Expression]]:
297        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
298
299    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
300        return self.generator(**opts).generate(expression)
301
302    def transpile(self, sql: str, **opts) -> t.List[str]:
303        return [self.generate(expression, **opts) for expression in self.parse(sql)]
304
305    def tokenize(self, sql: str) -> t.List[Token]:
306        return self.tokenizer.tokenize(sql)
307
308    @property
309    def tokenizer(self) -> Tokenizer:
310        if not hasattr(self, "_tokenizer"):
311            self._tokenizer = self.tokenizer_class()
312        return self._tokenizer
313
314    def parser(self, **opts) -> Parser:
315        return self.parser_class(**opts)
316
317    def generator(self, **opts) -> Generator:
318        return self.generator_class(**opts)
319
320
321DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
322
323
324def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
325    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
326
327
328def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
329    if expression.args.get("accuracy"):
330        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
331    return self.func("APPROX_COUNT_DISTINCT", expression.this)
332
333
334def if_sql(self: Generator, expression: exp.If) -> str:
335    return self.func(
336        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
337    )
338
339
340def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
341    return self.binary(expression, "->")
342
343
344def arrow_json_extract_scalar_sql(
345    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
346) -> str:
347    return self.binary(expression, "->>")
348
349
350def inline_array_sql(self: Generator, expression: exp.Array) -> str:
351    return f"[{self.expressions(expression, flat=True)}]"
352
353
354def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
355    return self.like_sql(
356        exp.Like(
357            this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy()
358        )
359    )
360
361
362def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
363    zone = self.sql(expression, "this")
364    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
365
366
367def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
368    if expression.args.get("recursive"):
369        self.unsupported("Recursive CTEs are unsupported")
370        expression.args["recursive"] = False
371    return self.with_sql(expression)
372
373
374def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
375    n = self.sql(expression, "this")
376    d = self.sql(expression, "expression")
377    return f"IF({d} <> 0, {n} / {d}, NULL)"
378
379
380def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
381    self.unsupported("TABLESAMPLE unsupported")
382    return self.sql(expression.this)
383
384
385def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
386    self.unsupported("PIVOT unsupported")
387    return ""
388
389
390def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
391    return self.cast_sql(expression)
392
393
394def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
395    self.unsupported("Properties unsupported")
396    return ""
397
398
399def no_comment_column_constraint_sql(
400    self: Generator, expression: exp.CommentColumnConstraint
401) -> str:
402    self.unsupported("CommentColumnConstraint unsupported")
403    return ""
404
405
406def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
407    self.unsupported("MAP_FROM_ENTRIES unsupported")
408    return ""
409
410
411def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
412    this = self.sql(expression, "this")
413    substr = self.sql(expression, "substr")
414    position = self.sql(expression, "position")
415    if position:
416        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
417    return f"STRPOS({this}, {substr})"
418
419
420def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
421    return (
422        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
423    )
424
425
426def var_map_sql(
427    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
428) -> str:
429    keys = expression.args["keys"]
430    values = expression.args["values"]
431
432    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
433        self.unsupported("Cannot convert array columns into map.")
434        return self.func(map_func_name, keys, values)
435
436    args = []
437    for key, value in zip(keys.expressions, values.expressions):
438        args.append(self.sql(key))
439        args.append(self.sql(value))
440
441    return self.func(map_func_name, *args)
442
443
444def format_time_lambda(
445    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
446) -> t.Callable[[t.List], E]:
447    """Helper used for time expressions.
448
449    Args:
450        exp_class: the expression class to instantiate.
451        dialect: target sql dialect.
452        default: the default format, True being time.
453
454    Returns:
455        A callable that can be used to return the appropriately formatted time expression.
456    """
457
458    def _format_time(args: t.List):
459        return exp_class(
460            this=seq_get(args, 0),
461            format=Dialect[dialect].format_time(
462                seq_get(args, 1)
463                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
464            ),
465        )
466
467    return _format_time
468
469
470def time_format(
471    dialect: DialectType = None,
472) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
473    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
474        """
475        Returns the time format for a given expression, unless it's equivalent
476        to the default time format of the dialect of interest.
477        """
478        time_format = self.format_time(expression)
479        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
480
481    return _time_format
482
483
484def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
485    """
486    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
487    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
488    columns are removed from the create statement.
489    """
490    has_schema = isinstance(expression.this, exp.Schema)
491    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
492
493    if has_schema and is_partitionable:
494        expression = expression.copy()
495        prop = expression.find(exp.PartitionedByProperty)
496        if prop and prop.this and not isinstance(prop.this, exp.Schema):
497            schema = expression.this
498            columns = {v.name.upper() for v in prop.this.expressions}
499            partitions = [col for col in schema.expressions if col.name.upper() in columns]
500            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
501            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
502            expression.set("this", schema)
503
504    return self.create_sql(expression)
505
506
507def parse_date_delta(
508    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
509) -> t.Callable[[t.List], E]:
510    def inner_func(args: t.List) -> E:
511        unit_based = len(args) == 3
512        this = args[2] if unit_based else seq_get(args, 0)
513        unit = args[0] if unit_based else exp.Literal.string("DAY")
514        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
515        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
516
517    return inner_func
518
519
520def parse_date_delta_with_interval(
521    expression_class: t.Type[E],
522) -> t.Callable[[t.List], t.Optional[E]]:
523    def func(args: t.List) -> t.Optional[E]:
524        if len(args) < 2:
525            return None
526
527        interval = args[1]
528
529        if not isinstance(interval, exp.Interval):
530            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
531
532        expression = interval.this
533        if expression and expression.is_string:
534            expression = exp.Literal.number(expression.this)
535
536        return expression_class(
537            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
538        )
539
540    return func
541
542
543def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
544    unit = seq_get(args, 0)
545    this = seq_get(args, 1)
546
547    if isinstance(this, exp.Cast) and this.is_type("date"):
548        return exp.DateTrunc(unit=unit, this=this)
549    return exp.TimestampTrunc(this=this, unit=unit)
550
551
552def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
553    return self.func(
554        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
555    )
556
557
558def locate_to_strposition(args: t.List) -> exp.Expression:
559    return exp.StrPosition(
560        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
561    )
562
563
564def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
565    return self.func(
566        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
567    )
568
569
570def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
571    expression = expression.copy()
572    return self.sql(
573        exp.Substring(
574            this=expression.this, start=exp.Literal.number(1), length=expression.expression
575        )
576    )
577
578
579def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
580    expression = expression.copy()
581    return self.sql(
582        exp.Substring(
583            this=expression.this,
584            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
585        )
586    )
587
588
589def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
590    return self.sql(exp.cast(expression.this, "timestamp"))
591
592
593def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
594    return self.sql(exp.cast(expression.this, "date"))
595
596
597# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
598def encode_decode_sql(
599    self: Generator, expression: exp.Expression, name: str, replace: bool = True
600) -> str:
601    charset = expression.args.get("charset")
602    if charset and charset.name.lower() != "utf-8":
603        self.unsupported(f"Expected utf-8 character set, got {charset}.")
604
605    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
606
607
608def min_or_least(self: Generator, expression: exp.Min) -> str:
609    name = "LEAST" if expression.expressions else "MIN"
610    return rename_func(name)(self, expression)
611
612
613def max_or_greatest(self: Generator, expression: exp.Max) -> str:
614    name = "GREATEST" if expression.expressions else "MAX"
615    return rename_func(name)(self, expression)
616
617
618def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
619    cond = expression.this
620
621    if isinstance(expression.this, exp.Distinct):
622        cond = expression.this.expressions[0]
623        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
624
625    return self.func("sum", exp.func("if", cond.copy(), 1, 0))
626
627
628def trim_sql(self: Generator, expression: exp.Trim) -> str:
629    target = self.sql(expression, "this")
630    trim_type = self.sql(expression, "position")
631    remove_chars = self.sql(expression, "expression")
632    collation = self.sql(expression, "collation")
633
634    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
635    if not remove_chars and not collation:
636        return self.trim_sql(expression)
637
638    trim_type = f"{trim_type} " if trim_type else ""
639    remove_chars = f"{remove_chars} " if remove_chars else ""
640    from_part = "FROM " if trim_type or remove_chars else ""
641    collation = f" COLLATE {collation}" if collation else ""
642    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
643
644
645def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
646    return self.func("STRPTIME", expression.this, self.format_time(expression))
647
648
649def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
650    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
651        _dialect = Dialect.get_or_raise(dialect)
652        time_format = self.format_time(expression)
653        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
654            return self.sql(exp.cast(str_to_time_sql(self, expression), "date"))
655
656        return self.sql(exp.cast(self.sql(expression, "this"), "date"))
657
658    return _ts_or_ds_to_date_sql
659
660
661def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
662    expression = expression.copy()
663    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
664
665
666def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
667    expression = expression.copy()
668    delim, *rest_args = expression.expressions
669    return self.sql(
670        reduce(
671            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
672            rest_args,
673        )
674    )
675
676
677def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
678    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
679    if bad_args:
680        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
681
682    return self.func(
683        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
684    )
685
686
687def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
688    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
689    if bad_args:
690        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
691
692    return self.func(
693        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
694    )
695
696
697def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
698    names = []
699    for agg in aggregations:
700        if isinstance(agg, exp.Alias):
701            names.append(agg.alias)
702        else:
703            """
704            This case corresponds to aggregations without aliases being used as suffixes
705            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
706            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
707            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
708            """
709            agg_all_unquoted = agg.transform(
710                lambda node: exp.Identifier(this=node.name, quoted=False)
711                if isinstance(node, exp.Identifier)
712                else node
713            )
714            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
715
716    return names
717
718
719def simplify_literal(expression: E) -> E:
720    if not isinstance(expression.expression, exp.Literal):
721        from sqlglot.optimizer.simplify import simplify
722
723        simplify(expression.expression)
724
725    return expression
726
727
728def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
729    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
730
731
732# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
733def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
734    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
735
736
737def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
738    return self.func("MAX", expression.this)
739
740
741# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon
742def json_keyvalue_comma_sql(self, expression: exp.JSONKeyValue) -> str:
743    return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
class Dialects(builtins.str, enum.Enum):
21class Dialects(str, Enum):
22    DIALECT = ""
23
24    BIGQUERY = "bigquery"
25    CLICKHOUSE = "clickhouse"
26    DATABRICKS = "databricks"
27    DRILL = "drill"
28    DUCKDB = "duckdb"
29    HIVE = "hive"
30    MYSQL = "mysql"
31    ORACLE = "oracle"
32    POSTGRES = "postgres"
33    PRESTO = "presto"
34    REDSHIFT = "redshift"
35    SNOWFLAKE = "snowflake"
36    SPARK = "spark"
37    SPARK2 = "spark2"
38    SQLITE = "sqlite"
39    STARROCKS = "starrocks"
40    TABLEAU = "tableau"
41    TERADATA = "teradata"
42    TRINO = "trino"
43    TSQL = "tsql"
44    Doris = "doris"

An enumeration.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Doris = <Dialects.Doris: 'doris'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
134class Dialect(metaclass=_Dialect):
135    # Determines the base index offset for arrays
136    INDEX_OFFSET = 0
137
138    # If true unnest table aliases are considered only as column aliases
139    UNNEST_COLUMN_ONLY = False
140
141    # Determines whether or not the table alias comes after tablesample
142    ALIAS_POST_TABLESAMPLE = False
143
144    # Determines whether or not unquoted identifiers are resolved as uppercase
145    # When set to None, it means that the dialect treats all identifiers as case-insensitive
146    RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
147
148    # Determines whether or not an unquoted identifier can start with a digit
149    IDENTIFIERS_CAN_START_WITH_DIGIT = False
150
151    # Determines whether or not the DPIPE token ('||') is a string concatenation operator
152    DPIPE_IS_STRING_CONCAT = True
153
154    # Determines whether or not CONCAT's arguments must be strings
155    STRICT_STRING_CONCAT = False
156
157    # Determines whether or not user-defined data types are supported
158    SUPPORTS_USER_DEFINED_TYPES = True
159
160    # Determines how function names are going to be normalized
161    NORMALIZE_FUNCTIONS: bool | str = "upper"
162
163    # Indicates the default null ordering method to use if not explicitly set
164    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
165    NULL_ORDERING = "nulls_are_small"
166
167    DATE_FORMAT = "'%Y-%m-%d'"
168    DATEINT_FORMAT = "'%Y%m%d'"
169    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
170
171    # Custom time mappings in which the key represents dialect time format
172    # and the value represents a python time format
173    TIME_MAPPING: t.Dict[str, str] = {}
174
175    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
176    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
177    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
178    FORMAT_MAPPING: t.Dict[str, str] = {}
179
180    # Columns that are auto-generated by the engine corresponding to this dialect
181    # Such columns may be excluded from SELECT * queries, for example
182    PSEUDOCOLUMNS: t.Set[str] = set()
183
184    # Autofilled
185    tokenizer_class = Tokenizer
186    parser_class = Parser
187    generator_class = Generator
188
189    # A trie of the time_mapping keys
190    TIME_TRIE: t.Dict = {}
191    FORMAT_TRIE: t.Dict = {}
192
193    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
194    INVERSE_TIME_TRIE: t.Dict = {}
195
196    def __eq__(self, other: t.Any) -> bool:
197        return type(self) == other
198
199    def __hash__(self) -> int:
200        return hash(type(self))
201
202    @classmethod
203    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
204        if not dialect:
205            return cls
206        if isinstance(dialect, _Dialect):
207            return dialect
208        if isinstance(dialect, Dialect):
209            return dialect.__class__
210
211        result = cls.get(dialect)
212        if not result:
213            raise ValueError(f"Unknown dialect '{dialect}'")
214
215        return result
216
217    @classmethod
218    def format_time(
219        cls, expression: t.Optional[str | exp.Expression]
220    ) -> t.Optional[exp.Expression]:
221        if isinstance(expression, str):
222            return exp.Literal.string(
223                # the time formats are quoted
224                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
225            )
226
227        if expression and expression.is_string:
228            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
229
230        return expression
231
232    @classmethod
233    def normalize_identifier(cls, expression: E) -> E:
234        """
235        Normalizes an unquoted identifier to either lower or upper case, thus essentially
236        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
237        they will be normalized regardless of being quoted or not.
238        """
239        if isinstance(expression, exp.Identifier) and (
240            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
241        ):
242            expression.set(
243                "this",
244                expression.this.upper()
245                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
246                else expression.this.lower(),
247            )
248
249        return expression
250
251    @classmethod
252    def case_sensitive(cls, text: str) -> bool:
253        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
254        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
255            return False
256
257        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
258        return any(unsafe(char) for char in text)
259
260    @classmethod
261    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
262        """Checks if text can be identified given an identify option.
263
264        Args:
265            text: The text to check.
266            identify:
267                "always" or `True`: Always returns true.
268                "safe": True if the identifier is case-insensitive.
269
270        Returns:
271            Whether or not the given text can be identified.
272        """
273        if identify is True or identify == "always":
274            return True
275
276        if identify == "safe":
277            return not cls.case_sensitive(text)
278
279        return False
280
281    @classmethod
282    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
283        if isinstance(expression, exp.Identifier):
284            name = expression.this
285            expression.set(
286                "quoted",
287                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
288            )
289
290        return expression
291
292    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
293        return self.parser(**opts).parse(self.tokenize(sql), sql)
294
295    def parse_into(
296        self, expression_type: exp.IntoType, sql: str, **opts
297    ) -> t.List[t.Optional[exp.Expression]]:
298        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
299
300    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
301        return self.generator(**opts).generate(expression)
302
303    def transpile(self, sql: str, **opts) -> t.List[str]:
304        return [self.generate(expression, **opts) for expression in self.parse(sql)]
305
306    def tokenize(self, sql: str) -> t.List[Token]:
307        return self.tokenizer.tokenize(sql)
308
309    @property
310    def tokenizer(self) -> Tokenizer:
311        if not hasattr(self, "_tokenizer"):
312            self._tokenizer = self.tokenizer_class()
313        return self._tokenizer
314
315    def parser(self, **opts) -> Parser:
316        return self.parser_class(**opts)
317
318    def generator(self, **opts) -> Generator:
319        return self.generator_class(**opts)
INDEX_OFFSET = 0
UNNEST_COLUMN_ONLY = False
ALIAS_POST_TABLESAMPLE = False
RESOLVES_IDENTIFIERS_AS_UPPERCASE: Optional[bool] = False
IDENTIFIERS_CAN_START_WITH_DIGIT = False
DPIPE_IS_STRING_CONCAT = True
STRICT_STRING_CONCAT = False
SUPPORTS_USER_DEFINED_TYPES = True
NORMALIZE_FUNCTIONS: bool | str = 'upper'
NULL_ORDERING = 'nulls_are_small'
DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}
FORMAT_MAPPING: Dict[str, str] = {}
PSEUDOCOLUMNS: Set[str] = set()
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Type[Dialect]:
202    @classmethod
203    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
204        if not dialect:
205            return cls
206        if isinstance(dialect, _Dialect):
207            return dialect
208        if isinstance(dialect, Dialect):
209            return dialect.__class__
210
211        result = cls.get(dialect)
212        if not result:
213            raise ValueError(f"Unknown dialect '{dialect}'")
214
215        return result
@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
217    @classmethod
218    def format_time(
219        cls, expression: t.Optional[str | exp.Expression]
220    ) -> t.Optional[exp.Expression]:
221        if isinstance(expression, str):
222            return exp.Literal.string(
223                # the time formats are quoted
224                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
225            )
226
227        if expression and expression.is_string:
228            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
229
230        return expression
@classmethod
def normalize_identifier(cls, expression: ~E) -> ~E:
232    @classmethod
233    def normalize_identifier(cls, expression: E) -> E:
234        """
235        Normalizes an unquoted identifier to either lower or upper case, thus essentially
236        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
237        they will be normalized regardless of being quoted or not.
238        """
239        if isinstance(expression, exp.Identifier) and (
240            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
241        ):
242            expression.set(
243                "this",
244                expression.this.upper()
245                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
246                else expression.this.lower(),
247            )
248
249        return expression

Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized regardless of being quoted or not.

@classmethod
def case_sensitive(cls, text: str) -> bool:
251    @classmethod
252    def case_sensitive(cls, text: str) -> bool:
253        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
254        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
255            return False
256
257        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
258        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

@classmethod
def can_identify(cls, text: str, identify: str | bool = 'safe') -> bool:
260    @classmethod
261    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
262        """Checks if text can be identified given an identify option.
263
264        Args:
265            text: The text to check.
266            identify:
267                "always" or `True`: Always returns true.
268                "safe": True if the identifier is case-insensitive.
269
270        Returns:
271            Whether or not the given text can be identified.
272        """
273        if identify is True or identify == "always":
274            return True
275
276        if identify == "safe":
277            return not cls.case_sensitive(text)
278
279        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:

Whether or not the given text can be identified.

@classmethod
def quote_identifier(cls, expression: ~E, identify: bool = True) -> ~E:
281    @classmethod
282    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
283        if isinstance(expression, exp.Identifier):
284            name = expression.this
285            expression.set(
286                "quoted",
287                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
288            )
289
290        return expression
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
292    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
293        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
295    def parse_into(
296        self, expression_type: exp.IntoType, sql: str, **opts
297    ) -> t.List[t.Optional[exp.Expression]]:
298        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: Optional[sqlglot.expressions.Expression], **opts) -> str:
300    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
301        return self.generator(**opts).generate(expression)
def transpile(self, sql: str, **opts) -> List[str]:
303    def transpile(self, sql: str, **opts) -> t.List[str]:
304        return [self.generate(expression, **opts) for expression in self.parse(sql)]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
306    def tokenize(self, sql: str) -> t.List[Token]:
307        return self.tokenizer.tokenize(sql)
def parser(self, **opts) -> sqlglot.parser.Parser:
315    def parser(self, **opts) -> Parser:
316        return self.parser_class(**opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
318    def generator(self, **opts) -> Generator:
319        return self.generator_class(**opts)
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START = None
BIT_END = None
HEX_START = None
HEX_END = None
BYTE_START = None
BYTE_END = None
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
325def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
326    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
329def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
330    if expression.args.get("accuracy"):
331        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
332    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.If) -> str:
335def if_sql(self: Generator, expression: exp.If) -> str:
336    return self.func(
337        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
338    )
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
341def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
342    return self.binary(expression, "->")
def arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
345def arrow_json_extract_scalar_sql(
346    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
347) -> str:
348    return self.binary(expression, "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
351def inline_array_sql(self: Generator, expression: exp.Array) -> str:
352    return f"[{self.expressions(expression, flat=True)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
355def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
356    return self.like_sql(
357        exp.Like(
358            this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy()
359        )
360    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
363def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
364    zone = self.sql(expression, "this")
365    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
368def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
369    if expression.args.get("recursive"):
370        self.unsupported("Recursive CTEs are unsupported")
371        expression.args["recursive"] = False
372    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
375def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
376    n = self.sql(expression, "this")
377    d = self.sql(expression, "expression")
378    return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
381def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
382    self.unsupported("TABLESAMPLE unsupported")
383    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
386def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
387    self.unsupported("PIVOT unsupported")
388    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
391def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
392    return self.cast_sql(expression)
def no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
395def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
396    self.unsupported("Properties unsupported")
397    return ""
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
400def no_comment_column_constraint_sql(
401    self: Generator, expression: exp.CommentColumnConstraint
402) -> str:
403    self.unsupported("CommentColumnConstraint unsupported")
404    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
407def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
408    self.unsupported("MAP_FROM_ENTRIES unsupported")
409    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
412def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
413    this = self.sql(expression, "this")
414    substr = self.sql(expression, "substr")
415    position = self.sql(expression, "position")
416    if position:
417        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
418    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
421def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
422    return (
423        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
424    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
427def var_map_sql(
428    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
429) -> str:
430    keys = expression.args["keys"]
431    values = expression.args["values"]
432
433    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
434        self.unsupported("Cannot convert array columns into map.")
435        return self.func(map_func_name, keys, values)
436
437    args = []
438    for key, value in zip(keys.expressions, values.expressions):
439        args.append(self.sql(key))
440        args.append(self.sql(value))
441
442    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
445def format_time_lambda(
446    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
447) -> t.Callable[[t.List], E]:
448    """Helper used for time expressions.
449
450    Args:
451        exp_class: the expression class to instantiate.
452        dialect: target sql dialect.
453        default: the default format, True being time.
454
455    Returns:
456        A callable that can be used to return the appropriately formatted time expression.
457    """
458
459    def _format_time(args: t.List):
460        return exp_class(
461            this=seq_get(args, 0),
462            format=Dialect[dialect].format_time(
463                seq_get(args, 1)
464                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
465            ),
466        )
467
468    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
471def time_format(
472    dialect: DialectType = None,
473) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
474    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
475        """
476        Returns the time format for a given expression, unless it's equivalent
477        to the default time format of the dialect of interest.
478        """
479        time_format = self.format_time(expression)
480        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
481
482    return _time_format
def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
485def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
486    """
487    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
488    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
489    columns are removed from the create statement.
490    """
491    has_schema = isinstance(expression.this, exp.Schema)
492    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
493
494    if has_schema and is_partitionable:
495        expression = expression.copy()
496        prop = expression.find(exp.PartitionedByProperty)
497        if prop and prop.this and not isinstance(prop.this, exp.Schema):
498            schema = expression.this
499            columns = {v.name.upper() for v in prop.this.expressions}
500            partitions = [col for col in schema.expressions if col.name.upper() in columns]
501            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
502            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
503            expression.set("this", schema)
504
505    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
508def parse_date_delta(
509    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
510) -> t.Callable[[t.List], E]:
511    def inner_func(args: t.List) -> E:
512        unit_based = len(args) == 3
513        this = args[2] if unit_based else seq_get(args, 0)
514        unit = args[0] if unit_based else exp.Literal.string("DAY")
515        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
516        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
517
518    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
521def parse_date_delta_with_interval(
522    expression_class: t.Type[E],
523) -> t.Callable[[t.List], t.Optional[E]]:
524    def func(args: t.List) -> t.Optional[E]:
525        if len(args) < 2:
526            return None
527
528        interval = args[1]
529
530        if not isinstance(interval, exp.Interval):
531            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
532
533        expression = interval.this
534        if expression and expression.is_string:
535            expression = exp.Literal.number(expression.this)
536
537        return expression_class(
538            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
539        )
540
541    return func
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
544def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
545    unit = seq_get(args, 0)
546    this = seq_get(args, 1)
547
548    if isinstance(this, exp.Cast) and this.is_type("date"):
549        return exp.DateTrunc(unit=unit, this=this)
550    return exp.TimestampTrunc(this=this, unit=unit)
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
553def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
554    return self.func(
555        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
556    )
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
559def locate_to_strposition(args: t.List) -> exp.Expression:
560    return exp.StrPosition(
561        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
562    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
565def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
566    return self.func(
567        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
568    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
571def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
572    expression = expression.copy()
573    return self.sql(
574        exp.Substring(
575            this=expression.this, start=exp.Literal.number(1), length=expression.expression
576        )
577    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
580def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
581    expression = expression.copy()
582    return self.sql(
583        exp.Substring(
584            this=expression.this,
585            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
586        )
587    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
590def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
591    return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
594def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
595    return self.sql(exp.cast(expression.this, "date"))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
599def encode_decode_sql(
600    self: Generator, expression: exp.Expression, name: str, replace: bool = True
601) -> str:
602    charset = expression.args.get("charset")
603    if charset and charset.name.lower() != "utf-8":
604        self.unsupported(f"Expected utf-8 character set, got {charset}.")
605
606    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
609def min_or_least(self: Generator, expression: exp.Min) -> str:
610    name = "LEAST" if expression.expressions else "MIN"
611    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
614def max_or_greatest(self: Generator, expression: exp.Max) -> str:
615    name = "GREATEST" if expression.expressions else "MAX"
616    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
619def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
620    cond = expression.this
621
622    if isinstance(expression.this, exp.Distinct):
623        cond = expression.this.expressions[0]
624        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
625
626    return self.func("sum", exp.func("if", cond.copy(), 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
629def trim_sql(self: Generator, expression: exp.Trim) -> str:
630    target = self.sql(expression, "this")
631    trim_type = self.sql(expression, "position")
632    remove_chars = self.sql(expression, "expression")
633    collation = self.sql(expression, "collation")
634
635    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
636    if not remove_chars and not collation:
637        return self.trim_sql(expression)
638
639    trim_type = f"{trim_type} " if trim_type else ""
640    remove_chars = f"{remove_chars} " if remove_chars else ""
641    from_part = "FROM " if trim_type or remove_chars else ""
642    collation = f" COLLATE {collation}" if collation else ""
643    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
646def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
647    return self.func("STRPTIME", expression.this, self.format_time(expression))
def ts_or_ds_to_date_sql(dialect: str) -> Callable:
650def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
651    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
652        _dialect = Dialect.get_or_raise(dialect)
653        time_format = self.format_time(expression)
654        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
655            return self.sql(exp.cast(str_to_time_sql(self, expression), "date"))
656
657        return self.sql(exp.cast(self.sql(expression, "this"), "date"))
658
659    return _ts_or_ds_to_date_sql
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat | sqlglot.expressions.SafeConcat) -> str:
662def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
663    expression = expression.copy()
664    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
667def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
668    expression = expression.copy()
669    delim, *rest_args = expression.expressions
670    return self.sql(
671        reduce(
672            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
673            rest_args,
674        )
675    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
678def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
679    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
680    if bad_args:
681        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
682
683    return self.func(
684        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
685    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
688def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
689    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
690    if bad_args:
691        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
692
693    return self.func(
694        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
695    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
698def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
699    names = []
700    for agg in aggregations:
701        if isinstance(agg, exp.Alias):
702            names.append(agg.alias)
703        else:
704            """
705            This case corresponds to aggregations without aliases being used as suffixes
706            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
707            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
708            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
709            """
710            agg_all_unquoted = agg.transform(
711                lambda node: exp.Identifier(this=node.name, quoted=False)
712                if isinstance(node, exp.Identifier)
713                else node
714            )
715            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
716
717    return names
def simplify_literal(expression: ~E) -> ~E:
720def simplify_literal(expression: E) -> E:
721    if not isinstance(expression.expression, exp.Literal):
722        from sqlglot.optimizer.simplify import simplify
723
724        simplify(expression.expression)
725
726    return expression
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
729def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
730    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def parse_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
734def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
735    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
738def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
739    return self.func("MAX", expression.this)
def json_keyvalue_comma_sql(self, expression: sqlglot.expressions.JSONKeyValue) -> str:
743def json_keyvalue_comma_sql(self, expression: exp.JSONKeyValue) -> str:
744    return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"