sqlglot.dialects.dialect
1from __future__ import annotations 2 3import typing as t 4from enum import Enum 5from functools import reduce 6 7from sqlglot import exp 8from sqlglot._typing import E 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import flatten, seq_get 12from sqlglot.parser import Parser 13from sqlglot.time import format_time 14from sqlglot.tokens import Token, Tokenizer, TokenType 15from sqlglot.trie import new_trie 16 17B = t.TypeVar("B", bound=exp.Binary) 18 19 20class Dialects(str, Enum): 21 DIALECT = "" 22 23 BIGQUERY = "bigquery" 24 CLICKHOUSE = "clickhouse" 25 DATABRICKS = "databricks" 26 DRILL = "drill" 27 DUCKDB = "duckdb" 28 HIVE = "hive" 29 MYSQL = "mysql" 30 ORACLE = "oracle" 31 POSTGRES = "postgres" 32 PRESTO = "presto" 33 REDSHIFT = "redshift" 34 SNOWFLAKE = "snowflake" 35 SPARK = "spark" 36 SPARK2 = "spark2" 37 SQLITE = "sqlite" 38 STARROCKS = "starrocks" 39 TABLEAU = "tableau" 40 TERADATA = "teradata" 41 TRINO = "trino" 42 TSQL = "tsql" 43 Doris = "doris" 44 45 46class _Dialect(type): 47 classes: t.Dict[str, t.Type[Dialect]] = {} 48 49 def __eq__(cls, other: t.Any) -> bool: 50 if cls is other: 51 return True 52 if isinstance(other, str): 53 return cls is cls.get(other) 54 if isinstance(other, Dialect): 55 return cls is type(other) 56 57 return False 58 59 def __hash__(cls) -> int: 60 return hash(cls.__name__.lower()) 61 62 @classmethod 63 def __getitem__(cls, key: str) -> t.Type[Dialect]: 64 return cls.classes[key] 65 66 @classmethod 67 def get( 68 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 69 ) -> t.Optional[t.Type[Dialect]]: 70 return cls.classes.get(key, default) 71 72 def __new__(cls, clsname, bases, attrs): 73 klass = super().__new__(cls, clsname, bases, attrs) 74 enum = Dialects.__members__.get(clsname.upper()) 75 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 76 77 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 78 klass.FORMAT_TRIE = ( 79 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 80 ) 81 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 82 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 83 84 klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()} 85 86 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 87 klass.parser_class = getattr(klass, "Parser", Parser) 88 klass.generator_class = getattr(klass, "Generator", Generator) 89 90 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 91 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 92 klass.tokenizer_class._IDENTIFIERS.items() 93 )[0] 94 95 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 96 return next( 97 ( 98 (s, e) 99 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 100 if t == token_type 101 ), 102 (None, None), 103 ) 104 105 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 106 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 107 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 108 109 dialect_properties = { 110 **{ 111 k: v 112 for k, v in vars(klass).items() 113 if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__") 114 }, 115 "TOKENIZER_CLASS": klass.tokenizer_class, 116 } 117 118 if enum not in ("", "bigquery"): 119 dialect_properties["SELECT_KINDS"] = () 120 121 # Pass required dialect properties to the tokenizer, parser and generator classes 122 for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class): 123 for name, value in dialect_properties.items(): 124 if hasattr(subclass, name): 125 setattr(subclass, name, value) 126 127 if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT: 128 klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe 129 130 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 131 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 132 TokenType.ANTI, 133 TokenType.SEMI, 134 } 135 136 klass.generator_class.can_identify = klass.can_identify 137 138 return klass 139 140 141class Dialect(metaclass=_Dialect): 142 # Determines the base index offset for arrays 143 INDEX_OFFSET = 0 144 145 # If true unnest table aliases are considered only as column aliases 146 UNNEST_COLUMN_ONLY = False 147 148 # Determines whether or not the table alias comes after tablesample 149 ALIAS_POST_TABLESAMPLE = False 150 151 # Determines whether or not unquoted identifiers are resolved as uppercase 152 # When set to None, it means that the dialect treats all identifiers as case-insensitive 153 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 154 155 # Determines whether or not an unquoted identifier can start with a digit 156 IDENTIFIERS_CAN_START_WITH_DIGIT = False 157 158 # Determines whether or not the DPIPE token ('||') is a string concatenation operator 159 DPIPE_IS_STRING_CONCAT = True 160 161 # Determines whether or not CONCAT's arguments must be strings 162 STRICT_STRING_CONCAT = False 163 164 # Determines whether or not user-defined data types are supported 165 SUPPORTS_USER_DEFINED_TYPES = True 166 167 # Determines whether or not SEMI/ANTI JOINs are supported 168 SUPPORTS_SEMI_ANTI_JOIN = True 169 170 # Determines how function names are going to be normalized 171 NORMALIZE_FUNCTIONS: bool | str = "upper" 172 173 # Determines whether the base comes first in the LOG function 174 LOG_BASE_FIRST = True 175 176 # Indicates the default null ordering method to use if not explicitly set 177 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 178 NULL_ORDERING = "nulls_are_small" 179 180 DATE_FORMAT = "'%Y-%m-%d'" 181 DATEINT_FORMAT = "'%Y%m%d'" 182 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 183 184 # Custom time mappings in which the key represents dialect time format 185 # and the value represents a python time format 186 TIME_MAPPING: t.Dict[str, str] = {} 187 188 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 189 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 190 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 191 FORMAT_MAPPING: t.Dict[str, str] = {} 192 193 # Mapping of an unescaped escape sequence to the corresponding character 194 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 195 196 # Columns that are auto-generated by the engine corresponding to this dialect 197 # Such columns may be excluded from SELECT * queries, for example 198 PSEUDOCOLUMNS: t.Set[str] = set() 199 200 # Autofilled 201 tokenizer_class = Tokenizer 202 parser_class = Parser 203 generator_class = Generator 204 205 # A trie of the time_mapping keys 206 TIME_TRIE: t.Dict = {} 207 FORMAT_TRIE: t.Dict = {} 208 209 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 210 INVERSE_TIME_TRIE: t.Dict = {} 211 212 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 213 214 def __eq__(self, other: t.Any) -> bool: 215 return type(self) == other 216 217 def __hash__(self) -> int: 218 return hash(type(self)) 219 220 @classmethod 221 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 222 if not dialect: 223 return cls 224 if isinstance(dialect, _Dialect): 225 return dialect 226 if isinstance(dialect, Dialect): 227 return dialect.__class__ 228 229 result = cls.get(dialect) 230 if not result: 231 raise ValueError(f"Unknown dialect '{dialect}'") 232 233 return result 234 235 @classmethod 236 def format_time( 237 cls, expression: t.Optional[str | exp.Expression] 238 ) -> t.Optional[exp.Expression]: 239 if isinstance(expression, str): 240 return exp.Literal.string( 241 # the time formats are quoted 242 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 243 ) 244 245 if expression and expression.is_string: 246 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 247 248 return expression 249 250 @classmethod 251 def normalize_identifier(cls, expression: E) -> E: 252 """ 253 Normalizes an unquoted identifier to either lower or upper case, thus essentially 254 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 255 they will be normalized to lowercase regardless of being quoted or not. 256 """ 257 if isinstance(expression, exp.Identifier) and ( 258 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 259 ): 260 expression.set( 261 "this", 262 expression.this.upper() 263 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 264 else expression.this.lower(), 265 ) 266 267 return expression 268 269 @classmethod 270 def case_sensitive(cls, text: str) -> bool: 271 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 272 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 273 return False 274 275 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 276 return any(unsafe(char) for char in text) 277 278 @classmethod 279 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 280 """Checks if text can be identified given an identify option. 281 282 Args: 283 text: The text to check. 284 identify: 285 "always" or `True`: Always returns true. 286 "safe": True if the identifier is case-insensitive. 287 288 Returns: 289 Whether or not the given text can be identified. 290 """ 291 if identify is True or identify == "always": 292 return True 293 294 if identify == "safe": 295 return not cls.case_sensitive(text) 296 297 return False 298 299 @classmethod 300 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 301 if isinstance(expression, exp.Identifier): 302 name = expression.this 303 expression.set( 304 "quoted", 305 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 306 ) 307 308 return expression 309 310 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 311 return self.parser(**opts).parse(self.tokenize(sql), sql) 312 313 def parse_into( 314 self, expression_type: exp.IntoType, sql: str, **opts 315 ) -> t.List[t.Optional[exp.Expression]]: 316 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 317 318 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 319 return self.generator(**opts).generate(expression) 320 321 def transpile(self, sql: str, **opts) -> t.List[str]: 322 return [self.generate(expression, **opts) for expression in self.parse(sql)] 323 324 def tokenize(self, sql: str) -> t.List[Token]: 325 return self.tokenizer.tokenize(sql) 326 327 @property 328 def tokenizer(self) -> Tokenizer: 329 if not hasattr(self, "_tokenizer"): 330 self._tokenizer = self.tokenizer_class() 331 return self._tokenizer 332 333 def parser(self, **opts) -> Parser: 334 return self.parser_class(**opts) 335 336 def generator(self, **opts) -> Generator: 337 return self.generator_class(**opts) 338 339 340DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 341 342 343def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 344 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 345 346 347def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 348 if expression.args.get("accuracy"): 349 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 350 return self.func("APPROX_COUNT_DISTINCT", expression.this) 351 352 353def if_sql( 354 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 355) -> t.Callable[[Generator, exp.If], str]: 356 def _if_sql(self: Generator, expression: exp.If) -> str: 357 return self.func( 358 name, 359 expression.this, 360 expression.args.get("true"), 361 expression.args.get("false") or false_value, 362 ) 363 364 return _if_sql 365 366 367def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: 368 return self.binary(expression, "->") 369 370 371def arrow_json_extract_scalar_sql( 372 self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar 373) -> str: 374 return self.binary(expression, "->>") 375 376 377def inline_array_sql(self: Generator, expression: exp.Array) -> str: 378 return f"[{self.expressions(expression, flat=True)}]" 379 380 381def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 382 return self.like_sql( 383 exp.Like( 384 this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy() 385 ) 386 ) 387 388 389def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 390 zone = self.sql(expression, "this") 391 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 392 393 394def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 395 if expression.args.get("recursive"): 396 self.unsupported("Recursive CTEs are unsupported") 397 expression.args["recursive"] = False 398 return self.with_sql(expression) 399 400 401def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 402 n = self.sql(expression, "this") 403 d = self.sql(expression, "expression") 404 return f"IF({d} <> 0, {n} / {d}, NULL)" 405 406 407def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 408 self.unsupported("TABLESAMPLE unsupported") 409 return self.sql(expression.this) 410 411 412def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 413 self.unsupported("PIVOT unsupported") 414 return "" 415 416 417def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 418 return self.cast_sql(expression) 419 420 421def no_properties_sql(self: Generator, expression: exp.Properties) -> str: 422 self.unsupported("Properties unsupported") 423 return "" 424 425 426def no_comment_column_constraint_sql( 427 self: Generator, expression: exp.CommentColumnConstraint 428) -> str: 429 self.unsupported("CommentColumnConstraint unsupported") 430 return "" 431 432 433def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 434 self.unsupported("MAP_FROM_ENTRIES unsupported") 435 return "" 436 437 438def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 439 this = self.sql(expression, "this") 440 substr = self.sql(expression, "substr") 441 position = self.sql(expression, "position") 442 if position: 443 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 444 return f"STRPOS({this}, {substr})" 445 446 447def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 448 return ( 449 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 450 ) 451 452 453def var_map_sql( 454 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 455) -> str: 456 keys = expression.args["keys"] 457 values = expression.args["values"] 458 459 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 460 self.unsupported("Cannot convert array columns into map.") 461 return self.func(map_func_name, keys, values) 462 463 args = [] 464 for key, value in zip(keys.expressions, values.expressions): 465 args.append(self.sql(key)) 466 args.append(self.sql(value)) 467 468 return self.func(map_func_name, *args) 469 470 471def format_time_lambda( 472 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 473) -> t.Callable[[t.List], E]: 474 """Helper used for time expressions. 475 476 Args: 477 exp_class: the expression class to instantiate. 478 dialect: target sql dialect. 479 default: the default format, True being time. 480 481 Returns: 482 A callable that can be used to return the appropriately formatted time expression. 483 """ 484 485 def _format_time(args: t.List): 486 return exp_class( 487 this=seq_get(args, 0), 488 format=Dialect[dialect].format_time( 489 seq_get(args, 1) 490 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 491 ), 492 ) 493 494 return _format_time 495 496 497def time_format( 498 dialect: DialectType = None, 499) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 500 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 501 """ 502 Returns the time format for a given expression, unless it's equivalent 503 to the default time format of the dialect of interest. 504 """ 505 time_format = self.format_time(expression) 506 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 507 508 return _time_format 509 510 511def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 512 """ 513 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 514 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 515 columns are removed from the create statement. 516 """ 517 has_schema = isinstance(expression.this, exp.Schema) 518 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 519 520 if has_schema and is_partitionable: 521 expression = expression.copy() 522 prop = expression.find(exp.PartitionedByProperty) 523 if prop and prop.this and not isinstance(prop.this, exp.Schema): 524 schema = expression.this 525 columns = {v.name.upper() for v in prop.this.expressions} 526 partitions = [col for col in schema.expressions if col.name.upper() in columns] 527 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 528 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 529 expression.set("this", schema) 530 531 return self.create_sql(expression) 532 533 534def parse_date_delta( 535 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 536) -> t.Callable[[t.List], E]: 537 def inner_func(args: t.List) -> E: 538 unit_based = len(args) == 3 539 this = args[2] if unit_based else seq_get(args, 0) 540 unit = args[0] if unit_based else exp.Literal.string("DAY") 541 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 542 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 543 544 return inner_func 545 546 547def parse_date_delta_with_interval( 548 expression_class: t.Type[E], 549) -> t.Callable[[t.List], t.Optional[E]]: 550 def func(args: t.List) -> t.Optional[E]: 551 if len(args) < 2: 552 return None 553 554 interval = args[1] 555 556 if not isinstance(interval, exp.Interval): 557 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 558 559 expression = interval.this 560 if expression and expression.is_string: 561 expression = exp.Literal.number(expression.this) 562 563 return expression_class( 564 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 565 ) 566 567 return func 568 569 570def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 571 unit = seq_get(args, 0) 572 this = seq_get(args, 1) 573 574 if isinstance(this, exp.Cast) and this.is_type("date"): 575 return exp.DateTrunc(unit=unit, this=this) 576 return exp.TimestampTrunc(this=this, unit=unit) 577 578 579def date_add_interval_sql( 580 data_type: str, kind: str 581) -> t.Callable[[Generator, exp.Expression], str]: 582 def func(self: Generator, expression: exp.Expression) -> str: 583 this = self.sql(expression, "this") 584 unit = expression.args.get("unit") 585 unit = exp.var(unit.name.upper() if unit else "DAY") 586 interval = exp.Interval(this=expression.expression.copy(), unit=unit) 587 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 588 589 return func 590 591 592def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 593 return self.func( 594 "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this 595 ) 596 597 598def locate_to_strposition(args: t.List) -> exp.Expression: 599 return exp.StrPosition( 600 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 601 ) 602 603 604def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 605 return self.func( 606 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 607 ) 608 609 610def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 611 expression = expression.copy() 612 return self.sql( 613 exp.Substring( 614 this=expression.this, start=exp.Literal.number(1), length=expression.expression 615 ) 616 ) 617 618 619def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 620 expression = expression.copy() 621 return self.sql( 622 exp.Substring( 623 this=expression.this, 624 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 625 ) 626 ) 627 628 629def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 630 return self.sql(exp.cast(expression.this, "timestamp")) 631 632 633def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 634 return self.sql(exp.cast(expression.this, "date")) 635 636 637# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 638def encode_decode_sql( 639 self: Generator, expression: exp.Expression, name: str, replace: bool = True 640) -> str: 641 charset = expression.args.get("charset") 642 if charset and charset.name.lower() != "utf-8": 643 self.unsupported(f"Expected utf-8 character set, got {charset}.") 644 645 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 646 647 648def min_or_least(self: Generator, expression: exp.Min) -> str: 649 name = "LEAST" if expression.expressions else "MIN" 650 return rename_func(name)(self, expression) 651 652 653def max_or_greatest(self: Generator, expression: exp.Max) -> str: 654 name = "GREATEST" if expression.expressions else "MAX" 655 return rename_func(name)(self, expression) 656 657 658def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 659 cond = expression.this 660 661 if isinstance(expression.this, exp.Distinct): 662 cond = expression.this.expressions[0] 663 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 664 665 return self.func("sum", exp.func("if", cond.copy(), 1, 0)) 666 667 668def trim_sql(self: Generator, expression: exp.Trim) -> str: 669 target = self.sql(expression, "this") 670 trim_type = self.sql(expression, "position") 671 remove_chars = self.sql(expression, "expression") 672 collation = self.sql(expression, "collation") 673 674 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 675 if not remove_chars and not collation: 676 return self.trim_sql(expression) 677 678 trim_type = f"{trim_type} " if trim_type else "" 679 remove_chars = f"{remove_chars} " if remove_chars else "" 680 from_part = "FROM " if trim_type or remove_chars else "" 681 collation = f" COLLATE {collation}" if collation else "" 682 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 683 684 685def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 686 return self.func("STRPTIME", expression.this, self.format_time(expression)) 687 688 689def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 690 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 691 _dialect = Dialect.get_or_raise(dialect) 692 time_format = self.format_time(expression) 693 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 694 return self.sql(exp.cast(str_to_time_sql(self, expression), "date")) 695 696 return self.sql(exp.cast(self.sql(expression, "this"), "date")) 697 698 return _ts_or_ds_to_date_sql 699 700 701def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str: 702 expression = expression.copy() 703 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 704 705 706def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 707 expression = expression.copy() 708 delim, *rest_args = expression.expressions 709 return self.sql( 710 reduce( 711 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 712 rest_args, 713 ) 714 ) 715 716 717def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 718 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 719 if bad_args: 720 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 721 722 return self.func( 723 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 724 ) 725 726 727def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 728 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 729 if bad_args: 730 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 731 732 return self.func( 733 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 734 ) 735 736 737def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 738 names = [] 739 for agg in aggregations: 740 if isinstance(agg, exp.Alias): 741 names.append(agg.alias) 742 else: 743 """ 744 This case corresponds to aggregations without aliases being used as suffixes 745 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 746 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 747 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 748 """ 749 agg_all_unquoted = agg.transform( 750 lambda node: exp.Identifier(this=node.name, quoted=False) 751 if isinstance(node, exp.Identifier) 752 else node 753 ) 754 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 755 756 return names 757 758 759def simplify_literal(expression: E) -> E: 760 if not isinstance(expression.expression, exp.Literal): 761 from sqlglot.optimizer.simplify import simplify 762 763 simplify(expression.expression) 764 765 return expression 766 767 768def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 769 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 770 771 772# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 773def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 774 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 775 776 777def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 778 return self.func("MAX", expression.this) 779 780 781def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 782 a = self.sql(expression.left) 783 b = self.sql(expression.right) 784 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 785 786 787# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon 788def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str: 789 return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}" 790 791 792def is_parse_json(expression: exp.Expression) -> bool: 793 return isinstance(expression, exp.ParseJSON) or ( 794 isinstance(expression, exp.Cast) and expression.is_type("json") 795 ) 796 797 798def isnull_to_is_null(args: t.List) -> exp.Expression: 799 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 800 801 802def move_insert_cte_sql(self: Generator, expression: exp.Insert) -> str: 803 if expression.expression.args.get("with"): 804 expression = expression.copy() 805 expression.set("with", expression.expression.args["with"].pop()) 806 return self.insert_sql(expression)
21class Dialects(str, Enum): 22 DIALECT = "" 23 24 BIGQUERY = "bigquery" 25 CLICKHOUSE = "clickhouse" 26 DATABRICKS = "databricks" 27 DRILL = "drill" 28 DUCKDB = "duckdb" 29 HIVE = "hive" 30 MYSQL = "mysql" 31 ORACLE = "oracle" 32 POSTGRES = "postgres" 33 PRESTO = "presto" 34 REDSHIFT = "redshift" 35 SNOWFLAKE = "snowflake" 36 SPARK = "spark" 37 SPARK2 = "spark2" 38 SQLITE = "sqlite" 39 STARROCKS = "starrocks" 40 TABLEAU = "tableau" 41 TERADATA = "teradata" 42 TRINO = "trino" 43 TSQL = "tsql" 44 Doris = "doris"
An enumeration.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
142class Dialect(metaclass=_Dialect): 143 # Determines the base index offset for arrays 144 INDEX_OFFSET = 0 145 146 # If true unnest table aliases are considered only as column aliases 147 UNNEST_COLUMN_ONLY = False 148 149 # Determines whether or not the table alias comes after tablesample 150 ALIAS_POST_TABLESAMPLE = False 151 152 # Determines whether or not unquoted identifiers are resolved as uppercase 153 # When set to None, it means that the dialect treats all identifiers as case-insensitive 154 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 155 156 # Determines whether or not an unquoted identifier can start with a digit 157 IDENTIFIERS_CAN_START_WITH_DIGIT = False 158 159 # Determines whether or not the DPIPE token ('||') is a string concatenation operator 160 DPIPE_IS_STRING_CONCAT = True 161 162 # Determines whether or not CONCAT's arguments must be strings 163 STRICT_STRING_CONCAT = False 164 165 # Determines whether or not user-defined data types are supported 166 SUPPORTS_USER_DEFINED_TYPES = True 167 168 # Determines whether or not SEMI/ANTI JOINs are supported 169 SUPPORTS_SEMI_ANTI_JOIN = True 170 171 # Determines how function names are going to be normalized 172 NORMALIZE_FUNCTIONS: bool | str = "upper" 173 174 # Determines whether the base comes first in the LOG function 175 LOG_BASE_FIRST = True 176 177 # Indicates the default null ordering method to use if not explicitly set 178 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 179 NULL_ORDERING = "nulls_are_small" 180 181 DATE_FORMAT = "'%Y-%m-%d'" 182 DATEINT_FORMAT = "'%Y%m%d'" 183 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 184 185 # Custom time mappings in which the key represents dialect time format 186 # and the value represents a python time format 187 TIME_MAPPING: t.Dict[str, str] = {} 188 189 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 190 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 191 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 192 FORMAT_MAPPING: t.Dict[str, str] = {} 193 194 # Mapping of an unescaped escape sequence to the corresponding character 195 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 196 197 # Columns that are auto-generated by the engine corresponding to this dialect 198 # Such columns may be excluded from SELECT * queries, for example 199 PSEUDOCOLUMNS: t.Set[str] = set() 200 201 # Autofilled 202 tokenizer_class = Tokenizer 203 parser_class = Parser 204 generator_class = Generator 205 206 # A trie of the time_mapping keys 207 TIME_TRIE: t.Dict = {} 208 FORMAT_TRIE: t.Dict = {} 209 210 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 211 INVERSE_TIME_TRIE: t.Dict = {} 212 213 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 214 215 def __eq__(self, other: t.Any) -> bool: 216 return type(self) == other 217 218 def __hash__(self) -> int: 219 return hash(type(self)) 220 221 @classmethod 222 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 223 if not dialect: 224 return cls 225 if isinstance(dialect, _Dialect): 226 return dialect 227 if isinstance(dialect, Dialect): 228 return dialect.__class__ 229 230 result = cls.get(dialect) 231 if not result: 232 raise ValueError(f"Unknown dialect '{dialect}'") 233 234 return result 235 236 @classmethod 237 def format_time( 238 cls, expression: t.Optional[str | exp.Expression] 239 ) -> t.Optional[exp.Expression]: 240 if isinstance(expression, str): 241 return exp.Literal.string( 242 # the time formats are quoted 243 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 244 ) 245 246 if expression and expression.is_string: 247 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 248 249 return expression 250 251 @classmethod 252 def normalize_identifier(cls, expression: E) -> E: 253 """ 254 Normalizes an unquoted identifier to either lower or upper case, thus essentially 255 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 256 they will be normalized to lowercase regardless of being quoted or not. 257 """ 258 if isinstance(expression, exp.Identifier) and ( 259 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 260 ): 261 expression.set( 262 "this", 263 expression.this.upper() 264 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 265 else expression.this.lower(), 266 ) 267 268 return expression 269 270 @classmethod 271 def case_sensitive(cls, text: str) -> bool: 272 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 273 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 274 return False 275 276 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 277 return any(unsafe(char) for char in text) 278 279 @classmethod 280 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 281 """Checks if text can be identified given an identify option. 282 283 Args: 284 text: The text to check. 285 identify: 286 "always" or `True`: Always returns true. 287 "safe": True if the identifier is case-insensitive. 288 289 Returns: 290 Whether or not the given text can be identified. 291 """ 292 if identify is True or identify == "always": 293 return True 294 295 if identify == "safe": 296 return not cls.case_sensitive(text) 297 298 return False 299 300 @classmethod 301 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 302 if isinstance(expression, exp.Identifier): 303 name = expression.this 304 expression.set( 305 "quoted", 306 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 307 ) 308 309 return expression 310 311 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 312 return self.parser(**opts).parse(self.tokenize(sql), sql) 313 314 def parse_into( 315 self, expression_type: exp.IntoType, sql: str, **opts 316 ) -> t.List[t.Optional[exp.Expression]]: 317 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 318 319 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 320 return self.generator(**opts).generate(expression) 321 322 def transpile(self, sql: str, **opts) -> t.List[str]: 323 return [self.generate(expression, **opts) for expression in self.parse(sql)] 324 325 def tokenize(self, sql: str) -> t.List[Token]: 326 return self.tokenizer.tokenize(sql) 327 328 @property 329 def tokenizer(self) -> Tokenizer: 330 if not hasattr(self, "_tokenizer"): 331 self._tokenizer = self.tokenizer_class() 332 return self._tokenizer 333 334 def parser(self, **opts) -> Parser: 335 return self.parser_class(**opts) 336 337 def generator(self, **opts) -> Generator: 338 return self.generator_class(**opts)
221 @classmethod 222 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 223 if not dialect: 224 return cls 225 if isinstance(dialect, _Dialect): 226 return dialect 227 if isinstance(dialect, Dialect): 228 return dialect.__class__ 229 230 result = cls.get(dialect) 231 if not result: 232 raise ValueError(f"Unknown dialect '{dialect}'") 233 234 return result
236 @classmethod 237 def format_time( 238 cls, expression: t.Optional[str | exp.Expression] 239 ) -> t.Optional[exp.Expression]: 240 if isinstance(expression, str): 241 return exp.Literal.string( 242 # the time formats are quoted 243 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 244 ) 245 246 if expression and expression.is_string: 247 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 248 249 return expression
251 @classmethod 252 def normalize_identifier(cls, expression: E) -> E: 253 """ 254 Normalizes an unquoted identifier to either lower or upper case, thus essentially 255 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 256 they will be normalized to lowercase regardless of being quoted or not. 257 """ 258 if isinstance(expression, exp.Identifier) and ( 259 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 260 ): 261 expression.set( 262 "this", 263 expression.this.upper() 264 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 265 else expression.this.lower(), 266 ) 267 268 return expression
Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized to lowercase regardless of being quoted or not.
270 @classmethod 271 def case_sensitive(cls, text: str) -> bool: 272 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 273 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 274 return False 275 276 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 277 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
279 @classmethod 280 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 281 """Checks if text can be identified given an identify option. 282 283 Args: 284 text: The text to check. 285 identify: 286 "always" or `True`: Always returns true. 287 "safe": True if the identifier is case-insensitive. 288 289 Returns: 290 Whether or not the given text can be identified. 291 """ 292 if identify is True or identify == "always": 293 return True 294 295 if identify == "safe": 296 return not cls.case_sensitive(text) 297 298 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify: "always" or
True
: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:
Whether or not the given text can be identified.
300 @classmethod 301 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 302 if isinstance(expression, exp.Identifier): 303 name = expression.this 304 expression.set( 305 "quoted", 306 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 307 ) 308 309 return expression
354def if_sql( 355 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 356) -> t.Callable[[Generator, exp.If], str]: 357 def _if_sql(self: Generator, expression: exp.If) -> str: 358 return self.func( 359 name, 360 expression.this, 361 expression.args.get("true"), 362 expression.args.get("false") or false_value, 363 ) 364 365 return _if_sql
439def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 440 this = self.sql(expression, "this") 441 substr = self.sql(expression, "substr") 442 position = self.sql(expression, "position") 443 if position: 444 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 445 return f"STRPOS({this}, {substr})"
454def var_map_sql( 455 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 456) -> str: 457 keys = expression.args["keys"] 458 values = expression.args["values"] 459 460 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 461 self.unsupported("Cannot convert array columns into map.") 462 return self.func(map_func_name, keys, values) 463 464 args = [] 465 for key, value in zip(keys.expressions, values.expressions): 466 args.append(self.sql(key)) 467 args.append(self.sql(value)) 468 469 return self.func(map_func_name, *args)
472def format_time_lambda( 473 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 474) -> t.Callable[[t.List], E]: 475 """Helper used for time expressions. 476 477 Args: 478 exp_class: the expression class to instantiate. 479 dialect: target sql dialect. 480 default: the default format, True being time. 481 482 Returns: 483 A callable that can be used to return the appropriately formatted time expression. 484 """ 485 486 def _format_time(args: t.List): 487 return exp_class( 488 this=seq_get(args, 0), 489 format=Dialect[dialect].format_time( 490 seq_get(args, 1) 491 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 492 ), 493 ) 494 495 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
498def time_format( 499 dialect: DialectType = None, 500) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 501 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 502 """ 503 Returns the time format for a given expression, unless it's equivalent 504 to the default time format of the dialect of interest. 505 """ 506 time_format = self.format_time(expression) 507 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 508 509 return _time_format
512def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 513 """ 514 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 515 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 516 columns are removed from the create statement. 517 """ 518 has_schema = isinstance(expression.this, exp.Schema) 519 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 520 521 if has_schema and is_partitionable: 522 expression = expression.copy() 523 prop = expression.find(exp.PartitionedByProperty) 524 if prop and prop.this and not isinstance(prop.this, exp.Schema): 525 schema = expression.this 526 columns = {v.name.upper() for v in prop.this.expressions} 527 partitions = [col for col in schema.expressions if col.name.upper() in columns] 528 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 529 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 530 expression.set("this", schema) 531 532 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
535def parse_date_delta( 536 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 537) -> t.Callable[[t.List], E]: 538 def inner_func(args: t.List) -> E: 539 unit_based = len(args) == 3 540 this = args[2] if unit_based else seq_get(args, 0) 541 unit = args[0] if unit_based else exp.Literal.string("DAY") 542 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 543 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 544 545 return inner_func
548def parse_date_delta_with_interval( 549 expression_class: t.Type[E], 550) -> t.Callable[[t.List], t.Optional[E]]: 551 def func(args: t.List) -> t.Optional[E]: 552 if len(args) < 2: 553 return None 554 555 interval = args[1] 556 557 if not isinstance(interval, exp.Interval): 558 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 559 560 expression = interval.this 561 if expression and expression.is_string: 562 expression = exp.Literal.number(expression.this) 563 564 return expression_class( 565 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 566 ) 567 568 return func
580def date_add_interval_sql( 581 data_type: str, kind: str 582) -> t.Callable[[Generator, exp.Expression], str]: 583 def func(self: Generator, expression: exp.Expression) -> str: 584 this = self.sql(expression, "this") 585 unit = expression.args.get("unit") 586 unit = exp.var(unit.name.upper() if unit else "DAY") 587 interval = exp.Interval(this=expression.expression.copy(), unit=unit) 588 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 589 590 return func
639def encode_decode_sql( 640 self: Generator, expression: exp.Expression, name: str, replace: bool = True 641) -> str: 642 charset = expression.args.get("charset") 643 if charset and charset.name.lower() != "utf-8": 644 self.unsupported(f"Expected utf-8 character set, got {charset}.") 645 646 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
659def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 660 cond = expression.this 661 662 if isinstance(expression.this, exp.Distinct): 663 cond = expression.this.expressions[0] 664 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 665 666 return self.func("sum", exp.func("if", cond.copy(), 1, 0))
669def trim_sql(self: Generator, expression: exp.Trim) -> str: 670 target = self.sql(expression, "this") 671 trim_type = self.sql(expression, "position") 672 remove_chars = self.sql(expression, "expression") 673 collation = self.sql(expression, "collation") 674 675 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 676 if not remove_chars and not collation: 677 return self.trim_sql(expression) 678 679 trim_type = f"{trim_type} " if trim_type else "" 680 remove_chars = f"{remove_chars} " if remove_chars else "" 681 from_part = "FROM " if trim_type or remove_chars else "" 682 collation = f" COLLATE {collation}" if collation else "" 683 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
690def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 691 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 692 _dialect = Dialect.get_or_raise(dialect) 693 time_format = self.format_time(expression) 694 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 695 return self.sql(exp.cast(str_to_time_sql(self, expression), "date")) 696 697 return self.sql(exp.cast(self.sql(expression, "this"), "date")) 698 699 return _ts_or_ds_to_date_sql
707def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 708 expression = expression.copy() 709 delim, *rest_args = expression.expressions 710 return self.sql( 711 reduce( 712 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 713 rest_args, 714 ) 715 )
718def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 719 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 720 if bad_args: 721 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 722 723 return self.func( 724 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 725 )
728def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 729 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 730 if bad_args: 731 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 732 733 return self.func( 734 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 735 )
738def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 739 names = [] 740 for agg in aggregations: 741 if isinstance(agg, exp.Alias): 742 names.append(agg.alias) 743 else: 744 """ 745 This case corresponds to aggregations without aliases being used as suffixes 746 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 747 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 748 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 749 """ 750 agg_all_unquoted = agg.transform( 751 lambda node: exp.Identifier(this=node.name, quoted=False) 752 if isinstance(node, exp.Identifier) 753 else node 754 ) 755 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 756 757 return names