Skip to content

Commit

Permalink
Make DeserializationSchema.deserialize return List
Browse files Browse the repository at this point in the history
  • Loading branch information
ruanwenjun committed Nov 15, 2023
1 parent 3b6de37 commit e5a02b3
Show file tree
Hide file tree
Showing 27 changed files with 229 additions and 234 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,22 @@

package org.apache.seatunnel.api.serialization;

import org.apache.seatunnel.api.source.Collector;
import org.apache.seatunnel.api.table.type.SeaTunnelDataType;

import java.io.IOException;
import java.io.Serializable;
import java.util.List;

public interface DeserializationSchema<T> extends Serializable {

/**
* Deserializes the byte message.
*
* @param message The message, as a byte array.
* @return The deserialized message as an SeaTunnel Row (null if the message cannot be
* @return The deserialized message as an SeaTunnel Row (Empty list if the message cannot be
* deserialized).
*/
T deserialize(byte[] message) throws IOException;

default void deserialize(byte[] message, Collector<T> out) throws IOException {
T deserialize = deserialize(message);
if (deserialize != null) {
out.collect(deserialize);
}
}
List<T> deserialize(byte[] message) throws IOException;

SeaTunnelDataType<T> getProducedType();
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,14 @@
import org.apache.seatunnel.api.serialization.DeserializationSchema;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;

import org.apache.commons.collections4.CollectionUtils;

import lombok.extern.slf4j.Slf4j;

import java.io.IOException;
import java.util.List;

@Slf4j
public class AmazonSqsDeserializer implements SeaTunnelRowDeserializer {

private final DeserializationSchema<SeaTunnelRow> deserializationSchema;
Expand All @@ -33,8 +39,19 @@ public AmazonSqsDeserializer(DeserializationSchema<SeaTunnelRow> deserialization
@Override
public SeaTunnelRow deserializeRow(String row) {
try {
return deserializationSchema.deserialize(row.getBytes());
List<SeaTunnelRow> seaTunnelRows = deserializationSchema.deserialize(row.getBytes());
if (CollectionUtils.isEmpty(seaTunnelRows)) {
log.warn("The AmazonSqsDeserializer deserialize result is empty");
return null;
}
if (seaTunnelRows.size() != 1) {
log.warn(
"The AmazonSqsDeserializer only support one row, but got {} rows, will drop the extra rows",
seaTunnelRows.size());
}
return seaTunnelRows.get(0);
} catch (IOException e) {
log.error("Failed to deserialize row: {}", row, e);
return null;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ public final class SeaTunnelRowDebeziumDeserializeSchema
@Override
public void deserialize(SourceRecord record, Collector<SeaTunnelRow> collector)
throws Exception {
// todo: remove this kind of logic out of the deserialization schema
if (isSchemaChangeBeforeWatermarkEvent(record)) {
collector.markSchemaChangeBeforeCheckpoint();
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;

@Slf4j
Expand Down Expand Up @@ -97,16 +98,18 @@ public void read(String path, String tableId, Collector<SeaTunnelRow> output)
.forEach(
line -> {
try {
SeaTunnelRow seaTunnelRow =
List<SeaTunnelRow> seaTunnelRows =
deserializationSchema.deserialize(line.getBytes());
if (isMergePartition) {
int index = seaTunnelRowType.getTotalFields();
for (String value : partitionsMap.values()) {
seaTunnelRow.setField(index++, value);
for (SeaTunnelRow seaTunnelRow : seaTunnelRows) {
if (isMergePartition) {
int index = seaTunnelRowType.getTotalFields();
for (String value : partitionsMap.values()) {
seaTunnelRow.setField(index++, value);
}
}
seaTunnelRow.setTableId(tableId);
output.collect(seaTunnelRow);
}
seaTunnelRow.setTableId(tableId);
output.collect(seaTunnelRow);
} catch (IOException e) {
String errorMsg =
String.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.Optional;

Expand Down Expand Up @@ -94,32 +95,34 @@ public void read(String path, String tableId, Collector<SeaTunnelRow> output)
.forEach(
line -> {
try {
SeaTunnelRow seaTunnelRow =
List<SeaTunnelRow> seaTunnelRows =
deserializationSchema.deserialize(line.getBytes());
if (!readColumns.isEmpty()) {
// need column projection
Object[] fields;
if (isMergePartition) {
fields =
new Object
[readColumns.size()
+ partitionsMap.size()];
} else {
fields = new Object[readColumns.size()];
}
for (int i = 0; i < indexes.length; i++) {
fields[i] = seaTunnelRow.getField(indexes[i]);
for (SeaTunnelRow seaTunnelRow : seaTunnelRows) {
if (!readColumns.isEmpty()) {
// need column projection
Object[] fields;
if (isMergePartition) {
fields =
new Object
[readColumns.size()
+ partitionsMap.size()];
} else {
fields = new Object[readColumns.size()];
}
for (int i = 0; i < indexes.length; i++) {
fields[i] = seaTunnelRow.getField(indexes[i]);
}
seaTunnelRow = new SeaTunnelRow(fields);
}
seaTunnelRow = new SeaTunnelRow(fields);
}
if (isMergePartition) {
int index = seaTunnelRowType.getTotalFields();
for (String value : partitionsMap.values()) {
seaTunnelRow.setField(index++, value);
if (isMergePartition) {
int index = seaTunnelRowType.getTotalFields();
for (String value : partitionsMap.values()) {
seaTunnelRow.setField(index++, value);
}
}
seaTunnelRow.setTableId(tableId);
output.collect(seaTunnelRow);
}
seaTunnelRow.setTableId(tableId);
output.collect(seaTunnelRow);
} catch (IOException e) {
String errorMsg =
String.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import org.apache.seatunnel.common.exception.CommonErrorCodeDeprecated;
import org.apache.seatunnel.connectors.seatunnel.google.sheets.exception.GoogleSheetsConnectorException;

import org.apache.commons.collections4.CollectionUtils;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -51,7 +53,16 @@ public SeaTunnelRow deserializeRow(List<Object> row) {
}
}
String rowStr = objectMapper.writeValueAsString(map);
return deserializationSchema.deserialize(rowStr.getBytes());
List<SeaTunnelRow> seaTunnelRows = deserializationSchema.deserialize(rowStr.getBytes());
if (CollectionUtils.isEmpty(seaTunnelRows)) {
return null;
}
if (seaTunnelRows.size() != 1) {
throw new GoogleSheetsConnectorException(
CommonErrorCodeDeprecated.JSON_OPERATION_FAILED,
"Object json deserialization failed, the data contains multiple rows");
}
return seaTunnelRows.get(0);
} catch (IOException e) {
throw new GoogleSheetsConnectorException(
CommonErrorCodeDeprecated.JSON_OPERATION_FAILED,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import lombok.AllArgsConstructor;

import java.io.IOException;
import java.util.List;

@AllArgsConstructor
public class DeserializationCollector {
Expand All @@ -35,8 +36,8 @@ public void collect(byte[] message, Collector<SeaTunnelRow> out) throws IOExcept
if (deserializationSchema instanceof JsonDeserializationSchema) {
((JsonDeserializationSchema) deserializationSchema).collect(message, out);
} else {
SeaTunnelRow deserialize = deserializationSchema.deserialize(message);
out.collect(deserialize);
List<SeaTunnelRow> seaTunnelRows = deserializationSchema.deserialize(message);
seaTunnelRows.forEach(out::collect);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,17 @@

import lombok.AllArgsConstructor;

import java.util.Collections;
import java.util.List;

@AllArgsConstructor
public class SimpleTextDeserializationSchema implements DeserializationSchema<SeaTunnelRow> {

private SeaTunnelRowType rowType;

@Override
public SeaTunnelRow deserialize(byte[] message) {
return new SeaTunnelRow(new Object[] {new String(message)});
public List<SeaTunnelRow> deserialize(byte[] message) {
return Collections.singletonList(new SeaTunnelRow(new Object[] {new String(message)}));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,12 @@ public void pollNext(Collector<SeaTunnelRow> output) throws Exception {
CompatibleKafkaConnectDeserializationSchema) {
((CompatibleKafkaConnectDeserializationSchema)
deserializationSchema)
.deserialize(
record, output);
.deserialize(record)
.forEach(output::collect);
} else {
deserializationSchema.deserialize(
record.value(), output);
deserializationSchema
.deserialize(record.value())
.forEach(output::collect);
}
} catch (IOException e) {
if (this.messageFormatErrorHandleWay
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode;

import org.apache.seatunnel.api.serialization.DeserializationSchema;
import org.apache.seatunnel.api.source.Collector;
import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.common.utils.JsonUtils;
import org.apache.seatunnel.format.json.canal.CanalJsonDeserializationSchema;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/**
* for pulsar-connector, the data format is
Expand All @@ -52,22 +52,19 @@ public PulsarCanalDecorator(CanalJsonDeserializationSchema canalJsonDeserializat
}

@Override
public SeaTunnelRow deserialize(byte[] message) throws IOException {
throw new UnsupportedOperationException();
}

@Override
public void deserialize(byte[] message, Collector<SeaTunnelRow> out) throws IOException {
public List<SeaTunnelRow> deserialize(byte[] message) {
JsonNode pulsarCanal = JsonUtils.parseObject(message);
ArrayNode canalList = JsonUtils.parseArray(pulsarCanal.get(MESSAGE).asText());
Iterator<JsonNode> canalIterator = canalList.elements();
List<SeaTunnelRow> seaTunnelRows = new ArrayList<>();
while (canalIterator.hasNext()) {
JsonNode next = canalIterator.next();
// reconvert pulsar handler, reference to
// https://github.com/apache/pulsar/blob/master/pulsar-io/canal/src/main/java/org/apache/pulsar/io/canal/MessageUtils.java
ObjectNode root = reconvertPulsarData((ObjectNode) next);
canalJsonDeserializationSchema.deserialize(root, out);
seaTunnelRows.addAll(canalJsonDeserializationSchema.deserialize(root));
}
return seaTunnelRows;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ public void pollNext(Collector<T> output) throws Exception {
final Message<byte[]> message = recordWithSplitId.get().getMessage();
synchronized (output.getCheckpointLock()) {
splitStates.get(splitId).setLatestConsumedId(message.getMessageId());
deserialization.deserialize(message.getData(), output);
deserialization.deserialize(message.getData()).forEach(output::collect);
}
}
if (noMoreSplitsAssignment && finishedSplits.size() == splitStates.size()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.seatunnel.connectors.seatunnel.pulsar.source;

import org.apache.seatunnel.api.source.Collector;
import org.apache.seatunnel.api.table.type.BasicType;
import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
Expand All @@ -28,11 +27,8 @@
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import lombok.Getter;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;

public class PulsarCanalDecoratorTest {
Expand Down Expand Up @@ -61,27 +57,13 @@ void decoder() throws IOException {
PulsarCanalDecorator pulsarCanalDecorator =
new PulsarCanalDecorator(canalJsonDeserializationSchema);

SimpleCollector simpleCollector = new SimpleCollector();
pulsarCanalDecorator.deserialize(json.getBytes(StandardCharsets.UTF_8), simpleCollector);
Assertions.assertFalse(simpleCollector.getList().isEmpty());
for (SeaTunnelRow seaTunnelRow : simpleCollector.list) {
List<SeaTunnelRow> seaTunnelRows =
pulsarCanalDecorator.deserialize(json.getBytes(StandardCharsets.UTF_8));
Assertions.assertFalse(seaTunnelRows.isEmpty());
for (SeaTunnelRow seaTunnelRow : seaTunnelRows) {
for (Object field : seaTunnelRow.getFields()) {
Assertions.assertNotNull(field);
}
}
}

private static class SimpleCollector implements Collector<SeaTunnelRow> {
@Getter private List<SeaTunnelRow> list = new ArrayList<>();

@Override
public void collect(SeaTunnelRow record) {
list.add(record);
}

@Override
public Object getCheckpointLock() {
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public void pollNext(Collector output) throws Exception {
return;
}
deliveryTagsProcessedForCurrentSnapshot.add(envelope.getDeliveryTag());
deserializationSchema.deserialize(body, output);
deserializationSchema.deserialize(body).forEach(output::collect);
}

if (Boundedness.BOUNDED.equals(context.getBoundedness())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,14 @@ public void pollNext(Collector<SeaTunnelRow> output) throws Exception {
SeaTunnelDataType<SeaTunnelRow> seaTunnelRowType =
deserializationSchema.getProducedType();
valuesMap.put(((SeaTunnelRowType) seaTunnelRowType).getFieldName(0), k);
deserializationSchema.deserialize(
JsonUtils.toJsonString(valuesMap).getBytes(), output);
deserializationSchema
.deserialize(JsonUtils.toJsonString(valuesMap).getBytes())
.forEach(output::collect);
}
} else {
deserializationSchema.deserialize(value.getBytes(), output);
deserializationSchema
.deserialize(value.getBytes())
.forEach(output::collect);
}
}
}
Expand Down
Loading

0 comments on commit e5a02b3

Please sign in to comment.