Skip to content

Commit

Permalink
fix: add oneof fields to generated protoplus init (#485)
Browse files Browse the repository at this point in the history
Fixes: #484
  • Loading branch information
crwilcox authored Jul 7, 2020
1 parent 9076362 commit be5a847
Show file tree
Hide file tree
Showing 13 changed files with 193 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class {{ message.name }}({{ p }}.Message):
{% else -%}
{{ field.name }} = {{ p }}.{% if field.repeated %}Repeated{% endif %}Field(
{{- p }}.{{ field.proto_type }}, number={{ field.number }}
{% if field.oneof %}, oneof='{{ field.oneof }}'{% endif %}
{%- if field.enum or field.message %},
{{ field.proto_type.lower() }}={{ field.type.ident.rel(message.ident) }},
{% endif %})
Expand Down
51 changes: 47 additions & 4 deletions gapic/schema/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from gapic.schema import wrappers
from gapic.schema import naming as api_naming
from gapic.utils import cached_property
from gapic.utils import nth
from gapic.utils import to_snake_case
from gapic.utils import RESERVED_NAMES

Expand Down Expand Up @@ -556,14 +557,42 @@ def _load_children(self,
answer[wrapped.name] = wrapped
return answer

def _get_oneofs(self,
oneof_pbs: Sequence[descriptor_pb2.OneofDescriptorProto],
address: metadata.Address, path: Tuple[int, ...],
) -> Dict[str, wrappers.Oneof]:
"""Return a dictionary of wrapped oneofs for the given message.
Args:
oneof_fields (Sequence[~.descriptor_pb2.OneofDescriptorProto]): A
sequence of protobuf field objects.
address (~.metadata.Address): An address object denoting the
location of these oneofs.
path (Tuple[int]): The source location path thus far, as
understood by ``SourceCodeInfo.Location``.
Returns:
Mapping[str, ~.wrappers.Oneof]: A ordered mapping of
:class:`~.wrappers.Oneof` objects.
"""
# Iterate over the oneofs and collect them into a dictionary.
answer = collections.OrderedDict(
(oneof_pb.name, wrappers.Oneof(oneof_pb=oneof_pb))
for i, oneof_pb in enumerate(oneof_pbs)
)

# Done; return the answer.
return answer

def _get_fields(self,
field_pbs: Sequence[descriptor_pb2.FieldDescriptorProto],
address: metadata.Address, path: Tuple[int, ...],
oneofs: Optional[Dict[str, wrappers.Oneof]] = None
) -> Dict[str, wrappers.Field]:
"""Return a dictionary of wrapped fields for the given message.
Args:
fields (Sequence[~.descriptor_pb2.FieldDescriptorProto]): A
field_pbs (Sequence[~.descriptor_pb2.FieldDescriptorProto]): A
sequence of protobuf field objects.
address (~.metadata.Address): An address object denoting the
location of these fields.
Expand All @@ -585,7 +614,13 @@ def _get_fields(self,
# first) and this will be None. This case is addressed in the
# `_load_message` method.
answer: Dict[str, wrappers.Field] = collections.OrderedDict()
for field_pb, i in zip(field_pbs, range(0, sys.maxsize)):
for i, field_pb in enumerate(field_pbs):
is_oneof = oneofs and field_pb.oneof_index > 0
oneof_name = nth(
(oneofs or {}).keys(),
field_pb.oneof_index
) if is_oneof else None

answer[field_pb.name] = wrappers.Field(
field_pb=field_pb,
enum=self.api_enums.get(field_pb.type_name.lstrip('.')),
Expand All @@ -594,6 +629,7 @@ def _get_fields(self,
address=address.child(field_pb.name, path + (i,)),
documentation=self.docs.get(path + (i,), self.EMPTY),
),
oneof=oneof_name,
)

# Done; return the answer.
Expand Down Expand Up @@ -779,19 +815,25 @@ def _load_message(self,
loader=self._load_message,
path=path + (3,),
)
# self._load_children(message.oneof_decl, loader=self._load_field,
# address=nested_addr, info=info.get(8, {}))

oneofs = self._get_oneofs(
message_pb.oneof_decl,
address=address,
path=path + (7,),
)

# Create a dictionary of all the fields for this message.
fields = self._get_fields(
message_pb.field,
address=address,
path=path + (2,),
oneofs=oneofs,
)
fields.update(self._get_fields(
message_pb.extension,
address=address,
path=path + (6,),
oneofs=oneofs,
))

# Create a message correspoding to this descriptor.
Expand All @@ -804,6 +846,7 @@ def _load_message(self,
address=address,
documentation=self.docs.get(path, self.EMPTY),
),
oneofs=oneofs,
)
return self.proto_messages[address.proto]

Expand Down
11 changes: 11 additions & 0 deletions gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class Field:
meta: metadata.Metadata = dataclasses.field(
default_factory=metadata.Metadata,
)
oneof: Optional[str] = None

def __getattr__(self, name):
return getattr(self.field_pb, name)
Expand Down Expand Up @@ -206,6 +207,15 @@ def with_context(self, *, collisions: FrozenSet[str]) -> 'Field':
)


@dataclasses.dataclass(frozen=True)
class Oneof:
"""Description of a field."""
oneof_pb: descriptor_pb2.OneofDescriptorProto

def __getattr__(self, name):
return getattr(self.oneof_pb, name)


@dataclasses.dataclass(frozen=True)
class MessageType:
"""Description of a message (defined with the ``message`` keyword)."""
Expand All @@ -220,6 +230,7 @@ class MessageType:
meta: metadata.Metadata = dataclasses.field(
default_factory=metadata.Metadata,
)
oneofs: Optional[Mapping[str, 'Oneof']] = None

def __getattr__(self, name):
return getattr(self.message_pb, name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ class {{ message.name }}({{ p }}.Message):
{{- p }}.{{ key_field.proto_type }}, {{ p }}.{{ value_field.proto_type }}, number={{ field.number }}
{%- if value_field.enum or value_field.message %},
{{ value_field.proto_type.lower() }}={{ value_field.type.ident.rel(message.ident) }},
{% endif %})
{% endif %}) {# enum or message#}
{% endwith -%}
{% else -%}
{% else -%} {# field.map #}
{{ field.name }} = {{ p }}.{% if field.repeated %}Repeated{% endif %}Field(
{{- p }}.{{ field.proto_type }}, number={{ field.number }}
{% if field.oneof %}, oneof='{{ field.oneof }}'{% endif %}
{%- if field.enum or field.message %},
{{ field.proto_type.lower() }}={{ field.type.ident.rel(message.ident) }},
{% endif %})
{% endif -%}
{% endfor -%}
{% endif %}) {# enum or message #}
{% endif -%} {# field.map #}
{% endfor -%} {# for field in message.fields.values#}
{{ '\n\n' }}
2 changes: 1 addition & 1 deletion gapic/templates/noxfile.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def unit(session):
'--cov-config=.coveragerc',
'--cov-report=term',
'--cov-report=html',
os.path.join('tests', 'unit', '{{ api.naming.versioned_module_name }}'),
os.path.join('tests', 'unit',)
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,9 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'):
call.return_value = iter([{{ method.output.ident }}()])
{% else -%}
call.return_value = {{ method.output.ident }}(
{%- for field in method.output.fields.values() | rejectattr('message') %}
{%- for field in method.output.fields.values() | rejectattr('message')%}{% if not (field.oneof and not field.proto3_optional) %}
{{ field.name }}={{ field.mock_value }},
{%- endfor %}
{% endif %}{%- endfor %}
)
{% endif -%}
{% if method.client_streaming %}
Expand Down Expand Up @@ -318,14 +318,15 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'):
assert isinstance(message, {{ method.output.ident }})
{% else -%}
assert isinstance(response, {{ method.client_output.ident }})
{% for field in method.output.fields.values() | rejectattr('message') -%}
{% for field in method.output.fields.values() | rejectattr('message') -%}{% if not (field.oneof and not field.proto3_optional) %}
{% if field.field_pb.type in [1, 2] -%} {# Use approx eq for floats -#}
assert math.isclose(response.{{ field.name }}, {{ field.mock_value }}, rel_tol=1e-6)
{% elif field.field_pb.type == 8 -%} {# Use 'is' for bools #}
assert response.{{ field.name }} is {{ field.mock_value }}
{% else -%}
assert response.{{ field.name }} == {{ field.mock_value }}
{% endif -%}
{% endif -%} {# end oneof/optional #}
{% endfor %}
{% endif %}

Expand Down Expand Up @@ -368,8 +369,9 @@ async def test_{{ method.name|snake_case }}_async(transport: str = 'grpc_asyncio
{%- else -%}
grpc_helpers_async.FakeStreamUnaryCall
{%- endif -%}({{ method.output.ident }}(
{%- for field in method.output.fields.values() | rejectattr('message') %}
{%- for field in method.output.fields.values() | rejectattr('message') %}{% if not (field.oneof and not field.proto3_optional) %}
{{ field.name }}={{ field.mock_value }},
{%- endif %}
{%- endfor %}
))
{% endif -%}
Expand Down Expand Up @@ -400,14 +402,15 @@ async def test_{{ method.name|snake_case }}_async(transport: str = 'grpc_asyncio
assert isinstance(message, {{ method.output.ident }})
{% else -%}
assert isinstance(response, {{ method.client_output_async.ident }})
{% for field in method.output.fields.values() | rejectattr('message') -%}
{% for field in method.output.fields.values() | rejectattr('message') -%}{% if not (field.oneof and not field.proto3_optional) %}
{% if field.field_pb.type in [1, 2] -%} {# Use approx eq for floats -#}
assert math.isclose(response.{{ field.name }}, {{ field.mock_value }}, rel_tol=1e-6)
{% elif field.field_pb.type == 8 -%} {# Use 'is' for bools #}
assert response.{{ field.name }} is {{ field.mock_value }}
{% else -%}
assert response.{{ field.name }} == {{ field.mock_value }}
{% endif -%}
{% endif -%} {# oneof/optional #}
{% endfor %}
{% endif %}

Expand Down
2 changes: 2 additions & 0 deletions gapic/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from gapic.utils.cache import cached_property
from gapic.utils.case import to_snake_case
from gapic.utils.code import empty
from gapic.utils.code import nth
from gapic.utils.code import partition
from gapic.utils.doc import doc
from gapic.utils.filename import to_valid_filename
Expand All @@ -29,6 +30,7 @@
'cached_property',
'doc',
'empty',
'nth',
'partition',
'RESERVED_NAMES',
'rst',
Expand Down
15 changes: 14 additions & 1 deletion gapic/utils/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import (Callable, Iterable, List, Tuple, TypeVar)
from typing import (Callable, Iterable, List, Optional, Tuple, TypeVar)
import itertools


def empty(content: str) -> bool:
Expand Down Expand Up @@ -50,3 +51,15 @@ def partition(predicate: Callable[[T], bool],

# Returns trueList, falseList
return results[1], results[0]


def nth(iterable: Iterable[T], n: int, default: Optional[T] = None) -> Optional[T]:
"""Return the nth element of an iterable or a default value.
Args
iterable (Iterable(T)): An iterable on any type.
n (int): The 'index' of the lement to retrieve.
default (Optional(T)): An optional default elemnt if the iterable has
fewer than n elements.
"""
return next(itertools.islice(iterable, n, None), default)
13 changes: 12 additions & 1 deletion test_utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def make_field(
message: wrappers.MessageType = None,
enum: wrappers.EnumType = None,
meta: metadata.Metadata = None,
oneof: str = None,
**kwargs
) -> wrappers.Field:
T = desc.FieldDescriptorProto.Type
Expand All @@ -223,11 +224,13 @@ def make_field(
number=number,
**kwargs
)

return wrappers.Field(
field_pb=field_pb,
enum=enum,
message=message,
meta=meta or metadata.Metadata(),
oneof=oneof,
)


Expand Down Expand Up @@ -322,20 +325,28 @@ def make_enum_pb2(
def make_message_pb2(
name: str,
fields: tuple = (),
oneof_decl: tuple = (),
**kwargs
) -> desc.DescriptorProto:
return desc.DescriptorProto(name=name, field=fields, **kwargs)
return desc.DescriptorProto(name=name, field=fields, oneof_decl=oneof_decl, **kwargs)


def make_field_pb2(name: str, number: int,
type: int = 11, # 11 == message
type_name: str = None,
oneof_index: int = None
) -> desc.FieldDescriptorProto:
return desc.FieldDescriptorProto(
name=name,
number=number,
type=type,
type_name=type_name,
oneof_index=oneof_index,
)

def make_oneof_pb2(name: str) -> desc.OneofDescriptorProto:
return desc.OneofDescriptorProto(
name=name,
)


Expand Down
40 changes: 40 additions & 0 deletions tests/unit/schema/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
make_file_pb2,
make_message_pb2,
make_naming,
make_oneof_pb2,
)


Expand Down Expand Up @@ -239,6 +240,45 @@ def test_proto_keyword_fname():
}


def test_proto_oneof():
# Put together a couple of minimal protos.
fd = (
make_file_pb2(
name='dep.proto',
package='google.dep',
messages=(make_message_pb2(name='ImportedMessage', fields=()),),
),
make_file_pb2(
name='foo.proto',
package='google.example.v1',
messages=(
make_message_pb2(name='Foo', fields=()),
make_message_pb2(
name='Bar',
fields=(
make_field_pb2(name='imported_message', number=1,
type_name='.google.dep.ImportedMessage',
oneof_index=0),
make_field_pb2(
name='primitive', number=2, type=1, oneof_index=0),
),
oneof_decl=(
make_oneof_pb2(name="value_type"),
)
)
)
)
)

# Create an API with those protos.
api_schema = api.API.build(fd, package='google.example.v1')
proto = api_schema.protos['foo.proto']
assert proto.names == {'imported_message', 'Bar', 'primitive', 'Foo'}
oneofs = proto.messages["google.example.v1.Bar"].oneofs
assert len(oneofs) == 1
assert "value_type" in oneofs.keys()


def test_proto_names_import_collision():
# Put together a couple of minimal protos.
fd = (
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/schema/wrappers/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ def test_mock_value_int():
assert field.mock_value == '728'


def test_oneof():
REP = descriptor_pb2.FieldDescriptorProto.Label.Value('LABEL_REPEATED')

field = make_field(oneof="oneof_name")
assert field.oneof == "oneof_name"


def test_mock_value_float():
field = make_field(name='foo_bar', type='TYPE_DOUBLE')
assert field.mock_value == '0.728'
Expand Down
Loading

0 comments on commit be5a847

Please sign in to comment.