2025-02-13 15:02:59 +01:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import typing as t
|
|
|
|
|
|
|
|
from sqlglot import expressions as exp
|
|
|
|
|
|
|
|
if t.TYPE_CHECKING:
|
2025-02-13 15:58:03 +01:00
|
|
|
JSON = t.Union[dict, list, str, float, int, bool, None]
|
2025-02-13 15:02:59 +01:00
|
|
|
Node = t.Union[t.List["Node"], exp.DataType.Type, exp.Expression, JSON]
|
|
|
|
|
|
|
|
|
|
|
|
def dump(node: Node) -> JSON:
|
|
|
|
"""
|
|
|
|
Recursively dump an AST into a JSON-serializable dict.
|
|
|
|
"""
|
|
|
|
if isinstance(node, list):
|
|
|
|
return [dump(i) for i in node]
|
|
|
|
if isinstance(node, exp.DataType.Type):
|
|
|
|
return {
|
|
|
|
"class": "DataType.Type",
|
|
|
|
"value": node.value,
|
|
|
|
}
|
|
|
|
if isinstance(node, exp.Expression):
|
|
|
|
klass = node.__class__.__qualname__
|
|
|
|
if node.__class__.__module__ != exp.__name__:
|
|
|
|
klass = f"{node.__module__}.{klass}"
|
2025-02-13 15:58:03 +01:00
|
|
|
obj: t.Dict = {
|
2025-02-13 15:02:59 +01:00
|
|
|
"class": klass,
|
|
|
|
"args": {k: dump(v) for k, v in node.args.items() if v is not None and v != []},
|
|
|
|
}
|
|
|
|
if node.type:
|
2025-02-13 15:58:03 +01:00
|
|
|
obj["type"] = dump(node.type)
|
2025-02-13 15:02:59 +01:00
|
|
|
if node.comments:
|
|
|
|
obj["comments"] = node.comments
|
2025-02-13 15:31:44 +01:00
|
|
|
if node._meta is not None:
|
|
|
|
obj["meta"] = node._meta
|
|
|
|
|
2025-02-13 15:02:59 +01:00
|
|
|
return obj
|
|
|
|
return node
|
|
|
|
|
|
|
|
|
|
|
|
def load(obj: JSON) -> Node:
|
|
|
|
"""
|
|
|
|
Recursively load a dict (as returned by `dump`) into an AST.
|
|
|
|
"""
|
|
|
|
if isinstance(obj, list):
|
|
|
|
return [load(i) for i in obj]
|
|
|
|
if isinstance(obj, dict):
|
|
|
|
class_name = obj["class"]
|
|
|
|
|
|
|
|
if class_name == "DataType.Type":
|
|
|
|
return exp.DataType.Type(obj["value"])
|
|
|
|
|
|
|
|
if "." in class_name:
|
|
|
|
module_path, class_name = class_name.rsplit(".", maxsplit=1)
|
|
|
|
module = __import__(module_path, fromlist=[class_name])
|
|
|
|
else:
|
|
|
|
module = exp
|
|
|
|
|
|
|
|
klass = getattr(module, class_name)
|
|
|
|
|
|
|
|
expression = klass(**{k: load(v) for k, v in obj["args"].items()})
|
2025-02-13 15:58:03 +01:00
|
|
|
expression.type = t.cast(exp.DataType, load(obj.get("type")))
|
2025-02-13 15:31:44 +01:00
|
|
|
expression.comments = obj.get("comments")
|
|
|
|
expression._meta = obj.get("meta")
|
|
|
|
|
2025-02-13 15:02:59 +01:00
|
|
|
return expression
|
|
|
|
return obj
|