sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator, unsupported_args 11from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses 12from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time, subsecond_precision 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26 from sqlglot.optimizer.annotate_types import TypeAnnotator 27 28 AnnotatorsType = t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]] 29 30logger = logging.getLogger("sqlglot") 31 32UNESCAPED_SEQUENCES = { 33 "\\a": "\a", 34 "\\b": "\b", 35 "\\f": "\f", 36 "\\n": "\n", 37 "\\r": "\r", 38 "\\t": "\t", 39 "\\v": "\v", 40 "\\\\": "\\", 41} 42 43 44def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 45 return lambda self, e: self._annotate_with_type(e, data_type) 46 47 48class Dialects(str, Enum): 49 """Dialects supported by SQLGLot.""" 50 51 DIALECT = "" 52 53 ATHENA = "athena" 54 BIGQUERY = "bigquery" 55 CLICKHOUSE = "clickhouse" 56 DATABRICKS = "databricks" 57 DORIS = "doris" 58 DRILL = "drill" 59 DUCKDB = "duckdb" 60 HIVE = "hive" 61 MATERIALIZE = "materialize" 62 MYSQL = "mysql" 63 ORACLE = "oracle" 64 POSTGRES = "postgres" 65 PRESTO = "presto" 66 PRQL = "prql" 67 REDSHIFT = "redshift" 68 RISINGWAVE = "risingwave" 69 SNOWFLAKE = "snowflake" 70 SPARK = "spark" 71 SPARK2 = "spark2" 72 SQLITE = "sqlite" 73 STARROCKS = "starrocks" 74 TABLEAU = "tableau" 75 TERADATA = "teradata" 76 TRINO = "trino" 77 TSQL = "tsql" 78 79 80class NormalizationStrategy(str, AutoName): 81 """Specifies the strategy according to which identifiers should be normalized.""" 82 83 LOWERCASE = auto() 84 """Unquoted identifiers are lowercased.""" 85 86 UPPERCASE = auto() 87 """Unquoted identifiers are uppercased.""" 88 89 CASE_SENSITIVE = auto() 90 """Always case-sensitive, regardless of quotes.""" 91 92 CASE_INSENSITIVE = auto() 93 """Always case-insensitive, regardless of quotes.""" 94 95 96class _Dialect(type): 97 classes: t.Dict[str, t.Type[Dialect]] = {} 98 99 def __eq__(cls, other: t.Any) -> bool: 100 if cls is other: 101 return True 102 if isinstance(other, str): 103 return cls is cls.get(other) 104 if isinstance(other, Dialect): 105 return cls is type(other) 106 107 return False 108 109 def __hash__(cls) -> int: 110 return hash(cls.__name__.lower()) 111 112 @classmethod 113 def __getitem__(cls, key: str) -> t.Type[Dialect]: 114 return cls.classes[key] 115 116 @classmethod 117 def get( 118 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 119 ) -> t.Optional[t.Type[Dialect]]: 120 return cls.classes.get(key, default) 121 122 def __new__(cls, clsname, bases, attrs): 123 klass = super().__new__(cls, clsname, bases, attrs) 124 enum = Dialects.__members__.get(clsname.upper()) 125 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 126 127 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 128 klass.FORMAT_TRIE = ( 129 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 130 ) 131 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 132 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 133 klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()} 134 klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING) 135 136 klass.INVERSE_CREATABLE_KIND_MAPPING = { 137 v: k for k, v in klass.CREATABLE_KIND_MAPPING.items() 138 } 139 140 base = seq_get(bases, 0) 141 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 142 base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),) 143 base_parser = (getattr(base, "parser_class", Parser),) 144 base_generator = (getattr(base, "generator_class", Generator),) 145 146 klass.tokenizer_class = klass.__dict__.get( 147 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 148 ) 149 klass.jsonpath_tokenizer_class = klass.__dict__.get( 150 "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {}) 151 ) 152 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 153 klass.generator_class = klass.__dict__.get( 154 "Generator", type("Generator", base_generator, {}) 155 ) 156 157 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 158 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 159 klass.tokenizer_class._IDENTIFIERS.items() 160 )[0] 161 162 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 163 return next( 164 ( 165 (s, e) 166 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 167 if t == token_type 168 ), 169 (None, None), 170 ) 171 172 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 173 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 174 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 175 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 176 177 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 178 klass.UNESCAPED_SEQUENCES = { 179 **UNESCAPED_SEQUENCES, 180 **klass.UNESCAPED_SEQUENCES, 181 } 182 183 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 184 185 klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS 186 187 if enum not in ("", "bigquery"): 188 klass.generator_class.SELECT_KINDS = () 189 190 if enum not in ("", "athena", "presto", "trino"): 191 klass.generator_class.TRY_SUPPORTED = False 192 klass.generator_class.SUPPORTS_UESCAPE = False 193 194 if enum not in ("", "databricks", "hive", "spark", "spark2"): 195 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 196 for modifier in ("cluster", "distribute", "sort"): 197 modifier_transforms.pop(modifier, None) 198 199 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 200 201 if enum not in ("", "doris", "mysql"): 202 klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { 203 TokenType.STRAIGHT_JOIN, 204 } 205 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 206 TokenType.STRAIGHT_JOIN, 207 } 208 209 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 210 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 211 TokenType.ANTI, 212 TokenType.SEMI, 213 } 214 215 return klass 216 217 218class Dialect(metaclass=_Dialect): 219 INDEX_OFFSET = 0 220 """The base index offset for arrays.""" 221 222 WEEK_OFFSET = 0 223 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 224 225 UNNEST_COLUMN_ONLY = False 226 """Whether `UNNEST` table aliases are treated as column aliases.""" 227 228 ALIAS_POST_TABLESAMPLE = False 229 """Whether the table alias comes after tablesample.""" 230 231 TABLESAMPLE_SIZE_IS_PERCENT = False 232 """Whether a size in the table sample clause represents percentage.""" 233 234 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 235 """Specifies the strategy according to which identifiers should be normalized.""" 236 237 IDENTIFIERS_CAN_START_WITH_DIGIT = False 238 """Whether an unquoted identifier can start with a digit.""" 239 240 DPIPE_IS_STRING_CONCAT = True 241 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 242 243 STRICT_STRING_CONCAT = False 244 """Whether `CONCAT`'s arguments must be strings.""" 245 246 SUPPORTS_USER_DEFINED_TYPES = True 247 """Whether user-defined data types are supported.""" 248 249 SUPPORTS_SEMI_ANTI_JOIN = True 250 """Whether `SEMI` or `ANTI` joins are supported.""" 251 252 SUPPORTS_COLUMN_JOIN_MARKS = False 253 """Whether the old-style outer join (+) syntax is supported.""" 254 255 COPY_PARAMS_ARE_CSV = True 256 """Separator of COPY statement parameters.""" 257 258 NORMALIZE_FUNCTIONS: bool | str = "upper" 259 """ 260 Determines how function names are going to be normalized. 261 Possible values: 262 "upper" or True: Convert names to uppercase. 263 "lower": Convert names to lowercase. 264 False: Disables function name normalization. 265 """ 266 267 LOG_BASE_FIRST: t.Optional[bool] = True 268 """ 269 Whether the base comes first in the `LOG` function. 270 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 271 """ 272 273 NULL_ORDERING = "nulls_are_small" 274 """ 275 Default `NULL` ordering method to use if not explicitly set. 276 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 277 """ 278 279 TYPED_DIVISION = False 280 """ 281 Whether the behavior of `a / b` depends on the types of `a` and `b`. 282 False means `a / b` is always float division. 283 True means `a / b` is integer division if both `a` and `b` are integers. 284 """ 285 286 SAFE_DIVISION = False 287 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 288 289 CONCAT_COALESCE = False 290 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 291 292 HEX_LOWERCASE = False 293 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 294 295 DATE_FORMAT = "'%Y-%m-%d'" 296 DATEINT_FORMAT = "'%Y%m%d'" 297 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 298 299 TIME_MAPPING: t.Dict[str, str] = {} 300 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 301 302 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 303 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 304 FORMAT_MAPPING: t.Dict[str, str] = {} 305 """ 306 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 307 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 308 """ 309 310 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 311 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 312 313 PSEUDOCOLUMNS: t.Set[str] = set() 314 """ 315 Columns that are auto-generated by the engine corresponding to this dialect. 316 For example, such columns may be excluded from `SELECT *` queries. 317 """ 318 319 PREFER_CTE_ALIAS_COLUMN = False 320 """ 321 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 322 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 323 any projection aliases in the subquery. 324 325 For example, 326 WITH y(c) AS ( 327 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 328 ) SELECT c FROM y; 329 330 will be rewritten as 331 332 WITH y(c) AS ( 333 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 334 ) SELECT c FROM y; 335 """ 336 337 COPY_PARAMS_ARE_CSV = True 338 """ 339 Whether COPY statement parameters are separated by comma or whitespace 340 """ 341 342 FORCE_EARLY_ALIAS_REF_EXPANSION = False 343 """ 344 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 345 346 For example: 347 WITH data AS ( 348 SELECT 349 1 AS id, 350 2 AS my_id 351 ) 352 SELECT 353 id AS my_id 354 FROM 355 data 356 WHERE 357 my_id = 1 358 GROUP BY 359 my_id, 360 HAVING 361 my_id = 1 362 363 In most dialects, "my_id" would refer to "data.my_id" across the query, except: 364 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e 365 it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 366 - Clickhouse, which will forward the alias across the query i.e it resolves 367 to "WHERE id = 1 GROUP BY id HAVING id = 1" 368 """ 369 370 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 371 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 372 373 SUPPORTS_ORDER_BY_ALL = False 374 """ 375 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 376 """ 377 378 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 379 """ 380 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 381 as the former is of type INT[] vs the latter which is SUPER 382 """ 383 384 SUPPORTS_FIXED_SIZE_ARRAYS = False 385 """ 386 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. 387 in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should 388 be interpreted as a subscript/index operator. 389 """ 390 391 STRICT_JSON_PATH_SYNTAX = True 392 """Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.""" 393 394 ON_CONDITION_EMPTY_BEFORE_ERROR = True 395 """Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).""" 396 397 ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True 398 """Whether ArrayAgg needs to filter NULL values.""" 399 400 REGEXP_EXTRACT_DEFAULT_GROUP = 0 401 """The default value for the capturing group.""" 402 403 SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = { 404 exp.Except: True, 405 exp.Intersect: True, 406 exp.Union: True, 407 } 408 """ 409 Whether a set operation uses DISTINCT by default. This is `None` when either `DISTINCT` or `ALL` 410 must be explicitly specified. 411 """ 412 413 CREATABLE_KIND_MAPPING: dict[str, str] = {} 414 """ 415 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 416 equivalent of CREATE SCHEMA is CREATE DATABASE. 417 """ 418 419 # --- Autofilled --- 420 421 tokenizer_class = Tokenizer 422 jsonpath_tokenizer_class = JSONPathTokenizer 423 parser_class = Parser 424 generator_class = Generator 425 426 # A trie of the time_mapping keys 427 TIME_TRIE: t.Dict = {} 428 FORMAT_TRIE: t.Dict = {} 429 430 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 431 INVERSE_TIME_TRIE: t.Dict = {} 432 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 433 INVERSE_FORMAT_TRIE: t.Dict = {} 434 435 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 436 437 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 438 439 # Delimiters for string literals and identifiers 440 QUOTE_START = "'" 441 QUOTE_END = "'" 442 IDENTIFIER_START = '"' 443 IDENTIFIER_END = '"' 444 445 # Delimiters for bit, hex, byte and unicode literals 446 BIT_START: t.Optional[str] = None 447 BIT_END: t.Optional[str] = None 448 HEX_START: t.Optional[str] = None 449 HEX_END: t.Optional[str] = None 450 BYTE_START: t.Optional[str] = None 451 BYTE_END: t.Optional[str] = None 452 UNICODE_START: t.Optional[str] = None 453 UNICODE_END: t.Optional[str] = None 454 455 DATE_PART_MAPPING = { 456 "Y": "YEAR", 457 "YY": "YEAR", 458 "YYY": "YEAR", 459 "YYYY": "YEAR", 460 "YR": "YEAR", 461 "YEARS": "YEAR", 462 "YRS": "YEAR", 463 "MM": "MONTH", 464 "MON": "MONTH", 465 "MONS": "MONTH", 466 "MONTHS": "MONTH", 467 "D": "DAY", 468 "DD": "DAY", 469 "DAYS": "DAY", 470 "DAYOFMONTH": "DAY", 471 "DAY OF WEEK": "DAYOFWEEK", 472 "WEEKDAY": "DAYOFWEEK", 473 "DOW": "DAYOFWEEK", 474 "DW": "DAYOFWEEK", 475 "WEEKDAY_ISO": "DAYOFWEEKISO", 476 "DOW_ISO": "DAYOFWEEKISO", 477 "DW_ISO": "DAYOFWEEKISO", 478 "DAY OF YEAR": "DAYOFYEAR", 479 "DOY": "DAYOFYEAR", 480 "DY": "DAYOFYEAR", 481 "W": "WEEK", 482 "WK": "WEEK", 483 "WEEKOFYEAR": "WEEK", 484 "WOY": "WEEK", 485 "WY": "WEEK", 486 "WEEK_ISO": "WEEKISO", 487 "WEEKOFYEARISO": "WEEKISO", 488 "WEEKOFYEAR_ISO": "WEEKISO", 489 "Q": "QUARTER", 490 "QTR": "QUARTER", 491 "QTRS": "QUARTER", 492 "QUARTERS": "QUARTER", 493 "H": "HOUR", 494 "HH": "HOUR", 495 "HR": "HOUR", 496 "HOURS": "HOUR", 497 "HRS": "HOUR", 498 "M": "MINUTE", 499 "MI": "MINUTE", 500 "MIN": "MINUTE", 501 "MINUTES": "MINUTE", 502 "MINS": "MINUTE", 503 "S": "SECOND", 504 "SEC": "SECOND", 505 "SECONDS": "SECOND", 506 "SECS": "SECOND", 507 "MS": "MILLISECOND", 508 "MSEC": "MILLISECOND", 509 "MSECS": "MILLISECOND", 510 "MSECOND": "MILLISECOND", 511 "MSECONDS": "MILLISECOND", 512 "MILLISEC": "MILLISECOND", 513 "MILLISECS": "MILLISECOND", 514 "MILLISECON": "MILLISECOND", 515 "MILLISECONDS": "MILLISECOND", 516 "US": "MICROSECOND", 517 "USEC": "MICROSECOND", 518 "USECS": "MICROSECOND", 519 "MICROSEC": "MICROSECOND", 520 "MICROSECS": "MICROSECOND", 521 "USECOND": "MICROSECOND", 522 "USECONDS": "MICROSECOND", 523 "MICROSECONDS": "MICROSECOND", 524 "NS": "NANOSECOND", 525 "NSEC": "NANOSECOND", 526 "NANOSEC": "NANOSECOND", 527 "NSECOND": "NANOSECOND", 528 "NSECONDS": "NANOSECOND", 529 "NANOSECS": "NANOSECOND", 530 "EPOCH_SECOND": "EPOCH", 531 "EPOCH_SECONDS": "EPOCH", 532 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 533 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 534 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 535 "TZH": "TIMEZONE_HOUR", 536 "TZM": "TIMEZONE_MINUTE", 537 "DEC": "DECADE", 538 "DECS": "DECADE", 539 "DECADES": "DECADE", 540 "MIL": "MILLENIUM", 541 "MILS": "MILLENIUM", 542 "MILLENIA": "MILLENIUM", 543 "C": "CENTURY", 544 "CENT": "CENTURY", 545 "CENTS": "CENTURY", 546 "CENTURIES": "CENTURY", 547 } 548 549 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 550 exp.DataType.Type.BIGINT: { 551 exp.ApproxDistinct, 552 exp.ArraySize, 553 exp.Length, 554 }, 555 exp.DataType.Type.BOOLEAN: { 556 exp.Between, 557 exp.Boolean, 558 exp.In, 559 exp.RegexpLike, 560 }, 561 exp.DataType.Type.DATE: { 562 exp.CurrentDate, 563 exp.Date, 564 exp.DateFromParts, 565 exp.DateStrToDate, 566 exp.DiToDate, 567 exp.StrToDate, 568 exp.TimeStrToDate, 569 exp.TsOrDsToDate, 570 }, 571 exp.DataType.Type.DATETIME: { 572 exp.CurrentDatetime, 573 exp.Datetime, 574 exp.DatetimeAdd, 575 exp.DatetimeSub, 576 }, 577 exp.DataType.Type.DOUBLE: { 578 exp.ApproxQuantile, 579 exp.Avg, 580 exp.Div, 581 exp.Exp, 582 exp.Ln, 583 exp.Log, 584 exp.Pow, 585 exp.Quantile, 586 exp.Round, 587 exp.SafeDivide, 588 exp.Sqrt, 589 exp.Stddev, 590 exp.StddevPop, 591 exp.StddevSamp, 592 exp.Variance, 593 exp.VariancePop, 594 }, 595 exp.DataType.Type.INT: { 596 exp.Ceil, 597 exp.DatetimeDiff, 598 exp.DateDiff, 599 exp.TimestampDiff, 600 exp.TimeDiff, 601 exp.DateToDi, 602 exp.Levenshtein, 603 exp.Sign, 604 exp.StrPosition, 605 exp.TsOrDiToDi, 606 }, 607 exp.DataType.Type.JSON: { 608 exp.ParseJSON, 609 }, 610 exp.DataType.Type.TIME: { 611 exp.Time, 612 }, 613 exp.DataType.Type.TIMESTAMP: { 614 exp.CurrentTime, 615 exp.CurrentTimestamp, 616 exp.StrToTime, 617 exp.TimeAdd, 618 exp.TimeStrToTime, 619 exp.TimeSub, 620 exp.TimestampAdd, 621 exp.TimestampSub, 622 exp.UnixToTime, 623 }, 624 exp.DataType.Type.TINYINT: { 625 exp.Day, 626 exp.Month, 627 exp.Week, 628 exp.Year, 629 exp.Quarter, 630 }, 631 exp.DataType.Type.VARCHAR: { 632 exp.ArrayConcat, 633 exp.Concat, 634 exp.ConcatWs, 635 exp.DateToDateStr, 636 exp.GroupConcat, 637 exp.Initcap, 638 exp.Lower, 639 exp.Substring, 640 exp.TimeToStr, 641 exp.TimeToTimeStr, 642 exp.Trim, 643 exp.TsOrDsToDateStr, 644 exp.UnixToStr, 645 exp.UnixToTimeStr, 646 exp.Upper, 647 }, 648 } 649 650 ANNOTATORS: AnnotatorsType = { 651 **{ 652 expr_type: lambda self, e: self._annotate_unary(e) 653 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 654 }, 655 **{ 656 expr_type: lambda self, e: self._annotate_binary(e) 657 for expr_type in subclasses(exp.__name__, exp.Binary) 658 }, 659 **{ 660 expr_type: _annotate_with_type_lambda(data_type) 661 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 662 for expr_type in expressions 663 }, 664 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 665 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 666 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 667 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 668 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 669 exp.Bracket: lambda self, e: self._annotate_bracket(e), 670 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 671 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 672 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 673 exp.Count: lambda self, e: self._annotate_with_type( 674 e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT 675 ), 676 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 677 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 678 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 679 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 680 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 681 exp.Div: lambda self, e: self._annotate_div(e), 682 exp.Dot: lambda self, e: self._annotate_dot(e), 683 exp.Explode: lambda self, e: self._annotate_explode(e), 684 exp.Extract: lambda self, e: self._annotate_extract(e), 685 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 686 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 687 e, exp.DataType.build("ARRAY<DATE>") 688 ), 689 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 690 e, exp.DataType.build("ARRAY<TIMESTAMP>") 691 ), 692 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 693 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 694 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 695 exp.Literal: lambda self, e: self._annotate_literal(e), 696 exp.Map: lambda self, e: self._annotate_map(e), 697 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 698 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 699 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 700 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 701 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 702 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 703 exp.Struct: lambda self, e: self._annotate_struct(e), 704 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 705 exp.Timestamp: lambda self, e: self._annotate_with_type( 706 e, 707 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 708 ), 709 exp.ToMap: lambda self, e: self._annotate_to_map(e), 710 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 711 exp.Unnest: lambda self, e: self._annotate_unnest(e), 712 exp.VarMap: lambda self, e: self._annotate_map(e), 713 } 714 715 @classmethod 716 def get_or_raise(cls, dialect: DialectType) -> Dialect: 717 """ 718 Look up a dialect in the global dialect registry and return it if it exists. 719 720 Args: 721 dialect: The target dialect. If this is a string, it can be optionally followed by 722 additional key-value pairs that are separated by commas and are used to specify 723 dialect settings, such as whether the dialect's identifiers are case-sensitive. 724 725 Example: 726 >>> dialect = dialect_class = get_or_raise("duckdb") 727 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 728 729 Returns: 730 The corresponding Dialect instance. 731 """ 732 733 if not dialect: 734 return cls() 735 if isinstance(dialect, _Dialect): 736 return dialect() 737 if isinstance(dialect, Dialect): 738 return dialect 739 if isinstance(dialect, str): 740 try: 741 dialect_name, *kv_strings = dialect.split(",") 742 kv_pairs = (kv.split("=") for kv in kv_strings) 743 kwargs = {} 744 for pair in kv_pairs: 745 key = pair[0].strip() 746 value: t.Union[bool | str | None] = None 747 748 if len(pair) == 1: 749 # Default initialize standalone settings to True 750 value = True 751 elif len(pair) == 2: 752 value = pair[1].strip() 753 754 # Coerce the value to boolean if it matches to the truthy/falsy values below 755 value_lower = value.lower() 756 if value_lower in ("true", "1"): 757 value = True 758 elif value_lower in ("false", "0"): 759 value = False 760 761 kwargs[key] = value 762 763 except ValueError: 764 raise ValueError( 765 f"Invalid dialect format: '{dialect}'. " 766 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 767 ) 768 769 result = cls.get(dialect_name.strip()) 770 if not result: 771 from difflib import get_close_matches 772 773 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 774 if similar: 775 similar = f" Did you mean {similar}?" 776 777 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 778 779 return result(**kwargs) 780 781 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 782 783 @classmethod 784 def format_time( 785 cls, expression: t.Optional[str | exp.Expression] 786 ) -> t.Optional[exp.Expression]: 787 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 788 if isinstance(expression, str): 789 return exp.Literal.string( 790 # the time formats are quoted 791 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 792 ) 793 794 if expression and expression.is_string: 795 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 796 797 return expression 798 799 def __init__(self, **kwargs) -> None: 800 normalization_strategy = kwargs.pop("normalization_strategy", None) 801 802 if normalization_strategy is None: 803 self.normalization_strategy = self.NORMALIZATION_STRATEGY 804 else: 805 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 806 807 self.settings = kwargs 808 809 def __eq__(self, other: t.Any) -> bool: 810 # Does not currently take dialect state into account 811 return type(self) == other 812 813 def __hash__(self) -> int: 814 # Does not currently take dialect state into account 815 return hash(type(self)) 816 817 def normalize_identifier(self, expression: E) -> E: 818 """ 819 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 820 821 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 822 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 823 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 824 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 825 826 There are also dialects like Spark, which are case-insensitive even when quotes are 827 present, and dialects like MySQL, whose resolution rules match those employed by the 828 underlying operating system, for example they may always be case-sensitive in Linux. 829 830 Finally, the normalization behavior of some engines can even be controlled through flags, 831 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 832 833 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 834 that it can analyze queries in the optimizer and successfully capture their semantics. 835 """ 836 if ( 837 isinstance(expression, exp.Identifier) 838 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 839 and ( 840 not expression.quoted 841 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 842 ) 843 ): 844 expression.set( 845 "this", 846 ( 847 expression.this.upper() 848 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 849 else expression.this.lower() 850 ), 851 ) 852 853 return expression 854 855 def case_sensitive(self, text: str) -> bool: 856 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 857 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 858 return False 859 860 unsafe = ( 861 str.islower 862 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 863 else str.isupper 864 ) 865 return any(unsafe(char) for char in text) 866 867 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 868 """Checks if text can be identified given an identify option. 869 870 Args: 871 text: The text to check. 872 identify: 873 `"always"` or `True`: Always returns `True`. 874 `"safe"`: Only returns `True` if the identifier is case-insensitive. 875 876 Returns: 877 Whether the given text can be identified. 878 """ 879 if identify is True or identify == "always": 880 return True 881 882 if identify == "safe": 883 return not self.case_sensitive(text) 884 885 return False 886 887 def quote_identifier(self, expression: E, identify: bool = True) -> E: 888 """ 889 Adds quotes to a given identifier. 890 891 Args: 892 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 893 identify: If set to `False`, the quotes will only be added if the identifier is deemed 894 "unsafe", with respect to its characters and this dialect's normalization strategy. 895 """ 896 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 897 name = expression.this 898 expression.set( 899 "quoted", 900 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 901 ) 902 903 return expression 904 905 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 906 if isinstance(path, exp.Literal): 907 path_text = path.name 908 if path.is_number: 909 path_text = f"[{path_text}]" 910 try: 911 return parse_json_path(path_text, self) 912 except ParseError as e: 913 if self.STRICT_JSON_PATH_SYNTAX: 914 logger.warning(f"Invalid JSON path syntax. {str(e)}") 915 916 return path 917 918 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 919 return self.parser(**opts).parse(self.tokenize(sql), sql) 920 921 def parse_into( 922 self, expression_type: exp.IntoType, sql: str, **opts 923 ) -> t.List[t.Optional[exp.Expression]]: 924 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 925 926 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 927 return self.generator(**opts).generate(expression, copy=copy) 928 929 def transpile(self, sql: str, **opts) -> t.List[str]: 930 return [ 931 self.generate(expression, copy=False, **opts) if expression else "" 932 for expression in self.parse(sql) 933 ] 934 935 def tokenize(self, sql: str) -> t.List[Token]: 936 return self.tokenizer.tokenize(sql) 937 938 @property 939 def tokenizer(self) -> Tokenizer: 940 return self.tokenizer_class(dialect=self) 941 942 @property 943 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 944 return self.jsonpath_tokenizer_class(dialect=self) 945 946 def parser(self, **opts) -> Parser: 947 return self.parser_class(dialect=self, **opts) 948 949 def generator(self, **opts) -> Generator: 950 return self.generator_class(dialect=self, **opts) 951 952 953DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 954 955 956def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 957 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 958 959 960@unsupported_args("accuracy") 961def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 962 return self.func("APPROX_COUNT_DISTINCT", expression.this) 963 964 965def if_sql( 966 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 967) -> t.Callable[[Generator, exp.If], str]: 968 def _if_sql(self: Generator, expression: exp.If) -> str: 969 return self.func( 970 name, 971 expression.this, 972 expression.args.get("true"), 973 expression.args.get("false") or false_value, 974 ) 975 976 return _if_sql 977 978 979def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 980 this = expression.this 981 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 982 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 983 984 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 985 986 987def inline_array_sql(self: Generator, expression: exp.Array) -> str: 988 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 989 990 991def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 992 elem = seq_get(expression.expressions, 0) 993 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 994 return self.func("ARRAY", elem) 995 return inline_array_sql(self, expression) 996 997 998def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 999 return self.like_sql( 1000 exp.Like( 1001 this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) 1002 ) 1003 ) 1004 1005 1006def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 1007 zone = self.sql(expression, "this") 1008 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 1009 1010 1011def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 1012 if expression.args.get("recursive"): 1013 self.unsupported("Recursive CTEs are unsupported") 1014 expression.args["recursive"] = False 1015 return self.with_sql(expression) 1016 1017 1018def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 1019 n = self.sql(expression, "this") 1020 d = self.sql(expression, "expression") 1021 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 1022 1023 1024def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 1025 self.unsupported("TABLESAMPLE unsupported") 1026 return self.sql(expression.this) 1027 1028 1029def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 1030 self.unsupported("PIVOT unsupported") 1031 return "" 1032 1033 1034def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 1035 return self.cast_sql(expression) 1036 1037 1038def no_comment_column_constraint_sql( 1039 self: Generator, expression: exp.CommentColumnConstraint 1040) -> str: 1041 self.unsupported("CommentColumnConstraint unsupported") 1042 return "" 1043 1044 1045def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 1046 self.unsupported("MAP_FROM_ENTRIES unsupported") 1047 return "" 1048 1049 1050def property_sql(self: Generator, expression: exp.Property) -> str: 1051 return f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}" 1052 1053 1054def str_position_sql( 1055 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 1056) -> str: 1057 this = self.sql(expression, "this") 1058 substr = self.sql(expression, "substr") 1059 position = self.sql(expression, "position") 1060 instance = expression.args.get("instance") if generate_instance else None 1061 position_offset = "" 1062 1063 if position: 1064 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1065 this = self.func("SUBSTR", this, position) 1066 position_offset = f" + {position} - 1" 1067 1068 return self.func("STRPOS", this, substr, instance) + position_offset 1069 1070 1071def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 1072 return ( 1073 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 1074 ) 1075 1076 1077def var_map_sql( 1078 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1079) -> str: 1080 keys = expression.args["keys"] 1081 values = expression.args["values"] 1082 1083 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1084 self.unsupported("Cannot convert array columns into map.") 1085 return self.func(map_func_name, keys, values) 1086 1087 args = [] 1088 for key, value in zip(keys.expressions, values.expressions): 1089 args.append(self.sql(key)) 1090 args.append(self.sql(value)) 1091 1092 return self.func(map_func_name, *args) 1093 1094 1095def build_formatted_time( 1096 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1097) -> t.Callable[[t.List], E]: 1098 """Helper used for time expressions. 1099 1100 Args: 1101 exp_class: the expression class to instantiate. 1102 dialect: target sql dialect. 1103 default: the default format, True being time. 1104 1105 Returns: 1106 A callable that can be used to return the appropriately formatted time expression. 1107 """ 1108 1109 def _builder(args: t.List): 1110 return exp_class( 1111 this=seq_get(args, 0), 1112 format=Dialect[dialect].format_time( 1113 seq_get(args, 1) 1114 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1115 ), 1116 ) 1117 1118 return _builder 1119 1120 1121def time_format( 1122 dialect: DialectType = None, 1123) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1124 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1125 """ 1126 Returns the time format for a given expression, unless it's equivalent 1127 to the default time format of the dialect of interest. 1128 """ 1129 time_format = self.format_time(expression) 1130 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1131 1132 return _time_format 1133 1134 1135def build_date_delta( 1136 exp_class: t.Type[E], 1137 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1138 default_unit: t.Optional[str] = "DAY", 1139) -> t.Callable[[t.List], E]: 1140 def _builder(args: t.List) -> E: 1141 unit_based = len(args) == 3 1142 this = args[2] if unit_based else seq_get(args, 0) 1143 unit = None 1144 if unit_based or default_unit: 1145 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1146 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1147 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1148 1149 return _builder 1150 1151 1152def build_date_delta_with_interval( 1153 expression_class: t.Type[E], 1154) -> t.Callable[[t.List], t.Optional[E]]: 1155 def _builder(args: t.List) -> t.Optional[E]: 1156 if len(args) < 2: 1157 return None 1158 1159 interval = args[1] 1160 1161 if not isinstance(interval, exp.Interval): 1162 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1163 1164 return expression_class(this=args[0], expression=interval.this, unit=unit_to_str(interval)) 1165 1166 return _builder 1167 1168 1169def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1170 unit = seq_get(args, 0) 1171 this = seq_get(args, 1) 1172 1173 if isinstance(this, exp.Cast) and this.is_type("date"): 1174 return exp.DateTrunc(unit=unit, this=this) 1175 return exp.TimestampTrunc(this=this, unit=unit) 1176 1177 1178def date_add_interval_sql( 1179 data_type: str, kind: str 1180) -> t.Callable[[Generator, exp.Expression], str]: 1181 def func(self: Generator, expression: exp.Expression) -> str: 1182 this = self.sql(expression, "this") 1183 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1184 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1185 1186 return func 1187 1188 1189def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1190 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1191 args = [unit_to_str(expression), expression.this] 1192 if zone: 1193 args.append(expression.args.get("zone")) 1194 return self.func("DATE_TRUNC", *args) 1195 1196 return _timestamptrunc_sql 1197 1198 1199def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1200 zone = expression.args.get("zone") 1201 if not zone: 1202 from sqlglot.optimizer.annotate_types import annotate_types 1203 1204 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1205 return self.sql(exp.cast(expression.this, target_type)) 1206 if zone.name.lower() in TIMEZONES: 1207 return self.sql( 1208 exp.AtTimeZone( 1209 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1210 zone=zone, 1211 ) 1212 ) 1213 return self.func("TIMESTAMP", expression.this, zone) 1214 1215 1216def no_time_sql(self: Generator, expression: exp.Time) -> str: 1217 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1218 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1219 expr = exp.cast( 1220 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1221 ) 1222 return self.sql(expr) 1223 1224 1225def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1226 this = expression.this 1227 expr = expression.expression 1228 1229 if expr.name.lower() in TIMEZONES: 1230 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1231 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1232 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1233 return self.sql(this) 1234 1235 this = exp.cast(this, exp.DataType.Type.DATE) 1236 expr = exp.cast(expr, exp.DataType.Type.TIME) 1237 1238 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP)) 1239 1240 1241def locate_to_strposition(args: t.List) -> exp.Expression: 1242 return exp.StrPosition( 1243 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 1244 ) 1245 1246 1247def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 1248 return self.func( 1249 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 1250 ) 1251 1252 1253def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1254 return self.sql( 1255 exp.Substring( 1256 this=expression.this, start=exp.Literal.number(1), length=expression.expression 1257 ) 1258 ) 1259 1260 1261def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1262 return self.sql( 1263 exp.Substring( 1264 this=expression.this, 1265 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 1266 ) 1267 ) 1268 1269 1270def timestrtotime_sql( 1271 self: Generator, 1272 expression: exp.TimeStrToTime, 1273 include_precision: bool = False, 1274) -> str: 1275 datatype = exp.DataType.build( 1276 exp.DataType.Type.TIMESTAMPTZ 1277 if expression.args.get("zone") 1278 else exp.DataType.Type.TIMESTAMP 1279 ) 1280 1281 if isinstance(expression.this, exp.Literal) and include_precision: 1282 precision = subsecond_precision(expression.this.name) 1283 if precision > 0: 1284 datatype = exp.DataType.build( 1285 datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))] 1286 ) 1287 1288 return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect)) 1289 1290 1291def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 1292 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 1293 1294 1295# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 1296def encode_decode_sql( 1297 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1298) -> str: 1299 charset = expression.args.get("charset") 1300 if charset and charset.name.lower() != "utf-8": 1301 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1302 1303 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 1304 1305 1306def min_or_least(self: Generator, expression: exp.Min) -> str: 1307 name = "LEAST" if expression.expressions else "MIN" 1308 return rename_func(name)(self, expression) 1309 1310 1311def max_or_greatest(self: Generator, expression: exp.Max) -> str: 1312 name = "GREATEST" if expression.expressions else "MAX" 1313 return rename_func(name)(self, expression) 1314 1315 1316def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1317 cond = expression.this 1318 1319 if isinstance(expression.this, exp.Distinct): 1320 cond = expression.this.expressions[0] 1321 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1322 1323 return self.func("sum", exp.func("if", cond, 1, 0)) 1324 1325 1326def trim_sql(self: Generator, expression: exp.Trim) -> str: 1327 target = self.sql(expression, "this") 1328 trim_type = self.sql(expression, "position") 1329 remove_chars = self.sql(expression, "expression") 1330 collation = self.sql(expression, "collation") 1331 1332 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1333 if not remove_chars: 1334 return self.trim_sql(expression) 1335 1336 trim_type = f"{trim_type} " if trim_type else "" 1337 remove_chars = f"{remove_chars} " if remove_chars else "" 1338 from_part = "FROM " if trim_type or remove_chars else "" 1339 collation = f" COLLATE {collation}" if collation else "" 1340 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 1341 1342 1343def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 1344 return self.func("STRPTIME", expression.this, self.format_time(expression)) 1345 1346 1347def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 1348 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 1349 1350 1351def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 1352 delim, *rest_args = expression.expressions 1353 return self.sql( 1354 reduce( 1355 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 1356 rest_args, 1357 ) 1358 ) 1359 1360 1361@unsupported_args("position", "occurrence", "parameters") 1362def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1363 group = expression.args.get("group") 1364 1365 # Do not render group if it's the default value for this dialect 1366 if group and group.name == str(self.dialect.REGEXP_EXTRACT_DEFAULT_GROUP): 1367 group = None 1368 1369 return self.func("REGEXP_EXTRACT", expression.this, expression.expression, group) 1370 1371 1372@unsupported_args("position", "occurrence", "modifiers") 1373def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1374 return self.func( 1375 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1376 ) 1377 1378 1379def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1380 names = [] 1381 for agg in aggregations: 1382 if isinstance(agg, exp.Alias): 1383 names.append(agg.alias) 1384 else: 1385 """ 1386 This case corresponds to aggregations without aliases being used as suffixes 1387 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1388 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1389 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1390 """ 1391 agg_all_unquoted = agg.transform( 1392 lambda node: ( 1393 exp.Identifier(this=node.name, quoted=False) 1394 if isinstance(node, exp.Identifier) 1395 else node 1396 ) 1397 ) 1398 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1399 1400 return names 1401 1402 1403def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 1404 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 1405 1406 1407# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 1408def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 1409 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 1410 1411 1412def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 1413 return self.func("MAX", expression.this) 1414 1415 1416def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 1417 a = self.sql(expression.left) 1418 b = self.sql(expression.right) 1419 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 1420 1421 1422def is_parse_json(expression: exp.Expression) -> bool: 1423 return isinstance(expression, exp.ParseJSON) or ( 1424 isinstance(expression, exp.Cast) and expression.is_type("json") 1425 ) 1426 1427 1428def isnull_to_is_null(args: t.List) -> exp.Expression: 1429 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 1430 1431 1432def generatedasidentitycolumnconstraint_sql( 1433 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 1434) -> str: 1435 start = self.sql(expression, "start") or "1" 1436 increment = self.sql(expression, "increment") or "1" 1437 return f"IDENTITY({start}, {increment})" 1438 1439 1440def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1441 @unsupported_args("count") 1442 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1443 return self.func(name, expression.this, expression.expression) 1444 1445 return _arg_max_or_min_sql 1446 1447 1448def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1449 this = expression.this.copy() 1450 1451 return_type = expression.return_type 1452 if return_type.is_type(exp.DataType.Type.DATE): 1453 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1454 # can truncate timestamp strings, because some dialects can't cast them to DATE 1455 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1456 1457 expression.this.replace(exp.cast(this, return_type)) 1458 return expression 1459 1460 1461def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1462 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1463 if cast and isinstance(expression, exp.TsOrDsAdd): 1464 expression = ts_or_ds_add_cast(expression) 1465 1466 return self.func( 1467 name, 1468 unit_to_var(expression), 1469 expression.expression, 1470 expression.this, 1471 ) 1472 1473 return _delta_sql 1474 1475 1476def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1477 unit = expression.args.get("unit") 1478 1479 if isinstance(unit, exp.Placeholder): 1480 return unit 1481 if unit: 1482 return exp.Literal.string(unit.name) 1483 return exp.Literal.string(default) if default else None 1484 1485 1486def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1487 unit = expression.args.get("unit") 1488 1489 if isinstance(unit, (exp.Var, exp.Placeholder)): 1490 return unit 1491 return exp.Var(this=default) if default else None 1492 1493 1494@t.overload 1495def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var: 1496 pass 1497 1498 1499@t.overload 1500def map_date_part( 1501 part: t.Optional[exp.Expression], dialect: DialectType = Dialect 1502) -> t.Optional[exp.Expression]: 1503 pass 1504 1505 1506def map_date_part(part, dialect: DialectType = Dialect): 1507 mapped = ( 1508 Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None 1509 ) 1510 return exp.var(mapped) if mapped else part 1511 1512 1513def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1514 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1515 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1516 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1517 1518 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1519 1520 1521def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1522 """Remove table refs from columns in when statements.""" 1523 alias = expression.this.args.get("alias") 1524 1525 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1526 return self.dialect.normalize_identifier(identifier).name if identifier else None 1527 1528 targets = {normalize(expression.this.this)} 1529 1530 if alias: 1531 targets.add(normalize(alias.this)) 1532 1533 for when in expression.expressions: 1534 # only remove the target names from the THEN clause 1535 # theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED 1536 # ref: https://github.com/TobikoData/sqlmesh/issues/2934 1537 then = when.args.get("then") 1538 if then: 1539 then.transform( 1540 lambda node: ( 1541 exp.column(node.this) 1542 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1543 else node 1544 ), 1545 copy=False, 1546 ) 1547 1548 return self.merge_sql(expression) 1549 1550 1551def build_json_extract_path( 1552 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1553) -> t.Callable[[t.List], F]: 1554 def _builder(args: t.List) -> F: 1555 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1556 for arg in args[1:]: 1557 if not isinstance(arg, exp.Literal): 1558 # We use the fallback parser because we can't really transpile non-literals safely 1559 return expr_type.from_arg_list(args) 1560 1561 text = arg.name 1562 if is_int(text): 1563 index = int(text) 1564 segments.append( 1565 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1566 ) 1567 else: 1568 segments.append(exp.JSONPathKey(this=text)) 1569 1570 # This is done to avoid failing in the expression validator due to the arg count 1571 del args[2:] 1572 return expr_type( 1573 this=seq_get(args, 0), 1574 expression=exp.JSONPath(expressions=segments), 1575 only_json_types=arrow_req_json_type, 1576 ) 1577 1578 return _builder 1579 1580 1581def json_extract_segments( 1582 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1583) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1584 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1585 path = expression.expression 1586 if not isinstance(path, exp.JSONPath): 1587 return rename_func(name)(self, expression) 1588 1589 escape = path.args.get("escape") 1590 1591 segments = [] 1592 for segment in path.expressions: 1593 path = self.sql(segment) 1594 if path: 1595 if isinstance(segment, exp.JSONPathPart) and ( 1596 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1597 ): 1598 if escape: 1599 path = self.escape_str(path) 1600 1601 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1602 1603 segments.append(path) 1604 1605 if op: 1606 return f" {op} ".join([self.sql(expression.this), *segments]) 1607 return self.func(name, expression.this, *segments) 1608 1609 return _json_extract_segments 1610 1611 1612def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1613 if isinstance(expression.this, exp.JSONPathWildcard): 1614 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1615 1616 return expression.name 1617 1618 1619def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1620 cond = expression.expression 1621 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1622 alias = cond.expressions[0] 1623 cond = cond.this 1624 elif isinstance(cond, exp.Predicate): 1625 alias = "_u" 1626 else: 1627 self.unsupported("Unsupported filter condition") 1628 return "" 1629 1630 unnest = exp.Unnest(expressions=[expression.this]) 1631 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1632 return self.sql(exp.Array(expressions=[filtered])) 1633 1634 1635def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1636 return self.func( 1637 "TO_NUMBER", 1638 expression.this, 1639 expression.args.get("format"), 1640 expression.args.get("nlsparam"), 1641 ) 1642 1643 1644def build_default_decimal_type( 1645 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1646) -> t.Callable[[exp.DataType], exp.DataType]: 1647 def _builder(dtype: exp.DataType) -> exp.DataType: 1648 if dtype.expressions or precision is None: 1649 return dtype 1650 1651 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1652 return exp.DataType.build(f"DECIMAL({params})") 1653 1654 return _builder 1655 1656 1657def build_timestamp_from_parts(args: t.List) -> exp.Func: 1658 if len(args) == 2: 1659 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1660 # so we parse this into Anonymous for now instead of introducing complexity 1661 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1662 1663 return exp.TimestampFromParts.from_arg_list(args) 1664 1665 1666def sha256_sql(self: Generator, expression: exp.SHA2) -> str: 1667 return self.func(f"SHA{expression.text('length') or '256'}", expression.this) 1668 1669 1670def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1671 start = expression.args.get("start") 1672 end = expression.args.get("end") 1673 step = expression.args.get("step") 1674 1675 if isinstance(start, exp.Cast): 1676 target_type = start.to 1677 elif isinstance(end, exp.Cast): 1678 target_type = end.to 1679 else: 1680 target_type = None 1681 1682 if start and end and target_type and target_type.is_type("date", "timestamp"): 1683 if isinstance(start, exp.Cast) and target_type is start.to: 1684 end = exp.cast(end, target_type) 1685 else: 1686 start = exp.cast(start, target_type) 1687 1688 return self.func("SEQUENCE", start, end, step) 1689 1690 1691def build_regexp_extract(args: t.List, dialect: Dialect) -> exp.RegexpExtract: 1692 return exp.RegexpExtract( 1693 this=seq_get(args, 0), 1694 expression=seq_get(args, 1), 1695 group=seq_get(args, 2) or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP), 1696 )
49class Dialects(str, Enum): 50 """Dialects supported by SQLGLot.""" 51 52 DIALECT = "" 53 54 ATHENA = "athena" 55 BIGQUERY = "bigquery" 56 CLICKHOUSE = "clickhouse" 57 DATABRICKS = "databricks" 58 DORIS = "doris" 59 DRILL = "drill" 60 DUCKDB = "duckdb" 61 HIVE = "hive" 62 MATERIALIZE = "materialize" 63 MYSQL = "mysql" 64 ORACLE = "oracle" 65 POSTGRES = "postgres" 66 PRESTO = "presto" 67 PRQL = "prql" 68 REDSHIFT = "redshift" 69 RISINGWAVE = "risingwave" 70 SNOWFLAKE = "snowflake" 71 SPARK = "spark" 72 SPARK2 = "spark2" 73 SQLITE = "sqlite" 74 STARROCKS = "starrocks" 75 TABLEAU = "tableau" 76 TERADATA = "teradata" 77 TRINO = "trino" 78 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
81class NormalizationStrategy(str, AutoName): 82 """Specifies the strategy according to which identifiers should be normalized.""" 83 84 LOWERCASE = auto() 85 """Unquoted identifiers are lowercased.""" 86 87 UPPERCASE = auto() 88 """Unquoted identifiers are uppercased.""" 89 90 CASE_SENSITIVE = auto() 91 """Always case-sensitive, regardless of quotes.""" 92 93 CASE_INSENSITIVE = auto() 94 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
219class Dialect(metaclass=_Dialect): 220 INDEX_OFFSET = 0 221 """The base index offset for arrays.""" 222 223 WEEK_OFFSET = 0 224 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 225 226 UNNEST_COLUMN_ONLY = False 227 """Whether `UNNEST` table aliases are treated as column aliases.""" 228 229 ALIAS_POST_TABLESAMPLE = False 230 """Whether the table alias comes after tablesample.""" 231 232 TABLESAMPLE_SIZE_IS_PERCENT = False 233 """Whether a size in the table sample clause represents percentage.""" 234 235 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 236 """Specifies the strategy according to which identifiers should be normalized.""" 237 238 IDENTIFIERS_CAN_START_WITH_DIGIT = False 239 """Whether an unquoted identifier can start with a digit.""" 240 241 DPIPE_IS_STRING_CONCAT = True 242 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 243 244 STRICT_STRING_CONCAT = False 245 """Whether `CONCAT`'s arguments must be strings.""" 246 247 SUPPORTS_USER_DEFINED_TYPES = True 248 """Whether user-defined data types are supported.""" 249 250 SUPPORTS_SEMI_ANTI_JOIN = True 251 """Whether `SEMI` or `ANTI` joins are supported.""" 252 253 SUPPORTS_COLUMN_JOIN_MARKS = False 254 """Whether the old-style outer join (+) syntax is supported.""" 255 256 COPY_PARAMS_ARE_CSV = True 257 """Separator of COPY statement parameters.""" 258 259 NORMALIZE_FUNCTIONS: bool | str = "upper" 260 """ 261 Determines how function names are going to be normalized. 262 Possible values: 263 "upper" or True: Convert names to uppercase. 264 "lower": Convert names to lowercase. 265 False: Disables function name normalization. 266 """ 267 268 LOG_BASE_FIRST: t.Optional[bool] = True 269 """ 270 Whether the base comes first in the `LOG` function. 271 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 272 """ 273 274 NULL_ORDERING = "nulls_are_small" 275 """ 276 Default `NULL` ordering method to use if not explicitly set. 277 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 278 """ 279 280 TYPED_DIVISION = False 281 """ 282 Whether the behavior of `a / b` depends on the types of `a` and `b`. 283 False means `a / b` is always float division. 284 True means `a / b` is integer division if both `a` and `b` are integers. 285 """ 286 287 SAFE_DIVISION = False 288 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 289 290 CONCAT_COALESCE = False 291 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 292 293 HEX_LOWERCASE = False 294 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 295 296 DATE_FORMAT = "'%Y-%m-%d'" 297 DATEINT_FORMAT = "'%Y%m%d'" 298 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 299 300 TIME_MAPPING: t.Dict[str, str] = {} 301 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 302 303 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 304 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 305 FORMAT_MAPPING: t.Dict[str, str] = {} 306 """ 307 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 308 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 309 """ 310 311 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 312 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 313 314 PSEUDOCOLUMNS: t.Set[str] = set() 315 """ 316 Columns that are auto-generated by the engine corresponding to this dialect. 317 For example, such columns may be excluded from `SELECT *` queries. 318 """ 319 320 PREFER_CTE_ALIAS_COLUMN = False 321 """ 322 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 323 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 324 any projection aliases in the subquery. 325 326 For example, 327 WITH y(c) AS ( 328 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 329 ) SELECT c FROM y; 330 331 will be rewritten as 332 333 WITH y(c) AS ( 334 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 335 ) SELECT c FROM y; 336 """ 337 338 COPY_PARAMS_ARE_CSV = True 339 """ 340 Whether COPY statement parameters are separated by comma or whitespace 341 """ 342 343 FORCE_EARLY_ALIAS_REF_EXPANSION = False 344 """ 345 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 346 347 For example: 348 WITH data AS ( 349 SELECT 350 1 AS id, 351 2 AS my_id 352 ) 353 SELECT 354 id AS my_id 355 FROM 356 data 357 WHERE 358 my_id = 1 359 GROUP BY 360 my_id, 361 HAVING 362 my_id = 1 363 364 In most dialects, "my_id" would refer to "data.my_id" across the query, except: 365 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e 366 it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 367 - Clickhouse, which will forward the alias across the query i.e it resolves 368 to "WHERE id = 1 GROUP BY id HAVING id = 1" 369 """ 370 371 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 372 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 373 374 SUPPORTS_ORDER_BY_ALL = False 375 """ 376 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 377 """ 378 379 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 380 """ 381 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 382 as the former is of type INT[] vs the latter which is SUPER 383 """ 384 385 SUPPORTS_FIXED_SIZE_ARRAYS = False 386 """ 387 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. 388 in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should 389 be interpreted as a subscript/index operator. 390 """ 391 392 STRICT_JSON_PATH_SYNTAX = True 393 """Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.""" 394 395 ON_CONDITION_EMPTY_BEFORE_ERROR = True 396 """Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).""" 397 398 ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True 399 """Whether ArrayAgg needs to filter NULL values.""" 400 401 REGEXP_EXTRACT_DEFAULT_GROUP = 0 402 """The default value for the capturing group.""" 403 404 SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = { 405 exp.Except: True, 406 exp.Intersect: True, 407 exp.Union: True, 408 } 409 """ 410 Whether a set operation uses DISTINCT by default. This is `None` when either `DISTINCT` or `ALL` 411 must be explicitly specified. 412 """ 413 414 CREATABLE_KIND_MAPPING: dict[str, str] = {} 415 """ 416 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 417 equivalent of CREATE SCHEMA is CREATE DATABASE. 418 """ 419 420 # --- Autofilled --- 421 422 tokenizer_class = Tokenizer 423 jsonpath_tokenizer_class = JSONPathTokenizer 424 parser_class = Parser 425 generator_class = Generator 426 427 # A trie of the time_mapping keys 428 TIME_TRIE: t.Dict = {} 429 FORMAT_TRIE: t.Dict = {} 430 431 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 432 INVERSE_TIME_TRIE: t.Dict = {} 433 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 434 INVERSE_FORMAT_TRIE: t.Dict = {} 435 436 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 437 438 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 439 440 # Delimiters for string literals and identifiers 441 QUOTE_START = "'" 442 QUOTE_END = "'" 443 IDENTIFIER_START = '"' 444 IDENTIFIER_END = '"' 445 446 # Delimiters for bit, hex, byte and unicode literals 447 BIT_START: t.Optional[str] = None 448 BIT_END: t.Optional[str] = None 449 HEX_START: t.Optional[str] = None 450 HEX_END: t.Optional[str] = None 451 BYTE_START: t.Optional[str] = None 452 BYTE_END: t.Optional[str] = None 453 UNICODE_START: t.Optional[str] = None 454 UNICODE_END: t.Optional[str] = None 455 456 DATE_PART_MAPPING = { 457 "Y": "YEAR", 458 "YY": "YEAR", 459 "YYY": "YEAR", 460 "YYYY": "YEAR", 461 "YR": "YEAR", 462 "YEARS": "YEAR", 463 "YRS": "YEAR", 464 "MM": "MONTH", 465 "MON": "MONTH", 466 "MONS": "MONTH", 467 "MONTHS": "MONTH", 468 "D": "DAY", 469 "DD": "DAY", 470 "DAYS": "DAY", 471 "DAYOFMONTH": "DAY", 472 "DAY OF WEEK": "DAYOFWEEK", 473 "WEEKDAY": "DAYOFWEEK", 474 "DOW": "DAYOFWEEK", 475 "DW": "DAYOFWEEK", 476 "WEEKDAY_ISO": "DAYOFWEEKISO", 477 "DOW_ISO": "DAYOFWEEKISO", 478 "DW_ISO": "DAYOFWEEKISO", 479 "DAY OF YEAR": "DAYOFYEAR", 480 "DOY": "DAYOFYEAR", 481 "DY": "DAYOFYEAR", 482 "W": "WEEK", 483 "WK": "WEEK", 484 "WEEKOFYEAR": "WEEK", 485 "WOY": "WEEK", 486 "WY": "WEEK", 487 "WEEK_ISO": "WEEKISO", 488 "WEEKOFYEARISO": "WEEKISO", 489 "WEEKOFYEAR_ISO": "WEEKISO", 490 "Q": "QUARTER", 491 "QTR": "QUARTER", 492 "QTRS": "QUARTER", 493 "QUARTERS": "QUARTER", 494 "H": "HOUR", 495 "HH": "HOUR", 496 "HR": "HOUR", 497 "HOURS": "HOUR", 498 "HRS": "HOUR", 499 "M": "MINUTE", 500 "MI": "MINUTE", 501 "MIN": "MINUTE", 502 "MINUTES": "MINUTE", 503 "MINS": "MINUTE", 504 "S": "SECOND", 505 "SEC": "SECOND", 506 "SECONDS": "SECOND", 507 "SECS": "SECOND", 508 "MS": "MILLISECOND", 509 "MSEC": "MILLISECOND", 510 "MSECS": "MILLISECOND", 511 "MSECOND": "MILLISECOND", 512 "MSECONDS": "MILLISECOND", 513 "MILLISEC": "MILLISECOND", 514 "MILLISECS": "MILLISECOND", 515 "MILLISECON": "MILLISECOND", 516 "MILLISECONDS": "MILLISECOND", 517 "US": "MICROSECOND", 518 "USEC": "MICROSECOND", 519 "USECS": "MICROSECOND", 520 "MICROSEC": "MICROSECOND", 521 "MICROSECS": "MICROSECOND", 522 "USECOND": "MICROSECOND", 523 "USECONDS": "MICROSECOND", 524 "MICROSECONDS": "MICROSECOND", 525 "NS": "NANOSECOND", 526 "NSEC": "NANOSECOND", 527 "NANOSEC": "NANOSECOND", 528 "NSECOND": "NANOSECOND", 529 "NSECONDS": "NANOSECOND", 530 "NANOSECS": "NANOSECOND", 531 "EPOCH_SECOND": "EPOCH", 532 "EPOCH_SECONDS": "EPOCH", 533 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 534 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 535 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 536 "TZH": "TIMEZONE_HOUR", 537 "TZM": "TIMEZONE_MINUTE", 538 "DEC": "DECADE", 539 "DECS": "DECADE", 540 "DECADES": "DECADE", 541 "MIL": "MILLENIUM", 542 "MILS": "MILLENIUM", 543 "MILLENIA": "MILLENIUM", 544 "C": "CENTURY", 545 "CENT": "CENTURY", 546 "CENTS": "CENTURY", 547 "CENTURIES": "CENTURY", 548 } 549 550 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 551 exp.DataType.Type.BIGINT: { 552 exp.ApproxDistinct, 553 exp.ArraySize, 554 exp.Length, 555 }, 556 exp.DataType.Type.BOOLEAN: { 557 exp.Between, 558 exp.Boolean, 559 exp.In, 560 exp.RegexpLike, 561 }, 562 exp.DataType.Type.DATE: { 563 exp.CurrentDate, 564 exp.Date, 565 exp.DateFromParts, 566 exp.DateStrToDate, 567 exp.DiToDate, 568 exp.StrToDate, 569 exp.TimeStrToDate, 570 exp.TsOrDsToDate, 571 }, 572 exp.DataType.Type.DATETIME: { 573 exp.CurrentDatetime, 574 exp.Datetime, 575 exp.DatetimeAdd, 576 exp.DatetimeSub, 577 }, 578 exp.DataType.Type.DOUBLE: { 579 exp.ApproxQuantile, 580 exp.Avg, 581 exp.Div, 582 exp.Exp, 583 exp.Ln, 584 exp.Log, 585 exp.Pow, 586 exp.Quantile, 587 exp.Round, 588 exp.SafeDivide, 589 exp.Sqrt, 590 exp.Stddev, 591 exp.StddevPop, 592 exp.StddevSamp, 593 exp.Variance, 594 exp.VariancePop, 595 }, 596 exp.DataType.Type.INT: { 597 exp.Ceil, 598 exp.DatetimeDiff, 599 exp.DateDiff, 600 exp.TimestampDiff, 601 exp.TimeDiff, 602 exp.DateToDi, 603 exp.Levenshtein, 604 exp.Sign, 605 exp.StrPosition, 606 exp.TsOrDiToDi, 607 }, 608 exp.DataType.Type.JSON: { 609 exp.ParseJSON, 610 }, 611 exp.DataType.Type.TIME: { 612 exp.Time, 613 }, 614 exp.DataType.Type.TIMESTAMP: { 615 exp.CurrentTime, 616 exp.CurrentTimestamp, 617 exp.StrToTime, 618 exp.TimeAdd, 619 exp.TimeStrToTime, 620 exp.TimeSub, 621 exp.TimestampAdd, 622 exp.TimestampSub, 623 exp.UnixToTime, 624 }, 625 exp.DataType.Type.TINYINT: { 626 exp.Day, 627 exp.Month, 628 exp.Week, 629 exp.Year, 630 exp.Quarter, 631 }, 632 exp.DataType.Type.VARCHAR: { 633 exp.ArrayConcat, 634 exp.Concat, 635 exp.ConcatWs, 636 exp.DateToDateStr, 637 exp.GroupConcat, 638 exp.Initcap, 639 exp.Lower, 640 exp.Substring, 641 exp.TimeToStr, 642 exp.TimeToTimeStr, 643 exp.Trim, 644 exp.TsOrDsToDateStr, 645 exp.UnixToStr, 646 exp.UnixToTimeStr, 647 exp.Upper, 648 }, 649 } 650 651 ANNOTATORS: AnnotatorsType = { 652 **{ 653 expr_type: lambda self, e: self._annotate_unary(e) 654 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 655 }, 656 **{ 657 expr_type: lambda self, e: self._annotate_binary(e) 658 for expr_type in subclasses(exp.__name__, exp.Binary) 659 }, 660 **{ 661 expr_type: _annotate_with_type_lambda(data_type) 662 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 663 for expr_type in expressions 664 }, 665 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 666 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 667 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 668 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 669 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 670 exp.Bracket: lambda self, e: self._annotate_bracket(e), 671 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 672 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 673 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 674 exp.Count: lambda self, e: self._annotate_with_type( 675 e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT 676 ), 677 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 678 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 679 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 680 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 681 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 682 exp.Div: lambda self, e: self._annotate_div(e), 683 exp.Dot: lambda self, e: self._annotate_dot(e), 684 exp.Explode: lambda self, e: self._annotate_explode(e), 685 exp.Extract: lambda self, e: self._annotate_extract(e), 686 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 687 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 688 e, exp.DataType.build("ARRAY<DATE>") 689 ), 690 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 691 e, exp.DataType.build("ARRAY<TIMESTAMP>") 692 ), 693 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 694 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 695 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 696 exp.Literal: lambda self, e: self._annotate_literal(e), 697 exp.Map: lambda self, e: self._annotate_map(e), 698 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 699 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 700 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 701 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 702 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 703 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 704 exp.Struct: lambda self, e: self._annotate_struct(e), 705 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 706 exp.Timestamp: lambda self, e: self._annotate_with_type( 707 e, 708 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 709 ), 710 exp.ToMap: lambda self, e: self._annotate_to_map(e), 711 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 712 exp.Unnest: lambda self, e: self._annotate_unnest(e), 713 exp.VarMap: lambda self, e: self._annotate_map(e), 714 } 715 716 @classmethod 717 def get_or_raise(cls, dialect: DialectType) -> Dialect: 718 """ 719 Look up a dialect in the global dialect registry and return it if it exists. 720 721 Args: 722 dialect: The target dialect. If this is a string, it can be optionally followed by 723 additional key-value pairs that are separated by commas and are used to specify 724 dialect settings, such as whether the dialect's identifiers are case-sensitive. 725 726 Example: 727 >>> dialect = dialect_class = get_or_raise("duckdb") 728 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 729 730 Returns: 731 The corresponding Dialect instance. 732 """ 733 734 if not dialect: 735 return cls() 736 if isinstance(dialect, _Dialect): 737 return dialect() 738 if isinstance(dialect, Dialect): 739 return dialect 740 if isinstance(dialect, str): 741 try: 742 dialect_name, *kv_strings = dialect.split(",") 743 kv_pairs = (kv.split("=") for kv in kv_strings) 744 kwargs = {} 745 for pair in kv_pairs: 746 key = pair[0].strip() 747 value: t.Union[bool | str | None] = None 748 749 if len(pair) == 1: 750 # Default initialize standalone settings to True 751 value = True 752 elif len(pair) == 2: 753 value = pair[1].strip() 754 755 # Coerce the value to boolean if it matches to the truthy/falsy values below 756 value_lower = value.lower() 757 if value_lower in ("true", "1"): 758 value = True 759 elif value_lower in ("false", "0"): 760 value = False 761 762 kwargs[key] = value 763 764 except ValueError: 765 raise ValueError( 766 f"Invalid dialect format: '{dialect}'. " 767 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 768 ) 769 770 result = cls.get(dialect_name.strip()) 771 if not result: 772 from difflib import get_close_matches 773 774 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 775 if similar: 776 similar = f" Did you mean {similar}?" 777 778 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 779 780 return result(**kwargs) 781 782 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 783 784 @classmethod 785 def format_time( 786 cls, expression: t.Optional[str | exp.Expression] 787 ) -> t.Optional[exp.Expression]: 788 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 789 if isinstance(expression, str): 790 return exp.Literal.string( 791 # the time formats are quoted 792 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 793 ) 794 795 if expression and expression.is_string: 796 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 797 798 return expression 799 800 def __init__(self, **kwargs) -> None: 801 normalization_strategy = kwargs.pop("normalization_strategy", None) 802 803 if normalization_strategy is None: 804 self.normalization_strategy = self.NORMALIZATION_STRATEGY 805 else: 806 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 807 808 self.settings = kwargs 809 810 def __eq__(self, other: t.Any) -> bool: 811 # Does not currently take dialect state into account 812 return type(self) == other 813 814 def __hash__(self) -> int: 815 # Does not currently take dialect state into account 816 return hash(type(self)) 817 818 def normalize_identifier(self, expression: E) -> E: 819 """ 820 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 821 822 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 823 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 824 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 825 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 826 827 There are also dialects like Spark, which are case-insensitive even when quotes are 828 present, and dialects like MySQL, whose resolution rules match those employed by the 829 underlying operating system, for example they may always be case-sensitive in Linux. 830 831 Finally, the normalization behavior of some engines can even be controlled through flags, 832 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 833 834 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 835 that it can analyze queries in the optimizer and successfully capture their semantics. 836 """ 837 if ( 838 isinstance(expression, exp.Identifier) 839 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 840 and ( 841 not expression.quoted 842 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 843 ) 844 ): 845 expression.set( 846 "this", 847 ( 848 expression.this.upper() 849 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 850 else expression.this.lower() 851 ), 852 ) 853 854 return expression 855 856 def case_sensitive(self, text: str) -> bool: 857 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 858 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 859 return False 860 861 unsafe = ( 862 str.islower 863 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 864 else str.isupper 865 ) 866 return any(unsafe(char) for char in text) 867 868 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 869 """Checks if text can be identified given an identify option. 870 871 Args: 872 text: The text to check. 873 identify: 874 `"always"` or `True`: Always returns `True`. 875 `"safe"`: Only returns `True` if the identifier is case-insensitive. 876 877 Returns: 878 Whether the given text can be identified. 879 """ 880 if identify is True or identify == "always": 881 return True 882 883 if identify == "safe": 884 return not self.case_sensitive(text) 885 886 return False 887 888 def quote_identifier(self, expression: E, identify: bool = True) -> E: 889 """ 890 Adds quotes to a given identifier. 891 892 Args: 893 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 894 identify: If set to `False`, the quotes will only be added if the identifier is deemed 895 "unsafe", with respect to its characters and this dialect's normalization strategy. 896 """ 897 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 898 name = expression.this 899 expression.set( 900 "quoted", 901 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 902 ) 903 904 return expression 905 906 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 907 if isinstance(path, exp.Literal): 908 path_text = path.name 909 if path.is_number: 910 path_text = f"[{path_text}]" 911 try: 912 return parse_json_path(path_text, self) 913 except ParseError as e: 914 if self.STRICT_JSON_PATH_SYNTAX: 915 logger.warning(f"Invalid JSON path syntax. {str(e)}") 916 917 return path 918 919 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 920 return self.parser(**opts).parse(self.tokenize(sql), sql) 921 922 def parse_into( 923 self, expression_type: exp.IntoType, sql: str, **opts 924 ) -> t.List[t.Optional[exp.Expression]]: 925 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 926 927 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 928 return self.generator(**opts).generate(expression, copy=copy) 929 930 def transpile(self, sql: str, **opts) -> t.List[str]: 931 return [ 932 self.generate(expression, copy=False, **opts) if expression else "" 933 for expression in self.parse(sql) 934 ] 935 936 def tokenize(self, sql: str) -> t.List[Token]: 937 return self.tokenizer.tokenize(sql) 938 939 @property 940 def tokenizer(self) -> Tokenizer: 941 return self.tokenizer_class(dialect=self) 942 943 @property 944 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 945 return self.jsonpath_tokenizer_class(dialect=self) 946 947 def parser(self, **opts) -> Parser: 948 return self.parser_class(dialect=self, **opts) 949 950 def generator(self, **opts) -> Generator: 951 return self.generator_class(dialect=self, **opts)
800 def __init__(self, **kwargs) -> None: 801 normalization_strategy = kwargs.pop("normalization_strategy", None) 802 803 if normalization_strategy is None: 804 self.normalization_strategy = self.NORMALIZATION_STRATEGY 805 else: 806 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 807 808 self.settings = kwargs
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
For example:
WITH data AS ( SELECT 1 AS id, 2 AS my_id ) SELECT id AS my_id FROM data WHERE my_id = 1 GROUP BY my_id, HAVING my_id = 1
In most dialects, "my_id" would refer to "data.my_id" across the query, except: - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
Whether alias reference expansion before qualification should only happen for the GROUP BY clause.
Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks
Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) as the former is of type INT[] vs the latter which is SUPER
Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator.
Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.
Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).
Whether a set operation uses DISTINCT by default. This is None
when either DISTINCT
or ALL
must be explicitly specified.
Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse equivalent of CREATE SCHEMA is CREATE DATABASE.
716 @classmethod 717 def get_or_raise(cls, dialect: DialectType) -> Dialect: 718 """ 719 Look up a dialect in the global dialect registry and return it if it exists. 720 721 Args: 722 dialect: The target dialect. If this is a string, it can be optionally followed by 723 additional key-value pairs that are separated by commas and are used to specify 724 dialect settings, such as whether the dialect's identifiers are case-sensitive. 725 726 Example: 727 >>> dialect = dialect_class = get_or_raise("duckdb") 728 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 729 730 Returns: 731 The corresponding Dialect instance. 732 """ 733 734 if not dialect: 735 return cls() 736 if isinstance(dialect, _Dialect): 737 return dialect() 738 if isinstance(dialect, Dialect): 739 return dialect 740 if isinstance(dialect, str): 741 try: 742 dialect_name, *kv_strings = dialect.split(",") 743 kv_pairs = (kv.split("=") for kv in kv_strings) 744 kwargs = {} 745 for pair in kv_pairs: 746 key = pair[0].strip() 747 value: t.Union[bool | str | None] = None 748 749 if len(pair) == 1: 750 # Default initialize standalone settings to True 751 value = True 752 elif len(pair) == 2: 753 value = pair[1].strip() 754 755 # Coerce the value to boolean if it matches to the truthy/falsy values below 756 value_lower = value.lower() 757 if value_lower in ("true", "1"): 758 value = True 759 elif value_lower in ("false", "0"): 760 value = False 761 762 kwargs[key] = value 763 764 except ValueError: 765 raise ValueError( 766 f"Invalid dialect format: '{dialect}'. " 767 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 768 ) 769 770 result = cls.get(dialect_name.strip()) 771 if not result: 772 from difflib import get_close_matches 773 774 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 775 if similar: 776 similar = f" Did you mean {similar}?" 777 778 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 779 780 return result(**kwargs) 781 782 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
784 @classmethod 785 def format_time( 786 cls, expression: t.Optional[str | exp.Expression] 787 ) -> t.Optional[exp.Expression]: 788 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 789 if isinstance(expression, str): 790 return exp.Literal.string( 791 # the time formats are quoted 792 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 793 ) 794 795 if expression and expression.is_string: 796 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 797 798 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
818 def normalize_identifier(self, expression: E) -> E: 819 """ 820 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 821 822 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 823 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 824 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 825 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 826 827 There are also dialects like Spark, which are case-insensitive even when quotes are 828 present, and dialects like MySQL, whose resolution rules match those employed by the 829 underlying operating system, for example they may always be case-sensitive in Linux. 830 831 Finally, the normalization behavior of some engines can even be controlled through flags, 832 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 833 834 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 835 that it can analyze queries in the optimizer and successfully capture their semantics. 836 """ 837 if ( 838 isinstance(expression, exp.Identifier) 839 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 840 and ( 841 not expression.quoted 842 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 843 ) 844 ): 845 expression.set( 846 "this", 847 ( 848 expression.this.upper() 849 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 850 else expression.this.lower() 851 ), 852 ) 853 854 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
856 def case_sensitive(self, text: str) -> bool: 857 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 858 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 859 return False 860 861 unsafe = ( 862 str.islower 863 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 864 else str.isupper 865 ) 866 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
868 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 869 """Checks if text can be identified given an identify option. 870 871 Args: 872 text: The text to check. 873 identify: 874 `"always"` or `True`: Always returns `True`. 875 `"safe"`: Only returns `True` if the identifier is case-insensitive. 876 877 Returns: 878 Whether the given text can be identified. 879 """ 880 if identify is True or identify == "always": 881 return True 882 883 if identify == "safe": 884 return not self.case_sensitive(text) 885 886 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
888 def quote_identifier(self, expression: E, identify: bool = True) -> E: 889 """ 890 Adds quotes to a given identifier. 891 892 Args: 893 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 894 identify: If set to `False`, the quotes will only be added if the identifier is deemed 895 "unsafe", with respect to its characters and this dialect's normalization strategy. 896 """ 897 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 898 name = expression.this 899 expression.set( 900 "quoted", 901 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 902 ) 903 904 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
906 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 907 if isinstance(path, exp.Literal): 908 path_text = path.name 909 if path.is_number: 910 path_text = f"[{path_text}]" 911 try: 912 return parse_json_path(path_text, self) 913 except ParseError as e: 914 if self.STRICT_JSON_PATH_SYNTAX: 915 logger.warning(f"Invalid JSON path syntax. {str(e)}") 916 917 return path
966def if_sql( 967 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 968) -> t.Callable[[Generator, exp.If], str]: 969 def _if_sql(self: Generator, expression: exp.If) -> str: 970 return self.func( 971 name, 972 expression.this, 973 expression.args.get("true"), 974 expression.args.get("false") or false_value, 975 ) 976 977 return _if_sql
980def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 981 this = expression.this 982 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 983 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 984 985 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
1055def str_position_sql( 1056 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 1057) -> str: 1058 this = self.sql(expression, "this") 1059 substr = self.sql(expression, "substr") 1060 position = self.sql(expression, "position") 1061 instance = expression.args.get("instance") if generate_instance else None 1062 position_offset = "" 1063 1064 if position: 1065 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1066 this = self.func("SUBSTR", this, position) 1067 position_offset = f" + {position} - 1" 1068 1069 return self.func("STRPOS", this, substr, instance) + position_offset
1078def var_map_sql( 1079 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1080) -> str: 1081 keys = expression.args["keys"] 1082 values = expression.args["values"] 1083 1084 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1085 self.unsupported("Cannot convert array columns into map.") 1086 return self.func(map_func_name, keys, values) 1087 1088 args = [] 1089 for key, value in zip(keys.expressions, values.expressions): 1090 args.append(self.sql(key)) 1091 args.append(self.sql(value)) 1092 1093 return self.func(map_func_name, *args)
1096def build_formatted_time( 1097 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1098) -> t.Callable[[t.List], E]: 1099 """Helper used for time expressions. 1100 1101 Args: 1102 exp_class: the expression class to instantiate. 1103 dialect: target sql dialect. 1104 default: the default format, True being time. 1105 1106 Returns: 1107 A callable that can be used to return the appropriately formatted time expression. 1108 """ 1109 1110 def _builder(args: t.List): 1111 return exp_class( 1112 this=seq_get(args, 0), 1113 format=Dialect[dialect].format_time( 1114 seq_get(args, 1) 1115 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1116 ), 1117 ) 1118 1119 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
1122def time_format( 1123 dialect: DialectType = None, 1124) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1125 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1126 """ 1127 Returns the time format for a given expression, unless it's equivalent 1128 to the default time format of the dialect of interest. 1129 """ 1130 time_format = self.format_time(expression) 1131 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1132 1133 return _time_format
1136def build_date_delta( 1137 exp_class: t.Type[E], 1138 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1139 default_unit: t.Optional[str] = "DAY", 1140) -> t.Callable[[t.List], E]: 1141 def _builder(args: t.List) -> E: 1142 unit_based = len(args) == 3 1143 this = args[2] if unit_based else seq_get(args, 0) 1144 unit = None 1145 if unit_based or default_unit: 1146 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1147 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1148 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1149 1150 return _builder
1153def build_date_delta_with_interval( 1154 expression_class: t.Type[E], 1155) -> t.Callable[[t.List], t.Optional[E]]: 1156 def _builder(args: t.List) -> t.Optional[E]: 1157 if len(args) < 2: 1158 return None 1159 1160 interval = args[1] 1161 1162 if not isinstance(interval, exp.Interval): 1163 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1164 1165 return expression_class(this=args[0], expression=interval.this, unit=unit_to_str(interval)) 1166 1167 return _builder
1170def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1171 unit = seq_get(args, 0) 1172 this = seq_get(args, 1) 1173 1174 if isinstance(this, exp.Cast) and this.is_type("date"): 1175 return exp.DateTrunc(unit=unit, this=this) 1176 return exp.TimestampTrunc(this=this, unit=unit)
1179def date_add_interval_sql( 1180 data_type: str, kind: str 1181) -> t.Callable[[Generator, exp.Expression], str]: 1182 def func(self: Generator, expression: exp.Expression) -> str: 1183 this = self.sql(expression, "this") 1184 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1185 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1186 1187 return func
1190def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1191 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1192 args = [unit_to_str(expression), expression.this] 1193 if zone: 1194 args.append(expression.args.get("zone")) 1195 return self.func("DATE_TRUNC", *args) 1196 1197 return _timestamptrunc_sql
1200def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1201 zone = expression.args.get("zone") 1202 if not zone: 1203 from sqlglot.optimizer.annotate_types import annotate_types 1204 1205 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1206 return self.sql(exp.cast(expression.this, target_type)) 1207 if zone.name.lower() in TIMEZONES: 1208 return self.sql( 1209 exp.AtTimeZone( 1210 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1211 zone=zone, 1212 ) 1213 ) 1214 return self.func("TIMESTAMP", expression.this, zone)
1217def no_time_sql(self: Generator, expression: exp.Time) -> str: 1218 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1219 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1220 expr = exp.cast( 1221 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1222 ) 1223 return self.sql(expr)
1226def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1227 this = expression.this 1228 expr = expression.expression 1229 1230 if expr.name.lower() in TIMEZONES: 1231 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1232 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1233 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1234 return self.sql(this) 1235 1236 this = exp.cast(this, exp.DataType.Type.DATE) 1237 expr = exp.cast(expr, exp.DataType.Type.TIME) 1238 1239 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
1271def timestrtotime_sql( 1272 self: Generator, 1273 expression: exp.TimeStrToTime, 1274 include_precision: bool = False, 1275) -> str: 1276 datatype = exp.DataType.build( 1277 exp.DataType.Type.TIMESTAMPTZ 1278 if expression.args.get("zone") 1279 else exp.DataType.Type.TIMESTAMP 1280 ) 1281 1282 if isinstance(expression.this, exp.Literal) and include_precision: 1283 precision = subsecond_precision(expression.this.name) 1284 if precision > 0: 1285 datatype = exp.DataType.build( 1286 datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))] 1287 ) 1288 1289 return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect))
1297def encode_decode_sql( 1298 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1299) -> str: 1300 charset = expression.args.get("charset") 1301 if charset and charset.name.lower() != "utf-8": 1302 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1303 1304 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
1317def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1318 cond = expression.this 1319 1320 if isinstance(expression.this, exp.Distinct): 1321 cond = expression.this.expressions[0] 1322 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1323 1324 return self.func("sum", exp.func("if", cond, 1, 0))
1327def trim_sql(self: Generator, expression: exp.Trim) -> str: 1328 target = self.sql(expression, "this") 1329 trim_type = self.sql(expression, "position") 1330 remove_chars = self.sql(expression, "expression") 1331 collation = self.sql(expression, "collation") 1332 1333 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1334 if not remove_chars: 1335 return self.trim_sql(expression) 1336 1337 trim_type = f"{trim_type} " if trim_type else "" 1338 remove_chars = f"{remove_chars} " if remove_chars else "" 1339 from_part = "FROM " if trim_type or remove_chars else "" 1340 collation = f" COLLATE {collation}" if collation else "" 1341 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
1362@unsupported_args("position", "occurrence", "parameters") 1363def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1364 group = expression.args.get("group") 1365 1366 # Do not render group if it's the default value for this dialect 1367 if group and group.name == str(self.dialect.REGEXP_EXTRACT_DEFAULT_GROUP): 1368 group = None 1369 1370 return self.func("REGEXP_EXTRACT", expression.this, expression.expression, group)
1380def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1381 names = [] 1382 for agg in aggregations: 1383 if isinstance(agg, exp.Alias): 1384 names.append(agg.alias) 1385 else: 1386 """ 1387 This case corresponds to aggregations without aliases being used as suffixes 1388 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1389 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1390 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1391 """ 1392 agg_all_unquoted = agg.transform( 1393 lambda node: ( 1394 exp.Identifier(this=node.name, quoted=False) 1395 if isinstance(node, exp.Identifier) 1396 else node 1397 ) 1398 ) 1399 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1400 1401 return names
1441def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1442 @unsupported_args("count") 1443 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1444 return self.func(name, expression.this, expression.expression) 1445 1446 return _arg_max_or_min_sql
1449def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1450 this = expression.this.copy() 1451 1452 return_type = expression.return_type 1453 if return_type.is_type(exp.DataType.Type.DATE): 1454 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1455 # can truncate timestamp strings, because some dialects can't cast them to DATE 1456 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1457 1458 expression.this.replace(exp.cast(this, return_type)) 1459 return expression
1462def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1463 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1464 if cast and isinstance(expression, exp.TsOrDsAdd): 1465 expression = ts_or_ds_add_cast(expression) 1466 1467 return self.func( 1468 name, 1469 unit_to_var(expression), 1470 expression.expression, 1471 expression.this, 1472 ) 1473 1474 return _delta_sql
1477def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1478 unit = expression.args.get("unit") 1479 1480 if isinstance(unit, exp.Placeholder): 1481 return unit 1482 if unit: 1483 return exp.Literal.string(unit.name) 1484 return exp.Literal.string(default) if default else None
1514def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1515 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1516 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1517 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1518 1519 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1522def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1523 """Remove table refs from columns in when statements.""" 1524 alias = expression.this.args.get("alias") 1525 1526 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1527 return self.dialect.normalize_identifier(identifier).name if identifier else None 1528 1529 targets = {normalize(expression.this.this)} 1530 1531 if alias: 1532 targets.add(normalize(alias.this)) 1533 1534 for when in expression.expressions: 1535 # only remove the target names from the THEN clause 1536 # theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED 1537 # ref: https://github.com/TobikoData/sqlmesh/issues/2934 1538 then = when.args.get("then") 1539 if then: 1540 then.transform( 1541 lambda node: ( 1542 exp.column(node.this) 1543 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1544 else node 1545 ), 1546 copy=False, 1547 ) 1548 1549 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1552def build_json_extract_path( 1553 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1554) -> t.Callable[[t.List], F]: 1555 def _builder(args: t.List) -> F: 1556 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1557 for arg in args[1:]: 1558 if not isinstance(arg, exp.Literal): 1559 # We use the fallback parser because we can't really transpile non-literals safely 1560 return expr_type.from_arg_list(args) 1561 1562 text = arg.name 1563 if is_int(text): 1564 index = int(text) 1565 segments.append( 1566 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1567 ) 1568 else: 1569 segments.append(exp.JSONPathKey(this=text)) 1570 1571 # This is done to avoid failing in the expression validator due to the arg count 1572 del args[2:] 1573 return expr_type( 1574 this=seq_get(args, 0), 1575 expression=exp.JSONPath(expressions=segments), 1576 only_json_types=arrow_req_json_type, 1577 ) 1578 1579 return _builder
1582def json_extract_segments( 1583 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1584) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1585 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1586 path = expression.expression 1587 if not isinstance(path, exp.JSONPath): 1588 return rename_func(name)(self, expression) 1589 1590 escape = path.args.get("escape") 1591 1592 segments = [] 1593 for segment in path.expressions: 1594 path = self.sql(segment) 1595 if path: 1596 if isinstance(segment, exp.JSONPathPart) and ( 1597 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1598 ): 1599 if escape: 1600 path = self.escape_str(path) 1601 1602 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1603 1604 segments.append(path) 1605 1606 if op: 1607 return f" {op} ".join([self.sql(expression.this), *segments]) 1608 return self.func(name, expression.this, *segments) 1609 1610 return _json_extract_segments
1620def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1621 cond = expression.expression 1622 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1623 alias = cond.expressions[0] 1624 cond = cond.this 1625 elif isinstance(cond, exp.Predicate): 1626 alias = "_u" 1627 else: 1628 self.unsupported("Unsupported filter condition") 1629 return "" 1630 1631 unnest = exp.Unnest(expressions=[expression.this]) 1632 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1633 return self.sql(exp.Array(expressions=[filtered]))
1645def build_default_decimal_type( 1646 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1647) -> t.Callable[[exp.DataType], exp.DataType]: 1648 def _builder(dtype: exp.DataType) -> exp.DataType: 1649 if dtype.expressions or precision is None: 1650 return dtype 1651 1652 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1653 return exp.DataType.build(f"DECIMAL({params})") 1654 1655 return _builder
1658def build_timestamp_from_parts(args: t.List) -> exp.Func: 1659 if len(args) == 2: 1660 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1661 # so we parse this into Anonymous for now instead of introducing complexity 1662 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1663 1664 return exp.TimestampFromParts.from_arg_list(args)
1671def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1672 start = expression.args.get("start") 1673 end = expression.args.get("end") 1674 step = expression.args.get("step") 1675 1676 if isinstance(start, exp.Cast): 1677 target_type = start.to 1678 elif isinstance(end, exp.Cast): 1679 target_type = end.to 1680 else: 1681 target_type = None 1682 1683 if start and end and target_type and target_type.is_type("date", "timestamp"): 1684 if isinstance(start, exp.Cast) and target_type is start.to: 1685 end = exp.cast(end, target_type) 1686 else: 1687 start = exp.cast(start, target_type) 1688 1689 return self.func("SEQUENCE", start, end, step)