sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get 12from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26logger = logging.getLogger("sqlglot") 27 28UNESCAPED_SEQUENCES = { 29 "\\a": "\a", 30 "\\b": "\b", 31 "\\f": "\f", 32 "\\n": "\n", 33 "\\r": "\r", 34 "\\t": "\t", 35 "\\v": "\v", 36 "\\\\": "\\", 37} 38 39 40class Dialects(str, Enum): 41 """Dialects supported by SQLGLot.""" 42 43 DIALECT = "" 44 45 ATHENA = "athena" 46 BIGQUERY = "bigquery" 47 CLICKHOUSE = "clickhouse" 48 DATABRICKS = "databricks" 49 DORIS = "doris" 50 DRILL = "drill" 51 DUCKDB = "duckdb" 52 HIVE = "hive" 53 MATERIALIZE = "materialize" 54 MYSQL = "mysql" 55 ORACLE = "oracle" 56 POSTGRES = "postgres" 57 PRESTO = "presto" 58 PRQL = "prql" 59 REDSHIFT = "redshift" 60 RISINGWAVE = "risingwave" 61 SNOWFLAKE = "snowflake" 62 SPARK = "spark" 63 SPARK2 = "spark2" 64 SQLITE = "sqlite" 65 STARROCKS = "starrocks" 66 TABLEAU = "tableau" 67 TERADATA = "teradata" 68 TRINO = "trino" 69 TSQL = "tsql" 70 71 72class NormalizationStrategy(str, AutoName): 73 """Specifies the strategy according to which identifiers should be normalized.""" 74 75 LOWERCASE = auto() 76 """Unquoted identifiers are lowercased.""" 77 78 UPPERCASE = auto() 79 """Unquoted identifiers are uppercased.""" 80 81 CASE_SENSITIVE = auto() 82 """Always case-sensitive, regardless of quotes.""" 83 84 CASE_INSENSITIVE = auto() 85 """Always case-insensitive, regardless of quotes.""" 86 87 88class _Dialect(type): 89 classes: t.Dict[str, t.Type[Dialect]] = {} 90 91 def __eq__(cls, other: t.Any) -> bool: 92 if cls is other: 93 return True 94 if isinstance(other, str): 95 return cls is cls.get(other) 96 if isinstance(other, Dialect): 97 return cls is type(other) 98 99 return False 100 101 def __hash__(cls) -> int: 102 return hash(cls.__name__.lower()) 103 104 @classmethod 105 def __getitem__(cls, key: str) -> t.Type[Dialect]: 106 return cls.classes[key] 107 108 @classmethod 109 def get( 110 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 111 ) -> t.Optional[t.Type[Dialect]]: 112 return cls.classes.get(key, default) 113 114 def __new__(cls, clsname, bases, attrs): 115 klass = super().__new__(cls, clsname, bases, attrs) 116 enum = Dialects.__members__.get(clsname.upper()) 117 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 118 119 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 120 klass.FORMAT_TRIE = ( 121 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 122 ) 123 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 124 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 125 klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()} 126 klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING) 127 128 base = seq_get(bases, 0) 129 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 130 base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),) 131 base_parser = (getattr(base, "parser_class", Parser),) 132 base_generator = (getattr(base, "generator_class", Generator),) 133 134 klass.tokenizer_class = klass.__dict__.get( 135 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 136 ) 137 klass.jsonpath_tokenizer_class = klass.__dict__.get( 138 "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {}) 139 ) 140 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 141 klass.generator_class = klass.__dict__.get( 142 "Generator", type("Generator", base_generator, {}) 143 ) 144 145 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 146 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 147 klass.tokenizer_class._IDENTIFIERS.items() 148 )[0] 149 150 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 151 return next( 152 ( 153 (s, e) 154 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 155 if t == token_type 156 ), 157 (None, None), 158 ) 159 160 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 161 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 162 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 163 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 164 165 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 166 klass.UNESCAPED_SEQUENCES = { 167 **UNESCAPED_SEQUENCES, 168 **klass.UNESCAPED_SEQUENCES, 169 } 170 171 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 172 173 klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS 174 175 if enum not in ("", "bigquery"): 176 klass.generator_class.SELECT_KINDS = () 177 178 if enum not in ("", "athena", "presto", "trino"): 179 klass.generator_class.TRY_SUPPORTED = False 180 klass.generator_class.SUPPORTS_UESCAPE = False 181 182 if enum not in ("", "databricks", "hive", "spark", "spark2"): 183 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 184 for modifier in ("cluster", "distribute", "sort"): 185 modifier_transforms.pop(modifier, None) 186 187 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 188 189 if enum not in ("", "doris", "mysql"): 190 klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { 191 TokenType.STRAIGHT_JOIN, 192 } 193 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 194 TokenType.STRAIGHT_JOIN, 195 } 196 197 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 198 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 199 TokenType.ANTI, 200 TokenType.SEMI, 201 } 202 203 return klass 204 205 206class Dialect(metaclass=_Dialect): 207 INDEX_OFFSET = 0 208 """The base index offset for arrays.""" 209 210 WEEK_OFFSET = 0 211 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 212 213 UNNEST_COLUMN_ONLY = False 214 """Whether `UNNEST` table aliases are treated as column aliases.""" 215 216 ALIAS_POST_TABLESAMPLE = False 217 """Whether the table alias comes after tablesample.""" 218 219 TABLESAMPLE_SIZE_IS_PERCENT = False 220 """Whether a size in the table sample clause represents percentage.""" 221 222 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 223 """Specifies the strategy according to which identifiers should be normalized.""" 224 225 IDENTIFIERS_CAN_START_WITH_DIGIT = False 226 """Whether an unquoted identifier can start with a digit.""" 227 228 DPIPE_IS_STRING_CONCAT = True 229 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 230 231 STRICT_STRING_CONCAT = False 232 """Whether `CONCAT`'s arguments must be strings.""" 233 234 SUPPORTS_USER_DEFINED_TYPES = True 235 """Whether user-defined data types are supported.""" 236 237 SUPPORTS_SEMI_ANTI_JOIN = True 238 """Whether `SEMI` or `ANTI` joins are supported.""" 239 240 SUPPORTS_COLUMN_JOIN_MARKS = False 241 """Whether the old-style outer join (+) syntax is supported.""" 242 243 COPY_PARAMS_ARE_CSV = True 244 """Separator of COPY statement parameters.""" 245 246 NORMALIZE_FUNCTIONS: bool | str = "upper" 247 """ 248 Determines how function names are going to be normalized. 249 Possible values: 250 "upper" or True: Convert names to uppercase. 251 "lower": Convert names to lowercase. 252 False: Disables function name normalization. 253 """ 254 255 LOG_BASE_FIRST: t.Optional[bool] = True 256 """ 257 Whether the base comes first in the `LOG` function. 258 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 259 """ 260 261 NULL_ORDERING = "nulls_are_small" 262 """ 263 Default `NULL` ordering method to use if not explicitly set. 264 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 265 """ 266 267 TYPED_DIVISION = False 268 """ 269 Whether the behavior of `a / b` depends on the types of `a` and `b`. 270 False means `a / b` is always float division. 271 True means `a / b` is integer division if both `a` and `b` are integers. 272 """ 273 274 SAFE_DIVISION = False 275 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 276 277 CONCAT_COALESCE = False 278 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 279 280 HEX_LOWERCASE = False 281 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 282 283 DATE_FORMAT = "'%Y-%m-%d'" 284 DATEINT_FORMAT = "'%Y%m%d'" 285 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 286 287 TIME_MAPPING: t.Dict[str, str] = {} 288 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 289 290 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 291 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 292 FORMAT_MAPPING: t.Dict[str, str] = {} 293 """ 294 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 295 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 296 """ 297 298 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 299 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 300 301 PSEUDOCOLUMNS: t.Set[str] = set() 302 """ 303 Columns that are auto-generated by the engine corresponding to this dialect. 304 For example, such columns may be excluded from `SELECT *` queries. 305 """ 306 307 PREFER_CTE_ALIAS_COLUMN = False 308 """ 309 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 310 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 311 any projection aliases in the subquery. 312 313 For example, 314 WITH y(c) AS ( 315 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 316 ) SELECT c FROM y; 317 318 will be rewritten as 319 320 WITH y(c) AS ( 321 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 322 ) SELECT c FROM y; 323 """ 324 325 COPY_PARAMS_ARE_CSV = True 326 """ 327 Whether COPY statement parameters are separated by comma or whitespace 328 """ 329 330 FORCE_EARLY_ALIAS_REF_EXPANSION = False 331 """ 332 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 333 334 For example: 335 WITH data AS ( 336 SELECT 337 1 AS id, 338 2 AS my_id 339 ) 340 SELECT 341 id AS my_id 342 FROM 343 data 344 WHERE 345 my_id = 1 346 GROUP BY 347 my_id, 348 HAVING 349 my_id = 1 350 351 In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: 352 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 353 - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1" 354 """ 355 356 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 357 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 358 359 SUPPORTS_ORDER_BY_ALL = False 360 """ 361 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 362 """ 363 364 # --- Autofilled --- 365 366 tokenizer_class = Tokenizer 367 jsonpath_tokenizer_class = JSONPathTokenizer 368 parser_class = Parser 369 generator_class = Generator 370 371 # A trie of the time_mapping keys 372 TIME_TRIE: t.Dict = {} 373 FORMAT_TRIE: t.Dict = {} 374 375 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 376 INVERSE_TIME_TRIE: t.Dict = {} 377 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 378 INVERSE_FORMAT_TRIE: t.Dict = {} 379 380 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 381 382 # Delimiters for string literals and identifiers 383 QUOTE_START = "'" 384 QUOTE_END = "'" 385 IDENTIFIER_START = '"' 386 IDENTIFIER_END = '"' 387 388 # Delimiters for bit, hex, byte and unicode literals 389 BIT_START: t.Optional[str] = None 390 BIT_END: t.Optional[str] = None 391 HEX_START: t.Optional[str] = None 392 HEX_END: t.Optional[str] = None 393 BYTE_START: t.Optional[str] = None 394 BYTE_END: t.Optional[str] = None 395 UNICODE_START: t.Optional[str] = None 396 UNICODE_END: t.Optional[str] = None 397 398 DATE_PART_MAPPING = { 399 "Y": "YEAR", 400 "YY": "YEAR", 401 "YYY": "YEAR", 402 "YYYY": "YEAR", 403 "YR": "YEAR", 404 "YEARS": "YEAR", 405 "YRS": "YEAR", 406 "MM": "MONTH", 407 "MON": "MONTH", 408 "MONS": "MONTH", 409 "MONTHS": "MONTH", 410 "D": "DAY", 411 "DD": "DAY", 412 "DAYS": "DAY", 413 "DAYOFMONTH": "DAY", 414 "DAY OF WEEK": "DAYOFWEEK", 415 "WEEKDAY": "DAYOFWEEK", 416 "DOW": "DAYOFWEEK", 417 "DW": "DAYOFWEEK", 418 "WEEKDAY_ISO": "DAYOFWEEKISO", 419 "DOW_ISO": "DAYOFWEEKISO", 420 "DW_ISO": "DAYOFWEEKISO", 421 "DAY OF YEAR": "DAYOFYEAR", 422 "DOY": "DAYOFYEAR", 423 "DY": "DAYOFYEAR", 424 "W": "WEEK", 425 "WK": "WEEK", 426 "WEEKOFYEAR": "WEEK", 427 "WOY": "WEEK", 428 "WY": "WEEK", 429 "WEEK_ISO": "WEEKISO", 430 "WEEKOFYEARISO": "WEEKISO", 431 "WEEKOFYEAR_ISO": "WEEKISO", 432 "Q": "QUARTER", 433 "QTR": "QUARTER", 434 "QTRS": "QUARTER", 435 "QUARTERS": "QUARTER", 436 "H": "HOUR", 437 "HH": "HOUR", 438 "HR": "HOUR", 439 "HOURS": "HOUR", 440 "HRS": "HOUR", 441 "M": "MINUTE", 442 "MI": "MINUTE", 443 "MIN": "MINUTE", 444 "MINUTES": "MINUTE", 445 "MINS": "MINUTE", 446 "S": "SECOND", 447 "SEC": "SECOND", 448 "SECONDS": "SECOND", 449 "SECS": "SECOND", 450 "MS": "MILLISECOND", 451 "MSEC": "MILLISECOND", 452 "MSECS": "MILLISECOND", 453 "MSECOND": "MILLISECOND", 454 "MSECONDS": "MILLISECOND", 455 "MILLISEC": "MILLISECOND", 456 "MILLISECS": "MILLISECOND", 457 "MILLISECON": "MILLISECOND", 458 "MILLISECONDS": "MILLISECOND", 459 "US": "MICROSECOND", 460 "USEC": "MICROSECOND", 461 "USECS": "MICROSECOND", 462 "MICROSEC": "MICROSECOND", 463 "MICROSECS": "MICROSECOND", 464 "USECOND": "MICROSECOND", 465 "USECONDS": "MICROSECOND", 466 "MICROSECONDS": "MICROSECOND", 467 "NS": "NANOSECOND", 468 "NSEC": "NANOSECOND", 469 "NANOSEC": "NANOSECOND", 470 "NSECOND": "NANOSECOND", 471 "NSECONDS": "NANOSECOND", 472 "NANOSECS": "NANOSECOND", 473 "EPOCH_SECOND": "EPOCH", 474 "EPOCH_SECONDS": "EPOCH", 475 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 476 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 477 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 478 "TZH": "TIMEZONE_HOUR", 479 "TZM": "TIMEZONE_MINUTE", 480 "DEC": "DECADE", 481 "DECS": "DECADE", 482 "DECADES": "DECADE", 483 "MIL": "MILLENIUM", 484 "MILS": "MILLENIUM", 485 "MILLENIA": "MILLENIUM", 486 "C": "CENTURY", 487 "CENT": "CENTURY", 488 "CENTS": "CENTURY", 489 "CENTURIES": "CENTURY", 490 } 491 492 @classmethod 493 def get_or_raise(cls, dialect: DialectType) -> Dialect: 494 """ 495 Look up a dialect in the global dialect registry and return it if it exists. 496 497 Args: 498 dialect: The target dialect. If this is a string, it can be optionally followed by 499 additional key-value pairs that are separated by commas and are used to specify 500 dialect settings, such as whether the dialect's identifiers are case-sensitive. 501 502 Example: 503 >>> dialect = dialect_class = get_or_raise("duckdb") 504 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 505 506 Returns: 507 The corresponding Dialect instance. 508 """ 509 510 if not dialect: 511 return cls() 512 if isinstance(dialect, _Dialect): 513 return dialect() 514 if isinstance(dialect, Dialect): 515 return dialect 516 if isinstance(dialect, str): 517 try: 518 dialect_name, *kv_strings = dialect.split(",") 519 kv_pairs = (kv.split("=") for kv in kv_strings) 520 kwargs = {} 521 for pair in kv_pairs: 522 key = pair[0].strip() 523 value: t.Union[bool | str | None] = None 524 525 if len(pair) == 1: 526 # Default initialize standalone settings to True 527 value = True 528 elif len(pair) == 2: 529 value = pair[1].strip() 530 531 # Coerce the value to boolean if it matches to the truthy/falsy values below 532 value_lower = value.lower() 533 if value_lower in ("true", "1"): 534 value = True 535 elif value_lower in ("false", "0"): 536 value = False 537 538 kwargs[key] = value 539 540 except ValueError: 541 raise ValueError( 542 f"Invalid dialect format: '{dialect}'. " 543 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 544 ) 545 546 result = cls.get(dialect_name.strip()) 547 if not result: 548 from difflib import get_close_matches 549 550 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 551 if similar: 552 similar = f" Did you mean {similar}?" 553 554 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 555 556 return result(**kwargs) 557 558 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 559 560 @classmethod 561 def format_time( 562 cls, expression: t.Optional[str | exp.Expression] 563 ) -> t.Optional[exp.Expression]: 564 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 565 if isinstance(expression, str): 566 return exp.Literal.string( 567 # the time formats are quoted 568 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 569 ) 570 571 if expression and expression.is_string: 572 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 573 574 return expression 575 576 def __init__(self, **kwargs) -> None: 577 normalization_strategy = kwargs.pop("normalization_strategy", None) 578 579 if normalization_strategy is None: 580 self.normalization_strategy = self.NORMALIZATION_STRATEGY 581 else: 582 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 583 584 self.settings = kwargs 585 586 def __eq__(self, other: t.Any) -> bool: 587 # Does not currently take dialect state into account 588 return type(self) == other 589 590 def __hash__(self) -> int: 591 # Does not currently take dialect state into account 592 return hash(type(self)) 593 594 def normalize_identifier(self, expression: E) -> E: 595 """ 596 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 597 598 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 599 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 600 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 601 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 602 603 There are also dialects like Spark, which are case-insensitive even when quotes are 604 present, and dialects like MySQL, whose resolution rules match those employed by the 605 underlying operating system, for example they may always be case-sensitive in Linux. 606 607 Finally, the normalization behavior of some engines can even be controlled through flags, 608 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 609 610 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 611 that it can analyze queries in the optimizer and successfully capture their semantics. 612 """ 613 if ( 614 isinstance(expression, exp.Identifier) 615 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 616 and ( 617 not expression.quoted 618 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 619 ) 620 ): 621 expression.set( 622 "this", 623 ( 624 expression.this.upper() 625 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 626 else expression.this.lower() 627 ), 628 ) 629 630 return expression 631 632 def case_sensitive(self, text: str) -> bool: 633 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 634 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 635 return False 636 637 unsafe = ( 638 str.islower 639 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 640 else str.isupper 641 ) 642 return any(unsafe(char) for char in text) 643 644 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 645 """Checks if text can be identified given an identify option. 646 647 Args: 648 text: The text to check. 649 identify: 650 `"always"` or `True`: Always returns `True`. 651 `"safe"`: Only returns `True` if the identifier is case-insensitive. 652 653 Returns: 654 Whether the given text can be identified. 655 """ 656 if identify is True or identify == "always": 657 return True 658 659 if identify == "safe": 660 return not self.case_sensitive(text) 661 662 return False 663 664 def quote_identifier(self, expression: E, identify: bool = True) -> E: 665 """ 666 Adds quotes to a given identifier. 667 668 Args: 669 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 670 identify: If set to `False`, the quotes will only be added if the identifier is deemed 671 "unsafe", with respect to its characters and this dialect's normalization strategy. 672 """ 673 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 674 name = expression.this 675 expression.set( 676 "quoted", 677 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 678 ) 679 680 return expression 681 682 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 683 if isinstance(path, exp.Literal): 684 path_text = path.name 685 if path.is_number: 686 path_text = f"[{path_text}]" 687 try: 688 return parse_json_path(path_text, self) 689 except ParseError as e: 690 logger.warning(f"Invalid JSON path syntax. {str(e)}") 691 692 return path 693 694 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 695 return self.parser(**opts).parse(self.tokenize(sql), sql) 696 697 def parse_into( 698 self, expression_type: exp.IntoType, sql: str, **opts 699 ) -> t.List[t.Optional[exp.Expression]]: 700 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 701 702 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 703 return self.generator(**opts).generate(expression, copy=copy) 704 705 def transpile(self, sql: str, **opts) -> t.List[str]: 706 return [ 707 self.generate(expression, copy=False, **opts) if expression else "" 708 for expression in self.parse(sql) 709 ] 710 711 def tokenize(self, sql: str) -> t.List[Token]: 712 return self.tokenizer.tokenize(sql) 713 714 @property 715 def tokenizer(self) -> Tokenizer: 716 return self.tokenizer_class(dialect=self) 717 718 @property 719 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 720 return self.jsonpath_tokenizer_class(dialect=self) 721 722 def parser(self, **opts) -> Parser: 723 return self.parser_class(dialect=self, **opts) 724 725 def generator(self, **opts) -> Generator: 726 return self.generator_class(dialect=self, **opts) 727 728 729DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 730 731 732def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 733 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 734 735 736def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 737 if expression.args.get("accuracy"): 738 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 739 return self.func("APPROX_COUNT_DISTINCT", expression.this) 740 741 742def if_sql( 743 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 744) -> t.Callable[[Generator, exp.If], str]: 745 def _if_sql(self: Generator, expression: exp.If) -> str: 746 return self.func( 747 name, 748 expression.this, 749 expression.args.get("true"), 750 expression.args.get("false") or false_value, 751 ) 752 753 return _if_sql 754 755 756def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 757 this = expression.this 758 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 759 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 760 761 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 762 763 764def inline_array_sql(self: Generator, expression: exp.Array) -> str: 765 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 766 767 768def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 769 elem = seq_get(expression.expressions, 0) 770 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 771 return self.func("ARRAY", elem) 772 return inline_array_sql(self, expression) 773 774 775def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 776 return self.like_sql( 777 exp.Like( 778 this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) 779 ) 780 ) 781 782 783def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 784 zone = self.sql(expression, "this") 785 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 786 787 788def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 789 if expression.args.get("recursive"): 790 self.unsupported("Recursive CTEs are unsupported") 791 expression.args["recursive"] = False 792 return self.with_sql(expression) 793 794 795def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 796 n = self.sql(expression, "this") 797 d = self.sql(expression, "expression") 798 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 799 800 801def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 802 self.unsupported("TABLESAMPLE unsupported") 803 return self.sql(expression.this) 804 805 806def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 807 self.unsupported("PIVOT unsupported") 808 return "" 809 810 811def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 812 return self.cast_sql(expression) 813 814 815def no_comment_column_constraint_sql( 816 self: Generator, expression: exp.CommentColumnConstraint 817) -> str: 818 self.unsupported("CommentColumnConstraint unsupported") 819 return "" 820 821 822def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 823 self.unsupported("MAP_FROM_ENTRIES unsupported") 824 return "" 825 826 827def str_position_sql( 828 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 829) -> str: 830 this = self.sql(expression, "this") 831 substr = self.sql(expression, "substr") 832 position = self.sql(expression, "position") 833 instance = expression.args.get("instance") if generate_instance else None 834 position_offset = "" 835 836 if position: 837 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 838 this = self.func("SUBSTR", this, position) 839 position_offset = f" + {position} - 1" 840 841 return self.func("STRPOS", this, substr, instance) + position_offset 842 843 844def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 845 return ( 846 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 847 ) 848 849 850def var_map_sql( 851 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 852) -> str: 853 keys = expression.args["keys"] 854 values = expression.args["values"] 855 856 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 857 self.unsupported("Cannot convert array columns into map.") 858 return self.func(map_func_name, keys, values) 859 860 args = [] 861 for key, value in zip(keys.expressions, values.expressions): 862 args.append(self.sql(key)) 863 args.append(self.sql(value)) 864 865 return self.func(map_func_name, *args) 866 867 868def build_formatted_time( 869 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 870) -> t.Callable[[t.List], E]: 871 """Helper used for time expressions. 872 873 Args: 874 exp_class: the expression class to instantiate. 875 dialect: target sql dialect. 876 default: the default format, True being time. 877 878 Returns: 879 A callable that can be used to return the appropriately formatted time expression. 880 """ 881 882 def _builder(args: t.List): 883 return exp_class( 884 this=seq_get(args, 0), 885 format=Dialect[dialect].format_time( 886 seq_get(args, 1) 887 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 888 ), 889 ) 890 891 return _builder 892 893 894def time_format( 895 dialect: DialectType = None, 896) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 897 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 898 """ 899 Returns the time format for a given expression, unless it's equivalent 900 to the default time format of the dialect of interest. 901 """ 902 time_format = self.format_time(expression) 903 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 904 905 return _time_format 906 907 908def build_date_delta( 909 exp_class: t.Type[E], 910 unit_mapping: t.Optional[t.Dict[str, str]] = None, 911 default_unit: t.Optional[str] = "DAY", 912) -> t.Callable[[t.List], E]: 913 def _builder(args: t.List) -> E: 914 unit_based = len(args) == 3 915 this = args[2] if unit_based else seq_get(args, 0) 916 unit = None 917 if unit_based or default_unit: 918 unit = args[0] if unit_based else exp.Literal.string(default_unit) 919 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 920 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 921 922 return _builder 923 924 925def build_date_delta_with_interval( 926 expression_class: t.Type[E], 927) -> t.Callable[[t.List], t.Optional[E]]: 928 def _builder(args: t.List) -> t.Optional[E]: 929 if len(args) < 2: 930 return None 931 932 interval = args[1] 933 934 if not isinstance(interval, exp.Interval): 935 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 936 937 expression = interval.this 938 if expression and expression.is_string: 939 expression = exp.Literal.number(expression.this) 940 941 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 942 943 return _builder 944 945 946def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 947 unit = seq_get(args, 0) 948 this = seq_get(args, 1) 949 950 if isinstance(this, exp.Cast) and this.is_type("date"): 951 return exp.DateTrunc(unit=unit, this=this) 952 return exp.TimestampTrunc(this=this, unit=unit) 953 954 955def date_add_interval_sql( 956 data_type: str, kind: str 957) -> t.Callable[[Generator, exp.Expression], str]: 958 def func(self: Generator, expression: exp.Expression) -> str: 959 this = self.sql(expression, "this") 960 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 961 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 962 963 return func 964 965 966def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 967 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 968 args = [unit_to_str(expression), expression.this] 969 if zone: 970 args.append(expression.args.get("zone")) 971 return self.func("DATE_TRUNC", *args) 972 973 return _timestamptrunc_sql 974 975 976def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 977 zone = expression.args.get("zone") 978 if not zone: 979 from sqlglot.optimizer.annotate_types import annotate_types 980 981 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 982 return self.sql(exp.cast(expression.this, target_type)) 983 if zone.name.lower() in TIMEZONES: 984 return self.sql( 985 exp.AtTimeZone( 986 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 987 zone=zone, 988 ) 989 ) 990 return self.func("TIMESTAMP", expression.this, zone) 991 992 993def no_time_sql(self: Generator, expression: exp.Time) -> str: 994 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 995 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 996 expr = exp.cast( 997 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 998 ) 999 return self.sql(expr) 1000 1001 1002def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1003 this = expression.this 1004 expr = expression.expression 1005 1006 if expr.name.lower() in TIMEZONES: 1007 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1008 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1009 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1010 return self.sql(this) 1011 1012 this = exp.cast(this, exp.DataType.Type.DATE) 1013 expr = exp.cast(expr, exp.DataType.Type.TIME) 1014 1015 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP)) 1016 1017 1018def locate_to_strposition(args: t.List) -> exp.Expression: 1019 return exp.StrPosition( 1020 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 1021 ) 1022 1023 1024def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 1025 return self.func( 1026 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 1027 ) 1028 1029 1030def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1031 return self.sql( 1032 exp.Substring( 1033 this=expression.this, start=exp.Literal.number(1), length=expression.expression 1034 ) 1035 ) 1036 1037 1038def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1039 return self.sql( 1040 exp.Substring( 1041 this=expression.this, 1042 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 1043 ) 1044 ) 1045 1046 1047def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 1048 return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP)) 1049 1050 1051def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 1052 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 1053 1054 1055# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 1056def encode_decode_sql( 1057 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1058) -> str: 1059 charset = expression.args.get("charset") 1060 if charset and charset.name.lower() != "utf-8": 1061 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1062 1063 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 1064 1065 1066def min_or_least(self: Generator, expression: exp.Min) -> str: 1067 name = "LEAST" if expression.expressions else "MIN" 1068 return rename_func(name)(self, expression) 1069 1070 1071def max_or_greatest(self: Generator, expression: exp.Max) -> str: 1072 name = "GREATEST" if expression.expressions else "MAX" 1073 return rename_func(name)(self, expression) 1074 1075 1076def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1077 cond = expression.this 1078 1079 if isinstance(expression.this, exp.Distinct): 1080 cond = expression.this.expressions[0] 1081 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1082 1083 return self.func("sum", exp.func("if", cond, 1, 0)) 1084 1085 1086def trim_sql(self: Generator, expression: exp.Trim) -> str: 1087 target = self.sql(expression, "this") 1088 trim_type = self.sql(expression, "position") 1089 remove_chars = self.sql(expression, "expression") 1090 collation = self.sql(expression, "collation") 1091 1092 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1093 if not remove_chars and not collation: 1094 return self.trim_sql(expression) 1095 1096 trim_type = f"{trim_type} " if trim_type else "" 1097 remove_chars = f"{remove_chars} " if remove_chars else "" 1098 from_part = "FROM " if trim_type or remove_chars else "" 1099 collation = f" COLLATE {collation}" if collation else "" 1100 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 1101 1102 1103def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 1104 return self.func("STRPTIME", expression.this, self.format_time(expression)) 1105 1106 1107def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 1108 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 1109 1110 1111def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 1112 delim, *rest_args = expression.expressions 1113 return self.sql( 1114 reduce( 1115 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 1116 rest_args, 1117 ) 1118 ) 1119 1120 1121def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1122 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1123 if bad_args: 1124 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1125 1126 return self.func( 1127 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 1128 ) 1129 1130 1131def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1132 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1133 if bad_args: 1134 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1135 1136 return self.func( 1137 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1138 ) 1139 1140 1141def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1142 names = [] 1143 for agg in aggregations: 1144 if isinstance(agg, exp.Alias): 1145 names.append(agg.alias) 1146 else: 1147 """ 1148 This case corresponds to aggregations without aliases being used as suffixes 1149 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1150 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1151 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1152 """ 1153 agg_all_unquoted = agg.transform( 1154 lambda node: ( 1155 exp.Identifier(this=node.name, quoted=False) 1156 if isinstance(node, exp.Identifier) 1157 else node 1158 ) 1159 ) 1160 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1161 1162 return names 1163 1164 1165def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 1166 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 1167 1168 1169# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 1170def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 1171 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 1172 1173 1174def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 1175 return self.func("MAX", expression.this) 1176 1177 1178def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 1179 a = self.sql(expression.left) 1180 b = self.sql(expression.right) 1181 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 1182 1183 1184def is_parse_json(expression: exp.Expression) -> bool: 1185 return isinstance(expression, exp.ParseJSON) or ( 1186 isinstance(expression, exp.Cast) and expression.is_type("json") 1187 ) 1188 1189 1190def isnull_to_is_null(args: t.List) -> exp.Expression: 1191 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 1192 1193 1194def generatedasidentitycolumnconstraint_sql( 1195 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 1196) -> str: 1197 start = self.sql(expression, "start") or "1" 1198 increment = self.sql(expression, "increment") or "1" 1199 return f"IDENTITY({start}, {increment})" 1200 1201 1202def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1203 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1204 if expression.args.get("count"): 1205 self.unsupported(f"Only two arguments are supported in function {name}.") 1206 1207 return self.func(name, expression.this, expression.expression) 1208 1209 return _arg_max_or_min_sql 1210 1211 1212def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1213 this = expression.this.copy() 1214 1215 return_type = expression.return_type 1216 if return_type.is_type(exp.DataType.Type.DATE): 1217 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1218 # can truncate timestamp strings, because some dialects can't cast them to DATE 1219 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1220 1221 expression.this.replace(exp.cast(this, return_type)) 1222 return expression 1223 1224 1225def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1226 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1227 if cast and isinstance(expression, exp.TsOrDsAdd): 1228 expression = ts_or_ds_add_cast(expression) 1229 1230 return self.func( 1231 name, 1232 unit_to_var(expression), 1233 expression.expression, 1234 expression.this, 1235 ) 1236 1237 return _delta_sql 1238 1239 1240def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1241 unit = expression.args.get("unit") 1242 1243 if isinstance(unit, exp.Placeholder): 1244 return unit 1245 if unit: 1246 return exp.Literal.string(unit.name) 1247 return exp.Literal.string(default) if default else None 1248 1249 1250def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1251 unit = expression.args.get("unit") 1252 1253 if isinstance(unit, (exp.Var, exp.Placeholder)): 1254 return unit 1255 return exp.Var(this=default) if default else None 1256 1257 1258@t.overload 1259def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var: 1260 pass 1261 1262 1263@t.overload 1264def map_date_part( 1265 part: t.Optional[exp.Expression], dialect: DialectType = Dialect 1266) -> t.Optional[exp.Expression]: 1267 pass 1268 1269 1270def map_date_part(part, dialect: DialectType = Dialect): 1271 mapped = ( 1272 Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None 1273 ) 1274 return exp.var(mapped) if mapped else part 1275 1276 1277def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1278 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1279 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1280 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1281 1282 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1283 1284 1285def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1286 """Remove table refs from columns in when statements.""" 1287 alias = expression.this.args.get("alias") 1288 1289 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1290 return self.dialect.normalize_identifier(identifier).name if identifier else None 1291 1292 targets = {normalize(expression.this.this)} 1293 1294 if alias: 1295 targets.add(normalize(alias.this)) 1296 1297 for when in expression.expressions: 1298 when.transform( 1299 lambda node: ( 1300 exp.column(node.this) 1301 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1302 else node 1303 ), 1304 copy=False, 1305 ) 1306 1307 return self.merge_sql(expression) 1308 1309 1310def build_json_extract_path( 1311 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1312) -> t.Callable[[t.List], F]: 1313 def _builder(args: t.List) -> F: 1314 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1315 for arg in args[1:]: 1316 if not isinstance(arg, exp.Literal): 1317 # We use the fallback parser because we can't really transpile non-literals safely 1318 return expr_type.from_arg_list(args) 1319 1320 text = arg.name 1321 if is_int(text): 1322 index = int(text) 1323 segments.append( 1324 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1325 ) 1326 else: 1327 segments.append(exp.JSONPathKey(this=text)) 1328 1329 # This is done to avoid failing in the expression validator due to the arg count 1330 del args[2:] 1331 return expr_type( 1332 this=seq_get(args, 0), 1333 expression=exp.JSONPath(expressions=segments), 1334 only_json_types=arrow_req_json_type, 1335 ) 1336 1337 return _builder 1338 1339 1340def json_extract_segments( 1341 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1342) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1343 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1344 path = expression.expression 1345 if not isinstance(path, exp.JSONPath): 1346 return rename_func(name)(self, expression) 1347 1348 segments = [] 1349 for segment in path.expressions: 1350 path = self.sql(segment) 1351 if path: 1352 if isinstance(segment, exp.JSONPathPart) and ( 1353 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1354 ): 1355 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1356 1357 segments.append(path) 1358 1359 if op: 1360 return f" {op} ".join([self.sql(expression.this), *segments]) 1361 return self.func(name, expression.this, *segments) 1362 1363 return _json_extract_segments 1364 1365 1366def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1367 if isinstance(expression.this, exp.JSONPathWildcard): 1368 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1369 1370 return expression.name 1371 1372 1373def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1374 cond = expression.expression 1375 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1376 alias = cond.expressions[0] 1377 cond = cond.this 1378 elif isinstance(cond, exp.Predicate): 1379 alias = "_u" 1380 else: 1381 self.unsupported("Unsupported filter condition") 1382 return "" 1383 1384 unnest = exp.Unnest(expressions=[expression.this]) 1385 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1386 return self.sql(exp.Array(expressions=[filtered])) 1387 1388 1389def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1390 return self.func( 1391 "TO_NUMBER", 1392 expression.this, 1393 expression.args.get("format"), 1394 expression.args.get("nlsparam"), 1395 ) 1396 1397 1398def build_default_decimal_type( 1399 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1400) -> t.Callable[[exp.DataType], exp.DataType]: 1401 def _builder(dtype: exp.DataType) -> exp.DataType: 1402 if dtype.expressions or precision is None: 1403 return dtype 1404 1405 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1406 return exp.DataType.build(f"DECIMAL({params})") 1407 1408 return _builder 1409 1410 1411def build_timestamp_from_parts(args: t.List) -> exp.Func: 1412 if len(args) == 2: 1413 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1414 # so we parse this into Anonymous for now instead of introducing complexity 1415 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1416 1417 return exp.TimestampFromParts.from_arg_list(args) 1418 1419 1420def sha256_sql(self: Generator, expression: exp.SHA2) -> str: 1421 return self.func(f"SHA{expression.text('length') or '256'}", expression.this)
41class Dialects(str, Enum): 42 """Dialects supported by SQLGLot.""" 43 44 DIALECT = "" 45 46 ATHENA = "athena" 47 BIGQUERY = "bigquery" 48 CLICKHOUSE = "clickhouse" 49 DATABRICKS = "databricks" 50 DORIS = "doris" 51 DRILL = "drill" 52 DUCKDB = "duckdb" 53 HIVE = "hive" 54 MATERIALIZE = "materialize" 55 MYSQL = "mysql" 56 ORACLE = "oracle" 57 POSTGRES = "postgres" 58 PRESTO = "presto" 59 PRQL = "prql" 60 REDSHIFT = "redshift" 61 RISINGWAVE = "risingwave" 62 SNOWFLAKE = "snowflake" 63 SPARK = "spark" 64 SPARK2 = "spark2" 65 SQLITE = "sqlite" 66 STARROCKS = "starrocks" 67 TABLEAU = "tableau" 68 TERADATA = "teradata" 69 TRINO = "trino" 70 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
73class NormalizationStrategy(str, AutoName): 74 """Specifies the strategy according to which identifiers should be normalized.""" 75 76 LOWERCASE = auto() 77 """Unquoted identifiers are lowercased.""" 78 79 UPPERCASE = auto() 80 """Unquoted identifiers are uppercased.""" 81 82 CASE_SENSITIVE = auto() 83 """Always case-sensitive, regardless of quotes.""" 84 85 CASE_INSENSITIVE = auto() 86 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
207class Dialect(metaclass=_Dialect): 208 INDEX_OFFSET = 0 209 """The base index offset for arrays.""" 210 211 WEEK_OFFSET = 0 212 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 213 214 UNNEST_COLUMN_ONLY = False 215 """Whether `UNNEST` table aliases are treated as column aliases.""" 216 217 ALIAS_POST_TABLESAMPLE = False 218 """Whether the table alias comes after tablesample.""" 219 220 TABLESAMPLE_SIZE_IS_PERCENT = False 221 """Whether a size in the table sample clause represents percentage.""" 222 223 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 224 """Specifies the strategy according to which identifiers should be normalized.""" 225 226 IDENTIFIERS_CAN_START_WITH_DIGIT = False 227 """Whether an unquoted identifier can start with a digit.""" 228 229 DPIPE_IS_STRING_CONCAT = True 230 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 231 232 STRICT_STRING_CONCAT = False 233 """Whether `CONCAT`'s arguments must be strings.""" 234 235 SUPPORTS_USER_DEFINED_TYPES = True 236 """Whether user-defined data types are supported.""" 237 238 SUPPORTS_SEMI_ANTI_JOIN = True 239 """Whether `SEMI` or `ANTI` joins are supported.""" 240 241 SUPPORTS_COLUMN_JOIN_MARKS = False 242 """Whether the old-style outer join (+) syntax is supported.""" 243 244 COPY_PARAMS_ARE_CSV = True 245 """Separator of COPY statement parameters.""" 246 247 NORMALIZE_FUNCTIONS: bool | str = "upper" 248 """ 249 Determines how function names are going to be normalized. 250 Possible values: 251 "upper" or True: Convert names to uppercase. 252 "lower": Convert names to lowercase. 253 False: Disables function name normalization. 254 """ 255 256 LOG_BASE_FIRST: t.Optional[bool] = True 257 """ 258 Whether the base comes first in the `LOG` function. 259 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 260 """ 261 262 NULL_ORDERING = "nulls_are_small" 263 """ 264 Default `NULL` ordering method to use if not explicitly set. 265 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 266 """ 267 268 TYPED_DIVISION = False 269 """ 270 Whether the behavior of `a / b` depends on the types of `a` and `b`. 271 False means `a / b` is always float division. 272 True means `a / b` is integer division if both `a` and `b` are integers. 273 """ 274 275 SAFE_DIVISION = False 276 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 277 278 CONCAT_COALESCE = False 279 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 280 281 HEX_LOWERCASE = False 282 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 283 284 DATE_FORMAT = "'%Y-%m-%d'" 285 DATEINT_FORMAT = "'%Y%m%d'" 286 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 287 288 TIME_MAPPING: t.Dict[str, str] = {} 289 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 290 291 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 292 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 293 FORMAT_MAPPING: t.Dict[str, str] = {} 294 """ 295 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 296 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 297 """ 298 299 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 300 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 301 302 PSEUDOCOLUMNS: t.Set[str] = set() 303 """ 304 Columns that are auto-generated by the engine corresponding to this dialect. 305 For example, such columns may be excluded from `SELECT *` queries. 306 """ 307 308 PREFER_CTE_ALIAS_COLUMN = False 309 """ 310 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 311 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 312 any projection aliases in the subquery. 313 314 For example, 315 WITH y(c) AS ( 316 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 317 ) SELECT c FROM y; 318 319 will be rewritten as 320 321 WITH y(c) AS ( 322 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 323 ) SELECT c FROM y; 324 """ 325 326 COPY_PARAMS_ARE_CSV = True 327 """ 328 Whether COPY statement parameters are separated by comma or whitespace 329 """ 330 331 FORCE_EARLY_ALIAS_REF_EXPANSION = False 332 """ 333 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 334 335 For example: 336 WITH data AS ( 337 SELECT 338 1 AS id, 339 2 AS my_id 340 ) 341 SELECT 342 id AS my_id 343 FROM 344 data 345 WHERE 346 my_id = 1 347 GROUP BY 348 my_id, 349 HAVING 350 my_id = 1 351 352 In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: 353 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 354 - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1" 355 """ 356 357 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 358 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 359 360 SUPPORTS_ORDER_BY_ALL = False 361 """ 362 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 363 """ 364 365 # --- Autofilled --- 366 367 tokenizer_class = Tokenizer 368 jsonpath_tokenizer_class = JSONPathTokenizer 369 parser_class = Parser 370 generator_class = Generator 371 372 # A trie of the time_mapping keys 373 TIME_TRIE: t.Dict = {} 374 FORMAT_TRIE: t.Dict = {} 375 376 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 377 INVERSE_TIME_TRIE: t.Dict = {} 378 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 379 INVERSE_FORMAT_TRIE: t.Dict = {} 380 381 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 382 383 # Delimiters for string literals and identifiers 384 QUOTE_START = "'" 385 QUOTE_END = "'" 386 IDENTIFIER_START = '"' 387 IDENTIFIER_END = '"' 388 389 # Delimiters for bit, hex, byte and unicode literals 390 BIT_START: t.Optional[str] = None 391 BIT_END: t.Optional[str] = None 392 HEX_START: t.Optional[str] = None 393 HEX_END: t.Optional[str] = None 394 BYTE_START: t.Optional[str] = None 395 BYTE_END: t.Optional[str] = None 396 UNICODE_START: t.Optional[str] = None 397 UNICODE_END: t.Optional[str] = None 398 399 DATE_PART_MAPPING = { 400 "Y": "YEAR", 401 "YY": "YEAR", 402 "YYY": "YEAR", 403 "YYYY": "YEAR", 404 "YR": "YEAR", 405 "YEARS": "YEAR", 406 "YRS": "YEAR", 407 "MM": "MONTH", 408 "MON": "MONTH", 409 "MONS": "MONTH", 410 "MONTHS": "MONTH", 411 "D": "DAY", 412 "DD": "DAY", 413 "DAYS": "DAY", 414 "DAYOFMONTH": "DAY", 415 "DAY OF WEEK": "DAYOFWEEK", 416 "WEEKDAY": "DAYOFWEEK", 417 "DOW": "DAYOFWEEK", 418 "DW": "DAYOFWEEK", 419 "WEEKDAY_ISO": "DAYOFWEEKISO", 420 "DOW_ISO": "DAYOFWEEKISO", 421 "DW_ISO": "DAYOFWEEKISO", 422 "DAY OF YEAR": "DAYOFYEAR", 423 "DOY": "DAYOFYEAR", 424 "DY": "DAYOFYEAR", 425 "W": "WEEK", 426 "WK": "WEEK", 427 "WEEKOFYEAR": "WEEK", 428 "WOY": "WEEK", 429 "WY": "WEEK", 430 "WEEK_ISO": "WEEKISO", 431 "WEEKOFYEARISO": "WEEKISO", 432 "WEEKOFYEAR_ISO": "WEEKISO", 433 "Q": "QUARTER", 434 "QTR": "QUARTER", 435 "QTRS": "QUARTER", 436 "QUARTERS": "QUARTER", 437 "H": "HOUR", 438 "HH": "HOUR", 439 "HR": "HOUR", 440 "HOURS": "HOUR", 441 "HRS": "HOUR", 442 "M": "MINUTE", 443 "MI": "MINUTE", 444 "MIN": "MINUTE", 445 "MINUTES": "MINUTE", 446 "MINS": "MINUTE", 447 "S": "SECOND", 448 "SEC": "SECOND", 449 "SECONDS": "SECOND", 450 "SECS": "SECOND", 451 "MS": "MILLISECOND", 452 "MSEC": "MILLISECOND", 453 "MSECS": "MILLISECOND", 454 "MSECOND": "MILLISECOND", 455 "MSECONDS": "MILLISECOND", 456 "MILLISEC": "MILLISECOND", 457 "MILLISECS": "MILLISECOND", 458 "MILLISECON": "MILLISECOND", 459 "MILLISECONDS": "MILLISECOND", 460 "US": "MICROSECOND", 461 "USEC": "MICROSECOND", 462 "USECS": "MICROSECOND", 463 "MICROSEC": "MICROSECOND", 464 "MICROSECS": "MICROSECOND", 465 "USECOND": "MICROSECOND", 466 "USECONDS": "MICROSECOND", 467 "MICROSECONDS": "MICROSECOND", 468 "NS": "NANOSECOND", 469 "NSEC": "NANOSECOND", 470 "NANOSEC": "NANOSECOND", 471 "NSECOND": "NANOSECOND", 472 "NSECONDS": "NANOSECOND", 473 "NANOSECS": "NANOSECOND", 474 "EPOCH_SECOND": "EPOCH", 475 "EPOCH_SECONDS": "EPOCH", 476 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 477 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 478 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 479 "TZH": "TIMEZONE_HOUR", 480 "TZM": "TIMEZONE_MINUTE", 481 "DEC": "DECADE", 482 "DECS": "DECADE", 483 "DECADES": "DECADE", 484 "MIL": "MILLENIUM", 485 "MILS": "MILLENIUM", 486 "MILLENIA": "MILLENIUM", 487 "C": "CENTURY", 488 "CENT": "CENTURY", 489 "CENTS": "CENTURY", 490 "CENTURIES": "CENTURY", 491 } 492 493 @classmethod 494 def get_or_raise(cls, dialect: DialectType) -> Dialect: 495 """ 496 Look up a dialect in the global dialect registry and return it if it exists. 497 498 Args: 499 dialect: The target dialect. If this is a string, it can be optionally followed by 500 additional key-value pairs that are separated by commas and are used to specify 501 dialect settings, such as whether the dialect's identifiers are case-sensitive. 502 503 Example: 504 >>> dialect = dialect_class = get_or_raise("duckdb") 505 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 506 507 Returns: 508 The corresponding Dialect instance. 509 """ 510 511 if not dialect: 512 return cls() 513 if isinstance(dialect, _Dialect): 514 return dialect() 515 if isinstance(dialect, Dialect): 516 return dialect 517 if isinstance(dialect, str): 518 try: 519 dialect_name, *kv_strings = dialect.split(",") 520 kv_pairs = (kv.split("=") for kv in kv_strings) 521 kwargs = {} 522 for pair in kv_pairs: 523 key = pair[0].strip() 524 value: t.Union[bool | str | None] = None 525 526 if len(pair) == 1: 527 # Default initialize standalone settings to True 528 value = True 529 elif len(pair) == 2: 530 value = pair[1].strip() 531 532 # Coerce the value to boolean if it matches to the truthy/falsy values below 533 value_lower = value.lower() 534 if value_lower in ("true", "1"): 535 value = True 536 elif value_lower in ("false", "0"): 537 value = False 538 539 kwargs[key] = value 540 541 except ValueError: 542 raise ValueError( 543 f"Invalid dialect format: '{dialect}'. " 544 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 545 ) 546 547 result = cls.get(dialect_name.strip()) 548 if not result: 549 from difflib import get_close_matches 550 551 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 552 if similar: 553 similar = f" Did you mean {similar}?" 554 555 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 556 557 return result(**kwargs) 558 559 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 560 561 @classmethod 562 def format_time( 563 cls, expression: t.Optional[str | exp.Expression] 564 ) -> t.Optional[exp.Expression]: 565 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 566 if isinstance(expression, str): 567 return exp.Literal.string( 568 # the time formats are quoted 569 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 570 ) 571 572 if expression and expression.is_string: 573 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 574 575 return expression 576 577 def __init__(self, **kwargs) -> None: 578 normalization_strategy = kwargs.pop("normalization_strategy", None) 579 580 if normalization_strategy is None: 581 self.normalization_strategy = self.NORMALIZATION_STRATEGY 582 else: 583 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 584 585 self.settings = kwargs 586 587 def __eq__(self, other: t.Any) -> bool: 588 # Does not currently take dialect state into account 589 return type(self) == other 590 591 def __hash__(self) -> int: 592 # Does not currently take dialect state into account 593 return hash(type(self)) 594 595 def normalize_identifier(self, expression: E) -> E: 596 """ 597 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 598 599 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 600 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 601 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 602 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 603 604 There are also dialects like Spark, which are case-insensitive even when quotes are 605 present, and dialects like MySQL, whose resolution rules match those employed by the 606 underlying operating system, for example they may always be case-sensitive in Linux. 607 608 Finally, the normalization behavior of some engines can even be controlled through flags, 609 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 610 611 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 612 that it can analyze queries in the optimizer and successfully capture their semantics. 613 """ 614 if ( 615 isinstance(expression, exp.Identifier) 616 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 617 and ( 618 not expression.quoted 619 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 620 ) 621 ): 622 expression.set( 623 "this", 624 ( 625 expression.this.upper() 626 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 627 else expression.this.lower() 628 ), 629 ) 630 631 return expression 632 633 def case_sensitive(self, text: str) -> bool: 634 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 635 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 636 return False 637 638 unsafe = ( 639 str.islower 640 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 641 else str.isupper 642 ) 643 return any(unsafe(char) for char in text) 644 645 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 646 """Checks if text can be identified given an identify option. 647 648 Args: 649 text: The text to check. 650 identify: 651 `"always"` or `True`: Always returns `True`. 652 `"safe"`: Only returns `True` if the identifier is case-insensitive. 653 654 Returns: 655 Whether the given text can be identified. 656 """ 657 if identify is True or identify == "always": 658 return True 659 660 if identify == "safe": 661 return not self.case_sensitive(text) 662 663 return False 664 665 def quote_identifier(self, expression: E, identify: bool = True) -> E: 666 """ 667 Adds quotes to a given identifier. 668 669 Args: 670 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 671 identify: If set to `False`, the quotes will only be added if the identifier is deemed 672 "unsafe", with respect to its characters and this dialect's normalization strategy. 673 """ 674 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 675 name = expression.this 676 expression.set( 677 "quoted", 678 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 679 ) 680 681 return expression 682 683 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 684 if isinstance(path, exp.Literal): 685 path_text = path.name 686 if path.is_number: 687 path_text = f"[{path_text}]" 688 try: 689 return parse_json_path(path_text, self) 690 except ParseError as e: 691 logger.warning(f"Invalid JSON path syntax. {str(e)}") 692 693 return path 694 695 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 696 return self.parser(**opts).parse(self.tokenize(sql), sql) 697 698 def parse_into( 699 self, expression_type: exp.IntoType, sql: str, **opts 700 ) -> t.List[t.Optional[exp.Expression]]: 701 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 702 703 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 704 return self.generator(**opts).generate(expression, copy=copy) 705 706 def transpile(self, sql: str, **opts) -> t.List[str]: 707 return [ 708 self.generate(expression, copy=False, **opts) if expression else "" 709 for expression in self.parse(sql) 710 ] 711 712 def tokenize(self, sql: str) -> t.List[Token]: 713 return self.tokenizer.tokenize(sql) 714 715 @property 716 def tokenizer(self) -> Tokenizer: 717 return self.tokenizer_class(dialect=self) 718 719 @property 720 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 721 return self.jsonpath_tokenizer_class(dialect=self) 722 723 def parser(self, **opts) -> Parser: 724 return self.parser_class(dialect=self, **opts) 725 726 def generator(self, **opts) -> Generator: 727 return self.generator_class(dialect=self, **opts)
577 def __init__(self, **kwargs) -> None: 578 normalization_strategy = kwargs.pop("normalization_strategy", None) 579 580 if normalization_strategy is None: 581 self.normalization_strategy = self.NORMALIZATION_STRATEGY 582 else: 583 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 584 585 self.settings = kwargs
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
For example:
WITH data AS ( SELECT 1 AS id, 2 AS my_id ) SELECT id AS my_id FROM data WHERE my_id = 1 GROUP BY my_id, HAVING my_id = 1
In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
Whether alias reference expansion before qualification should only happen for the GROUP BY clause.
Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks
493 @classmethod 494 def get_or_raise(cls, dialect: DialectType) -> Dialect: 495 """ 496 Look up a dialect in the global dialect registry and return it if it exists. 497 498 Args: 499 dialect: The target dialect. If this is a string, it can be optionally followed by 500 additional key-value pairs that are separated by commas and are used to specify 501 dialect settings, such as whether the dialect's identifiers are case-sensitive. 502 503 Example: 504 >>> dialect = dialect_class = get_or_raise("duckdb") 505 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 506 507 Returns: 508 The corresponding Dialect instance. 509 """ 510 511 if not dialect: 512 return cls() 513 if isinstance(dialect, _Dialect): 514 return dialect() 515 if isinstance(dialect, Dialect): 516 return dialect 517 if isinstance(dialect, str): 518 try: 519 dialect_name, *kv_strings = dialect.split(",") 520 kv_pairs = (kv.split("=") for kv in kv_strings) 521 kwargs = {} 522 for pair in kv_pairs: 523 key = pair[0].strip() 524 value: t.Union[bool | str | None] = None 525 526 if len(pair) == 1: 527 # Default initialize standalone settings to True 528 value = True 529 elif len(pair) == 2: 530 value = pair[1].strip() 531 532 # Coerce the value to boolean if it matches to the truthy/falsy values below 533 value_lower = value.lower() 534 if value_lower in ("true", "1"): 535 value = True 536 elif value_lower in ("false", "0"): 537 value = False 538 539 kwargs[key] = value 540 541 except ValueError: 542 raise ValueError( 543 f"Invalid dialect format: '{dialect}'. " 544 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 545 ) 546 547 result = cls.get(dialect_name.strip()) 548 if not result: 549 from difflib import get_close_matches 550 551 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 552 if similar: 553 similar = f" Did you mean {similar}?" 554 555 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 556 557 return result(**kwargs) 558 559 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
561 @classmethod 562 def format_time( 563 cls, expression: t.Optional[str | exp.Expression] 564 ) -> t.Optional[exp.Expression]: 565 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 566 if isinstance(expression, str): 567 return exp.Literal.string( 568 # the time formats are quoted 569 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 570 ) 571 572 if expression and expression.is_string: 573 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 574 575 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
595 def normalize_identifier(self, expression: E) -> E: 596 """ 597 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 598 599 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 600 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 601 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 602 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 603 604 There are also dialects like Spark, which are case-insensitive even when quotes are 605 present, and dialects like MySQL, whose resolution rules match those employed by the 606 underlying operating system, for example they may always be case-sensitive in Linux. 607 608 Finally, the normalization behavior of some engines can even be controlled through flags, 609 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 610 611 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 612 that it can analyze queries in the optimizer and successfully capture their semantics. 613 """ 614 if ( 615 isinstance(expression, exp.Identifier) 616 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 617 and ( 618 not expression.quoted 619 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 620 ) 621 ): 622 expression.set( 623 "this", 624 ( 625 expression.this.upper() 626 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 627 else expression.this.lower() 628 ), 629 ) 630 631 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
633 def case_sensitive(self, text: str) -> bool: 634 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 635 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 636 return False 637 638 unsafe = ( 639 str.islower 640 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 641 else str.isupper 642 ) 643 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
645 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 646 """Checks if text can be identified given an identify option. 647 648 Args: 649 text: The text to check. 650 identify: 651 `"always"` or `True`: Always returns `True`. 652 `"safe"`: Only returns `True` if the identifier is case-insensitive. 653 654 Returns: 655 Whether the given text can be identified. 656 """ 657 if identify is True or identify == "always": 658 return True 659 660 if identify == "safe": 661 return not self.case_sensitive(text) 662 663 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
665 def quote_identifier(self, expression: E, identify: bool = True) -> E: 666 """ 667 Adds quotes to a given identifier. 668 669 Args: 670 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 671 identify: If set to `False`, the quotes will only be added if the identifier is deemed 672 "unsafe", with respect to its characters and this dialect's normalization strategy. 673 """ 674 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 675 name = expression.this 676 expression.set( 677 "quoted", 678 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 679 ) 680 681 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
683 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 684 if isinstance(path, exp.Literal): 685 path_text = path.name 686 if path.is_number: 687 path_text = f"[{path_text}]" 688 try: 689 return parse_json_path(path_text, self) 690 except ParseError as e: 691 logger.warning(f"Invalid JSON path syntax. {str(e)}") 692 693 return path
743def if_sql( 744 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 745) -> t.Callable[[Generator, exp.If], str]: 746 def _if_sql(self: Generator, expression: exp.If) -> str: 747 return self.func( 748 name, 749 expression.this, 750 expression.args.get("true"), 751 expression.args.get("false") or false_value, 752 ) 753 754 return _if_sql
757def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 758 this = expression.this 759 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 760 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 761 762 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
828def str_position_sql( 829 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 830) -> str: 831 this = self.sql(expression, "this") 832 substr = self.sql(expression, "substr") 833 position = self.sql(expression, "position") 834 instance = expression.args.get("instance") if generate_instance else None 835 position_offset = "" 836 837 if position: 838 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 839 this = self.func("SUBSTR", this, position) 840 position_offset = f" + {position} - 1" 841 842 return self.func("STRPOS", this, substr, instance) + position_offset
851def var_map_sql( 852 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 853) -> str: 854 keys = expression.args["keys"] 855 values = expression.args["values"] 856 857 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 858 self.unsupported("Cannot convert array columns into map.") 859 return self.func(map_func_name, keys, values) 860 861 args = [] 862 for key, value in zip(keys.expressions, values.expressions): 863 args.append(self.sql(key)) 864 args.append(self.sql(value)) 865 866 return self.func(map_func_name, *args)
869def build_formatted_time( 870 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 871) -> t.Callable[[t.List], E]: 872 """Helper used for time expressions. 873 874 Args: 875 exp_class: the expression class to instantiate. 876 dialect: target sql dialect. 877 default: the default format, True being time. 878 879 Returns: 880 A callable that can be used to return the appropriately formatted time expression. 881 """ 882 883 def _builder(args: t.List): 884 return exp_class( 885 this=seq_get(args, 0), 886 format=Dialect[dialect].format_time( 887 seq_get(args, 1) 888 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 889 ), 890 ) 891 892 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
895def time_format( 896 dialect: DialectType = None, 897) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 898 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 899 """ 900 Returns the time format for a given expression, unless it's equivalent 901 to the default time format of the dialect of interest. 902 """ 903 time_format = self.format_time(expression) 904 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 905 906 return _time_format
909def build_date_delta( 910 exp_class: t.Type[E], 911 unit_mapping: t.Optional[t.Dict[str, str]] = None, 912 default_unit: t.Optional[str] = "DAY", 913) -> t.Callable[[t.List], E]: 914 def _builder(args: t.List) -> E: 915 unit_based = len(args) == 3 916 this = args[2] if unit_based else seq_get(args, 0) 917 unit = None 918 if unit_based or default_unit: 919 unit = args[0] if unit_based else exp.Literal.string(default_unit) 920 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 921 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 922 923 return _builder
926def build_date_delta_with_interval( 927 expression_class: t.Type[E], 928) -> t.Callable[[t.List], t.Optional[E]]: 929 def _builder(args: t.List) -> t.Optional[E]: 930 if len(args) < 2: 931 return None 932 933 interval = args[1] 934 935 if not isinstance(interval, exp.Interval): 936 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 937 938 expression = interval.this 939 if expression and expression.is_string: 940 expression = exp.Literal.number(expression.this) 941 942 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 943 944 return _builder
956def date_add_interval_sql( 957 data_type: str, kind: str 958) -> t.Callable[[Generator, exp.Expression], str]: 959 def func(self: Generator, expression: exp.Expression) -> str: 960 this = self.sql(expression, "this") 961 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 962 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 963 964 return func
967def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 968 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 969 args = [unit_to_str(expression), expression.this] 970 if zone: 971 args.append(expression.args.get("zone")) 972 return self.func("DATE_TRUNC", *args) 973 974 return _timestamptrunc_sql
977def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 978 zone = expression.args.get("zone") 979 if not zone: 980 from sqlglot.optimizer.annotate_types import annotate_types 981 982 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 983 return self.sql(exp.cast(expression.this, target_type)) 984 if zone.name.lower() in TIMEZONES: 985 return self.sql( 986 exp.AtTimeZone( 987 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 988 zone=zone, 989 ) 990 ) 991 return self.func("TIMESTAMP", expression.this, zone)
994def no_time_sql(self: Generator, expression: exp.Time) -> str: 995 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 996 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 997 expr = exp.cast( 998 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 999 ) 1000 return self.sql(expr)
1003def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1004 this = expression.this 1005 expr = expression.expression 1006 1007 if expr.name.lower() in TIMEZONES: 1008 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1009 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1010 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1011 return self.sql(this) 1012 1013 this = exp.cast(this, exp.DataType.Type.DATE) 1014 expr = exp.cast(expr, exp.DataType.Type.TIME) 1015 1016 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
1057def encode_decode_sql( 1058 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1059) -> str: 1060 charset = expression.args.get("charset") 1061 if charset and charset.name.lower() != "utf-8": 1062 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1063 1064 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
1077def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1078 cond = expression.this 1079 1080 if isinstance(expression.this, exp.Distinct): 1081 cond = expression.this.expressions[0] 1082 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1083 1084 return self.func("sum", exp.func("if", cond, 1, 0))
1087def trim_sql(self: Generator, expression: exp.Trim) -> str: 1088 target = self.sql(expression, "this") 1089 trim_type = self.sql(expression, "position") 1090 remove_chars = self.sql(expression, "expression") 1091 collation = self.sql(expression, "collation") 1092 1093 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1094 if not remove_chars and not collation: 1095 return self.trim_sql(expression) 1096 1097 trim_type = f"{trim_type} " if trim_type else "" 1098 remove_chars = f"{remove_chars} " if remove_chars else "" 1099 from_part = "FROM " if trim_type or remove_chars else "" 1100 collation = f" COLLATE {collation}" if collation else "" 1101 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
1122def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1123 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1124 if bad_args: 1125 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1126 1127 return self.func( 1128 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 1129 )
1132def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1133 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1134 if bad_args: 1135 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1136 1137 return self.func( 1138 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1139 )
1142def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1143 names = [] 1144 for agg in aggregations: 1145 if isinstance(agg, exp.Alias): 1146 names.append(agg.alias) 1147 else: 1148 """ 1149 This case corresponds to aggregations without aliases being used as suffixes 1150 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1151 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1152 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1153 """ 1154 agg_all_unquoted = agg.transform( 1155 lambda node: ( 1156 exp.Identifier(this=node.name, quoted=False) 1157 if isinstance(node, exp.Identifier) 1158 else node 1159 ) 1160 ) 1161 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1162 1163 return names
1203def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1204 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1205 if expression.args.get("count"): 1206 self.unsupported(f"Only two arguments are supported in function {name}.") 1207 1208 return self.func(name, expression.this, expression.expression) 1209 1210 return _arg_max_or_min_sql
1213def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1214 this = expression.this.copy() 1215 1216 return_type = expression.return_type 1217 if return_type.is_type(exp.DataType.Type.DATE): 1218 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1219 # can truncate timestamp strings, because some dialects can't cast them to DATE 1220 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1221 1222 expression.this.replace(exp.cast(this, return_type)) 1223 return expression
1226def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1227 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1228 if cast and isinstance(expression, exp.TsOrDsAdd): 1229 expression = ts_or_ds_add_cast(expression) 1230 1231 return self.func( 1232 name, 1233 unit_to_var(expression), 1234 expression.expression, 1235 expression.this, 1236 ) 1237 1238 return _delta_sql
1241def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1242 unit = expression.args.get("unit") 1243 1244 if isinstance(unit, exp.Placeholder): 1245 return unit 1246 if unit: 1247 return exp.Literal.string(unit.name) 1248 return exp.Literal.string(default) if default else None
1278def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1279 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1280 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1281 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1282 1283 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1286def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1287 """Remove table refs from columns in when statements.""" 1288 alias = expression.this.args.get("alias") 1289 1290 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1291 return self.dialect.normalize_identifier(identifier).name if identifier else None 1292 1293 targets = {normalize(expression.this.this)} 1294 1295 if alias: 1296 targets.add(normalize(alias.this)) 1297 1298 for when in expression.expressions: 1299 when.transform( 1300 lambda node: ( 1301 exp.column(node.this) 1302 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1303 else node 1304 ), 1305 copy=False, 1306 ) 1307 1308 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1311def build_json_extract_path( 1312 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1313) -> t.Callable[[t.List], F]: 1314 def _builder(args: t.List) -> F: 1315 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1316 for arg in args[1:]: 1317 if not isinstance(arg, exp.Literal): 1318 # We use the fallback parser because we can't really transpile non-literals safely 1319 return expr_type.from_arg_list(args) 1320 1321 text = arg.name 1322 if is_int(text): 1323 index = int(text) 1324 segments.append( 1325 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1326 ) 1327 else: 1328 segments.append(exp.JSONPathKey(this=text)) 1329 1330 # This is done to avoid failing in the expression validator due to the arg count 1331 del args[2:] 1332 return expr_type( 1333 this=seq_get(args, 0), 1334 expression=exp.JSONPath(expressions=segments), 1335 only_json_types=arrow_req_json_type, 1336 ) 1337 1338 return _builder
1341def json_extract_segments( 1342 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1343) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1344 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1345 path = expression.expression 1346 if not isinstance(path, exp.JSONPath): 1347 return rename_func(name)(self, expression) 1348 1349 segments = [] 1350 for segment in path.expressions: 1351 path = self.sql(segment) 1352 if path: 1353 if isinstance(segment, exp.JSONPathPart) and ( 1354 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1355 ): 1356 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1357 1358 segments.append(path) 1359 1360 if op: 1361 return f" {op} ".join([self.sql(expression.this), *segments]) 1362 return self.func(name, expression.this, *segments) 1363 1364 return _json_extract_segments
1374def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1375 cond = expression.expression 1376 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1377 alias = cond.expressions[0] 1378 cond = cond.this 1379 elif isinstance(cond, exp.Predicate): 1380 alias = "_u" 1381 else: 1382 self.unsupported("Unsupported filter condition") 1383 return "" 1384 1385 unnest = exp.Unnest(expressions=[expression.this]) 1386 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1387 return self.sql(exp.Array(expressions=[filtered]))
1399def build_default_decimal_type( 1400 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1401) -> t.Callable[[exp.DataType], exp.DataType]: 1402 def _builder(dtype: exp.DataType) -> exp.DataType: 1403 if dtype.expressions or precision is None: 1404 return dtype 1405 1406 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1407 return exp.DataType.build(f"DECIMAL({params})") 1408 1409 return _builder
1412def build_timestamp_from_parts(args: t.List) -> exp.Func: 1413 if len(args) == 2: 1414 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1415 # so we parse this into Anonymous for now instead of introducing complexity 1416 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1417 1418 return exp.TimestampFromParts.from_arg_list(args)