Adding upstream version 1.2+20240521.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
6b2864e4b9
commit
8512f66c5a
229 changed files with 19561 additions and 0 deletions
1
src/aristaproto/plugin/__init__.py
Normal file
1
src/aristaproto/plugin/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from .main import main
|
4
src/aristaproto/plugin/__main__.py
Normal file
4
src/aristaproto/plugin/__main__.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
from .main import main
|
||||
|
||||
|
||||
main()
|
50
src/aristaproto/plugin/compiler.py
Normal file
50
src/aristaproto/plugin/compiler.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
import os.path
|
||||
|
||||
|
||||
try:
|
||||
# aristaproto[compiler] specific dependencies
|
||||
import black
|
||||
import isort.api
|
||||
import jinja2
|
||||
except ImportError as err:
|
||||
print(
|
||||
"\033[31m"
|
||||
f"Unable to import `{err.name}` from aristaproto plugin! "
|
||||
"Please ensure that you've installed aristaproto as "
|
||||
'`pip install "aristaproto[compiler]"` so that compiler dependencies '
|
||||
"are included."
|
||||
"\033[0m"
|
||||
)
|
||||
raise SystemExit(1)
|
||||
|
||||
from .models import OutputTemplate
|
||||
|
||||
|
||||
def outputfile_compiler(output_file: OutputTemplate) -> str:
|
||||
templates_folder = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "templates")
|
||||
)
|
||||
|
||||
env = jinja2.Environment(
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
loader=jinja2.FileSystemLoader(templates_folder),
|
||||
)
|
||||
template = env.get_template("template.py.j2")
|
||||
|
||||
code = template.render(output_file=output_file)
|
||||
code = isort.api.sort_code_string(
|
||||
code=code,
|
||||
show_diff=False,
|
||||
py_version=37,
|
||||
profile="black",
|
||||
combine_as_imports=True,
|
||||
lines_after_imports=2,
|
||||
quiet=True,
|
||||
force_grid_wrap=2,
|
||||
known_third_party=["grpclib", "aristaproto"],
|
||||
)
|
||||
return black.format_str(
|
||||
src_contents=code,
|
||||
mode=black.Mode(),
|
||||
)
|
52
src/aristaproto/plugin/main.py
Executable file
52
src/aristaproto/plugin/main.py
Executable file
|
@ -0,0 +1,52 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from aristaproto.lib.google.protobuf.compiler import (
|
||||
CodeGeneratorRequest,
|
||||
CodeGeneratorResponse,
|
||||
)
|
||||
from aristaproto.plugin.models import monkey_patch_oneof_index
|
||||
from aristaproto.plugin.parser import generate_code
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""The plugin's main entry point."""
|
||||
# Read request message from stdin
|
||||
data = sys.stdin.buffer.read()
|
||||
|
||||
# Apply Work around for proto2/3 difference in protoc messages
|
||||
monkey_patch_oneof_index()
|
||||
|
||||
# Parse request
|
||||
request = CodeGeneratorRequest()
|
||||
request.parse(data)
|
||||
|
||||
dump_file = os.getenv("ARISTAPROTO_DUMP")
|
||||
if dump_file:
|
||||
dump_request(dump_file, request)
|
||||
|
||||
# Generate code
|
||||
response = generate_code(request)
|
||||
|
||||
# Serialise response message
|
||||
output = response.SerializeToString()
|
||||
|
||||
# Write to stdout
|
||||
sys.stdout.buffer.write(output)
|
||||
|
||||
|
||||
def dump_request(dump_file: str, request: CodeGeneratorRequest) -> None:
|
||||
"""
|
||||
For developers: Supports running plugin.py standalone so its possible to debug it.
|
||||
Run protoc (or generate.py) with ARISTAPROTO_DUMP="yourfile.bin" to write the request to a file.
|
||||
Then run plugin.py from your IDE in debugging mode, and redirect stdin to the file.
|
||||
"""
|
||||
with open(str(dump_file), "wb") as fh:
|
||||
sys.stderr.write(f"\033[31mWriting input from protoc to: {dump_file}\033[0m\n")
|
||||
fh.write(request.SerializeToString())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
851
src/aristaproto/plugin/models.py
Normal file
851
src/aristaproto/plugin/models.py
Normal file
|
@ -0,0 +1,851 @@
|
|||
"""Plugin model dataclasses.
|
||||
|
||||
These classes are meant to be an intermediate representation
|
||||
of protobuf objects. They are used to organize the data collected during parsing.
|
||||
|
||||
The general intention is to create a doubly-linked tree-like structure
|
||||
with the following types of references:
|
||||
- Downwards references: from message -> fields, from output package -> messages
|
||||
or from service -> service methods
|
||||
- Upwards references: from field -> message, message -> package.
|
||||
- Input/output message references: from a service method to it's corresponding
|
||||
input/output messages, which may even be in another package.
|
||||
|
||||
There are convenience methods to allow climbing up and down this tree, for
|
||||
example to retrieve the list of all messages that are in the same package as
|
||||
the current message.
|
||||
|
||||
Most of these classes take as inputs:
|
||||
- proto_obj: A reference to it's corresponding protobuf object as
|
||||
presented by the protoc plugin.
|
||||
- parent: a reference to the parent object in the tree.
|
||||
|
||||
With this information, the class is able to expose attributes,
|
||||
such as a pythonized name, that will be calculated from proto_obj.
|
||||
|
||||
The instantiation should also attach a reference to the new object
|
||||
into the corresponding place within it's parent object. For example,
|
||||
instantiating field `A` with parent message `B` should add a
|
||||
reference to `A` to `B`'s `fields` attribute.
|
||||
"""
|
||||
|
||||
|
||||
import builtins
|
||||
import re
|
||||
import textwrap
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
field,
|
||||
)
|
||||
from typing import (
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import aristaproto
|
||||
from aristaproto import which_one_of
|
||||
from aristaproto.casing import sanitize_name
|
||||
from aristaproto.compile.importing import (
|
||||
get_type_reference,
|
||||
parse_source_type_name,
|
||||
)
|
||||
from aristaproto.compile.naming import (
|
||||
pythonize_class_name,
|
||||
pythonize_field_name,
|
||||
pythonize_method_name,
|
||||
)
|
||||
from aristaproto.lib.google.protobuf import (
|
||||
DescriptorProto,
|
||||
EnumDescriptorProto,
|
||||
Field,
|
||||
FieldDescriptorProto,
|
||||
FieldDescriptorProtoLabel,
|
||||
FieldDescriptorProtoType,
|
||||
FileDescriptorProto,
|
||||
MethodDescriptorProto,
|
||||
)
|
||||
from aristaproto.lib.google.protobuf.compiler import CodeGeneratorRequest
|
||||
|
||||
from ..compile.importing import (
|
||||
get_type_reference,
|
||||
parse_source_type_name,
|
||||
)
|
||||
from ..compile.naming import (
|
||||
pythonize_class_name,
|
||||
pythonize_enum_member_name,
|
||||
pythonize_field_name,
|
||||
pythonize_method_name,
|
||||
)
|
||||
|
||||
|
||||
# Create a unique placeholder to deal with
|
||||
# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
|
||||
PLACEHOLDER = object()
|
||||
|
||||
# Organize proto types into categories
|
||||
PROTO_FLOAT_TYPES = (
|
||||
FieldDescriptorProtoType.TYPE_DOUBLE, # 1
|
||||
FieldDescriptorProtoType.TYPE_FLOAT, # 2
|
||||
)
|
||||
PROTO_INT_TYPES = (
|
||||
FieldDescriptorProtoType.TYPE_INT64, # 3
|
||||
FieldDescriptorProtoType.TYPE_UINT64, # 4
|
||||
FieldDescriptorProtoType.TYPE_INT32, # 5
|
||||
FieldDescriptorProtoType.TYPE_FIXED64, # 6
|
||||
FieldDescriptorProtoType.TYPE_FIXED32, # 7
|
||||
FieldDescriptorProtoType.TYPE_UINT32, # 13
|
||||
FieldDescriptorProtoType.TYPE_SFIXED32, # 15
|
||||
FieldDescriptorProtoType.TYPE_SFIXED64, # 16
|
||||
FieldDescriptorProtoType.TYPE_SINT32, # 17
|
||||
FieldDescriptorProtoType.TYPE_SINT64, # 18
|
||||
)
|
||||
PROTO_BOOL_TYPES = (FieldDescriptorProtoType.TYPE_BOOL,) # 8
|
||||
PROTO_STR_TYPES = (FieldDescriptorProtoType.TYPE_STRING,) # 9
|
||||
PROTO_BYTES_TYPES = (FieldDescriptorProtoType.TYPE_BYTES,) # 12
|
||||
PROTO_MESSAGE_TYPES = (
|
||||
FieldDescriptorProtoType.TYPE_MESSAGE, # 11
|
||||
FieldDescriptorProtoType.TYPE_ENUM, # 14
|
||||
)
|
||||
PROTO_MAP_TYPES = (FieldDescriptorProtoType.TYPE_MESSAGE,) # 11
|
||||
PROTO_PACKED_TYPES = (
|
||||
FieldDescriptorProtoType.TYPE_DOUBLE, # 1
|
||||
FieldDescriptorProtoType.TYPE_FLOAT, # 2
|
||||
FieldDescriptorProtoType.TYPE_INT64, # 3
|
||||
FieldDescriptorProtoType.TYPE_UINT64, # 4
|
||||
FieldDescriptorProtoType.TYPE_INT32, # 5
|
||||
FieldDescriptorProtoType.TYPE_FIXED64, # 6
|
||||
FieldDescriptorProtoType.TYPE_FIXED32, # 7
|
||||
FieldDescriptorProtoType.TYPE_BOOL, # 8
|
||||
FieldDescriptorProtoType.TYPE_UINT32, # 13
|
||||
FieldDescriptorProtoType.TYPE_SFIXED32, # 15
|
||||
FieldDescriptorProtoType.TYPE_SFIXED64, # 16
|
||||
FieldDescriptorProtoType.TYPE_SINT32, # 17
|
||||
FieldDescriptorProtoType.TYPE_SINT64, # 18
|
||||
)
|
||||
|
||||
|
||||
def monkey_patch_oneof_index():
|
||||
"""
|
||||
The compiler message types are written for proto2, but we read them as proto3.
|
||||
For this to work in the case of the oneof_index fields, which depend on being able
|
||||
to tell whether they were set, we have to treat them as oneof fields. This method
|
||||
monkey patches the generated classes after the fact to force this behaviour.
|
||||
"""
|
||||
object.__setattr__(
|
||||
FieldDescriptorProto.__dataclass_fields__["oneof_index"].metadata[
|
||||
"aristaproto"
|
||||
],
|
||||
"group",
|
||||
"oneof_index",
|
||||
)
|
||||
object.__setattr__(
|
||||
Field.__dataclass_fields__["oneof_index"].metadata["aristaproto"],
|
||||
"group",
|
||||
"oneof_index",
|
||||
)
|
||||
|
||||
|
||||
def get_comment(
|
||||
proto_file: "FileDescriptorProto", path: List[int], indent: int = 4
|
||||
) -> str:
|
||||
pad = " " * indent
|
||||
for sci_loc in proto_file.source_code_info.location:
|
||||
if list(sci_loc.path) == path and sci_loc.leading_comments:
|
||||
lines = sci_loc.leading_comments.strip().replace("\t", " ").split("\n")
|
||||
# This is a field, message, enum, service, or method
|
||||
if len(lines) == 1 and len(lines[0]) < 79 - indent - 6:
|
||||
lines[0] = lines[0].strip('"')
|
||||
# rstrip to remove trailing spaces including whitespaces from empty lines.
|
||||
return f'{pad}"""{lines[0]}"""'
|
||||
else:
|
||||
# rstrip to remove trailing spaces including empty lines.
|
||||
padded = [f"\n{pad}{line}".rstrip(" ") for line in lines]
|
||||
joined = "".join(padded)
|
||||
return f'{pad}"""{joined}\n{pad}"""'
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
class ProtoContentBase:
|
||||
"""Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler."""
|
||||
|
||||
source_file: FileDescriptorProto
|
||||
path: List[int]
|
||||
comment_indent: int = 4
|
||||
parent: Union["aristaproto.Message", "OutputTemplate"]
|
||||
|
||||
__dataclass_fields__: Dict[str, object]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Checks that no fake default fields were left as placeholders."""
|
||||
for field_name, field_val in self.__dataclass_fields__.items():
|
||||
if field_val is PLACEHOLDER:
|
||||
raise ValueError(f"`{field_name}` is a required field.")
|
||||
|
||||
@property
|
||||
def output_file(self) -> "OutputTemplate":
|
||||
current = self
|
||||
while not isinstance(current, OutputTemplate):
|
||||
current = current.parent
|
||||
return current
|
||||
|
||||
@property
|
||||
def request(self) -> "PluginRequestCompiler":
|
||||
current = self
|
||||
while not isinstance(current, OutputTemplate):
|
||||
current = current.parent
|
||||
return current.parent_request
|
||||
|
||||
@property
|
||||
def comment(self) -> str:
|
||||
"""Crawl the proto source code and retrieve comments
|
||||
for this object.
|
||||
"""
|
||||
return get_comment(
|
||||
proto_file=self.source_file, path=self.path, indent=self.comment_indent
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginRequestCompiler:
|
||||
plugin_request_obj: CodeGeneratorRequest
|
||||
output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def all_messages(self) -> List["MessageCompiler"]:
|
||||
"""All of the messages in this request.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[MessageCompiler]
|
||||
List of all of the messages in this request.
|
||||
"""
|
||||
return [
|
||||
msg for output in self.output_packages.values() for msg in output.messages
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputTemplate:
|
||||
"""Representation of an output .py file.
|
||||
|
||||
Each output file corresponds to a .proto input file,
|
||||
but may need references to other .proto files to be
|
||||
built.
|
||||
"""
|
||||
|
||||
parent_request: PluginRequestCompiler
|
||||
package_proto_obj: FileDescriptorProto
|
||||
input_files: List[str] = field(default_factory=list)
|
||||
imports: Set[str] = field(default_factory=set)
|
||||
datetime_imports: Set[str] = field(default_factory=set)
|
||||
typing_imports: Set[str] = field(default_factory=set)
|
||||
pydantic_imports: Set[str] = field(default_factory=set)
|
||||
builtins_import: bool = False
|
||||
messages: List["MessageCompiler"] = field(default_factory=list)
|
||||
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
|
||||
services: List["ServiceCompiler"] = field(default_factory=list)
|
||||
imports_type_checking_only: Set[str] = field(default_factory=set)
|
||||
pydantic_dataclasses: bool = False
|
||||
output: bool = True
|
||||
|
||||
@property
|
||||
def package(self) -> str:
|
||||
"""Name of input package.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Name of input package.
|
||||
"""
|
||||
return self.package_proto_obj.package
|
||||
|
||||
@property
|
||||
def input_filenames(self) -> Iterable[str]:
|
||||
"""Names of the input files used to build this output.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Iterable[str]
|
||||
Names of the input files used to build this output.
|
||||
"""
|
||||
return sorted(f.name for f in self.input_files)
|
||||
|
||||
@property
|
||||
def python_module_imports(self) -> Set[str]:
|
||||
imports = set()
|
||||
if any(x for x in self.messages if any(x.deprecated_fields)):
|
||||
imports.add("warnings")
|
||||
if self.builtins_import:
|
||||
imports.add("builtins")
|
||||
return imports
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageCompiler(ProtoContentBase):
|
||||
"""Representation of a protobuf message."""
|
||||
|
||||
source_file: FileDescriptorProto
|
||||
parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER
|
||||
proto_obj: DescriptorProto = PLACEHOLDER
|
||||
path: List[int] = PLACEHOLDER
|
||||
fields: List[Union["FieldCompiler", "MessageCompiler"]] = field(
|
||||
default_factory=list
|
||||
)
|
||||
deprecated: bool = field(default=False, init=False)
|
||||
builtins_types: Set[str] = field(default_factory=set)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Add message to output file
|
||||
if isinstance(self.parent, OutputTemplate):
|
||||
if isinstance(self, EnumDefinitionCompiler):
|
||||
self.output_file.enums.append(self)
|
||||
else:
|
||||
self.output_file.messages.append(self)
|
||||
self.deprecated = self.proto_obj.options.deprecated
|
||||
super().__post_init__()
|
||||
|
||||
@property
|
||||
def proto_name(self) -> str:
|
||||
return self.proto_obj.name
|
||||
|
||||
@property
|
||||
def py_name(self) -> str:
|
||||
return pythonize_class_name(self.proto_name)
|
||||
|
||||
@property
|
||||
def annotation(self) -> str:
|
||||
if self.repeated:
|
||||
return f"List[{self.py_name}]"
|
||||
return self.py_name
|
||||
|
||||
@property
|
||||
def deprecated_fields(self) -> Iterator[str]:
|
||||
for f in self.fields:
|
||||
if f.deprecated:
|
||||
yield f.py_name
|
||||
|
||||
@property
|
||||
def has_deprecated_fields(self) -> bool:
|
||||
return any(self.deprecated_fields)
|
||||
|
||||
@property
|
||||
def has_oneof_fields(self) -> bool:
|
||||
return any(isinstance(field, OneOfFieldCompiler) for field in self.fields)
|
||||
|
||||
@property
|
||||
def has_message_field(self) -> bool:
|
||||
return any(
|
||||
(
|
||||
field.proto_obj.type in PROTO_MESSAGE_TYPES
|
||||
for field in self.fields
|
||||
if isinstance(field.proto_obj, FieldDescriptorProto)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def is_map(
|
||||
proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto
|
||||
) -> bool:
|
||||
"""True if proto_field_obj is a map, otherwise False."""
|
||||
if proto_field_obj.type == FieldDescriptorProtoType.TYPE_MESSAGE:
|
||||
if not hasattr(parent_message, "nested_type"):
|
||||
return False
|
||||
|
||||
# This might be a map...
|
||||
message_type = proto_field_obj.type_name.split(".").pop().lower()
|
||||
map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry"
|
||||
if message_type == map_entry:
|
||||
for nested in parent_message.nested_type: # parent message
|
||||
if (
|
||||
nested.name.replace("_", "").lower() == map_entry
|
||||
and nested.options.map_entry
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
|
||||
"""
|
||||
True if proto_field_obj is a OneOf, otherwise False.
|
||||
|
||||
.. warning::
|
||||
Becuase the message from protoc is defined in proto2, and aristaproto works with
|
||||
proto3, and interpreting the FieldDescriptorProto.oneof_index field requires
|
||||
distinguishing between default and unset values (which proto3 doesn't support),
|
||||
we have to hack the generated FieldDescriptorProto class for this to work.
|
||||
The hack consists of setting group="oneof_index" in the field metadata,
|
||||
essentially making oneof_index the sole member of a one_of group, which allows
|
||||
us to tell whether it was set, via the which_one_of interface.
|
||||
"""
|
||||
|
||||
return (
|
||||
not proto_field_obj.proto3_optional
|
||||
and which_one_of(proto_field_obj, "oneof_index")[0] == "oneof_index"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FieldCompiler(MessageCompiler):
|
||||
parent: MessageCompiler = PLACEHOLDER
|
||||
proto_obj: FieldDescriptorProto = PLACEHOLDER
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Add field to message
|
||||
self.parent.fields.append(self)
|
||||
# Check for new imports
|
||||
self.add_imports_to(self.output_file)
|
||||
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
|
||||
|
||||
def get_field_string(self, indent: int = 4) -> str:
|
||||
"""Construct string representation of this field as a field."""
|
||||
name = f"{self.py_name}"
|
||||
annotations = f": {self.annotation}"
|
||||
field_args = ", ".join(
|
||||
([""] + self.aristaproto_field_args) if self.aristaproto_field_args else []
|
||||
)
|
||||
aristaproto_field_type = (
|
||||
f"aristaproto.{self.field_type}_field({self.proto_obj.number}{field_args})"
|
||||
)
|
||||
if self.py_name in dir(builtins):
|
||||
self.parent.builtins_types.add(self.py_name)
|
||||
return f"{name}{annotations} = {aristaproto_field_type}"
|
||||
|
||||
@property
|
||||
def aristaproto_field_args(self) -> List[str]:
|
||||
args = []
|
||||
if self.field_wraps:
|
||||
args.append(f"wraps={self.field_wraps}")
|
||||
if self.optional:
|
||||
args.append(f"optional=True")
|
||||
return args
|
||||
|
||||
@property
|
||||
def datetime_imports(self) -> Set[str]:
|
||||
imports = set()
|
||||
annotation = self.annotation
|
||||
# FIXME: false positives - e.g. `MyDatetimedelta`
|
||||
if "timedelta" in annotation:
|
||||
imports.add("timedelta")
|
||||
if "datetime" in annotation:
|
||||
imports.add("datetime")
|
||||
return imports
|
||||
|
||||
@property
|
||||
def typing_imports(self) -> Set[str]:
|
||||
imports = set()
|
||||
annotation = self.annotation
|
||||
if "Optional[" in annotation:
|
||||
imports.add("Optional")
|
||||
if "List[" in annotation:
|
||||
imports.add("List")
|
||||
if "Dict[" in annotation:
|
||||
imports.add("Dict")
|
||||
return imports
|
||||
|
||||
@property
|
||||
def pydantic_imports(self) -> Set[str]:
|
||||
return set()
|
||||
|
||||
@property
|
||||
def use_builtins(self) -> bool:
|
||||
return self.py_type in self.parent.builtins_types or (
|
||||
self.py_type == self.py_name and self.py_name in dir(builtins)
|
||||
)
|
||||
|
||||
def add_imports_to(self, output_file: OutputTemplate) -> None:
|
||||
output_file.datetime_imports.update(self.datetime_imports)
|
||||
output_file.typing_imports.update(self.typing_imports)
|
||||
output_file.pydantic_imports.update(self.pydantic_imports)
|
||||
output_file.builtins_import = output_file.builtins_import or self.use_builtins
|
||||
|
||||
@property
|
||||
def field_wraps(self) -> Optional[str]:
|
||||
"""Returns aristaproto wrapped field type or None."""
|
||||
match_wrapper = re.match(
|
||||
r"\.google\.protobuf\.(.+)Value$", self.proto_obj.type_name
|
||||
)
|
||||
if match_wrapper:
|
||||
wrapped_type = "TYPE_" + match_wrapper.group(1).upper()
|
||||
if hasattr(aristaproto, wrapped_type):
|
||||
return f"aristaproto.{wrapped_type}"
|
||||
return None
|
||||
|
||||
@property
|
||||
def repeated(self) -> bool:
|
||||
return (
|
||||
self.proto_obj.label == FieldDescriptorProtoLabel.LABEL_REPEATED
|
||||
and not is_map(self.proto_obj, self.parent)
|
||||
)
|
||||
|
||||
@property
|
||||
def optional(self) -> bool:
|
||||
return self.proto_obj.proto3_optional
|
||||
|
||||
@property
|
||||
def mutable(self) -> bool:
|
||||
"""True if the field is a mutable type, otherwise False."""
|
||||
return self.annotation.startswith(("List[", "Dict["))
|
||||
|
||||
@property
|
||||
def field_type(self) -> str:
|
||||
"""String representation of proto field type."""
|
||||
return (
|
||||
FieldDescriptorProtoType(self.proto_obj.type)
|
||||
.name.lower()
|
||||
.replace("type_", "")
|
||||
)
|
||||
|
||||
@property
|
||||
def default_value_string(self) -> str:
|
||||
"""Python representation of the default proto value."""
|
||||
if self.repeated:
|
||||
return "[]"
|
||||
if self.optional:
|
||||
return "None"
|
||||
if self.py_type == "int":
|
||||
return "0"
|
||||
if self.py_type == "float":
|
||||
return "0.0"
|
||||
elif self.py_type == "bool":
|
||||
return "False"
|
||||
elif self.py_type == "str":
|
||||
return '""'
|
||||
elif self.py_type == "bytes":
|
||||
return 'b""'
|
||||
elif self.field_type == "enum":
|
||||
enum_proto_obj_name = self.proto_obj.type_name.split(".").pop()
|
||||
enum = next(
|
||||
e
|
||||
for e in self.output_file.enums
|
||||
if e.proto_obj.name == enum_proto_obj_name
|
||||
)
|
||||
return enum.default_value_string
|
||||
else:
|
||||
# Message type
|
||||
return "None"
|
||||
|
||||
@property
|
||||
def packed(self) -> bool:
|
||||
"""True if the wire representation is a packed format."""
|
||||
return self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES
|
||||
|
||||
@property
|
||||
def py_name(self) -> str:
|
||||
"""Pythonized name."""
|
||||
return pythonize_field_name(self.proto_name)
|
||||
|
||||
@property
|
||||
def proto_name(self) -> str:
|
||||
"""Original protobuf name."""
|
||||
return self.proto_obj.name
|
||||
|
||||
@property
|
||||
def py_type(self) -> str:
|
||||
"""String representation of Python type."""
|
||||
if self.proto_obj.type in PROTO_FLOAT_TYPES:
|
||||
return "float"
|
||||
elif self.proto_obj.type in PROTO_INT_TYPES:
|
||||
return "int"
|
||||
elif self.proto_obj.type in PROTO_BOOL_TYPES:
|
||||
return "bool"
|
||||
elif self.proto_obj.type in PROTO_STR_TYPES:
|
||||
return "str"
|
||||
elif self.proto_obj.type in PROTO_BYTES_TYPES:
|
||||
return "bytes"
|
||||
elif self.proto_obj.type in PROTO_MESSAGE_TYPES:
|
||||
# Type referencing another defined Message or a named enum
|
||||
return get_type_reference(
|
||||
package=self.output_file.package,
|
||||
imports=self.output_file.imports,
|
||||
source_type=self.proto_obj.type_name,
|
||||
pydantic=self.output_file.pydantic_dataclasses,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown type {self.proto_obj.type}")
|
||||
|
||||
@property
|
||||
def annotation(self) -> str:
|
||||
py_type = self.py_type
|
||||
if self.use_builtins:
|
||||
py_type = f"builtins.{py_type}"
|
||||
if self.repeated:
|
||||
return f"List[{py_type}]"
|
||||
if self.optional:
|
||||
return f"Optional[{py_type}]"
|
||||
return py_type
|
||||
|
||||
|
||||
@dataclass
|
||||
class OneOfFieldCompiler(FieldCompiler):
|
||||
@property
|
||||
def aristaproto_field_args(self) -> List[str]:
|
||||
args = super().aristaproto_field_args
|
||||
group = self.parent.proto_obj.oneof_decl[self.proto_obj.oneof_index].name
|
||||
args.append(f'group="{group}"')
|
||||
return args
|
||||
|
||||
|
||||
@dataclass
|
||||
class PydanticOneOfFieldCompiler(OneOfFieldCompiler):
|
||||
@property
|
||||
def optional(self) -> bool:
|
||||
# Force the optional to be True. This will allow the pydantic dataclass
|
||||
# to validate the object correctly by allowing the field to be let empty.
|
||||
# We add a pydantic validator later to ensure exactly one field is defined.
|
||||
return True
|
||||
|
||||
@property
|
||||
def pydantic_imports(self) -> Set[str]:
|
||||
return {"root_validator"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MapEntryCompiler(FieldCompiler):
|
||||
py_k_type: Type = PLACEHOLDER
|
||||
py_v_type: Type = PLACEHOLDER
|
||||
proto_k_type: str = PLACEHOLDER
|
||||
proto_v_type: str = PLACEHOLDER
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Explore nested types and set k_type and v_type if unset."""
|
||||
map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry"
|
||||
for nested in self.parent.proto_obj.nested_type:
|
||||
if (
|
||||
nested.name.replace("_", "").lower() == map_entry
|
||||
and nested.options.map_entry
|
||||
):
|
||||
# Get Python types
|
||||
self.py_k_type = FieldCompiler(
|
||||
source_file=self.source_file,
|
||||
parent=self,
|
||||
proto_obj=nested.field[0], # key
|
||||
).py_type
|
||||
self.py_v_type = FieldCompiler(
|
||||
source_file=self.source_file,
|
||||
parent=self,
|
||||
proto_obj=nested.field[1], # value
|
||||
).py_type
|
||||
|
||||
# Get proto types
|
||||
self.proto_k_type = FieldDescriptorProtoType(nested.field[0].type).name
|
||||
self.proto_v_type = FieldDescriptorProtoType(nested.field[1].type).name
|
||||
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
|
||||
|
||||
@property
|
||||
def aristaproto_field_args(self) -> List[str]:
|
||||
return [f"aristaproto.{self.proto_k_type}", f"aristaproto.{self.proto_v_type}"]
|
||||
|
||||
@property
|
||||
def field_type(self) -> str:
|
||||
return "map"
|
||||
|
||||
@property
|
||||
def annotation(self) -> str:
|
||||
return f"Dict[{self.py_k_type}, {self.py_v_type}]"
|
||||
|
||||
@property
|
||||
def repeated(self) -> bool:
|
||||
return False # maps cannot be repeated
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnumDefinitionCompiler(MessageCompiler):
|
||||
"""Representation of a proto Enum definition."""
|
||||
|
||||
proto_obj: EnumDescriptorProto = PLACEHOLDER
|
||||
entries: List["EnumDefinitionCompiler.EnumEntry"] = PLACEHOLDER
|
||||
|
||||
@dataclass(unsafe_hash=True)
|
||||
class EnumEntry:
|
||||
"""Representation of an Enum entry."""
|
||||
|
||||
name: str
|
||||
value: int
|
||||
comment: str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Get entries/allowed values for this Enum
|
||||
self.entries = [
|
||||
self.EnumEntry(
|
||||
name=pythonize_enum_member_name(
|
||||
entry_proto_value.name, self.proto_obj.name
|
||||
),
|
||||
value=entry_proto_value.number,
|
||||
comment=get_comment(
|
||||
proto_file=self.source_file, path=self.path + [2, entry_number]
|
||||
),
|
||||
)
|
||||
for entry_number, entry_proto_value in enumerate(self.proto_obj.value)
|
||||
]
|
||||
super().__post_init__() # call MessageCompiler __post_init__
|
||||
|
||||
@property
|
||||
def default_value_string(self) -> str:
|
||||
"""Python representation of the default value for Enums.
|
||||
|
||||
As per the spec, this is the first value of the Enum.
|
||||
"""
|
||||
return str(self.entries[0].value) # ideally, should ALWAYS be int(0)!
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceCompiler(ProtoContentBase):
|
||||
parent: OutputTemplate = PLACEHOLDER
|
||||
proto_obj: DescriptorProto = PLACEHOLDER
|
||||
path: List[int] = PLACEHOLDER
|
||||
methods: List["ServiceMethodCompiler"] = field(default_factory=list)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Add service to output file
|
||||
self.output_file.services.append(self)
|
||||
self.output_file.typing_imports.add("Dict")
|
||||
super().__post_init__() # check for unset fields
|
||||
|
||||
@property
|
||||
def proto_name(self) -> str:
|
||||
return self.proto_obj.name
|
||||
|
||||
@property
|
||||
def py_name(self) -> str:
|
||||
return pythonize_class_name(self.proto_name)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceMethodCompiler(ProtoContentBase):
|
||||
parent: ServiceCompiler
|
||||
proto_obj: MethodDescriptorProto
|
||||
path: List[int] = PLACEHOLDER
|
||||
comment_indent: int = 8
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Add method to service
|
||||
self.parent.methods.append(self)
|
||||
|
||||
# Check for imports
|
||||
if "Optional" in self.py_output_message_type:
|
||||
self.output_file.typing_imports.add("Optional")
|
||||
|
||||
# Check for Async imports
|
||||
if self.client_streaming:
|
||||
self.output_file.typing_imports.add("AsyncIterable")
|
||||
self.output_file.typing_imports.add("Iterable")
|
||||
self.output_file.typing_imports.add("Union")
|
||||
|
||||
# Required by both client and server
|
||||
if self.client_streaming or self.server_streaming:
|
||||
self.output_file.typing_imports.add("AsyncIterator")
|
||||
|
||||
# add imports required for request arguments timeout, deadline and metadata
|
||||
self.output_file.typing_imports.add("Optional")
|
||||
self.output_file.imports_type_checking_only.add("import grpclib.server")
|
||||
self.output_file.imports_type_checking_only.add(
|
||||
"from aristaproto.grpc.grpclib_client import MetadataLike"
|
||||
)
|
||||
self.output_file.imports_type_checking_only.add(
|
||||
"from grpclib.metadata import Deadline"
|
||||
)
|
||||
|
||||
super().__post_init__() # check for unset fields
|
||||
|
||||
@property
|
||||
def py_name(self) -> str:
|
||||
"""Pythonized method name."""
|
||||
return pythonize_method_name(self.proto_obj.name)
|
||||
|
||||
@property
|
||||
def proto_name(self) -> str:
|
||||
"""Original protobuf name."""
|
||||
return self.proto_obj.name
|
||||
|
||||
@property
|
||||
def route(self) -> str:
|
||||
package_part = (
|
||||
f"{self.output_file.package}." if self.output_file.package else ""
|
||||
)
|
||||
return f"/{package_part}{self.parent.proto_name}/{self.proto_name}"
|
||||
|
||||
@property
|
||||
def py_input_message(self) -> Optional[MessageCompiler]:
|
||||
"""Find the input message object.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[MessageCompiler]
|
||||
Method instance representing the input message.
|
||||
If not input message could be found or there are no
|
||||
input messages, None is returned.
|
||||
"""
|
||||
package, name = parse_source_type_name(self.proto_obj.input_type)
|
||||
|
||||
# Nested types are currently flattened without dots.
|
||||
# Todo: keep a fully quantified name in types, that is
|
||||
# comparable with method.input_type
|
||||
for msg in self.request.all_messages:
|
||||
if (
|
||||
msg.py_name == pythonize_class_name(name.replace(".", ""))
|
||||
and msg.output_file.package == package
|
||||
):
|
||||
return msg
|
||||
return None
|
||||
|
||||
@property
|
||||
def py_input_message_type(self) -> str:
|
||||
"""String representation of the Python type corresponding to the
|
||||
input message.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
String representation of the Python type corresponding to the input message.
|
||||
"""
|
||||
return get_type_reference(
|
||||
package=self.output_file.package,
|
||||
imports=self.output_file.imports,
|
||||
source_type=self.proto_obj.input_type,
|
||||
unwrap=False,
|
||||
pydantic=self.output_file.pydantic_dataclasses,
|
||||
).strip('"')
|
||||
|
||||
@property
|
||||
def py_input_message_param(self) -> str:
|
||||
"""Param name corresponding to py_input_message_type.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Param name corresponding to py_input_message_type.
|
||||
"""
|
||||
return pythonize_field_name(self.py_input_message_type)
|
||||
|
||||
@property
|
||||
def py_output_message_type(self) -> str:
|
||||
"""String representation of the Python type corresponding to the
|
||||
output message.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
String representation of the Python type corresponding to the output message.
|
||||
"""
|
||||
return get_type_reference(
|
||||
package=self.output_file.package,
|
||||
imports=self.output_file.imports,
|
||||
source_type=self.proto_obj.output_type,
|
||||
unwrap=False,
|
||||
pydantic=self.output_file.pydantic_dataclasses,
|
||||
).strip('"')
|
||||
|
||||
@property
|
||||
def client_streaming(self) -> bool:
|
||||
return self.proto_obj.client_streaming
|
||||
|
||||
@property
|
||||
def server_streaming(self) -> bool:
|
||||
return self.proto_obj.server_streaming
|
221
src/aristaproto/plugin/parser.py
Normal file
221
src/aristaproto/plugin/parser.py
Normal file
|
@ -0,0 +1,221 @@
|
|||
import pathlib
|
||||
import sys
|
||||
from typing import (
|
||||
Generator,
|
||||
List,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from aristaproto.lib.google.protobuf import (
|
||||
DescriptorProto,
|
||||
EnumDescriptorProto,
|
||||
FieldDescriptorProto,
|
||||
FileDescriptorProto,
|
||||
ServiceDescriptorProto,
|
||||
)
|
||||
from aristaproto.lib.google.protobuf.compiler import (
|
||||
CodeGeneratorRequest,
|
||||
CodeGeneratorResponse,
|
||||
CodeGeneratorResponseFeature,
|
||||
CodeGeneratorResponseFile,
|
||||
)
|
||||
|
||||
from .compiler import outputfile_compiler
|
||||
from .models import (
|
||||
EnumDefinitionCompiler,
|
||||
FieldCompiler,
|
||||
MapEntryCompiler,
|
||||
MessageCompiler,
|
||||
OneOfFieldCompiler,
|
||||
OutputTemplate,
|
||||
PluginRequestCompiler,
|
||||
PydanticOneOfFieldCompiler,
|
||||
ServiceCompiler,
|
||||
ServiceMethodCompiler,
|
||||
is_map,
|
||||
is_oneof,
|
||||
)
|
||||
|
||||
|
||||
def traverse(
|
||||
proto_file: FileDescriptorProto,
|
||||
) -> Generator[
|
||||
Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None
|
||||
]:
|
||||
# Todo: Keep information about nested hierarchy
|
||||
def _traverse(
|
||||
path: List[int],
|
||||
items: Union[List[EnumDescriptorProto], List[DescriptorProto]],
|
||||
prefix: str = "",
|
||||
) -> Generator[
|
||||
Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None
|
||||
]:
|
||||
for i, item in enumerate(items):
|
||||
# Adjust the name since we flatten the hierarchy.
|
||||
# Todo: don't change the name, but include full name in returned tuple
|
||||
item.name = next_prefix = f"{prefix}_{item.name}"
|
||||
yield item, [*path, i]
|
||||
|
||||
if isinstance(item, DescriptorProto):
|
||||
# Get nested types.
|
||||
yield from _traverse([*path, i, 4], item.enum_type, next_prefix)
|
||||
yield from _traverse([*path, i, 3], item.nested_type, next_prefix)
|
||||
|
||||
yield from _traverse([5], proto_file.enum_type)
|
||||
yield from _traverse([4], proto_file.message_type)
|
||||
|
||||
|
||||
def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
|
||||
response = CodeGeneratorResponse()
|
||||
|
||||
plugin_options = request.parameter.split(",") if request.parameter else []
|
||||
response.supported_features = CodeGeneratorResponseFeature.FEATURE_PROTO3_OPTIONAL
|
||||
|
||||
request_data = PluginRequestCompiler(plugin_request_obj=request)
|
||||
# Gather output packages
|
||||
for proto_file in request.proto_file:
|
||||
output_package_name = proto_file.package
|
||||
if output_package_name not in request_data.output_packages:
|
||||
# Create a new output if there is no output for this package
|
||||
request_data.output_packages[output_package_name] = OutputTemplate(
|
||||
parent_request=request_data, package_proto_obj=proto_file
|
||||
)
|
||||
# Add this input file to the output corresponding to this package
|
||||
request_data.output_packages[output_package_name].input_files.append(proto_file)
|
||||
|
||||
if (
|
||||
proto_file.package == "google.protobuf"
|
||||
and "INCLUDE_GOOGLE" not in plugin_options
|
||||
):
|
||||
# If not INCLUDE_GOOGLE,
|
||||
# skip outputting Google's well-known types
|
||||
request_data.output_packages[output_package_name].output = False
|
||||
|
||||
if "pydantic_dataclasses" in plugin_options:
|
||||
request_data.output_packages[
|
||||
output_package_name
|
||||
].pydantic_dataclasses = True
|
||||
|
||||
# Read Messages and Enums
|
||||
# We need to read Messages before Services in so that we can
|
||||
# get the references to input/output messages for each service
|
||||
for output_package_name, output_package in request_data.output_packages.items():
|
||||
for proto_input_file in output_package.input_files:
|
||||
for item, path in traverse(proto_input_file):
|
||||
read_protobuf_type(
|
||||
source_file=proto_input_file,
|
||||
item=item,
|
||||
path=path,
|
||||
output_package=output_package,
|
||||
)
|
||||
|
||||
# Read Services
|
||||
for output_package_name, output_package in request_data.output_packages.items():
|
||||
for proto_input_file in output_package.input_files:
|
||||
for index, service in enumerate(proto_input_file.service):
|
||||
read_protobuf_service(service, index, output_package)
|
||||
|
||||
# Generate output files
|
||||
output_paths: Set[pathlib.Path] = set()
|
||||
for output_package_name, output_package in request_data.output_packages.items():
|
||||
if not output_package.output:
|
||||
continue
|
||||
|
||||
# Add files to the response object
|
||||
output_path = pathlib.Path(*output_package_name.split("."), "__init__.py")
|
||||
output_paths.add(output_path)
|
||||
|
||||
response.file.append(
|
||||
CodeGeneratorResponseFile(
|
||||
name=str(output_path),
|
||||
# Render and then format the output file
|
||||
content=outputfile_compiler(output_file=output_package),
|
||||
)
|
||||
)
|
||||
|
||||
# Make each output directory a package with __init__ file
|
||||
init_files = {
|
||||
directory.joinpath("__init__.py")
|
||||
for path in output_paths
|
||||
for directory in path.parents
|
||||
if not directory.joinpath("__init__.py").exists()
|
||||
} - output_paths
|
||||
|
||||
for init_file in init_files:
|
||||
response.file.append(CodeGeneratorResponseFile(name=str(init_file)))
|
||||
|
||||
for output_package_name in sorted(output_paths.union(init_files)):
|
||||
print(f"Writing {output_package_name}", file=sys.stderr)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def _make_one_of_field_compiler(
|
||||
output_package: OutputTemplate,
|
||||
source_file: "FileDescriptorProto",
|
||||
parent: MessageCompiler,
|
||||
proto_obj: "FieldDescriptorProto",
|
||||
path: List[int],
|
||||
) -> FieldCompiler:
|
||||
pydantic = output_package.pydantic_dataclasses
|
||||
Cls = PydanticOneOfFieldCompiler if pydantic else OneOfFieldCompiler
|
||||
return Cls(
|
||||
source_file=source_file,
|
||||
parent=parent,
|
||||
proto_obj=proto_obj,
|
||||
path=path,
|
||||
)
|
||||
|
||||
|
||||
def read_protobuf_type(
|
||||
item: DescriptorProto,
|
||||
path: List[int],
|
||||
source_file: "FileDescriptorProto",
|
||||
output_package: OutputTemplate,
|
||||
) -> None:
|
||||
if isinstance(item, DescriptorProto):
|
||||
if item.options.map_entry:
|
||||
# Skip generated map entry messages since we just use dicts
|
||||
return
|
||||
# Process Message
|
||||
message_data = MessageCompiler(
|
||||
source_file=source_file, parent=output_package, proto_obj=item, path=path
|
||||
)
|
||||
for index, field in enumerate(item.field):
|
||||
if is_map(field, item):
|
||||
MapEntryCompiler(
|
||||
source_file=source_file,
|
||||
parent=message_data,
|
||||
proto_obj=field,
|
||||
path=path + [2, index],
|
||||
)
|
||||
elif is_oneof(field):
|
||||
_make_one_of_field_compiler(
|
||||
output_package, source_file, message_data, field, path + [2, index]
|
||||
)
|
||||
else:
|
||||
FieldCompiler(
|
||||
source_file=source_file,
|
||||
parent=message_data,
|
||||
proto_obj=field,
|
||||
path=path + [2, index],
|
||||
)
|
||||
elif isinstance(item, EnumDescriptorProto):
|
||||
# Enum
|
||||
EnumDefinitionCompiler(
|
||||
source_file=source_file, parent=output_package, proto_obj=item, path=path
|
||||
)
|
||||
|
||||
|
||||
def read_protobuf_service(
|
||||
service: ServiceDescriptorProto, index: int, output_package: OutputTemplate
|
||||
) -> None:
|
||||
service_data = ServiceCompiler(
|
||||
parent=output_package, proto_obj=service, path=[6, index]
|
||||
)
|
||||
for j, method in enumerate(service.method):
|
||||
ServiceMethodCompiler(
|
||||
parent=service_data, proto_obj=method, path=[6, index, 2, j]
|
||||
)
|
2
src/aristaproto/plugin/plugin.bat
Normal file
2
src/aristaproto/plugin/plugin.bat
Normal file
|
@ -0,0 +1,2 @@
|
|||
@SET plugin_dir=%~dp0
|
||||
@python -m %plugin_dir% %*
|
Loading…
Add table
Add a link
Reference in a new issue