1
0
Fork 0

Merging upstream version 26.6.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 22:19:49 +01:00
parent 12333df27e
commit 3532bfd564
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
99 changed files with 40433 additions and 38803 deletions

View file

@ -1,11 +1,13 @@
from __future__ import annotations
import importlib
import logging
import typing as t
from enum import Enum, auto
from functools import reduce
from sqlglot import exp
from sqlglot.dialects import DIALECT_MODULE_NAMES
from sqlglot.errors import ParseError
from sqlglot.generator import Generator, unsupported_args
from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses, to_bool
@ -64,6 +66,7 @@ class Dialects(str, Enum):
DRILL = "drill"
DRUID = "druid"
DUCKDB = "duckdb"
DUNE = "dune"
HIVE = "hive"
MATERIALIZE = "materialize"
MYSQL = "mysql"
@ -101,7 +104,7 @@ class NormalizationStrategy(str, AutoName):
class _Dialect(type):
classes: t.Dict[str, t.Type[Dialect]] = {}
_classes: t.Dict[str, t.Type[Dialect]] = {}
def __eq__(cls, other: t.Any) -> bool:
if cls is other:
@ -116,20 +119,46 @@ class _Dialect(type):
def __hash__(cls) -> int:
return hash(cls.__name__.lower())
@property
def classes(cls):
if len(DIALECT_MODULE_NAMES) != len(cls._classes):
for key in DIALECT_MODULE_NAMES:
cls._try_load(key)
return cls._classes
@classmethod
def _try_load(cls, key: str | Dialects) -> None:
if isinstance(key, Dialects):
key = key.value
# This import will lead to a new dialect being loaded, and hence, registered.
# We check that the key is an actual sqlglot module to avoid blindly importing
# files. Custom user dialects need to be imported at the top-level package, in
# order for them to be registered as soon as possible.
if key in DIALECT_MODULE_NAMES:
importlib.import_module(f"sqlglot.dialects.{key}")
@classmethod
def __getitem__(cls, key: str) -> t.Type[Dialect]:
return cls.classes[key]
if key not in cls._classes:
cls._try_load(key)
return cls._classes[key]
@classmethod
def get(
cls, key: str, default: t.Optional[t.Type[Dialect]] = None
) -> t.Optional[t.Type[Dialect]]:
return cls.classes.get(key, default)
if key not in cls._classes:
cls._try_load(key)
return cls._classes.get(key, default)
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
enum = Dialects.__members__.get(clsname.upper())
cls.classes[enum.value if enum is not None else clsname.lower()] = klass
cls._classes[enum.value if enum is not None else clsname.lower()] = klass
klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
klass.FORMAT_TRIE = (
@ -792,7 +821,9 @@ class Dialect(metaclass=_Dialect):
if not result:
from difflib import get_close_matches
similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
close_matches = get_close_matches(dialect_name, list(DIALECT_MODULE_NAMES), n=1)
similar = seq_get(close_matches, 0) or ""
if similar:
similar = f" Did you mean {similar}?"
@ -1119,8 +1150,8 @@ def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
def var_map_sql(
self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
) -> str:
keys = expression.args["keys"]
values = expression.args["values"]
keys = expression.args.get("keys")
values = expression.args.get("values")
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
self.unsupported("Cannot convert array columns into map.")