sqlglot.dialects.dialect
1from __future__ import annotations 2 3import typing as t 4from enum import Enum 5from functools import reduce 6 7from sqlglot import exp 8from sqlglot._typing import E 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import flatten, seq_get 12from sqlglot.parser import Parser 13from sqlglot.time import TIMEZONES, format_time 14from sqlglot.tokens import Token, Tokenizer, TokenType 15from sqlglot.trie import new_trie 16 17B = t.TypeVar("B", bound=exp.Binary) 18 19 20class Dialects(str, Enum): 21 DIALECT = "" 22 23 BIGQUERY = "bigquery" 24 CLICKHOUSE = "clickhouse" 25 DATABRICKS = "databricks" 26 DRILL = "drill" 27 DUCKDB = "duckdb" 28 HIVE = "hive" 29 MYSQL = "mysql" 30 ORACLE = "oracle" 31 POSTGRES = "postgres" 32 PRESTO = "presto" 33 REDSHIFT = "redshift" 34 SNOWFLAKE = "snowflake" 35 SPARK = "spark" 36 SPARK2 = "spark2" 37 SQLITE = "sqlite" 38 STARROCKS = "starrocks" 39 TABLEAU = "tableau" 40 TERADATA = "teradata" 41 TRINO = "trino" 42 TSQL = "tsql" 43 Doris = "doris" 44 45 46class _Dialect(type): 47 classes: t.Dict[str, t.Type[Dialect]] = {} 48 49 def __eq__(cls, other: t.Any) -> bool: 50 if cls is other: 51 return True 52 if isinstance(other, str): 53 return cls is cls.get(other) 54 if isinstance(other, Dialect): 55 return cls is type(other) 56 57 return False 58 59 def __hash__(cls) -> int: 60 return hash(cls.__name__.lower()) 61 62 @classmethod 63 def __getitem__(cls, key: str) -> t.Type[Dialect]: 64 return cls.classes[key] 65 66 @classmethod 67 def get( 68 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 69 ) -> t.Optional[t.Type[Dialect]]: 70 return cls.classes.get(key, default) 71 72 def __new__(cls, clsname, bases, attrs): 73 klass = super().__new__(cls, clsname, bases, attrs) 74 enum = Dialects.__members__.get(clsname.upper()) 75 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 76 77 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 78 klass.FORMAT_TRIE = ( 79 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 80 ) 81 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 82 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 83 84 klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()} 85 86 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 87 klass.parser_class = getattr(klass, "Parser", Parser) 88 klass.generator_class = getattr(klass, "Generator", Generator) 89 90 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 91 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 92 klass.tokenizer_class._IDENTIFIERS.items() 93 )[0] 94 95 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 96 return next( 97 ( 98 (s, e) 99 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 100 if t == token_type 101 ), 102 (None, None), 103 ) 104 105 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 106 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 107 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 108 109 dialect_properties = { 110 **{ 111 k: v 112 for k, v in vars(klass).items() 113 if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__") 114 }, 115 "TOKENIZER_CLASS": klass.tokenizer_class, 116 } 117 118 if enum not in ("", "bigquery"): 119 dialect_properties["SELECT_KINDS"] = () 120 121 # Pass required dialect properties to the tokenizer, parser and generator classes 122 for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class): 123 for name, value in dialect_properties.items(): 124 if hasattr(subclass, name): 125 setattr(subclass, name, value) 126 127 if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT: 128 klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe 129 130 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 131 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 132 TokenType.ANTI, 133 TokenType.SEMI, 134 } 135 136 klass.generator_class.can_identify = klass.can_identify 137 138 return klass 139 140 141class Dialect(metaclass=_Dialect): 142 # Determines the base index offset for arrays 143 INDEX_OFFSET = 0 144 145 # If true unnest table aliases are considered only as column aliases 146 UNNEST_COLUMN_ONLY = False 147 148 # Determines whether or not the table alias comes after tablesample 149 ALIAS_POST_TABLESAMPLE = False 150 151 # Determines whether or not unquoted identifiers are resolved as uppercase 152 # When set to None, it means that the dialect treats all identifiers as case-insensitive 153 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 154 155 # Determines whether or not an unquoted identifier can start with a digit 156 IDENTIFIERS_CAN_START_WITH_DIGIT = False 157 158 # Determines whether or not the DPIPE token ('||') is a string concatenation operator 159 DPIPE_IS_STRING_CONCAT = True 160 161 # Determines whether or not CONCAT's arguments must be strings 162 STRICT_STRING_CONCAT = False 163 164 # Determines whether or not user-defined data types are supported 165 SUPPORTS_USER_DEFINED_TYPES = True 166 167 # Determines whether or not SEMI/ANTI JOINs are supported 168 SUPPORTS_SEMI_ANTI_JOIN = True 169 170 # Determines how function names are going to be normalized 171 NORMALIZE_FUNCTIONS: bool | str = "upper" 172 173 # Determines whether the base comes first in the LOG function 174 LOG_BASE_FIRST = True 175 176 # Indicates the default null ordering method to use if not explicitly set 177 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 178 NULL_ORDERING = "nulls_are_small" 179 180 DATE_FORMAT = "'%Y-%m-%d'" 181 DATEINT_FORMAT = "'%Y%m%d'" 182 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 183 184 # Custom time mappings in which the key represents dialect time format 185 # and the value represents a python time format 186 TIME_MAPPING: t.Dict[str, str] = {} 187 188 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 189 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 190 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 191 FORMAT_MAPPING: t.Dict[str, str] = {} 192 193 # Mapping of an unescaped escape sequence to the corresponding character 194 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 195 196 # Columns that are auto-generated by the engine corresponding to this dialect 197 # Such columns may be excluded from SELECT * queries, for example 198 PSEUDOCOLUMNS: t.Set[str] = set() 199 200 # Autofilled 201 tokenizer_class = Tokenizer 202 parser_class = Parser 203 generator_class = Generator 204 205 # A trie of the time_mapping keys 206 TIME_TRIE: t.Dict = {} 207 FORMAT_TRIE: t.Dict = {} 208 209 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 210 INVERSE_TIME_TRIE: t.Dict = {} 211 212 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 213 214 def __eq__(self, other: t.Any) -> bool: 215 return type(self) == other 216 217 def __hash__(self) -> int: 218 return hash(type(self)) 219 220 @classmethod 221 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 222 if not dialect: 223 return cls 224 if isinstance(dialect, _Dialect): 225 return dialect 226 if isinstance(dialect, Dialect): 227 return dialect.__class__ 228 229 result = cls.get(dialect) 230 if not result: 231 raise ValueError(f"Unknown dialect '{dialect}'") 232 233 return result 234 235 @classmethod 236 def format_time( 237 cls, expression: t.Optional[str | exp.Expression] 238 ) -> t.Optional[exp.Expression]: 239 if isinstance(expression, str): 240 return exp.Literal.string( 241 # the time formats are quoted 242 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 243 ) 244 245 if expression and expression.is_string: 246 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 247 248 return expression 249 250 @classmethod 251 def normalize_identifier(cls, expression: E) -> E: 252 """ 253 Normalizes an unquoted identifier to either lower or upper case, thus essentially 254 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 255 they will be normalized to lowercase regardless of being quoted or not. 256 """ 257 if isinstance(expression, exp.Identifier) and ( 258 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 259 ): 260 expression.set( 261 "this", 262 expression.this.upper() 263 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 264 else expression.this.lower(), 265 ) 266 267 return expression 268 269 @classmethod 270 def case_sensitive(cls, text: str) -> bool: 271 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 272 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 273 return False 274 275 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 276 return any(unsafe(char) for char in text) 277 278 @classmethod 279 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 280 """Checks if text can be identified given an identify option. 281 282 Args: 283 text: The text to check. 284 identify: 285 "always" or `True`: Always returns true. 286 "safe": True if the identifier is case-insensitive. 287 288 Returns: 289 Whether or not the given text can be identified. 290 """ 291 if identify is True or identify == "always": 292 return True 293 294 if identify == "safe": 295 return not cls.case_sensitive(text) 296 297 return False 298 299 @classmethod 300 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 301 if isinstance(expression, exp.Identifier): 302 name = expression.this 303 expression.set( 304 "quoted", 305 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 306 ) 307 308 return expression 309 310 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 311 return self.parser(**opts).parse(self.tokenize(sql), sql) 312 313 def parse_into( 314 self, expression_type: exp.IntoType, sql: str, **opts 315 ) -> t.List[t.Optional[exp.Expression]]: 316 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 317 318 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 319 return self.generator(**opts).generate(expression, copy=copy) 320 321 def transpile(self, sql: str, **opts) -> t.List[str]: 322 return [ 323 self.generate(expression, copy=False, **opts) if expression else "" 324 for expression in self.parse(sql) 325 ] 326 327 def tokenize(self, sql: str) -> t.List[Token]: 328 return self.tokenizer.tokenize(sql) 329 330 @property 331 def tokenizer(self) -> Tokenizer: 332 if not hasattr(self, "_tokenizer"): 333 self._tokenizer = self.tokenizer_class() 334 return self._tokenizer 335 336 def parser(self, **opts) -> Parser: 337 return self.parser_class(**opts) 338 339 def generator(self, **opts) -> Generator: 340 return self.generator_class(**opts) 341 342 343DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 344 345 346def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 347 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 348 349 350def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 351 if expression.args.get("accuracy"): 352 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 353 return self.func("APPROX_COUNT_DISTINCT", expression.this) 354 355 356def if_sql( 357 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 358) -> t.Callable[[Generator, exp.If], str]: 359 def _if_sql(self: Generator, expression: exp.If) -> str: 360 return self.func( 361 name, 362 expression.this, 363 expression.args.get("true"), 364 expression.args.get("false") or false_value, 365 ) 366 367 return _if_sql 368 369 370def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: 371 return self.binary(expression, "->") 372 373 374def arrow_json_extract_scalar_sql( 375 self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar 376) -> str: 377 return self.binary(expression, "->>") 378 379 380def inline_array_sql(self: Generator, expression: exp.Array) -> str: 381 return f"[{self.expressions(expression, flat=True)}]" 382 383 384def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 385 return self.like_sql( 386 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 387 ) 388 389 390def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 391 zone = self.sql(expression, "this") 392 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 393 394 395def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 396 if expression.args.get("recursive"): 397 self.unsupported("Recursive CTEs are unsupported") 398 expression.args["recursive"] = False 399 return self.with_sql(expression) 400 401 402def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 403 n = self.sql(expression, "this") 404 d = self.sql(expression, "expression") 405 return f"IF({d} <> 0, {n} / {d}, NULL)" 406 407 408def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 409 self.unsupported("TABLESAMPLE unsupported") 410 return self.sql(expression.this) 411 412 413def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 414 self.unsupported("PIVOT unsupported") 415 return "" 416 417 418def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 419 return self.cast_sql(expression) 420 421 422def no_properties_sql(self: Generator, expression: exp.Properties) -> str: 423 self.unsupported("Properties unsupported") 424 return "" 425 426 427def no_comment_column_constraint_sql( 428 self: Generator, expression: exp.CommentColumnConstraint 429) -> str: 430 self.unsupported("CommentColumnConstraint unsupported") 431 return "" 432 433 434def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 435 self.unsupported("MAP_FROM_ENTRIES unsupported") 436 return "" 437 438 439def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 440 this = self.sql(expression, "this") 441 substr = self.sql(expression, "substr") 442 position = self.sql(expression, "position") 443 if position: 444 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 445 return f"STRPOS({this}, {substr})" 446 447 448def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 449 return ( 450 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 451 ) 452 453 454def var_map_sql( 455 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 456) -> str: 457 keys = expression.args["keys"] 458 values = expression.args["values"] 459 460 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 461 self.unsupported("Cannot convert array columns into map.") 462 return self.func(map_func_name, keys, values) 463 464 args = [] 465 for key, value in zip(keys.expressions, values.expressions): 466 args.append(self.sql(key)) 467 args.append(self.sql(value)) 468 469 return self.func(map_func_name, *args) 470 471 472def format_time_lambda( 473 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 474) -> t.Callable[[t.List], E]: 475 """Helper used for time expressions. 476 477 Args: 478 exp_class: the expression class to instantiate. 479 dialect: target sql dialect. 480 default: the default format, True being time. 481 482 Returns: 483 A callable that can be used to return the appropriately formatted time expression. 484 """ 485 486 def _format_time(args: t.List): 487 return exp_class( 488 this=seq_get(args, 0), 489 format=Dialect[dialect].format_time( 490 seq_get(args, 1) 491 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 492 ), 493 ) 494 495 return _format_time 496 497 498def time_format( 499 dialect: DialectType = None, 500) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 501 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 502 """ 503 Returns the time format for a given expression, unless it's equivalent 504 to the default time format of the dialect of interest. 505 """ 506 time_format = self.format_time(expression) 507 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 508 509 return _time_format 510 511 512def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 513 """ 514 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 515 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 516 columns are removed from the create statement. 517 """ 518 has_schema = isinstance(expression.this, exp.Schema) 519 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 520 521 if has_schema and is_partitionable: 522 prop = expression.find(exp.PartitionedByProperty) 523 if prop and prop.this and not isinstance(prop.this, exp.Schema): 524 schema = expression.this 525 columns = {v.name.upper() for v in prop.this.expressions} 526 partitions = [col for col in schema.expressions if col.name.upper() in columns] 527 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 528 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 529 expression.set("this", schema) 530 531 return self.create_sql(expression) 532 533 534def parse_date_delta( 535 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 536) -> t.Callable[[t.List], E]: 537 def inner_func(args: t.List) -> E: 538 unit_based = len(args) == 3 539 this = args[2] if unit_based else seq_get(args, 0) 540 unit = args[0] if unit_based else exp.Literal.string("DAY") 541 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 542 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 543 544 return inner_func 545 546 547def parse_date_delta_with_interval( 548 expression_class: t.Type[E], 549) -> t.Callable[[t.List], t.Optional[E]]: 550 def func(args: t.List) -> t.Optional[E]: 551 if len(args) < 2: 552 return None 553 554 interval = args[1] 555 556 if not isinstance(interval, exp.Interval): 557 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 558 559 expression = interval.this 560 if expression and expression.is_string: 561 expression = exp.Literal.number(expression.this) 562 563 return expression_class( 564 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 565 ) 566 567 return func 568 569 570def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 571 unit = seq_get(args, 0) 572 this = seq_get(args, 1) 573 574 if isinstance(this, exp.Cast) and this.is_type("date"): 575 return exp.DateTrunc(unit=unit, this=this) 576 return exp.TimestampTrunc(this=this, unit=unit) 577 578 579def date_add_interval_sql( 580 data_type: str, kind: str 581) -> t.Callable[[Generator, exp.Expression], str]: 582 def func(self: Generator, expression: exp.Expression) -> str: 583 this = self.sql(expression, "this") 584 unit = expression.args.get("unit") 585 unit = exp.var(unit.name.upper() if unit else "DAY") 586 interval = exp.Interval(this=expression.expression, unit=unit) 587 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 588 589 return func 590 591 592def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 593 return self.func( 594 "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this 595 ) 596 597 598def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 599 if not expression.expression: 600 return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP)) 601 if expression.text("expression").lower() in TIMEZONES: 602 return self.sql( 603 exp.AtTimeZone( 604 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 605 zone=expression.expression, 606 ) 607 ) 608 return self.function_fallback_sql(expression) 609 610 611def locate_to_strposition(args: t.List) -> exp.Expression: 612 return exp.StrPosition( 613 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 614 ) 615 616 617def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 618 return self.func( 619 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 620 ) 621 622 623def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 624 return self.sql( 625 exp.Substring( 626 this=expression.this, start=exp.Literal.number(1), length=expression.expression 627 ) 628 ) 629 630 631def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 632 return self.sql( 633 exp.Substring( 634 this=expression.this, 635 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 636 ) 637 ) 638 639 640def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 641 return self.sql(exp.cast(expression.this, "timestamp")) 642 643 644def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 645 return self.sql(exp.cast(expression.this, "date")) 646 647 648# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 649def encode_decode_sql( 650 self: Generator, expression: exp.Expression, name: str, replace: bool = True 651) -> str: 652 charset = expression.args.get("charset") 653 if charset and charset.name.lower() != "utf-8": 654 self.unsupported(f"Expected utf-8 character set, got {charset}.") 655 656 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 657 658 659def min_or_least(self: Generator, expression: exp.Min) -> str: 660 name = "LEAST" if expression.expressions else "MIN" 661 return rename_func(name)(self, expression) 662 663 664def max_or_greatest(self: Generator, expression: exp.Max) -> str: 665 name = "GREATEST" if expression.expressions else "MAX" 666 return rename_func(name)(self, expression) 667 668 669def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 670 cond = expression.this 671 672 if isinstance(expression.this, exp.Distinct): 673 cond = expression.this.expressions[0] 674 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 675 676 return self.func("sum", exp.func("if", cond, 1, 0)) 677 678 679def trim_sql(self: Generator, expression: exp.Trim) -> str: 680 target = self.sql(expression, "this") 681 trim_type = self.sql(expression, "position") 682 remove_chars = self.sql(expression, "expression") 683 collation = self.sql(expression, "collation") 684 685 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 686 if not remove_chars and not collation: 687 return self.trim_sql(expression) 688 689 trim_type = f"{trim_type} " if trim_type else "" 690 remove_chars = f"{remove_chars} " if remove_chars else "" 691 from_part = "FROM " if trim_type or remove_chars else "" 692 collation = f" COLLATE {collation}" if collation else "" 693 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 694 695 696def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 697 return self.func("STRPTIME", expression.this, self.format_time(expression)) 698 699 700def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 701 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 702 _dialect = Dialect.get_or_raise(dialect) 703 time_format = self.format_time(expression) 704 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 705 return self.sql( 706 exp.cast( 707 exp.StrToTime(this=expression.this, format=expression.args["format"]), 708 "date", 709 ) 710 ) 711 return self.sql(exp.cast(expression.this, "date")) 712 713 return _ts_or_ds_to_date_sql 714 715 716def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str: 717 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 718 719 720def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 721 delim, *rest_args = expression.expressions 722 return self.sql( 723 reduce( 724 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 725 rest_args, 726 ) 727 ) 728 729 730def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 731 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 732 if bad_args: 733 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 734 735 return self.func( 736 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 737 ) 738 739 740def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 741 bad_args = list( 742 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 743 ) 744 if bad_args: 745 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 746 747 return self.func( 748 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 749 ) 750 751 752def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 753 names = [] 754 for agg in aggregations: 755 if isinstance(agg, exp.Alias): 756 names.append(agg.alias) 757 else: 758 """ 759 This case corresponds to aggregations without aliases being used as suffixes 760 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 761 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 762 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 763 """ 764 agg_all_unquoted = agg.transform( 765 lambda node: exp.Identifier(this=node.name, quoted=False) 766 if isinstance(node, exp.Identifier) 767 else node 768 ) 769 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 770 771 return names 772 773 774def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 775 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 776 777 778# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 779def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 780 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 781 782 783def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 784 return self.func("MAX", expression.this) 785 786 787def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 788 a = self.sql(expression.left) 789 b = self.sql(expression.right) 790 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 791 792 793# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon 794def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str: 795 return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}" 796 797 798def is_parse_json(expression: exp.Expression) -> bool: 799 return isinstance(expression, exp.ParseJSON) or ( 800 isinstance(expression, exp.Cast) and expression.is_type("json") 801 ) 802 803 804def isnull_to_is_null(args: t.List) -> exp.Expression: 805 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 806 807 808def generatedasidentitycolumnconstraint_sql( 809 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 810) -> str: 811 start = self.sql(expression, "start") or "1" 812 increment = self.sql(expression, "increment") or "1" 813 return f"IDENTITY({start}, {increment})" 814 815 816def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 817 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 818 if expression.args.get("count"): 819 self.unsupported(f"Only two arguments are supported in function {name}.") 820 821 return self.func(name, expression.this, expression.expression) 822 823 return _arg_max_or_min_sql
21class Dialects(str, Enum): 22 DIALECT = "" 23 24 BIGQUERY = "bigquery" 25 CLICKHOUSE = "clickhouse" 26 DATABRICKS = "databricks" 27 DRILL = "drill" 28 DUCKDB = "duckdb" 29 HIVE = "hive" 30 MYSQL = "mysql" 31 ORACLE = "oracle" 32 POSTGRES = "postgres" 33 PRESTO = "presto" 34 REDSHIFT = "redshift" 35 SNOWFLAKE = "snowflake" 36 SPARK = "spark" 37 SPARK2 = "spark2" 38 SQLITE = "sqlite" 39 STARROCKS = "starrocks" 40 TABLEAU = "tableau" 41 TERADATA = "teradata" 42 TRINO = "trino" 43 TSQL = "tsql" 44 Doris = "doris"
An enumeration.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
142class Dialect(metaclass=_Dialect): 143 # Determines the base index offset for arrays 144 INDEX_OFFSET = 0 145 146 # If true unnest table aliases are considered only as column aliases 147 UNNEST_COLUMN_ONLY = False 148 149 # Determines whether or not the table alias comes after tablesample 150 ALIAS_POST_TABLESAMPLE = False 151 152 # Determines whether or not unquoted identifiers are resolved as uppercase 153 # When set to None, it means that the dialect treats all identifiers as case-insensitive 154 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 155 156 # Determines whether or not an unquoted identifier can start with a digit 157 IDENTIFIERS_CAN_START_WITH_DIGIT = False 158 159 # Determines whether or not the DPIPE token ('||') is a string concatenation operator 160 DPIPE_IS_STRING_CONCAT = True 161 162 # Determines whether or not CONCAT's arguments must be strings 163 STRICT_STRING_CONCAT = False 164 165 # Determines whether or not user-defined data types are supported 166 SUPPORTS_USER_DEFINED_TYPES = True 167 168 # Determines whether or not SEMI/ANTI JOINs are supported 169 SUPPORTS_SEMI_ANTI_JOIN = True 170 171 # Determines how function names are going to be normalized 172 NORMALIZE_FUNCTIONS: bool | str = "upper" 173 174 # Determines whether the base comes first in the LOG function 175 LOG_BASE_FIRST = True 176 177 # Indicates the default null ordering method to use if not explicitly set 178 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 179 NULL_ORDERING = "nulls_are_small" 180 181 DATE_FORMAT = "'%Y-%m-%d'" 182 DATEINT_FORMAT = "'%Y%m%d'" 183 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 184 185 # Custom time mappings in which the key represents dialect time format 186 # and the value represents a python time format 187 TIME_MAPPING: t.Dict[str, str] = {} 188 189 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 190 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 191 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 192 FORMAT_MAPPING: t.Dict[str, str] = {} 193 194 # Mapping of an unescaped escape sequence to the corresponding character 195 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 196 197 # Columns that are auto-generated by the engine corresponding to this dialect 198 # Such columns may be excluded from SELECT * queries, for example 199 PSEUDOCOLUMNS: t.Set[str] = set() 200 201 # Autofilled 202 tokenizer_class = Tokenizer 203 parser_class = Parser 204 generator_class = Generator 205 206 # A trie of the time_mapping keys 207 TIME_TRIE: t.Dict = {} 208 FORMAT_TRIE: t.Dict = {} 209 210 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 211 INVERSE_TIME_TRIE: t.Dict = {} 212 213 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 214 215 def __eq__(self, other: t.Any) -> bool: 216 return type(self) == other 217 218 def __hash__(self) -> int: 219 return hash(type(self)) 220 221 @classmethod 222 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 223 if not dialect: 224 return cls 225 if isinstance(dialect, _Dialect): 226 return dialect 227 if isinstance(dialect, Dialect): 228 return dialect.__class__ 229 230 result = cls.get(dialect) 231 if not result: 232 raise ValueError(f"Unknown dialect '{dialect}'") 233 234 return result 235 236 @classmethod 237 def format_time( 238 cls, expression: t.Optional[str | exp.Expression] 239 ) -> t.Optional[exp.Expression]: 240 if isinstance(expression, str): 241 return exp.Literal.string( 242 # the time formats are quoted 243 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 244 ) 245 246 if expression and expression.is_string: 247 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 248 249 return expression 250 251 @classmethod 252 def normalize_identifier(cls, expression: E) -> E: 253 """ 254 Normalizes an unquoted identifier to either lower or upper case, thus essentially 255 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 256 they will be normalized to lowercase regardless of being quoted or not. 257 """ 258 if isinstance(expression, exp.Identifier) and ( 259 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 260 ): 261 expression.set( 262 "this", 263 expression.this.upper() 264 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 265 else expression.this.lower(), 266 ) 267 268 return expression 269 270 @classmethod 271 def case_sensitive(cls, text: str) -> bool: 272 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 273 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 274 return False 275 276 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 277 return any(unsafe(char) for char in text) 278 279 @classmethod 280 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 281 """Checks if text can be identified given an identify option. 282 283 Args: 284 text: The text to check. 285 identify: 286 "always" or `True`: Always returns true. 287 "safe": True if the identifier is case-insensitive. 288 289 Returns: 290 Whether or not the given text can be identified. 291 """ 292 if identify is True or identify == "always": 293 return True 294 295 if identify == "safe": 296 return not cls.case_sensitive(text) 297 298 return False 299 300 @classmethod 301 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 302 if isinstance(expression, exp.Identifier): 303 name = expression.this 304 expression.set( 305 "quoted", 306 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 307 ) 308 309 return expression 310 311 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 312 return self.parser(**opts).parse(self.tokenize(sql), sql) 313 314 def parse_into( 315 self, expression_type: exp.IntoType, sql: str, **opts 316 ) -> t.List[t.Optional[exp.Expression]]: 317 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 318 319 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 320 return self.generator(**opts).generate(expression, copy=copy) 321 322 def transpile(self, sql: str, **opts) -> t.List[str]: 323 return [ 324 self.generate(expression, copy=False, **opts) if expression else "" 325 for expression in self.parse(sql) 326 ] 327 328 def tokenize(self, sql: str) -> t.List[Token]: 329 return self.tokenizer.tokenize(sql) 330 331 @property 332 def tokenizer(self) -> Tokenizer: 333 if not hasattr(self, "_tokenizer"): 334 self._tokenizer = self.tokenizer_class() 335 return self._tokenizer 336 337 def parser(self, **opts) -> Parser: 338 return self.parser_class(**opts) 339 340 def generator(self, **opts) -> Generator: 341 return self.generator_class(**opts)
221 @classmethod 222 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 223 if not dialect: 224 return cls 225 if isinstance(dialect, _Dialect): 226 return dialect 227 if isinstance(dialect, Dialect): 228 return dialect.__class__ 229 230 result = cls.get(dialect) 231 if not result: 232 raise ValueError(f"Unknown dialect '{dialect}'") 233 234 return result
236 @classmethod 237 def format_time( 238 cls, expression: t.Optional[str | exp.Expression] 239 ) -> t.Optional[exp.Expression]: 240 if isinstance(expression, str): 241 return exp.Literal.string( 242 # the time formats are quoted 243 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 244 ) 245 246 if expression and expression.is_string: 247 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 248 249 return expression
251 @classmethod 252 def normalize_identifier(cls, expression: E) -> E: 253 """ 254 Normalizes an unquoted identifier to either lower or upper case, thus essentially 255 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 256 they will be normalized to lowercase regardless of being quoted or not. 257 """ 258 if isinstance(expression, exp.Identifier) and ( 259 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 260 ): 261 expression.set( 262 "this", 263 expression.this.upper() 264 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 265 else expression.this.lower(), 266 ) 267 268 return expression
Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized to lowercase regardless of being quoted or not.
270 @classmethod 271 def case_sensitive(cls, text: str) -> bool: 272 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 273 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 274 return False 275 276 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 277 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
279 @classmethod 280 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 281 """Checks if text can be identified given an identify option. 282 283 Args: 284 text: The text to check. 285 identify: 286 "always" or `True`: Always returns true. 287 "safe": True if the identifier is case-insensitive. 288 289 Returns: 290 Whether or not the given text can be identified. 291 """ 292 if identify is True or identify == "always": 293 return True 294 295 if identify == "safe": 296 return not cls.case_sensitive(text) 297 298 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify: "always" or
True
: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:
Whether or not the given text can be identified.
300 @classmethod 301 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 302 if isinstance(expression, exp.Identifier): 303 name = expression.this 304 expression.set( 305 "quoted", 306 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 307 ) 308 309 return expression
357def if_sql( 358 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 359) -> t.Callable[[Generator, exp.If], str]: 360 def _if_sql(self: Generator, expression: exp.If) -> str: 361 return self.func( 362 name, 363 expression.this, 364 expression.args.get("true"), 365 expression.args.get("false") or false_value, 366 ) 367 368 return _if_sql
440def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 441 this = self.sql(expression, "this") 442 substr = self.sql(expression, "substr") 443 position = self.sql(expression, "position") 444 if position: 445 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 446 return f"STRPOS({this}, {substr})"
455def var_map_sql( 456 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 457) -> str: 458 keys = expression.args["keys"] 459 values = expression.args["values"] 460 461 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 462 self.unsupported("Cannot convert array columns into map.") 463 return self.func(map_func_name, keys, values) 464 465 args = [] 466 for key, value in zip(keys.expressions, values.expressions): 467 args.append(self.sql(key)) 468 args.append(self.sql(value)) 469 470 return self.func(map_func_name, *args)
473def format_time_lambda( 474 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 475) -> t.Callable[[t.List], E]: 476 """Helper used for time expressions. 477 478 Args: 479 exp_class: the expression class to instantiate. 480 dialect: target sql dialect. 481 default: the default format, True being time. 482 483 Returns: 484 A callable that can be used to return the appropriately formatted time expression. 485 """ 486 487 def _format_time(args: t.List): 488 return exp_class( 489 this=seq_get(args, 0), 490 format=Dialect[dialect].format_time( 491 seq_get(args, 1) 492 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 493 ), 494 ) 495 496 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
499def time_format( 500 dialect: DialectType = None, 501) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 502 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 503 """ 504 Returns the time format for a given expression, unless it's equivalent 505 to the default time format of the dialect of interest. 506 """ 507 time_format = self.format_time(expression) 508 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 509 510 return _time_format
513def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 514 """ 515 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 516 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 517 columns are removed from the create statement. 518 """ 519 has_schema = isinstance(expression.this, exp.Schema) 520 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 521 522 if has_schema and is_partitionable: 523 prop = expression.find(exp.PartitionedByProperty) 524 if prop and prop.this and not isinstance(prop.this, exp.Schema): 525 schema = expression.this 526 columns = {v.name.upper() for v in prop.this.expressions} 527 partitions = [col for col in schema.expressions if col.name.upper() in columns] 528 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 529 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 530 expression.set("this", schema) 531 532 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
535def parse_date_delta( 536 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 537) -> t.Callable[[t.List], E]: 538 def inner_func(args: t.List) -> E: 539 unit_based = len(args) == 3 540 this = args[2] if unit_based else seq_get(args, 0) 541 unit = args[0] if unit_based else exp.Literal.string("DAY") 542 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 543 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 544 545 return inner_func
548def parse_date_delta_with_interval( 549 expression_class: t.Type[E], 550) -> t.Callable[[t.List], t.Optional[E]]: 551 def func(args: t.List) -> t.Optional[E]: 552 if len(args) < 2: 553 return None 554 555 interval = args[1] 556 557 if not isinstance(interval, exp.Interval): 558 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 559 560 expression = interval.this 561 if expression and expression.is_string: 562 expression = exp.Literal.number(expression.this) 563 564 return expression_class( 565 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 566 ) 567 568 return func
580def date_add_interval_sql( 581 data_type: str, kind: str 582) -> t.Callable[[Generator, exp.Expression], str]: 583 def func(self: Generator, expression: exp.Expression) -> str: 584 this = self.sql(expression, "this") 585 unit = expression.args.get("unit") 586 unit = exp.var(unit.name.upper() if unit else "DAY") 587 interval = exp.Interval(this=expression.expression, unit=unit) 588 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 589 590 return func
599def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 600 if not expression.expression: 601 return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP)) 602 if expression.text("expression").lower() in TIMEZONES: 603 return self.sql( 604 exp.AtTimeZone( 605 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 606 zone=expression.expression, 607 ) 608 ) 609 return self.function_fallback_sql(expression)
650def encode_decode_sql( 651 self: Generator, expression: exp.Expression, name: str, replace: bool = True 652) -> str: 653 charset = expression.args.get("charset") 654 if charset and charset.name.lower() != "utf-8": 655 self.unsupported(f"Expected utf-8 character set, got {charset}.") 656 657 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
670def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 671 cond = expression.this 672 673 if isinstance(expression.this, exp.Distinct): 674 cond = expression.this.expressions[0] 675 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 676 677 return self.func("sum", exp.func("if", cond, 1, 0))
680def trim_sql(self: Generator, expression: exp.Trim) -> str: 681 target = self.sql(expression, "this") 682 trim_type = self.sql(expression, "position") 683 remove_chars = self.sql(expression, "expression") 684 collation = self.sql(expression, "collation") 685 686 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 687 if not remove_chars and not collation: 688 return self.trim_sql(expression) 689 690 trim_type = f"{trim_type} " if trim_type else "" 691 remove_chars = f"{remove_chars} " if remove_chars else "" 692 from_part = "FROM " if trim_type or remove_chars else "" 693 collation = f" COLLATE {collation}" if collation else "" 694 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
701def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 702 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 703 _dialect = Dialect.get_or_raise(dialect) 704 time_format = self.format_time(expression) 705 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 706 return self.sql( 707 exp.cast( 708 exp.StrToTime(this=expression.this, format=expression.args["format"]), 709 "date", 710 ) 711 ) 712 return self.sql(exp.cast(expression.this, "date")) 713 714 return _ts_or_ds_to_date_sql
731def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 732 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 733 if bad_args: 734 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 735 736 return self.func( 737 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 738 )
741def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 742 bad_args = list( 743 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 744 ) 745 if bad_args: 746 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 747 748 return self.func( 749 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 750 )
753def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 754 names = [] 755 for agg in aggregations: 756 if isinstance(agg, exp.Alias): 757 names.append(agg.alias) 758 else: 759 """ 760 This case corresponds to aggregations without aliases being used as suffixes 761 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 762 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 763 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 764 """ 765 agg_all_unquoted = agg.transform( 766 lambda node: exp.Identifier(this=node.name, quoted=False) 767 if isinstance(node, exp.Identifier) 768 else node 769 ) 770 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 771 772 return names
817def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 818 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 819 if expression.args.get("count"): 820 self.unsupported(f"Only two arguments are supported in function {name}.") 821 822 return self.func(name, expression.this, expression.expression) 823 824 return _arg_max_or_min_sql