Skip to content

Commit

Permalink
fix: Map services & rpcs from descriptors
Browse files Browse the repository at this point in the history
  • Loading branch information
semtexzv committed Sep 21, 2023
1 parent 5ec665d commit 88cb277
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 8 deletions.
1 change: 1 addition & 0 deletions protokit_binformat/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ path = "../protokit_derive"
version = "0.1.2"

[dependencies]
bytes = "1.4.0"
indexmap = "1"
thiserror = "1"
bumpalo = { version = "3.13.0", optional = true, default_features = false }
59 changes: 52 additions & 7 deletions protokit_desc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ impl FieldDef {
FieldDescriptorProtoType::TYPE_GROUP => {
let mut n = desc.type_name.as_deref().unwrap();
while let Some(p) = n.find('.') {
n = &n[p + 1 ..]
n = &n[p + 1..]
}
name = set.cache(n);
DataType::Unresolved(set.cache(desc.type_name.as_ref().unwrap()), UnresolvedHint::Group)
Expand Down Expand Up @@ -410,7 +410,7 @@ impl FieldDef {
DataType::Map(map) => {
let mut name = self.name.clone().to_string();
unsafe {
name.as_bytes_mut()[.. 1].make_ascii_uppercase();
name.as_bytes_mut()[..1].make_ascii_uppercase();
}
let map_entry_name = format!("{name}Entry");
fout.type_name = Some(format!(
Expand Down Expand Up @@ -666,6 +666,23 @@ impl ServiceDef {
}
out
}

#[cfg(feature = "descriptors")]
fn from_descriptor(
set: &mut FileSetDef,
file: &FileDescriptorProto,
name: &ArcStr,
desc: &ServiceDescriptorProto,
) -> Self {
Self {
name: name.clone(),
rpc: desc.method.iter().map(|v| {
let name = set.cache(v.name.as_ref().expect("Missing service name"));
(name.clone(), RpcDef::from_descriptor(set, &name, &v))
}).collect(),
options: desc.options.as_deref().cloned().unwrap_or_default(),
}
}
}

#[derive(Debug, Default)]
Expand Down Expand Up @@ -708,6 +725,21 @@ impl RpcDef {
};
out
}
#[cfg(feature = "descriptors")]
fn from_descriptor(
set: &mut FileSetDef,
name: &ArcStr,
desc: &MethodDescriptorProto,
) -> Self {
Self {
name: name.clone(),
req_stream: desc.client_streaming.unwrap_or_default(),
req_typ: DataType::Unresolved(set.cache(desc.input_type.as_ref().unwrap()), UnresolvedHint::Message),
res_stream: desc.server_streaming.unwrap_or_default(),
res_typ: DataType::Unresolved(set.cache(desc.output_type.as_ref().unwrap()), UnresolvedHint::Message),
options: desc.options.as_deref().cloned().unwrap_or_default(),
}
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -932,7 +964,7 @@ impl FileDef {
.transpose()
.unwrap()
.unwrap_or(Syntax::Proto2),
package: set.cache(desc.package.as_ref().unwrap()),
package: set.cache(desc.package.as_ref().expect("Missing package name")),
messages: Default::default(),
enums: Default::default(),
services: Default::default(),
Expand Down Expand Up @@ -972,6 +1004,16 @@ impl FileDef {
let def = MessageDef::from_descriptor(set, file, &name, desc);
this.messages.insert(name, def);
}
fn parse_svc(
set: &mut FileSetDef,
file: &FileDescriptorProto,
this: &mut FileDef,
desc: &ServiceDescriptorProto,
) {
let name = set.cache(desc.name.as_ref().unwrap());
let def = ServiceDef::from_descriptor(set, file, &name, desc);
this.services.insert(name, def);
}

for desc in &desc.enum_type {
parse_enum(set, &mut this, None, desc);
Expand All @@ -980,6 +1022,9 @@ impl FileDef {
for field_desc in &desc.message_type {
parse_msg(set, desc, &mut this, None, field_desc);
}
for svc_desc in &desc.service {
parse_svc(set, desc, &mut this, svc_desc);
}
this
}
#[cfg(feature = "descriptors")]
Expand Down Expand Up @@ -1169,7 +1214,7 @@ fn try_resolve_within_scopes(
let qualified = format!("{scope}{scope_dot}{symbol}");
match (names.get(qualified.as_str()), scope.rfind('.')) {
(Some(v), _) => return Some(*v),
(None, Some(p)) => scope = &scope[.. p],
(None, Some(p)) => scope = &scope[..p],
// Resolve globally without the prefix
(None, None) => return names.get(symbol).copied(),
}
Expand All @@ -1185,14 +1230,14 @@ fn try_resolve_symbol(
if let Some(without_dot) = symbol.strip_prefix('.') {
// We're searching for global symbol. If package prefix matches, we can search for the inner part of the symbol
if let Some(without_package) = without_dot.strip_prefix(file_package) {
let localized_symbol = &without_package[1 ..];
let localized_symbol = &without_package[1..];
return names.get(localized_symbol).cloned();
} else {
eprintln!("Package mismatch: {} within: {}", without_dot, &file_package);
return None;
}
} else if let Some(localized) = symbol.strip_prefix(file_package) {
let localized_symbol = &localized[1 ..];
let localized_symbol = &localized[1..];
return names.get(localized_symbol).cloned();
}

Expand All @@ -1210,7 +1255,7 @@ fn try_resolve_symbol(
) {
(Some(v), _) => return Some(v),
// We need to remove subpackages, because name sections might be of nested messages, not package names
(None, Some(v)) => file_package = &file_package[.. v],
(None, Some(v)) => file_package = &file_package[..v],
(None, None) => {
return try_resolve_within_scopes(names, "", symbol);
}
Expand Down
2 changes: 1 addition & 1 deletion protokit_examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ publish = false


[dependencies]
protokit = { path = "../protokit", version = "0.1.2", features = ["textformat"] }
protokit = { path = "../protokit", version = "0.1.2", features = ["textformat", "grpc"] }

[build-dependencies]
protokit_build = { path = "../protokit_build" }

0 comments on commit 88cb277

Please sign in to comment.