Skip to content

Commit 5c6131a

Browse files
authored
fix: add oneof fields to generated protoplus init (#485)
Fixes: #484
1 parent a10685a commit 5c6131a

File tree

13 files changed

+193
-17
lines changed

13 files changed

+193
-17
lines changed

packages/gapic-generator/gapic/ads-templates/%namespace/%name/%version/%sub/types/_message.py.j2

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class {{ message.name }}({{ p }}.Message):
4343
{% else -%}
4444
{{ field.name }} = {{ p }}.{% if field.repeated %}Repeated{% endif %}Field(
4545
{{- p }}.{{ field.proto_type }}, number={{ field.number }}
46+
{% if field.oneof %}, oneof='{{ field.oneof }}'{% endif %}
4647
{%- if field.enum or field.message %},
4748
{{ field.proto_type.lower() }}={{ field.type.ident.rel(message.ident) }},
4849
{% endif %})

packages/gapic-generator/gapic/schema/api.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from gapic.schema import wrappers
3535
from gapic.schema import naming as api_naming
3636
from gapic.utils import cached_property
37+
from gapic.utils import nth
3738
from gapic.utils import to_snake_case
3839
from gapic.utils import RESERVED_NAMES
3940

@@ -556,14 +557,42 @@ def _load_children(self,
556557
answer[wrapped.name] = wrapped
557558
return answer
558559

560+
def _get_oneofs(self,
561+
oneof_pbs: Sequence[descriptor_pb2.OneofDescriptorProto],
562+
address: metadata.Address, path: Tuple[int, ...],
563+
) -> Dict[str, wrappers.Oneof]:
564+
"""Return a dictionary of wrapped oneofs for the given message.
565+
566+
Args:
567+
oneof_fields (Sequence[~.descriptor_pb2.OneofDescriptorProto]): A
568+
sequence of protobuf field objects.
569+
address (~.metadata.Address): An address object denoting the
570+
location of these oneofs.
571+
path (Tuple[int]): The source location path thus far, as
572+
understood by ``SourceCodeInfo.Location``.
573+
574+
Returns:
575+
Mapping[str, ~.wrappers.Oneof]: A ordered mapping of
576+
:class:`~.wrappers.Oneof` objects.
577+
"""
578+
# Iterate over the oneofs and collect them into a dictionary.
579+
answer = collections.OrderedDict(
580+
(oneof_pb.name, wrappers.Oneof(oneof_pb=oneof_pb))
581+
for i, oneof_pb in enumerate(oneof_pbs)
582+
)
583+
584+
# Done; return the answer.
585+
return answer
586+
559587
def _get_fields(self,
560588
field_pbs: Sequence[descriptor_pb2.FieldDescriptorProto],
561589
address: metadata.Address, path: Tuple[int, ...],
590+
oneofs: Optional[Dict[str, wrappers.Oneof]] = None
562591
) -> Dict[str, wrappers.Field]:
563592
"""Return a dictionary of wrapped fields for the given message.
564593
565594
Args:
566-
fields (Sequence[~.descriptor_pb2.FieldDescriptorProto]): A
595+
field_pbs (Sequence[~.descriptor_pb2.FieldDescriptorProto]): A
567596
sequence of protobuf field objects.
568597
address (~.metadata.Address): An address object denoting the
569598
location of these fields.
@@ -585,7 +614,13 @@ def _get_fields(self,
585614
# first) and this will be None. This case is addressed in the
586615
# `_load_message` method.
587616
answer: Dict[str, wrappers.Field] = collections.OrderedDict()
588-
for field_pb, i in zip(field_pbs, range(0, sys.maxsize)):
617+
for i, field_pb in enumerate(field_pbs):
618+
is_oneof = oneofs and field_pb.oneof_index > 0
619+
oneof_name = nth(
620+
(oneofs or {}).keys(),
621+
field_pb.oneof_index
622+
) if is_oneof else None
623+
589624
answer[field_pb.name] = wrappers.Field(
590625
field_pb=field_pb,
591626
enum=self.api_enums.get(field_pb.type_name.lstrip('.')),
@@ -594,6 +629,7 @@ def _get_fields(self,
594629
address=address.child(field_pb.name, path + (i,)),
595630
documentation=self.docs.get(path + (i,), self.EMPTY),
596631
),
632+
oneof=oneof_name,
597633
)
598634

599635
# Done; return the answer.
@@ -779,19 +815,25 @@ def _load_message(self,
779815
loader=self._load_message,
780816
path=path + (3,),
781817
)
782-
# self._load_children(message.oneof_decl, loader=self._load_field,
783-
# address=nested_addr, info=info.get(8, {}))
818+
819+
oneofs = self._get_oneofs(
820+
message_pb.oneof_decl,
821+
address=address,
822+
path=path + (7,),
823+
)
784824

785825
# Create a dictionary of all the fields for this message.
786826
fields = self._get_fields(
787827
message_pb.field,
788828
address=address,
789829
path=path + (2,),
830+
oneofs=oneofs,
790831
)
791832
fields.update(self._get_fields(
792833
message_pb.extension,
793834
address=address,
794835
path=path + (6,),
836+
oneofs=oneofs,
795837
))
796838

797839
# Create a message correspoding to this descriptor.
@@ -804,6 +846,7 @@ def _load_message(self,
804846
address=address,
805847
documentation=self.docs.get(path, self.EMPTY),
806848
),
849+
oneofs=oneofs,
807850
)
808851
return self.proto_messages[address.proto]
809852

packages/gapic-generator/gapic/schema/wrappers.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class Field:
5454
meta: metadata.Metadata = dataclasses.field(
5555
default_factory=metadata.Metadata,
5656
)
57+
oneof: Optional[str] = None
5758

5859
def __getattr__(self, name):
5960
return getattr(self.field_pb, name)
@@ -206,6 +207,15 @@ def with_context(self, *, collisions: FrozenSet[str]) -> 'Field':
206207
)
207208

208209

210+
@dataclasses.dataclass(frozen=True)
211+
class Oneof:
212+
"""Description of a field."""
213+
oneof_pb: descriptor_pb2.OneofDescriptorProto
214+
215+
def __getattr__(self, name):
216+
return getattr(self.oneof_pb, name)
217+
218+
209219
@dataclasses.dataclass(frozen=True)
210220
class MessageType:
211221
"""Description of a message (defined with the ``message`` keyword)."""
@@ -220,6 +230,7 @@ class MessageType:
220230
meta: metadata.Metadata = dataclasses.field(
221231
default_factory=metadata.Metadata,
222232
)
233+
oneofs: Optional[Mapping[str, 'Oneof']] = None
223234

224235
def __getattr__(self, name):
225236
return getattr(self.message_pb, name)

packages/gapic-generator/gapic/templates/%namespace/%name_%version/%sub/types/_message.py.j2

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,15 @@ class {{ message.name }}({{ p }}.Message):
3838
{{- p }}.{{ key_field.proto_type }}, {{ p }}.{{ value_field.proto_type }}, number={{ field.number }}
3939
{%- if value_field.enum or value_field.message %},
4040
{{ value_field.proto_type.lower() }}={{ value_field.type.ident.rel(message.ident) }},
41-
{% endif %})
41+
{% endif %}) {# enum or message#}
4242
{% endwith -%}
43-
{% else -%}
43+
{% else -%} {# field.map #}
4444
{{ field.name }} = {{ p }}.{% if field.repeated %}Repeated{% endif %}Field(
4545
{{- p }}.{{ field.proto_type }}, number={{ field.number }}
46+
{% if field.oneof %}, oneof='{{ field.oneof }}'{% endif %}
4647
{%- if field.enum or field.message %},
4748
{{ field.proto_type.lower() }}={{ field.type.ident.rel(message.ident) }},
48-
{% endif %})
49-
{% endif -%}
50-
{% endfor -%}
49+
{% endif %}) {# enum or message #}
50+
{% endif -%} {# field.map #}
51+
{% endfor -%} {# for field in message.fields.values#}
5152
{{ '\n\n' }}

packages/gapic-generator/gapic/templates/noxfile.py.j2

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def unit(session):
2020
'--cov-config=.coveragerc',
2121
'--cov-report=term',
2222
'--cov-report=html',
23-
os.path.join('tests', 'unit', '{{ api.naming.versioned_module_name }}'),
23+
os.path.join('tests', 'unit',)
2424
)
2525

2626

packages/gapic-generator/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,9 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'):
288288
call.return_value = iter([{{ method.output.ident }}()])
289289
{% else -%}
290290
call.return_value = {{ method.output.ident }}(
291-
{%- for field in method.output.fields.values() | rejectattr('message') %}
291+
{%- for field in method.output.fields.values() | rejectattr('message')%}{% if not (field.oneof and not field.proto3_optional) %}
292292
{{ field.name }}={{ field.mock_value }},
293-
{%- endfor %}
293+
{% endif %}{%- endfor %}
294294
)
295295
{% endif -%}
296296
{% if method.client_streaming %}
@@ -318,14 +318,15 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'):
318318
assert isinstance(message, {{ method.output.ident }})
319319
{% else -%}
320320
assert isinstance(response, {{ method.client_output.ident }})
321-
{% for field in method.output.fields.values() | rejectattr('message') -%}
321+
{% for field in method.output.fields.values() | rejectattr('message') -%}{% if not (field.oneof and not field.proto3_optional) %}
322322
{% if field.field_pb.type in [1, 2] -%} {# Use approx eq for floats -#}
323323
assert math.isclose(response.{{ field.name }}, {{ field.mock_value }}, rel_tol=1e-6)
324324
{% elif field.field_pb.type == 8 -%} {# Use 'is' for bools #}
325325
assert response.{{ field.name }} is {{ field.mock_value }}
326326
{% else -%}
327327
assert response.{{ field.name }} == {{ field.mock_value }}
328328
{% endif -%}
329+
{% endif -%} {# end oneof/optional #}
329330
{% endfor %}
330331
{% endif %}
331332

@@ -368,8 +369,9 @@ async def test_{{ method.name|snake_case }}_async(transport: str = 'grpc_asyncio
368369
{%- else -%}
369370
grpc_helpers_async.FakeStreamUnaryCall
370371
{%- endif -%}({{ method.output.ident }}(
371-
{%- for field in method.output.fields.values() | rejectattr('message') %}
372+
{%- for field in method.output.fields.values() | rejectattr('message') %}{% if not (field.oneof and not field.proto3_optional) %}
372373
{{ field.name }}={{ field.mock_value }},
374+
{%- endif %}
373375
{%- endfor %}
374376
))
375377
{% endif -%}
@@ -400,14 +402,15 @@ async def test_{{ method.name|snake_case }}_async(transport: str = 'grpc_asyncio
400402
assert isinstance(message, {{ method.output.ident }})
401403
{% else -%}
402404
assert isinstance(response, {{ method.client_output_async.ident }})
403-
{% for field in method.output.fields.values() | rejectattr('message') -%}
405+
{% for field in method.output.fields.values() | rejectattr('message') -%}{% if not (field.oneof and not field.proto3_optional) %}
404406
{% if field.field_pb.type in [1, 2] -%} {# Use approx eq for floats -#}
405407
assert math.isclose(response.{{ field.name }}, {{ field.mock_value }}, rel_tol=1e-6)
406408
{% elif field.field_pb.type == 8 -%} {# Use 'is' for bools #}
407409
assert response.{{ field.name }} is {{ field.mock_value }}
408410
{% else -%}
409411
assert response.{{ field.name }} == {{ field.mock_value }}
410412
{% endif -%}
413+
{% endif -%} {# oneof/optional #}
411414
{% endfor %}
412415
{% endif %}
413416

packages/gapic-generator/gapic/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from gapic.utils.cache import cached_property
1616
from gapic.utils.case import to_snake_case
1717
from gapic.utils.code import empty
18+
from gapic.utils.code import nth
1819
from gapic.utils.code import partition
1920
from gapic.utils.doc import doc
2021
from gapic.utils.filename import to_valid_filename
@@ -29,6 +30,7 @@
2930
'cached_property',
3031
'doc',
3132
'empty',
33+
'nth',
3234
'partition',
3335
'RESERVED_NAMES',
3436
'rst',

packages/gapic-generator/gapic/utils/code.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import (Callable, Iterable, List, Tuple, TypeVar)
15+
from typing import (Callable, Iterable, List, Optional, Tuple, TypeVar)
16+
import itertools
1617

1718

1819
def empty(content: str) -> bool:
@@ -50,3 +51,15 @@ def partition(predicate: Callable[[T], bool],
5051

5152
# Returns trueList, falseList
5253
return results[1], results[0]
54+
55+
56+
def nth(iterable: Iterable[T], n: int, default: Optional[T] = None) -> Optional[T]:
57+
"""Return the nth element of an iterable or a default value.
58+
59+
Args
60+
iterable (Iterable(T)): An iterable on any type.
61+
n (int): The 'index' of the lement to retrieve.
62+
default (Optional(T)): An optional default elemnt if the iterable has
63+
fewer than n elements.
64+
"""
65+
return next(itertools.islice(iterable, n, None), default)

packages/gapic-generator/test_utils/test_utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def make_field(
200200
message: wrappers.MessageType = None,
201201
enum: wrappers.EnumType = None,
202202
meta: metadata.Metadata = None,
203+
oneof: str = None,
203204
**kwargs
204205
) -> wrappers.Field:
205206
T = desc.FieldDescriptorProto.Type
@@ -223,11 +224,13 @@ def make_field(
223224
number=number,
224225
**kwargs
225226
)
227+
226228
return wrappers.Field(
227229
field_pb=field_pb,
228230
enum=enum,
229231
message=message,
230232
meta=meta or metadata.Metadata(),
233+
oneof=oneof,
231234
)
232235

233236

@@ -322,20 +325,28 @@ def make_enum_pb2(
322325
def make_message_pb2(
323326
name: str,
324327
fields: tuple = (),
328+
oneof_decl: tuple = (),
325329
**kwargs
326330
) -> desc.DescriptorProto:
327-
return desc.DescriptorProto(name=name, field=fields, **kwargs)
331+
return desc.DescriptorProto(name=name, field=fields, oneof_decl=oneof_decl, **kwargs)
328332

329333

330334
def make_field_pb2(name: str, number: int,
331335
type: int = 11, # 11 == message
332336
type_name: str = None,
337+
oneof_index: int = None
333338
) -> desc.FieldDescriptorProto:
334339
return desc.FieldDescriptorProto(
335340
name=name,
336341
number=number,
337342
type=type,
338343
type_name=type_name,
344+
oneof_index=oneof_index,
345+
)
346+
347+
def make_oneof_pb2(name: str) -> desc.OneofDescriptorProto:
348+
return desc.OneofDescriptorProto(
349+
name=name,
339350
)
340351

341352

packages/gapic-generator/tests/unit/schema/test_api.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
make_file_pb2,
3535
make_message_pb2,
3636
make_naming,
37+
make_oneof_pb2,
3738
)
3839

3940

@@ -239,6 +240,45 @@ def test_proto_keyword_fname():
239240
}
240241

241242

243+
def test_proto_oneof():
244+
# Put together a couple of minimal protos.
245+
fd = (
246+
make_file_pb2(
247+
name='dep.proto',
248+
package='google.dep',
249+
messages=(make_message_pb2(name='ImportedMessage', fields=()),),
250+
),
251+
make_file_pb2(
252+
name='foo.proto',
253+
package='google.example.v1',
254+
messages=(
255+
make_message_pb2(name='Foo', fields=()),
256+
make_message_pb2(
257+
name='Bar',
258+
fields=(
259+
make_field_pb2(name='imported_message', number=1,
260+
type_name='.google.dep.ImportedMessage',
261+
oneof_index=0),
262+
make_field_pb2(
263+
name='primitive', number=2, type=1, oneof_index=0),
264+
),
265+
oneof_decl=(
266+
make_oneof_pb2(name="value_type"),
267+
)
268+
)
269+
)
270+
)
271+
)
272+
273+
# Create an API with those protos.
274+
api_schema = api.API.build(fd, package='google.example.v1')
275+
proto = api_schema.protos['foo.proto']
276+
assert proto.names == {'imported_message', 'Bar', 'primitive', 'Foo'}
277+
oneofs = proto.messages["google.example.v1.Bar"].oneofs
278+
assert len(oneofs) == 1
279+
assert "value_type" in oneofs.keys()
280+
281+
242282
def test_proto_names_import_collision():
243283
# Put together a couple of minimal protos.
244284
fd = (

0 commit comments

Comments
 (0)