Skip to content

Commit

Permalink
Refactoring Java parsing (3.20.x) (#10666)
Browse files Browse the repository at this point in the history
* Porting java cleanup

* Update changelog

* Fix absl usage

* Extension patch

* Remove extra allocations
  • Loading branch information
mkruskal-google authored Sep 29, 2022
1 parent 139068b commit ae6d69d
Show file tree
Hide file tree
Showing 40 changed files with 1,690 additions and 940 deletions.
11 changes: 11 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
2022-09-27 version 3.20.3 (C++/Java/Python/PHP/Objective-C/C#/Ruby)
Java
* Refactoring java full runtime to reuse sub-message builders and prepare to
migrate parsing logic from parse constructor to builder.
* Move proto wireformat parsing functionality from the private "parsing
constructor" to the Builder class.
* Change the Lite runtime to prefer merging from the wireformat into mutable
messages rather than building up a new immutable object before merging. This
way results in fewer allocations and copy operations.
* Make message-type extensions merge from wire-format instead of building up instances and merging afterwards. This has much better performance.

2022-09-13 version 3.20.2 (C++/Java/Python/PHP/Objective-C/C#/Ruby)

C++
Expand Down
27 changes: 11 additions & 16 deletions java/core/src/main/java/com/google/protobuf/AbstractMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -424,27 +424,22 @@ public BuilderType mergeFrom(
throws IOException {
boolean discardUnknown = input.shouldDiscardUnknownFields();
final UnknownFieldSet.Builder unknownFields =
discardUnknown ? null : UnknownFieldSet.newBuilder(getUnknownFields());
while (true) {
final int tag = input.readTag();
if (tag == 0) {
break;
}

MessageReflection.BuilderAdapter builderAdapter =
new MessageReflection.BuilderAdapter(this);
if (!MessageReflection.mergeFieldFrom(
input, unknownFields, extensionRegistry, getDescriptorForType(), builderAdapter, tag)) {
// end group tag
break;
}
}
discardUnknown ? null : getUnknownFieldSetBuilder();
MessageReflection.mergeMessageFrom(this, unknownFields, input, extensionRegistry);
if (unknownFields != null) {
setUnknownFields(unknownFields.build());
setUnknownFieldSetBuilder(unknownFields);
}
return (BuilderType) this;
}

protected UnknownFieldSet.Builder getUnknownFieldSetBuilder() {
return UnknownFieldSet.newBuilder(getUnknownFields());
}

protected void setUnknownFieldSetBuilder(final UnknownFieldSet.Builder builder) {
setUnknownFields(builder.build());
}

@Override
public BuilderType mergeUnknownFields(final UnknownFieldSet unknownFields) {
setUnknownFields(
Expand Down
146 changes: 86 additions & 60 deletions java/core/src/main/java/com/google/protobuf/ArrayDecoders.java
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,29 @@ static int decodeBytes(byte[] data, int position, Registers registers)
@SuppressWarnings({"unchecked", "rawtypes"})
static int decodeMessageField(
Schema schema, byte[] data, int position, int limit, Registers registers) throws IOException {
Object msg = schema.newInstance();
int offset = mergeMessageField(msg, schema, data, position, limit, registers);
schema.makeImmutable(msg);
registers.object1 = msg;
return offset;
}

/** Decodes a group value. */
@SuppressWarnings({"unchecked", "rawtypes"})
static int decodeGroupField(
Schema schema, byte[] data, int position, int limit, int endGroup, Registers registers)
throws IOException {
Object msg = schema.newInstance();
int offset = mergeGroupField(msg, schema, data, position, limit, endGroup, registers);
schema.makeImmutable(msg);
registers.object1 = msg;
return offset;
}

@SuppressWarnings({"unchecked", "rawtypes"})
static int mergeMessageField(
Object msg, Schema schema, byte[] data, int position, int limit, Registers registers)
throws IOException {
int length = data[position++];
if (length < 0) {
position = decodeVarint32(length, data, position, registers);
Expand All @@ -244,27 +267,28 @@ static int decodeMessageField(
if (length < 0 || length > limit - position) {
throw InvalidProtocolBufferException.truncatedMessage();
}
Object result = schema.newInstance();
schema.mergeFrom(result, data, position, position + length, registers);
schema.makeImmutable(result);
registers.object1 = result;
schema.mergeFrom(msg, data, position, position + length, registers);
registers.object1 = msg;
return position + length;
}

/** Decodes a group value. */
@SuppressWarnings({"unchecked", "rawtypes"})
static int decodeGroupField(
Schema schema, byte[] data, int position, int limit, int endGroup, Registers registers)
static int mergeGroupField(
Object msg,
Schema schema,
byte[] data,
int position,
int limit,
int endGroup,
Registers registers)
throws IOException {
// A group field must has a MessageSchema (the only other subclass of Schema is MessageSetSchema
// and it can't be used in group fields).
final MessageSchema messageSchema = (MessageSchema) schema;
Object result = messageSchema.newInstance();
// It's OK to directly use parseProto2Message since proto3 doesn't have group.
final int endPosition =
messageSchema.parseProto2Message(result, data, position, limit, endGroup, registers);
messageSchema.makeImmutable(result);
registers.object1 = result;
messageSchema.parseProto2Message(msg, data, position, limit, endGroup, registers);
registers.object1 = msg;
return endPosition;
}

Expand Down Expand Up @@ -848,26 +872,19 @@ static int decodeExtension(
break;
}
case ENUM:
{
IntArrayList list = new IntArrayList();
position = decodePackedVarint32List(data, position, list, registers);
UnknownFieldSetLite unknownFields = message.unknownFields;
if (unknownFields == UnknownFieldSetLite.getDefaultInstance()) {
unknownFields = null;
}
unknownFields =
SchemaUtil.filterUnknownEnumList(
fieldNumber,
list,
extension.descriptor.getEnumType(),
unknownFields,
unknownFieldSchema);
if (unknownFields != null) {
message.unknownFields = unknownFields;
{
IntArrayList list = new IntArrayList();
position = decodePackedVarint32List(data, position, list, registers);
SchemaUtil.filterUnknownEnumList(
message,
fieldNumber,
list,
extension.descriptor.getEnumType(),
null,
unknownFieldSchema);
extensions.setField(extension.descriptor, list);
break;
}
extensions.setField(extension.descriptor, list);
break;
}
default:
throw new IllegalStateException(
"Type cannot be packed: " + extension.descriptor.getLiteType());
Expand All @@ -879,13 +896,8 @@ static int decodeExtension(
position = decodeVarint32(data, position, registers);
Object enumValue = extension.descriptor.getEnumType().findValueByNumber(registers.int1);
if (enumValue == null) {
UnknownFieldSetLite unknownFields = ((GeneratedMessageLite) message).unknownFields;
if (unknownFields == UnknownFieldSetLite.getDefaultInstance()) {
unknownFields = UnknownFieldSetLite.newInstance();
((GeneratedMessageLite) message).unknownFields = unknownFields;
}
SchemaUtil.storeUnknownEnum(
fieldNumber, registers.int1, unknownFields, unknownFieldSchema);
message, fieldNumber, registers.int1, null, unknownFieldSchema);
return position;
}
// Note, we store the integer value instead of the actual enum object in FieldSet.
Expand Down Expand Up @@ -942,38 +954,52 @@ static int decodeExtension(
value = registers.object1;
break;
case GROUP:
final int endTag = (fieldNumber << 3) | WireFormat.WIRETYPE_END_GROUP;
position = decodeGroupField(
Protobuf.getInstance().schemaFor(extension.getMessageDefaultInstance().getClass()),
data, position, limit, endTag, registers);
value = registers.object1;
break;

{
final int endTag = (fieldNumber << 3) | WireFormat.WIRETYPE_END_GROUP;
final Schema fieldSchema =
Protobuf.getInstance()
.schemaFor(extension.getMessageDefaultInstance().getClass());
if (extension.isRepeated()) {
position = decodeGroupField(fieldSchema, data, position, limit, endTag, registers);
extensions.addRepeatedField(extension.descriptor, registers.object1);
} else {
Object oldValue = extensions.getField(extension.descriptor);
if (oldValue == null) {
oldValue = fieldSchema.newInstance();
extensions.setField(extension.descriptor, oldValue);
}
position =
mergeGroupField(
oldValue, fieldSchema, data, position, limit, endTag, registers);
}
return position;
}
case MESSAGE:
position = decodeMessageField(
Protobuf.getInstance().schemaFor(extension.getMessageDefaultInstance().getClass()),
data, position, limit, registers);
value = registers.object1;
break;

{
final Schema fieldSchema =
Protobuf.getInstance()
.schemaFor(extension.getMessageDefaultInstance().getClass());
if (extension.isRepeated()) {
position = decodeMessageField(fieldSchema, data, position, limit, registers);
extensions.addRepeatedField(extension.descriptor, registers.object1);
} else {
Object oldValue = extensions.getField(extension.descriptor);
if (oldValue == null) {
oldValue = fieldSchema.newInstance();
extensions.setField(extension.descriptor, oldValue);
}
position =
mergeMessageField(oldValue, fieldSchema, data, position, limit, registers);
}
return position;
}
case ENUM:
throw new IllegalStateException("Shouldn't reach here.");
}
}
if (extension.isRepeated()) {
extensions.addRepeatedField(extension.descriptor, value);
} else {
switch (extension.getLiteType()) {
case MESSAGE:
case GROUP:
Object oldValue = extensions.getField(extension.descriptor);
if (oldValue != null) {
value = Internal.mergeMessage(oldValue, value);
}
break;
default:
break;
}
extensions.setField(extension.descriptor, value);
}
}
Expand Down
32 changes: 20 additions & 12 deletions java/core/src/main/java/com/google/protobuf/BinaryReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,15 @@ public <T> T readMessageBySchemaWithCheck(

private <T> T readMessage(Schema<T> schema, ExtensionRegistryLite extensionRegistry)
throws IOException {
T newInstance = schema.newInstance();
mergeMessageField(newInstance, schema, extensionRegistry);
schema.makeImmutable(newInstance);
return newInstance;
}

@Override
public <T> void mergeMessageField(
T target, Schema<T> schema, ExtensionRegistryLite extensionRegistry) throws IOException {
int size = readVarint32();
requireBytes(size);

Expand All @@ -257,15 +266,10 @@ private <T> T readMessage(Schema<T> schema, ExtensionRegistryLite extensionRegis
limit = newLimit;

try {
// Allocate and read the message.
T message = schema.newInstance();
schema.mergeFrom(message, this, extensionRegistry);
schema.makeImmutable(message);

schema.mergeFrom(target, this, extensionRegistry);
if (pos != newLimit) {
throw InvalidProtocolBufferException.parseFailure();
}
return message;
} finally {
// Restore the limit.
limit = prevLimit;
Expand All @@ -290,19 +294,23 @@ public <T> T readGroupBySchemaWithCheck(

private <T> T readGroup(Schema<T> schema, ExtensionRegistryLite extensionRegistry)
throws IOException {
T newInstance = schema.newInstance();
mergeGroupField(newInstance, schema, extensionRegistry);
schema.makeImmutable(newInstance);
return newInstance;
}

@Override
public <T> void mergeGroupField(
T target, Schema<T> schema, ExtensionRegistryLite extensionRegistry) throws IOException {
int prevEndGroupTag = endGroupTag;
endGroupTag = WireFormat.makeTag(WireFormat.getTagFieldNumber(tag), WIRETYPE_END_GROUP);

try {
// Allocate and read the message.
T message = schema.newInstance();
schema.mergeFrom(message, this, extensionRegistry);
schema.makeImmutable(message);

schema.mergeFrom(target, this, extensionRegistry);
if (tag != endGroupTag) {
throw InvalidProtocolBufferException.parseFailure();
}
return message;
} finally {
// Restore the old end group tag.
endGroupTag = prevEndGroupTag;
Expand Down
Loading

0 comments on commit ae6d69d

Please sign in to comment.