Skip to content

Commit ba8db91

Browse files
committed
Redo MessageSerializer with unions. Still has bugs
Change-Id: Ib8beb014310219a7ab8263802ec94d2ea5af6805
1 parent 21854cc commit ba8db91

File tree

3 files changed

+75
-110
lines changed

3 files changed

+75
-110
lines changed

cpp/src/arrow/ipc/adapter.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,12 @@ class RecordBatchWriter : public ArrayVisitor {
129129
num_rows_, body_length, field_nodes_, buffer_meta_, &metadata_fb));
130130

131131
// Need to write 4 bytes (metadata size), the metadata, plus padding to
132-
// fall on a 64-byte offset
133-
int64_t padded_metadata_length =
134-
BitUtil::RoundUpToMultipleOf64(metadata_fb->size() + 4);
132+
// fall on an 8-byte offset
133+
int64_t padded_metadata_length = BitUtil::CeilByte(metadata_fb->size() + 4);
135134

136135
// The returned metadata size includes the length prefix, the flatbuffer,
137136
// plus padding
138-
*metadata_length = padded_metadata_length;
137+
*metadata_length = static_cast<int32_t>(padded_metadata_length);
139138

140139
// Write the flatbuffer size prefix
141140
int32_t flatbuffer_size = metadata_fb->size();
@@ -604,7 +603,9 @@ Status ReadRecordBatchMetadata(int64_t offset, int32_t metadata_length,
604603
return Status::Invalid(ss.str());
605604
}
606605

607-
*metadata = std::make_shared<RecordBatchMetadata>(buffer, sizeof(int32_t));
606+
std::shared_ptr<Message> message;
607+
RETURN_NOT_OK(Message::Open(buffer, 4, &message));
608+
*metadata = std::make_shared<RecordBatchMetadata>(message);
608609
return Status::OK();
609610
}
610611

cpp/src/arrow/ipc/metadata-internal.cc

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -320,23 +320,10 @@ Status MessageBuilder::SetRecordBatch(int32_t length, int64_t body_length,
320320
Status WriteRecordBatchMetadata(int32_t length, int64_t body_length,
321321
const std::vector<flatbuf::FieldNode>& nodes,
322322
const std::vector<flatbuf::Buffer>& buffers, std::shared_ptr<Buffer>* out) {
323-
flatbuffers::FlatBufferBuilder fbb;
324-
325-
auto batch = flatbuf::CreateRecordBatch(
326-
fbb, length, fbb.CreateVectorOfStructs(nodes), fbb.CreateVectorOfStructs(buffers));
327-
328-
fbb.Finish(batch);
329-
330-
int32_t size = fbb.GetSize();
331-
332-
auto result = std::make_shared<PoolBuffer>();
333-
RETURN_NOT_OK(result->Resize(size));
334-
335-
uint8_t* dst = result->mutable_data();
336-
memcpy(dst, fbb.GetBufferPointer(), size);
337-
338-
*out = result;
339-
return Status::OK();
323+
MessageBuilder builder;
324+
RETURN_NOT_OK(builder.SetRecordBatch(length, body_length, nodes, buffers));
325+
RETURN_NOT_OK(builder.Finish());
326+
return builder.GetBuffer(out);
340327
}
341328

342329
Status MessageBuilder::Finish() {

java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java

Lines changed: 65 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -70,33 +70,24 @@ public static int bytesToInt(byte[] bytes) {
7070
*/
7171
public static long serialize(WriteChannel out, Schema schema) throws IOException {
7272
FlatBufferBuilder builder = new FlatBufferBuilder();
73-
builder.finish(schema.getSchema(builder));
74-
ByteBuffer serializedBody = builder.dataBuffer();
75-
ByteBuffer serializedHeader =
76-
serializeHeader(MessageHeader.Schema, serializedBody.remaining());
77-
78-
long size = out.writeIntLittleEndian(serializedHeader.remaining());
79-
size += out.write(serializedHeader);
80-
size += out.write(serializedBody);
73+
int schemaOffset = schema.getSchema(builder);
74+
ByteBuffer serializedMessage = serializeMessage(builder, MessageHeader.Schema, schemaOffset, 0);
75+
long size = out.writeIntLittleEndian(serializedMessage.remaining());
76+
size += out.write(serializedMessage);
8177
return size;
8278
}
8379

8480
/**
8581
* Deserializes a schema object. Format is from serialize().
8682
*/
8783
public static Schema deserializeSchema(ReadChannel in) throws IOException {
88-
Message header = deserializeHeader(in, MessageHeader.Schema);
89-
if (header == null) {
84+
Message message = deserializeMessage(in, MessageHeader.Schema);
85+
if (message == null) {
9086
throw new IOException("Unexpected end of input. Missing schema.");
9187
}
9288

93-
// Now read the schema.
94-
ByteBuffer buffer = ByteBuffer.allocate((int)header.bodyLength());
95-
if (in.readFully(buffer) != header.bodyLength()) {
96-
throw new IOException("Unexpected end of input trying to read schema.");
97-
}
98-
buffer.rewind();
99-
return Schema.deserialize(buffer);
89+
return Schema.convertSchema((org.apache.arrow.flatbuf.Schema)
90+
message.header(new org.apache.arrow.flatbuf.Schema()));
10091
}
10192

10293
/**
@@ -106,37 +97,22 @@ public static ArrowBlock serialize(WriteChannel out, ArrowRecordBatch batch)
10697
throws IOException {
10798
long start = out.getCurrentPosition();
10899
int bodyLength = batch.computeBodyLength();
109-
ByteBuffer metadata = WriteChannel.serialize(batch);
110-
111-
int messageLength = 4 + metadata.remaining() + bodyLength;
112-
ByteBuffer serializedHeader =
113-
serializeHeader(MessageHeader.RecordBatch, messageLength);
114-
115-
// Compute the required alignment. This is not a great way to do it. The issue is
116-
// that we need to know the message size to serialize the message header but the
117-
// size depends on the alignment, which depends on the message header.
118-
// This will serialize the header again with the updated size alignment adjusted.
119-
// TODO: We really just want sizeof(MessageHeader) from the serializeHeader() above.
120-
// Is there a way to do this?
121-
long bufferOffset = start + 4 + serializedHeader.remaining() + 4 + metadata.remaining();
122-
if (bufferOffset % 8 != 0) {
123-
messageLength += 8 - bufferOffset % 8;
124-
serializedHeader = serializeHeader(MessageHeader.RecordBatch, messageLength);
125-
}
126100

127-
// Write message header.
128-
out.writeIntLittleEndian(serializedHeader.remaining());
129-
out.write(serializedHeader);
101+
FlatBufferBuilder builder = new FlatBufferBuilder();
102+
int batchOffset = batch.writeTo(builder);
103+
104+
ByteBuffer serializedMessage = serializeMessage(builder, MessageHeader.RecordBatch,
105+
batchOffset, bodyLength);
130106

131-
// Write batch header. with the 4 byte little endian prefix
132-
out.writeIntLittleEndian(metadata.remaining());
133-
int metadataSize = metadata.remaining();
134-
long batchStart = out.getCurrentPosition();
135-
out.write(metadata);
107+
long metadataStart = out.getCurrentPosition();
108+
out.writeIntLittleEndian(serializedMessage.remaining());
109+
out.write(serializedMessage);
136110

137111
// Align the output to 8 byte boundary.
138112
out.align();
139113

114+
long metadataSize = out.getCurrentPosition() - metadataStart;
115+
140116
long bufferStart = out.getCurrentPosition();
141117
List<ArrowBuf> buffers = batch.getBuffers();
142118
List<ArrowBuffer> buffersLayout = batch.getBuffersLayout();
@@ -154,31 +130,31 @@ public static ArrowBlock serialize(WriteChannel out, ArrowRecordBatch batch)
154130
" != " + startPosition + layout.getSize());
155131
}
156132
}
157-
return new ArrowBlock(batchStart, metadataSize, out.getCurrentPosition() - bufferStart);
133+
return new ArrowBlock(start, (int) metadataSize, out.getCurrentPosition() - bufferStart);
158134
}
159135

160136
/**
161137
* Deserializes a RecordBatch
162138
*/
163139
public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in,
164140
BufferAllocator alloc) throws IOException {
165-
Message header = deserializeHeader(in, MessageHeader.RecordBatch);
166-
if (header == null) return null;
141+
Message message = deserializeMessage(in, MessageHeader.RecordBatch);
142+
if (message == null) return null;
167143

168-
int messageLen = (int)header.bodyLength();
169-
// Now read the buffer. This has the metadata followed by the data.
170-
ArrowBuf buffer = alloc.buffer(messageLen);
171-
long readPosition = in.getCurrentPositiion();
172-
if (in.readFully(buffer, messageLen) != messageLen) {
173-
throw new IOException("Unexpected end of input trying to read batch.");
144+
if (message.bodyLength() > Integer.MAX_VALUE) {
145+
throw new IOException("Cannot currently deserialize record batches over 2GB");
174146
}
175147

176-
// Read the length of the metadata.
177-
int metadataLen = buffer.readInt();
178-
buffer = buffer.slice(4, messageLen - 4);
179-
readPosition += 4;
180-
messageLen -= 4;
181-
return deserializeRecordBatch(buffer, readPosition, metadataLen, messageLen);
148+
RecordBatch recordBatchFB = (RecordBatch) message.header(new RecordBatch());
149+
150+
int bodyLength = (int) message.bodyLength();
151+
152+
// Now read the record batch body
153+
ArrowBuf buffer = alloc.buffer(bodyLength);
154+
if (in.readFully(buffer, bodyLength) != bodyLength) {
155+
throw new IOException("Unexpected end of input trying to read batch.");
156+
}
157+
return deserializeRecordBatch(recordBatchFB, buffer);
182158
}
183159

184160
/**
@@ -188,41 +164,41 @@ public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in,
188164
public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in, ArrowBlock block,
189165
BufferAllocator alloc) throws IOException {
190166
long readPosition = in.getCurrentPositiion();
167+
168+
// Metadata length contains byte padding
191169
long totalLen = block.getMetadataLength() + block.getBodyLength();
192-
if ((readPosition + block.getMetadataLength()) % 8 != 0) {
193-
// Compute padded size.
194-
totalLen += (8 - (readPosition + block.getMetadataLength()) % 8);
195-
}
196170

197171
if (totalLen > Integer.MAX_VALUE) {
198172
throw new IOException("Cannot currently deserialize record batches over 2GB");
199173
}
200174

201-
202175
ArrowBuf buffer = alloc.buffer((int) totalLen);
203176
if (in.readFully(buffer, (int) totalLen) != totalLen) {
204177
throw new IOException("Unexpected end of input trying to read batch.");
205178
}
206179

207-
return deserializeRecordBatch(buffer, readPosition, block.getMetadataLength(), (int) totalLen);
180+
return deserializeRecordBatch(buffer, block.getMetadataLength(), (int) totalLen);
208181
}
209182

210183
// Deserializes a record batch. Buffer should start at the RecordBatch and include
211184
// all the bytes for the metadata and then data buffers.
212-
private static ArrowRecordBatch deserializeRecordBatch(
213-
ArrowBuf buffer, long readPosition, int metadataLen, int bufferLen) {
185+
private static ArrowRecordBatch deserializeRecordBatch(ArrowBuf buffer, int metadataLen,
186+
int bufferLen) {
214187
// Read the metadata.
215188
RecordBatch recordBatchFB =
216189
RecordBatch.getRootAsRecordBatch(buffer.nioBuffer().asReadOnlyBuffer());
217190

218191
int bufferOffset = metadataLen;
219-
readPosition += bufferOffset;
220-
if (readPosition % 8 != 0) {
221-
bufferOffset += (int)(8 - readPosition % 8);
222-
}
223192

224193
// Now read the body
225194
final ArrowBuf body = buffer.slice(bufferOffset, bufferLen - bufferOffset);
195+
return deserializeRecordBatch(recordBatchFB, body);
196+
}
197+
198+
// Deserializes a record batch given the Flatbuffer metadata and in-memory body
199+
private static ArrowRecordBatch deserializeRecordBatch(RecordBatch recordBatchFB,
200+
ArrowBuf body) {
201+
// Now read the body
226202
int nodesLength = recordBatchFB.nodesLength();
227203
List<ArrowFieldNode> nodes = new ArrayList<>();
228204
for (int i = 0; i < nodesLength; ++i) {
@@ -237,43 +213,44 @@ private static ArrowRecordBatch deserializeRecordBatch(
237213
}
238214
ArrowRecordBatch arrowRecordBatch =
239215
new ArrowRecordBatch(recordBatchFB.length(), nodes, buffers);
240-
buffer.release();
216+
body.release();
241217
return arrowRecordBatch;
242218
}
243219

244220
/**
245221
* Serializes a message header.
246222
*/
247-
private static ByteBuffer serializeHeader(byte headerType, int bodyLength) {
248-
FlatBufferBuilder headerBuilder = new FlatBufferBuilder();
249-
Message.startMessage(headerBuilder);
250-
Message.addHeaderType(headerBuilder, headerType);
251-
Message.addVersion(headerBuilder, MetadataVersion.V1);
252-
Message.addBodyLength(headerBuilder, bodyLength);
253-
headerBuilder.finish(Message.endMessage(headerBuilder));
254-
return headerBuilder.dataBuffer();
223+
private static ByteBuffer serializeMessage(FlatBufferBuilder builder, byte headerType,
224+
int headerOffset, int bodyLength) {
225+
Message.startMessage(builder);
226+
Message.addHeaderType(builder, headerType);
227+
Message.addHeader(builder, headerOffset);
228+
Message.addVersion(builder, MetadataVersion.V1);
229+
Message.addBodyLength(builder, bodyLength);
230+
builder.finish(Message.endMessage(builder));
231+
return builder.dataBuffer();
255232
}
256233

257-
private static Message deserializeHeader(ReadChannel in, byte headerType) throws IOException {
258-
// Read the header size. There is an i32 little endian prefix.
234+
private static Message deserializeMessage(ReadChannel in, byte headerType) throws IOException {
235+
// Read the message size. There is an i32 little endian prefix.
259236
ByteBuffer buffer = ByteBuffer.allocate(4);
260237
if (in.readFully(buffer) != 4) {
261238
return null;
262239
}
263240

264-
int headerLength = bytesToInt(buffer.array());
265-
buffer = ByteBuffer.allocate(headerLength);
266-
if (in.readFully(buffer) != headerLength) {
241+
int messageLength = bytesToInt(buffer.array());
242+
buffer = ByteBuffer.allocate(messageLength);
243+
if (in.readFully(buffer) != messageLength) {
267244
throw new IOException(
268-
"Unexpected end of stream trying to read header.");
245+
"Unexpected end of stream trying to read message.");
269246
}
270247
buffer.rewind();
271248

272-
Message header = Message.getRootAsMessage(buffer);
273-
if (header.headerType() != headerType) {
249+
Message message = Message.getRootAsMessage(buffer);
250+
if (message.headerType() != headerType) {
274251
throw new IOException("Invalid message: expecting " + headerType +
275-
". Message contained: " + header.headerType());
252+
". Message contained: " + message.headerType());
276253
}
277-
return header;
254+
return message;
278255
}
279256
}

0 commit comments

Comments
 (0)