sqlglot.dialects.dialect
1from __future__ import annotations 2 3import typing as t 4from enum import Enum, auto 5from functools import reduce 6 7from sqlglot import exp 8from sqlglot._typing import E 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, seq_get 12from sqlglot.parser import Parser 13from sqlglot.time import TIMEZONES, format_time 14from sqlglot.tokens import Token, Tokenizer, TokenType 15from sqlglot.trie import new_trie 16 17B = t.TypeVar("B", bound=exp.Binary) 18 19DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 20DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 21 22 23class Dialects(str, Enum): 24 """Dialects supported by SQLGLot.""" 25 26 DIALECT = "" 27 28 BIGQUERY = "bigquery" 29 CLICKHOUSE = "clickhouse" 30 DATABRICKS = "databricks" 31 DORIS = "doris" 32 DRILL = "drill" 33 DUCKDB = "duckdb" 34 HIVE = "hive" 35 MYSQL = "mysql" 36 ORACLE = "oracle" 37 POSTGRES = "postgres" 38 PRESTO = "presto" 39 REDSHIFT = "redshift" 40 SNOWFLAKE = "snowflake" 41 SPARK = "spark" 42 SPARK2 = "spark2" 43 SQLITE = "sqlite" 44 STARROCKS = "starrocks" 45 TABLEAU = "tableau" 46 TERADATA = "teradata" 47 TRINO = "trino" 48 TSQL = "tsql" 49 50 51class NormalizationStrategy(str, AutoName): 52 """Specifies the strategy according to which identifiers should be normalized.""" 53 54 LOWERCASE = auto() 55 """Unquoted identifiers are lowercased.""" 56 57 UPPERCASE = auto() 58 """Unquoted identifiers are uppercased.""" 59 60 CASE_SENSITIVE = auto() 61 """Always case-sensitive, regardless of quotes.""" 62 63 CASE_INSENSITIVE = auto() 64 """Always case-insensitive, regardless of quotes.""" 65 66 67class _Dialect(type): 68 classes: t.Dict[str, t.Type[Dialect]] = {} 69 70 def __eq__(cls, other: t.Any) -> bool: 71 if cls is other: 72 return True 73 if isinstance(other, str): 74 return cls is cls.get(other) 75 if isinstance(other, Dialect): 76 return cls is type(other) 77 78 return False 79 80 def __hash__(cls) -> int: 81 return hash(cls.__name__.lower()) 82 83 @classmethod 84 def __getitem__(cls, key: str) -> t.Type[Dialect]: 85 return cls.classes[key] 86 87 @classmethod 88 def get( 89 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 90 ) -> t.Optional[t.Type[Dialect]]: 91 return cls.classes.get(key, default) 92 93 def __new__(cls, clsname, bases, attrs): 94 klass = super().__new__(cls, clsname, bases, attrs) 95 enum = Dialects.__members__.get(clsname.upper()) 96 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 97 98 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 99 klass.FORMAT_TRIE = ( 100 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 101 ) 102 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 103 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 104 105 klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()} 106 107 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 108 klass.parser_class = getattr(klass, "Parser", Parser) 109 klass.generator_class = getattr(klass, "Generator", Generator) 110 111 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 112 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 113 klass.tokenizer_class._IDENTIFIERS.items() 114 )[0] 115 116 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 117 return next( 118 ( 119 (s, e) 120 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 121 if t == token_type 122 ), 123 (None, None), 124 ) 125 126 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 127 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 128 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 129 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 130 131 if enum not in ("", "bigquery"): 132 klass.generator_class.SELECT_KINDS = () 133 134 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 135 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 136 TokenType.ANTI, 137 TokenType.SEMI, 138 } 139 140 return klass 141 142 143class Dialect(metaclass=_Dialect): 144 INDEX_OFFSET = 0 145 """Determines the base index offset for arrays.""" 146 147 WEEK_OFFSET = 0 148 """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 149 150 UNNEST_COLUMN_ONLY = False 151 """Determines whether or not `UNNEST` table aliases are treated as column aliases.""" 152 153 ALIAS_POST_TABLESAMPLE = False 154 """Determines whether or not the table alias comes after tablesample.""" 155 156 TABLESAMPLE_SIZE_IS_PERCENT = False 157 """Determines whether or not a size in the table sample clause represents percentage.""" 158 159 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 160 """Specifies the strategy according to which identifiers should be normalized.""" 161 162 IDENTIFIERS_CAN_START_WITH_DIGIT = False 163 """Determines whether or not an unquoted identifier can start with a digit.""" 164 165 DPIPE_IS_STRING_CONCAT = True 166 """Determines whether or not the DPIPE token (`||`) is a string concatenation operator.""" 167 168 STRICT_STRING_CONCAT = False 169 """Determines whether or not `CONCAT`'s arguments must be strings.""" 170 171 SUPPORTS_USER_DEFINED_TYPES = True 172 """Determines whether or not user-defined data types are supported.""" 173 174 SUPPORTS_SEMI_ANTI_JOIN = True 175 """Determines whether or not `SEMI` or `ANTI` joins are supported.""" 176 177 NORMALIZE_FUNCTIONS: bool | str = "upper" 178 """Determines how function names are going to be normalized.""" 179 180 LOG_BASE_FIRST = True 181 """Determines whether the base comes first in the `LOG` function.""" 182 183 NULL_ORDERING = "nulls_are_small" 184 """ 185 Indicates the default `NULL` ordering method to use if not explicitly set. 186 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 187 """ 188 189 TYPED_DIVISION = False 190 """ 191 Whether the behavior of `a / b` depends on the types of `a` and `b`. 192 False means `a / b` is always float division. 193 True means `a / b` is integer division if both `a` and `b` are integers. 194 """ 195 196 SAFE_DIVISION = False 197 """Determines whether division by zero throws an error (`False`) or returns NULL (`True`).""" 198 199 CONCAT_COALESCE = False 200 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 201 202 DATE_FORMAT = "'%Y-%m-%d'" 203 DATEINT_FORMAT = "'%Y%m%d'" 204 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 205 206 TIME_MAPPING: t.Dict[str, str] = {} 207 """Associates this dialect's time formats with their equivalent Python `strftime` format.""" 208 209 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 210 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 211 FORMAT_MAPPING: t.Dict[str, str] = {} 212 """ 213 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 214 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 215 """ 216 217 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 218 """Mapping of an unescaped escape sequence to the corresponding character.""" 219 220 PSEUDOCOLUMNS: t.Set[str] = set() 221 """ 222 Columns that are auto-generated by the engine corresponding to this dialect. 223 For example, such columns may be excluded from `SELECT *` queries. 224 """ 225 226 PREFER_CTE_ALIAS_COLUMN = False 227 """ 228 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 229 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 230 any projection aliases in the subquery. 231 232 For example, 233 WITH y(c) AS ( 234 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 235 ) SELECT c FROM y; 236 237 will be rewritten as 238 239 WITH y(c) AS ( 240 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 241 ) SELECT c FROM y; 242 """ 243 244 # --- Autofilled --- 245 246 tokenizer_class = Tokenizer 247 parser_class = Parser 248 generator_class = Generator 249 250 # A trie of the time_mapping keys 251 TIME_TRIE: t.Dict = {} 252 FORMAT_TRIE: t.Dict = {} 253 254 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 255 INVERSE_TIME_TRIE: t.Dict = {} 256 257 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 258 259 # Delimiters for quotes, identifiers and the corresponding escape characters 260 QUOTE_START = "'" 261 QUOTE_END = "'" 262 IDENTIFIER_START = '"' 263 IDENTIFIER_END = '"' 264 265 # Delimiters for bit, hex, byte and unicode literals 266 BIT_START: t.Optional[str] = None 267 BIT_END: t.Optional[str] = None 268 HEX_START: t.Optional[str] = None 269 HEX_END: t.Optional[str] = None 270 BYTE_START: t.Optional[str] = None 271 BYTE_END: t.Optional[str] = None 272 UNICODE_START: t.Optional[str] = None 273 UNICODE_END: t.Optional[str] = None 274 275 @classmethod 276 def get_or_raise(cls, dialect: DialectType) -> Dialect: 277 """ 278 Look up a dialect in the global dialect registry and return it if it exists. 279 280 Args: 281 dialect: The target dialect. If this is a string, it can be optionally followed by 282 additional key-value pairs that are separated by commas and are used to specify 283 dialect settings, such as whether the dialect's identifiers are case-sensitive. 284 285 Example: 286 >>> dialect = dialect_class = get_or_raise("duckdb") 287 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 288 289 Returns: 290 The corresponding Dialect instance. 291 """ 292 293 if not dialect: 294 return cls() 295 if isinstance(dialect, _Dialect): 296 return dialect() 297 if isinstance(dialect, Dialect): 298 return dialect 299 if isinstance(dialect, str): 300 try: 301 dialect_name, *kv_pairs = dialect.split(",") 302 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 303 except ValueError: 304 raise ValueError( 305 f"Invalid dialect format: '{dialect}'. " 306 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 307 ) 308 309 result = cls.get(dialect_name.strip()) 310 if not result: 311 from difflib import get_close_matches 312 313 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 314 if similar: 315 similar = f" Did you mean {similar}?" 316 317 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 318 319 return result(**kwargs) 320 321 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 322 323 @classmethod 324 def format_time( 325 cls, expression: t.Optional[str | exp.Expression] 326 ) -> t.Optional[exp.Expression]: 327 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 328 if isinstance(expression, str): 329 return exp.Literal.string( 330 # the time formats are quoted 331 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 332 ) 333 334 if expression and expression.is_string: 335 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 336 337 return expression 338 339 def __init__(self, **kwargs) -> None: 340 normalization_strategy = kwargs.get("normalization_strategy") 341 342 if normalization_strategy is None: 343 self.normalization_strategy = self.NORMALIZATION_STRATEGY 344 else: 345 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 346 347 def __eq__(self, other: t.Any) -> bool: 348 # Does not currently take dialect state into account 349 return type(self) == other 350 351 def __hash__(self) -> int: 352 # Does not currently take dialect state into account 353 return hash(type(self)) 354 355 def normalize_identifier(self, expression: E) -> E: 356 """ 357 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 358 359 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 360 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 361 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 362 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 363 364 There are also dialects like Spark, which are case-insensitive even when quotes are 365 present, and dialects like MySQL, whose resolution rules match those employed by the 366 underlying operating system, for example they may always be case-sensitive in Linux. 367 368 Finally, the normalization behavior of some engines can even be controlled through flags, 369 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 370 371 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 372 that it can analyze queries in the optimizer and successfully capture their semantics. 373 """ 374 if ( 375 isinstance(expression, exp.Identifier) 376 and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE 377 and ( 378 not expression.quoted 379 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 380 ) 381 ): 382 expression.set( 383 "this", 384 expression.this.upper() 385 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 386 else expression.this.lower(), 387 ) 388 389 return expression 390 391 def case_sensitive(self, text: str) -> bool: 392 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 393 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 394 return False 395 396 unsafe = ( 397 str.islower 398 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 399 else str.isupper 400 ) 401 return any(unsafe(char) for char in text) 402 403 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 404 """Checks if text can be identified given an identify option. 405 406 Args: 407 text: The text to check. 408 identify: 409 `"always"` or `True`: Always returns `True`. 410 `"safe"`: Only returns `True` if the identifier is case-insensitive. 411 412 Returns: 413 Whether or not the given text can be identified. 414 """ 415 if identify is True or identify == "always": 416 return True 417 418 if identify == "safe": 419 return not self.case_sensitive(text) 420 421 return False 422 423 def quote_identifier(self, expression: E, identify: bool = True) -> E: 424 """ 425 Adds quotes to a given identifier. 426 427 Args: 428 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 429 identify: If set to `False`, the quotes will only be added if the identifier is deemed 430 "unsafe", with respect to its characters and this dialect's normalization strategy. 431 """ 432 if isinstance(expression, exp.Identifier): 433 name = expression.this 434 expression.set( 435 "quoted", 436 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 437 ) 438 439 return expression 440 441 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 442 return self.parser(**opts).parse(self.tokenize(sql), sql) 443 444 def parse_into( 445 self, expression_type: exp.IntoType, sql: str, **opts 446 ) -> t.List[t.Optional[exp.Expression]]: 447 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 448 449 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 450 return self.generator(**opts).generate(expression, copy=copy) 451 452 def transpile(self, sql: str, **opts) -> t.List[str]: 453 return [ 454 self.generate(expression, copy=False, **opts) if expression else "" 455 for expression in self.parse(sql) 456 ] 457 458 def tokenize(self, sql: str) -> t.List[Token]: 459 return self.tokenizer.tokenize(sql) 460 461 @property 462 def tokenizer(self) -> Tokenizer: 463 if not hasattr(self, "_tokenizer"): 464 self._tokenizer = self.tokenizer_class(dialect=self) 465 return self._tokenizer 466 467 def parser(self, **opts) -> Parser: 468 return self.parser_class(dialect=self, **opts) 469 470 def generator(self, **opts) -> Generator: 471 return self.generator_class(dialect=self, **opts) 472 473 474DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 475 476 477def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 478 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 479 480 481def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 482 if expression.args.get("accuracy"): 483 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 484 return self.func("APPROX_COUNT_DISTINCT", expression.this) 485 486 487def if_sql( 488 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 489) -> t.Callable[[Generator, exp.If], str]: 490 def _if_sql(self: Generator, expression: exp.If) -> str: 491 return self.func( 492 name, 493 expression.this, 494 expression.args.get("true"), 495 expression.args.get("false") or false_value, 496 ) 497 498 return _if_sql 499 500 501def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: 502 return self.binary(expression, "->") 503 504 505def arrow_json_extract_scalar_sql( 506 self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar 507) -> str: 508 return self.binary(expression, "->>") 509 510 511def inline_array_sql(self: Generator, expression: exp.Array) -> str: 512 return f"[{self.expressions(expression, flat=True)}]" 513 514 515def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 516 return self.like_sql( 517 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 518 ) 519 520 521def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 522 zone = self.sql(expression, "this") 523 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 524 525 526def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 527 if expression.args.get("recursive"): 528 self.unsupported("Recursive CTEs are unsupported") 529 expression.args["recursive"] = False 530 return self.with_sql(expression) 531 532 533def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 534 n = self.sql(expression, "this") 535 d = self.sql(expression, "expression") 536 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 537 538 539def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 540 self.unsupported("TABLESAMPLE unsupported") 541 return self.sql(expression.this) 542 543 544def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 545 self.unsupported("PIVOT unsupported") 546 return "" 547 548 549def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 550 return self.cast_sql(expression) 551 552 553def no_properties_sql(self: Generator, expression: exp.Properties) -> str: 554 self.unsupported("Properties unsupported") 555 return "" 556 557 558def no_comment_column_constraint_sql( 559 self: Generator, expression: exp.CommentColumnConstraint 560) -> str: 561 self.unsupported("CommentColumnConstraint unsupported") 562 return "" 563 564 565def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 566 self.unsupported("MAP_FROM_ENTRIES unsupported") 567 return "" 568 569 570def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 571 this = self.sql(expression, "this") 572 substr = self.sql(expression, "substr") 573 position = self.sql(expression, "position") 574 if position: 575 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 576 return f"STRPOS({this}, {substr})" 577 578 579def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 580 return ( 581 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 582 ) 583 584 585def var_map_sql( 586 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 587) -> str: 588 keys = expression.args["keys"] 589 values = expression.args["values"] 590 591 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 592 self.unsupported("Cannot convert array columns into map.") 593 return self.func(map_func_name, keys, values) 594 595 args = [] 596 for key, value in zip(keys.expressions, values.expressions): 597 args.append(self.sql(key)) 598 args.append(self.sql(value)) 599 600 return self.func(map_func_name, *args) 601 602 603def format_time_lambda( 604 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 605) -> t.Callable[[t.List], E]: 606 """Helper used for time expressions. 607 608 Args: 609 exp_class: the expression class to instantiate. 610 dialect: target sql dialect. 611 default: the default format, True being time. 612 613 Returns: 614 A callable that can be used to return the appropriately formatted time expression. 615 """ 616 617 def _format_time(args: t.List): 618 return exp_class( 619 this=seq_get(args, 0), 620 format=Dialect[dialect].format_time( 621 seq_get(args, 1) 622 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 623 ), 624 ) 625 626 return _format_time 627 628 629def time_format( 630 dialect: DialectType = None, 631) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 632 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 633 """ 634 Returns the time format for a given expression, unless it's equivalent 635 to the default time format of the dialect of interest. 636 """ 637 time_format = self.format_time(expression) 638 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 639 640 return _time_format 641 642 643def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 644 """ 645 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 646 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 647 columns are removed from the create statement. 648 """ 649 has_schema = isinstance(expression.this, exp.Schema) 650 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 651 652 if has_schema and is_partitionable: 653 prop = expression.find(exp.PartitionedByProperty) 654 if prop and prop.this and not isinstance(prop.this, exp.Schema): 655 schema = expression.this 656 columns = {v.name.upper() for v in prop.this.expressions} 657 partitions = [col for col in schema.expressions if col.name.upper() in columns] 658 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 659 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 660 expression.set("this", schema) 661 662 return self.create_sql(expression) 663 664 665def parse_date_delta( 666 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 667) -> t.Callable[[t.List], E]: 668 def inner_func(args: t.List) -> E: 669 unit_based = len(args) == 3 670 this = args[2] if unit_based else seq_get(args, 0) 671 unit = args[0] if unit_based else exp.Literal.string("DAY") 672 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 673 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 674 675 return inner_func 676 677 678def parse_date_delta_with_interval( 679 expression_class: t.Type[E], 680) -> t.Callable[[t.List], t.Optional[E]]: 681 def func(args: t.List) -> t.Optional[E]: 682 if len(args) < 2: 683 return None 684 685 interval = args[1] 686 687 if not isinstance(interval, exp.Interval): 688 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 689 690 expression = interval.this 691 if expression and expression.is_string: 692 expression = exp.Literal.number(expression.this) 693 694 return expression_class( 695 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 696 ) 697 698 return func 699 700 701def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 702 unit = seq_get(args, 0) 703 this = seq_get(args, 1) 704 705 if isinstance(this, exp.Cast) and this.is_type("date"): 706 return exp.DateTrunc(unit=unit, this=this) 707 return exp.TimestampTrunc(this=this, unit=unit) 708 709 710def date_add_interval_sql( 711 data_type: str, kind: str 712) -> t.Callable[[Generator, exp.Expression], str]: 713 def func(self: Generator, expression: exp.Expression) -> str: 714 this = self.sql(expression, "this") 715 unit = expression.args.get("unit") 716 unit = exp.var(unit.name.upper() if unit else "DAY") 717 interval = exp.Interval(this=expression.expression, unit=unit) 718 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 719 720 return func 721 722 723def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 724 return self.func( 725 "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this 726 ) 727 728 729def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 730 if not expression.expression: 731 return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP)) 732 if expression.text("expression").lower() in TIMEZONES: 733 return self.sql( 734 exp.AtTimeZone( 735 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 736 zone=expression.expression, 737 ) 738 ) 739 return self.function_fallback_sql(expression) 740 741 742def locate_to_strposition(args: t.List) -> exp.Expression: 743 return exp.StrPosition( 744 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 745 ) 746 747 748def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 749 return self.func( 750 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 751 ) 752 753 754def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 755 return self.sql( 756 exp.Substring( 757 this=expression.this, start=exp.Literal.number(1), length=expression.expression 758 ) 759 ) 760 761 762def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 763 return self.sql( 764 exp.Substring( 765 this=expression.this, 766 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 767 ) 768 ) 769 770 771def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 772 return self.sql(exp.cast(expression.this, "timestamp")) 773 774 775def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 776 return self.sql(exp.cast(expression.this, "date")) 777 778 779# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 780def encode_decode_sql( 781 self: Generator, expression: exp.Expression, name: str, replace: bool = True 782) -> str: 783 charset = expression.args.get("charset") 784 if charset and charset.name.lower() != "utf-8": 785 self.unsupported(f"Expected utf-8 character set, got {charset}.") 786 787 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 788 789 790def min_or_least(self: Generator, expression: exp.Min) -> str: 791 name = "LEAST" if expression.expressions else "MIN" 792 return rename_func(name)(self, expression) 793 794 795def max_or_greatest(self: Generator, expression: exp.Max) -> str: 796 name = "GREATEST" if expression.expressions else "MAX" 797 return rename_func(name)(self, expression) 798 799 800def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 801 cond = expression.this 802 803 if isinstance(expression.this, exp.Distinct): 804 cond = expression.this.expressions[0] 805 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 806 807 return self.func("sum", exp.func("if", cond, 1, 0)) 808 809 810def trim_sql(self: Generator, expression: exp.Trim) -> str: 811 target = self.sql(expression, "this") 812 trim_type = self.sql(expression, "position") 813 remove_chars = self.sql(expression, "expression") 814 collation = self.sql(expression, "collation") 815 816 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 817 if not remove_chars and not collation: 818 return self.trim_sql(expression) 819 820 trim_type = f"{trim_type} " if trim_type else "" 821 remove_chars = f"{remove_chars} " if remove_chars else "" 822 from_part = "FROM " if trim_type or remove_chars else "" 823 collation = f" COLLATE {collation}" if collation else "" 824 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 825 826 827def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 828 return self.func("STRPTIME", expression.this, self.format_time(expression)) 829 830 831def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 832 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 833 834 835def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 836 delim, *rest_args = expression.expressions 837 return self.sql( 838 reduce( 839 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 840 rest_args, 841 ) 842 ) 843 844 845def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 846 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 847 if bad_args: 848 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 849 850 return self.func( 851 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 852 ) 853 854 855def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 856 bad_args = list( 857 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 858 ) 859 if bad_args: 860 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 861 862 return self.func( 863 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 864 ) 865 866 867def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 868 names = [] 869 for agg in aggregations: 870 if isinstance(agg, exp.Alias): 871 names.append(agg.alias) 872 else: 873 """ 874 This case corresponds to aggregations without aliases being used as suffixes 875 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 876 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 877 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 878 """ 879 agg_all_unquoted = agg.transform( 880 lambda node: exp.Identifier(this=node.name, quoted=False) 881 if isinstance(node, exp.Identifier) 882 else node 883 ) 884 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 885 886 return names 887 888 889def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 890 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 891 892 893# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 894def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 895 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 896 897 898def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 899 return self.func("MAX", expression.this) 900 901 902def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 903 a = self.sql(expression.left) 904 b = self.sql(expression.right) 905 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 906 907 908def is_parse_json(expression: exp.Expression) -> bool: 909 return isinstance(expression, exp.ParseJSON) or ( 910 isinstance(expression, exp.Cast) and expression.is_type("json") 911 ) 912 913 914def isnull_to_is_null(args: t.List) -> exp.Expression: 915 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 916 917 918def generatedasidentitycolumnconstraint_sql( 919 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 920) -> str: 921 start = self.sql(expression, "start") or "1" 922 increment = self.sql(expression, "increment") or "1" 923 return f"IDENTITY({start}, {increment})" 924 925 926def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 927 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 928 if expression.args.get("count"): 929 self.unsupported(f"Only two arguments are supported in function {name}.") 930 931 return self.func(name, expression.this, expression.expression) 932 933 return _arg_max_or_min_sql 934 935 936def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 937 this = expression.this.copy() 938 939 return_type = expression.return_type 940 if return_type.is_type(exp.DataType.Type.DATE): 941 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 942 # can truncate timestamp strings, because some dialects can't cast them to DATE 943 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 944 945 expression.this.replace(exp.cast(this, return_type)) 946 return expression 947 948 949def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 950 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 951 if cast and isinstance(expression, exp.TsOrDsAdd): 952 expression = ts_or_ds_add_cast(expression) 953 954 return self.func( 955 name, 956 exp.var(expression.text("unit").upper() or "DAY"), 957 expression.expression, 958 expression.this, 959 ) 960 961 return _delta_sql 962 963 964def prepend_dollar_to_path(expression: exp.GetPath) -> exp.GetPath: 965 from sqlglot.optimizer.simplify import simplify 966 967 # Makes sure the path will be evaluated correctly at runtime to include the path root. 968 # For example, `[0].foo` will become `$[0].foo`, and `foo` will become `$.foo`. 969 path = expression.expression 970 path = exp.func( 971 "if", 972 exp.func("startswith", path, "'['"), 973 exp.func("concat", "'$'", path), 974 exp.func("concat", "'$.'", path), 975 ) 976 977 expression.expression.replace(simplify(path)) 978 return expression 979 980 981def path_to_jsonpath( 982 name: str = "JSON_EXTRACT", 983) -> t.Callable[[Generator, exp.GetPath], str]: 984 def _transform(self: Generator, expression: exp.GetPath) -> str: 985 return rename_func(name)(self, prepend_dollar_to_path(expression)) 986 987 return _transform 988 989 990def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 991 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 992 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 993 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 994 995 return self.sql(exp.cast(minus_one_day, "date")) 996 997 998def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 999 """Remove table refs from columns in when statements.""" 1000 alias = expression.this.args.get("alias") 1001 1002 normalize = ( 1003 lambda identifier: self.dialect.normalize_identifier(identifier).name 1004 if identifier 1005 else None 1006 ) 1007 1008 targets = {normalize(expression.this.this)} 1009 1010 if alias: 1011 targets.add(normalize(alias.this)) 1012 1013 for when in expression.expressions: 1014 when.transform( 1015 lambda node: exp.column(node.this) 1016 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1017 else node, 1018 copy=False, 1019 ) 1020 1021 return self.merge_sql(expression)
24class Dialects(str, Enum): 25 """Dialects supported by SQLGLot.""" 26 27 DIALECT = "" 28 29 BIGQUERY = "bigquery" 30 CLICKHOUSE = "clickhouse" 31 DATABRICKS = "databricks" 32 DORIS = "doris" 33 DRILL = "drill" 34 DUCKDB = "duckdb" 35 HIVE = "hive" 36 MYSQL = "mysql" 37 ORACLE = "oracle" 38 POSTGRES = "postgres" 39 PRESTO = "presto" 40 REDSHIFT = "redshift" 41 SNOWFLAKE = "snowflake" 42 SPARK = "spark" 43 SPARK2 = "spark2" 44 SQLITE = "sqlite" 45 STARROCKS = "starrocks" 46 TABLEAU = "tableau" 47 TERADATA = "teradata" 48 TRINO = "trino" 49 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
52class NormalizationStrategy(str, AutoName): 53 """Specifies the strategy according to which identifiers should be normalized.""" 54 55 LOWERCASE = auto() 56 """Unquoted identifiers are lowercased.""" 57 58 UPPERCASE = auto() 59 """Unquoted identifiers are uppercased.""" 60 61 CASE_SENSITIVE = auto() 62 """Always case-sensitive, regardless of quotes.""" 63 64 CASE_INSENSITIVE = auto() 65 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
144class Dialect(metaclass=_Dialect): 145 INDEX_OFFSET = 0 146 """Determines the base index offset for arrays.""" 147 148 WEEK_OFFSET = 0 149 """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 150 151 UNNEST_COLUMN_ONLY = False 152 """Determines whether or not `UNNEST` table aliases are treated as column aliases.""" 153 154 ALIAS_POST_TABLESAMPLE = False 155 """Determines whether or not the table alias comes after tablesample.""" 156 157 TABLESAMPLE_SIZE_IS_PERCENT = False 158 """Determines whether or not a size in the table sample clause represents percentage.""" 159 160 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 161 """Specifies the strategy according to which identifiers should be normalized.""" 162 163 IDENTIFIERS_CAN_START_WITH_DIGIT = False 164 """Determines whether or not an unquoted identifier can start with a digit.""" 165 166 DPIPE_IS_STRING_CONCAT = True 167 """Determines whether or not the DPIPE token (`||`) is a string concatenation operator.""" 168 169 STRICT_STRING_CONCAT = False 170 """Determines whether or not `CONCAT`'s arguments must be strings.""" 171 172 SUPPORTS_USER_DEFINED_TYPES = True 173 """Determines whether or not user-defined data types are supported.""" 174 175 SUPPORTS_SEMI_ANTI_JOIN = True 176 """Determines whether or not `SEMI` or `ANTI` joins are supported.""" 177 178 NORMALIZE_FUNCTIONS: bool | str = "upper" 179 """Determines how function names are going to be normalized.""" 180 181 LOG_BASE_FIRST = True 182 """Determines whether the base comes first in the `LOG` function.""" 183 184 NULL_ORDERING = "nulls_are_small" 185 """ 186 Indicates the default `NULL` ordering method to use if not explicitly set. 187 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 188 """ 189 190 TYPED_DIVISION = False 191 """ 192 Whether the behavior of `a / b` depends on the types of `a` and `b`. 193 False means `a / b` is always float division. 194 True means `a / b` is integer division if both `a` and `b` are integers. 195 """ 196 197 SAFE_DIVISION = False 198 """Determines whether division by zero throws an error (`False`) or returns NULL (`True`).""" 199 200 CONCAT_COALESCE = False 201 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 202 203 DATE_FORMAT = "'%Y-%m-%d'" 204 DATEINT_FORMAT = "'%Y%m%d'" 205 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 206 207 TIME_MAPPING: t.Dict[str, str] = {} 208 """Associates this dialect's time formats with their equivalent Python `strftime` format.""" 209 210 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 211 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 212 FORMAT_MAPPING: t.Dict[str, str] = {} 213 """ 214 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 215 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 216 """ 217 218 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 219 """Mapping of an unescaped escape sequence to the corresponding character.""" 220 221 PSEUDOCOLUMNS: t.Set[str] = set() 222 """ 223 Columns that are auto-generated by the engine corresponding to this dialect. 224 For example, such columns may be excluded from `SELECT *` queries. 225 """ 226 227 PREFER_CTE_ALIAS_COLUMN = False 228 """ 229 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 230 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 231 any projection aliases in the subquery. 232 233 For example, 234 WITH y(c) AS ( 235 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 236 ) SELECT c FROM y; 237 238 will be rewritten as 239 240 WITH y(c) AS ( 241 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 242 ) SELECT c FROM y; 243 """ 244 245 # --- Autofilled --- 246 247 tokenizer_class = Tokenizer 248 parser_class = Parser 249 generator_class = Generator 250 251 # A trie of the time_mapping keys 252 TIME_TRIE: t.Dict = {} 253 FORMAT_TRIE: t.Dict = {} 254 255 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 256 INVERSE_TIME_TRIE: t.Dict = {} 257 258 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 259 260 # Delimiters for quotes, identifiers and the corresponding escape characters 261 QUOTE_START = "'" 262 QUOTE_END = "'" 263 IDENTIFIER_START = '"' 264 IDENTIFIER_END = '"' 265 266 # Delimiters for bit, hex, byte and unicode literals 267 BIT_START: t.Optional[str] = None 268 BIT_END: t.Optional[str] = None 269 HEX_START: t.Optional[str] = None 270 HEX_END: t.Optional[str] = None 271 BYTE_START: t.Optional[str] = None 272 BYTE_END: t.Optional[str] = None 273 UNICODE_START: t.Optional[str] = None 274 UNICODE_END: t.Optional[str] = None 275 276 @classmethod 277 def get_or_raise(cls, dialect: DialectType) -> Dialect: 278 """ 279 Look up a dialect in the global dialect registry and return it if it exists. 280 281 Args: 282 dialect: The target dialect. If this is a string, it can be optionally followed by 283 additional key-value pairs that are separated by commas and are used to specify 284 dialect settings, such as whether the dialect's identifiers are case-sensitive. 285 286 Example: 287 >>> dialect = dialect_class = get_or_raise("duckdb") 288 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 289 290 Returns: 291 The corresponding Dialect instance. 292 """ 293 294 if not dialect: 295 return cls() 296 if isinstance(dialect, _Dialect): 297 return dialect() 298 if isinstance(dialect, Dialect): 299 return dialect 300 if isinstance(dialect, str): 301 try: 302 dialect_name, *kv_pairs = dialect.split(",") 303 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 304 except ValueError: 305 raise ValueError( 306 f"Invalid dialect format: '{dialect}'. " 307 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 308 ) 309 310 result = cls.get(dialect_name.strip()) 311 if not result: 312 from difflib import get_close_matches 313 314 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 315 if similar: 316 similar = f" Did you mean {similar}?" 317 318 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 319 320 return result(**kwargs) 321 322 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 323 324 @classmethod 325 def format_time( 326 cls, expression: t.Optional[str | exp.Expression] 327 ) -> t.Optional[exp.Expression]: 328 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 329 if isinstance(expression, str): 330 return exp.Literal.string( 331 # the time formats are quoted 332 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 333 ) 334 335 if expression and expression.is_string: 336 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 337 338 return expression 339 340 def __init__(self, **kwargs) -> None: 341 normalization_strategy = kwargs.get("normalization_strategy") 342 343 if normalization_strategy is None: 344 self.normalization_strategy = self.NORMALIZATION_STRATEGY 345 else: 346 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 347 348 def __eq__(self, other: t.Any) -> bool: 349 # Does not currently take dialect state into account 350 return type(self) == other 351 352 def __hash__(self) -> int: 353 # Does not currently take dialect state into account 354 return hash(type(self)) 355 356 def normalize_identifier(self, expression: E) -> E: 357 """ 358 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 359 360 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 361 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 362 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 363 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 364 365 There are also dialects like Spark, which are case-insensitive even when quotes are 366 present, and dialects like MySQL, whose resolution rules match those employed by the 367 underlying operating system, for example they may always be case-sensitive in Linux. 368 369 Finally, the normalization behavior of some engines can even be controlled through flags, 370 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 371 372 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 373 that it can analyze queries in the optimizer and successfully capture their semantics. 374 """ 375 if ( 376 isinstance(expression, exp.Identifier) 377 and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE 378 and ( 379 not expression.quoted 380 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 381 ) 382 ): 383 expression.set( 384 "this", 385 expression.this.upper() 386 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 387 else expression.this.lower(), 388 ) 389 390 return expression 391 392 def case_sensitive(self, text: str) -> bool: 393 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 394 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 395 return False 396 397 unsafe = ( 398 str.islower 399 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 400 else str.isupper 401 ) 402 return any(unsafe(char) for char in text) 403 404 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 405 """Checks if text can be identified given an identify option. 406 407 Args: 408 text: The text to check. 409 identify: 410 `"always"` or `True`: Always returns `True`. 411 `"safe"`: Only returns `True` if the identifier is case-insensitive. 412 413 Returns: 414 Whether or not the given text can be identified. 415 """ 416 if identify is True or identify == "always": 417 return True 418 419 if identify == "safe": 420 return not self.case_sensitive(text) 421 422 return False 423 424 def quote_identifier(self, expression: E, identify: bool = True) -> E: 425 """ 426 Adds quotes to a given identifier. 427 428 Args: 429 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 430 identify: If set to `False`, the quotes will only be added if the identifier is deemed 431 "unsafe", with respect to its characters and this dialect's normalization strategy. 432 """ 433 if isinstance(expression, exp.Identifier): 434 name = expression.this 435 expression.set( 436 "quoted", 437 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 438 ) 439 440 return expression 441 442 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 443 return self.parser(**opts).parse(self.tokenize(sql), sql) 444 445 def parse_into( 446 self, expression_type: exp.IntoType, sql: str, **opts 447 ) -> t.List[t.Optional[exp.Expression]]: 448 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 449 450 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 451 return self.generator(**opts).generate(expression, copy=copy) 452 453 def transpile(self, sql: str, **opts) -> t.List[str]: 454 return [ 455 self.generate(expression, copy=False, **opts) if expression else "" 456 for expression in self.parse(sql) 457 ] 458 459 def tokenize(self, sql: str) -> t.List[Token]: 460 return self.tokenizer.tokenize(sql) 461 462 @property 463 def tokenizer(self) -> Tokenizer: 464 if not hasattr(self, "_tokenizer"): 465 self._tokenizer = self.tokenizer_class(dialect=self) 466 return self._tokenizer 467 468 def parser(self, **opts) -> Parser: 469 return self.parser_class(dialect=self, **opts) 470 471 def generator(self, **opts) -> Generator: 472 return self.generator_class(dialect=self, **opts)
340 def __init__(self, **kwargs) -> None: 341 normalization_strategy = kwargs.get("normalization_strategy") 342 343 if normalization_strategy is None: 344 self.normalization_strategy = self.NORMALIZATION_STRATEGY 345 else: 346 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Determines whether or not UNNEST
table aliases are treated as column aliases.
Determines whether or not a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines whether or not an unquoted identifier can start with a digit.
Determines whether or not the DPIPE token (||
) is a string concatenation operator.
Indicates the default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
Determines whether division by zero throws an error (False
) or returns NULL (True
).
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
format.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an unescaped escape sequence to the corresponding character.
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
276 @classmethod 277 def get_or_raise(cls, dialect: DialectType) -> Dialect: 278 """ 279 Look up a dialect in the global dialect registry and return it if it exists. 280 281 Args: 282 dialect: The target dialect. If this is a string, it can be optionally followed by 283 additional key-value pairs that are separated by commas and are used to specify 284 dialect settings, such as whether the dialect's identifiers are case-sensitive. 285 286 Example: 287 >>> dialect = dialect_class = get_or_raise("duckdb") 288 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 289 290 Returns: 291 The corresponding Dialect instance. 292 """ 293 294 if not dialect: 295 return cls() 296 if isinstance(dialect, _Dialect): 297 return dialect() 298 if isinstance(dialect, Dialect): 299 return dialect 300 if isinstance(dialect, str): 301 try: 302 dialect_name, *kv_pairs = dialect.split(",") 303 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 304 except ValueError: 305 raise ValueError( 306 f"Invalid dialect format: '{dialect}'. " 307 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 308 ) 309 310 result = cls.get(dialect_name.strip()) 311 if not result: 312 from difflib import get_close_matches 313 314 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 315 if similar: 316 similar = f" Did you mean {similar}?" 317 318 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 319 320 return result(**kwargs) 321 322 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
324 @classmethod 325 def format_time( 326 cls, expression: t.Optional[str | exp.Expression] 327 ) -> t.Optional[exp.Expression]: 328 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 329 if isinstance(expression, str): 330 return exp.Literal.string( 331 # the time formats are quoted 332 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 333 ) 334 335 if expression and expression.is_string: 336 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 337 338 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
356 def normalize_identifier(self, expression: E) -> E: 357 """ 358 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 359 360 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 361 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 362 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 363 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 364 365 There are also dialects like Spark, which are case-insensitive even when quotes are 366 present, and dialects like MySQL, whose resolution rules match those employed by the 367 underlying operating system, for example they may always be case-sensitive in Linux. 368 369 Finally, the normalization behavior of some engines can even be controlled through flags, 370 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 371 372 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 373 that it can analyze queries in the optimizer and successfully capture their semantics. 374 """ 375 if ( 376 isinstance(expression, exp.Identifier) 377 and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE 378 and ( 379 not expression.quoted 380 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 381 ) 382 ): 383 expression.set( 384 "this", 385 expression.this.upper() 386 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 387 else expression.this.lower(), 388 ) 389 390 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
392 def case_sensitive(self, text: str) -> bool: 393 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 394 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 395 return False 396 397 unsafe = ( 398 str.islower 399 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 400 else str.isupper 401 ) 402 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
404 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 405 """Checks if text can be identified given an identify option. 406 407 Args: 408 text: The text to check. 409 identify: 410 `"always"` or `True`: Always returns `True`. 411 `"safe"`: Only returns `True` if the identifier is case-insensitive. 412 413 Returns: 414 Whether or not the given text can be identified. 415 """ 416 if identify is True or identify == "always": 417 return True 418 419 if identify == "safe": 420 return not self.case_sensitive(text) 421 422 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether or not the given text can be identified.
424 def quote_identifier(self, expression: E, identify: bool = True) -> E: 425 """ 426 Adds quotes to a given identifier. 427 428 Args: 429 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 430 identify: If set to `False`, the quotes will only be added if the identifier is deemed 431 "unsafe", with respect to its characters and this dialect's normalization strategy. 432 """ 433 if isinstance(expression, exp.Identifier): 434 name = expression.this 435 expression.set( 436 "quoted", 437 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 438 ) 439 440 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
488def if_sql( 489 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 490) -> t.Callable[[Generator, exp.If], str]: 491 def _if_sql(self: Generator, expression: exp.If) -> str: 492 return self.func( 493 name, 494 expression.this, 495 expression.args.get("true"), 496 expression.args.get("false") or false_value, 497 ) 498 499 return _if_sql
571def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 572 this = self.sql(expression, "this") 573 substr = self.sql(expression, "substr") 574 position = self.sql(expression, "position") 575 if position: 576 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 577 return f"STRPOS({this}, {substr})"
586def var_map_sql( 587 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 588) -> str: 589 keys = expression.args["keys"] 590 values = expression.args["values"] 591 592 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 593 self.unsupported("Cannot convert array columns into map.") 594 return self.func(map_func_name, keys, values) 595 596 args = [] 597 for key, value in zip(keys.expressions, values.expressions): 598 args.append(self.sql(key)) 599 args.append(self.sql(value)) 600 601 return self.func(map_func_name, *args)
604def format_time_lambda( 605 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 606) -> t.Callable[[t.List], E]: 607 """Helper used for time expressions. 608 609 Args: 610 exp_class: the expression class to instantiate. 611 dialect: target sql dialect. 612 default: the default format, True being time. 613 614 Returns: 615 A callable that can be used to return the appropriately formatted time expression. 616 """ 617 618 def _format_time(args: t.List): 619 return exp_class( 620 this=seq_get(args, 0), 621 format=Dialect[dialect].format_time( 622 seq_get(args, 1) 623 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 624 ), 625 ) 626 627 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
630def time_format( 631 dialect: DialectType = None, 632) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 633 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 634 """ 635 Returns the time format for a given expression, unless it's equivalent 636 to the default time format of the dialect of interest. 637 """ 638 time_format = self.format_time(expression) 639 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 640 641 return _time_format
644def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 645 """ 646 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 647 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 648 columns are removed from the create statement. 649 """ 650 has_schema = isinstance(expression.this, exp.Schema) 651 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 652 653 if has_schema and is_partitionable: 654 prop = expression.find(exp.PartitionedByProperty) 655 if prop and prop.this and not isinstance(prop.this, exp.Schema): 656 schema = expression.this 657 columns = {v.name.upper() for v in prop.this.expressions} 658 partitions = [col for col in schema.expressions if col.name.upper() in columns] 659 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 660 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 661 expression.set("this", schema) 662 663 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
666def parse_date_delta( 667 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 668) -> t.Callable[[t.List], E]: 669 def inner_func(args: t.List) -> E: 670 unit_based = len(args) == 3 671 this = args[2] if unit_based else seq_get(args, 0) 672 unit = args[0] if unit_based else exp.Literal.string("DAY") 673 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 674 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 675 676 return inner_func
679def parse_date_delta_with_interval( 680 expression_class: t.Type[E], 681) -> t.Callable[[t.List], t.Optional[E]]: 682 def func(args: t.List) -> t.Optional[E]: 683 if len(args) < 2: 684 return None 685 686 interval = args[1] 687 688 if not isinstance(interval, exp.Interval): 689 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 690 691 expression = interval.this 692 if expression and expression.is_string: 693 expression = exp.Literal.number(expression.this) 694 695 return expression_class( 696 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 697 ) 698 699 return func
711def date_add_interval_sql( 712 data_type: str, kind: str 713) -> t.Callable[[Generator, exp.Expression], str]: 714 def func(self: Generator, expression: exp.Expression) -> str: 715 this = self.sql(expression, "this") 716 unit = expression.args.get("unit") 717 unit = exp.var(unit.name.upper() if unit else "DAY") 718 interval = exp.Interval(this=expression.expression, unit=unit) 719 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 720 721 return func
730def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 731 if not expression.expression: 732 return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP)) 733 if expression.text("expression").lower() in TIMEZONES: 734 return self.sql( 735 exp.AtTimeZone( 736 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 737 zone=expression.expression, 738 ) 739 ) 740 return self.function_fallback_sql(expression)
781def encode_decode_sql( 782 self: Generator, expression: exp.Expression, name: str, replace: bool = True 783) -> str: 784 charset = expression.args.get("charset") 785 if charset and charset.name.lower() != "utf-8": 786 self.unsupported(f"Expected utf-8 character set, got {charset}.") 787 788 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
801def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 802 cond = expression.this 803 804 if isinstance(expression.this, exp.Distinct): 805 cond = expression.this.expressions[0] 806 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 807 808 return self.func("sum", exp.func("if", cond, 1, 0))
811def trim_sql(self: Generator, expression: exp.Trim) -> str: 812 target = self.sql(expression, "this") 813 trim_type = self.sql(expression, "position") 814 remove_chars = self.sql(expression, "expression") 815 collation = self.sql(expression, "collation") 816 817 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 818 if not remove_chars and not collation: 819 return self.trim_sql(expression) 820 821 trim_type = f"{trim_type} " if trim_type else "" 822 remove_chars = f"{remove_chars} " if remove_chars else "" 823 from_part = "FROM " if trim_type or remove_chars else "" 824 collation = f" COLLATE {collation}" if collation else "" 825 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
846def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 847 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 848 if bad_args: 849 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 850 851 return self.func( 852 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 853 )
856def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 857 bad_args = list( 858 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 859 ) 860 if bad_args: 861 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 862 863 return self.func( 864 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 865 )
868def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 869 names = [] 870 for agg in aggregations: 871 if isinstance(agg, exp.Alias): 872 names.append(agg.alias) 873 else: 874 """ 875 This case corresponds to aggregations without aliases being used as suffixes 876 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 877 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 878 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 879 """ 880 agg_all_unquoted = agg.transform( 881 lambda node: exp.Identifier(this=node.name, quoted=False) 882 if isinstance(node, exp.Identifier) 883 else node 884 ) 885 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 886 887 return names
927def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 928 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 929 if expression.args.get("count"): 930 self.unsupported(f"Only two arguments are supported in function {name}.") 931 932 return self.func(name, expression.this, expression.expression) 933 934 return _arg_max_or_min_sql
937def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 938 this = expression.this.copy() 939 940 return_type = expression.return_type 941 if return_type.is_type(exp.DataType.Type.DATE): 942 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 943 # can truncate timestamp strings, because some dialects can't cast them to DATE 944 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 945 946 expression.this.replace(exp.cast(this, return_type)) 947 return expression
950def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 951 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 952 if cast and isinstance(expression, exp.TsOrDsAdd): 953 expression = ts_or_ds_add_cast(expression) 954 955 return self.func( 956 name, 957 exp.var(expression.text("unit").upper() or "DAY"), 958 expression.expression, 959 expression.this, 960 ) 961 962 return _delta_sql
965def prepend_dollar_to_path(expression: exp.GetPath) -> exp.GetPath: 966 from sqlglot.optimizer.simplify import simplify 967 968 # Makes sure the path will be evaluated correctly at runtime to include the path root. 969 # For example, `[0].foo` will become `$[0].foo`, and `foo` will become `$.foo`. 970 path = expression.expression 971 path = exp.func( 972 "if", 973 exp.func("startswith", path, "'['"), 974 exp.func("concat", "'$'", path), 975 exp.func("concat", "'$.'", path), 976 ) 977 978 expression.expression.replace(simplify(path)) 979 return expression
991def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 992 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 993 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 994 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 995 996 return self.sql(exp.cast(minus_one_day, "date"))
999def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1000 """Remove table refs from columns in when statements.""" 1001 alias = expression.this.args.get("alias") 1002 1003 normalize = ( 1004 lambda identifier: self.dialect.normalize_identifier(identifier).name 1005 if identifier 1006 else None 1007 ) 1008 1009 targets = {normalize(expression.this.this)} 1010 1011 if alias: 1012 targets.add(normalize(alias.this)) 1013 1014 for when in expression.expressions: 1015 when.transform( 1016 lambda node: exp.column(node.this) 1017 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1018 else node, 1019 copy=False, 1020 ) 1021 1022 return self.merge_sql(expression)
Remove table refs from columns in when statements.