sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get 12from sqlglot.jsonpath import parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26logger = logging.getLogger("sqlglot") 27 28UNESCAPED_SEQUENCES = { 29 "\\a": "\a", 30 "\\b": "\b", 31 "\\f": "\f", 32 "\\n": "\n", 33 "\\r": "\r", 34 "\\t": "\t", 35 "\\v": "\v", 36 "\\\\": "\\", 37} 38 39 40class Dialects(str, Enum): 41 """Dialects supported by SQLGLot.""" 42 43 DIALECT = "" 44 45 ATHENA = "athena" 46 BIGQUERY = "bigquery" 47 CLICKHOUSE = "clickhouse" 48 DATABRICKS = "databricks" 49 DORIS = "doris" 50 DRILL = "drill" 51 DUCKDB = "duckdb" 52 HIVE = "hive" 53 MATERIALIZE = "materialize" 54 MYSQL = "mysql" 55 ORACLE = "oracle" 56 POSTGRES = "postgres" 57 PRESTO = "presto" 58 PRQL = "prql" 59 REDSHIFT = "redshift" 60 RISINGWAVE = "risingwave" 61 SNOWFLAKE = "snowflake" 62 SPARK = "spark" 63 SPARK2 = "spark2" 64 SQLITE = "sqlite" 65 STARROCKS = "starrocks" 66 TABLEAU = "tableau" 67 TERADATA = "teradata" 68 TRINO = "trino" 69 TSQL = "tsql" 70 71 72class NormalizationStrategy(str, AutoName): 73 """Specifies the strategy according to which identifiers should be normalized.""" 74 75 LOWERCASE = auto() 76 """Unquoted identifiers are lowercased.""" 77 78 UPPERCASE = auto() 79 """Unquoted identifiers are uppercased.""" 80 81 CASE_SENSITIVE = auto() 82 """Always case-sensitive, regardless of quotes.""" 83 84 CASE_INSENSITIVE = auto() 85 """Always case-insensitive, regardless of quotes.""" 86 87 88class _Dialect(type): 89 classes: t.Dict[str, t.Type[Dialect]] = {} 90 91 def __eq__(cls, other: t.Any) -> bool: 92 if cls is other: 93 return True 94 if isinstance(other, str): 95 return cls is cls.get(other) 96 if isinstance(other, Dialect): 97 return cls is type(other) 98 99 return False 100 101 def __hash__(cls) -> int: 102 return hash(cls.__name__.lower()) 103 104 @classmethod 105 def __getitem__(cls, key: str) -> t.Type[Dialect]: 106 return cls.classes[key] 107 108 @classmethod 109 def get( 110 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 111 ) -> t.Optional[t.Type[Dialect]]: 112 return cls.classes.get(key, default) 113 114 def __new__(cls, clsname, bases, attrs): 115 klass = super().__new__(cls, clsname, bases, attrs) 116 enum = Dialects.__members__.get(clsname.upper()) 117 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 118 119 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 120 klass.FORMAT_TRIE = ( 121 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 122 ) 123 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 124 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 125 126 base = seq_get(bases, 0) 127 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 128 base_parser = (getattr(base, "parser_class", Parser),) 129 base_generator = (getattr(base, "generator_class", Generator),) 130 131 klass.tokenizer_class = klass.__dict__.get( 132 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 133 ) 134 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 135 klass.generator_class = klass.__dict__.get( 136 "Generator", type("Generator", base_generator, {}) 137 ) 138 139 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 140 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 141 klass.tokenizer_class._IDENTIFIERS.items() 142 )[0] 143 144 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 145 return next( 146 ( 147 (s, e) 148 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 149 if t == token_type 150 ), 151 (None, None), 152 ) 153 154 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 155 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 156 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 157 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 158 159 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 160 klass.UNESCAPED_SEQUENCES = { 161 **UNESCAPED_SEQUENCES, 162 **klass.UNESCAPED_SEQUENCES, 163 } 164 165 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 166 167 if enum not in ("", "bigquery"): 168 klass.generator_class.SELECT_KINDS = () 169 170 if enum not in ("", "athena", "presto", "trino"): 171 klass.generator_class.TRY_SUPPORTED = False 172 173 if enum not in ("", "databricks", "hive", "spark", "spark2"): 174 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 175 for modifier in ("cluster", "distribute", "sort"): 176 modifier_transforms.pop(modifier, None) 177 178 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 179 180 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 181 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 182 TokenType.ANTI, 183 TokenType.SEMI, 184 } 185 186 return klass 187 188 189class Dialect(metaclass=_Dialect): 190 INDEX_OFFSET = 0 191 """The base index offset for arrays.""" 192 193 WEEK_OFFSET = 0 194 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 195 196 UNNEST_COLUMN_ONLY = False 197 """Whether `UNNEST` table aliases are treated as column aliases.""" 198 199 ALIAS_POST_TABLESAMPLE = False 200 """Whether the table alias comes after tablesample.""" 201 202 TABLESAMPLE_SIZE_IS_PERCENT = False 203 """Whether a size in the table sample clause represents percentage.""" 204 205 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 206 """Specifies the strategy according to which identifiers should be normalized.""" 207 208 IDENTIFIERS_CAN_START_WITH_DIGIT = False 209 """Whether an unquoted identifier can start with a digit.""" 210 211 DPIPE_IS_STRING_CONCAT = True 212 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 213 214 STRICT_STRING_CONCAT = False 215 """Whether `CONCAT`'s arguments must be strings.""" 216 217 SUPPORTS_USER_DEFINED_TYPES = True 218 """Whether user-defined data types are supported.""" 219 220 SUPPORTS_SEMI_ANTI_JOIN = True 221 """Whether `SEMI` or `ANTI` joins are supported.""" 222 223 NORMALIZE_FUNCTIONS: bool | str = "upper" 224 """ 225 Determines how function names are going to be normalized. 226 Possible values: 227 "upper" or True: Convert names to uppercase. 228 "lower": Convert names to lowercase. 229 False: Disables function name normalization. 230 """ 231 232 LOG_BASE_FIRST: t.Optional[bool] = True 233 """ 234 Whether the base comes first in the `LOG` function. 235 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 236 """ 237 238 NULL_ORDERING = "nulls_are_small" 239 """ 240 Default `NULL` ordering method to use if not explicitly set. 241 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 242 """ 243 244 TYPED_DIVISION = False 245 """ 246 Whether the behavior of `a / b` depends on the types of `a` and `b`. 247 False means `a / b` is always float division. 248 True means `a / b` is integer division if both `a` and `b` are integers. 249 """ 250 251 SAFE_DIVISION = False 252 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 253 254 CONCAT_COALESCE = False 255 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 256 257 HEX_LOWERCASE = False 258 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 259 260 DATE_FORMAT = "'%Y-%m-%d'" 261 DATEINT_FORMAT = "'%Y%m%d'" 262 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 263 264 TIME_MAPPING: t.Dict[str, str] = {} 265 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 266 267 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 268 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 269 FORMAT_MAPPING: t.Dict[str, str] = {} 270 """ 271 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 272 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 273 """ 274 275 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 276 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 277 278 PSEUDOCOLUMNS: t.Set[str] = set() 279 """ 280 Columns that are auto-generated by the engine corresponding to this dialect. 281 For example, such columns may be excluded from `SELECT *` queries. 282 """ 283 284 PREFER_CTE_ALIAS_COLUMN = False 285 """ 286 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 287 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 288 any projection aliases in the subquery. 289 290 For example, 291 WITH y(c) AS ( 292 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 293 ) SELECT c FROM y; 294 295 will be rewritten as 296 297 WITH y(c) AS ( 298 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 299 ) SELECT c FROM y; 300 """ 301 302 # --- Autofilled --- 303 304 tokenizer_class = Tokenizer 305 parser_class = Parser 306 generator_class = Generator 307 308 # A trie of the time_mapping keys 309 TIME_TRIE: t.Dict = {} 310 FORMAT_TRIE: t.Dict = {} 311 312 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 313 INVERSE_TIME_TRIE: t.Dict = {} 314 315 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 316 317 # Delimiters for string literals and identifiers 318 QUOTE_START = "'" 319 QUOTE_END = "'" 320 IDENTIFIER_START = '"' 321 IDENTIFIER_END = '"' 322 323 # Delimiters for bit, hex, byte and unicode literals 324 BIT_START: t.Optional[str] = None 325 BIT_END: t.Optional[str] = None 326 HEX_START: t.Optional[str] = None 327 HEX_END: t.Optional[str] = None 328 BYTE_START: t.Optional[str] = None 329 BYTE_END: t.Optional[str] = None 330 UNICODE_START: t.Optional[str] = None 331 UNICODE_END: t.Optional[str] = None 332 333 # Separator of COPY statement parameters 334 COPY_PARAMS_ARE_CSV = True 335 336 @classmethod 337 def get_or_raise(cls, dialect: DialectType) -> Dialect: 338 """ 339 Look up a dialect in the global dialect registry and return it if it exists. 340 341 Args: 342 dialect: The target dialect. If this is a string, it can be optionally followed by 343 additional key-value pairs that are separated by commas and are used to specify 344 dialect settings, such as whether the dialect's identifiers are case-sensitive. 345 346 Example: 347 >>> dialect = dialect_class = get_or_raise("duckdb") 348 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 349 350 Returns: 351 The corresponding Dialect instance. 352 """ 353 354 if not dialect: 355 return cls() 356 if isinstance(dialect, _Dialect): 357 return dialect() 358 if isinstance(dialect, Dialect): 359 return dialect 360 if isinstance(dialect, str): 361 try: 362 dialect_name, *kv_pairs = dialect.split(",") 363 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 364 except ValueError: 365 raise ValueError( 366 f"Invalid dialect format: '{dialect}'. " 367 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 368 ) 369 370 result = cls.get(dialect_name.strip()) 371 if not result: 372 from difflib import get_close_matches 373 374 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 375 if similar: 376 similar = f" Did you mean {similar}?" 377 378 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 379 380 return result(**kwargs) 381 382 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 383 384 @classmethod 385 def format_time( 386 cls, expression: t.Optional[str | exp.Expression] 387 ) -> t.Optional[exp.Expression]: 388 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 389 if isinstance(expression, str): 390 return exp.Literal.string( 391 # the time formats are quoted 392 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 393 ) 394 395 if expression and expression.is_string: 396 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 397 398 return expression 399 400 def __init__(self, **kwargs) -> None: 401 normalization_strategy = kwargs.get("normalization_strategy") 402 403 if normalization_strategy is None: 404 self.normalization_strategy = self.NORMALIZATION_STRATEGY 405 else: 406 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 407 408 def __eq__(self, other: t.Any) -> bool: 409 # Does not currently take dialect state into account 410 return type(self) == other 411 412 def __hash__(self) -> int: 413 # Does not currently take dialect state into account 414 return hash(type(self)) 415 416 def normalize_identifier(self, expression: E) -> E: 417 """ 418 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 419 420 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 421 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 422 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 423 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 424 425 There are also dialects like Spark, which are case-insensitive even when quotes are 426 present, and dialects like MySQL, whose resolution rules match those employed by the 427 underlying operating system, for example they may always be case-sensitive in Linux. 428 429 Finally, the normalization behavior of some engines can even be controlled through flags, 430 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 431 432 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 433 that it can analyze queries in the optimizer and successfully capture their semantics. 434 """ 435 if ( 436 isinstance(expression, exp.Identifier) 437 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 438 and ( 439 not expression.quoted 440 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 441 ) 442 ): 443 expression.set( 444 "this", 445 ( 446 expression.this.upper() 447 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 448 else expression.this.lower() 449 ), 450 ) 451 452 return expression 453 454 def case_sensitive(self, text: str) -> bool: 455 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 456 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 457 return False 458 459 unsafe = ( 460 str.islower 461 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 462 else str.isupper 463 ) 464 return any(unsafe(char) for char in text) 465 466 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 467 """Checks if text can be identified given an identify option. 468 469 Args: 470 text: The text to check. 471 identify: 472 `"always"` or `True`: Always returns `True`. 473 `"safe"`: Only returns `True` if the identifier is case-insensitive. 474 475 Returns: 476 Whether the given text can be identified. 477 """ 478 if identify is True or identify == "always": 479 return True 480 481 if identify == "safe": 482 return not self.case_sensitive(text) 483 484 return False 485 486 def quote_identifier(self, expression: E, identify: bool = True) -> E: 487 """ 488 Adds quotes to a given identifier. 489 490 Args: 491 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 492 identify: If set to `False`, the quotes will only be added if the identifier is deemed 493 "unsafe", with respect to its characters and this dialect's normalization strategy. 494 """ 495 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 496 name = expression.this 497 expression.set( 498 "quoted", 499 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 500 ) 501 502 return expression 503 504 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 505 if isinstance(path, exp.Literal): 506 path_text = path.name 507 if path.is_number: 508 path_text = f"[{path_text}]" 509 510 try: 511 return parse_json_path(path_text) 512 except ParseError as e: 513 logger.warning(f"Invalid JSON path syntax. {str(e)}") 514 515 return path 516 517 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 518 return self.parser(**opts).parse(self.tokenize(sql), sql) 519 520 def parse_into( 521 self, expression_type: exp.IntoType, sql: str, **opts 522 ) -> t.List[t.Optional[exp.Expression]]: 523 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 524 525 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 526 return self.generator(**opts).generate(expression, copy=copy) 527 528 def transpile(self, sql: str, **opts) -> t.List[str]: 529 return [ 530 self.generate(expression, copy=False, **opts) if expression else "" 531 for expression in self.parse(sql) 532 ] 533 534 def tokenize(self, sql: str) -> t.List[Token]: 535 return self.tokenizer.tokenize(sql) 536 537 @property 538 def tokenizer(self) -> Tokenizer: 539 if not hasattr(self, "_tokenizer"): 540 self._tokenizer = self.tokenizer_class(dialect=self) 541 return self._tokenizer 542 543 def parser(self, **opts) -> Parser: 544 return self.parser_class(dialect=self, **opts) 545 546 def generator(self, **opts) -> Generator: 547 return self.generator_class(dialect=self, **opts) 548 549 550DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 551 552 553def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 554 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 555 556 557def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 558 if expression.args.get("accuracy"): 559 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 560 return self.func("APPROX_COUNT_DISTINCT", expression.this) 561 562 563def if_sql( 564 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 565) -> t.Callable[[Generator, exp.If], str]: 566 def _if_sql(self: Generator, expression: exp.If) -> str: 567 return self.func( 568 name, 569 expression.this, 570 expression.args.get("true"), 571 expression.args.get("false") or false_value, 572 ) 573 574 return _if_sql 575 576 577def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 578 this = expression.this 579 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 580 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 581 582 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 583 584 585def inline_array_sql(self: Generator, expression: exp.Array) -> str: 586 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 587 588 589def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 590 elem = seq_get(expression.expressions, 0) 591 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 592 return self.func("ARRAY", elem) 593 return inline_array_sql(self, expression) 594 595 596def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 597 return self.like_sql( 598 exp.Like( 599 this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) 600 ) 601 ) 602 603 604def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 605 zone = self.sql(expression, "this") 606 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 607 608 609def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 610 if expression.args.get("recursive"): 611 self.unsupported("Recursive CTEs are unsupported") 612 expression.args["recursive"] = False 613 return self.with_sql(expression) 614 615 616def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 617 n = self.sql(expression, "this") 618 d = self.sql(expression, "expression") 619 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 620 621 622def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 623 self.unsupported("TABLESAMPLE unsupported") 624 return self.sql(expression.this) 625 626 627def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 628 self.unsupported("PIVOT unsupported") 629 return "" 630 631 632def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 633 return self.cast_sql(expression) 634 635 636def no_comment_column_constraint_sql( 637 self: Generator, expression: exp.CommentColumnConstraint 638) -> str: 639 self.unsupported("CommentColumnConstraint unsupported") 640 return "" 641 642 643def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 644 self.unsupported("MAP_FROM_ENTRIES unsupported") 645 return "" 646 647 648def str_position_sql( 649 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 650) -> str: 651 this = self.sql(expression, "this") 652 substr = self.sql(expression, "substr") 653 position = self.sql(expression, "position") 654 instance = expression.args.get("instance") if generate_instance else None 655 position_offset = "" 656 657 if position: 658 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 659 this = self.func("SUBSTR", this, position) 660 position_offset = f" + {position} - 1" 661 662 return self.func("STRPOS", this, substr, instance) + position_offset 663 664 665def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 666 return ( 667 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 668 ) 669 670 671def var_map_sql( 672 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 673) -> str: 674 keys = expression.args["keys"] 675 values = expression.args["values"] 676 677 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 678 self.unsupported("Cannot convert array columns into map.") 679 return self.func(map_func_name, keys, values) 680 681 args = [] 682 for key, value in zip(keys.expressions, values.expressions): 683 args.append(self.sql(key)) 684 args.append(self.sql(value)) 685 686 return self.func(map_func_name, *args) 687 688 689def build_formatted_time( 690 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 691) -> t.Callable[[t.List], E]: 692 """Helper used for time expressions. 693 694 Args: 695 exp_class: the expression class to instantiate. 696 dialect: target sql dialect. 697 default: the default format, True being time. 698 699 Returns: 700 A callable that can be used to return the appropriately formatted time expression. 701 """ 702 703 def _builder(args: t.List): 704 return exp_class( 705 this=seq_get(args, 0), 706 format=Dialect[dialect].format_time( 707 seq_get(args, 1) 708 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 709 ), 710 ) 711 712 return _builder 713 714 715def time_format( 716 dialect: DialectType = None, 717) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 718 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 719 """ 720 Returns the time format for a given expression, unless it's equivalent 721 to the default time format of the dialect of interest. 722 """ 723 time_format = self.format_time(expression) 724 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 725 726 return _time_format 727 728 729def build_date_delta( 730 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 731) -> t.Callable[[t.List], E]: 732 def _builder(args: t.List) -> E: 733 unit_based = len(args) == 3 734 this = args[2] if unit_based else seq_get(args, 0) 735 unit = args[0] if unit_based else exp.Literal.string("DAY") 736 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 737 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 738 739 return _builder 740 741 742def build_date_delta_with_interval( 743 expression_class: t.Type[E], 744) -> t.Callable[[t.List], t.Optional[E]]: 745 def _builder(args: t.List) -> t.Optional[E]: 746 if len(args) < 2: 747 return None 748 749 interval = args[1] 750 751 if not isinstance(interval, exp.Interval): 752 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 753 754 expression = interval.this 755 if expression and expression.is_string: 756 expression = exp.Literal.number(expression.this) 757 758 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 759 760 return _builder 761 762 763def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 764 unit = seq_get(args, 0) 765 this = seq_get(args, 1) 766 767 if isinstance(this, exp.Cast) and this.is_type("date"): 768 return exp.DateTrunc(unit=unit, this=this) 769 return exp.TimestampTrunc(this=this, unit=unit) 770 771 772def date_add_interval_sql( 773 data_type: str, kind: str 774) -> t.Callable[[Generator, exp.Expression], str]: 775 def func(self: Generator, expression: exp.Expression) -> str: 776 this = self.sql(expression, "this") 777 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 778 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 779 780 return func 781 782 783def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 784 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 785 args = [unit_to_str(expression), expression.this] 786 if zone: 787 args.append(expression.args.get("zone")) 788 return self.func("DATE_TRUNC", *args) 789 790 return _timestamptrunc_sql 791 792 793def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 794 if not expression.expression: 795 from sqlglot.optimizer.annotate_types import annotate_types 796 797 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 798 return self.sql(exp.cast(expression.this, target_type)) 799 if expression.text("expression").lower() in TIMEZONES: 800 return self.sql( 801 exp.AtTimeZone( 802 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 803 zone=expression.expression, 804 ) 805 ) 806 return self.func("TIMESTAMP", expression.this, expression.expression) 807 808 809def locate_to_strposition(args: t.List) -> exp.Expression: 810 return exp.StrPosition( 811 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 812 ) 813 814 815def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 816 return self.func( 817 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 818 ) 819 820 821def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 822 return self.sql( 823 exp.Substring( 824 this=expression.this, start=exp.Literal.number(1), length=expression.expression 825 ) 826 ) 827 828 829def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 830 return self.sql( 831 exp.Substring( 832 this=expression.this, 833 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 834 ) 835 ) 836 837 838def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 839 return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP)) 840 841 842def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 843 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 844 845 846# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 847def encode_decode_sql( 848 self: Generator, expression: exp.Expression, name: str, replace: bool = True 849) -> str: 850 charset = expression.args.get("charset") 851 if charset and charset.name.lower() != "utf-8": 852 self.unsupported(f"Expected utf-8 character set, got {charset}.") 853 854 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 855 856 857def min_or_least(self: Generator, expression: exp.Min) -> str: 858 name = "LEAST" if expression.expressions else "MIN" 859 return rename_func(name)(self, expression) 860 861 862def max_or_greatest(self: Generator, expression: exp.Max) -> str: 863 name = "GREATEST" if expression.expressions else "MAX" 864 return rename_func(name)(self, expression) 865 866 867def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 868 cond = expression.this 869 870 if isinstance(expression.this, exp.Distinct): 871 cond = expression.this.expressions[0] 872 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 873 874 return self.func("sum", exp.func("if", cond, 1, 0)) 875 876 877def trim_sql(self: Generator, expression: exp.Trim) -> str: 878 target = self.sql(expression, "this") 879 trim_type = self.sql(expression, "position") 880 remove_chars = self.sql(expression, "expression") 881 collation = self.sql(expression, "collation") 882 883 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 884 if not remove_chars and not collation: 885 return self.trim_sql(expression) 886 887 trim_type = f"{trim_type} " if trim_type else "" 888 remove_chars = f"{remove_chars} " if remove_chars else "" 889 from_part = "FROM " if trim_type or remove_chars else "" 890 collation = f" COLLATE {collation}" if collation else "" 891 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 892 893 894def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 895 return self.func("STRPTIME", expression.this, self.format_time(expression)) 896 897 898def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 899 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 900 901 902def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 903 delim, *rest_args = expression.expressions 904 return self.sql( 905 reduce( 906 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 907 rest_args, 908 ) 909 ) 910 911 912def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 913 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 914 if bad_args: 915 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 916 917 return self.func( 918 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 919 ) 920 921 922def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 923 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 924 if bad_args: 925 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 926 927 return self.func( 928 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 929 ) 930 931 932def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 933 names = [] 934 for agg in aggregations: 935 if isinstance(agg, exp.Alias): 936 names.append(agg.alias) 937 else: 938 """ 939 This case corresponds to aggregations without aliases being used as suffixes 940 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 941 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 942 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 943 """ 944 agg_all_unquoted = agg.transform( 945 lambda node: ( 946 exp.Identifier(this=node.name, quoted=False) 947 if isinstance(node, exp.Identifier) 948 else node 949 ) 950 ) 951 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 952 953 return names 954 955 956def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 957 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 958 959 960# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 961def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 962 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 963 964 965def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 966 return self.func("MAX", expression.this) 967 968 969def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 970 a = self.sql(expression.left) 971 b = self.sql(expression.right) 972 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 973 974 975def is_parse_json(expression: exp.Expression) -> bool: 976 return isinstance(expression, exp.ParseJSON) or ( 977 isinstance(expression, exp.Cast) and expression.is_type("json") 978 ) 979 980 981def isnull_to_is_null(args: t.List) -> exp.Expression: 982 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 983 984 985def generatedasidentitycolumnconstraint_sql( 986 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 987) -> str: 988 start = self.sql(expression, "start") or "1" 989 increment = self.sql(expression, "increment") or "1" 990 return f"IDENTITY({start}, {increment})" 991 992 993def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 994 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 995 if expression.args.get("count"): 996 self.unsupported(f"Only two arguments are supported in function {name}.") 997 998 return self.func(name, expression.this, expression.expression) 999 1000 return _arg_max_or_min_sql 1001 1002 1003def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1004 this = expression.this.copy() 1005 1006 return_type = expression.return_type 1007 if return_type.is_type(exp.DataType.Type.DATE): 1008 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1009 # can truncate timestamp strings, because some dialects can't cast them to DATE 1010 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1011 1012 expression.this.replace(exp.cast(this, return_type)) 1013 return expression 1014 1015 1016def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1017 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1018 if cast and isinstance(expression, exp.TsOrDsAdd): 1019 expression = ts_or_ds_add_cast(expression) 1020 1021 return self.func( 1022 name, 1023 unit_to_var(expression), 1024 expression.expression, 1025 expression.this, 1026 ) 1027 1028 return _delta_sql 1029 1030 1031def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1032 unit = expression.args.get("unit") 1033 1034 if isinstance(unit, exp.Placeholder): 1035 return unit 1036 if unit: 1037 return exp.Literal.string(unit.name) 1038 return exp.Literal.string(default) if default else None 1039 1040 1041def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1042 unit = expression.args.get("unit") 1043 1044 if isinstance(unit, (exp.Var, exp.Placeholder)): 1045 return unit 1046 return exp.Var(this=default) if default else None 1047 1048 1049def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1050 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1051 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1052 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1053 1054 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1055 1056 1057def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1058 """Remove table refs from columns in when statements.""" 1059 alias = expression.this.args.get("alias") 1060 1061 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1062 return self.dialect.normalize_identifier(identifier).name if identifier else None 1063 1064 targets = {normalize(expression.this.this)} 1065 1066 if alias: 1067 targets.add(normalize(alias.this)) 1068 1069 for when in expression.expressions: 1070 when.transform( 1071 lambda node: ( 1072 exp.column(node.this) 1073 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1074 else node 1075 ), 1076 copy=False, 1077 ) 1078 1079 return self.merge_sql(expression) 1080 1081 1082def build_json_extract_path( 1083 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1084) -> t.Callable[[t.List], F]: 1085 def _builder(args: t.List) -> F: 1086 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1087 for arg in args[1:]: 1088 if not isinstance(arg, exp.Literal): 1089 # We use the fallback parser because we can't really transpile non-literals safely 1090 return expr_type.from_arg_list(args) 1091 1092 text = arg.name 1093 if is_int(text): 1094 index = int(text) 1095 segments.append( 1096 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1097 ) 1098 else: 1099 segments.append(exp.JSONPathKey(this=text)) 1100 1101 # This is done to avoid failing in the expression validator due to the arg count 1102 del args[2:] 1103 return expr_type( 1104 this=seq_get(args, 0), 1105 expression=exp.JSONPath(expressions=segments), 1106 only_json_types=arrow_req_json_type, 1107 ) 1108 1109 return _builder 1110 1111 1112def json_extract_segments( 1113 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1114) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1115 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1116 path = expression.expression 1117 if not isinstance(path, exp.JSONPath): 1118 return rename_func(name)(self, expression) 1119 1120 segments = [] 1121 for segment in path.expressions: 1122 path = self.sql(segment) 1123 if path: 1124 if isinstance(segment, exp.JSONPathPart) and ( 1125 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1126 ): 1127 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1128 1129 segments.append(path) 1130 1131 if op: 1132 return f" {op} ".join([self.sql(expression.this), *segments]) 1133 return self.func(name, expression.this, *segments) 1134 1135 return _json_extract_segments 1136 1137 1138def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1139 if isinstance(expression.this, exp.JSONPathWildcard): 1140 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1141 1142 return expression.name 1143 1144 1145def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1146 cond = expression.expression 1147 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1148 alias = cond.expressions[0] 1149 cond = cond.this 1150 elif isinstance(cond, exp.Predicate): 1151 alias = "_u" 1152 else: 1153 self.unsupported("Unsupported filter condition") 1154 return "" 1155 1156 unnest = exp.Unnest(expressions=[expression.this]) 1157 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1158 return self.sql(exp.Array(expressions=[filtered])) 1159 1160 1161def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1162 return self.func( 1163 "TO_NUMBER", 1164 expression.this, 1165 expression.args.get("format"), 1166 expression.args.get("nlsparam"), 1167 ) 1168 1169 1170def build_default_decimal_type( 1171 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1172) -> t.Callable[[exp.DataType], exp.DataType]: 1173 def _builder(dtype: exp.DataType) -> exp.DataType: 1174 if dtype.expressions or precision is None: 1175 return dtype 1176 1177 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1178 return exp.DataType.build(f"DECIMAL({params})") 1179 1180 return _builder
41class Dialects(str, Enum): 42 """Dialects supported by SQLGLot.""" 43 44 DIALECT = "" 45 46 ATHENA = "athena" 47 BIGQUERY = "bigquery" 48 CLICKHOUSE = "clickhouse" 49 DATABRICKS = "databricks" 50 DORIS = "doris" 51 DRILL = "drill" 52 DUCKDB = "duckdb" 53 HIVE = "hive" 54 MATERIALIZE = "materialize" 55 MYSQL = "mysql" 56 ORACLE = "oracle" 57 POSTGRES = "postgres" 58 PRESTO = "presto" 59 PRQL = "prql" 60 REDSHIFT = "redshift" 61 RISINGWAVE = "risingwave" 62 SNOWFLAKE = "snowflake" 63 SPARK = "spark" 64 SPARK2 = "spark2" 65 SQLITE = "sqlite" 66 STARROCKS = "starrocks" 67 TABLEAU = "tableau" 68 TERADATA = "teradata" 69 TRINO = "trino" 70 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
73class NormalizationStrategy(str, AutoName): 74 """Specifies the strategy according to which identifiers should be normalized.""" 75 76 LOWERCASE = auto() 77 """Unquoted identifiers are lowercased.""" 78 79 UPPERCASE = auto() 80 """Unquoted identifiers are uppercased.""" 81 82 CASE_SENSITIVE = auto() 83 """Always case-sensitive, regardless of quotes.""" 84 85 CASE_INSENSITIVE = auto() 86 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
190class Dialect(metaclass=_Dialect): 191 INDEX_OFFSET = 0 192 """The base index offset for arrays.""" 193 194 WEEK_OFFSET = 0 195 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 196 197 UNNEST_COLUMN_ONLY = False 198 """Whether `UNNEST` table aliases are treated as column aliases.""" 199 200 ALIAS_POST_TABLESAMPLE = False 201 """Whether the table alias comes after tablesample.""" 202 203 TABLESAMPLE_SIZE_IS_PERCENT = False 204 """Whether a size in the table sample clause represents percentage.""" 205 206 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 207 """Specifies the strategy according to which identifiers should be normalized.""" 208 209 IDENTIFIERS_CAN_START_WITH_DIGIT = False 210 """Whether an unquoted identifier can start with a digit.""" 211 212 DPIPE_IS_STRING_CONCAT = True 213 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 214 215 STRICT_STRING_CONCAT = False 216 """Whether `CONCAT`'s arguments must be strings.""" 217 218 SUPPORTS_USER_DEFINED_TYPES = True 219 """Whether user-defined data types are supported.""" 220 221 SUPPORTS_SEMI_ANTI_JOIN = True 222 """Whether `SEMI` or `ANTI` joins are supported.""" 223 224 NORMALIZE_FUNCTIONS: bool | str = "upper" 225 """ 226 Determines how function names are going to be normalized. 227 Possible values: 228 "upper" or True: Convert names to uppercase. 229 "lower": Convert names to lowercase. 230 False: Disables function name normalization. 231 """ 232 233 LOG_BASE_FIRST: t.Optional[bool] = True 234 """ 235 Whether the base comes first in the `LOG` function. 236 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 237 """ 238 239 NULL_ORDERING = "nulls_are_small" 240 """ 241 Default `NULL` ordering method to use if not explicitly set. 242 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 243 """ 244 245 TYPED_DIVISION = False 246 """ 247 Whether the behavior of `a / b` depends on the types of `a` and `b`. 248 False means `a / b` is always float division. 249 True means `a / b` is integer division if both `a` and `b` are integers. 250 """ 251 252 SAFE_DIVISION = False 253 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 254 255 CONCAT_COALESCE = False 256 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 257 258 HEX_LOWERCASE = False 259 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 260 261 DATE_FORMAT = "'%Y-%m-%d'" 262 DATEINT_FORMAT = "'%Y%m%d'" 263 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 264 265 TIME_MAPPING: t.Dict[str, str] = {} 266 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 267 268 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 269 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 270 FORMAT_MAPPING: t.Dict[str, str] = {} 271 """ 272 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 273 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 274 """ 275 276 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 277 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 278 279 PSEUDOCOLUMNS: t.Set[str] = set() 280 """ 281 Columns that are auto-generated by the engine corresponding to this dialect. 282 For example, such columns may be excluded from `SELECT *` queries. 283 """ 284 285 PREFER_CTE_ALIAS_COLUMN = False 286 """ 287 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 288 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 289 any projection aliases in the subquery. 290 291 For example, 292 WITH y(c) AS ( 293 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 294 ) SELECT c FROM y; 295 296 will be rewritten as 297 298 WITH y(c) AS ( 299 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 300 ) SELECT c FROM y; 301 """ 302 303 # --- Autofilled --- 304 305 tokenizer_class = Tokenizer 306 parser_class = Parser 307 generator_class = Generator 308 309 # A trie of the time_mapping keys 310 TIME_TRIE: t.Dict = {} 311 FORMAT_TRIE: t.Dict = {} 312 313 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 314 INVERSE_TIME_TRIE: t.Dict = {} 315 316 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 317 318 # Delimiters for string literals and identifiers 319 QUOTE_START = "'" 320 QUOTE_END = "'" 321 IDENTIFIER_START = '"' 322 IDENTIFIER_END = '"' 323 324 # Delimiters for bit, hex, byte and unicode literals 325 BIT_START: t.Optional[str] = None 326 BIT_END: t.Optional[str] = None 327 HEX_START: t.Optional[str] = None 328 HEX_END: t.Optional[str] = None 329 BYTE_START: t.Optional[str] = None 330 BYTE_END: t.Optional[str] = None 331 UNICODE_START: t.Optional[str] = None 332 UNICODE_END: t.Optional[str] = None 333 334 # Separator of COPY statement parameters 335 COPY_PARAMS_ARE_CSV = True 336 337 @classmethod 338 def get_or_raise(cls, dialect: DialectType) -> Dialect: 339 """ 340 Look up a dialect in the global dialect registry and return it if it exists. 341 342 Args: 343 dialect: The target dialect. If this is a string, it can be optionally followed by 344 additional key-value pairs that are separated by commas and are used to specify 345 dialect settings, such as whether the dialect's identifiers are case-sensitive. 346 347 Example: 348 >>> dialect = dialect_class = get_or_raise("duckdb") 349 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 350 351 Returns: 352 The corresponding Dialect instance. 353 """ 354 355 if not dialect: 356 return cls() 357 if isinstance(dialect, _Dialect): 358 return dialect() 359 if isinstance(dialect, Dialect): 360 return dialect 361 if isinstance(dialect, str): 362 try: 363 dialect_name, *kv_pairs = dialect.split(",") 364 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 365 except ValueError: 366 raise ValueError( 367 f"Invalid dialect format: '{dialect}'. " 368 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 369 ) 370 371 result = cls.get(dialect_name.strip()) 372 if not result: 373 from difflib import get_close_matches 374 375 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 376 if similar: 377 similar = f" Did you mean {similar}?" 378 379 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 380 381 return result(**kwargs) 382 383 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 384 385 @classmethod 386 def format_time( 387 cls, expression: t.Optional[str | exp.Expression] 388 ) -> t.Optional[exp.Expression]: 389 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 390 if isinstance(expression, str): 391 return exp.Literal.string( 392 # the time formats are quoted 393 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 394 ) 395 396 if expression and expression.is_string: 397 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 398 399 return expression 400 401 def __init__(self, **kwargs) -> None: 402 normalization_strategy = kwargs.get("normalization_strategy") 403 404 if normalization_strategy is None: 405 self.normalization_strategy = self.NORMALIZATION_STRATEGY 406 else: 407 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 408 409 def __eq__(self, other: t.Any) -> bool: 410 # Does not currently take dialect state into account 411 return type(self) == other 412 413 def __hash__(self) -> int: 414 # Does not currently take dialect state into account 415 return hash(type(self)) 416 417 def normalize_identifier(self, expression: E) -> E: 418 """ 419 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 420 421 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 422 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 423 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 424 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 425 426 There are also dialects like Spark, which are case-insensitive even when quotes are 427 present, and dialects like MySQL, whose resolution rules match those employed by the 428 underlying operating system, for example they may always be case-sensitive in Linux. 429 430 Finally, the normalization behavior of some engines can even be controlled through flags, 431 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 432 433 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 434 that it can analyze queries in the optimizer and successfully capture their semantics. 435 """ 436 if ( 437 isinstance(expression, exp.Identifier) 438 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 439 and ( 440 not expression.quoted 441 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 442 ) 443 ): 444 expression.set( 445 "this", 446 ( 447 expression.this.upper() 448 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 449 else expression.this.lower() 450 ), 451 ) 452 453 return expression 454 455 def case_sensitive(self, text: str) -> bool: 456 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 457 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 458 return False 459 460 unsafe = ( 461 str.islower 462 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 463 else str.isupper 464 ) 465 return any(unsafe(char) for char in text) 466 467 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 468 """Checks if text can be identified given an identify option. 469 470 Args: 471 text: The text to check. 472 identify: 473 `"always"` or `True`: Always returns `True`. 474 `"safe"`: Only returns `True` if the identifier is case-insensitive. 475 476 Returns: 477 Whether the given text can be identified. 478 """ 479 if identify is True or identify == "always": 480 return True 481 482 if identify == "safe": 483 return not self.case_sensitive(text) 484 485 return False 486 487 def quote_identifier(self, expression: E, identify: bool = True) -> E: 488 """ 489 Adds quotes to a given identifier. 490 491 Args: 492 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 493 identify: If set to `False`, the quotes will only be added if the identifier is deemed 494 "unsafe", with respect to its characters and this dialect's normalization strategy. 495 """ 496 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 497 name = expression.this 498 expression.set( 499 "quoted", 500 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 501 ) 502 503 return expression 504 505 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 506 if isinstance(path, exp.Literal): 507 path_text = path.name 508 if path.is_number: 509 path_text = f"[{path_text}]" 510 511 try: 512 return parse_json_path(path_text) 513 except ParseError as e: 514 logger.warning(f"Invalid JSON path syntax. {str(e)}") 515 516 return path 517 518 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 519 return self.parser(**opts).parse(self.tokenize(sql), sql) 520 521 def parse_into( 522 self, expression_type: exp.IntoType, sql: str, **opts 523 ) -> t.List[t.Optional[exp.Expression]]: 524 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 525 526 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 527 return self.generator(**opts).generate(expression, copy=copy) 528 529 def transpile(self, sql: str, **opts) -> t.List[str]: 530 return [ 531 self.generate(expression, copy=False, **opts) if expression else "" 532 for expression in self.parse(sql) 533 ] 534 535 def tokenize(self, sql: str) -> t.List[Token]: 536 return self.tokenizer.tokenize(sql) 537 538 @property 539 def tokenizer(self) -> Tokenizer: 540 if not hasattr(self, "_tokenizer"): 541 self._tokenizer = self.tokenizer_class(dialect=self) 542 return self._tokenizer 543 544 def parser(self, **opts) -> Parser: 545 return self.parser_class(dialect=self, **opts) 546 547 def generator(self, **opts) -> Generator: 548 return self.generator_class(dialect=self, **opts)
401 def __init__(self, **kwargs) -> None: 402 normalization_strategy = kwargs.get("normalization_strategy") 403 404 if normalization_strategy is None: 405 self.normalization_strategy = self.NORMALIZATION_STRATEGY 406 else: 407 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
337 @classmethod 338 def get_or_raise(cls, dialect: DialectType) -> Dialect: 339 """ 340 Look up a dialect in the global dialect registry and return it if it exists. 341 342 Args: 343 dialect: The target dialect. If this is a string, it can be optionally followed by 344 additional key-value pairs that are separated by commas and are used to specify 345 dialect settings, such as whether the dialect's identifiers are case-sensitive. 346 347 Example: 348 >>> dialect = dialect_class = get_or_raise("duckdb") 349 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 350 351 Returns: 352 The corresponding Dialect instance. 353 """ 354 355 if not dialect: 356 return cls() 357 if isinstance(dialect, _Dialect): 358 return dialect() 359 if isinstance(dialect, Dialect): 360 return dialect 361 if isinstance(dialect, str): 362 try: 363 dialect_name, *kv_pairs = dialect.split(",") 364 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 365 except ValueError: 366 raise ValueError( 367 f"Invalid dialect format: '{dialect}'. " 368 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 369 ) 370 371 result = cls.get(dialect_name.strip()) 372 if not result: 373 from difflib import get_close_matches 374 375 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 376 if similar: 377 similar = f" Did you mean {similar}?" 378 379 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 380 381 return result(**kwargs) 382 383 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
385 @classmethod 386 def format_time( 387 cls, expression: t.Optional[str | exp.Expression] 388 ) -> t.Optional[exp.Expression]: 389 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 390 if isinstance(expression, str): 391 return exp.Literal.string( 392 # the time formats are quoted 393 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 394 ) 395 396 if expression and expression.is_string: 397 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 398 399 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
417 def normalize_identifier(self, expression: E) -> E: 418 """ 419 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 420 421 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 422 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 423 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 424 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 425 426 There are also dialects like Spark, which are case-insensitive even when quotes are 427 present, and dialects like MySQL, whose resolution rules match those employed by the 428 underlying operating system, for example they may always be case-sensitive in Linux. 429 430 Finally, the normalization behavior of some engines can even be controlled through flags, 431 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 432 433 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 434 that it can analyze queries in the optimizer and successfully capture their semantics. 435 """ 436 if ( 437 isinstance(expression, exp.Identifier) 438 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 439 and ( 440 not expression.quoted 441 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 442 ) 443 ): 444 expression.set( 445 "this", 446 ( 447 expression.this.upper() 448 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 449 else expression.this.lower() 450 ), 451 ) 452 453 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
455 def case_sensitive(self, text: str) -> bool: 456 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 457 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 458 return False 459 460 unsafe = ( 461 str.islower 462 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 463 else str.isupper 464 ) 465 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
467 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 468 """Checks if text can be identified given an identify option. 469 470 Args: 471 text: The text to check. 472 identify: 473 `"always"` or `True`: Always returns `True`. 474 `"safe"`: Only returns `True` if the identifier is case-insensitive. 475 476 Returns: 477 Whether the given text can be identified. 478 """ 479 if identify is True or identify == "always": 480 return True 481 482 if identify == "safe": 483 return not self.case_sensitive(text) 484 485 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
487 def quote_identifier(self, expression: E, identify: bool = True) -> E: 488 """ 489 Adds quotes to a given identifier. 490 491 Args: 492 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 493 identify: If set to `False`, the quotes will only be added if the identifier is deemed 494 "unsafe", with respect to its characters and this dialect's normalization strategy. 495 """ 496 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 497 name = expression.this 498 expression.set( 499 "quoted", 500 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 501 ) 502 503 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
505 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 506 if isinstance(path, exp.Literal): 507 path_text = path.name 508 if path.is_number: 509 path_text = f"[{path_text}]" 510 511 try: 512 return parse_json_path(path_text) 513 except ParseError as e: 514 logger.warning(f"Invalid JSON path syntax. {str(e)}") 515 516 return path
564def if_sql( 565 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 566) -> t.Callable[[Generator, exp.If], str]: 567 def _if_sql(self: Generator, expression: exp.If) -> str: 568 return self.func( 569 name, 570 expression.this, 571 expression.args.get("true"), 572 expression.args.get("false") or false_value, 573 ) 574 575 return _if_sql
578def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 579 this = expression.this 580 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 581 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 582 583 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
649def str_position_sql( 650 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 651) -> str: 652 this = self.sql(expression, "this") 653 substr = self.sql(expression, "substr") 654 position = self.sql(expression, "position") 655 instance = expression.args.get("instance") if generate_instance else None 656 position_offset = "" 657 658 if position: 659 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 660 this = self.func("SUBSTR", this, position) 661 position_offset = f" + {position} - 1" 662 663 return self.func("STRPOS", this, substr, instance) + position_offset
672def var_map_sql( 673 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 674) -> str: 675 keys = expression.args["keys"] 676 values = expression.args["values"] 677 678 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 679 self.unsupported("Cannot convert array columns into map.") 680 return self.func(map_func_name, keys, values) 681 682 args = [] 683 for key, value in zip(keys.expressions, values.expressions): 684 args.append(self.sql(key)) 685 args.append(self.sql(value)) 686 687 return self.func(map_func_name, *args)
690def build_formatted_time( 691 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 692) -> t.Callable[[t.List], E]: 693 """Helper used for time expressions. 694 695 Args: 696 exp_class: the expression class to instantiate. 697 dialect: target sql dialect. 698 default: the default format, True being time. 699 700 Returns: 701 A callable that can be used to return the appropriately formatted time expression. 702 """ 703 704 def _builder(args: t.List): 705 return exp_class( 706 this=seq_get(args, 0), 707 format=Dialect[dialect].format_time( 708 seq_get(args, 1) 709 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 710 ), 711 ) 712 713 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
716def time_format( 717 dialect: DialectType = None, 718) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 719 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 720 """ 721 Returns the time format for a given expression, unless it's equivalent 722 to the default time format of the dialect of interest. 723 """ 724 time_format = self.format_time(expression) 725 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 726 727 return _time_format
730def build_date_delta( 731 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 732) -> t.Callable[[t.List], E]: 733 def _builder(args: t.List) -> E: 734 unit_based = len(args) == 3 735 this = args[2] if unit_based else seq_get(args, 0) 736 unit = args[0] if unit_based else exp.Literal.string("DAY") 737 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 738 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 739 740 return _builder
743def build_date_delta_with_interval( 744 expression_class: t.Type[E], 745) -> t.Callable[[t.List], t.Optional[E]]: 746 def _builder(args: t.List) -> t.Optional[E]: 747 if len(args) < 2: 748 return None 749 750 interval = args[1] 751 752 if not isinstance(interval, exp.Interval): 753 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 754 755 expression = interval.this 756 if expression and expression.is_string: 757 expression = exp.Literal.number(expression.this) 758 759 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 760 761 return _builder
773def date_add_interval_sql( 774 data_type: str, kind: str 775) -> t.Callable[[Generator, exp.Expression], str]: 776 def func(self: Generator, expression: exp.Expression) -> str: 777 this = self.sql(expression, "this") 778 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 779 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 780 781 return func
784def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 785 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 786 args = [unit_to_str(expression), expression.this] 787 if zone: 788 args.append(expression.args.get("zone")) 789 return self.func("DATE_TRUNC", *args) 790 791 return _timestamptrunc_sql
794def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 795 if not expression.expression: 796 from sqlglot.optimizer.annotate_types import annotate_types 797 798 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 799 return self.sql(exp.cast(expression.this, target_type)) 800 if expression.text("expression").lower() in TIMEZONES: 801 return self.sql( 802 exp.AtTimeZone( 803 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 804 zone=expression.expression, 805 ) 806 ) 807 return self.func("TIMESTAMP", expression.this, expression.expression)
848def encode_decode_sql( 849 self: Generator, expression: exp.Expression, name: str, replace: bool = True 850) -> str: 851 charset = expression.args.get("charset") 852 if charset and charset.name.lower() != "utf-8": 853 self.unsupported(f"Expected utf-8 character set, got {charset}.") 854 855 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
868def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 869 cond = expression.this 870 871 if isinstance(expression.this, exp.Distinct): 872 cond = expression.this.expressions[0] 873 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 874 875 return self.func("sum", exp.func("if", cond, 1, 0))
878def trim_sql(self: Generator, expression: exp.Trim) -> str: 879 target = self.sql(expression, "this") 880 trim_type = self.sql(expression, "position") 881 remove_chars = self.sql(expression, "expression") 882 collation = self.sql(expression, "collation") 883 884 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 885 if not remove_chars and not collation: 886 return self.trim_sql(expression) 887 888 trim_type = f"{trim_type} " if trim_type else "" 889 remove_chars = f"{remove_chars} " if remove_chars else "" 890 from_part = "FROM " if trim_type or remove_chars else "" 891 collation = f" COLLATE {collation}" if collation else "" 892 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
913def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 914 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 915 if bad_args: 916 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 917 918 return self.func( 919 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 920 )
923def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 924 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 925 if bad_args: 926 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 927 928 return self.func( 929 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 930 )
933def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 934 names = [] 935 for agg in aggregations: 936 if isinstance(agg, exp.Alias): 937 names.append(agg.alias) 938 else: 939 """ 940 This case corresponds to aggregations without aliases being used as suffixes 941 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 942 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 943 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 944 """ 945 agg_all_unquoted = agg.transform( 946 lambda node: ( 947 exp.Identifier(this=node.name, quoted=False) 948 if isinstance(node, exp.Identifier) 949 else node 950 ) 951 ) 952 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 953 954 return names
994def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 995 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 996 if expression.args.get("count"): 997 self.unsupported(f"Only two arguments are supported in function {name}.") 998 999 return self.func(name, expression.this, expression.expression) 1000 1001 return _arg_max_or_min_sql
1004def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1005 this = expression.this.copy() 1006 1007 return_type = expression.return_type 1008 if return_type.is_type(exp.DataType.Type.DATE): 1009 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1010 # can truncate timestamp strings, because some dialects can't cast them to DATE 1011 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1012 1013 expression.this.replace(exp.cast(this, return_type)) 1014 return expression
1017def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1018 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1019 if cast and isinstance(expression, exp.TsOrDsAdd): 1020 expression = ts_or_ds_add_cast(expression) 1021 1022 return self.func( 1023 name, 1024 unit_to_var(expression), 1025 expression.expression, 1026 expression.this, 1027 ) 1028 1029 return _delta_sql
1032def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1033 unit = expression.args.get("unit") 1034 1035 if isinstance(unit, exp.Placeholder): 1036 return unit 1037 if unit: 1038 return exp.Literal.string(unit.name) 1039 return exp.Literal.string(default) if default else None
1050def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1051 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1052 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1053 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1054 1055 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1058def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1059 """Remove table refs from columns in when statements.""" 1060 alias = expression.this.args.get("alias") 1061 1062 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1063 return self.dialect.normalize_identifier(identifier).name if identifier else None 1064 1065 targets = {normalize(expression.this.this)} 1066 1067 if alias: 1068 targets.add(normalize(alias.this)) 1069 1070 for when in expression.expressions: 1071 when.transform( 1072 lambda node: ( 1073 exp.column(node.this) 1074 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1075 else node 1076 ), 1077 copy=False, 1078 ) 1079 1080 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1083def build_json_extract_path( 1084 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1085) -> t.Callable[[t.List], F]: 1086 def _builder(args: t.List) -> F: 1087 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1088 for arg in args[1:]: 1089 if not isinstance(arg, exp.Literal): 1090 # We use the fallback parser because we can't really transpile non-literals safely 1091 return expr_type.from_arg_list(args) 1092 1093 text = arg.name 1094 if is_int(text): 1095 index = int(text) 1096 segments.append( 1097 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1098 ) 1099 else: 1100 segments.append(exp.JSONPathKey(this=text)) 1101 1102 # This is done to avoid failing in the expression validator due to the arg count 1103 del args[2:] 1104 return expr_type( 1105 this=seq_get(args, 0), 1106 expression=exp.JSONPath(expressions=segments), 1107 only_json_types=arrow_req_json_type, 1108 ) 1109 1110 return _builder
1113def json_extract_segments( 1114 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1115) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1116 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1117 path = expression.expression 1118 if not isinstance(path, exp.JSONPath): 1119 return rename_func(name)(self, expression) 1120 1121 segments = [] 1122 for segment in path.expressions: 1123 path = self.sql(segment) 1124 if path: 1125 if isinstance(segment, exp.JSONPathPart) and ( 1126 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1127 ): 1128 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1129 1130 segments.append(path) 1131 1132 if op: 1133 return f" {op} ".join([self.sql(expression.this), *segments]) 1134 return self.func(name, expression.this, *segments) 1135 1136 return _json_extract_segments
1146def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1147 cond = expression.expression 1148 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1149 alias = cond.expressions[0] 1150 cond = cond.this 1151 elif isinstance(cond, exp.Predicate): 1152 alias = "_u" 1153 else: 1154 self.unsupported("Unsupported filter condition") 1155 return "" 1156 1157 unnest = exp.Unnest(expressions=[expression.this]) 1158 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1159 return self.sql(exp.Array(expressions=[filtered]))
1171def build_default_decimal_type( 1172 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1173) -> t.Callable[[exp.DataType], exp.DataType]: 1174 def _builder(dtype: exp.DataType) -> exp.DataType: 1175 if dtype.expressions or precision is None: 1176 return dtype 1177 1178 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1179 return exp.DataType.build(f"DECIMAL({params})") 1180 1181 return _builder