sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get 12from sqlglot.jsonpath import parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26logger = logging.getLogger("sqlglot") 27 28 29class Dialects(str, Enum): 30 """Dialects supported by SQLGLot.""" 31 32 DIALECT = "" 33 34 ATHENA = "athena" 35 BIGQUERY = "bigquery" 36 CLICKHOUSE = "clickhouse" 37 DATABRICKS = "databricks" 38 DORIS = "doris" 39 DRILL = "drill" 40 DUCKDB = "duckdb" 41 HIVE = "hive" 42 MYSQL = "mysql" 43 ORACLE = "oracle" 44 POSTGRES = "postgres" 45 PRESTO = "presto" 46 PRQL = "prql" 47 REDSHIFT = "redshift" 48 SNOWFLAKE = "snowflake" 49 SPARK = "spark" 50 SPARK2 = "spark2" 51 SQLITE = "sqlite" 52 STARROCKS = "starrocks" 53 TABLEAU = "tableau" 54 TERADATA = "teradata" 55 TRINO = "trino" 56 TSQL = "tsql" 57 58 59class NormalizationStrategy(str, AutoName): 60 """Specifies the strategy according to which identifiers should be normalized.""" 61 62 LOWERCASE = auto() 63 """Unquoted identifiers are lowercased.""" 64 65 UPPERCASE = auto() 66 """Unquoted identifiers are uppercased.""" 67 68 CASE_SENSITIVE = auto() 69 """Always case-sensitive, regardless of quotes.""" 70 71 CASE_INSENSITIVE = auto() 72 """Always case-insensitive, regardless of quotes.""" 73 74 75class _Dialect(type): 76 classes: t.Dict[str, t.Type[Dialect]] = {} 77 78 def __eq__(cls, other: t.Any) -> bool: 79 if cls is other: 80 return True 81 if isinstance(other, str): 82 return cls is cls.get(other) 83 if isinstance(other, Dialect): 84 return cls is type(other) 85 86 return False 87 88 def __hash__(cls) -> int: 89 return hash(cls.__name__.lower()) 90 91 @classmethod 92 def __getitem__(cls, key: str) -> t.Type[Dialect]: 93 return cls.classes[key] 94 95 @classmethod 96 def get( 97 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 98 ) -> t.Optional[t.Type[Dialect]]: 99 return cls.classes.get(key, default) 100 101 def __new__(cls, clsname, bases, attrs): 102 klass = super().__new__(cls, clsname, bases, attrs) 103 enum = Dialects.__members__.get(clsname.upper()) 104 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 105 106 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 107 klass.FORMAT_TRIE = ( 108 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 109 ) 110 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 111 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 112 113 base = seq_get(bases, 0) 114 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 115 base_parser = (getattr(base, "parser_class", Parser),) 116 base_generator = (getattr(base, "generator_class", Generator),) 117 118 klass.tokenizer_class = klass.__dict__.get( 119 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 120 ) 121 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 122 klass.generator_class = klass.__dict__.get( 123 "Generator", type("Generator", base_generator, {}) 124 ) 125 126 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 127 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 128 klass.tokenizer_class._IDENTIFIERS.items() 129 )[0] 130 131 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 132 return next( 133 ( 134 (s, e) 135 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 136 if t == token_type 137 ), 138 (None, None), 139 ) 140 141 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 142 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 143 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 144 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 145 146 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 147 klass.UNESCAPED_SEQUENCES = { 148 "\\a": "\a", 149 "\\b": "\b", 150 "\\f": "\f", 151 "\\n": "\n", 152 "\\r": "\r", 153 "\\t": "\t", 154 "\\v": "\v", 155 "\\\\": "\\", 156 **klass.UNESCAPED_SEQUENCES, 157 } 158 159 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 160 161 if enum not in ("", "bigquery"): 162 klass.generator_class.SELECT_KINDS = () 163 164 if enum not in ("", "databricks", "hive", "spark", "spark2"): 165 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 166 for modifier in ("cluster", "distribute", "sort"): 167 modifier_transforms.pop(modifier, None) 168 169 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 170 171 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 172 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 173 TokenType.ANTI, 174 TokenType.SEMI, 175 } 176 177 return klass 178 179 180class Dialect(metaclass=_Dialect): 181 INDEX_OFFSET = 0 182 """The base index offset for arrays.""" 183 184 WEEK_OFFSET = 0 185 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 186 187 UNNEST_COLUMN_ONLY = False 188 """Whether `UNNEST` table aliases are treated as column aliases.""" 189 190 ALIAS_POST_TABLESAMPLE = False 191 """Whether the table alias comes after tablesample.""" 192 193 TABLESAMPLE_SIZE_IS_PERCENT = False 194 """Whether a size in the table sample clause represents percentage.""" 195 196 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 197 """Specifies the strategy according to which identifiers should be normalized.""" 198 199 IDENTIFIERS_CAN_START_WITH_DIGIT = False 200 """Whether an unquoted identifier can start with a digit.""" 201 202 DPIPE_IS_STRING_CONCAT = True 203 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 204 205 STRICT_STRING_CONCAT = False 206 """Whether `CONCAT`'s arguments must be strings.""" 207 208 SUPPORTS_USER_DEFINED_TYPES = True 209 """Whether user-defined data types are supported.""" 210 211 SUPPORTS_SEMI_ANTI_JOIN = True 212 """Whether `SEMI` or `ANTI` joins are supported.""" 213 214 NORMALIZE_FUNCTIONS: bool | str = "upper" 215 """ 216 Determines how function names are going to be normalized. 217 Possible values: 218 "upper" or True: Convert names to uppercase. 219 "lower": Convert names to lowercase. 220 False: Disables function name normalization. 221 """ 222 223 LOG_BASE_FIRST: t.Optional[bool] = True 224 """ 225 Whether the base comes first in the `LOG` function. 226 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 227 """ 228 229 NULL_ORDERING = "nulls_are_small" 230 """ 231 Default `NULL` ordering method to use if not explicitly set. 232 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 233 """ 234 235 TYPED_DIVISION = False 236 """ 237 Whether the behavior of `a / b` depends on the types of `a` and `b`. 238 False means `a / b` is always float division. 239 True means `a / b` is integer division if both `a` and `b` are integers. 240 """ 241 242 SAFE_DIVISION = False 243 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 244 245 CONCAT_COALESCE = False 246 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 247 248 DATE_FORMAT = "'%Y-%m-%d'" 249 DATEINT_FORMAT = "'%Y%m%d'" 250 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 251 252 TIME_MAPPING: t.Dict[str, str] = {} 253 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 254 255 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 256 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 257 FORMAT_MAPPING: t.Dict[str, str] = {} 258 """ 259 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 260 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 261 """ 262 263 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 264 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 265 266 PSEUDOCOLUMNS: t.Set[str] = set() 267 """ 268 Columns that are auto-generated by the engine corresponding to this dialect. 269 For example, such columns may be excluded from `SELECT *` queries. 270 """ 271 272 PREFER_CTE_ALIAS_COLUMN = False 273 """ 274 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 275 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 276 any projection aliases in the subquery. 277 278 For example, 279 WITH y(c) AS ( 280 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 281 ) SELECT c FROM y; 282 283 will be rewritten as 284 285 WITH y(c) AS ( 286 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 287 ) SELECT c FROM y; 288 """ 289 290 # --- Autofilled --- 291 292 tokenizer_class = Tokenizer 293 parser_class = Parser 294 generator_class = Generator 295 296 # A trie of the time_mapping keys 297 TIME_TRIE: t.Dict = {} 298 FORMAT_TRIE: t.Dict = {} 299 300 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 301 INVERSE_TIME_TRIE: t.Dict = {} 302 303 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 304 305 # Delimiters for string literals and identifiers 306 QUOTE_START = "'" 307 QUOTE_END = "'" 308 IDENTIFIER_START = '"' 309 IDENTIFIER_END = '"' 310 311 # Delimiters for bit, hex, byte and unicode literals 312 BIT_START: t.Optional[str] = None 313 BIT_END: t.Optional[str] = None 314 HEX_START: t.Optional[str] = None 315 HEX_END: t.Optional[str] = None 316 BYTE_START: t.Optional[str] = None 317 BYTE_END: t.Optional[str] = None 318 UNICODE_START: t.Optional[str] = None 319 UNICODE_END: t.Optional[str] = None 320 321 @classmethod 322 def get_or_raise(cls, dialect: DialectType) -> Dialect: 323 """ 324 Look up a dialect in the global dialect registry and return it if it exists. 325 326 Args: 327 dialect: The target dialect. If this is a string, it can be optionally followed by 328 additional key-value pairs that are separated by commas and are used to specify 329 dialect settings, such as whether the dialect's identifiers are case-sensitive. 330 331 Example: 332 >>> dialect = dialect_class = get_or_raise("duckdb") 333 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 334 335 Returns: 336 The corresponding Dialect instance. 337 """ 338 339 if not dialect: 340 return cls() 341 if isinstance(dialect, _Dialect): 342 return dialect() 343 if isinstance(dialect, Dialect): 344 return dialect 345 if isinstance(dialect, str): 346 try: 347 dialect_name, *kv_pairs = dialect.split(",") 348 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 349 except ValueError: 350 raise ValueError( 351 f"Invalid dialect format: '{dialect}'. " 352 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 353 ) 354 355 result = cls.get(dialect_name.strip()) 356 if not result: 357 from difflib import get_close_matches 358 359 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 360 if similar: 361 similar = f" Did you mean {similar}?" 362 363 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 364 365 return result(**kwargs) 366 367 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 368 369 @classmethod 370 def format_time( 371 cls, expression: t.Optional[str | exp.Expression] 372 ) -> t.Optional[exp.Expression]: 373 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 374 if isinstance(expression, str): 375 return exp.Literal.string( 376 # the time formats are quoted 377 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 378 ) 379 380 if expression and expression.is_string: 381 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 382 383 return expression 384 385 def __init__(self, **kwargs) -> None: 386 normalization_strategy = kwargs.get("normalization_strategy") 387 388 if normalization_strategy is None: 389 self.normalization_strategy = self.NORMALIZATION_STRATEGY 390 else: 391 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 392 393 def __eq__(self, other: t.Any) -> bool: 394 # Does not currently take dialect state into account 395 return type(self) == other 396 397 def __hash__(self) -> int: 398 # Does not currently take dialect state into account 399 return hash(type(self)) 400 401 def normalize_identifier(self, expression: E) -> E: 402 """ 403 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 404 405 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 406 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 407 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 408 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 409 410 There are also dialects like Spark, which are case-insensitive even when quotes are 411 present, and dialects like MySQL, whose resolution rules match those employed by the 412 underlying operating system, for example they may always be case-sensitive in Linux. 413 414 Finally, the normalization behavior of some engines can even be controlled through flags, 415 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 416 417 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 418 that it can analyze queries in the optimizer and successfully capture their semantics. 419 """ 420 if ( 421 isinstance(expression, exp.Identifier) 422 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 423 and ( 424 not expression.quoted 425 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 426 ) 427 ): 428 expression.set( 429 "this", 430 ( 431 expression.this.upper() 432 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 433 else expression.this.lower() 434 ), 435 ) 436 437 return expression 438 439 def case_sensitive(self, text: str) -> bool: 440 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 441 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 442 return False 443 444 unsafe = ( 445 str.islower 446 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 447 else str.isupper 448 ) 449 return any(unsafe(char) for char in text) 450 451 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 452 """Checks if text can be identified given an identify option. 453 454 Args: 455 text: The text to check. 456 identify: 457 `"always"` or `True`: Always returns `True`. 458 `"safe"`: Only returns `True` if the identifier is case-insensitive. 459 460 Returns: 461 Whether the given text can be identified. 462 """ 463 if identify is True or identify == "always": 464 return True 465 466 if identify == "safe": 467 return not self.case_sensitive(text) 468 469 return False 470 471 def quote_identifier(self, expression: E, identify: bool = True) -> E: 472 """ 473 Adds quotes to a given identifier. 474 475 Args: 476 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 477 identify: If set to `False`, the quotes will only be added if the identifier is deemed 478 "unsafe", with respect to its characters and this dialect's normalization strategy. 479 """ 480 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 481 name = expression.this 482 expression.set( 483 "quoted", 484 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 485 ) 486 487 return expression 488 489 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 490 if isinstance(path, exp.Literal): 491 path_text = path.name 492 if path.is_number: 493 path_text = f"[{path_text}]" 494 495 try: 496 return parse_json_path(path_text) 497 except ParseError as e: 498 logger.warning(f"Invalid JSON path syntax. {str(e)}") 499 500 return path 501 502 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 503 return self.parser(**opts).parse(self.tokenize(sql), sql) 504 505 def parse_into( 506 self, expression_type: exp.IntoType, sql: str, **opts 507 ) -> t.List[t.Optional[exp.Expression]]: 508 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 509 510 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 511 return self.generator(**opts).generate(expression, copy=copy) 512 513 def transpile(self, sql: str, **opts) -> t.List[str]: 514 return [ 515 self.generate(expression, copy=False, **opts) if expression else "" 516 for expression in self.parse(sql) 517 ] 518 519 def tokenize(self, sql: str) -> t.List[Token]: 520 return self.tokenizer.tokenize(sql) 521 522 @property 523 def tokenizer(self) -> Tokenizer: 524 if not hasattr(self, "_tokenizer"): 525 self._tokenizer = self.tokenizer_class(dialect=self) 526 return self._tokenizer 527 528 def parser(self, **opts) -> Parser: 529 return self.parser_class(dialect=self, **opts) 530 531 def generator(self, **opts) -> Generator: 532 return self.generator_class(dialect=self, **opts) 533 534 535DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 536 537 538def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 539 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 540 541 542def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 543 if expression.args.get("accuracy"): 544 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 545 return self.func("APPROX_COUNT_DISTINCT", expression.this) 546 547 548def if_sql( 549 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 550) -> t.Callable[[Generator, exp.If], str]: 551 def _if_sql(self: Generator, expression: exp.If) -> str: 552 return self.func( 553 name, 554 expression.this, 555 expression.args.get("true"), 556 expression.args.get("false") or false_value, 557 ) 558 559 return _if_sql 560 561 562def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 563 this = expression.this 564 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 565 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 566 567 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 568 569 570def inline_array_sql(self: Generator, expression: exp.Array) -> str: 571 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 572 573 574def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 575 elem = seq_get(expression.expressions, 0) 576 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 577 return self.func("ARRAY", elem) 578 return inline_array_sql(self, expression) 579 580 581def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 582 return self.like_sql( 583 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 584 ) 585 586 587def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 588 zone = self.sql(expression, "this") 589 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 590 591 592def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 593 if expression.args.get("recursive"): 594 self.unsupported("Recursive CTEs are unsupported") 595 expression.args["recursive"] = False 596 return self.with_sql(expression) 597 598 599def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 600 n = self.sql(expression, "this") 601 d = self.sql(expression, "expression") 602 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 603 604 605def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 606 self.unsupported("TABLESAMPLE unsupported") 607 return self.sql(expression.this) 608 609 610def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 611 self.unsupported("PIVOT unsupported") 612 return "" 613 614 615def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 616 return self.cast_sql(expression) 617 618 619def no_comment_column_constraint_sql( 620 self: Generator, expression: exp.CommentColumnConstraint 621) -> str: 622 self.unsupported("CommentColumnConstraint unsupported") 623 return "" 624 625 626def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 627 self.unsupported("MAP_FROM_ENTRIES unsupported") 628 return "" 629 630 631def str_position_sql( 632 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 633) -> str: 634 this = self.sql(expression, "this") 635 substr = self.sql(expression, "substr") 636 position = self.sql(expression, "position") 637 instance = expression.args.get("instance") if generate_instance else None 638 position_offset = "" 639 640 if position: 641 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 642 this = self.func("SUBSTR", this, position) 643 position_offset = f" + {position} - 1" 644 645 return self.func("STRPOS", this, substr, instance) + position_offset 646 647 648def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 649 return ( 650 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 651 ) 652 653 654def var_map_sql( 655 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 656) -> str: 657 keys = expression.args["keys"] 658 values = expression.args["values"] 659 660 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 661 self.unsupported("Cannot convert array columns into map.") 662 return self.func(map_func_name, keys, values) 663 664 args = [] 665 for key, value in zip(keys.expressions, values.expressions): 666 args.append(self.sql(key)) 667 args.append(self.sql(value)) 668 669 return self.func(map_func_name, *args) 670 671 672def build_formatted_time( 673 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 674) -> t.Callable[[t.List], E]: 675 """Helper used for time expressions. 676 677 Args: 678 exp_class: the expression class to instantiate. 679 dialect: target sql dialect. 680 default: the default format, True being time. 681 682 Returns: 683 A callable that can be used to return the appropriately formatted time expression. 684 """ 685 686 def _builder(args: t.List): 687 return exp_class( 688 this=seq_get(args, 0), 689 format=Dialect[dialect].format_time( 690 seq_get(args, 1) 691 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 692 ), 693 ) 694 695 return _builder 696 697 698def time_format( 699 dialect: DialectType = None, 700) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 701 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 702 """ 703 Returns the time format for a given expression, unless it's equivalent 704 to the default time format of the dialect of interest. 705 """ 706 time_format = self.format_time(expression) 707 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 708 709 return _time_format 710 711 712def build_date_delta( 713 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 714) -> t.Callable[[t.List], E]: 715 def _builder(args: t.List) -> E: 716 unit_based = len(args) == 3 717 this = args[2] if unit_based else seq_get(args, 0) 718 unit = args[0] if unit_based else exp.Literal.string("DAY") 719 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 720 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 721 722 return _builder 723 724 725def build_date_delta_with_interval( 726 expression_class: t.Type[E], 727) -> t.Callable[[t.List], t.Optional[E]]: 728 def _builder(args: t.List) -> t.Optional[E]: 729 if len(args) < 2: 730 return None 731 732 interval = args[1] 733 734 if not isinstance(interval, exp.Interval): 735 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 736 737 expression = interval.this 738 if expression and expression.is_string: 739 expression = exp.Literal.number(expression.this) 740 741 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 742 743 return _builder 744 745 746def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 747 unit = seq_get(args, 0) 748 this = seq_get(args, 1) 749 750 if isinstance(this, exp.Cast) and this.is_type("date"): 751 return exp.DateTrunc(unit=unit, this=this) 752 return exp.TimestampTrunc(this=this, unit=unit) 753 754 755def date_add_interval_sql( 756 data_type: str, kind: str 757) -> t.Callable[[Generator, exp.Expression], str]: 758 def func(self: Generator, expression: exp.Expression) -> str: 759 this = self.sql(expression, "this") 760 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 761 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 762 763 return func 764 765 766def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 767 return self.func("DATE_TRUNC", unit_to_str(expression), expression.this) 768 769 770def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 771 if not expression.expression: 772 from sqlglot.optimizer.annotate_types import annotate_types 773 774 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 775 return self.sql(exp.cast(expression.this, target_type)) 776 if expression.text("expression").lower() in TIMEZONES: 777 return self.sql( 778 exp.AtTimeZone( 779 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 780 zone=expression.expression, 781 ) 782 ) 783 return self.func("TIMESTAMP", expression.this, expression.expression) 784 785 786def locate_to_strposition(args: t.List) -> exp.Expression: 787 return exp.StrPosition( 788 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 789 ) 790 791 792def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 793 return self.func( 794 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 795 ) 796 797 798def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 799 return self.sql( 800 exp.Substring( 801 this=expression.this, start=exp.Literal.number(1), length=expression.expression 802 ) 803 ) 804 805 806def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 807 return self.sql( 808 exp.Substring( 809 this=expression.this, 810 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 811 ) 812 ) 813 814 815def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 816 return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP)) 817 818 819def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 820 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 821 822 823# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 824def encode_decode_sql( 825 self: Generator, expression: exp.Expression, name: str, replace: bool = True 826) -> str: 827 charset = expression.args.get("charset") 828 if charset and charset.name.lower() != "utf-8": 829 self.unsupported(f"Expected utf-8 character set, got {charset}.") 830 831 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 832 833 834def min_or_least(self: Generator, expression: exp.Min) -> str: 835 name = "LEAST" if expression.expressions else "MIN" 836 return rename_func(name)(self, expression) 837 838 839def max_or_greatest(self: Generator, expression: exp.Max) -> str: 840 name = "GREATEST" if expression.expressions else "MAX" 841 return rename_func(name)(self, expression) 842 843 844def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 845 cond = expression.this 846 847 if isinstance(expression.this, exp.Distinct): 848 cond = expression.this.expressions[0] 849 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 850 851 return self.func("sum", exp.func("if", cond, 1, 0)) 852 853 854def trim_sql(self: Generator, expression: exp.Trim) -> str: 855 target = self.sql(expression, "this") 856 trim_type = self.sql(expression, "position") 857 remove_chars = self.sql(expression, "expression") 858 collation = self.sql(expression, "collation") 859 860 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 861 if not remove_chars and not collation: 862 return self.trim_sql(expression) 863 864 trim_type = f"{trim_type} " if trim_type else "" 865 remove_chars = f"{remove_chars} " if remove_chars else "" 866 from_part = "FROM " if trim_type or remove_chars else "" 867 collation = f" COLLATE {collation}" if collation else "" 868 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 869 870 871def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 872 return self.func("STRPTIME", expression.this, self.format_time(expression)) 873 874 875def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 876 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 877 878 879def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 880 delim, *rest_args = expression.expressions 881 return self.sql( 882 reduce( 883 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 884 rest_args, 885 ) 886 ) 887 888 889def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 890 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 891 if bad_args: 892 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 893 894 return self.func( 895 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 896 ) 897 898 899def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 900 bad_args = list( 901 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 902 ) 903 if bad_args: 904 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 905 906 return self.func( 907 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 908 ) 909 910 911def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 912 names = [] 913 for agg in aggregations: 914 if isinstance(agg, exp.Alias): 915 names.append(agg.alias) 916 else: 917 """ 918 This case corresponds to aggregations without aliases being used as suffixes 919 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 920 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 921 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 922 """ 923 agg_all_unquoted = agg.transform( 924 lambda node: ( 925 exp.Identifier(this=node.name, quoted=False) 926 if isinstance(node, exp.Identifier) 927 else node 928 ) 929 ) 930 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 931 932 return names 933 934 935def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 936 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 937 938 939# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 940def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 941 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 942 943 944def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 945 return self.func("MAX", expression.this) 946 947 948def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 949 a = self.sql(expression.left) 950 b = self.sql(expression.right) 951 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 952 953 954def is_parse_json(expression: exp.Expression) -> bool: 955 return isinstance(expression, exp.ParseJSON) or ( 956 isinstance(expression, exp.Cast) and expression.is_type("json") 957 ) 958 959 960def isnull_to_is_null(args: t.List) -> exp.Expression: 961 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 962 963 964def generatedasidentitycolumnconstraint_sql( 965 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 966) -> str: 967 start = self.sql(expression, "start") or "1" 968 increment = self.sql(expression, "increment") or "1" 969 return f"IDENTITY({start}, {increment})" 970 971 972def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 973 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 974 if expression.args.get("count"): 975 self.unsupported(f"Only two arguments are supported in function {name}.") 976 977 return self.func(name, expression.this, expression.expression) 978 979 return _arg_max_or_min_sql 980 981 982def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 983 this = expression.this.copy() 984 985 return_type = expression.return_type 986 if return_type.is_type(exp.DataType.Type.DATE): 987 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 988 # can truncate timestamp strings, because some dialects can't cast them to DATE 989 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 990 991 expression.this.replace(exp.cast(this, return_type)) 992 return expression 993 994 995def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 996 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 997 if cast and isinstance(expression, exp.TsOrDsAdd): 998 expression = ts_or_ds_add_cast(expression) 999 1000 return self.func( 1001 name, 1002 unit_to_var(expression), 1003 expression.expression, 1004 expression.this, 1005 ) 1006 1007 return _delta_sql 1008 1009 1010def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1011 unit = expression.args.get("unit") 1012 1013 if isinstance(unit, exp.Placeholder): 1014 return unit 1015 if unit: 1016 return exp.Literal.string(unit.name) 1017 return exp.Literal.string(default) if default else None 1018 1019 1020def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1021 unit = expression.args.get("unit") 1022 1023 if isinstance(unit, (exp.Var, exp.Placeholder)): 1024 return unit 1025 return exp.Var(this=default) if default else None 1026 1027 1028def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1029 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1030 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1031 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1032 1033 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1034 1035 1036def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1037 """Remove table refs from columns in when statements.""" 1038 alias = expression.this.args.get("alias") 1039 1040 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1041 return self.dialect.normalize_identifier(identifier).name if identifier else None 1042 1043 targets = {normalize(expression.this.this)} 1044 1045 if alias: 1046 targets.add(normalize(alias.this)) 1047 1048 for when in expression.expressions: 1049 when.transform( 1050 lambda node: ( 1051 exp.column(node.this) 1052 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1053 else node 1054 ), 1055 copy=False, 1056 ) 1057 1058 return self.merge_sql(expression) 1059 1060 1061def build_json_extract_path( 1062 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1063) -> t.Callable[[t.List], F]: 1064 def _builder(args: t.List) -> F: 1065 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1066 for arg in args[1:]: 1067 if not isinstance(arg, exp.Literal): 1068 # We use the fallback parser because we can't really transpile non-literals safely 1069 return expr_type.from_arg_list(args) 1070 1071 text = arg.name 1072 if is_int(text): 1073 index = int(text) 1074 segments.append( 1075 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1076 ) 1077 else: 1078 segments.append(exp.JSONPathKey(this=text)) 1079 1080 # This is done to avoid failing in the expression validator due to the arg count 1081 del args[2:] 1082 return expr_type( 1083 this=seq_get(args, 0), 1084 expression=exp.JSONPath(expressions=segments), 1085 only_json_types=arrow_req_json_type, 1086 ) 1087 1088 return _builder 1089 1090 1091def json_extract_segments( 1092 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1093) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1094 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1095 path = expression.expression 1096 if not isinstance(path, exp.JSONPath): 1097 return rename_func(name)(self, expression) 1098 1099 segments = [] 1100 for segment in path.expressions: 1101 path = self.sql(segment) 1102 if path: 1103 if isinstance(segment, exp.JSONPathPart) and ( 1104 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1105 ): 1106 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1107 1108 segments.append(path) 1109 1110 if op: 1111 return f" {op} ".join([self.sql(expression.this), *segments]) 1112 return self.func(name, expression.this, *segments) 1113 1114 return _json_extract_segments 1115 1116 1117def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1118 if isinstance(expression.this, exp.JSONPathWildcard): 1119 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1120 1121 return expression.name 1122 1123 1124def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1125 cond = expression.expression 1126 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1127 alias = cond.expressions[0] 1128 cond = cond.this 1129 elif isinstance(cond, exp.Predicate): 1130 alias = "_u" 1131 else: 1132 self.unsupported("Unsupported filter condition") 1133 return "" 1134 1135 unnest = exp.Unnest(expressions=[expression.this]) 1136 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1137 return self.sql(exp.Array(expressions=[filtered])) 1138 1139 1140def to_number_with_nls_param(self, expression: exp.ToNumber) -> str: 1141 return self.func( 1142 "TO_NUMBER", 1143 expression.this, 1144 expression.args.get("format"), 1145 expression.args.get("nlsparam"), 1146 )
30class Dialects(str, Enum): 31 """Dialects supported by SQLGLot.""" 32 33 DIALECT = "" 34 35 ATHENA = "athena" 36 BIGQUERY = "bigquery" 37 CLICKHOUSE = "clickhouse" 38 DATABRICKS = "databricks" 39 DORIS = "doris" 40 DRILL = "drill" 41 DUCKDB = "duckdb" 42 HIVE = "hive" 43 MYSQL = "mysql" 44 ORACLE = "oracle" 45 POSTGRES = "postgres" 46 PRESTO = "presto" 47 PRQL = "prql" 48 REDSHIFT = "redshift" 49 SNOWFLAKE = "snowflake" 50 SPARK = "spark" 51 SPARK2 = "spark2" 52 SQLITE = "sqlite" 53 STARROCKS = "starrocks" 54 TABLEAU = "tableau" 55 TERADATA = "teradata" 56 TRINO = "trino" 57 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
60class NormalizationStrategy(str, AutoName): 61 """Specifies the strategy according to which identifiers should be normalized.""" 62 63 LOWERCASE = auto() 64 """Unquoted identifiers are lowercased.""" 65 66 UPPERCASE = auto() 67 """Unquoted identifiers are uppercased.""" 68 69 CASE_SENSITIVE = auto() 70 """Always case-sensitive, regardless of quotes.""" 71 72 CASE_INSENSITIVE = auto() 73 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
181class Dialect(metaclass=_Dialect): 182 INDEX_OFFSET = 0 183 """The base index offset for arrays.""" 184 185 WEEK_OFFSET = 0 186 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 187 188 UNNEST_COLUMN_ONLY = False 189 """Whether `UNNEST` table aliases are treated as column aliases.""" 190 191 ALIAS_POST_TABLESAMPLE = False 192 """Whether the table alias comes after tablesample.""" 193 194 TABLESAMPLE_SIZE_IS_PERCENT = False 195 """Whether a size in the table sample clause represents percentage.""" 196 197 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 198 """Specifies the strategy according to which identifiers should be normalized.""" 199 200 IDENTIFIERS_CAN_START_WITH_DIGIT = False 201 """Whether an unquoted identifier can start with a digit.""" 202 203 DPIPE_IS_STRING_CONCAT = True 204 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 205 206 STRICT_STRING_CONCAT = False 207 """Whether `CONCAT`'s arguments must be strings.""" 208 209 SUPPORTS_USER_DEFINED_TYPES = True 210 """Whether user-defined data types are supported.""" 211 212 SUPPORTS_SEMI_ANTI_JOIN = True 213 """Whether `SEMI` or `ANTI` joins are supported.""" 214 215 NORMALIZE_FUNCTIONS: bool | str = "upper" 216 """ 217 Determines how function names are going to be normalized. 218 Possible values: 219 "upper" or True: Convert names to uppercase. 220 "lower": Convert names to lowercase. 221 False: Disables function name normalization. 222 """ 223 224 LOG_BASE_FIRST: t.Optional[bool] = True 225 """ 226 Whether the base comes first in the `LOG` function. 227 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 228 """ 229 230 NULL_ORDERING = "nulls_are_small" 231 """ 232 Default `NULL` ordering method to use if not explicitly set. 233 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 234 """ 235 236 TYPED_DIVISION = False 237 """ 238 Whether the behavior of `a / b` depends on the types of `a` and `b`. 239 False means `a / b` is always float division. 240 True means `a / b` is integer division if both `a` and `b` are integers. 241 """ 242 243 SAFE_DIVISION = False 244 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 245 246 CONCAT_COALESCE = False 247 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 248 249 DATE_FORMAT = "'%Y-%m-%d'" 250 DATEINT_FORMAT = "'%Y%m%d'" 251 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 252 253 TIME_MAPPING: t.Dict[str, str] = {} 254 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 255 256 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 257 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 258 FORMAT_MAPPING: t.Dict[str, str] = {} 259 """ 260 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 261 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 262 """ 263 264 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 265 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 266 267 PSEUDOCOLUMNS: t.Set[str] = set() 268 """ 269 Columns that are auto-generated by the engine corresponding to this dialect. 270 For example, such columns may be excluded from `SELECT *` queries. 271 """ 272 273 PREFER_CTE_ALIAS_COLUMN = False 274 """ 275 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 276 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 277 any projection aliases in the subquery. 278 279 For example, 280 WITH y(c) AS ( 281 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 282 ) SELECT c FROM y; 283 284 will be rewritten as 285 286 WITH y(c) AS ( 287 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 288 ) SELECT c FROM y; 289 """ 290 291 # --- Autofilled --- 292 293 tokenizer_class = Tokenizer 294 parser_class = Parser 295 generator_class = Generator 296 297 # A trie of the time_mapping keys 298 TIME_TRIE: t.Dict = {} 299 FORMAT_TRIE: t.Dict = {} 300 301 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 302 INVERSE_TIME_TRIE: t.Dict = {} 303 304 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 305 306 # Delimiters for string literals and identifiers 307 QUOTE_START = "'" 308 QUOTE_END = "'" 309 IDENTIFIER_START = '"' 310 IDENTIFIER_END = '"' 311 312 # Delimiters for bit, hex, byte and unicode literals 313 BIT_START: t.Optional[str] = None 314 BIT_END: t.Optional[str] = None 315 HEX_START: t.Optional[str] = None 316 HEX_END: t.Optional[str] = None 317 BYTE_START: t.Optional[str] = None 318 BYTE_END: t.Optional[str] = None 319 UNICODE_START: t.Optional[str] = None 320 UNICODE_END: t.Optional[str] = None 321 322 @classmethod 323 def get_or_raise(cls, dialect: DialectType) -> Dialect: 324 """ 325 Look up a dialect in the global dialect registry and return it if it exists. 326 327 Args: 328 dialect: The target dialect. If this is a string, it can be optionally followed by 329 additional key-value pairs that are separated by commas and are used to specify 330 dialect settings, such as whether the dialect's identifiers are case-sensitive. 331 332 Example: 333 >>> dialect = dialect_class = get_or_raise("duckdb") 334 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 335 336 Returns: 337 The corresponding Dialect instance. 338 """ 339 340 if not dialect: 341 return cls() 342 if isinstance(dialect, _Dialect): 343 return dialect() 344 if isinstance(dialect, Dialect): 345 return dialect 346 if isinstance(dialect, str): 347 try: 348 dialect_name, *kv_pairs = dialect.split(",") 349 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 350 except ValueError: 351 raise ValueError( 352 f"Invalid dialect format: '{dialect}'. " 353 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 354 ) 355 356 result = cls.get(dialect_name.strip()) 357 if not result: 358 from difflib import get_close_matches 359 360 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 361 if similar: 362 similar = f" Did you mean {similar}?" 363 364 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 365 366 return result(**kwargs) 367 368 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 369 370 @classmethod 371 def format_time( 372 cls, expression: t.Optional[str | exp.Expression] 373 ) -> t.Optional[exp.Expression]: 374 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 375 if isinstance(expression, str): 376 return exp.Literal.string( 377 # the time formats are quoted 378 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 379 ) 380 381 if expression and expression.is_string: 382 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 383 384 return expression 385 386 def __init__(self, **kwargs) -> None: 387 normalization_strategy = kwargs.get("normalization_strategy") 388 389 if normalization_strategy is None: 390 self.normalization_strategy = self.NORMALIZATION_STRATEGY 391 else: 392 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 393 394 def __eq__(self, other: t.Any) -> bool: 395 # Does not currently take dialect state into account 396 return type(self) == other 397 398 def __hash__(self) -> int: 399 # Does not currently take dialect state into account 400 return hash(type(self)) 401 402 def normalize_identifier(self, expression: E) -> E: 403 """ 404 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 405 406 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 407 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 408 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 409 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 410 411 There are also dialects like Spark, which are case-insensitive even when quotes are 412 present, and dialects like MySQL, whose resolution rules match those employed by the 413 underlying operating system, for example they may always be case-sensitive in Linux. 414 415 Finally, the normalization behavior of some engines can even be controlled through flags, 416 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 417 418 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 419 that it can analyze queries in the optimizer and successfully capture their semantics. 420 """ 421 if ( 422 isinstance(expression, exp.Identifier) 423 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 424 and ( 425 not expression.quoted 426 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 427 ) 428 ): 429 expression.set( 430 "this", 431 ( 432 expression.this.upper() 433 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 434 else expression.this.lower() 435 ), 436 ) 437 438 return expression 439 440 def case_sensitive(self, text: str) -> bool: 441 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 442 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 443 return False 444 445 unsafe = ( 446 str.islower 447 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 448 else str.isupper 449 ) 450 return any(unsafe(char) for char in text) 451 452 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 453 """Checks if text can be identified given an identify option. 454 455 Args: 456 text: The text to check. 457 identify: 458 `"always"` or `True`: Always returns `True`. 459 `"safe"`: Only returns `True` if the identifier is case-insensitive. 460 461 Returns: 462 Whether the given text can be identified. 463 """ 464 if identify is True or identify == "always": 465 return True 466 467 if identify == "safe": 468 return not self.case_sensitive(text) 469 470 return False 471 472 def quote_identifier(self, expression: E, identify: bool = True) -> E: 473 """ 474 Adds quotes to a given identifier. 475 476 Args: 477 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 478 identify: If set to `False`, the quotes will only be added if the identifier is deemed 479 "unsafe", with respect to its characters and this dialect's normalization strategy. 480 """ 481 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 482 name = expression.this 483 expression.set( 484 "quoted", 485 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 486 ) 487 488 return expression 489 490 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 491 if isinstance(path, exp.Literal): 492 path_text = path.name 493 if path.is_number: 494 path_text = f"[{path_text}]" 495 496 try: 497 return parse_json_path(path_text) 498 except ParseError as e: 499 logger.warning(f"Invalid JSON path syntax. {str(e)}") 500 501 return path 502 503 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 504 return self.parser(**opts).parse(self.tokenize(sql), sql) 505 506 def parse_into( 507 self, expression_type: exp.IntoType, sql: str, **opts 508 ) -> t.List[t.Optional[exp.Expression]]: 509 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 510 511 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 512 return self.generator(**opts).generate(expression, copy=copy) 513 514 def transpile(self, sql: str, **opts) -> t.List[str]: 515 return [ 516 self.generate(expression, copy=False, **opts) if expression else "" 517 for expression in self.parse(sql) 518 ] 519 520 def tokenize(self, sql: str) -> t.List[Token]: 521 return self.tokenizer.tokenize(sql) 522 523 @property 524 def tokenizer(self) -> Tokenizer: 525 if not hasattr(self, "_tokenizer"): 526 self._tokenizer = self.tokenizer_class(dialect=self) 527 return self._tokenizer 528 529 def parser(self, **opts) -> Parser: 530 return self.parser_class(dialect=self, **opts) 531 532 def generator(self, **opts) -> Generator: 533 return self.generator_class(dialect=self, **opts)
386 def __init__(self, **kwargs) -> None: 387 normalization_strategy = kwargs.get("normalization_strategy") 388 389 if normalization_strategy is None: 390 self.normalization_strategy = self.NORMALIZATION_STRATEGY 391 else: 392 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
322 @classmethod 323 def get_or_raise(cls, dialect: DialectType) -> Dialect: 324 """ 325 Look up a dialect in the global dialect registry and return it if it exists. 326 327 Args: 328 dialect: The target dialect. If this is a string, it can be optionally followed by 329 additional key-value pairs that are separated by commas and are used to specify 330 dialect settings, such as whether the dialect's identifiers are case-sensitive. 331 332 Example: 333 >>> dialect = dialect_class = get_or_raise("duckdb") 334 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 335 336 Returns: 337 The corresponding Dialect instance. 338 """ 339 340 if not dialect: 341 return cls() 342 if isinstance(dialect, _Dialect): 343 return dialect() 344 if isinstance(dialect, Dialect): 345 return dialect 346 if isinstance(dialect, str): 347 try: 348 dialect_name, *kv_pairs = dialect.split(",") 349 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 350 except ValueError: 351 raise ValueError( 352 f"Invalid dialect format: '{dialect}'. " 353 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 354 ) 355 356 result = cls.get(dialect_name.strip()) 357 if not result: 358 from difflib import get_close_matches 359 360 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 361 if similar: 362 similar = f" Did you mean {similar}?" 363 364 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 365 366 return result(**kwargs) 367 368 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
370 @classmethod 371 def format_time( 372 cls, expression: t.Optional[str | exp.Expression] 373 ) -> t.Optional[exp.Expression]: 374 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 375 if isinstance(expression, str): 376 return exp.Literal.string( 377 # the time formats are quoted 378 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 379 ) 380 381 if expression and expression.is_string: 382 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 383 384 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
402 def normalize_identifier(self, expression: E) -> E: 403 """ 404 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 405 406 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 407 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 408 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 409 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 410 411 There are also dialects like Spark, which are case-insensitive even when quotes are 412 present, and dialects like MySQL, whose resolution rules match those employed by the 413 underlying operating system, for example they may always be case-sensitive in Linux. 414 415 Finally, the normalization behavior of some engines can even be controlled through flags, 416 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 417 418 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 419 that it can analyze queries in the optimizer and successfully capture their semantics. 420 """ 421 if ( 422 isinstance(expression, exp.Identifier) 423 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 424 and ( 425 not expression.quoted 426 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 427 ) 428 ): 429 expression.set( 430 "this", 431 ( 432 expression.this.upper() 433 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 434 else expression.this.lower() 435 ), 436 ) 437 438 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
440 def case_sensitive(self, text: str) -> bool: 441 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 442 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 443 return False 444 445 unsafe = ( 446 str.islower 447 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 448 else str.isupper 449 ) 450 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
452 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 453 """Checks if text can be identified given an identify option. 454 455 Args: 456 text: The text to check. 457 identify: 458 `"always"` or `True`: Always returns `True`. 459 `"safe"`: Only returns `True` if the identifier is case-insensitive. 460 461 Returns: 462 Whether the given text can be identified. 463 """ 464 if identify is True or identify == "always": 465 return True 466 467 if identify == "safe": 468 return not self.case_sensitive(text) 469 470 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
472 def quote_identifier(self, expression: E, identify: bool = True) -> E: 473 """ 474 Adds quotes to a given identifier. 475 476 Args: 477 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 478 identify: If set to `False`, the quotes will only be added if the identifier is deemed 479 "unsafe", with respect to its characters and this dialect's normalization strategy. 480 """ 481 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 482 name = expression.this 483 expression.set( 484 "quoted", 485 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 486 ) 487 488 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
490 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 491 if isinstance(path, exp.Literal): 492 path_text = path.name 493 if path.is_number: 494 path_text = f"[{path_text}]" 495 496 try: 497 return parse_json_path(path_text) 498 except ParseError as e: 499 logger.warning(f"Invalid JSON path syntax. {str(e)}") 500 501 return path
549def if_sql( 550 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 551) -> t.Callable[[Generator, exp.If], str]: 552 def _if_sql(self: Generator, expression: exp.If) -> str: 553 return self.func( 554 name, 555 expression.this, 556 expression.args.get("true"), 557 expression.args.get("false") or false_value, 558 ) 559 560 return _if_sql
563def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 564 this = expression.this 565 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 566 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 567 568 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
632def str_position_sql( 633 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 634) -> str: 635 this = self.sql(expression, "this") 636 substr = self.sql(expression, "substr") 637 position = self.sql(expression, "position") 638 instance = expression.args.get("instance") if generate_instance else None 639 position_offset = "" 640 641 if position: 642 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 643 this = self.func("SUBSTR", this, position) 644 position_offset = f" + {position} - 1" 645 646 return self.func("STRPOS", this, substr, instance) + position_offset
655def var_map_sql( 656 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 657) -> str: 658 keys = expression.args["keys"] 659 values = expression.args["values"] 660 661 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 662 self.unsupported("Cannot convert array columns into map.") 663 return self.func(map_func_name, keys, values) 664 665 args = [] 666 for key, value in zip(keys.expressions, values.expressions): 667 args.append(self.sql(key)) 668 args.append(self.sql(value)) 669 670 return self.func(map_func_name, *args)
673def build_formatted_time( 674 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 675) -> t.Callable[[t.List], E]: 676 """Helper used for time expressions. 677 678 Args: 679 exp_class: the expression class to instantiate. 680 dialect: target sql dialect. 681 default: the default format, True being time. 682 683 Returns: 684 A callable that can be used to return the appropriately formatted time expression. 685 """ 686 687 def _builder(args: t.List): 688 return exp_class( 689 this=seq_get(args, 0), 690 format=Dialect[dialect].format_time( 691 seq_get(args, 1) 692 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 693 ), 694 ) 695 696 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
699def time_format( 700 dialect: DialectType = None, 701) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 702 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 703 """ 704 Returns the time format for a given expression, unless it's equivalent 705 to the default time format of the dialect of interest. 706 """ 707 time_format = self.format_time(expression) 708 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 709 710 return _time_format
713def build_date_delta( 714 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 715) -> t.Callable[[t.List], E]: 716 def _builder(args: t.List) -> E: 717 unit_based = len(args) == 3 718 this = args[2] if unit_based else seq_get(args, 0) 719 unit = args[0] if unit_based else exp.Literal.string("DAY") 720 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 721 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 722 723 return _builder
726def build_date_delta_with_interval( 727 expression_class: t.Type[E], 728) -> t.Callable[[t.List], t.Optional[E]]: 729 def _builder(args: t.List) -> t.Optional[E]: 730 if len(args) < 2: 731 return None 732 733 interval = args[1] 734 735 if not isinstance(interval, exp.Interval): 736 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 737 738 expression = interval.this 739 if expression and expression.is_string: 740 expression = exp.Literal.number(expression.this) 741 742 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 743 744 return _builder
756def date_add_interval_sql( 757 data_type: str, kind: str 758) -> t.Callable[[Generator, exp.Expression], str]: 759 def func(self: Generator, expression: exp.Expression) -> str: 760 this = self.sql(expression, "this") 761 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 762 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 763 764 return func
771def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 772 if not expression.expression: 773 from sqlglot.optimizer.annotate_types import annotate_types 774 775 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 776 return self.sql(exp.cast(expression.this, target_type)) 777 if expression.text("expression").lower() in TIMEZONES: 778 return self.sql( 779 exp.AtTimeZone( 780 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 781 zone=expression.expression, 782 ) 783 ) 784 return self.func("TIMESTAMP", expression.this, expression.expression)
825def encode_decode_sql( 826 self: Generator, expression: exp.Expression, name: str, replace: bool = True 827) -> str: 828 charset = expression.args.get("charset") 829 if charset and charset.name.lower() != "utf-8": 830 self.unsupported(f"Expected utf-8 character set, got {charset}.") 831 832 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
845def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 846 cond = expression.this 847 848 if isinstance(expression.this, exp.Distinct): 849 cond = expression.this.expressions[0] 850 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 851 852 return self.func("sum", exp.func("if", cond, 1, 0))
855def trim_sql(self: Generator, expression: exp.Trim) -> str: 856 target = self.sql(expression, "this") 857 trim_type = self.sql(expression, "position") 858 remove_chars = self.sql(expression, "expression") 859 collation = self.sql(expression, "collation") 860 861 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 862 if not remove_chars and not collation: 863 return self.trim_sql(expression) 864 865 trim_type = f"{trim_type} " if trim_type else "" 866 remove_chars = f"{remove_chars} " if remove_chars else "" 867 from_part = "FROM " if trim_type or remove_chars else "" 868 collation = f" COLLATE {collation}" if collation else "" 869 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
890def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 891 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 892 if bad_args: 893 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 894 895 return self.func( 896 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 897 )
900def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 901 bad_args = list( 902 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 903 ) 904 if bad_args: 905 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 906 907 return self.func( 908 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 909 )
912def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 913 names = [] 914 for agg in aggregations: 915 if isinstance(agg, exp.Alias): 916 names.append(agg.alias) 917 else: 918 """ 919 This case corresponds to aggregations without aliases being used as suffixes 920 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 921 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 922 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 923 """ 924 agg_all_unquoted = agg.transform( 925 lambda node: ( 926 exp.Identifier(this=node.name, quoted=False) 927 if isinstance(node, exp.Identifier) 928 else node 929 ) 930 ) 931 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 932 933 return names
973def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 974 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 975 if expression.args.get("count"): 976 self.unsupported(f"Only two arguments are supported in function {name}.") 977 978 return self.func(name, expression.this, expression.expression) 979 980 return _arg_max_or_min_sql
983def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 984 this = expression.this.copy() 985 986 return_type = expression.return_type 987 if return_type.is_type(exp.DataType.Type.DATE): 988 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 989 # can truncate timestamp strings, because some dialects can't cast them to DATE 990 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 991 992 expression.this.replace(exp.cast(this, return_type)) 993 return expression
996def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 997 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 998 if cast and isinstance(expression, exp.TsOrDsAdd): 999 expression = ts_or_ds_add_cast(expression) 1000 1001 return self.func( 1002 name, 1003 unit_to_var(expression), 1004 expression.expression, 1005 expression.this, 1006 ) 1007 1008 return _delta_sql
1011def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1012 unit = expression.args.get("unit") 1013 1014 if isinstance(unit, exp.Placeholder): 1015 return unit 1016 if unit: 1017 return exp.Literal.string(unit.name) 1018 return exp.Literal.string(default) if default else None
1029def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1030 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1031 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1032 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1033 1034 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1037def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1038 """Remove table refs from columns in when statements.""" 1039 alias = expression.this.args.get("alias") 1040 1041 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1042 return self.dialect.normalize_identifier(identifier).name if identifier else None 1043 1044 targets = {normalize(expression.this.this)} 1045 1046 if alias: 1047 targets.add(normalize(alias.this)) 1048 1049 for when in expression.expressions: 1050 when.transform( 1051 lambda node: ( 1052 exp.column(node.this) 1053 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1054 else node 1055 ), 1056 copy=False, 1057 ) 1058 1059 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1062def build_json_extract_path( 1063 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1064) -> t.Callable[[t.List], F]: 1065 def _builder(args: t.List) -> F: 1066 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1067 for arg in args[1:]: 1068 if not isinstance(arg, exp.Literal): 1069 # We use the fallback parser because we can't really transpile non-literals safely 1070 return expr_type.from_arg_list(args) 1071 1072 text = arg.name 1073 if is_int(text): 1074 index = int(text) 1075 segments.append( 1076 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1077 ) 1078 else: 1079 segments.append(exp.JSONPathKey(this=text)) 1080 1081 # This is done to avoid failing in the expression validator due to the arg count 1082 del args[2:] 1083 return expr_type( 1084 this=seq_get(args, 0), 1085 expression=exp.JSONPath(expressions=segments), 1086 only_json_types=arrow_req_json_type, 1087 ) 1088 1089 return _builder
1092def json_extract_segments( 1093 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1094) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1095 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1096 path = expression.expression 1097 if not isinstance(path, exp.JSONPath): 1098 return rename_func(name)(self, expression) 1099 1100 segments = [] 1101 for segment in path.expressions: 1102 path = self.sql(segment) 1103 if path: 1104 if isinstance(segment, exp.JSONPathPart) and ( 1105 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1106 ): 1107 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1108 1109 segments.append(path) 1110 1111 if op: 1112 return f" {op} ".join([self.sql(expression.this), *segments]) 1113 return self.func(name, expression.this, *segments) 1114 1115 return _json_extract_segments
1125def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1126 cond = expression.expression 1127 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1128 alias = cond.expressions[0] 1129 cond = cond.this 1130 elif isinstance(cond, exp.Predicate): 1131 alias = "_u" 1132 else: 1133 self.unsupported("Unsupported filter condition") 1134 return "" 1135 1136 unnest = exp.Unnest(expressions=[expression.this]) 1137 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1138 return self.sql(exp.Array(expressions=[filtered]))