Skip to content

Commit

Permalink
Protobuf deserialization support (#715)
Browse files Browse the repository at this point in the history
  • Loading branch information
mwylde authored Aug 14, 2024
1 parent 964c40c commit 19f258b
Show file tree
Hide file tree
Showing 45 changed files with 965 additions and 166 deletions.
241 changes: 143 additions & 98 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ deltalake = { version = "0.18.2" }
cornucopia = { version = "0.9.0" }
cornucopia_async = {version = "0.6.0"}
deadpool-postgres = "0.12"
prost = "0.12"
prost-reflect = "0.12.0"
prost-build = {version = "0.12" }
prost-types = "0.12"

[profile.release]
debug = 1
Expand Down
3 changes: 2 additions & 1 deletion crates/arroyo-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ arroyo-udf-host = { path = "../arroyo-udf/arroyo-udf-host" }
tonic = { workspace = true }
tonic-reflection = { workspace = true }
tonic-web = { workspace = true }
prost = "0.12"
prost = {workspace = true}
prost-reflect = "0.12.0"
tokio = { version = "1", features = ["full"] }
tokio-stream = "0.1.12"
tower = "0.4"
Expand Down
72 changes: 69 additions & 3 deletions crates/arroyo-api/src/connection_tables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ use tracing::debug;
use arroyo_connectors::confluent::ConfluentProfile;
use arroyo_connectors::connector_for_type;
use arroyo_connectors::kafka::{KafkaConfig, KafkaTable, SchemaRegistry};
use arroyo_formats::{avro, json};
use arroyo_formats::{avro, json, proto};
use arroyo_operator::connector::ErasedConnector;
use arroyo_rpc::api_types::connections::{
ConnectionProfile, ConnectionSchema, ConnectionTable, ConnectionTablePost, ConnectionType,
SchemaDefinition,
};
use arroyo_rpc::api_types::{ConnectionTableCollection, PaginationQueryParams};
use arroyo_rpc::formats::{AvroFormat, Format, JsonFormat};
use arroyo_rpc::formats::{AvroFormat, Format, JsonFormat, ProtobufFormat};
use arroyo_rpc::public_ids::{generate_id, IdTypes};
use arroyo_rpc::schema_resolver::{
ConfluentSchemaRegistry, ConfluentSchemaSubjectResponse, ConfluentSchemaType,
Expand All @@ -39,6 +39,7 @@ use crate::{
queries::api_queries::{self, DbConnectionTable},
to_micros, AuthData,
};
use arroyo_formats::proto::schema::{protobuf_to_arrow, schema_file_to_descriptor};
use cornucopia_async::{Database, DatabaseSource};

async fn get_and_validate_connector(
Expand Down Expand Up @@ -466,6 +467,7 @@ pub(crate) async fn expand_schema(
Format::Parquet(_) => Ok(schema),
Format::RawString(_) => Ok(schema),
Format::RawBytes(_) => Ok(schema),
Format::Protobuf(_) => expand_proto_schema(schema).await,
}
}

Expand Down Expand Up @@ -533,6 +535,66 @@ async fn expand_avro_schema(
Ok(schema)
}

async fn expand_proto_schema(mut schema: ConnectionSchema) -> Result<ConnectionSchema, ErrorResp> {
let Some(Format::Protobuf(ProtobufFormat {
message_name,
compiled_schema,
..
})) = &mut schema.format
else {
panic!("not proto");
};

if let Some(definition) = &schema.definition {
let SchemaDefinition::ProtobufSchema {
schema: protobuf_schema,
dependencies,
} = &definition
else {
return Err(bad_request("Schema is not a protobuf schema"));
};

let message_name = message_name
.as_ref()
.filter(|m| !m.is_empty())
.ok_or_else(|| bad_request("message name must be provided for protobuf schemas"))?;

let encoded = schema_file_to_descriptor(protobuf_schema, dependencies)
.await
.map_err(|e| bad_request(e.to_string()))?;

let pool = proto::schema::get_pool(&encoded)
.map_err(|e| bad_request(format!("error handling protobuf: {}", e)))?;
*compiled_schema = Some(encoded);

let descriptor = pool.get_message_by_name(message_name).ok_or_else(|| {
bad_request(format!(
"Message '{}' not found in proto definition; messages are {}",
message_name,
pool.all_messages()
.map(|m| m.full_name().to_string())
.filter(|m| !m.starts_with("google.protobuf."))
.collect::<Vec<_>>()
.join(", ")
))
})?;

let arrow = protobuf_to_arrow(&descriptor)
.map_err(|e| bad_request(format!("Failed to convert schema: {}", e)))?;

let fields: Result<_, String> = arrow
.fields
.into_iter()
.map(|f| (**f).clone().try_into())
.collect();

schema.fields =
fields.map_err(|e| bad_request(format!("failed to convert schema: {}", e)))?;
};

Ok(schema)
}

async fn expand_json_schema(
name: &str,
connector: &str,
Expand Down Expand Up @@ -665,7 +727,7 @@ async fn get_schema(
pub(crate) async fn test_schema(
WithRejection(Json(req), _): WithRejection<Json<ConnectionSchema>, ApiError>,
) -> Result<(), ErrorResp> {
let Some(schema_def) = req.definition else {
let Some(schema_def) = &req.definition else {
return Ok(());
};

Expand All @@ -677,6 +739,10 @@ pub(crate) async fn test_schema(
Ok(())
}
}
SchemaDefinition::ProtobufSchema { .. } => {
let _ = expand_proto_schema(req.clone()).await?;
Ok(())
}
_ => {
// TODO: add testing for other schema types
Ok(())
Expand Down
1 change: 1 addition & 0 deletions crates/arroyo-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ impl IntoResponse for HttpError {
TestSourceMessage,
JsonFormat,
AvroFormat,
ProtobufFormat,
ParquetFormat,
RawStringFormat,
RawBytesFormat,
Expand Down
2 changes: 1 addition & 1 deletion crates/arroyo-compiler-service/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ arroyo-server-common = { path = "../arroyo-server-common" }
arroyo-storage = { path = "../arroyo-storage" }

tonic = {workspace = true}
prost = "0.12"
prost = {workspace = true}
tokio = { version = "1", features = ["full"] }
tracing = "0.1"
anyhow = "1.0.75"
Expand Down
2 changes: 1 addition & 1 deletion crates/arroyo-connectors/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ tokio-stream = "0.1"
once_cell = "1.17.1"
typify = "0.0.13"
schemars = "0.8"
prost = "0.12"
prost = {workspace = true}
tonic = {workspace = true}
governor = "0.6.0"
anyhow = "1.0.71"
Expand Down
1 change: 1 addition & 0 deletions crates/arroyo-connectors/src/filesystem/source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ impl FileSystemSourceFunc {
}
Format::RawString(_) => todo!(),
Format::RawBytes(_) => todo!(),
Format::Protobuf(_) => todo!("Protobuf not supported"),
}
}

Expand Down
22 changes: 22 additions & 0 deletions crates/arroyo-connectors/src/kafka/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,28 @@ impl KafkaTester {
Format::RawBytes(_) => {
// all bytes are valid
}
Format::Protobuf(_) => {
let aschema: ArroyoSchema = schema.clone().into();
let mut deserializer =
ArrowDeserializer::new(format.clone(), aschema.clone(), None, BadData::Fail {});
let mut builders = aschema.builders();

let mut error = deserializer
.deserialize_slice(&mut builders, &msg, SystemTime::now())
.await
.into_iter()
.next();
if let Some(Err(e)) = deserializer.flush_buffer() {
error.replace(e);
}

if let Some(error) = error {
bail!(
"Failed to parse message according to the provided Protobuf schema: {}",
error.details()
);
}
}
};

Ok(())
Expand Down
2 changes: 1 addition & 1 deletion crates/arroyo-controller/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ arroyo-worker = { path = "../arroyo-worker" }
tonic = {workspace = true}
tonic-reflection = {workspace = true}

prost = "0.12"
prost = {workspace = true}
tokio = { version = "1", features = ["full"] }
tokio-stream = "0.1.12"
rand = "0.8"
Expand Down
2 changes: 1 addition & 1 deletion crates/arroyo-datastream/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ hex = "0.4"
tokio = "1"
tonic = {workspace = true}
anyhow = "1.0.70"
prost = "0.12"
prost = {workspace = true}
regex = "1.9.5"
serde_json = "1.0.108"
strum = { version = "0.25.0", features = ["derive"] }
Expand Down
7 changes: 6 additions & 1 deletion crates/arroyo-formats/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,9 @@ bincode = "2.0.0-rc.3"
memchr = "2"
typify = "0.0.13"
schemars = "0.8"
prost = "0.12"
prost = { workspace = true}
prost-reflect = { workspace = true}
prost-build = { workspace = true }
prost-types = { workspace = true}
base64 = "0.22.1"
uuid = { version = "1.10.0", features = ["v4"] }
21 changes: 3 additions & 18 deletions crates/arroyo-formats/src/avro/de.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::float_to_json;
use apache_avro::types::{Value, Value as AvroValue};
use apache_avro::{from_avro_datum, AvroResult, Reader, Schema};
use arroyo_rpc::formats::AvroFormat;
Expand Down Expand Up @@ -81,22 +82,6 @@ pub(crate) async fn avro_messages(
Ok(messages)
}

fn convert_float(f: f64) -> JsonValue {
match serde_json::Number::from_f64(f) {
Some(n) => JsonValue::Number(n),
None => JsonValue::String(
(if f.is_infinite() && f.is_sign_positive() {
"+Inf"
} else if f.is_infinite() {
"-Inf"
} else {
"NaN"
})
.to_string(),
),
}
}

fn encode_vec(v: Vec<u8>) -> JsonValue {
JsonValue::String(v.into_iter().map(char::from).collect())
}
Expand All @@ -114,8 +99,8 @@ pub(crate) fn avro_to_json(value: AvroValue) -> JsonValue {
| Value::TimestampMicros(i)
| Value::LocalTimestampMillis(i)
| Value::LocalTimestampMicros(i) => JsonValue::Number(serde_json::Number::from(i)),
Value::Float(f) => convert_float(f as f64),
Value::Double(f) => convert_float(f),
Value::Float(f) => float_to_json(f as f64),
Value::Double(f) => float_to_json(f),
Value::String(s) | Value::Enum(_, s) => JsonValue::String(s),
// this isn't the standard Avro json encoding, which just
Value::Bytes(b) | Value::Fixed(_, b) => encode_vec(b),
Expand Down
Loading

0 comments on commit 19f258b

Please sign in to comment.