sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get 12from sqlglot.jsonpath import parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26logger = logging.getLogger("sqlglot") 27 28UNESCAPED_SEQUENCES = { 29 "\\a": "\a", 30 "\\b": "\b", 31 "\\f": "\f", 32 "\\n": "\n", 33 "\\r": "\r", 34 "\\t": "\t", 35 "\\v": "\v", 36 "\\\\": "\\", 37} 38 39 40class Dialects(str, Enum): 41 """Dialects supported by SQLGLot.""" 42 43 DIALECT = "" 44 45 ATHENA = "athena" 46 BIGQUERY = "bigquery" 47 CLICKHOUSE = "clickhouse" 48 DATABRICKS = "databricks" 49 DORIS = "doris" 50 DRILL = "drill" 51 DUCKDB = "duckdb" 52 HIVE = "hive" 53 MYSQL = "mysql" 54 ORACLE = "oracle" 55 POSTGRES = "postgres" 56 PRESTO = "presto" 57 PRQL = "prql" 58 REDSHIFT = "redshift" 59 SNOWFLAKE = "snowflake" 60 SPARK = "spark" 61 SPARK2 = "spark2" 62 SQLITE = "sqlite" 63 STARROCKS = "starrocks" 64 TABLEAU = "tableau" 65 TERADATA = "teradata" 66 TRINO = "trino" 67 TSQL = "tsql" 68 69 70class NormalizationStrategy(str, AutoName): 71 """Specifies the strategy according to which identifiers should be normalized.""" 72 73 LOWERCASE = auto() 74 """Unquoted identifiers are lowercased.""" 75 76 UPPERCASE = auto() 77 """Unquoted identifiers are uppercased.""" 78 79 CASE_SENSITIVE = auto() 80 """Always case-sensitive, regardless of quotes.""" 81 82 CASE_INSENSITIVE = auto() 83 """Always case-insensitive, regardless of quotes.""" 84 85 86class _Dialect(type): 87 classes: t.Dict[str, t.Type[Dialect]] = {} 88 89 def __eq__(cls, other: t.Any) -> bool: 90 if cls is other: 91 return True 92 if isinstance(other, str): 93 return cls is cls.get(other) 94 if isinstance(other, Dialect): 95 return cls is type(other) 96 97 return False 98 99 def __hash__(cls) -> int: 100 return hash(cls.__name__.lower()) 101 102 @classmethod 103 def __getitem__(cls, key: str) -> t.Type[Dialect]: 104 return cls.classes[key] 105 106 @classmethod 107 def get( 108 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 109 ) -> t.Optional[t.Type[Dialect]]: 110 return cls.classes.get(key, default) 111 112 def __new__(cls, clsname, bases, attrs): 113 klass = super().__new__(cls, clsname, bases, attrs) 114 enum = Dialects.__members__.get(clsname.upper()) 115 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 116 117 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 118 klass.FORMAT_TRIE = ( 119 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 120 ) 121 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 122 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 123 124 base = seq_get(bases, 0) 125 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 126 base_parser = (getattr(base, "parser_class", Parser),) 127 base_generator = (getattr(base, "generator_class", Generator),) 128 129 klass.tokenizer_class = klass.__dict__.get( 130 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 131 ) 132 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 133 klass.generator_class = klass.__dict__.get( 134 "Generator", type("Generator", base_generator, {}) 135 ) 136 137 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 138 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 139 klass.tokenizer_class._IDENTIFIERS.items() 140 )[0] 141 142 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 143 return next( 144 ( 145 (s, e) 146 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 147 if t == token_type 148 ), 149 (None, None), 150 ) 151 152 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 153 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 154 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 155 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 156 157 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 158 klass.UNESCAPED_SEQUENCES = { 159 **UNESCAPED_SEQUENCES, 160 **klass.UNESCAPED_SEQUENCES, 161 } 162 163 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 164 165 if enum not in ("", "bigquery"): 166 klass.generator_class.SELECT_KINDS = () 167 168 if enum not in ("", "athena", "presto", "trino"): 169 klass.generator_class.TRY_SUPPORTED = False 170 171 if enum not in ("", "databricks", "hive", "spark", "spark2"): 172 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 173 for modifier in ("cluster", "distribute", "sort"): 174 modifier_transforms.pop(modifier, None) 175 176 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 177 178 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 179 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 180 TokenType.ANTI, 181 TokenType.SEMI, 182 } 183 184 return klass 185 186 187class Dialect(metaclass=_Dialect): 188 INDEX_OFFSET = 0 189 """The base index offset for arrays.""" 190 191 WEEK_OFFSET = 0 192 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 193 194 UNNEST_COLUMN_ONLY = False 195 """Whether `UNNEST` table aliases are treated as column aliases.""" 196 197 ALIAS_POST_TABLESAMPLE = False 198 """Whether the table alias comes after tablesample.""" 199 200 TABLESAMPLE_SIZE_IS_PERCENT = False 201 """Whether a size in the table sample clause represents percentage.""" 202 203 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 204 """Specifies the strategy according to which identifiers should be normalized.""" 205 206 IDENTIFIERS_CAN_START_WITH_DIGIT = False 207 """Whether an unquoted identifier can start with a digit.""" 208 209 DPIPE_IS_STRING_CONCAT = True 210 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 211 212 STRICT_STRING_CONCAT = False 213 """Whether `CONCAT`'s arguments must be strings.""" 214 215 SUPPORTS_USER_DEFINED_TYPES = True 216 """Whether user-defined data types are supported.""" 217 218 SUPPORTS_SEMI_ANTI_JOIN = True 219 """Whether `SEMI` or `ANTI` joins are supported.""" 220 221 NORMALIZE_FUNCTIONS: bool | str = "upper" 222 """ 223 Determines how function names are going to be normalized. 224 Possible values: 225 "upper" or True: Convert names to uppercase. 226 "lower": Convert names to lowercase. 227 False: Disables function name normalization. 228 """ 229 230 LOG_BASE_FIRST: t.Optional[bool] = True 231 """ 232 Whether the base comes first in the `LOG` function. 233 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 234 """ 235 236 NULL_ORDERING = "nulls_are_small" 237 """ 238 Default `NULL` ordering method to use if not explicitly set. 239 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 240 """ 241 242 TYPED_DIVISION = False 243 """ 244 Whether the behavior of `a / b` depends on the types of `a` and `b`. 245 False means `a / b` is always float division. 246 True means `a / b` is integer division if both `a` and `b` are integers. 247 """ 248 249 SAFE_DIVISION = False 250 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 251 252 CONCAT_COALESCE = False 253 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 254 255 HEX_LOWERCASE = False 256 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 257 258 DATE_FORMAT = "'%Y-%m-%d'" 259 DATEINT_FORMAT = "'%Y%m%d'" 260 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 261 262 TIME_MAPPING: t.Dict[str, str] = {} 263 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 264 265 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 266 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 267 FORMAT_MAPPING: t.Dict[str, str] = {} 268 """ 269 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 270 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 271 """ 272 273 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 274 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 275 276 PSEUDOCOLUMNS: t.Set[str] = set() 277 """ 278 Columns that are auto-generated by the engine corresponding to this dialect. 279 For example, such columns may be excluded from `SELECT *` queries. 280 """ 281 282 PREFER_CTE_ALIAS_COLUMN = False 283 """ 284 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 285 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 286 any projection aliases in the subquery. 287 288 For example, 289 WITH y(c) AS ( 290 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 291 ) SELECT c FROM y; 292 293 will be rewritten as 294 295 WITH y(c) AS ( 296 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 297 ) SELECT c FROM y; 298 """ 299 300 # --- Autofilled --- 301 302 tokenizer_class = Tokenizer 303 parser_class = Parser 304 generator_class = Generator 305 306 # A trie of the time_mapping keys 307 TIME_TRIE: t.Dict = {} 308 FORMAT_TRIE: t.Dict = {} 309 310 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 311 INVERSE_TIME_TRIE: t.Dict = {} 312 313 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 314 315 # Delimiters for string literals and identifiers 316 QUOTE_START = "'" 317 QUOTE_END = "'" 318 IDENTIFIER_START = '"' 319 IDENTIFIER_END = '"' 320 321 # Delimiters for bit, hex, byte and unicode literals 322 BIT_START: t.Optional[str] = None 323 BIT_END: t.Optional[str] = None 324 HEX_START: t.Optional[str] = None 325 HEX_END: t.Optional[str] = None 326 BYTE_START: t.Optional[str] = None 327 BYTE_END: t.Optional[str] = None 328 UNICODE_START: t.Optional[str] = None 329 UNICODE_END: t.Optional[str] = None 330 331 # Separator of COPY statement parameters 332 COPY_PARAMS_ARE_CSV = True 333 334 @classmethod 335 def get_or_raise(cls, dialect: DialectType) -> Dialect: 336 """ 337 Look up a dialect in the global dialect registry and return it if it exists. 338 339 Args: 340 dialect: The target dialect. If this is a string, it can be optionally followed by 341 additional key-value pairs that are separated by commas and are used to specify 342 dialect settings, such as whether the dialect's identifiers are case-sensitive. 343 344 Example: 345 >>> dialect = dialect_class = get_or_raise("duckdb") 346 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 347 348 Returns: 349 The corresponding Dialect instance. 350 """ 351 352 if not dialect: 353 return cls() 354 if isinstance(dialect, _Dialect): 355 return dialect() 356 if isinstance(dialect, Dialect): 357 return dialect 358 if isinstance(dialect, str): 359 try: 360 dialect_name, *kv_pairs = dialect.split(",") 361 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 362 except ValueError: 363 raise ValueError( 364 f"Invalid dialect format: '{dialect}'. " 365 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 366 ) 367 368 result = cls.get(dialect_name.strip()) 369 if not result: 370 from difflib import get_close_matches 371 372 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 373 if similar: 374 similar = f" Did you mean {similar}?" 375 376 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 377 378 return result(**kwargs) 379 380 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 381 382 @classmethod 383 def format_time( 384 cls, expression: t.Optional[str | exp.Expression] 385 ) -> t.Optional[exp.Expression]: 386 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 387 if isinstance(expression, str): 388 return exp.Literal.string( 389 # the time formats are quoted 390 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 391 ) 392 393 if expression and expression.is_string: 394 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 395 396 return expression 397 398 def __init__(self, **kwargs) -> None: 399 normalization_strategy = kwargs.get("normalization_strategy") 400 401 if normalization_strategy is None: 402 self.normalization_strategy = self.NORMALIZATION_STRATEGY 403 else: 404 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 405 406 def __eq__(self, other: t.Any) -> bool: 407 # Does not currently take dialect state into account 408 return type(self) == other 409 410 def __hash__(self) -> int: 411 # Does not currently take dialect state into account 412 return hash(type(self)) 413 414 def normalize_identifier(self, expression: E) -> E: 415 """ 416 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 417 418 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 419 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 420 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 421 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 422 423 There are also dialects like Spark, which are case-insensitive even when quotes are 424 present, and dialects like MySQL, whose resolution rules match those employed by the 425 underlying operating system, for example they may always be case-sensitive in Linux. 426 427 Finally, the normalization behavior of some engines can even be controlled through flags, 428 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 429 430 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 431 that it can analyze queries in the optimizer and successfully capture their semantics. 432 """ 433 if ( 434 isinstance(expression, exp.Identifier) 435 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 436 and ( 437 not expression.quoted 438 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 439 ) 440 ): 441 expression.set( 442 "this", 443 ( 444 expression.this.upper() 445 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 446 else expression.this.lower() 447 ), 448 ) 449 450 return expression 451 452 def case_sensitive(self, text: str) -> bool: 453 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 454 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 455 return False 456 457 unsafe = ( 458 str.islower 459 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 460 else str.isupper 461 ) 462 return any(unsafe(char) for char in text) 463 464 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 465 """Checks if text can be identified given an identify option. 466 467 Args: 468 text: The text to check. 469 identify: 470 `"always"` or `True`: Always returns `True`. 471 `"safe"`: Only returns `True` if the identifier is case-insensitive. 472 473 Returns: 474 Whether the given text can be identified. 475 """ 476 if identify is True or identify == "always": 477 return True 478 479 if identify == "safe": 480 return not self.case_sensitive(text) 481 482 return False 483 484 def quote_identifier(self, expression: E, identify: bool = True) -> E: 485 """ 486 Adds quotes to a given identifier. 487 488 Args: 489 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 490 identify: If set to `False`, the quotes will only be added if the identifier is deemed 491 "unsafe", with respect to its characters and this dialect's normalization strategy. 492 """ 493 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 494 name = expression.this 495 expression.set( 496 "quoted", 497 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 498 ) 499 500 return expression 501 502 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 503 if isinstance(path, exp.Literal): 504 path_text = path.name 505 if path.is_number: 506 path_text = f"[{path_text}]" 507 508 try: 509 return parse_json_path(path_text) 510 except ParseError as e: 511 logger.warning(f"Invalid JSON path syntax. {str(e)}") 512 513 return path 514 515 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 516 return self.parser(**opts).parse(self.tokenize(sql), sql) 517 518 def parse_into( 519 self, expression_type: exp.IntoType, sql: str, **opts 520 ) -> t.List[t.Optional[exp.Expression]]: 521 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 522 523 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 524 return self.generator(**opts).generate(expression, copy=copy) 525 526 def transpile(self, sql: str, **opts) -> t.List[str]: 527 return [ 528 self.generate(expression, copy=False, **opts) if expression else "" 529 for expression in self.parse(sql) 530 ] 531 532 def tokenize(self, sql: str) -> t.List[Token]: 533 return self.tokenizer.tokenize(sql) 534 535 @property 536 def tokenizer(self) -> Tokenizer: 537 if not hasattr(self, "_tokenizer"): 538 self._tokenizer = self.tokenizer_class(dialect=self) 539 return self._tokenizer 540 541 def parser(self, **opts) -> Parser: 542 return self.parser_class(dialect=self, **opts) 543 544 def generator(self, **opts) -> Generator: 545 return self.generator_class(dialect=self, **opts) 546 547 548DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 549 550 551def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 552 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 553 554 555def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 556 if expression.args.get("accuracy"): 557 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 558 return self.func("APPROX_COUNT_DISTINCT", expression.this) 559 560 561def if_sql( 562 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 563) -> t.Callable[[Generator, exp.If], str]: 564 def _if_sql(self: Generator, expression: exp.If) -> str: 565 return self.func( 566 name, 567 expression.this, 568 expression.args.get("true"), 569 expression.args.get("false") or false_value, 570 ) 571 572 return _if_sql 573 574 575def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 576 this = expression.this 577 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 578 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 579 580 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 581 582 583def inline_array_sql(self: Generator, expression: exp.Array) -> str: 584 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 585 586 587def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 588 elem = seq_get(expression.expressions, 0) 589 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 590 return self.func("ARRAY", elem) 591 return inline_array_sql(self, expression) 592 593 594def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 595 return self.like_sql( 596 exp.Like( 597 this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) 598 ) 599 ) 600 601 602def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 603 zone = self.sql(expression, "this") 604 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 605 606 607def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 608 if expression.args.get("recursive"): 609 self.unsupported("Recursive CTEs are unsupported") 610 expression.args["recursive"] = False 611 return self.with_sql(expression) 612 613 614def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 615 n = self.sql(expression, "this") 616 d = self.sql(expression, "expression") 617 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 618 619 620def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 621 self.unsupported("TABLESAMPLE unsupported") 622 return self.sql(expression.this) 623 624 625def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 626 self.unsupported("PIVOT unsupported") 627 return "" 628 629 630def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 631 return self.cast_sql(expression) 632 633 634def no_comment_column_constraint_sql( 635 self: Generator, expression: exp.CommentColumnConstraint 636) -> str: 637 self.unsupported("CommentColumnConstraint unsupported") 638 return "" 639 640 641def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 642 self.unsupported("MAP_FROM_ENTRIES unsupported") 643 return "" 644 645 646def str_position_sql( 647 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 648) -> str: 649 this = self.sql(expression, "this") 650 substr = self.sql(expression, "substr") 651 position = self.sql(expression, "position") 652 instance = expression.args.get("instance") if generate_instance else None 653 position_offset = "" 654 655 if position: 656 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 657 this = self.func("SUBSTR", this, position) 658 position_offset = f" + {position} - 1" 659 660 return self.func("STRPOS", this, substr, instance) + position_offset 661 662 663def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 664 return ( 665 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 666 ) 667 668 669def var_map_sql( 670 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 671) -> str: 672 keys = expression.args["keys"] 673 values = expression.args["values"] 674 675 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 676 self.unsupported("Cannot convert array columns into map.") 677 return self.func(map_func_name, keys, values) 678 679 args = [] 680 for key, value in zip(keys.expressions, values.expressions): 681 args.append(self.sql(key)) 682 args.append(self.sql(value)) 683 684 return self.func(map_func_name, *args) 685 686 687def build_formatted_time( 688 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 689) -> t.Callable[[t.List], E]: 690 """Helper used for time expressions. 691 692 Args: 693 exp_class: the expression class to instantiate. 694 dialect: target sql dialect. 695 default: the default format, True being time. 696 697 Returns: 698 A callable that can be used to return the appropriately formatted time expression. 699 """ 700 701 def _builder(args: t.List): 702 return exp_class( 703 this=seq_get(args, 0), 704 format=Dialect[dialect].format_time( 705 seq_get(args, 1) 706 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 707 ), 708 ) 709 710 return _builder 711 712 713def time_format( 714 dialect: DialectType = None, 715) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 716 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 717 """ 718 Returns the time format for a given expression, unless it's equivalent 719 to the default time format of the dialect of interest. 720 """ 721 time_format = self.format_time(expression) 722 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 723 724 return _time_format 725 726 727def build_date_delta( 728 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 729) -> t.Callable[[t.List], E]: 730 def _builder(args: t.List) -> E: 731 unit_based = len(args) == 3 732 this = args[2] if unit_based else seq_get(args, 0) 733 unit = args[0] if unit_based else exp.Literal.string("DAY") 734 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 735 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 736 737 return _builder 738 739 740def build_date_delta_with_interval( 741 expression_class: t.Type[E], 742) -> t.Callable[[t.List], t.Optional[E]]: 743 def _builder(args: t.List) -> t.Optional[E]: 744 if len(args) < 2: 745 return None 746 747 interval = args[1] 748 749 if not isinstance(interval, exp.Interval): 750 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 751 752 expression = interval.this 753 if expression and expression.is_string: 754 expression = exp.Literal.number(expression.this) 755 756 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 757 758 return _builder 759 760 761def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 762 unit = seq_get(args, 0) 763 this = seq_get(args, 1) 764 765 if isinstance(this, exp.Cast) and this.is_type("date"): 766 return exp.DateTrunc(unit=unit, this=this) 767 return exp.TimestampTrunc(this=this, unit=unit) 768 769 770def date_add_interval_sql( 771 data_type: str, kind: str 772) -> t.Callable[[Generator, exp.Expression], str]: 773 def func(self: Generator, expression: exp.Expression) -> str: 774 this = self.sql(expression, "this") 775 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 776 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 777 778 return func 779 780 781def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 782 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 783 args = [unit_to_str(expression), expression.this] 784 if zone: 785 args.append(expression.args.get("zone")) 786 return self.func("DATE_TRUNC", *args) 787 788 return _timestamptrunc_sql 789 790 791def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 792 if not expression.expression: 793 from sqlglot.optimizer.annotate_types import annotate_types 794 795 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 796 return self.sql(exp.cast(expression.this, target_type)) 797 if expression.text("expression").lower() in TIMEZONES: 798 return self.sql( 799 exp.AtTimeZone( 800 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 801 zone=expression.expression, 802 ) 803 ) 804 return self.func("TIMESTAMP", expression.this, expression.expression) 805 806 807def locate_to_strposition(args: t.List) -> exp.Expression: 808 return exp.StrPosition( 809 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 810 ) 811 812 813def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 814 return self.func( 815 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 816 ) 817 818 819def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 820 return self.sql( 821 exp.Substring( 822 this=expression.this, start=exp.Literal.number(1), length=expression.expression 823 ) 824 ) 825 826 827def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 828 return self.sql( 829 exp.Substring( 830 this=expression.this, 831 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 832 ) 833 ) 834 835 836def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 837 return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP)) 838 839 840def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 841 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 842 843 844# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 845def encode_decode_sql( 846 self: Generator, expression: exp.Expression, name: str, replace: bool = True 847) -> str: 848 charset = expression.args.get("charset") 849 if charset and charset.name.lower() != "utf-8": 850 self.unsupported(f"Expected utf-8 character set, got {charset}.") 851 852 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 853 854 855def min_or_least(self: Generator, expression: exp.Min) -> str: 856 name = "LEAST" if expression.expressions else "MIN" 857 return rename_func(name)(self, expression) 858 859 860def max_or_greatest(self: Generator, expression: exp.Max) -> str: 861 name = "GREATEST" if expression.expressions else "MAX" 862 return rename_func(name)(self, expression) 863 864 865def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 866 cond = expression.this 867 868 if isinstance(expression.this, exp.Distinct): 869 cond = expression.this.expressions[0] 870 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 871 872 return self.func("sum", exp.func("if", cond, 1, 0)) 873 874 875def trim_sql(self: Generator, expression: exp.Trim) -> str: 876 target = self.sql(expression, "this") 877 trim_type = self.sql(expression, "position") 878 remove_chars = self.sql(expression, "expression") 879 collation = self.sql(expression, "collation") 880 881 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 882 if not remove_chars and not collation: 883 return self.trim_sql(expression) 884 885 trim_type = f"{trim_type} " if trim_type else "" 886 remove_chars = f"{remove_chars} " if remove_chars else "" 887 from_part = "FROM " if trim_type or remove_chars else "" 888 collation = f" COLLATE {collation}" if collation else "" 889 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 890 891 892def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 893 return self.func("STRPTIME", expression.this, self.format_time(expression)) 894 895 896def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 897 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 898 899 900def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 901 delim, *rest_args = expression.expressions 902 return self.sql( 903 reduce( 904 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 905 rest_args, 906 ) 907 ) 908 909 910def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 911 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 912 if bad_args: 913 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 914 915 return self.func( 916 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 917 ) 918 919 920def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 921 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 922 if bad_args: 923 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 924 925 return self.func( 926 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 927 ) 928 929 930def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 931 names = [] 932 for agg in aggregations: 933 if isinstance(agg, exp.Alias): 934 names.append(agg.alias) 935 else: 936 """ 937 This case corresponds to aggregations without aliases being used as suffixes 938 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 939 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 940 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 941 """ 942 agg_all_unquoted = agg.transform( 943 lambda node: ( 944 exp.Identifier(this=node.name, quoted=False) 945 if isinstance(node, exp.Identifier) 946 else node 947 ) 948 ) 949 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 950 951 return names 952 953 954def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 955 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 956 957 958# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 959def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 960 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 961 962 963def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 964 return self.func("MAX", expression.this) 965 966 967def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 968 a = self.sql(expression.left) 969 b = self.sql(expression.right) 970 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 971 972 973def is_parse_json(expression: exp.Expression) -> bool: 974 return isinstance(expression, exp.ParseJSON) or ( 975 isinstance(expression, exp.Cast) and expression.is_type("json") 976 ) 977 978 979def isnull_to_is_null(args: t.List) -> exp.Expression: 980 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 981 982 983def generatedasidentitycolumnconstraint_sql( 984 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 985) -> str: 986 start = self.sql(expression, "start") or "1" 987 increment = self.sql(expression, "increment") or "1" 988 return f"IDENTITY({start}, {increment})" 989 990 991def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 992 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 993 if expression.args.get("count"): 994 self.unsupported(f"Only two arguments are supported in function {name}.") 995 996 return self.func(name, expression.this, expression.expression) 997 998 return _arg_max_or_min_sql 999 1000 1001def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1002 this = expression.this.copy() 1003 1004 return_type = expression.return_type 1005 if return_type.is_type(exp.DataType.Type.DATE): 1006 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1007 # can truncate timestamp strings, because some dialects can't cast them to DATE 1008 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1009 1010 expression.this.replace(exp.cast(this, return_type)) 1011 return expression 1012 1013 1014def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1015 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1016 if cast and isinstance(expression, exp.TsOrDsAdd): 1017 expression = ts_or_ds_add_cast(expression) 1018 1019 return self.func( 1020 name, 1021 unit_to_var(expression), 1022 expression.expression, 1023 expression.this, 1024 ) 1025 1026 return _delta_sql 1027 1028 1029def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1030 unit = expression.args.get("unit") 1031 1032 if isinstance(unit, exp.Placeholder): 1033 return unit 1034 if unit: 1035 return exp.Literal.string(unit.name) 1036 return exp.Literal.string(default) if default else None 1037 1038 1039def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1040 unit = expression.args.get("unit") 1041 1042 if isinstance(unit, (exp.Var, exp.Placeholder)): 1043 return unit 1044 return exp.Var(this=default) if default else None 1045 1046 1047def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1048 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1049 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1050 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1051 1052 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1053 1054 1055def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1056 """Remove table refs from columns in when statements.""" 1057 alias = expression.this.args.get("alias") 1058 1059 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1060 return self.dialect.normalize_identifier(identifier).name if identifier else None 1061 1062 targets = {normalize(expression.this.this)} 1063 1064 if alias: 1065 targets.add(normalize(alias.this)) 1066 1067 for when in expression.expressions: 1068 when.transform( 1069 lambda node: ( 1070 exp.column(node.this) 1071 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1072 else node 1073 ), 1074 copy=False, 1075 ) 1076 1077 return self.merge_sql(expression) 1078 1079 1080def build_json_extract_path( 1081 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1082) -> t.Callable[[t.List], F]: 1083 def _builder(args: t.List) -> F: 1084 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1085 for arg in args[1:]: 1086 if not isinstance(arg, exp.Literal): 1087 # We use the fallback parser because we can't really transpile non-literals safely 1088 return expr_type.from_arg_list(args) 1089 1090 text = arg.name 1091 if is_int(text): 1092 index = int(text) 1093 segments.append( 1094 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1095 ) 1096 else: 1097 segments.append(exp.JSONPathKey(this=text)) 1098 1099 # This is done to avoid failing in the expression validator due to the arg count 1100 del args[2:] 1101 return expr_type( 1102 this=seq_get(args, 0), 1103 expression=exp.JSONPath(expressions=segments), 1104 only_json_types=arrow_req_json_type, 1105 ) 1106 1107 return _builder 1108 1109 1110def json_extract_segments( 1111 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1112) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1113 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1114 path = expression.expression 1115 if not isinstance(path, exp.JSONPath): 1116 return rename_func(name)(self, expression) 1117 1118 segments = [] 1119 for segment in path.expressions: 1120 path = self.sql(segment) 1121 if path: 1122 if isinstance(segment, exp.JSONPathPart) and ( 1123 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1124 ): 1125 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1126 1127 segments.append(path) 1128 1129 if op: 1130 return f" {op} ".join([self.sql(expression.this), *segments]) 1131 return self.func(name, expression.this, *segments) 1132 1133 return _json_extract_segments 1134 1135 1136def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1137 if isinstance(expression.this, exp.JSONPathWildcard): 1138 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1139 1140 return expression.name 1141 1142 1143def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1144 cond = expression.expression 1145 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1146 alias = cond.expressions[0] 1147 cond = cond.this 1148 elif isinstance(cond, exp.Predicate): 1149 alias = "_u" 1150 else: 1151 self.unsupported("Unsupported filter condition") 1152 return "" 1153 1154 unnest = exp.Unnest(expressions=[expression.this]) 1155 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1156 return self.sql(exp.Array(expressions=[filtered])) 1157 1158 1159def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1160 return self.func( 1161 "TO_NUMBER", 1162 expression.this, 1163 expression.args.get("format"), 1164 expression.args.get("nlsparam"), 1165 ) 1166 1167 1168def build_default_decimal_type( 1169 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1170) -> t.Callable[[exp.DataType], exp.DataType]: 1171 def _builder(dtype: exp.DataType) -> exp.DataType: 1172 if dtype.expressions or precision is None: 1173 return dtype 1174 1175 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1176 return exp.DataType.build(f"DECIMAL({params})") 1177 1178 return _builder
41class Dialects(str, Enum): 42 """Dialects supported by SQLGLot.""" 43 44 DIALECT = "" 45 46 ATHENA = "athena" 47 BIGQUERY = "bigquery" 48 CLICKHOUSE = "clickhouse" 49 DATABRICKS = "databricks" 50 DORIS = "doris" 51 DRILL = "drill" 52 DUCKDB = "duckdb" 53 HIVE = "hive" 54 MYSQL = "mysql" 55 ORACLE = "oracle" 56 POSTGRES = "postgres" 57 PRESTO = "presto" 58 PRQL = "prql" 59 REDSHIFT = "redshift" 60 SNOWFLAKE = "snowflake" 61 SPARK = "spark" 62 SPARK2 = "spark2" 63 SQLITE = "sqlite" 64 STARROCKS = "starrocks" 65 TABLEAU = "tableau" 66 TERADATA = "teradata" 67 TRINO = "trino" 68 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
71class NormalizationStrategy(str, AutoName): 72 """Specifies the strategy according to which identifiers should be normalized.""" 73 74 LOWERCASE = auto() 75 """Unquoted identifiers are lowercased.""" 76 77 UPPERCASE = auto() 78 """Unquoted identifiers are uppercased.""" 79 80 CASE_SENSITIVE = auto() 81 """Always case-sensitive, regardless of quotes.""" 82 83 CASE_INSENSITIVE = auto() 84 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
188class Dialect(metaclass=_Dialect): 189 INDEX_OFFSET = 0 190 """The base index offset for arrays.""" 191 192 WEEK_OFFSET = 0 193 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 194 195 UNNEST_COLUMN_ONLY = False 196 """Whether `UNNEST` table aliases are treated as column aliases.""" 197 198 ALIAS_POST_TABLESAMPLE = False 199 """Whether the table alias comes after tablesample.""" 200 201 TABLESAMPLE_SIZE_IS_PERCENT = False 202 """Whether a size in the table sample clause represents percentage.""" 203 204 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 205 """Specifies the strategy according to which identifiers should be normalized.""" 206 207 IDENTIFIERS_CAN_START_WITH_DIGIT = False 208 """Whether an unquoted identifier can start with a digit.""" 209 210 DPIPE_IS_STRING_CONCAT = True 211 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 212 213 STRICT_STRING_CONCAT = False 214 """Whether `CONCAT`'s arguments must be strings.""" 215 216 SUPPORTS_USER_DEFINED_TYPES = True 217 """Whether user-defined data types are supported.""" 218 219 SUPPORTS_SEMI_ANTI_JOIN = True 220 """Whether `SEMI` or `ANTI` joins are supported.""" 221 222 NORMALIZE_FUNCTIONS: bool | str = "upper" 223 """ 224 Determines how function names are going to be normalized. 225 Possible values: 226 "upper" or True: Convert names to uppercase. 227 "lower": Convert names to lowercase. 228 False: Disables function name normalization. 229 """ 230 231 LOG_BASE_FIRST: t.Optional[bool] = True 232 """ 233 Whether the base comes first in the `LOG` function. 234 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 235 """ 236 237 NULL_ORDERING = "nulls_are_small" 238 """ 239 Default `NULL` ordering method to use if not explicitly set. 240 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 241 """ 242 243 TYPED_DIVISION = False 244 """ 245 Whether the behavior of `a / b` depends on the types of `a` and `b`. 246 False means `a / b` is always float division. 247 True means `a / b` is integer division if both `a` and `b` are integers. 248 """ 249 250 SAFE_DIVISION = False 251 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 252 253 CONCAT_COALESCE = False 254 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 255 256 HEX_LOWERCASE = False 257 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 258 259 DATE_FORMAT = "'%Y-%m-%d'" 260 DATEINT_FORMAT = "'%Y%m%d'" 261 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 262 263 TIME_MAPPING: t.Dict[str, str] = {} 264 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 265 266 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 267 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 268 FORMAT_MAPPING: t.Dict[str, str] = {} 269 """ 270 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 271 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 272 """ 273 274 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 275 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 276 277 PSEUDOCOLUMNS: t.Set[str] = set() 278 """ 279 Columns that are auto-generated by the engine corresponding to this dialect. 280 For example, such columns may be excluded from `SELECT *` queries. 281 """ 282 283 PREFER_CTE_ALIAS_COLUMN = False 284 """ 285 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 286 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 287 any projection aliases in the subquery. 288 289 For example, 290 WITH y(c) AS ( 291 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 292 ) SELECT c FROM y; 293 294 will be rewritten as 295 296 WITH y(c) AS ( 297 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 298 ) SELECT c FROM y; 299 """ 300 301 # --- Autofilled --- 302 303 tokenizer_class = Tokenizer 304 parser_class = Parser 305 generator_class = Generator 306 307 # A trie of the time_mapping keys 308 TIME_TRIE: t.Dict = {} 309 FORMAT_TRIE: t.Dict = {} 310 311 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 312 INVERSE_TIME_TRIE: t.Dict = {} 313 314 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 315 316 # Delimiters for string literals and identifiers 317 QUOTE_START = "'" 318 QUOTE_END = "'" 319 IDENTIFIER_START = '"' 320 IDENTIFIER_END = '"' 321 322 # Delimiters for bit, hex, byte and unicode literals 323 BIT_START: t.Optional[str] = None 324 BIT_END: t.Optional[str] = None 325 HEX_START: t.Optional[str] = None 326 HEX_END: t.Optional[str] = None 327 BYTE_START: t.Optional[str] = None 328 BYTE_END: t.Optional[str] = None 329 UNICODE_START: t.Optional[str] = None 330 UNICODE_END: t.Optional[str] = None 331 332 # Separator of COPY statement parameters 333 COPY_PARAMS_ARE_CSV = True 334 335 @classmethod 336 def get_or_raise(cls, dialect: DialectType) -> Dialect: 337 """ 338 Look up a dialect in the global dialect registry and return it if it exists. 339 340 Args: 341 dialect: The target dialect. If this is a string, it can be optionally followed by 342 additional key-value pairs that are separated by commas and are used to specify 343 dialect settings, such as whether the dialect's identifiers are case-sensitive. 344 345 Example: 346 >>> dialect = dialect_class = get_or_raise("duckdb") 347 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 348 349 Returns: 350 The corresponding Dialect instance. 351 """ 352 353 if not dialect: 354 return cls() 355 if isinstance(dialect, _Dialect): 356 return dialect() 357 if isinstance(dialect, Dialect): 358 return dialect 359 if isinstance(dialect, str): 360 try: 361 dialect_name, *kv_pairs = dialect.split(",") 362 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 363 except ValueError: 364 raise ValueError( 365 f"Invalid dialect format: '{dialect}'. " 366 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 367 ) 368 369 result = cls.get(dialect_name.strip()) 370 if not result: 371 from difflib import get_close_matches 372 373 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 374 if similar: 375 similar = f" Did you mean {similar}?" 376 377 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 378 379 return result(**kwargs) 380 381 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 382 383 @classmethod 384 def format_time( 385 cls, expression: t.Optional[str | exp.Expression] 386 ) -> t.Optional[exp.Expression]: 387 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 388 if isinstance(expression, str): 389 return exp.Literal.string( 390 # the time formats are quoted 391 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 392 ) 393 394 if expression and expression.is_string: 395 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 396 397 return expression 398 399 def __init__(self, **kwargs) -> None: 400 normalization_strategy = kwargs.get("normalization_strategy") 401 402 if normalization_strategy is None: 403 self.normalization_strategy = self.NORMALIZATION_STRATEGY 404 else: 405 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 406 407 def __eq__(self, other: t.Any) -> bool: 408 # Does not currently take dialect state into account 409 return type(self) == other 410 411 def __hash__(self) -> int: 412 # Does not currently take dialect state into account 413 return hash(type(self)) 414 415 def normalize_identifier(self, expression: E) -> E: 416 """ 417 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 418 419 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 420 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 421 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 422 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 423 424 There are also dialects like Spark, which are case-insensitive even when quotes are 425 present, and dialects like MySQL, whose resolution rules match those employed by the 426 underlying operating system, for example they may always be case-sensitive in Linux. 427 428 Finally, the normalization behavior of some engines can even be controlled through flags, 429 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 430 431 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 432 that it can analyze queries in the optimizer and successfully capture their semantics. 433 """ 434 if ( 435 isinstance(expression, exp.Identifier) 436 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 437 and ( 438 not expression.quoted 439 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 440 ) 441 ): 442 expression.set( 443 "this", 444 ( 445 expression.this.upper() 446 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 447 else expression.this.lower() 448 ), 449 ) 450 451 return expression 452 453 def case_sensitive(self, text: str) -> bool: 454 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 455 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 456 return False 457 458 unsafe = ( 459 str.islower 460 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 461 else str.isupper 462 ) 463 return any(unsafe(char) for char in text) 464 465 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 466 """Checks if text can be identified given an identify option. 467 468 Args: 469 text: The text to check. 470 identify: 471 `"always"` or `True`: Always returns `True`. 472 `"safe"`: Only returns `True` if the identifier is case-insensitive. 473 474 Returns: 475 Whether the given text can be identified. 476 """ 477 if identify is True or identify == "always": 478 return True 479 480 if identify == "safe": 481 return not self.case_sensitive(text) 482 483 return False 484 485 def quote_identifier(self, expression: E, identify: bool = True) -> E: 486 """ 487 Adds quotes to a given identifier. 488 489 Args: 490 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 491 identify: If set to `False`, the quotes will only be added if the identifier is deemed 492 "unsafe", with respect to its characters and this dialect's normalization strategy. 493 """ 494 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 495 name = expression.this 496 expression.set( 497 "quoted", 498 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 499 ) 500 501 return expression 502 503 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 504 if isinstance(path, exp.Literal): 505 path_text = path.name 506 if path.is_number: 507 path_text = f"[{path_text}]" 508 509 try: 510 return parse_json_path(path_text) 511 except ParseError as e: 512 logger.warning(f"Invalid JSON path syntax. {str(e)}") 513 514 return path 515 516 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 517 return self.parser(**opts).parse(self.tokenize(sql), sql) 518 519 def parse_into( 520 self, expression_type: exp.IntoType, sql: str, **opts 521 ) -> t.List[t.Optional[exp.Expression]]: 522 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 523 524 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 525 return self.generator(**opts).generate(expression, copy=copy) 526 527 def transpile(self, sql: str, **opts) -> t.List[str]: 528 return [ 529 self.generate(expression, copy=False, **opts) if expression else "" 530 for expression in self.parse(sql) 531 ] 532 533 def tokenize(self, sql: str) -> t.List[Token]: 534 return self.tokenizer.tokenize(sql) 535 536 @property 537 def tokenizer(self) -> Tokenizer: 538 if not hasattr(self, "_tokenizer"): 539 self._tokenizer = self.tokenizer_class(dialect=self) 540 return self._tokenizer 541 542 def parser(self, **opts) -> Parser: 543 return self.parser_class(dialect=self, **opts) 544 545 def generator(self, **opts) -> Generator: 546 return self.generator_class(dialect=self, **opts)
399 def __init__(self, **kwargs) -> None: 400 normalization_strategy = kwargs.get("normalization_strategy") 401 402 if normalization_strategy is None: 403 self.normalization_strategy = self.NORMALIZATION_STRATEGY 404 else: 405 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
335 @classmethod 336 def get_or_raise(cls, dialect: DialectType) -> Dialect: 337 """ 338 Look up a dialect in the global dialect registry and return it if it exists. 339 340 Args: 341 dialect: The target dialect. If this is a string, it can be optionally followed by 342 additional key-value pairs that are separated by commas and are used to specify 343 dialect settings, such as whether the dialect's identifiers are case-sensitive. 344 345 Example: 346 >>> dialect = dialect_class = get_or_raise("duckdb") 347 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 348 349 Returns: 350 The corresponding Dialect instance. 351 """ 352 353 if not dialect: 354 return cls() 355 if isinstance(dialect, _Dialect): 356 return dialect() 357 if isinstance(dialect, Dialect): 358 return dialect 359 if isinstance(dialect, str): 360 try: 361 dialect_name, *kv_pairs = dialect.split(",") 362 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 363 except ValueError: 364 raise ValueError( 365 f"Invalid dialect format: '{dialect}'. " 366 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 367 ) 368 369 result = cls.get(dialect_name.strip()) 370 if not result: 371 from difflib import get_close_matches 372 373 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 374 if similar: 375 similar = f" Did you mean {similar}?" 376 377 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 378 379 return result(**kwargs) 380 381 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
383 @classmethod 384 def format_time( 385 cls, expression: t.Optional[str | exp.Expression] 386 ) -> t.Optional[exp.Expression]: 387 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 388 if isinstance(expression, str): 389 return exp.Literal.string( 390 # the time formats are quoted 391 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 392 ) 393 394 if expression and expression.is_string: 395 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 396 397 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
415 def normalize_identifier(self, expression: E) -> E: 416 """ 417 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 418 419 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 420 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 421 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 422 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 423 424 There are also dialects like Spark, which are case-insensitive even when quotes are 425 present, and dialects like MySQL, whose resolution rules match those employed by the 426 underlying operating system, for example they may always be case-sensitive in Linux. 427 428 Finally, the normalization behavior of some engines can even be controlled through flags, 429 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 430 431 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 432 that it can analyze queries in the optimizer and successfully capture their semantics. 433 """ 434 if ( 435 isinstance(expression, exp.Identifier) 436 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 437 and ( 438 not expression.quoted 439 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 440 ) 441 ): 442 expression.set( 443 "this", 444 ( 445 expression.this.upper() 446 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 447 else expression.this.lower() 448 ), 449 ) 450 451 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
453 def case_sensitive(self, text: str) -> bool: 454 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 455 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 456 return False 457 458 unsafe = ( 459 str.islower 460 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 461 else str.isupper 462 ) 463 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
465 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 466 """Checks if text can be identified given an identify option. 467 468 Args: 469 text: The text to check. 470 identify: 471 `"always"` or `True`: Always returns `True`. 472 `"safe"`: Only returns `True` if the identifier is case-insensitive. 473 474 Returns: 475 Whether the given text can be identified. 476 """ 477 if identify is True or identify == "always": 478 return True 479 480 if identify == "safe": 481 return not self.case_sensitive(text) 482 483 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
485 def quote_identifier(self, expression: E, identify: bool = True) -> E: 486 """ 487 Adds quotes to a given identifier. 488 489 Args: 490 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 491 identify: If set to `False`, the quotes will only be added if the identifier is deemed 492 "unsafe", with respect to its characters and this dialect's normalization strategy. 493 """ 494 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 495 name = expression.this 496 expression.set( 497 "quoted", 498 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 499 ) 500 501 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
503 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 504 if isinstance(path, exp.Literal): 505 path_text = path.name 506 if path.is_number: 507 path_text = f"[{path_text}]" 508 509 try: 510 return parse_json_path(path_text) 511 except ParseError as e: 512 logger.warning(f"Invalid JSON path syntax. {str(e)}") 513 514 return path
562def if_sql( 563 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 564) -> t.Callable[[Generator, exp.If], str]: 565 def _if_sql(self: Generator, expression: exp.If) -> str: 566 return self.func( 567 name, 568 expression.this, 569 expression.args.get("true"), 570 expression.args.get("false") or false_value, 571 ) 572 573 return _if_sql
576def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 577 this = expression.this 578 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 579 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 580 581 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
647def str_position_sql( 648 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 649) -> str: 650 this = self.sql(expression, "this") 651 substr = self.sql(expression, "substr") 652 position = self.sql(expression, "position") 653 instance = expression.args.get("instance") if generate_instance else None 654 position_offset = "" 655 656 if position: 657 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 658 this = self.func("SUBSTR", this, position) 659 position_offset = f" + {position} - 1" 660 661 return self.func("STRPOS", this, substr, instance) + position_offset
670def var_map_sql( 671 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 672) -> str: 673 keys = expression.args["keys"] 674 values = expression.args["values"] 675 676 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 677 self.unsupported("Cannot convert array columns into map.") 678 return self.func(map_func_name, keys, values) 679 680 args = [] 681 for key, value in zip(keys.expressions, values.expressions): 682 args.append(self.sql(key)) 683 args.append(self.sql(value)) 684 685 return self.func(map_func_name, *args)
688def build_formatted_time( 689 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 690) -> t.Callable[[t.List], E]: 691 """Helper used for time expressions. 692 693 Args: 694 exp_class: the expression class to instantiate. 695 dialect: target sql dialect. 696 default: the default format, True being time. 697 698 Returns: 699 A callable that can be used to return the appropriately formatted time expression. 700 """ 701 702 def _builder(args: t.List): 703 return exp_class( 704 this=seq_get(args, 0), 705 format=Dialect[dialect].format_time( 706 seq_get(args, 1) 707 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 708 ), 709 ) 710 711 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
714def time_format( 715 dialect: DialectType = None, 716) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 717 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 718 """ 719 Returns the time format for a given expression, unless it's equivalent 720 to the default time format of the dialect of interest. 721 """ 722 time_format = self.format_time(expression) 723 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 724 725 return _time_format
728def build_date_delta( 729 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 730) -> t.Callable[[t.List], E]: 731 def _builder(args: t.List) -> E: 732 unit_based = len(args) == 3 733 this = args[2] if unit_based else seq_get(args, 0) 734 unit = args[0] if unit_based else exp.Literal.string("DAY") 735 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 736 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 737 738 return _builder
741def build_date_delta_with_interval( 742 expression_class: t.Type[E], 743) -> t.Callable[[t.List], t.Optional[E]]: 744 def _builder(args: t.List) -> t.Optional[E]: 745 if len(args) < 2: 746 return None 747 748 interval = args[1] 749 750 if not isinstance(interval, exp.Interval): 751 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 752 753 expression = interval.this 754 if expression and expression.is_string: 755 expression = exp.Literal.number(expression.this) 756 757 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 758 759 return _builder
771def date_add_interval_sql( 772 data_type: str, kind: str 773) -> t.Callable[[Generator, exp.Expression], str]: 774 def func(self: Generator, expression: exp.Expression) -> str: 775 this = self.sql(expression, "this") 776 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 777 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 778 779 return func
782def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 783 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 784 args = [unit_to_str(expression), expression.this] 785 if zone: 786 args.append(expression.args.get("zone")) 787 return self.func("DATE_TRUNC", *args) 788 789 return _timestamptrunc_sql
792def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 793 if not expression.expression: 794 from sqlglot.optimizer.annotate_types import annotate_types 795 796 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 797 return self.sql(exp.cast(expression.this, target_type)) 798 if expression.text("expression").lower() in TIMEZONES: 799 return self.sql( 800 exp.AtTimeZone( 801 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 802 zone=expression.expression, 803 ) 804 ) 805 return self.func("TIMESTAMP", expression.this, expression.expression)
846def encode_decode_sql( 847 self: Generator, expression: exp.Expression, name: str, replace: bool = True 848) -> str: 849 charset = expression.args.get("charset") 850 if charset and charset.name.lower() != "utf-8": 851 self.unsupported(f"Expected utf-8 character set, got {charset}.") 852 853 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
866def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 867 cond = expression.this 868 869 if isinstance(expression.this, exp.Distinct): 870 cond = expression.this.expressions[0] 871 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 872 873 return self.func("sum", exp.func("if", cond, 1, 0))
876def trim_sql(self: Generator, expression: exp.Trim) -> str: 877 target = self.sql(expression, "this") 878 trim_type = self.sql(expression, "position") 879 remove_chars = self.sql(expression, "expression") 880 collation = self.sql(expression, "collation") 881 882 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 883 if not remove_chars and not collation: 884 return self.trim_sql(expression) 885 886 trim_type = f"{trim_type} " if trim_type else "" 887 remove_chars = f"{remove_chars} " if remove_chars else "" 888 from_part = "FROM " if trim_type or remove_chars else "" 889 collation = f" COLLATE {collation}" if collation else "" 890 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
911def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 912 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 913 if bad_args: 914 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 915 916 return self.func( 917 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 918 )
921def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 922 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 923 if bad_args: 924 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 925 926 return self.func( 927 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 928 )
931def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 932 names = [] 933 for agg in aggregations: 934 if isinstance(agg, exp.Alias): 935 names.append(agg.alias) 936 else: 937 """ 938 This case corresponds to aggregations without aliases being used as suffixes 939 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 940 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 941 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 942 """ 943 agg_all_unquoted = agg.transform( 944 lambda node: ( 945 exp.Identifier(this=node.name, quoted=False) 946 if isinstance(node, exp.Identifier) 947 else node 948 ) 949 ) 950 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 951 952 return names
992def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 993 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 994 if expression.args.get("count"): 995 self.unsupported(f"Only two arguments are supported in function {name}.") 996 997 return self.func(name, expression.this, expression.expression) 998 999 return _arg_max_or_min_sql
1002def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1003 this = expression.this.copy() 1004 1005 return_type = expression.return_type 1006 if return_type.is_type(exp.DataType.Type.DATE): 1007 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1008 # can truncate timestamp strings, because some dialects can't cast them to DATE 1009 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1010 1011 expression.this.replace(exp.cast(this, return_type)) 1012 return expression
1015def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1016 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1017 if cast and isinstance(expression, exp.TsOrDsAdd): 1018 expression = ts_or_ds_add_cast(expression) 1019 1020 return self.func( 1021 name, 1022 unit_to_var(expression), 1023 expression.expression, 1024 expression.this, 1025 ) 1026 1027 return _delta_sql
1030def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1031 unit = expression.args.get("unit") 1032 1033 if isinstance(unit, exp.Placeholder): 1034 return unit 1035 if unit: 1036 return exp.Literal.string(unit.name) 1037 return exp.Literal.string(default) if default else None
1048def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1049 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1050 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1051 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1052 1053 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1056def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1057 """Remove table refs from columns in when statements.""" 1058 alias = expression.this.args.get("alias") 1059 1060 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1061 return self.dialect.normalize_identifier(identifier).name if identifier else None 1062 1063 targets = {normalize(expression.this.this)} 1064 1065 if alias: 1066 targets.add(normalize(alias.this)) 1067 1068 for when in expression.expressions: 1069 when.transform( 1070 lambda node: ( 1071 exp.column(node.this) 1072 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1073 else node 1074 ), 1075 copy=False, 1076 ) 1077 1078 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1081def build_json_extract_path( 1082 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1083) -> t.Callable[[t.List], F]: 1084 def _builder(args: t.List) -> F: 1085 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1086 for arg in args[1:]: 1087 if not isinstance(arg, exp.Literal): 1088 # We use the fallback parser because we can't really transpile non-literals safely 1089 return expr_type.from_arg_list(args) 1090 1091 text = arg.name 1092 if is_int(text): 1093 index = int(text) 1094 segments.append( 1095 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1096 ) 1097 else: 1098 segments.append(exp.JSONPathKey(this=text)) 1099 1100 # This is done to avoid failing in the expression validator due to the arg count 1101 del args[2:] 1102 return expr_type( 1103 this=seq_get(args, 0), 1104 expression=exp.JSONPath(expressions=segments), 1105 only_json_types=arrow_req_json_type, 1106 ) 1107 1108 return _builder
1111def json_extract_segments( 1112 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1113) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1114 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1115 path = expression.expression 1116 if not isinstance(path, exp.JSONPath): 1117 return rename_func(name)(self, expression) 1118 1119 segments = [] 1120 for segment in path.expressions: 1121 path = self.sql(segment) 1122 if path: 1123 if isinstance(segment, exp.JSONPathPart) and ( 1124 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1125 ): 1126 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1127 1128 segments.append(path) 1129 1130 if op: 1131 return f" {op} ".join([self.sql(expression.this), *segments]) 1132 return self.func(name, expression.this, *segments) 1133 1134 return _json_extract_segments
1144def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1145 cond = expression.expression 1146 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1147 alias = cond.expressions[0] 1148 cond = cond.this 1149 elif isinstance(cond, exp.Predicate): 1150 alias = "_u" 1151 else: 1152 self.unsupported("Unsupported filter condition") 1153 return "" 1154 1155 unnest = exp.Unnest(expressions=[expression.this]) 1156 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1157 return self.sql(exp.Array(expressions=[filtered]))
1169def build_default_decimal_type( 1170 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1171) -> t.Callable[[exp.DataType], exp.DataType]: 1172 def _builder(dtype: exp.DataType) -> exp.DataType: 1173 if dtype.expressions or precision is None: 1174 return dtype 1175 1176 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1177 return exp.DataType.build(f"DECIMAL({params})") 1178 1179 return _builder