Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add oneof fields to generated protoplus init #485

Merged
merged 32 commits into from
Jul 7, 2020
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
1cbae6a
fix: add oneof fields to genereated protoplus init
crwilcox Jun 25, 2020
f104ba0
Add a one of field to the Field dataclass
crwilcox Jun 25, 2020
1f8ee9c
style
crwilcox Jun 26, 2020
7438862
style
crwilcox Jun 26, 2020
593f647
str type for oneof
crwilcox Jun 26, 2020
114a277
add field oneof to ads templates
crwilcox Jun 26, 2020
589e914
mypy: specify oneof as optional
crwilcox Jun 26, 2020
c912534
add oneofs to schema
crwilcox Jun 30, 2020
afa3578
exit oops
crwilcox Jun 30, 2020
9a5e6c5
mypy and lint fixes
crwilcox Jun 30, 2020
086cd70
Merge branch 'master' into oneof-proto-templates
crwilcox Jun 30, 2020
cf31e2a
add tests for oneofs
crwilcox Jun 30, 2020
3b90426
lint
crwilcox Jun 30, 2020
70cfe15
typing and test
crwilcox Jun 30, 2020
ca19eee
Merge branch 'master' into oneof-proto-templates
crwilcox Jun 30, 2020
58d3f97
style
crwilcox Jun 30, 2020
b44317f
whitespace
crwilcox Jun 30, 2020
df33ed9
Merge remote-tracking branch 'upstream' into oneof-proto-templates
crwilcox Jun 30, 2020
c4d1b4e
Merge branch 'master' into oneof-proto-templates
crwilcox Jul 6, 2020
374bfbe
Add an nth utility
software-dov Jul 6, 2020
a329bd0
Tweak api.py to use nth
software-dov Jul 6, 2020
cbf73a3
Templates change for oneofs (NOT ADS TEMPLATES)
software-dov Jul 6, 2020
2fc85fa
Merge branch 'oneof-proto-templates' of https://github.com/googleapis…
software-dov Jul 6, 2020
5f861d3
Test nth
software-dov Jul 7, 2020
64ec32b
Style check
software-dov Jul 7, 2020
07709f8
Tweak
software-dov Jul 7, 2020
aa7ceb6
Thing
software-dov Jul 7, 2020
ec65b73
Thing in thing exception
software-dov Jul 7, 2020
05b05d5
Whitespace
software-dov Jul 7, 2020
e590edb
Fighting the type system
software-dov Jul 7, 2020
3def32f
Merge branch 'master' into oneof-proto-templates
software-dov Jul 7, 2020
8dbaa75
Merge branch 'master' into oneof-proto-templates
software-dov Jul 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class {{ message.name }}({{ p }}.Message):
{% else -%}
{{ field.name }} = {{ p }}.{% if field.repeated %}Repeated{% endif %}Field(
{{- p }}.{{ field.proto_type }}, number={{ field.number }}
{% if field.oneof %}, oneof='{{ field.oneof }}'{% endif %}
{%- if field.enum or field.message %},
{{ field.proto_type.lower() }}={{ field.type.ident.rel(message.ident) }},
{% endif %})
Expand Down
51 changes: 47 additions & 4 deletions gapic/schema/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from gapic.schema import wrappers
from gapic.schema import naming as api_naming
from gapic.utils import cached_property
from gapic.utils import nth
from gapic.utils import to_snake_case
from gapic.utils import RESERVED_NAMES

Expand Down Expand Up @@ -556,14 +557,42 @@ def _load_children(self,
answer[wrapped.name] = wrapped
return answer

def _get_oneofs(self,
oneof_pbs: Sequence[descriptor_pb2.OneofDescriptorProto],
address: metadata.Address, path: Tuple[int, ...],
) -> Dict[str, wrappers.Oneof]:
"""Return a dictionary of wrapped oneofs for the given message.

Args:
oneof_fields (Sequence[~.descriptor_pb2.OneofDescriptorProto]): A
sequence of protobuf field objects.
address (~.metadata.Address): An address object denoting the
location of these oneofs.
path (Tuple[int]): The source location path thus far, as
understood by ``SourceCodeInfo.Location``.

Returns:
Mapping[str, ~.wrappers.Oneof]: A ordered mapping of
:class:`~.wrappers.Oneof` objects.
"""
# Iterate over the oneofs and collect them into a dictionary.
answer = collections.OrderedDict(
(oneof_pb.name, wrappers.Oneof(oneof_pb=oneof_pb))
for i, oneof_pb in enumerate(oneof_pbs)
)

# Done; return the answer.
return answer

def _get_fields(self,
field_pbs: Sequence[descriptor_pb2.FieldDescriptorProto],
address: metadata.Address, path: Tuple[int, ...],
oneofs: Optional[Dict[str, wrappers.Oneof]] = None
) -> Dict[str, wrappers.Field]:
"""Return a dictionary of wrapped fields for the given message.

Args:
fields (Sequence[~.descriptor_pb2.FieldDescriptorProto]): A
field_pbs (Sequence[~.descriptor_pb2.FieldDescriptorProto]): A
sequence of protobuf field objects.
address (~.metadata.Address): An address object denoting the
location of these fields.
Expand All @@ -585,7 +614,13 @@ def _get_fields(self,
# first) and this will be None. This case is addressed in the
# `_load_message` method.
answer: Dict[str, wrappers.Field] = collections.OrderedDict()
for field_pb, i in zip(field_pbs, range(0, sys.maxsize)):
for i, field_pb in enumerate(field_pbs):
is_oneof = oneofs and field_pb.oneof_index > 0
oneof_name = nth(
(oneofs or {}).keys(),
field_pb.oneof_index
) if is_oneof else None

answer[field_pb.name] = wrappers.Field(
field_pb=field_pb,
enum=self.api_enums.get(field_pb.type_name.lstrip('.')),
Expand All @@ -594,6 +629,7 @@ def _get_fields(self,
address=address.child(field_pb.name, path + (i,)),
documentation=self.docs.get(path + (i,), self.EMPTY),
),
oneof=oneof_name,
)

# Done; return the answer.
Expand Down Expand Up @@ -779,19 +815,25 @@ def _load_message(self,
loader=self._load_message,
path=path + (3,),
)
# self._load_children(message.oneof_decl, loader=self._load_field,
# address=nested_addr, info=info.get(8, {}))

oneofs = self._get_oneofs(
message_pb.oneof_decl,
address=address,
path=path + (7,),
)

# Create a dictionary of all the fields for this message.
fields = self._get_fields(
message_pb.field,
address=address,
path=path + (2,),
oneofs=oneofs,
)
fields.update(self._get_fields(
message_pb.extension,
address=address,
path=path + (6,),
oneofs=oneofs,
))

# Create a message correspoding to this descriptor.
Expand All @@ -804,6 +846,7 @@ def _load_message(self,
address=address,
documentation=self.docs.get(path, self.EMPTY),
),
oneofs=oneofs,
crwilcox marked this conversation as resolved.
Show resolved Hide resolved
)
return self.proto_messages[address.proto]

Expand Down
11 changes: 11 additions & 0 deletions gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class Field:
meta: metadata.Metadata = dataclasses.field(
default_factory=metadata.Metadata,
)
oneof: Optional[str] = None

def __getattr__(self, name):
return getattr(self.field_pb, name)
Expand Down Expand Up @@ -206,6 +207,15 @@ def with_context(self, *, collisions: FrozenSet[str]) -> 'Field':
)


@dataclasses.dataclass(frozen=True)
class Oneof:
"""Description of a field."""
oneof_pb: descriptor_pb2.OneofDescriptorProto

def __getattr__(self, name):
return getattr(self.oneof_pb, name)


@dataclasses.dataclass(frozen=True)
class MessageType:
"""Description of a message (defined with the ``message`` keyword)."""
Expand All @@ -220,6 +230,7 @@ class MessageType:
meta: metadata.Metadata = dataclasses.field(
default_factory=metadata.Metadata,
)
oneofs: Optional[Mapping[str, 'Oneof']] = None

def __getattr__(self, name):
return getattr(self.message_pb, name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ class {{ message.name }}({{ p }}.Message):
{{- p }}.{{ key_field.proto_type }}, {{ p }}.{{ value_field.proto_type }}, number={{ field.number }}
{%- if value_field.enum or value_field.message %},
{{ value_field.proto_type.lower() }}={{ value_field.type.ident.rel(message.ident) }},
{% endif %})
{% endif %}) {# enum or message#}
{% endwith -%}
{% else -%}
{% else -%} {# field.map #}
{{ field.name }} = {{ p }}.{% if field.repeated %}Repeated{% endif %}Field(
{{- p }}.{{ field.proto_type }}, number={{ field.number }}
{% if field.oneof %}, oneof='{{ field.oneof }}'{% endif %}
{%- if field.enum or field.message %},
{{ field.proto_type.lower() }}={{ field.type.ident.rel(message.ident) }},
{% endif %})
{% endif -%}
{% endfor -%}
{% endif %}) {# enum or message #}
{% endif -%} {# field.map #}
{% endfor -%} {# for field in message.fields.values#}
{{ '\n\n' }}
2 changes: 1 addition & 1 deletion gapic/templates/noxfile.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def unit(session):
'--cov-config=.coveragerc',
'--cov-report=term',
'--cov-report=html',
os.path.join('tests', 'unit', '{{ api.naming.versioned_module_name }}'),
os.path.join('tests', 'unit',)
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,9 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'):
call.return_value = iter([{{ method.output.ident }}()])
{% else -%}
call.return_value = {{ method.output.ident }}(
{%- for field in method.output.fields.values() | rejectattr('message') %}
{%- for field in method.output.fields.values() | rejectattr('message')%}{% if not (field.oneof and not field.proto3_optional) %}
{{ field.name }}={{ field.mock_value }},
{%- endfor %}
{% endif %}{%- endfor %}
)
{% endif -%}
{% if method.client_streaming %}
Expand Down Expand Up @@ -318,14 +318,15 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'):
assert isinstance(message, {{ method.output.ident }})
{% else -%}
assert isinstance(response, {{ method.client_output.ident }})
{% for field in method.output.fields.values() | rejectattr('message') -%}
{% for field in method.output.fields.values() | rejectattr('message') -%}{% if not (field.oneof and not field.proto3_optional) %}
{% if field.field_pb.type in [1, 2] -%} {# Use approx eq for floats -#}
assert math.isclose(response.{{ field.name }}, {{ field.mock_value }}, rel_tol=1e-6)
{% elif field.field_pb.type == 8 -%} {# Use 'is' for bools #}
assert response.{{ field.name }} is {{ field.mock_value }}
{% else -%}
assert response.{{ field.name }} == {{ field.mock_value }}
{% endif -%}
{% endif -%} {# end oneof/optional #}
{% endfor %}
{% endif %}

Expand Down Expand Up @@ -368,8 +369,9 @@ async def test_{{ method.name|snake_case }}_async(transport: str = 'grpc_asyncio
{%- else -%}
grpc_helpers_async.FakeStreamUnaryCall
{%- endif -%}({{ method.output.ident }}(
{%- for field in method.output.fields.values() | rejectattr('message') %}
{%- for field in method.output.fields.values() | rejectattr('message') %}{% if not (field.oneof and not field.proto3_optional) %}
{{ field.name }}={{ field.mock_value }},
{%- endif %}
{%- endfor %}
))
{% endif -%}
Expand Down Expand Up @@ -400,14 +402,15 @@ async def test_{{ method.name|snake_case }}_async(transport: str = 'grpc_asyncio
assert isinstance(message, {{ method.output.ident }})
{% else -%}
assert isinstance(response, {{ method.client_output_async.ident }})
{% for field in method.output.fields.values() | rejectattr('message') -%}
{% for field in method.output.fields.values() | rejectattr('message') -%}{% if not (field.oneof and not field.proto3_optional) %}
{% if field.field_pb.type in [1, 2] -%} {# Use approx eq for floats -#}
assert math.isclose(response.{{ field.name }}, {{ field.mock_value }}, rel_tol=1e-6)
{% elif field.field_pb.type == 8 -%} {# Use 'is' for bools #}
assert response.{{ field.name }} is {{ field.mock_value }}
{% else -%}
assert response.{{ field.name }} == {{ field.mock_value }}
{% endif -%}
{% endif -%} {# oneof/optional #}
{% endfor %}
{% endif %}

Expand Down
2 changes: 2 additions & 0 deletions gapic/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from gapic.utils.cache import cached_property
from gapic.utils.case import to_snake_case
from gapic.utils.code import empty
from gapic.utils.code import nth
from gapic.utils.code import partition
from gapic.utils.doc import doc
from gapic.utils.filename import to_valid_filename
Expand All @@ -29,6 +30,7 @@
'cached_property',
'doc',
'empty',
'nth',
'partition',
'RESERVED_NAMES',
'rst',
Expand Down
15 changes: 14 additions & 1 deletion gapic/utils/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import (Callable, Iterable, List, Tuple, TypeVar)
from typing import (Callable, Iterable, List, Optional, Tuple, TypeVar)
import itertools


def empty(content: str) -> bool:
Expand Down Expand Up @@ -50,3 +51,15 @@ def partition(predicate: Callable[[T], bool],

# Returns trueList, falseList
return results[1], results[0]


def nth(iterable: Iterable[T], n: int, default: Optional[T] = None) -> Optional[T]:
"""Return the nth element of an iterable or a default value.

Args
iterable (Iterable(T)): An iterable on any type.
n (int): The 'index' of the lement to retrieve.
default (Optional(T)): An optional default elemnt if the iterable has
fewer than n elements.
"""
return next(itertools.islice(iterable, n, None), default)
13 changes: 12 additions & 1 deletion test_utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def make_field(
message: wrappers.MessageType = None,
enum: wrappers.EnumType = None,
meta: metadata.Metadata = None,
oneof: str = None,
**kwargs
) -> wrappers.Field:
T = desc.FieldDescriptorProto.Type
Expand All @@ -223,11 +224,13 @@ def make_field(
number=number,
**kwargs
)

return wrappers.Field(
field_pb=field_pb,
enum=enum,
message=message,
meta=meta or metadata.Metadata(),
oneof=oneof,
)


Expand Down Expand Up @@ -322,20 +325,28 @@ def make_enum_pb2(
def make_message_pb2(
name: str,
fields: tuple = (),
oneof_decl: tuple = (),
**kwargs
) -> desc.DescriptorProto:
return desc.DescriptorProto(name=name, field=fields, **kwargs)
return desc.DescriptorProto(name=name, field=fields, oneof_decl=oneof_decl, **kwargs)


def make_field_pb2(name: str, number: int,
type: int = 11, # 11 == message
type_name: str = None,
oneof_index: int = None
) -> desc.FieldDescriptorProto:
return desc.FieldDescriptorProto(
name=name,
number=number,
type=type,
type_name=type_name,
oneof_index=oneof_index,
)

def make_oneof_pb2(name: str) -> desc.OneofDescriptorProto:
return desc.OneofDescriptorProto(
name=name,
)


Expand Down
40 changes: 40 additions & 0 deletions tests/unit/schema/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
make_file_pb2,
make_message_pb2,
make_naming,
make_oneof_pb2,
)


Expand Down Expand Up @@ -239,6 +240,45 @@ def test_proto_keyword_fname():
}


def test_proto_oneof():
# Put together a couple of minimal protos.
fd = (
make_file_pb2(
name='dep.proto',
package='google.dep',
messages=(make_message_pb2(name='ImportedMessage', fields=()),),
),
make_file_pb2(
name='foo.proto',
package='google.example.v1',
messages=(
make_message_pb2(name='Foo', fields=()),
make_message_pb2(
name='Bar',
fields=(
make_field_pb2(name='imported_message', number=1,
type_name='.google.dep.ImportedMessage',
oneof_index=0),
make_field_pb2(
name='primitive', number=2, type=1, oneof_index=0),
),
oneof_decl=(
make_oneof_pb2(name="value_type"),
)
)
)
)
)

# Create an API with those protos.
api_schema = api.API.build(fd, package='google.example.v1')
proto = api_schema.protos['foo.proto']
assert proto.names == {'imported_message', 'Bar', 'primitive', 'Foo'}
oneofs = proto.messages["google.example.v1.Bar"].oneofs
assert len(oneofs) == 1
assert "value_type" in oneofs.keys()


def test_proto_names_import_collision():
# Put together a couple of minimal protos.
fd = (
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/schema/wrappers/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ def test_mock_value_int():
assert field.mock_value == '728'


def test_oneof():
REP = descriptor_pb2.FieldDescriptorProto.Label.Value('LABEL_REPEATED')

field = make_field(oneof="oneof_name")
assert field.oneof == "oneof_name"


def test_mock_value_float():
field = make_field(name='foo_bar', type='TYPE_DOUBLE')
assert field.mock_value == '0.728'
Expand Down
Loading