sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get 12from sqlglot.jsonpath import parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26logger = logging.getLogger("sqlglot") 27 28 29class Dialects(str, Enum): 30 """Dialects supported by SQLGLot.""" 31 32 DIALECT = "" 33 34 BIGQUERY = "bigquery" 35 CLICKHOUSE = "clickhouse" 36 DATABRICKS = "databricks" 37 DORIS = "doris" 38 DRILL = "drill" 39 DUCKDB = "duckdb" 40 HIVE = "hive" 41 MYSQL = "mysql" 42 ORACLE = "oracle" 43 POSTGRES = "postgres" 44 PRESTO = "presto" 45 REDSHIFT = "redshift" 46 SNOWFLAKE = "snowflake" 47 SPARK = "spark" 48 SPARK2 = "spark2" 49 SQLITE = "sqlite" 50 STARROCKS = "starrocks" 51 TABLEAU = "tableau" 52 TERADATA = "teradata" 53 TRINO = "trino" 54 TSQL = "tsql" 55 56 57class NormalizationStrategy(str, AutoName): 58 """Specifies the strategy according to which identifiers should be normalized.""" 59 60 LOWERCASE = auto() 61 """Unquoted identifiers are lowercased.""" 62 63 UPPERCASE = auto() 64 """Unquoted identifiers are uppercased.""" 65 66 CASE_SENSITIVE = auto() 67 """Always case-sensitive, regardless of quotes.""" 68 69 CASE_INSENSITIVE = auto() 70 """Always case-insensitive, regardless of quotes.""" 71 72 73class _Dialect(type): 74 classes: t.Dict[str, t.Type[Dialect]] = {} 75 76 def __eq__(cls, other: t.Any) -> bool: 77 if cls is other: 78 return True 79 if isinstance(other, str): 80 return cls is cls.get(other) 81 if isinstance(other, Dialect): 82 return cls is type(other) 83 84 return False 85 86 def __hash__(cls) -> int: 87 return hash(cls.__name__.lower()) 88 89 @classmethod 90 def __getitem__(cls, key: str) -> t.Type[Dialect]: 91 return cls.classes[key] 92 93 @classmethod 94 def get( 95 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 96 ) -> t.Optional[t.Type[Dialect]]: 97 return cls.classes.get(key, default) 98 99 def __new__(cls, clsname, bases, attrs): 100 klass = super().__new__(cls, clsname, bases, attrs) 101 enum = Dialects.__members__.get(clsname.upper()) 102 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 103 104 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 105 klass.FORMAT_TRIE = ( 106 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 107 ) 108 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 109 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 110 111 klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()} 112 113 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 114 klass.parser_class = getattr(klass, "Parser", Parser) 115 klass.generator_class = getattr(klass, "Generator", Generator) 116 117 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 118 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 119 klass.tokenizer_class._IDENTIFIERS.items() 120 )[0] 121 122 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 123 return next( 124 ( 125 (s, e) 126 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 127 if t == token_type 128 ), 129 (None, None), 130 ) 131 132 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 133 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 134 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 135 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 136 137 if enum not in ("", "bigquery"): 138 klass.generator_class.SELECT_KINDS = () 139 140 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 141 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 142 TokenType.ANTI, 143 TokenType.SEMI, 144 } 145 146 return klass 147 148 149class Dialect(metaclass=_Dialect): 150 INDEX_OFFSET = 0 151 """The base index offset for arrays.""" 152 153 WEEK_OFFSET = 0 154 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 155 156 UNNEST_COLUMN_ONLY = False 157 """Whether `UNNEST` table aliases are treated as column aliases.""" 158 159 ALIAS_POST_TABLESAMPLE = False 160 """Whether the table alias comes after tablesample.""" 161 162 TABLESAMPLE_SIZE_IS_PERCENT = False 163 """Whether a size in the table sample clause represents percentage.""" 164 165 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 166 """Specifies the strategy according to which identifiers should be normalized.""" 167 168 IDENTIFIERS_CAN_START_WITH_DIGIT = False 169 """Whether an unquoted identifier can start with a digit.""" 170 171 DPIPE_IS_STRING_CONCAT = True 172 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 173 174 STRICT_STRING_CONCAT = False 175 """Whether `CONCAT`'s arguments must be strings.""" 176 177 SUPPORTS_USER_DEFINED_TYPES = True 178 """Whether user-defined data types are supported.""" 179 180 SUPPORTS_SEMI_ANTI_JOIN = True 181 """Whether `SEMI` or `ANTI` joins are supported.""" 182 183 NORMALIZE_FUNCTIONS: bool | str = "upper" 184 """ 185 Determines how function names are going to be normalized. 186 Possible values: 187 "upper" or True: Convert names to uppercase. 188 "lower": Convert names to lowercase. 189 False: Disables function name normalization. 190 """ 191 192 LOG_BASE_FIRST = True 193 """Whether the base comes first in the `LOG` function.""" 194 195 NULL_ORDERING = "nulls_are_small" 196 """ 197 Default `NULL` ordering method to use if not explicitly set. 198 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 199 """ 200 201 TYPED_DIVISION = False 202 """ 203 Whether the behavior of `a / b` depends on the types of `a` and `b`. 204 False means `a / b` is always float division. 205 True means `a / b` is integer division if both `a` and `b` are integers. 206 """ 207 208 SAFE_DIVISION = False 209 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 210 211 CONCAT_COALESCE = False 212 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 213 214 DATE_FORMAT = "'%Y-%m-%d'" 215 DATEINT_FORMAT = "'%Y%m%d'" 216 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 217 218 TIME_MAPPING: t.Dict[str, str] = {} 219 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 220 221 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 222 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 223 FORMAT_MAPPING: t.Dict[str, str] = {} 224 """ 225 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 226 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 227 """ 228 229 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 230 """Mapping of an unescaped escape sequence to the corresponding character.""" 231 232 PSEUDOCOLUMNS: t.Set[str] = set() 233 """ 234 Columns that are auto-generated by the engine corresponding to this dialect. 235 For example, such columns may be excluded from `SELECT *` queries. 236 """ 237 238 PREFER_CTE_ALIAS_COLUMN = False 239 """ 240 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 241 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 242 any projection aliases in the subquery. 243 244 For example, 245 WITH y(c) AS ( 246 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 247 ) SELECT c FROM y; 248 249 will be rewritten as 250 251 WITH y(c) AS ( 252 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 253 ) SELECT c FROM y; 254 """ 255 256 # --- Autofilled --- 257 258 tokenizer_class = Tokenizer 259 parser_class = Parser 260 generator_class = Generator 261 262 # A trie of the time_mapping keys 263 TIME_TRIE: t.Dict = {} 264 FORMAT_TRIE: t.Dict = {} 265 266 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 267 INVERSE_TIME_TRIE: t.Dict = {} 268 269 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 270 271 # Delimiters for string literals and identifiers 272 QUOTE_START = "'" 273 QUOTE_END = "'" 274 IDENTIFIER_START = '"' 275 IDENTIFIER_END = '"' 276 277 # Delimiters for bit, hex, byte and unicode literals 278 BIT_START: t.Optional[str] = None 279 BIT_END: t.Optional[str] = None 280 HEX_START: t.Optional[str] = None 281 HEX_END: t.Optional[str] = None 282 BYTE_START: t.Optional[str] = None 283 BYTE_END: t.Optional[str] = None 284 UNICODE_START: t.Optional[str] = None 285 UNICODE_END: t.Optional[str] = None 286 287 @classmethod 288 def get_or_raise(cls, dialect: DialectType) -> Dialect: 289 """ 290 Look up a dialect in the global dialect registry and return it if it exists. 291 292 Args: 293 dialect: The target dialect. If this is a string, it can be optionally followed by 294 additional key-value pairs that are separated by commas and are used to specify 295 dialect settings, such as whether the dialect's identifiers are case-sensitive. 296 297 Example: 298 >>> dialect = dialect_class = get_or_raise("duckdb") 299 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 300 301 Returns: 302 The corresponding Dialect instance. 303 """ 304 305 if not dialect: 306 return cls() 307 if isinstance(dialect, _Dialect): 308 return dialect() 309 if isinstance(dialect, Dialect): 310 return dialect 311 if isinstance(dialect, str): 312 try: 313 dialect_name, *kv_pairs = dialect.split(",") 314 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 315 except ValueError: 316 raise ValueError( 317 f"Invalid dialect format: '{dialect}'. " 318 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 319 ) 320 321 result = cls.get(dialect_name.strip()) 322 if not result: 323 from difflib import get_close_matches 324 325 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 326 if similar: 327 similar = f" Did you mean {similar}?" 328 329 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 330 331 return result(**kwargs) 332 333 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 334 335 @classmethod 336 def format_time( 337 cls, expression: t.Optional[str | exp.Expression] 338 ) -> t.Optional[exp.Expression]: 339 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 340 if isinstance(expression, str): 341 return exp.Literal.string( 342 # the time formats are quoted 343 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 344 ) 345 346 if expression and expression.is_string: 347 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 348 349 return expression 350 351 def __init__(self, **kwargs) -> None: 352 normalization_strategy = kwargs.get("normalization_strategy") 353 354 if normalization_strategy is None: 355 self.normalization_strategy = self.NORMALIZATION_STRATEGY 356 else: 357 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 358 359 def __eq__(self, other: t.Any) -> bool: 360 # Does not currently take dialect state into account 361 return type(self) == other 362 363 def __hash__(self) -> int: 364 # Does not currently take dialect state into account 365 return hash(type(self)) 366 367 def normalize_identifier(self, expression: E) -> E: 368 """ 369 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 370 371 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 372 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 373 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 374 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 375 376 There are also dialects like Spark, which are case-insensitive even when quotes are 377 present, and dialects like MySQL, whose resolution rules match those employed by the 378 underlying operating system, for example they may always be case-sensitive in Linux. 379 380 Finally, the normalization behavior of some engines can even be controlled through flags, 381 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 382 383 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 384 that it can analyze queries in the optimizer and successfully capture their semantics. 385 """ 386 if ( 387 isinstance(expression, exp.Identifier) 388 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 389 and ( 390 not expression.quoted 391 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 392 ) 393 ): 394 expression.set( 395 "this", 396 ( 397 expression.this.upper() 398 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 399 else expression.this.lower() 400 ), 401 ) 402 403 return expression 404 405 def case_sensitive(self, text: str) -> bool: 406 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 407 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 408 return False 409 410 unsafe = ( 411 str.islower 412 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 413 else str.isupper 414 ) 415 return any(unsafe(char) for char in text) 416 417 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 418 """Checks if text can be identified given an identify option. 419 420 Args: 421 text: The text to check. 422 identify: 423 `"always"` or `True`: Always returns `True`. 424 `"safe"`: Only returns `True` if the identifier is case-insensitive. 425 426 Returns: 427 Whether the given text can be identified. 428 """ 429 if identify is True or identify == "always": 430 return True 431 432 if identify == "safe": 433 return not self.case_sensitive(text) 434 435 return False 436 437 def quote_identifier(self, expression: E, identify: bool = True) -> E: 438 """ 439 Adds quotes to a given identifier. 440 441 Args: 442 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 443 identify: If set to `False`, the quotes will only be added if the identifier is deemed 444 "unsafe", with respect to its characters and this dialect's normalization strategy. 445 """ 446 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 447 name = expression.this 448 expression.set( 449 "quoted", 450 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 451 ) 452 453 return expression 454 455 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 456 if isinstance(path, exp.Literal): 457 path_text = path.name 458 if path.is_number: 459 path_text = f"[{path_text}]" 460 461 try: 462 return parse_json_path(path_text) 463 except ParseError as e: 464 logger.warning(f"Invalid JSON path syntax. {str(e)}") 465 466 return path 467 468 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 469 return self.parser(**opts).parse(self.tokenize(sql), sql) 470 471 def parse_into( 472 self, expression_type: exp.IntoType, sql: str, **opts 473 ) -> t.List[t.Optional[exp.Expression]]: 474 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 475 476 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 477 return self.generator(**opts).generate(expression, copy=copy) 478 479 def transpile(self, sql: str, **opts) -> t.List[str]: 480 return [ 481 self.generate(expression, copy=False, **opts) if expression else "" 482 for expression in self.parse(sql) 483 ] 484 485 def tokenize(self, sql: str) -> t.List[Token]: 486 return self.tokenizer.tokenize(sql) 487 488 @property 489 def tokenizer(self) -> Tokenizer: 490 if not hasattr(self, "_tokenizer"): 491 self._tokenizer = self.tokenizer_class(dialect=self) 492 return self._tokenizer 493 494 def parser(self, **opts) -> Parser: 495 return self.parser_class(dialect=self, **opts) 496 497 def generator(self, **opts) -> Generator: 498 return self.generator_class(dialect=self, **opts) 499 500 501DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 502 503 504def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 505 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 506 507 508def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 509 if expression.args.get("accuracy"): 510 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 511 return self.func("APPROX_COUNT_DISTINCT", expression.this) 512 513 514def if_sql( 515 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 516) -> t.Callable[[Generator, exp.If], str]: 517 def _if_sql(self: Generator, expression: exp.If) -> str: 518 return self.func( 519 name, 520 expression.this, 521 expression.args.get("true"), 522 expression.args.get("false") or false_value, 523 ) 524 525 return _if_sql 526 527 528def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 529 this = expression.this 530 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 531 this.replace(exp.cast(this, "json")) 532 533 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 534 535 536def inline_array_sql(self: Generator, expression: exp.Array) -> str: 537 return f"[{self.expressions(expression, flat=True)}]" 538 539 540def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 541 return self.like_sql( 542 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 543 ) 544 545 546def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 547 zone = self.sql(expression, "this") 548 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 549 550 551def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 552 if expression.args.get("recursive"): 553 self.unsupported("Recursive CTEs are unsupported") 554 expression.args["recursive"] = False 555 return self.with_sql(expression) 556 557 558def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 559 n = self.sql(expression, "this") 560 d = self.sql(expression, "expression") 561 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 562 563 564def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 565 self.unsupported("TABLESAMPLE unsupported") 566 return self.sql(expression.this) 567 568 569def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 570 self.unsupported("PIVOT unsupported") 571 return "" 572 573 574def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 575 return self.cast_sql(expression) 576 577 578def no_comment_column_constraint_sql( 579 self: Generator, expression: exp.CommentColumnConstraint 580) -> str: 581 self.unsupported("CommentColumnConstraint unsupported") 582 return "" 583 584 585def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 586 self.unsupported("MAP_FROM_ENTRIES unsupported") 587 return "" 588 589 590def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 591 this = self.sql(expression, "this") 592 substr = self.sql(expression, "substr") 593 position = self.sql(expression, "position") 594 if position: 595 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 596 return f"STRPOS({this}, {substr})" 597 598 599def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 600 return ( 601 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 602 ) 603 604 605def var_map_sql( 606 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 607) -> str: 608 keys = expression.args["keys"] 609 values = expression.args["values"] 610 611 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 612 self.unsupported("Cannot convert array columns into map.") 613 return self.func(map_func_name, keys, values) 614 615 args = [] 616 for key, value in zip(keys.expressions, values.expressions): 617 args.append(self.sql(key)) 618 args.append(self.sql(value)) 619 620 return self.func(map_func_name, *args) 621 622 623def build_formatted_time( 624 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 625) -> t.Callable[[t.List], E]: 626 """Helper used for time expressions. 627 628 Args: 629 exp_class: the expression class to instantiate. 630 dialect: target sql dialect. 631 default: the default format, True being time. 632 633 Returns: 634 A callable that can be used to return the appropriately formatted time expression. 635 """ 636 637 def _builder(args: t.List): 638 return exp_class( 639 this=seq_get(args, 0), 640 format=Dialect[dialect].format_time( 641 seq_get(args, 1) 642 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 643 ), 644 ) 645 646 return _builder 647 648 649def time_format( 650 dialect: DialectType = None, 651) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 652 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 653 """ 654 Returns the time format for a given expression, unless it's equivalent 655 to the default time format of the dialect of interest. 656 """ 657 time_format = self.format_time(expression) 658 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 659 660 return _time_format 661 662 663def build_date_delta( 664 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 665) -> t.Callable[[t.List], E]: 666 def _builder(args: t.List) -> E: 667 unit_based = len(args) == 3 668 this = args[2] if unit_based else seq_get(args, 0) 669 unit = args[0] if unit_based else exp.Literal.string("DAY") 670 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 671 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 672 673 return _builder 674 675 676def build_date_delta_with_interval( 677 expression_class: t.Type[E], 678) -> t.Callable[[t.List], t.Optional[E]]: 679 def _builder(args: t.List) -> t.Optional[E]: 680 if len(args) < 2: 681 return None 682 683 interval = args[1] 684 685 if not isinstance(interval, exp.Interval): 686 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 687 688 expression = interval.this 689 if expression and expression.is_string: 690 expression = exp.Literal.number(expression.this) 691 692 return expression_class( 693 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 694 ) 695 696 return _builder 697 698 699def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 700 unit = seq_get(args, 0) 701 this = seq_get(args, 1) 702 703 if isinstance(this, exp.Cast) and this.is_type("date"): 704 return exp.DateTrunc(unit=unit, this=this) 705 return exp.TimestampTrunc(this=this, unit=unit) 706 707 708def date_add_interval_sql( 709 data_type: str, kind: str 710) -> t.Callable[[Generator, exp.Expression], str]: 711 def func(self: Generator, expression: exp.Expression) -> str: 712 this = self.sql(expression, "this") 713 unit = expression.args.get("unit") 714 unit = exp.var(unit.name.upper() if unit else "DAY") 715 interval = exp.Interval(this=expression.expression, unit=unit) 716 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 717 718 return func 719 720 721def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 722 return self.func( 723 "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this 724 ) 725 726 727def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 728 if not expression.expression: 729 from sqlglot.optimizer.annotate_types import annotate_types 730 731 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 732 return self.sql(exp.cast(expression.this, to=target_type)) 733 if expression.text("expression").lower() in TIMEZONES: 734 return self.sql( 735 exp.AtTimeZone( 736 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 737 zone=expression.expression, 738 ) 739 ) 740 return self.func("TIMESTAMP", expression.this, expression.expression) 741 742 743def locate_to_strposition(args: t.List) -> exp.Expression: 744 return exp.StrPosition( 745 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 746 ) 747 748 749def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 750 return self.func( 751 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 752 ) 753 754 755def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 756 return self.sql( 757 exp.Substring( 758 this=expression.this, start=exp.Literal.number(1), length=expression.expression 759 ) 760 ) 761 762 763def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 764 return self.sql( 765 exp.Substring( 766 this=expression.this, 767 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 768 ) 769 ) 770 771 772def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 773 return self.sql(exp.cast(expression.this, "timestamp")) 774 775 776def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 777 return self.sql(exp.cast(expression.this, "date")) 778 779 780# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 781def encode_decode_sql( 782 self: Generator, expression: exp.Expression, name: str, replace: bool = True 783) -> str: 784 charset = expression.args.get("charset") 785 if charset and charset.name.lower() != "utf-8": 786 self.unsupported(f"Expected utf-8 character set, got {charset}.") 787 788 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 789 790 791def min_or_least(self: Generator, expression: exp.Min) -> str: 792 name = "LEAST" if expression.expressions else "MIN" 793 return rename_func(name)(self, expression) 794 795 796def max_or_greatest(self: Generator, expression: exp.Max) -> str: 797 name = "GREATEST" if expression.expressions else "MAX" 798 return rename_func(name)(self, expression) 799 800 801def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 802 cond = expression.this 803 804 if isinstance(expression.this, exp.Distinct): 805 cond = expression.this.expressions[0] 806 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 807 808 return self.func("sum", exp.func("if", cond, 1, 0)) 809 810 811def trim_sql(self: Generator, expression: exp.Trim) -> str: 812 target = self.sql(expression, "this") 813 trim_type = self.sql(expression, "position") 814 remove_chars = self.sql(expression, "expression") 815 collation = self.sql(expression, "collation") 816 817 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 818 if not remove_chars and not collation: 819 return self.trim_sql(expression) 820 821 trim_type = f"{trim_type} " if trim_type else "" 822 remove_chars = f"{remove_chars} " if remove_chars else "" 823 from_part = "FROM " if trim_type or remove_chars else "" 824 collation = f" COLLATE {collation}" if collation else "" 825 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 826 827 828def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 829 return self.func("STRPTIME", expression.this, self.format_time(expression)) 830 831 832def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 833 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 834 835 836def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 837 delim, *rest_args = expression.expressions 838 return self.sql( 839 reduce( 840 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 841 rest_args, 842 ) 843 ) 844 845 846def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 847 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 848 if bad_args: 849 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 850 851 return self.func( 852 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 853 ) 854 855 856def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 857 bad_args = list( 858 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 859 ) 860 if bad_args: 861 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 862 863 return self.func( 864 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 865 ) 866 867 868def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 869 names = [] 870 for agg in aggregations: 871 if isinstance(agg, exp.Alias): 872 names.append(agg.alias) 873 else: 874 """ 875 This case corresponds to aggregations without aliases being used as suffixes 876 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 877 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 878 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 879 """ 880 agg_all_unquoted = agg.transform( 881 lambda node: ( 882 exp.Identifier(this=node.name, quoted=False) 883 if isinstance(node, exp.Identifier) 884 else node 885 ) 886 ) 887 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 888 889 return names 890 891 892def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 893 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 894 895 896# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 897def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 898 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 899 900 901def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 902 return self.func("MAX", expression.this) 903 904 905def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 906 a = self.sql(expression.left) 907 b = self.sql(expression.right) 908 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 909 910 911def is_parse_json(expression: exp.Expression) -> bool: 912 return isinstance(expression, exp.ParseJSON) or ( 913 isinstance(expression, exp.Cast) and expression.is_type("json") 914 ) 915 916 917def isnull_to_is_null(args: t.List) -> exp.Expression: 918 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 919 920 921def generatedasidentitycolumnconstraint_sql( 922 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 923) -> str: 924 start = self.sql(expression, "start") or "1" 925 increment = self.sql(expression, "increment") or "1" 926 return f"IDENTITY({start}, {increment})" 927 928 929def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 930 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 931 if expression.args.get("count"): 932 self.unsupported(f"Only two arguments are supported in function {name}.") 933 934 return self.func(name, expression.this, expression.expression) 935 936 return _arg_max_or_min_sql 937 938 939def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 940 this = expression.this.copy() 941 942 return_type = expression.return_type 943 if return_type.is_type(exp.DataType.Type.DATE): 944 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 945 # can truncate timestamp strings, because some dialects can't cast them to DATE 946 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 947 948 expression.this.replace(exp.cast(this, return_type)) 949 return expression 950 951 952def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 953 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 954 if cast and isinstance(expression, exp.TsOrDsAdd): 955 expression = ts_or_ds_add_cast(expression) 956 957 return self.func( 958 name, 959 exp.var(expression.text("unit").upper() or "DAY"), 960 expression.expression, 961 expression.this, 962 ) 963 964 return _delta_sql 965 966 967def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 968 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 969 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 970 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 971 972 return self.sql(exp.cast(minus_one_day, "date")) 973 974 975def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 976 """Remove table refs from columns in when statements.""" 977 alias = expression.this.args.get("alias") 978 979 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 980 return self.dialect.normalize_identifier(identifier).name if identifier else None 981 982 targets = {normalize(expression.this.this)} 983 984 if alias: 985 targets.add(normalize(alias.this)) 986 987 for when in expression.expressions: 988 when.transform( 989 lambda node: ( 990 exp.column(node.this) 991 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 992 else node 993 ), 994 copy=False, 995 ) 996 997 return self.merge_sql(expression) 998 999 1000def build_json_extract_path( 1001 expr_type: t.Type[F], zero_based_indexing: bool = True 1002) -> t.Callable[[t.List], F]: 1003 def _builder(args: t.List) -> F: 1004 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1005 for arg in args[1:]: 1006 if not isinstance(arg, exp.Literal): 1007 # We use the fallback parser because we can't really transpile non-literals safely 1008 return expr_type.from_arg_list(args) 1009 1010 text = arg.name 1011 if is_int(text): 1012 index = int(text) 1013 segments.append( 1014 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1015 ) 1016 else: 1017 segments.append(exp.JSONPathKey(this=text)) 1018 1019 # This is done to avoid failing in the expression validator due to the arg count 1020 del args[2:] 1021 return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments)) 1022 1023 return _builder 1024 1025 1026def json_extract_segments( 1027 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1028) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1029 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1030 path = expression.expression 1031 if not isinstance(path, exp.JSONPath): 1032 return rename_func(name)(self, expression) 1033 1034 segments = [] 1035 for segment in path.expressions: 1036 path = self.sql(segment) 1037 if path: 1038 if isinstance(segment, exp.JSONPathPart) and ( 1039 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1040 ): 1041 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1042 1043 segments.append(path) 1044 1045 if op: 1046 return f" {op} ".join([self.sql(expression.this), *segments]) 1047 return self.func(name, expression.this, *segments) 1048 1049 return _json_extract_segments 1050 1051 1052def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1053 if isinstance(expression.this, exp.JSONPathWildcard): 1054 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1055 1056 return expression.name 1057 1058 1059def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1060 cond = expression.expression 1061 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1062 alias = cond.expressions[0] 1063 cond = cond.this 1064 elif isinstance(cond, exp.Predicate): 1065 alias = "_u" 1066 else: 1067 self.unsupported("Unsupported filter condition") 1068 return "" 1069 1070 unnest = exp.Unnest(expressions=[expression.this]) 1071 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1072 return self.sql(exp.Array(expressions=[filtered]))
30class Dialects(str, Enum): 31 """Dialects supported by SQLGLot.""" 32 33 DIALECT = "" 34 35 BIGQUERY = "bigquery" 36 CLICKHOUSE = "clickhouse" 37 DATABRICKS = "databricks" 38 DORIS = "doris" 39 DRILL = "drill" 40 DUCKDB = "duckdb" 41 HIVE = "hive" 42 MYSQL = "mysql" 43 ORACLE = "oracle" 44 POSTGRES = "postgres" 45 PRESTO = "presto" 46 REDSHIFT = "redshift" 47 SNOWFLAKE = "snowflake" 48 SPARK = "spark" 49 SPARK2 = "spark2" 50 SQLITE = "sqlite" 51 STARROCKS = "starrocks" 52 TABLEAU = "tableau" 53 TERADATA = "teradata" 54 TRINO = "trino" 55 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
58class NormalizationStrategy(str, AutoName): 59 """Specifies the strategy according to which identifiers should be normalized.""" 60 61 LOWERCASE = auto() 62 """Unquoted identifiers are lowercased.""" 63 64 UPPERCASE = auto() 65 """Unquoted identifiers are uppercased.""" 66 67 CASE_SENSITIVE = auto() 68 """Always case-sensitive, regardless of quotes.""" 69 70 CASE_INSENSITIVE = auto() 71 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
150class Dialect(metaclass=_Dialect): 151 INDEX_OFFSET = 0 152 """The base index offset for arrays.""" 153 154 WEEK_OFFSET = 0 155 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 156 157 UNNEST_COLUMN_ONLY = False 158 """Whether `UNNEST` table aliases are treated as column aliases.""" 159 160 ALIAS_POST_TABLESAMPLE = False 161 """Whether the table alias comes after tablesample.""" 162 163 TABLESAMPLE_SIZE_IS_PERCENT = False 164 """Whether a size in the table sample clause represents percentage.""" 165 166 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 167 """Specifies the strategy according to which identifiers should be normalized.""" 168 169 IDENTIFIERS_CAN_START_WITH_DIGIT = False 170 """Whether an unquoted identifier can start with a digit.""" 171 172 DPIPE_IS_STRING_CONCAT = True 173 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 174 175 STRICT_STRING_CONCAT = False 176 """Whether `CONCAT`'s arguments must be strings.""" 177 178 SUPPORTS_USER_DEFINED_TYPES = True 179 """Whether user-defined data types are supported.""" 180 181 SUPPORTS_SEMI_ANTI_JOIN = True 182 """Whether `SEMI` or `ANTI` joins are supported.""" 183 184 NORMALIZE_FUNCTIONS: bool | str = "upper" 185 """ 186 Determines how function names are going to be normalized. 187 Possible values: 188 "upper" or True: Convert names to uppercase. 189 "lower": Convert names to lowercase. 190 False: Disables function name normalization. 191 """ 192 193 LOG_BASE_FIRST = True 194 """Whether the base comes first in the `LOG` function.""" 195 196 NULL_ORDERING = "nulls_are_small" 197 """ 198 Default `NULL` ordering method to use if not explicitly set. 199 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 200 """ 201 202 TYPED_DIVISION = False 203 """ 204 Whether the behavior of `a / b` depends on the types of `a` and `b`. 205 False means `a / b` is always float division. 206 True means `a / b` is integer division if both `a` and `b` are integers. 207 """ 208 209 SAFE_DIVISION = False 210 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 211 212 CONCAT_COALESCE = False 213 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 214 215 DATE_FORMAT = "'%Y-%m-%d'" 216 DATEINT_FORMAT = "'%Y%m%d'" 217 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 218 219 TIME_MAPPING: t.Dict[str, str] = {} 220 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 221 222 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 223 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 224 FORMAT_MAPPING: t.Dict[str, str] = {} 225 """ 226 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 227 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 228 """ 229 230 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 231 """Mapping of an unescaped escape sequence to the corresponding character.""" 232 233 PSEUDOCOLUMNS: t.Set[str] = set() 234 """ 235 Columns that are auto-generated by the engine corresponding to this dialect. 236 For example, such columns may be excluded from `SELECT *` queries. 237 """ 238 239 PREFER_CTE_ALIAS_COLUMN = False 240 """ 241 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 242 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 243 any projection aliases in the subquery. 244 245 For example, 246 WITH y(c) AS ( 247 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 248 ) SELECT c FROM y; 249 250 will be rewritten as 251 252 WITH y(c) AS ( 253 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 254 ) SELECT c FROM y; 255 """ 256 257 # --- Autofilled --- 258 259 tokenizer_class = Tokenizer 260 parser_class = Parser 261 generator_class = Generator 262 263 # A trie of the time_mapping keys 264 TIME_TRIE: t.Dict = {} 265 FORMAT_TRIE: t.Dict = {} 266 267 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 268 INVERSE_TIME_TRIE: t.Dict = {} 269 270 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 271 272 # Delimiters for string literals and identifiers 273 QUOTE_START = "'" 274 QUOTE_END = "'" 275 IDENTIFIER_START = '"' 276 IDENTIFIER_END = '"' 277 278 # Delimiters for bit, hex, byte and unicode literals 279 BIT_START: t.Optional[str] = None 280 BIT_END: t.Optional[str] = None 281 HEX_START: t.Optional[str] = None 282 HEX_END: t.Optional[str] = None 283 BYTE_START: t.Optional[str] = None 284 BYTE_END: t.Optional[str] = None 285 UNICODE_START: t.Optional[str] = None 286 UNICODE_END: t.Optional[str] = None 287 288 @classmethod 289 def get_or_raise(cls, dialect: DialectType) -> Dialect: 290 """ 291 Look up a dialect in the global dialect registry and return it if it exists. 292 293 Args: 294 dialect: The target dialect. If this is a string, it can be optionally followed by 295 additional key-value pairs that are separated by commas and are used to specify 296 dialect settings, such as whether the dialect's identifiers are case-sensitive. 297 298 Example: 299 >>> dialect = dialect_class = get_or_raise("duckdb") 300 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 301 302 Returns: 303 The corresponding Dialect instance. 304 """ 305 306 if not dialect: 307 return cls() 308 if isinstance(dialect, _Dialect): 309 return dialect() 310 if isinstance(dialect, Dialect): 311 return dialect 312 if isinstance(dialect, str): 313 try: 314 dialect_name, *kv_pairs = dialect.split(",") 315 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 316 except ValueError: 317 raise ValueError( 318 f"Invalid dialect format: '{dialect}'. " 319 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 320 ) 321 322 result = cls.get(dialect_name.strip()) 323 if not result: 324 from difflib import get_close_matches 325 326 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 327 if similar: 328 similar = f" Did you mean {similar}?" 329 330 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 331 332 return result(**kwargs) 333 334 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 335 336 @classmethod 337 def format_time( 338 cls, expression: t.Optional[str | exp.Expression] 339 ) -> t.Optional[exp.Expression]: 340 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 341 if isinstance(expression, str): 342 return exp.Literal.string( 343 # the time formats are quoted 344 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 345 ) 346 347 if expression and expression.is_string: 348 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 349 350 return expression 351 352 def __init__(self, **kwargs) -> None: 353 normalization_strategy = kwargs.get("normalization_strategy") 354 355 if normalization_strategy is None: 356 self.normalization_strategy = self.NORMALIZATION_STRATEGY 357 else: 358 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 359 360 def __eq__(self, other: t.Any) -> bool: 361 # Does not currently take dialect state into account 362 return type(self) == other 363 364 def __hash__(self) -> int: 365 # Does not currently take dialect state into account 366 return hash(type(self)) 367 368 def normalize_identifier(self, expression: E) -> E: 369 """ 370 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 371 372 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 373 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 374 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 375 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 376 377 There are also dialects like Spark, which are case-insensitive even when quotes are 378 present, and dialects like MySQL, whose resolution rules match those employed by the 379 underlying operating system, for example they may always be case-sensitive in Linux. 380 381 Finally, the normalization behavior of some engines can even be controlled through flags, 382 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 383 384 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 385 that it can analyze queries in the optimizer and successfully capture their semantics. 386 """ 387 if ( 388 isinstance(expression, exp.Identifier) 389 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 390 and ( 391 not expression.quoted 392 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 393 ) 394 ): 395 expression.set( 396 "this", 397 ( 398 expression.this.upper() 399 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 400 else expression.this.lower() 401 ), 402 ) 403 404 return expression 405 406 def case_sensitive(self, text: str) -> bool: 407 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 408 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 409 return False 410 411 unsafe = ( 412 str.islower 413 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 414 else str.isupper 415 ) 416 return any(unsafe(char) for char in text) 417 418 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 419 """Checks if text can be identified given an identify option. 420 421 Args: 422 text: The text to check. 423 identify: 424 `"always"` or `True`: Always returns `True`. 425 `"safe"`: Only returns `True` if the identifier is case-insensitive. 426 427 Returns: 428 Whether the given text can be identified. 429 """ 430 if identify is True or identify == "always": 431 return True 432 433 if identify == "safe": 434 return not self.case_sensitive(text) 435 436 return False 437 438 def quote_identifier(self, expression: E, identify: bool = True) -> E: 439 """ 440 Adds quotes to a given identifier. 441 442 Args: 443 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 444 identify: If set to `False`, the quotes will only be added if the identifier is deemed 445 "unsafe", with respect to its characters and this dialect's normalization strategy. 446 """ 447 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 448 name = expression.this 449 expression.set( 450 "quoted", 451 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 452 ) 453 454 return expression 455 456 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 457 if isinstance(path, exp.Literal): 458 path_text = path.name 459 if path.is_number: 460 path_text = f"[{path_text}]" 461 462 try: 463 return parse_json_path(path_text) 464 except ParseError as e: 465 logger.warning(f"Invalid JSON path syntax. {str(e)}") 466 467 return path 468 469 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 470 return self.parser(**opts).parse(self.tokenize(sql), sql) 471 472 def parse_into( 473 self, expression_type: exp.IntoType, sql: str, **opts 474 ) -> t.List[t.Optional[exp.Expression]]: 475 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 476 477 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 478 return self.generator(**opts).generate(expression, copy=copy) 479 480 def transpile(self, sql: str, **opts) -> t.List[str]: 481 return [ 482 self.generate(expression, copy=False, **opts) if expression else "" 483 for expression in self.parse(sql) 484 ] 485 486 def tokenize(self, sql: str) -> t.List[Token]: 487 return self.tokenizer.tokenize(sql) 488 489 @property 490 def tokenizer(self) -> Tokenizer: 491 if not hasattr(self, "_tokenizer"): 492 self._tokenizer = self.tokenizer_class(dialect=self) 493 return self._tokenizer 494 495 def parser(self, **opts) -> Parser: 496 return self.parser_class(dialect=self, **opts) 497 498 def generator(self, **opts) -> Generator: 499 return self.generator_class(dialect=self, **opts)
352 def __init__(self, **kwargs) -> None: 353 normalization_strategy = kwargs.get("normalization_strategy") 354 355 if normalization_strategy is None: 356 self.normalization_strategy = self.NORMALIZATION_STRATEGY 357 else: 358 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an unescaped escape sequence to the corresponding character.
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
288 @classmethod 289 def get_or_raise(cls, dialect: DialectType) -> Dialect: 290 """ 291 Look up a dialect in the global dialect registry and return it if it exists. 292 293 Args: 294 dialect: The target dialect. If this is a string, it can be optionally followed by 295 additional key-value pairs that are separated by commas and are used to specify 296 dialect settings, such as whether the dialect's identifiers are case-sensitive. 297 298 Example: 299 >>> dialect = dialect_class = get_or_raise("duckdb") 300 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 301 302 Returns: 303 The corresponding Dialect instance. 304 """ 305 306 if not dialect: 307 return cls() 308 if isinstance(dialect, _Dialect): 309 return dialect() 310 if isinstance(dialect, Dialect): 311 return dialect 312 if isinstance(dialect, str): 313 try: 314 dialect_name, *kv_pairs = dialect.split(",") 315 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 316 except ValueError: 317 raise ValueError( 318 f"Invalid dialect format: '{dialect}'. " 319 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 320 ) 321 322 result = cls.get(dialect_name.strip()) 323 if not result: 324 from difflib import get_close_matches 325 326 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 327 if similar: 328 similar = f" Did you mean {similar}?" 329 330 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 331 332 return result(**kwargs) 333 334 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
336 @classmethod 337 def format_time( 338 cls, expression: t.Optional[str | exp.Expression] 339 ) -> t.Optional[exp.Expression]: 340 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 341 if isinstance(expression, str): 342 return exp.Literal.string( 343 # the time formats are quoted 344 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 345 ) 346 347 if expression and expression.is_string: 348 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 349 350 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
368 def normalize_identifier(self, expression: E) -> E: 369 """ 370 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 371 372 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 373 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 374 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 375 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 376 377 There are also dialects like Spark, which are case-insensitive even when quotes are 378 present, and dialects like MySQL, whose resolution rules match those employed by the 379 underlying operating system, for example they may always be case-sensitive in Linux. 380 381 Finally, the normalization behavior of some engines can even be controlled through flags, 382 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 383 384 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 385 that it can analyze queries in the optimizer and successfully capture their semantics. 386 """ 387 if ( 388 isinstance(expression, exp.Identifier) 389 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 390 and ( 391 not expression.quoted 392 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 393 ) 394 ): 395 expression.set( 396 "this", 397 ( 398 expression.this.upper() 399 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 400 else expression.this.lower() 401 ), 402 ) 403 404 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
406 def case_sensitive(self, text: str) -> bool: 407 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 408 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 409 return False 410 411 unsafe = ( 412 str.islower 413 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 414 else str.isupper 415 ) 416 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
418 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 419 """Checks if text can be identified given an identify option. 420 421 Args: 422 text: The text to check. 423 identify: 424 `"always"` or `True`: Always returns `True`. 425 `"safe"`: Only returns `True` if the identifier is case-insensitive. 426 427 Returns: 428 Whether the given text can be identified. 429 """ 430 if identify is True or identify == "always": 431 return True 432 433 if identify == "safe": 434 return not self.case_sensitive(text) 435 436 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
438 def quote_identifier(self, expression: E, identify: bool = True) -> E: 439 """ 440 Adds quotes to a given identifier. 441 442 Args: 443 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 444 identify: If set to `False`, the quotes will only be added if the identifier is deemed 445 "unsafe", with respect to its characters and this dialect's normalization strategy. 446 """ 447 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 448 name = expression.this 449 expression.set( 450 "quoted", 451 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 452 ) 453 454 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
456 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 457 if isinstance(path, exp.Literal): 458 path_text = path.name 459 if path.is_number: 460 path_text = f"[{path_text}]" 461 462 try: 463 return parse_json_path(path_text) 464 except ParseError as e: 465 logger.warning(f"Invalid JSON path syntax. {str(e)}") 466 467 return path
515def if_sql( 516 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 517) -> t.Callable[[Generator, exp.If], str]: 518 def _if_sql(self: Generator, expression: exp.If) -> str: 519 return self.func( 520 name, 521 expression.this, 522 expression.args.get("true"), 523 expression.args.get("false") or false_value, 524 ) 525 526 return _if_sql
529def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 530 this = expression.this 531 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 532 this.replace(exp.cast(this, "json")) 533 534 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
591def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 592 this = self.sql(expression, "this") 593 substr = self.sql(expression, "substr") 594 position = self.sql(expression, "position") 595 if position: 596 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 597 return f"STRPOS({this}, {substr})"
606def var_map_sql( 607 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 608) -> str: 609 keys = expression.args["keys"] 610 values = expression.args["values"] 611 612 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 613 self.unsupported("Cannot convert array columns into map.") 614 return self.func(map_func_name, keys, values) 615 616 args = [] 617 for key, value in zip(keys.expressions, values.expressions): 618 args.append(self.sql(key)) 619 args.append(self.sql(value)) 620 621 return self.func(map_func_name, *args)
624def build_formatted_time( 625 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 626) -> t.Callable[[t.List], E]: 627 """Helper used for time expressions. 628 629 Args: 630 exp_class: the expression class to instantiate. 631 dialect: target sql dialect. 632 default: the default format, True being time. 633 634 Returns: 635 A callable that can be used to return the appropriately formatted time expression. 636 """ 637 638 def _builder(args: t.List): 639 return exp_class( 640 this=seq_get(args, 0), 641 format=Dialect[dialect].format_time( 642 seq_get(args, 1) 643 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 644 ), 645 ) 646 647 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
650def time_format( 651 dialect: DialectType = None, 652) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 653 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 654 """ 655 Returns the time format for a given expression, unless it's equivalent 656 to the default time format of the dialect of interest. 657 """ 658 time_format = self.format_time(expression) 659 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 660 661 return _time_format
664def build_date_delta( 665 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 666) -> t.Callable[[t.List], E]: 667 def _builder(args: t.List) -> E: 668 unit_based = len(args) == 3 669 this = args[2] if unit_based else seq_get(args, 0) 670 unit = args[0] if unit_based else exp.Literal.string("DAY") 671 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 672 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 673 674 return _builder
677def build_date_delta_with_interval( 678 expression_class: t.Type[E], 679) -> t.Callable[[t.List], t.Optional[E]]: 680 def _builder(args: t.List) -> t.Optional[E]: 681 if len(args) < 2: 682 return None 683 684 interval = args[1] 685 686 if not isinstance(interval, exp.Interval): 687 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 688 689 expression = interval.this 690 if expression and expression.is_string: 691 expression = exp.Literal.number(expression.this) 692 693 return expression_class( 694 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 695 ) 696 697 return _builder
709def date_add_interval_sql( 710 data_type: str, kind: str 711) -> t.Callable[[Generator, exp.Expression], str]: 712 def func(self: Generator, expression: exp.Expression) -> str: 713 this = self.sql(expression, "this") 714 unit = expression.args.get("unit") 715 unit = exp.var(unit.name.upper() if unit else "DAY") 716 interval = exp.Interval(this=expression.expression, unit=unit) 717 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 718 719 return func
728def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 729 if not expression.expression: 730 from sqlglot.optimizer.annotate_types import annotate_types 731 732 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 733 return self.sql(exp.cast(expression.this, to=target_type)) 734 if expression.text("expression").lower() in TIMEZONES: 735 return self.sql( 736 exp.AtTimeZone( 737 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 738 zone=expression.expression, 739 ) 740 ) 741 return self.func("TIMESTAMP", expression.this, expression.expression)
782def encode_decode_sql( 783 self: Generator, expression: exp.Expression, name: str, replace: bool = True 784) -> str: 785 charset = expression.args.get("charset") 786 if charset and charset.name.lower() != "utf-8": 787 self.unsupported(f"Expected utf-8 character set, got {charset}.") 788 789 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
802def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 803 cond = expression.this 804 805 if isinstance(expression.this, exp.Distinct): 806 cond = expression.this.expressions[0] 807 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 808 809 return self.func("sum", exp.func("if", cond, 1, 0))
812def trim_sql(self: Generator, expression: exp.Trim) -> str: 813 target = self.sql(expression, "this") 814 trim_type = self.sql(expression, "position") 815 remove_chars = self.sql(expression, "expression") 816 collation = self.sql(expression, "collation") 817 818 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 819 if not remove_chars and not collation: 820 return self.trim_sql(expression) 821 822 trim_type = f"{trim_type} " if trim_type else "" 823 remove_chars = f"{remove_chars} " if remove_chars else "" 824 from_part = "FROM " if trim_type or remove_chars else "" 825 collation = f" COLLATE {collation}" if collation else "" 826 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
847def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 848 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 849 if bad_args: 850 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 851 852 return self.func( 853 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 854 )
857def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 858 bad_args = list( 859 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 860 ) 861 if bad_args: 862 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 863 864 return self.func( 865 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 866 )
869def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 870 names = [] 871 for agg in aggregations: 872 if isinstance(agg, exp.Alias): 873 names.append(agg.alias) 874 else: 875 """ 876 This case corresponds to aggregations without aliases being used as suffixes 877 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 878 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 879 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 880 """ 881 agg_all_unquoted = agg.transform( 882 lambda node: ( 883 exp.Identifier(this=node.name, quoted=False) 884 if isinstance(node, exp.Identifier) 885 else node 886 ) 887 ) 888 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 889 890 return names
930def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 931 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 932 if expression.args.get("count"): 933 self.unsupported(f"Only two arguments are supported in function {name}.") 934 935 return self.func(name, expression.this, expression.expression) 936 937 return _arg_max_or_min_sql
940def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 941 this = expression.this.copy() 942 943 return_type = expression.return_type 944 if return_type.is_type(exp.DataType.Type.DATE): 945 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 946 # can truncate timestamp strings, because some dialects can't cast them to DATE 947 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 948 949 expression.this.replace(exp.cast(this, return_type)) 950 return expression
953def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 954 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 955 if cast and isinstance(expression, exp.TsOrDsAdd): 956 expression = ts_or_ds_add_cast(expression) 957 958 return self.func( 959 name, 960 exp.var(expression.text("unit").upper() or "DAY"), 961 expression.expression, 962 expression.this, 963 ) 964 965 return _delta_sql
968def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 969 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 970 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 971 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 972 973 return self.sql(exp.cast(minus_one_day, "date"))
976def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 977 """Remove table refs from columns in when statements.""" 978 alias = expression.this.args.get("alias") 979 980 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 981 return self.dialect.normalize_identifier(identifier).name if identifier else None 982 983 targets = {normalize(expression.this.this)} 984 985 if alias: 986 targets.add(normalize(alias.this)) 987 988 for when in expression.expressions: 989 when.transform( 990 lambda node: ( 991 exp.column(node.this) 992 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 993 else node 994 ), 995 copy=False, 996 ) 997 998 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1001def build_json_extract_path( 1002 expr_type: t.Type[F], zero_based_indexing: bool = True 1003) -> t.Callable[[t.List], F]: 1004 def _builder(args: t.List) -> F: 1005 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1006 for arg in args[1:]: 1007 if not isinstance(arg, exp.Literal): 1008 # We use the fallback parser because we can't really transpile non-literals safely 1009 return expr_type.from_arg_list(args) 1010 1011 text = arg.name 1012 if is_int(text): 1013 index = int(text) 1014 segments.append( 1015 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1016 ) 1017 else: 1018 segments.append(exp.JSONPathKey(this=text)) 1019 1020 # This is done to avoid failing in the expression validator due to the arg count 1021 del args[2:] 1022 return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments)) 1023 1024 return _builder
1027def json_extract_segments( 1028 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1029) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1030 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1031 path = expression.expression 1032 if not isinstance(path, exp.JSONPath): 1033 return rename_func(name)(self, expression) 1034 1035 segments = [] 1036 for segment in path.expressions: 1037 path = self.sql(segment) 1038 if path: 1039 if isinstance(segment, exp.JSONPathPart) and ( 1040 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1041 ): 1042 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1043 1044 segments.append(path) 1045 1046 if op: 1047 return f" {op} ".join([self.sql(expression.this), *segments]) 1048 return self.func(name, expression.this, *segments) 1049 1050 return _json_extract_segments
1060def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1061 cond = expression.expression 1062 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1063 alias = cond.expressions[0] 1064 cond = cond.this 1065 elif isinstance(cond, exp.Predicate): 1066 alias = "_u" 1067 else: 1068 self.unsupported("Unsupported filter condition") 1069 return "" 1070 1071 unnest = exp.Unnest(expressions=[expression.this]) 1072 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1073 return self.sql(exp.Array(expressions=[filtered]))