sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator, unsupported_args 11from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses, to_bool 12from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time, subsecond_precision 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26 from sqlglot.optimizer.annotate_types import TypeAnnotator 27 28 AnnotatorsType = t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]] 29 30logger = logging.getLogger("sqlglot") 31 32UNESCAPED_SEQUENCES = { 33 "\\a": "\a", 34 "\\b": "\b", 35 "\\f": "\f", 36 "\\n": "\n", 37 "\\r": "\r", 38 "\\t": "\t", 39 "\\v": "\v", 40 "\\\\": "\\", 41} 42 43 44def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 45 return lambda self, e: self._annotate_with_type(e, data_type) 46 47 48class Dialects(str, Enum): 49 """Dialects supported by SQLGLot.""" 50 51 DIALECT = "" 52 53 ATHENA = "athena" 54 BIGQUERY = "bigquery" 55 CLICKHOUSE = "clickhouse" 56 DATABRICKS = "databricks" 57 DORIS = "doris" 58 DRILL = "drill" 59 DUCKDB = "duckdb" 60 HIVE = "hive" 61 MATERIALIZE = "materialize" 62 MYSQL = "mysql" 63 ORACLE = "oracle" 64 POSTGRES = "postgres" 65 PRESTO = "presto" 66 PRQL = "prql" 67 REDSHIFT = "redshift" 68 RISINGWAVE = "risingwave" 69 SNOWFLAKE = "snowflake" 70 SPARK = "spark" 71 SPARK2 = "spark2" 72 SQLITE = "sqlite" 73 STARROCKS = "starrocks" 74 TABLEAU = "tableau" 75 TERADATA = "teradata" 76 TRINO = "trino" 77 TSQL = "tsql" 78 79 80class NormalizationStrategy(str, AutoName): 81 """Specifies the strategy according to which identifiers should be normalized.""" 82 83 LOWERCASE = auto() 84 """Unquoted identifiers are lowercased.""" 85 86 UPPERCASE = auto() 87 """Unquoted identifiers are uppercased.""" 88 89 CASE_SENSITIVE = auto() 90 """Always case-sensitive, regardless of quotes.""" 91 92 CASE_INSENSITIVE = auto() 93 """Always case-insensitive, regardless of quotes.""" 94 95 96class _Dialect(type): 97 classes: t.Dict[str, t.Type[Dialect]] = {} 98 99 def __eq__(cls, other: t.Any) -> bool: 100 if cls is other: 101 return True 102 if isinstance(other, str): 103 return cls is cls.get(other) 104 if isinstance(other, Dialect): 105 return cls is type(other) 106 107 return False 108 109 def __hash__(cls) -> int: 110 return hash(cls.__name__.lower()) 111 112 @classmethod 113 def __getitem__(cls, key: str) -> t.Type[Dialect]: 114 return cls.classes[key] 115 116 @classmethod 117 def get( 118 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 119 ) -> t.Optional[t.Type[Dialect]]: 120 return cls.classes.get(key, default) 121 122 def __new__(cls, clsname, bases, attrs): 123 klass = super().__new__(cls, clsname, bases, attrs) 124 enum = Dialects.__members__.get(clsname.upper()) 125 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 126 127 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 128 klass.FORMAT_TRIE = ( 129 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 130 ) 131 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 132 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 133 klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()} 134 klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING) 135 136 klass.INVERSE_CREATABLE_KIND_MAPPING = { 137 v: k for k, v in klass.CREATABLE_KIND_MAPPING.items() 138 } 139 140 base = seq_get(bases, 0) 141 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 142 base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),) 143 base_parser = (getattr(base, "parser_class", Parser),) 144 base_generator = (getattr(base, "generator_class", Generator),) 145 146 klass.tokenizer_class = klass.__dict__.get( 147 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 148 ) 149 klass.jsonpath_tokenizer_class = klass.__dict__.get( 150 "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {}) 151 ) 152 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 153 klass.generator_class = klass.__dict__.get( 154 "Generator", type("Generator", base_generator, {}) 155 ) 156 157 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 158 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 159 klass.tokenizer_class._IDENTIFIERS.items() 160 )[0] 161 162 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 163 return next( 164 ( 165 (s, e) 166 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 167 if t == token_type 168 ), 169 (None, None), 170 ) 171 172 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 173 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 174 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 175 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 176 177 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 178 klass.UNESCAPED_SEQUENCES = { 179 **UNESCAPED_SEQUENCES, 180 **klass.UNESCAPED_SEQUENCES, 181 } 182 183 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 184 185 klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS 186 187 if enum not in ("", "bigquery"): 188 klass.generator_class.SELECT_KINDS = () 189 190 if enum not in ("", "athena", "presto", "trino"): 191 klass.generator_class.TRY_SUPPORTED = False 192 klass.generator_class.SUPPORTS_UESCAPE = False 193 194 if enum not in ("", "databricks", "hive", "spark", "spark2"): 195 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 196 for modifier in ("cluster", "distribute", "sort"): 197 modifier_transforms.pop(modifier, None) 198 199 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 200 201 if enum not in ("", "doris", "mysql"): 202 klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { 203 TokenType.STRAIGHT_JOIN, 204 } 205 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 206 TokenType.STRAIGHT_JOIN, 207 } 208 209 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 210 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 211 TokenType.ANTI, 212 TokenType.SEMI, 213 } 214 215 return klass 216 217 218class Dialect(metaclass=_Dialect): 219 INDEX_OFFSET = 0 220 """The base index offset for arrays.""" 221 222 WEEK_OFFSET = 0 223 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 224 225 UNNEST_COLUMN_ONLY = False 226 """Whether `UNNEST` table aliases are treated as column aliases.""" 227 228 ALIAS_POST_TABLESAMPLE = False 229 """Whether the table alias comes after tablesample.""" 230 231 TABLESAMPLE_SIZE_IS_PERCENT = False 232 """Whether a size in the table sample clause represents percentage.""" 233 234 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 235 """Specifies the strategy according to which identifiers should be normalized.""" 236 237 IDENTIFIERS_CAN_START_WITH_DIGIT = False 238 """Whether an unquoted identifier can start with a digit.""" 239 240 DPIPE_IS_STRING_CONCAT = True 241 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 242 243 STRICT_STRING_CONCAT = False 244 """Whether `CONCAT`'s arguments must be strings.""" 245 246 SUPPORTS_USER_DEFINED_TYPES = True 247 """Whether user-defined data types are supported.""" 248 249 SUPPORTS_SEMI_ANTI_JOIN = True 250 """Whether `SEMI` or `ANTI` joins are supported.""" 251 252 SUPPORTS_COLUMN_JOIN_MARKS = False 253 """Whether the old-style outer join (+) syntax is supported.""" 254 255 COPY_PARAMS_ARE_CSV = True 256 """Separator of COPY statement parameters.""" 257 258 NORMALIZE_FUNCTIONS: bool | str = "upper" 259 """ 260 Determines how function names are going to be normalized. 261 Possible values: 262 "upper" or True: Convert names to uppercase. 263 "lower": Convert names to lowercase. 264 False: Disables function name normalization. 265 """ 266 267 PRESERVE_ORIGINAL_NAMES: bool = False 268 """ 269 Whether the name of the function should be preserved inside the node's metadata, 270 can be useful for roundtripping deprecated vs new functions that share an AST node 271 e.g JSON_VALUE vs JSON_EXTRACT_SCALAR in BigQuery 272 """ 273 274 LOG_BASE_FIRST: t.Optional[bool] = True 275 """ 276 Whether the base comes first in the `LOG` function. 277 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 278 """ 279 280 NULL_ORDERING = "nulls_are_small" 281 """ 282 Default `NULL` ordering method to use if not explicitly set. 283 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 284 """ 285 286 TYPED_DIVISION = False 287 """ 288 Whether the behavior of `a / b` depends on the types of `a` and `b`. 289 False means `a / b` is always float division. 290 True means `a / b` is integer division if both `a` and `b` are integers. 291 """ 292 293 SAFE_DIVISION = False 294 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 295 296 CONCAT_COALESCE = False 297 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 298 299 HEX_LOWERCASE = False 300 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 301 302 DATE_FORMAT = "'%Y-%m-%d'" 303 DATEINT_FORMAT = "'%Y%m%d'" 304 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 305 306 TIME_MAPPING: t.Dict[str, str] = {} 307 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 308 309 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 310 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 311 FORMAT_MAPPING: t.Dict[str, str] = {} 312 """ 313 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 314 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 315 """ 316 317 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 318 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 319 320 PSEUDOCOLUMNS: t.Set[str] = set() 321 """ 322 Columns that are auto-generated by the engine corresponding to this dialect. 323 For example, such columns may be excluded from `SELECT *` queries. 324 """ 325 326 PREFER_CTE_ALIAS_COLUMN = False 327 """ 328 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 329 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 330 any projection aliases in the subquery. 331 332 For example, 333 WITH y(c) AS ( 334 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 335 ) SELECT c FROM y; 336 337 will be rewritten as 338 339 WITH y(c) AS ( 340 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 341 ) SELECT c FROM y; 342 """ 343 344 COPY_PARAMS_ARE_CSV = True 345 """ 346 Whether COPY statement parameters are separated by comma or whitespace 347 """ 348 349 FORCE_EARLY_ALIAS_REF_EXPANSION = False 350 """ 351 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 352 353 For example: 354 WITH data AS ( 355 SELECT 356 1 AS id, 357 2 AS my_id 358 ) 359 SELECT 360 id AS my_id 361 FROM 362 data 363 WHERE 364 my_id = 1 365 GROUP BY 366 my_id, 367 HAVING 368 my_id = 1 369 370 In most dialects, "my_id" would refer to "data.my_id" across the query, except: 371 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e 372 it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 373 - Clickhouse, which will forward the alias across the query i.e it resolves 374 to "WHERE id = 1 GROUP BY id HAVING id = 1" 375 """ 376 377 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 378 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 379 380 SUPPORTS_ORDER_BY_ALL = False 381 """ 382 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 383 """ 384 385 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 386 """ 387 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 388 as the former is of type INT[] vs the latter which is SUPER 389 """ 390 391 SUPPORTS_FIXED_SIZE_ARRAYS = False 392 """ 393 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. 394 in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should 395 be interpreted as a subscript/index operator. 396 """ 397 398 STRICT_JSON_PATH_SYNTAX = True 399 """Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.""" 400 401 ON_CONDITION_EMPTY_BEFORE_ERROR = True 402 """Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).""" 403 404 ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True 405 """Whether ArrayAgg needs to filter NULL values.""" 406 407 PROMOTE_TO_INFERRED_DATETIME_TYPE = False 408 """ 409 This flag is used in the optimizer's canonicalize rule and determines whether x will be promoted 410 to the literal's type in x::DATE < '2020-01-01 12:05:03' (i.e., DATETIME). When false, the literal 411 is cast to x's type to match it instead. 412 """ 413 414 SUPPORTS_VALUES_DEFAULT = True 415 """Whether the DEFAULT keyword is supported in the VALUES clause.""" 416 417 REGEXP_EXTRACT_DEFAULT_GROUP = 0 418 """The default value for the capturing group.""" 419 420 SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = { 421 exp.Except: True, 422 exp.Intersect: True, 423 exp.Union: True, 424 } 425 """ 426 Whether a set operation uses DISTINCT by default. This is `None` when either `DISTINCT` or `ALL` 427 must be explicitly specified. 428 """ 429 430 CREATABLE_KIND_MAPPING: dict[str, str] = {} 431 """ 432 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 433 equivalent of CREATE SCHEMA is CREATE DATABASE. 434 """ 435 436 # --- Autofilled --- 437 438 tokenizer_class = Tokenizer 439 jsonpath_tokenizer_class = JSONPathTokenizer 440 parser_class = Parser 441 generator_class = Generator 442 443 # A trie of the time_mapping keys 444 TIME_TRIE: t.Dict = {} 445 FORMAT_TRIE: t.Dict = {} 446 447 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 448 INVERSE_TIME_TRIE: t.Dict = {} 449 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 450 INVERSE_FORMAT_TRIE: t.Dict = {} 451 452 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 453 454 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 455 456 # Delimiters for string literals and identifiers 457 QUOTE_START = "'" 458 QUOTE_END = "'" 459 IDENTIFIER_START = '"' 460 IDENTIFIER_END = '"' 461 462 # Delimiters for bit, hex, byte and unicode literals 463 BIT_START: t.Optional[str] = None 464 BIT_END: t.Optional[str] = None 465 HEX_START: t.Optional[str] = None 466 HEX_END: t.Optional[str] = None 467 BYTE_START: t.Optional[str] = None 468 BYTE_END: t.Optional[str] = None 469 UNICODE_START: t.Optional[str] = None 470 UNICODE_END: t.Optional[str] = None 471 472 DATE_PART_MAPPING = { 473 "Y": "YEAR", 474 "YY": "YEAR", 475 "YYY": "YEAR", 476 "YYYY": "YEAR", 477 "YR": "YEAR", 478 "YEARS": "YEAR", 479 "YRS": "YEAR", 480 "MM": "MONTH", 481 "MON": "MONTH", 482 "MONS": "MONTH", 483 "MONTHS": "MONTH", 484 "D": "DAY", 485 "DD": "DAY", 486 "DAYS": "DAY", 487 "DAYOFMONTH": "DAY", 488 "DAY OF WEEK": "DAYOFWEEK", 489 "WEEKDAY": "DAYOFWEEK", 490 "DOW": "DAYOFWEEK", 491 "DW": "DAYOFWEEK", 492 "WEEKDAY_ISO": "DAYOFWEEKISO", 493 "DOW_ISO": "DAYOFWEEKISO", 494 "DW_ISO": "DAYOFWEEKISO", 495 "DAY OF YEAR": "DAYOFYEAR", 496 "DOY": "DAYOFYEAR", 497 "DY": "DAYOFYEAR", 498 "W": "WEEK", 499 "WK": "WEEK", 500 "WEEKOFYEAR": "WEEK", 501 "WOY": "WEEK", 502 "WY": "WEEK", 503 "WEEK_ISO": "WEEKISO", 504 "WEEKOFYEARISO": "WEEKISO", 505 "WEEKOFYEAR_ISO": "WEEKISO", 506 "Q": "QUARTER", 507 "QTR": "QUARTER", 508 "QTRS": "QUARTER", 509 "QUARTERS": "QUARTER", 510 "H": "HOUR", 511 "HH": "HOUR", 512 "HR": "HOUR", 513 "HOURS": "HOUR", 514 "HRS": "HOUR", 515 "M": "MINUTE", 516 "MI": "MINUTE", 517 "MIN": "MINUTE", 518 "MINUTES": "MINUTE", 519 "MINS": "MINUTE", 520 "S": "SECOND", 521 "SEC": "SECOND", 522 "SECONDS": "SECOND", 523 "SECS": "SECOND", 524 "MS": "MILLISECOND", 525 "MSEC": "MILLISECOND", 526 "MSECS": "MILLISECOND", 527 "MSECOND": "MILLISECOND", 528 "MSECONDS": "MILLISECOND", 529 "MILLISEC": "MILLISECOND", 530 "MILLISECS": "MILLISECOND", 531 "MILLISECON": "MILLISECOND", 532 "MILLISECONDS": "MILLISECOND", 533 "US": "MICROSECOND", 534 "USEC": "MICROSECOND", 535 "USECS": "MICROSECOND", 536 "MICROSEC": "MICROSECOND", 537 "MICROSECS": "MICROSECOND", 538 "USECOND": "MICROSECOND", 539 "USECONDS": "MICROSECOND", 540 "MICROSECONDS": "MICROSECOND", 541 "NS": "NANOSECOND", 542 "NSEC": "NANOSECOND", 543 "NANOSEC": "NANOSECOND", 544 "NSECOND": "NANOSECOND", 545 "NSECONDS": "NANOSECOND", 546 "NANOSECS": "NANOSECOND", 547 "EPOCH_SECOND": "EPOCH", 548 "EPOCH_SECONDS": "EPOCH", 549 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 550 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 551 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 552 "TZH": "TIMEZONE_HOUR", 553 "TZM": "TIMEZONE_MINUTE", 554 "DEC": "DECADE", 555 "DECS": "DECADE", 556 "DECADES": "DECADE", 557 "MIL": "MILLENIUM", 558 "MILS": "MILLENIUM", 559 "MILLENIA": "MILLENIUM", 560 "C": "CENTURY", 561 "CENT": "CENTURY", 562 "CENTS": "CENTURY", 563 "CENTURIES": "CENTURY", 564 } 565 566 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 567 exp.DataType.Type.BIGINT: { 568 exp.ApproxDistinct, 569 exp.ArraySize, 570 exp.Length, 571 }, 572 exp.DataType.Type.BOOLEAN: { 573 exp.Between, 574 exp.Boolean, 575 exp.In, 576 exp.RegexpLike, 577 }, 578 exp.DataType.Type.DATE: { 579 exp.CurrentDate, 580 exp.Date, 581 exp.DateFromParts, 582 exp.DateStrToDate, 583 exp.DiToDate, 584 exp.StrToDate, 585 exp.TimeStrToDate, 586 exp.TsOrDsToDate, 587 }, 588 exp.DataType.Type.DATETIME: { 589 exp.CurrentDatetime, 590 exp.Datetime, 591 exp.DatetimeAdd, 592 exp.DatetimeSub, 593 }, 594 exp.DataType.Type.DOUBLE: { 595 exp.ApproxQuantile, 596 exp.Avg, 597 exp.Exp, 598 exp.Ln, 599 exp.Log, 600 exp.Pow, 601 exp.Quantile, 602 exp.Round, 603 exp.SafeDivide, 604 exp.Sqrt, 605 exp.Stddev, 606 exp.StddevPop, 607 exp.StddevSamp, 608 exp.ToDouble, 609 exp.Variance, 610 exp.VariancePop, 611 }, 612 exp.DataType.Type.INT: { 613 exp.Ceil, 614 exp.DatetimeDiff, 615 exp.DateDiff, 616 exp.TimestampDiff, 617 exp.TimeDiff, 618 exp.DateToDi, 619 exp.Levenshtein, 620 exp.Sign, 621 exp.StrPosition, 622 exp.TsOrDiToDi, 623 }, 624 exp.DataType.Type.JSON: { 625 exp.ParseJSON, 626 }, 627 exp.DataType.Type.TIME: { 628 exp.Time, 629 }, 630 exp.DataType.Type.TIMESTAMP: { 631 exp.CurrentTime, 632 exp.CurrentTimestamp, 633 exp.StrToTime, 634 exp.TimeAdd, 635 exp.TimeStrToTime, 636 exp.TimeSub, 637 exp.TimestampAdd, 638 exp.TimestampSub, 639 exp.UnixToTime, 640 }, 641 exp.DataType.Type.TINYINT: { 642 exp.Day, 643 exp.Month, 644 exp.Week, 645 exp.Year, 646 exp.Quarter, 647 }, 648 exp.DataType.Type.VARCHAR: { 649 exp.ArrayConcat, 650 exp.Concat, 651 exp.ConcatWs, 652 exp.DateToDateStr, 653 exp.GroupConcat, 654 exp.Initcap, 655 exp.Lower, 656 exp.Substring, 657 exp.String, 658 exp.TimeToStr, 659 exp.TimeToTimeStr, 660 exp.Trim, 661 exp.TsOrDsToDateStr, 662 exp.UnixToStr, 663 exp.UnixToTimeStr, 664 exp.Upper, 665 }, 666 } 667 668 ANNOTATORS: AnnotatorsType = { 669 **{ 670 expr_type: lambda self, e: self._annotate_unary(e) 671 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 672 }, 673 **{ 674 expr_type: lambda self, e: self._annotate_binary(e) 675 for expr_type in subclasses(exp.__name__, exp.Binary) 676 }, 677 **{ 678 expr_type: _annotate_with_type_lambda(data_type) 679 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 680 for expr_type in expressions 681 }, 682 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 683 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 684 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 685 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 686 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 687 exp.Bracket: lambda self, e: self._annotate_bracket(e), 688 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 689 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 690 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 691 exp.Count: lambda self, e: self._annotate_with_type( 692 e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT 693 ), 694 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 695 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 696 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 697 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 698 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 699 exp.Div: lambda self, e: self._annotate_div(e), 700 exp.Dot: lambda self, e: self._annotate_dot(e), 701 exp.Explode: lambda self, e: self._annotate_explode(e), 702 exp.Extract: lambda self, e: self._annotate_extract(e), 703 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 704 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 705 e, exp.DataType.build("ARRAY<DATE>") 706 ), 707 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 708 e, exp.DataType.build("ARRAY<TIMESTAMP>") 709 ), 710 exp.Greatest: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 711 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 712 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 713 exp.Least: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 714 exp.Literal: lambda self, e: self._annotate_literal(e), 715 exp.Map: lambda self, e: self._annotate_map(e), 716 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 717 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 718 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 719 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 720 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 721 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 722 exp.Struct: lambda self, e: self._annotate_struct(e), 723 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 724 exp.Timestamp: lambda self, e: self._annotate_with_type( 725 e, 726 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 727 ), 728 exp.ToMap: lambda self, e: self._annotate_to_map(e), 729 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 730 exp.Unnest: lambda self, e: self._annotate_unnest(e), 731 exp.VarMap: lambda self, e: self._annotate_map(e), 732 } 733 734 @classmethod 735 def get_or_raise(cls, dialect: DialectType) -> Dialect: 736 """ 737 Look up a dialect in the global dialect registry and return it if it exists. 738 739 Args: 740 dialect: The target dialect. If this is a string, it can be optionally followed by 741 additional key-value pairs that are separated by commas and are used to specify 742 dialect settings, such as whether the dialect's identifiers are case-sensitive. 743 744 Example: 745 >>> dialect = dialect_class = get_or_raise("duckdb") 746 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 747 748 Returns: 749 The corresponding Dialect instance. 750 """ 751 752 if not dialect: 753 return cls() 754 if isinstance(dialect, _Dialect): 755 return dialect() 756 if isinstance(dialect, Dialect): 757 return dialect 758 if isinstance(dialect, str): 759 try: 760 dialect_name, *kv_strings = dialect.split(",") 761 kv_pairs = (kv.split("=") for kv in kv_strings) 762 kwargs = {} 763 for pair in kv_pairs: 764 key = pair[0].strip() 765 value: t.Union[bool | str | None] = None 766 767 if len(pair) == 1: 768 # Default initialize standalone settings to True 769 value = True 770 elif len(pair) == 2: 771 value = pair[1].strip() 772 773 kwargs[key] = to_bool(value) 774 775 except ValueError: 776 raise ValueError( 777 f"Invalid dialect format: '{dialect}'. " 778 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 779 ) 780 781 result = cls.get(dialect_name.strip()) 782 if not result: 783 from difflib import get_close_matches 784 785 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 786 if similar: 787 similar = f" Did you mean {similar}?" 788 789 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 790 791 return result(**kwargs) 792 793 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 794 795 @classmethod 796 def format_time( 797 cls, expression: t.Optional[str | exp.Expression] 798 ) -> t.Optional[exp.Expression]: 799 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 800 if isinstance(expression, str): 801 return exp.Literal.string( 802 # the time formats are quoted 803 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 804 ) 805 806 if expression and expression.is_string: 807 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 808 809 return expression 810 811 def __init__(self, **kwargs) -> None: 812 normalization_strategy = kwargs.pop("normalization_strategy", None) 813 814 if normalization_strategy is None: 815 self.normalization_strategy = self.NORMALIZATION_STRATEGY 816 else: 817 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 818 819 self.settings = kwargs 820 821 def __eq__(self, other: t.Any) -> bool: 822 # Does not currently take dialect state into account 823 return type(self) == other 824 825 def __hash__(self) -> int: 826 # Does not currently take dialect state into account 827 return hash(type(self)) 828 829 def normalize_identifier(self, expression: E) -> E: 830 """ 831 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 832 833 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 834 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 835 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 836 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 837 838 There are also dialects like Spark, which are case-insensitive even when quotes are 839 present, and dialects like MySQL, whose resolution rules match those employed by the 840 underlying operating system, for example they may always be case-sensitive in Linux. 841 842 Finally, the normalization behavior of some engines can even be controlled through flags, 843 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 844 845 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 846 that it can analyze queries in the optimizer and successfully capture their semantics. 847 """ 848 if ( 849 isinstance(expression, exp.Identifier) 850 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 851 and ( 852 not expression.quoted 853 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 854 ) 855 ): 856 expression.set( 857 "this", 858 ( 859 expression.this.upper() 860 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 861 else expression.this.lower() 862 ), 863 ) 864 865 return expression 866 867 def case_sensitive(self, text: str) -> bool: 868 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 869 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 870 return False 871 872 unsafe = ( 873 str.islower 874 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 875 else str.isupper 876 ) 877 return any(unsafe(char) for char in text) 878 879 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 880 """Checks if text can be identified given an identify option. 881 882 Args: 883 text: The text to check. 884 identify: 885 `"always"` or `True`: Always returns `True`. 886 `"safe"`: Only returns `True` if the identifier is case-insensitive. 887 888 Returns: 889 Whether the given text can be identified. 890 """ 891 if identify is True or identify == "always": 892 return True 893 894 if identify == "safe": 895 return not self.case_sensitive(text) 896 897 return False 898 899 def quote_identifier(self, expression: E, identify: bool = True) -> E: 900 """ 901 Adds quotes to a given identifier. 902 903 Args: 904 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 905 identify: If set to `False`, the quotes will only be added if the identifier is deemed 906 "unsafe", with respect to its characters and this dialect's normalization strategy. 907 """ 908 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 909 name = expression.this 910 expression.set( 911 "quoted", 912 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 913 ) 914 915 return expression 916 917 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 918 if isinstance(path, exp.Literal): 919 path_text = path.name 920 if path.is_number: 921 path_text = f"[{path_text}]" 922 try: 923 return parse_json_path(path_text, self) 924 except ParseError as e: 925 if self.STRICT_JSON_PATH_SYNTAX: 926 logger.warning(f"Invalid JSON path syntax. {str(e)}") 927 928 return path 929 930 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 931 return self.parser(**opts).parse(self.tokenize(sql), sql) 932 933 def parse_into( 934 self, expression_type: exp.IntoType, sql: str, **opts 935 ) -> t.List[t.Optional[exp.Expression]]: 936 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 937 938 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 939 return self.generator(**opts).generate(expression, copy=copy) 940 941 def transpile(self, sql: str, **opts) -> t.List[str]: 942 return [ 943 self.generate(expression, copy=False, **opts) if expression else "" 944 for expression in self.parse(sql) 945 ] 946 947 def tokenize(self, sql: str) -> t.List[Token]: 948 return self.tokenizer.tokenize(sql) 949 950 @property 951 def tokenizer(self) -> Tokenizer: 952 return self.tokenizer_class(dialect=self) 953 954 @property 955 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 956 return self.jsonpath_tokenizer_class(dialect=self) 957 958 def parser(self, **opts) -> Parser: 959 return self.parser_class(dialect=self, **opts) 960 961 def generator(self, **opts) -> Generator: 962 return self.generator_class(dialect=self, **opts) 963 964 965DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 966 967 968def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 969 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 970 971 972@unsupported_args("accuracy") 973def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 974 return self.func("APPROX_COUNT_DISTINCT", expression.this) 975 976 977def if_sql( 978 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 979) -> t.Callable[[Generator, exp.If], str]: 980 def _if_sql(self: Generator, expression: exp.If) -> str: 981 return self.func( 982 name, 983 expression.this, 984 expression.args.get("true"), 985 expression.args.get("false") or false_value, 986 ) 987 988 return _if_sql 989 990 991def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 992 this = expression.this 993 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 994 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 995 996 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 997 998 999def inline_array_sql(self: Generator, expression: exp.Array) -> str: 1000 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 1001 1002 1003def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 1004 elem = seq_get(expression.expressions, 0) 1005 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 1006 return self.func("ARRAY", elem) 1007 return inline_array_sql(self, expression) 1008 1009 1010def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 1011 return self.like_sql( 1012 exp.Like( 1013 this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) 1014 ) 1015 ) 1016 1017 1018def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 1019 zone = self.sql(expression, "this") 1020 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 1021 1022 1023def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 1024 if expression.args.get("recursive"): 1025 self.unsupported("Recursive CTEs are unsupported") 1026 expression.args["recursive"] = False 1027 return self.with_sql(expression) 1028 1029 1030def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide, if_sql: str = "IF") -> str: 1031 n = self.sql(expression, "this") 1032 d = self.sql(expression, "expression") 1033 return f"{if_sql}(({d}) <> 0, ({n}) / ({d}), NULL)" 1034 1035 1036def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 1037 self.unsupported("TABLESAMPLE unsupported") 1038 return self.sql(expression.this) 1039 1040 1041def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 1042 self.unsupported("PIVOT unsupported") 1043 return "" 1044 1045 1046def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 1047 return self.cast_sql(expression) 1048 1049 1050def no_comment_column_constraint_sql( 1051 self: Generator, expression: exp.CommentColumnConstraint 1052) -> str: 1053 self.unsupported("CommentColumnConstraint unsupported") 1054 return "" 1055 1056 1057def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 1058 self.unsupported("MAP_FROM_ENTRIES unsupported") 1059 return "" 1060 1061 1062def property_sql(self: Generator, expression: exp.Property) -> str: 1063 return f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}" 1064 1065 1066def str_position_sql( 1067 self: Generator, 1068 expression: exp.StrPosition, 1069 generate_instance: bool = False, 1070 str_position_func_name: str = "STRPOS", 1071) -> str: 1072 this = self.sql(expression, "this") 1073 substr = self.sql(expression, "substr") 1074 position = self.sql(expression, "position") 1075 instance = expression.args.get("instance") if generate_instance else None 1076 position_offset = "" 1077 1078 if position: 1079 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1080 this = self.func("SUBSTR", this, position) 1081 position_offset = f" + {position} - 1" 1082 1083 return self.func(str_position_func_name, this, substr, instance) + position_offset 1084 1085 1086def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 1087 return ( 1088 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 1089 ) 1090 1091 1092def var_map_sql( 1093 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1094) -> str: 1095 keys = expression.args["keys"] 1096 values = expression.args["values"] 1097 1098 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1099 self.unsupported("Cannot convert array columns into map.") 1100 return self.func(map_func_name, keys, values) 1101 1102 args = [] 1103 for key, value in zip(keys.expressions, values.expressions): 1104 args.append(self.sql(key)) 1105 args.append(self.sql(value)) 1106 1107 return self.func(map_func_name, *args) 1108 1109 1110def build_formatted_time( 1111 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1112) -> t.Callable[[t.List], E]: 1113 """Helper used for time expressions. 1114 1115 Args: 1116 exp_class: the expression class to instantiate. 1117 dialect: target sql dialect. 1118 default: the default format, True being time. 1119 1120 Returns: 1121 A callable that can be used to return the appropriately formatted time expression. 1122 """ 1123 1124 def _builder(args: t.List): 1125 return exp_class( 1126 this=seq_get(args, 0), 1127 format=Dialect[dialect].format_time( 1128 seq_get(args, 1) 1129 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1130 ), 1131 ) 1132 1133 return _builder 1134 1135 1136def time_format( 1137 dialect: DialectType = None, 1138) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1139 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1140 """ 1141 Returns the time format for a given expression, unless it's equivalent 1142 to the default time format of the dialect of interest. 1143 """ 1144 time_format = self.format_time(expression) 1145 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1146 1147 return _time_format 1148 1149 1150def build_date_delta( 1151 exp_class: t.Type[E], 1152 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1153 default_unit: t.Optional[str] = "DAY", 1154) -> t.Callable[[t.List], E]: 1155 def _builder(args: t.List) -> E: 1156 unit_based = len(args) == 3 1157 this = args[2] if unit_based else seq_get(args, 0) 1158 unit = None 1159 if unit_based or default_unit: 1160 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1161 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1162 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1163 1164 return _builder 1165 1166 1167def build_date_delta_with_interval( 1168 expression_class: t.Type[E], 1169) -> t.Callable[[t.List], t.Optional[E]]: 1170 def _builder(args: t.List) -> t.Optional[E]: 1171 if len(args) < 2: 1172 return None 1173 1174 interval = args[1] 1175 1176 if not isinstance(interval, exp.Interval): 1177 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1178 1179 return expression_class(this=args[0], expression=interval.this, unit=unit_to_str(interval)) 1180 1181 return _builder 1182 1183 1184def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1185 unit = seq_get(args, 0) 1186 this = seq_get(args, 1) 1187 1188 if isinstance(this, exp.Cast) and this.is_type("date"): 1189 return exp.DateTrunc(unit=unit, this=this) 1190 return exp.TimestampTrunc(this=this, unit=unit) 1191 1192 1193def date_add_interval_sql( 1194 data_type: str, kind: str 1195) -> t.Callable[[Generator, exp.Expression], str]: 1196 def func(self: Generator, expression: exp.Expression) -> str: 1197 this = self.sql(expression, "this") 1198 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1199 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1200 1201 return func 1202 1203 1204def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1205 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1206 args = [unit_to_str(expression), expression.this] 1207 if zone: 1208 args.append(expression.args.get("zone")) 1209 return self.func("DATE_TRUNC", *args) 1210 1211 return _timestamptrunc_sql 1212 1213 1214def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1215 zone = expression.args.get("zone") 1216 if not zone: 1217 from sqlglot.optimizer.annotate_types import annotate_types 1218 1219 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1220 return self.sql(exp.cast(expression.this, target_type)) 1221 if zone.name.lower() in TIMEZONES: 1222 return self.sql( 1223 exp.AtTimeZone( 1224 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1225 zone=zone, 1226 ) 1227 ) 1228 return self.func("TIMESTAMP", expression.this, zone) 1229 1230 1231def no_time_sql(self: Generator, expression: exp.Time) -> str: 1232 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1233 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1234 expr = exp.cast( 1235 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1236 ) 1237 return self.sql(expr) 1238 1239 1240def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1241 this = expression.this 1242 expr = expression.expression 1243 1244 if expr.name.lower() in TIMEZONES: 1245 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1246 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1247 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1248 return self.sql(this) 1249 1250 this = exp.cast(this, exp.DataType.Type.DATE) 1251 expr = exp.cast(expr, exp.DataType.Type.TIME) 1252 1253 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP)) 1254 1255 1256def locate_to_strposition(args: t.List) -> exp.Expression: 1257 return exp.StrPosition( 1258 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 1259 ) 1260 1261 1262def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 1263 return self.func( 1264 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 1265 ) 1266 1267 1268def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1269 return self.sql( 1270 exp.Substring( 1271 this=expression.this, start=exp.Literal.number(1), length=expression.expression 1272 ) 1273 ) 1274 1275 1276def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1277 return self.sql( 1278 exp.Substring( 1279 this=expression.this, 1280 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 1281 ) 1282 ) 1283 1284 1285def timestrtotime_sql( 1286 self: Generator, 1287 expression: exp.TimeStrToTime, 1288 include_precision: bool = False, 1289) -> str: 1290 datatype = exp.DataType.build( 1291 exp.DataType.Type.TIMESTAMPTZ 1292 if expression.args.get("zone") 1293 else exp.DataType.Type.TIMESTAMP 1294 ) 1295 1296 if isinstance(expression.this, exp.Literal) and include_precision: 1297 precision = subsecond_precision(expression.this.name) 1298 if precision > 0: 1299 datatype = exp.DataType.build( 1300 datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))] 1301 ) 1302 1303 return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect)) 1304 1305 1306def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 1307 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 1308 1309 1310# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 1311def encode_decode_sql( 1312 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1313) -> str: 1314 charset = expression.args.get("charset") 1315 if charset and charset.name.lower() != "utf-8": 1316 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1317 1318 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 1319 1320 1321def min_or_least(self: Generator, expression: exp.Min) -> str: 1322 name = "LEAST" if expression.expressions else "MIN" 1323 return rename_func(name)(self, expression) 1324 1325 1326def max_or_greatest(self: Generator, expression: exp.Max) -> str: 1327 name = "GREATEST" if expression.expressions else "MAX" 1328 return rename_func(name)(self, expression) 1329 1330 1331def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1332 cond = expression.this 1333 1334 if isinstance(expression.this, exp.Distinct): 1335 cond = expression.this.expressions[0] 1336 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1337 1338 return self.func("sum", exp.func("if", cond, 1, 0)) 1339 1340 1341def trim_sql(self: Generator, expression: exp.Trim) -> str: 1342 target = self.sql(expression, "this") 1343 trim_type = self.sql(expression, "position") 1344 remove_chars = self.sql(expression, "expression") 1345 collation = self.sql(expression, "collation") 1346 1347 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1348 if not remove_chars: 1349 return self.trim_sql(expression) 1350 1351 trim_type = f"{trim_type} " if trim_type else "" 1352 remove_chars = f"{remove_chars} " if remove_chars else "" 1353 from_part = "FROM " if trim_type or remove_chars else "" 1354 collation = f" COLLATE {collation}" if collation else "" 1355 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 1356 1357 1358def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 1359 return self.func("STRPTIME", expression.this, self.format_time(expression)) 1360 1361 1362def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 1363 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 1364 1365 1366def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 1367 delim, *rest_args = expression.expressions 1368 return self.sql( 1369 reduce( 1370 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 1371 rest_args, 1372 ) 1373 ) 1374 1375 1376@unsupported_args("position", "occurrence", "parameters") 1377def regexp_extract_sql( 1378 self: Generator, expression: exp.RegexpExtract | exp.RegexpExtractAll 1379) -> str: 1380 group = expression.args.get("group") 1381 1382 # Do not render group if it's the default value for this dialect 1383 if group and group.name == str(self.dialect.REGEXP_EXTRACT_DEFAULT_GROUP): 1384 group = None 1385 1386 return self.func(expression.sql_name(), expression.this, expression.expression, group) 1387 1388 1389@unsupported_args("position", "occurrence", "modifiers") 1390def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1391 return self.func( 1392 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1393 ) 1394 1395 1396def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1397 names = [] 1398 for agg in aggregations: 1399 if isinstance(agg, exp.Alias): 1400 names.append(agg.alias) 1401 else: 1402 """ 1403 This case corresponds to aggregations without aliases being used as suffixes 1404 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1405 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1406 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1407 """ 1408 agg_all_unquoted = agg.transform( 1409 lambda node: ( 1410 exp.Identifier(this=node.name, quoted=False) 1411 if isinstance(node, exp.Identifier) 1412 else node 1413 ) 1414 ) 1415 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1416 1417 return names 1418 1419 1420def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 1421 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 1422 1423 1424# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 1425def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 1426 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 1427 1428 1429def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 1430 return self.func("MAX", expression.this) 1431 1432 1433def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 1434 a = self.sql(expression.left) 1435 b = self.sql(expression.right) 1436 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 1437 1438 1439def is_parse_json(expression: exp.Expression) -> bool: 1440 return isinstance(expression, exp.ParseJSON) or ( 1441 isinstance(expression, exp.Cast) and expression.is_type("json") 1442 ) 1443 1444 1445def isnull_to_is_null(args: t.List) -> exp.Expression: 1446 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 1447 1448 1449def generatedasidentitycolumnconstraint_sql( 1450 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 1451) -> str: 1452 start = self.sql(expression, "start") or "1" 1453 increment = self.sql(expression, "increment") or "1" 1454 return f"IDENTITY({start}, {increment})" 1455 1456 1457def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1458 @unsupported_args("count") 1459 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1460 return self.func(name, expression.this, expression.expression) 1461 1462 return _arg_max_or_min_sql 1463 1464 1465def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1466 this = expression.this.copy() 1467 1468 return_type = expression.return_type 1469 if return_type.is_type(exp.DataType.Type.DATE): 1470 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1471 # can truncate timestamp strings, because some dialects can't cast them to DATE 1472 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1473 1474 expression.this.replace(exp.cast(this, return_type)) 1475 return expression 1476 1477 1478def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1479 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1480 if cast and isinstance(expression, exp.TsOrDsAdd): 1481 expression = ts_or_ds_add_cast(expression) 1482 1483 return self.func( 1484 name, 1485 unit_to_var(expression), 1486 expression.expression, 1487 expression.this, 1488 ) 1489 1490 return _delta_sql 1491 1492 1493def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1494 unit = expression.args.get("unit") 1495 1496 if isinstance(unit, exp.Placeholder): 1497 return unit 1498 if unit: 1499 return exp.Literal.string(unit.name) 1500 return exp.Literal.string(default) if default else None 1501 1502 1503def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1504 unit = expression.args.get("unit") 1505 1506 if isinstance(unit, (exp.Var, exp.Placeholder)): 1507 return unit 1508 return exp.Var(this=default) if default else None 1509 1510 1511@t.overload 1512def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var: 1513 pass 1514 1515 1516@t.overload 1517def map_date_part( 1518 part: t.Optional[exp.Expression], dialect: DialectType = Dialect 1519) -> t.Optional[exp.Expression]: 1520 pass 1521 1522 1523def map_date_part(part, dialect: DialectType = Dialect): 1524 mapped = ( 1525 Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None 1526 ) 1527 return exp.var(mapped) if mapped else part 1528 1529 1530def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1531 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1532 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1533 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1534 1535 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1536 1537 1538def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1539 """Remove table refs from columns in when statements.""" 1540 alias = expression.this.args.get("alias") 1541 1542 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1543 return self.dialect.normalize_identifier(identifier).name if identifier else None 1544 1545 targets = {normalize(expression.this.this)} 1546 1547 if alias: 1548 targets.add(normalize(alias.this)) 1549 1550 for when in expression.expressions: 1551 # only remove the target names from the THEN clause 1552 # theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED 1553 # ref: https://github.com/TobikoData/sqlmesh/issues/2934 1554 then = when.args.get("then") 1555 if then: 1556 then.transform( 1557 lambda node: ( 1558 exp.column(node.this) 1559 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1560 else node 1561 ), 1562 copy=False, 1563 ) 1564 1565 return self.merge_sql(expression) 1566 1567 1568def build_json_extract_path( 1569 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1570) -> t.Callable[[t.List], F]: 1571 def _builder(args: t.List) -> F: 1572 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1573 for arg in args[1:]: 1574 if not isinstance(arg, exp.Literal): 1575 # We use the fallback parser because we can't really transpile non-literals safely 1576 return expr_type.from_arg_list(args) 1577 1578 text = arg.name 1579 if is_int(text): 1580 index = int(text) 1581 segments.append( 1582 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1583 ) 1584 else: 1585 segments.append(exp.JSONPathKey(this=text)) 1586 1587 # This is done to avoid failing in the expression validator due to the arg count 1588 del args[2:] 1589 return expr_type( 1590 this=seq_get(args, 0), 1591 expression=exp.JSONPath(expressions=segments), 1592 only_json_types=arrow_req_json_type, 1593 ) 1594 1595 return _builder 1596 1597 1598def json_extract_segments( 1599 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1600) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1601 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1602 path = expression.expression 1603 if not isinstance(path, exp.JSONPath): 1604 return rename_func(name)(self, expression) 1605 1606 escape = path.args.get("escape") 1607 1608 segments = [] 1609 for segment in path.expressions: 1610 path = self.sql(segment) 1611 if path: 1612 if isinstance(segment, exp.JSONPathPart) and ( 1613 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1614 ): 1615 if escape: 1616 path = self.escape_str(path) 1617 1618 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1619 1620 segments.append(path) 1621 1622 if op: 1623 return f" {op} ".join([self.sql(expression.this), *segments]) 1624 return self.func(name, expression.this, *segments) 1625 1626 return _json_extract_segments 1627 1628 1629def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1630 if isinstance(expression.this, exp.JSONPathWildcard): 1631 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1632 1633 return expression.name 1634 1635 1636def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1637 cond = expression.expression 1638 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1639 alias = cond.expressions[0] 1640 cond = cond.this 1641 elif isinstance(cond, exp.Predicate): 1642 alias = "_u" 1643 else: 1644 self.unsupported("Unsupported filter condition") 1645 return "" 1646 1647 unnest = exp.Unnest(expressions=[expression.this]) 1648 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1649 return self.sql(exp.Array(expressions=[filtered])) 1650 1651 1652def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1653 return self.func( 1654 "TO_NUMBER", 1655 expression.this, 1656 expression.args.get("format"), 1657 expression.args.get("nlsparam"), 1658 ) 1659 1660 1661def build_default_decimal_type( 1662 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1663) -> t.Callable[[exp.DataType], exp.DataType]: 1664 def _builder(dtype: exp.DataType) -> exp.DataType: 1665 if dtype.expressions or precision is None: 1666 return dtype 1667 1668 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1669 return exp.DataType.build(f"DECIMAL({params})") 1670 1671 return _builder 1672 1673 1674def build_timestamp_from_parts(args: t.List) -> exp.Func: 1675 if len(args) == 2: 1676 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1677 # so we parse this into Anonymous for now instead of introducing complexity 1678 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1679 1680 return exp.TimestampFromParts.from_arg_list(args) 1681 1682 1683def sha256_sql(self: Generator, expression: exp.SHA2) -> str: 1684 return self.func(f"SHA{expression.text('length') or '256'}", expression.this) 1685 1686 1687def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1688 start = expression.args.get("start") 1689 end = expression.args.get("end") 1690 step = expression.args.get("step") 1691 1692 if isinstance(start, exp.Cast): 1693 target_type = start.to 1694 elif isinstance(end, exp.Cast): 1695 target_type = end.to 1696 else: 1697 target_type = None 1698 1699 if start and end and target_type and target_type.is_type("date", "timestamp"): 1700 if isinstance(start, exp.Cast) and target_type is start.to: 1701 end = exp.cast(end, target_type) 1702 else: 1703 start = exp.cast(start, target_type) 1704 1705 return self.func("SEQUENCE", start, end, step) 1706 1707 1708def build_regexp_extract(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]: 1709 def _builder(args: t.List, dialect: Dialect) -> E: 1710 return expr_type( 1711 this=seq_get(args, 0), 1712 expression=seq_get(args, 1), 1713 group=seq_get(args, 2) or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP), 1714 parameters=seq_get(args, 3), 1715 ) 1716 1717 return _builder 1718 1719 1720def explode_to_unnest_sql(self: Generator, expression: exp.Lateral) -> str: 1721 if isinstance(expression.this, exp.Explode): 1722 return self.sql( 1723 exp.Join( 1724 this=exp.Unnest( 1725 expressions=[expression.this.this], 1726 alias=expression.args.get("alias"), 1727 offset=isinstance(expression.this, exp.Posexplode), 1728 ), 1729 kind="cross", 1730 ) 1731 ) 1732 return self.lateral_sql(expression) 1733 1734 1735def timestampdiff_sql(self: Generator, expression: exp.DatetimeDiff | exp.TimestampDiff) -> str: 1736 return self.func("TIMESTAMPDIFF", expression.unit, expression.expression, expression.this) 1737 1738 1739def no_make_interval_sql(self: Generator, expression: exp.MakeInterval, sep: str = ", ") -> str: 1740 args = [] 1741 for unit, value in expression.args.items(): 1742 if isinstance(value, exp.Kwarg): 1743 value = value.expression 1744 1745 args.append(f"{value} {unit}") 1746 1747 return f"INTERVAL '{self.format_args(*args, sep=sep)}'"
49class Dialects(str, Enum): 50 """Dialects supported by SQLGLot.""" 51 52 DIALECT = "" 53 54 ATHENA = "athena" 55 BIGQUERY = "bigquery" 56 CLICKHOUSE = "clickhouse" 57 DATABRICKS = "databricks" 58 DORIS = "doris" 59 DRILL = "drill" 60 DUCKDB = "duckdb" 61 HIVE = "hive" 62 MATERIALIZE = "materialize" 63 MYSQL = "mysql" 64 ORACLE = "oracle" 65 POSTGRES = "postgres" 66 PRESTO = "presto" 67 PRQL = "prql" 68 REDSHIFT = "redshift" 69 RISINGWAVE = "risingwave" 70 SNOWFLAKE = "snowflake" 71 SPARK = "spark" 72 SPARK2 = "spark2" 73 SQLITE = "sqlite" 74 STARROCKS = "starrocks" 75 TABLEAU = "tableau" 76 TERADATA = "teradata" 77 TRINO = "trino" 78 TSQL = "tsql"
Dialects supported by SQLGLot.
81class NormalizationStrategy(str, AutoName): 82 """Specifies the strategy according to which identifiers should be normalized.""" 83 84 LOWERCASE = auto() 85 """Unquoted identifiers are lowercased.""" 86 87 UPPERCASE = auto() 88 """Unquoted identifiers are uppercased.""" 89 90 CASE_SENSITIVE = auto() 91 """Always case-sensitive, regardless of quotes.""" 92 93 CASE_INSENSITIVE = auto() 94 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
219class Dialect(metaclass=_Dialect): 220 INDEX_OFFSET = 0 221 """The base index offset for arrays.""" 222 223 WEEK_OFFSET = 0 224 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 225 226 UNNEST_COLUMN_ONLY = False 227 """Whether `UNNEST` table aliases are treated as column aliases.""" 228 229 ALIAS_POST_TABLESAMPLE = False 230 """Whether the table alias comes after tablesample.""" 231 232 TABLESAMPLE_SIZE_IS_PERCENT = False 233 """Whether a size in the table sample clause represents percentage.""" 234 235 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 236 """Specifies the strategy according to which identifiers should be normalized.""" 237 238 IDENTIFIERS_CAN_START_WITH_DIGIT = False 239 """Whether an unquoted identifier can start with a digit.""" 240 241 DPIPE_IS_STRING_CONCAT = True 242 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 243 244 STRICT_STRING_CONCAT = False 245 """Whether `CONCAT`'s arguments must be strings.""" 246 247 SUPPORTS_USER_DEFINED_TYPES = True 248 """Whether user-defined data types are supported.""" 249 250 SUPPORTS_SEMI_ANTI_JOIN = True 251 """Whether `SEMI` or `ANTI` joins are supported.""" 252 253 SUPPORTS_COLUMN_JOIN_MARKS = False 254 """Whether the old-style outer join (+) syntax is supported.""" 255 256 COPY_PARAMS_ARE_CSV = True 257 """Separator of COPY statement parameters.""" 258 259 NORMALIZE_FUNCTIONS: bool | str = "upper" 260 """ 261 Determines how function names are going to be normalized. 262 Possible values: 263 "upper" or True: Convert names to uppercase. 264 "lower": Convert names to lowercase. 265 False: Disables function name normalization. 266 """ 267 268 PRESERVE_ORIGINAL_NAMES: bool = False 269 """ 270 Whether the name of the function should be preserved inside the node's metadata, 271 can be useful for roundtripping deprecated vs new functions that share an AST node 272 e.g JSON_VALUE vs JSON_EXTRACT_SCALAR in BigQuery 273 """ 274 275 LOG_BASE_FIRST: t.Optional[bool] = True 276 """ 277 Whether the base comes first in the `LOG` function. 278 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 279 """ 280 281 NULL_ORDERING = "nulls_are_small" 282 """ 283 Default `NULL` ordering method to use if not explicitly set. 284 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 285 """ 286 287 TYPED_DIVISION = False 288 """ 289 Whether the behavior of `a / b` depends on the types of `a` and `b`. 290 False means `a / b` is always float division. 291 True means `a / b` is integer division if both `a` and `b` are integers. 292 """ 293 294 SAFE_DIVISION = False 295 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 296 297 CONCAT_COALESCE = False 298 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 299 300 HEX_LOWERCASE = False 301 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 302 303 DATE_FORMAT = "'%Y-%m-%d'" 304 DATEINT_FORMAT = "'%Y%m%d'" 305 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 306 307 TIME_MAPPING: t.Dict[str, str] = {} 308 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 309 310 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 311 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 312 FORMAT_MAPPING: t.Dict[str, str] = {} 313 """ 314 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 315 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 316 """ 317 318 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 319 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 320 321 PSEUDOCOLUMNS: t.Set[str] = set() 322 """ 323 Columns that are auto-generated by the engine corresponding to this dialect. 324 For example, such columns may be excluded from `SELECT *` queries. 325 """ 326 327 PREFER_CTE_ALIAS_COLUMN = False 328 """ 329 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 330 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 331 any projection aliases in the subquery. 332 333 For example, 334 WITH y(c) AS ( 335 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 336 ) SELECT c FROM y; 337 338 will be rewritten as 339 340 WITH y(c) AS ( 341 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 342 ) SELECT c FROM y; 343 """ 344 345 COPY_PARAMS_ARE_CSV = True 346 """ 347 Whether COPY statement parameters are separated by comma or whitespace 348 """ 349 350 FORCE_EARLY_ALIAS_REF_EXPANSION = False 351 """ 352 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 353 354 For example: 355 WITH data AS ( 356 SELECT 357 1 AS id, 358 2 AS my_id 359 ) 360 SELECT 361 id AS my_id 362 FROM 363 data 364 WHERE 365 my_id = 1 366 GROUP BY 367 my_id, 368 HAVING 369 my_id = 1 370 371 In most dialects, "my_id" would refer to "data.my_id" across the query, except: 372 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e 373 it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 374 - Clickhouse, which will forward the alias across the query i.e it resolves 375 to "WHERE id = 1 GROUP BY id HAVING id = 1" 376 """ 377 378 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 379 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 380 381 SUPPORTS_ORDER_BY_ALL = False 382 """ 383 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 384 """ 385 386 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 387 """ 388 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 389 as the former is of type INT[] vs the latter which is SUPER 390 """ 391 392 SUPPORTS_FIXED_SIZE_ARRAYS = False 393 """ 394 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. 395 in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should 396 be interpreted as a subscript/index operator. 397 """ 398 399 STRICT_JSON_PATH_SYNTAX = True 400 """Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.""" 401 402 ON_CONDITION_EMPTY_BEFORE_ERROR = True 403 """Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).""" 404 405 ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True 406 """Whether ArrayAgg needs to filter NULL values.""" 407 408 PROMOTE_TO_INFERRED_DATETIME_TYPE = False 409 """ 410 This flag is used in the optimizer's canonicalize rule and determines whether x will be promoted 411 to the literal's type in x::DATE < '2020-01-01 12:05:03' (i.e., DATETIME). When false, the literal 412 is cast to x's type to match it instead. 413 """ 414 415 SUPPORTS_VALUES_DEFAULT = True 416 """Whether the DEFAULT keyword is supported in the VALUES clause.""" 417 418 REGEXP_EXTRACT_DEFAULT_GROUP = 0 419 """The default value for the capturing group.""" 420 421 SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = { 422 exp.Except: True, 423 exp.Intersect: True, 424 exp.Union: True, 425 } 426 """ 427 Whether a set operation uses DISTINCT by default. This is `None` when either `DISTINCT` or `ALL` 428 must be explicitly specified. 429 """ 430 431 CREATABLE_KIND_MAPPING: dict[str, str] = {} 432 """ 433 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 434 equivalent of CREATE SCHEMA is CREATE DATABASE. 435 """ 436 437 # --- Autofilled --- 438 439 tokenizer_class = Tokenizer 440 jsonpath_tokenizer_class = JSONPathTokenizer 441 parser_class = Parser 442 generator_class = Generator 443 444 # A trie of the time_mapping keys 445 TIME_TRIE: t.Dict = {} 446 FORMAT_TRIE: t.Dict = {} 447 448 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 449 INVERSE_TIME_TRIE: t.Dict = {} 450 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 451 INVERSE_FORMAT_TRIE: t.Dict = {} 452 453 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 454 455 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 456 457 # Delimiters for string literals and identifiers 458 QUOTE_START = "'" 459 QUOTE_END = "'" 460 IDENTIFIER_START = '"' 461 IDENTIFIER_END = '"' 462 463 # Delimiters for bit, hex, byte and unicode literals 464 BIT_START: t.Optional[str] = None 465 BIT_END: t.Optional[str] = None 466 HEX_START: t.Optional[str] = None 467 HEX_END: t.Optional[str] = None 468 BYTE_START: t.Optional[str] = None 469 BYTE_END: t.Optional[str] = None 470 UNICODE_START: t.Optional[str] = None 471 UNICODE_END: t.Optional[str] = None 472 473 DATE_PART_MAPPING = { 474 "Y": "YEAR", 475 "YY": "YEAR", 476 "YYY": "YEAR", 477 "YYYY": "YEAR", 478 "YR": "YEAR", 479 "YEARS": "YEAR", 480 "YRS": "YEAR", 481 "MM": "MONTH", 482 "MON": "MONTH", 483 "MONS": "MONTH", 484 "MONTHS": "MONTH", 485 "D": "DAY", 486 "DD": "DAY", 487 "DAYS": "DAY", 488 "DAYOFMONTH": "DAY", 489 "DAY OF WEEK": "DAYOFWEEK", 490 "WEEKDAY": "DAYOFWEEK", 491 "DOW": "DAYOFWEEK", 492 "DW": "DAYOFWEEK", 493 "WEEKDAY_ISO": "DAYOFWEEKISO", 494 "DOW_ISO": "DAYOFWEEKISO", 495 "DW_ISO": "DAYOFWEEKISO", 496 "DAY OF YEAR": "DAYOFYEAR", 497 "DOY": "DAYOFYEAR", 498 "DY": "DAYOFYEAR", 499 "W": "WEEK", 500 "WK": "WEEK", 501 "WEEKOFYEAR": "WEEK", 502 "WOY": "WEEK", 503 "WY": "WEEK", 504 "WEEK_ISO": "WEEKISO", 505 "WEEKOFYEARISO": "WEEKISO", 506 "WEEKOFYEAR_ISO": "WEEKISO", 507 "Q": "QUARTER", 508 "QTR": "QUARTER", 509 "QTRS": "QUARTER", 510 "QUARTERS": "QUARTER", 511 "H": "HOUR", 512 "HH": "HOUR", 513 "HR": "HOUR", 514 "HOURS": "HOUR", 515 "HRS": "HOUR", 516 "M": "MINUTE", 517 "MI": "MINUTE", 518 "MIN": "MINUTE", 519 "MINUTES": "MINUTE", 520 "MINS": "MINUTE", 521 "S": "SECOND", 522 "SEC": "SECOND", 523 "SECONDS": "SECOND", 524 "SECS": "SECOND", 525 "MS": "MILLISECOND", 526 "MSEC": "MILLISECOND", 527 "MSECS": "MILLISECOND", 528 "MSECOND": "MILLISECOND", 529 "MSECONDS": "MILLISECOND", 530 "MILLISEC": "MILLISECOND", 531 "MILLISECS": "MILLISECOND", 532 "MILLISECON": "MILLISECOND", 533 "MILLISECONDS": "MILLISECOND", 534 "US": "MICROSECOND", 535 "USEC": "MICROSECOND", 536 "USECS": "MICROSECOND", 537 "MICROSEC": "MICROSECOND", 538 "MICROSECS": "MICROSECOND", 539 "USECOND": "MICROSECOND", 540 "USECONDS": "MICROSECOND", 541 "MICROSECONDS": "MICROSECOND", 542 "NS": "NANOSECOND", 543 "NSEC": "NANOSECOND", 544 "NANOSEC": "NANOSECOND", 545 "NSECOND": "NANOSECOND", 546 "NSECONDS": "NANOSECOND", 547 "NANOSECS": "NANOSECOND", 548 "EPOCH_SECOND": "EPOCH", 549 "EPOCH_SECONDS": "EPOCH", 550 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 551 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 552 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 553 "TZH": "TIMEZONE_HOUR", 554 "TZM": "TIMEZONE_MINUTE", 555 "DEC": "DECADE", 556 "DECS": "DECADE", 557 "DECADES": "DECADE", 558 "MIL": "MILLENIUM", 559 "MILS": "MILLENIUM", 560 "MILLENIA": "MILLENIUM", 561 "C": "CENTURY", 562 "CENT": "CENTURY", 563 "CENTS": "CENTURY", 564 "CENTURIES": "CENTURY", 565 } 566 567 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 568 exp.DataType.Type.BIGINT: { 569 exp.ApproxDistinct, 570 exp.ArraySize, 571 exp.Length, 572 }, 573 exp.DataType.Type.BOOLEAN: { 574 exp.Between, 575 exp.Boolean, 576 exp.In, 577 exp.RegexpLike, 578 }, 579 exp.DataType.Type.DATE: { 580 exp.CurrentDate, 581 exp.Date, 582 exp.DateFromParts, 583 exp.DateStrToDate, 584 exp.DiToDate, 585 exp.StrToDate, 586 exp.TimeStrToDate, 587 exp.TsOrDsToDate, 588 }, 589 exp.DataType.Type.DATETIME: { 590 exp.CurrentDatetime, 591 exp.Datetime, 592 exp.DatetimeAdd, 593 exp.DatetimeSub, 594 }, 595 exp.DataType.Type.DOUBLE: { 596 exp.ApproxQuantile, 597 exp.Avg, 598 exp.Exp, 599 exp.Ln, 600 exp.Log, 601 exp.Pow, 602 exp.Quantile, 603 exp.Round, 604 exp.SafeDivide, 605 exp.Sqrt, 606 exp.Stddev, 607 exp.StddevPop, 608 exp.StddevSamp, 609 exp.ToDouble, 610 exp.Variance, 611 exp.VariancePop, 612 }, 613 exp.DataType.Type.INT: { 614 exp.Ceil, 615 exp.DatetimeDiff, 616 exp.DateDiff, 617 exp.TimestampDiff, 618 exp.TimeDiff, 619 exp.DateToDi, 620 exp.Levenshtein, 621 exp.Sign, 622 exp.StrPosition, 623 exp.TsOrDiToDi, 624 }, 625 exp.DataType.Type.JSON: { 626 exp.ParseJSON, 627 }, 628 exp.DataType.Type.TIME: { 629 exp.Time, 630 }, 631 exp.DataType.Type.TIMESTAMP: { 632 exp.CurrentTime, 633 exp.CurrentTimestamp, 634 exp.StrToTime, 635 exp.TimeAdd, 636 exp.TimeStrToTime, 637 exp.TimeSub, 638 exp.TimestampAdd, 639 exp.TimestampSub, 640 exp.UnixToTime, 641 }, 642 exp.DataType.Type.TINYINT: { 643 exp.Day, 644 exp.Month, 645 exp.Week, 646 exp.Year, 647 exp.Quarter, 648 }, 649 exp.DataType.Type.VARCHAR: { 650 exp.ArrayConcat, 651 exp.Concat, 652 exp.ConcatWs, 653 exp.DateToDateStr, 654 exp.GroupConcat, 655 exp.Initcap, 656 exp.Lower, 657 exp.Substring, 658 exp.String, 659 exp.TimeToStr, 660 exp.TimeToTimeStr, 661 exp.Trim, 662 exp.TsOrDsToDateStr, 663 exp.UnixToStr, 664 exp.UnixToTimeStr, 665 exp.Upper, 666 }, 667 } 668 669 ANNOTATORS: AnnotatorsType = { 670 **{ 671 expr_type: lambda self, e: self._annotate_unary(e) 672 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 673 }, 674 **{ 675 expr_type: lambda self, e: self._annotate_binary(e) 676 for expr_type in subclasses(exp.__name__, exp.Binary) 677 }, 678 **{ 679 expr_type: _annotate_with_type_lambda(data_type) 680 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 681 for expr_type in expressions 682 }, 683 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 684 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 685 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 686 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 687 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 688 exp.Bracket: lambda self, e: self._annotate_bracket(e), 689 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 690 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 691 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 692 exp.Count: lambda self, e: self._annotate_with_type( 693 e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT 694 ), 695 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 696 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 697 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 698 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 699 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 700 exp.Div: lambda self, e: self._annotate_div(e), 701 exp.Dot: lambda self, e: self._annotate_dot(e), 702 exp.Explode: lambda self, e: self._annotate_explode(e), 703 exp.Extract: lambda self, e: self._annotate_extract(e), 704 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 705 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 706 e, exp.DataType.build("ARRAY<DATE>") 707 ), 708 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 709 e, exp.DataType.build("ARRAY<TIMESTAMP>") 710 ), 711 exp.Greatest: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 712 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 713 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 714 exp.Least: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 715 exp.Literal: lambda self, e: self._annotate_literal(e), 716 exp.Map: lambda self, e: self._annotate_map(e), 717 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 718 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 719 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 720 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 721 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 722 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 723 exp.Struct: lambda self, e: self._annotate_struct(e), 724 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 725 exp.Timestamp: lambda self, e: self._annotate_with_type( 726 e, 727 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 728 ), 729 exp.ToMap: lambda self, e: self._annotate_to_map(e), 730 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 731 exp.Unnest: lambda self, e: self._annotate_unnest(e), 732 exp.VarMap: lambda self, e: self._annotate_map(e), 733 } 734 735 @classmethod 736 def get_or_raise(cls, dialect: DialectType) -> Dialect: 737 """ 738 Look up a dialect in the global dialect registry and return it if it exists. 739 740 Args: 741 dialect: The target dialect. If this is a string, it can be optionally followed by 742 additional key-value pairs that are separated by commas and are used to specify 743 dialect settings, such as whether the dialect's identifiers are case-sensitive. 744 745 Example: 746 >>> dialect = dialect_class = get_or_raise("duckdb") 747 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 748 749 Returns: 750 The corresponding Dialect instance. 751 """ 752 753 if not dialect: 754 return cls() 755 if isinstance(dialect, _Dialect): 756 return dialect() 757 if isinstance(dialect, Dialect): 758 return dialect 759 if isinstance(dialect, str): 760 try: 761 dialect_name, *kv_strings = dialect.split(",") 762 kv_pairs = (kv.split("=") for kv in kv_strings) 763 kwargs = {} 764 for pair in kv_pairs: 765 key = pair[0].strip() 766 value: t.Union[bool | str | None] = None 767 768 if len(pair) == 1: 769 # Default initialize standalone settings to True 770 value = True 771 elif len(pair) == 2: 772 value = pair[1].strip() 773 774 kwargs[key] = to_bool(value) 775 776 except ValueError: 777 raise ValueError( 778 f"Invalid dialect format: '{dialect}'. " 779 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 780 ) 781 782 result = cls.get(dialect_name.strip()) 783 if not result: 784 from difflib import get_close_matches 785 786 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 787 if similar: 788 similar = f" Did you mean {similar}?" 789 790 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 791 792 return result(**kwargs) 793 794 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 795 796 @classmethod 797 def format_time( 798 cls, expression: t.Optional[str | exp.Expression] 799 ) -> t.Optional[exp.Expression]: 800 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 801 if isinstance(expression, str): 802 return exp.Literal.string( 803 # the time formats are quoted 804 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 805 ) 806 807 if expression and expression.is_string: 808 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 809 810 return expression 811 812 def __init__(self, **kwargs) -> None: 813 normalization_strategy = kwargs.pop("normalization_strategy", None) 814 815 if normalization_strategy is None: 816 self.normalization_strategy = self.NORMALIZATION_STRATEGY 817 else: 818 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 819 820 self.settings = kwargs 821 822 def __eq__(self, other: t.Any) -> bool: 823 # Does not currently take dialect state into account 824 return type(self) == other 825 826 def __hash__(self) -> int: 827 # Does not currently take dialect state into account 828 return hash(type(self)) 829 830 def normalize_identifier(self, expression: E) -> E: 831 """ 832 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 833 834 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 835 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 836 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 837 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 838 839 There are also dialects like Spark, which are case-insensitive even when quotes are 840 present, and dialects like MySQL, whose resolution rules match those employed by the 841 underlying operating system, for example they may always be case-sensitive in Linux. 842 843 Finally, the normalization behavior of some engines can even be controlled through flags, 844 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 845 846 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 847 that it can analyze queries in the optimizer and successfully capture their semantics. 848 """ 849 if ( 850 isinstance(expression, exp.Identifier) 851 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 852 and ( 853 not expression.quoted 854 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 855 ) 856 ): 857 expression.set( 858 "this", 859 ( 860 expression.this.upper() 861 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 862 else expression.this.lower() 863 ), 864 ) 865 866 return expression 867 868 def case_sensitive(self, text: str) -> bool: 869 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 870 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 871 return False 872 873 unsafe = ( 874 str.islower 875 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 876 else str.isupper 877 ) 878 return any(unsafe(char) for char in text) 879 880 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 881 """Checks if text can be identified given an identify option. 882 883 Args: 884 text: The text to check. 885 identify: 886 `"always"` or `True`: Always returns `True`. 887 `"safe"`: Only returns `True` if the identifier is case-insensitive. 888 889 Returns: 890 Whether the given text can be identified. 891 """ 892 if identify is True or identify == "always": 893 return True 894 895 if identify == "safe": 896 return not self.case_sensitive(text) 897 898 return False 899 900 def quote_identifier(self, expression: E, identify: bool = True) -> E: 901 """ 902 Adds quotes to a given identifier. 903 904 Args: 905 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 906 identify: If set to `False`, the quotes will only be added if the identifier is deemed 907 "unsafe", with respect to its characters and this dialect's normalization strategy. 908 """ 909 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 910 name = expression.this 911 expression.set( 912 "quoted", 913 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 914 ) 915 916 return expression 917 918 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 919 if isinstance(path, exp.Literal): 920 path_text = path.name 921 if path.is_number: 922 path_text = f"[{path_text}]" 923 try: 924 return parse_json_path(path_text, self) 925 except ParseError as e: 926 if self.STRICT_JSON_PATH_SYNTAX: 927 logger.warning(f"Invalid JSON path syntax. {str(e)}") 928 929 return path 930 931 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 932 return self.parser(**opts).parse(self.tokenize(sql), sql) 933 934 def parse_into( 935 self, expression_type: exp.IntoType, sql: str, **opts 936 ) -> t.List[t.Optional[exp.Expression]]: 937 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 938 939 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 940 return self.generator(**opts).generate(expression, copy=copy) 941 942 def transpile(self, sql: str, **opts) -> t.List[str]: 943 return [ 944 self.generate(expression, copy=False, **opts) if expression else "" 945 for expression in self.parse(sql) 946 ] 947 948 def tokenize(self, sql: str) -> t.List[Token]: 949 return self.tokenizer.tokenize(sql) 950 951 @property 952 def tokenizer(self) -> Tokenizer: 953 return self.tokenizer_class(dialect=self) 954 955 @property 956 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 957 return self.jsonpath_tokenizer_class(dialect=self) 958 959 def parser(self, **opts) -> Parser: 960 return self.parser_class(dialect=self, **opts) 961 962 def generator(self, **opts) -> Generator: 963 return self.generator_class(dialect=self, **opts)
812 def __init__(self, **kwargs) -> None: 813 normalization_strategy = kwargs.pop("normalization_strategy", None) 814 815 if normalization_strategy is None: 816 self.normalization_strategy = self.NORMALIZATION_STRATEGY 817 else: 818 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 819 820 self.settings = kwargs
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the name of the function should be preserved inside the node's metadata, can be useful for roundtripping deprecated vs new functions that share an AST node e.g JSON_VALUE vs JSON_EXTRACT_SCALAR in BigQuery
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
For example:
WITH data AS ( SELECT 1 AS id, 2 AS my_id ) SELECT id AS my_id FROM data WHERE my_id = 1 GROUP BY my_id, HAVING my_id = 1
In most dialects, "my_id" would refer to "data.my_id" across the query, except: - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
Whether alias reference expansion before qualification should only happen for the GROUP BY clause.
Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks
Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) as the former is of type INT[] vs the latter which is SUPER
Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator.
Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.
Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).
This flag is used in the optimizer's canonicalize rule and determines whether x will be promoted to the literal's type in x::DATE < '2020-01-01 12:05:03' (i.e., DATETIME). When false, the literal is cast to x's type to match it instead.
Whether a set operation uses DISTINCT by default. This is None
when either DISTINCT
or ALL
must be explicitly specified.
Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse equivalent of CREATE SCHEMA is CREATE DATABASE.
735 @classmethod 736 def get_or_raise(cls, dialect: DialectType) -> Dialect: 737 """ 738 Look up a dialect in the global dialect registry and return it if it exists. 739 740 Args: 741 dialect: The target dialect. If this is a string, it can be optionally followed by 742 additional key-value pairs that are separated by commas and are used to specify 743 dialect settings, such as whether the dialect's identifiers are case-sensitive. 744 745 Example: 746 >>> dialect = dialect_class = get_or_raise("duckdb") 747 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 748 749 Returns: 750 The corresponding Dialect instance. 751 """ 752 753 if not dialect: 754 return cls() 755 if isinstance(dialect, _Dialect): 756 return dialect() 757 if isinstance(dialect, Dialect): 758 return dialect 759 if isinstance(dialect, str): 760 try: 761 dialect_name, *kv_strings = dialect.split(",") 762 kv_pairs = (kv.split("=") for kv in kv_strings) 763 kwargs = {} 764 for pair in kv_pairs: 765 key = pair[0].strip() 766 value: t.Union[bool | str | None] = None 767 768 if len(pair) == 1: 769 # Default initialize standalone settings to True 770 value = True 771 elif len(pair) == 2: 772 value = pair[1].strip() 773 774 kwargs[key] = to_bool(value) 775 776 except ValueError: 777 raise ValueError( 778 f"Invalid dialect format: '{dialect}'. " 779 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 780 ) 781 782 result = cls.get(dialect_name.strip()) 783 if not result: 784 from difflib import get_close_matches 785 786 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 787 if similar: 788 similar = f" Did you mean {similar}?" 789 790 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 791 792 return result(**kwargs) 793 794 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
796 @classmethod 797 def format_time( 798 cls, expression: t.Optional[str | exp.Expression] 799 ) -> t.Optional[exp.Expression]: 800 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 801 if isinstance(expression, str): 802 return exp.Literal.string( 803 # the time formats are quoted 804 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 805 ) 806 807 if expression and expression.is_string: 808 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 809 810 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
830 def normalize_identifier(self, expression: E) -> E: 831 """ 832 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 833 834 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 835 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 836 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 837 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 838 839 There are also dialects like Spark, which are case-insensitive even when quotes are 840 present, and dialects like MySQL, whose resolution rules match those employed by the 841 underlying operating system, for example they may always be case-sensitive in Linux. 842 843 Finally, the normalization behavior of some engines can even be controlled through flags, 844 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 845 846 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 847 that it can analyze queries in the optimizer and successfully capture their semantics. 848 """ 849 if ( 850 isinstance(expression, exp.Identifier) 851 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 852 and ( 853 not expression.quoted 854 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 855 ) 856 ): 857 expression.set( 858 "this", 859 ( 860 expression.this.upper() 861 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 862 else expression.this.lower() 863 ), 864 ) 865 866 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
868 def case_sensitive(self, text: str) -> bool: 869 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 870 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 871 return False 872 873 unsafe = ( 874 str.islower 875 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 876 else str.isupper 877 ) 878 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
880 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 881 """Checks if text can be identified given an identify option. 882 883 Args: 884 text: The text to check. 885 identify: 886 `"always"` or `True`: Always returns `True`. 887 `"safe"`: Only returns `True` if the identifier is case-insensitive. 888 889 Returns: 890 Whether the given text can be identified. 891 """ 892 if identify is True or identify == "always": 893 return True 894 895 if identify == "safe": 896 return not self.case_sensitive(text) 897 898 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
900 def quote_identifier(self, expression: E, identify: bool = True) -> E: 901 """ 902 Adds quotes to a given identifier. 903 904 Args: 905 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 906 identify: If set to `False`, the quotes will only be added if the identifier is deemed 907 "unsafe", with respect to its characters and this dialect's normalization strategy. 908 """ 909 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 910 name = expression.this 911 expression.set( 912 "quoted", 913 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 914 ) 915 916 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
918 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 919 if isinstance(path, exp.Literal): 920 path_text = path.name 921 if path.is_number: 922 path_text = f"[{path_text}]" 923 try: 924 return parse_json_path(path_text, self) 925 except ParseError as e: 926 if self.STRICT_JSON_PATH_SYNTAX: 927 logger.warning(f"Invalid JSON path syntax. {str(e)}") 928 929 return path
978def if_sql( 979 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 980) -> t.Callable[[Generator, exp.If], str]: 981 def _if_sql(self: Generator, expression: exp.If) -> str: 982 return self.func( 983 name, 984 expression.this, 985 expression.args.get("true"), 986 expression.args.get("false") or false_value, 987 ) 988 989 return _if_sql
992def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 993 this = expression.this 994 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 995 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 996 997 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
1067def str_position_sql( 1068 self: Generator, 1069 expression: exp.StrPosition, 1070 generate_instance: bool = False, 1071 str_position_func_name: str = "STRPOS", 1072) -> str: 1073 this = self.sql(expression, "this") 1074 substr = self.sql(expression, "substr") 1075 position = self.sql(expression, "position") 1076 instance = expression.args.get("instance") if generate_instance else None 1077 position_offset = "" 1078 1079 if position: 1080 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1081 this = self.func("SUBSTR", this, position) 1082 position_offset = f" + {position} - 1" 1083 1084 return self.func(str_position_func_name, this, substr, instance) + position_offset
1093def var_map_sql( 1094 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1095) -> str: 1096 keys = expression.args["keys"] 1097 values = expression.args["values"] 1098 1099 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1100 self.unsupported("Cannot convert array columns into map.") 1101 return self.func(map_func_name, keys, values) 1102 1103 args = [] 1104 for key, value in zip(keys.expressions, values.expressions): 1105 args.append(self.sql(key)) 1106 args.append(self.sql(value)) 1107 1108 return self.func(map_func_name, *args)
1111def build_formatted_time( 1112 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1113) -> t.Callable[[t.List], E]: 1114 """Helper used for time expressions. 1115 1116 Args: 1117 exp_class: the expression class to instantiate. 1118 dialect: target sql dialect. 1119 default: the default format, True being time. 1120 1121 Returns: 1122 A callable that can be used to return the appropriately formatted time expression. 1123 """ 1124 1125 def _builder(args: t.List): 1126 return exp_class( 1127 this=seq_get(args, 0), 1128 format=Dialect[dialect].format_time( 1129 seq_get(args, 1) 1130 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1131 ), 1132 ) 1133 1134 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
1137def time_format( 1138 dialect: DialectType = None, 1139) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1140 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1141 """ 1142 Returns the time format for a given expression, unless it's equivalent 1143 to the default time format of the dialect of interest. 1144 """ 1145 time_format = self.format_time(expression) 1146 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1147 1148 return _time_format
1151def build_date_delta( 1152 exp_class: t.Type[E], 1153 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1154 default_unit: t.Optional[str] = "DAY", 1155) -> t.Callable[[t.List], E]: 1156 def _builder(args: t.List) -> E: 1157 unit_based = len(args) == 3 1158 this = args[2] if unit_based else seq_get(args, 0) 1159 unit = None 1160 if unit_based or default_unit: 1161 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1162 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1163 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1164 1165 return _builder
1168def build_date_delta_with_interval( 1169 expression_class: t.Type[E], 1170) -> t.Callable[[t.List], t.Optional[E]]: 1171 def _builder(args: t.List) -> t.Optional[E]: 1172 if len(args) < 2: 1173 return None 1174 1175 interval = args[1] 1176 1177 if not isinstance(interval, exp.Interval): 1178 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1179 1180 return expression_class(this=args[0], expression=interval.this, unit=unit_to_str(interval)) 1181 1182 return _builder
1185def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1186 unit = seq_get(args, 0) 1187 this = seq_get(args, 1) 1188 1189 if isinstance(this, exp.Cast) and this.is_type("date"): 1190 return exp.DateTrunc(unit=unit, this=this) 1191 return exp.TimestampTrunc(this=this, unit=unit)
1194def date_add_interval_sql( 1195 data_type: str, kind: str 1196) -> t.Callable[[Generator, exp.Expression], str]: 1197 def func(self: Generator, expression: exp.Expression) -> str: 1198 this = self.sql(expression, "this") 1199 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1200 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1201 1202 return func
1205def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1206 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1207 args = [unit_to_str(expression), expression.this] 1208 if zone: 1209 args.append(expression.args.get("zone")) 1210 return self.func("DATE_TRUNC", *args) 1211 1212 return _timestamptrunc_sql
1215def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1216 zone = expression.args.get("zone") 1217 if not zone: 1218 from sqlglot.optimizer.annotate_types import annotate_types 1219 1220 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1221 return self.sql(exp.cast(expression.this, target_type)) 1222 if zone.name.lower() in TIMEZONES: 1223 return self.sql( 1224 exp.AtTimeZone( 1225 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1226 zone=zone, 1227 ) 1228 ) 1229 return self.func("TIMESTAMP", expression.this, zone)
1232def no_time_sql(self: Generator, expression: exp.Time) -> str: 1233 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1234 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1235 expr = exp.cast( 1236 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1237 ) 1238 return self.sql(expr)
1241def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1242 this = expression.this 1243 expr = expression.expression 1244 1245 if expr.name.lower() in TIMEZONES: 1246 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1247 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1248 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1249 return self.sql(this) 1250 1251 this = exp.cast(this, exp.DataType.Type.DATE) 1252 expr = exp.cast(expr, exp.DataType.Type.TIME) 1253 1254 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
1286def timestrtotime_sql( 1287 self: Generator, 1288 expression: exp.TimeStrToTime, 1289 include_precision: bool = False, 1290) -> str: 1291 datatype = exp.DataType.build( 1292 exp.DataType.Type.TIMESTAMPTZ 1293 if expression.args.get("zone") 1294 else exp.DataType.Type.TIMESTAMP 1295 ) 1296 1297 if isinstance(expression.this, exp.Literal) and include_precision: 1298 precision = subsecond_precision(expression.this.name) 1299 if precision > 0: 1300 datatype = exp.DataType.build( 1301 datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))] 1302 ) 1303 1304 return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect))
1312def encode_decode_sql( 1313 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1314) -> str: 1315 charset = expression.args.get("charset") 1316 if charset and charset.name.lower() != "utf-8": 1317 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1318 1319 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
1332def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1333 cond = expression.this 1334 1335 if isinstance(expression.this, exp.Distinct): 1336 cond = expression.this.expressions[0] 1337 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1338 1339 return self.func("sum", exp.func("if", cond, 1, 0))
1342def trim_sql(self: Generator, expression: exp.Trim) -> str: 1343 target = self.sql(expression, "this") 1344 trim_type = self.sql(expression, "position") 1345 remove_chars = self.sql(expression, "expression") 1346 collation = self.sql(expression, "collation") 1347 1348 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1349 if not remove_chars: 1350 return self.trim_sql(expression) 1351 1352 trim_type = f"{trim_type} " if trim_type else "" 1353 remove_chars = f"{remove_chars} " if remove_chars else "" 1354 from_part = "FROM " if trim_type or remove_chars else "" 1355 collation = f" COLLATE {collation}" if collation else "" 1356 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
1377@unsupported_args("position", "occurrence", "parameters") 1378def regexp_extract_sql( 1379 self: Generator, expression: exp.RegexpExtract | exp.RegexpExtractAll 1380) -> str: 1381 group = expression.args.get("group") 1382 1383 # Do not render group if it's the default value for this dialect 1384 if group and group.name == str(self.dialect.REGEXP_EXTRACT_DEFAULT_GROUP): 1385 group = None 1386 1387 return self.func(expression.sql_name(), expression.this, expression.expression, group)
1397def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1398 names = [] 1399 for agg in aggregations: 1400 if isinstance(agg, exp.Alias): 1401 names.append(agg.alias) 1402 else: 1403 """ 1404 This case corresponds to aggregations without aliases being used as suffixes 1405 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1406 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1407 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1408 """ 1409 agg_all_unquoted = agg.transform( 1410 lambda node: ( 1411 exp.Identifier(this=node.name, quoted=False) 1412 if isinstance(node, exp.Identifier) 1413 else node 1414 ) 1415 ) 1416 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1417 1418 return names
1458def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1459 @unsupported_args("count") 1460 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1461 return self.func(name, expression.this, expression.expression) 1462 1463 return _arg_max_or_min_sql
1466def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1467 this = expression.this.copy() 1468 1469 return_type = expression.return_type 1470 if return_type.is_type(exp.DataType.Type.DATE): 1471 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1472 # can truncate timestamp strings, because some dialects can't cast them to DATE 1473 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1474 1475 expression.this.replace(exp.cast(this, return_type)) 1476 return expression
1479def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1480 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1481 if cast and isinstance(expression, exp.TsOrDsAdd): 1482 expression = ts_or_ds_add_cast(expression) 1483 1484 return self.func( 1485 name, 1486 unit_to_var(expression), 1487 expression.expression, 1488 expression.this, 1489 ) 1490 1491 return _delta_sql
1494def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1495 unit = expression.args.get("unit") 1496 1497 if isinstance(unit, exp.Placeholder): 1498 return unit 1499 if unit: 1500 return exp.Literal.string(unit.name) 1501 return exp.Literal.string(default) if default else None
1531def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1532 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1533 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1534 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1535 1536 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1539def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1540 """Remove table refs from columns in when statements.""" 1541 alias = expression.this.args.get("alias") 1542 1543 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1544 return self.dialect.normalize_identifier(identifier).name if identifier else None 1545 1546 targets = {normalize(expression.this.this)} 1547 1548 if alias: 1549 targets.add(normalize(alias.this)) 1550 1551 for when in expression.expressions: 1552 # only remove the target names from the THEN clause 1553 # theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED 1554 # ref: https://github.com/TobikoData/sqlmesh/issues/2934 1555 then = when.args.get("then") 1556 if then: 1557 then.transform( 1558 lambda node: ( 1559 exp.column(node.this) 1560 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1561 else node 1562 ), 1563 copy=False, 1564 ) 1565 1566 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1569def build_json_extract_path( 1570 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1571) -> t.Callable[[t.List], F]: 1572 def _builder(args: t.List) -> F: 1573 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1574 for arg in args[1:]: 1575 if not isinstance(arg, exp.Literal): 1576 # We use the fallback parser because we can't really transpile non-literals safely 1577 return expr_type.from_arg_list(args) 1578 1579 text = arg.name 1580 if is_int(text): 1581 index = int(text) 1582 segments.append( 1583 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1584 ) 1585 else: 1586 segments.append(exp.JSONPathKey(this=text)) 1587 1588 # This is done to avoid failing in the expression validator due to the arg count 1589 del args[2:] 1590 return expr_type( 1591 this=seq_get(args, 0), 1592 expression=exp.JSONPath(expressions=segments), 1593 only_json_types=arrow_req_json_type, 1594 ) 1595 1596 return _builder
1599def json_extract_segments( 1600 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1601) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1602 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1603 path = expression.expression 1604 if not isinstance(path, exp.JSONPath): 1605 return rename_func(name)(self, expression) 1606 1607 escape = path.args.get("escape") 1608 1609 segments = [] 1610 for segment in path.expressions: 1611 path = self.sql(segment) 1612 if path: 1613 if isinstance(segment, exp.JSONPathPart) and ( 1614 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1615 ): 1616 if escape: 1617 path = self.escape_str(path) 1618 1619 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1620 1621 segments.append(path) 1622 1623 if op: 1624 return f" {op} ".join([self.sql(expression.this), *segments]) 1625 return self.func(name, expression.this, *segments) 1626 1627 return _json_extract_segments
1637def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1638 cond = expression.expression 1639 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1640 alias = cond.expressions[0] 1641 cond = cond.this 1642 elif isinstance(cond, exp.Predicate): 1643 alias = "_u" 1644 else: 1645 self.unsupported("Unsupported filter condition") 1646 return "" 1647 1648 unnest = exp.Unnest(expressions=[expression.this]) 1649 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1650 return self.sql(exp.Array(expressions=[filtered]))
1662def build_default_decimal_type( 1663 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1664) -> t.Callable[[exp.DataType], exp.DataType]: 1665 def _builder(dtype: exp.DataType) -> exp.DataType: 1666 if dtype.expressions or precision is None: 1667 return dtype 1668 1669 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1670 return exp.DataType.build(f"DECIMAL({params})") 1671 1672 return _builder
1675def build_timestamp_from_parts(args: t.List) -> exp.Func: 1676 if len(args) == 2: 1677 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1678 # so we parse this into Anonymous for now instead of introducing complexity 1679 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1680 1681 return exp.TimestampFromParts.from_arg_list(args)
1688def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1689 start = expression.args.get("start") 1690 end = expression.args.get("end") 1691 step = expression.args.get("step") 1692 1693 if isinstance(start, exp.Cast): 1694 target_type = start.to 1695 elif isinstance(end, exp.Cast): 1696 target_type = end.to 1697 else: 1698 target_type = None 1699 1700 if start and end and target_type and target_type.is_type("date", "timestamp"): 1701 if isinstance(start, exp.Cast) and target_type is start.to: 1702 end = exp.cast(end, target_type) 1703 else: 1704 start = exp.cast(start, target_type) 1705 1706 return self.func("SEQUENCE", start, end, step)
1709def build_regexp_extract(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]: 1710 def _builder(args: t.List, dialect: Dialect) -> E: 1711 return expr_type( 1712 this=seq_get(args, 0), 1713 expression=seq_get(args, 1), 1714 group=seq_get(args, 2) or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP), 1715 parameters=seq_get(args, 3), 1716 ) 1717 1718 return _builder
1721def explode_to_unnest_sql(self: Generator, expression: exp.Lateral) -> str: 1722 if isinstance(expression.this, exp.Explode): 1723 return self.sql( 1724 exp.Join( 1725 this=exp.Unnest( 1726 expressions=[expression.this.this], 1727 alias=expression.args.get("alias"), 1728 offset=isinstance(expression.this, exp.Posexplode), 1729 ), 1730 kind="cross", 1731 ) 1732 ) 1733 return self.lateral_sql(expression)
1740def no_make_interval_sql(self: Generator, expression: exp.MakeInterval, sep: str = ", ") -> str: 1741 args = [] 1742 for unit, value in expression.args.items(): 1743 if isinstance(value, exp.Kwarg): 1744 value = value.expression 1745 1746 args.append(f"{value} {unit}") 1747 1748 return f"INTERVAL '{self.format_args(*args, sep=sep)}'"