Skip to content

Commit

Permalink
refactor: improve exception pass in packet.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
fu050409 committed Apr 15, 2024
1 parent afccefc commit 2dec0ce
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 40 deletions.
12 changes: 6 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

57 changes: 23 additions & 34 deletions src/models/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,22 +92,15 @@ impl<'a> OKE<'a> {
})
}

pub fn from_public_key_bytes(
&mut self,
public_key_bytes: &[u8],
) -> Result<&mut Self, OblivionException> {
self.public_key = Some(PublicKey::from_sec1_bytes(public_key_bytes).unwrap());
pub fn from_public_key_bytes(&mut self, public_key_bytes: &[u8]) -> Result<&mut Self> {
self.public_key = Some(PublicKey::from_sec1_bytes(public_key_bytes)?);
Ok(self)
}

pub async fn from_stream(
&mut self,
stream: &mut Socket,
) -> Result<&mut Self, OblivionException> {
pub async fn from_stream(&mut self, stream: &mut Socket) -> Result<&mut Self> {
let remote_public_key_length = stream.recv_len().await?;
let remote_public_key_bytes = stream.recv(remote_public_key_length).await?;
self.remote_public_key =
Some(PublicKey::from_sec1_bytes(&remote_public_key_bytes).unwrap());
self.remote_public_key = Some(PublicKey::from_sec1_bytes(&remote_public_key_bytes)?);
self.shared_aes_key = Some(generate_shared_key(
self.private_key.as_ref().unwrap(),
self.remote_public_key.as_ref().unwrap(),
Expand All @@ -122,44 +115,40 @@ impl<'a> OKE<'a> {
) -> Result<&mut Self, OblivionException> {
let remote_public_key_length = stream.recv_len().await?;
let remote_public_key_bytes = stream.recv(remote_public_key_length).await?;
self.remote_public_key =
Some(PublicKey::from_sec1_bytes(&remote_public_key_bytes).unwrap());
self.remote_public_key = Some(PublicKey::from_sec1_bytes(&remote_public_key_bytes)?);
let salt_length = stream.recv_len().await?;
self.salt = Some(stream.recv(salt_length).await?);
self.shared_aes_key = Some(generate_shared_key(
self.private_key.as_ref().unwrap(),
self.remote_public_key.as_ref().unwrap(),
&self.salt.as_mut().unwrap(),
self.private_key.unwrap(),
&self.remote_public_key.unwrap(),
self.salt.as_mut().unwrap(),
)?);
Ok(self)
}

pub async fn to_stream(&mut self, stream: &mut Socket) -> Result<(), OblivionException> {
stream.send(&self.plain_data()).await?;
pub async fn to_stream(&mut self, stream: &mut Socket) -> Result<()> {
stream.send(&self.plain_data()?).await?;
Ok(())
}

pub async fn to_stream_with_salt(
&mut self,
stream: &mut Socket,
) -> Result<(), OblivionException> {
stream.send(&self.plain_data()).await?;
stream.send(&self.plain_salt()).await?;
pub async fn to_stream_with_salt(&mut self, stream: &mut Socket) -> Result<()> {
stream.send(&self.plain_data()?).await?;
stream.send(&self.plain_salt()?).await?;
Ok(())
}

pub fn plain_data(&mut self) -> Vec<u8> {
pub fn plain_data(&mut self) -> Result<Vec<u8>> {
let public_key_bytes = self.public_key.unwrap().to_sec1_bytes().to_vec();
let mut plain_data_bytes = length(&public_key_bytes).unwrap();
let mut plain_data_bytes = length(&public_key_bytes)?;
plain_data_bytes.extend(public_key_bytes);
plain_data_bytes
Ok(plain_data_bytes)
}

pub fn plain_salt(&mut self) -> Vec<u8> {
pub fn plain_salt(&mut self) -> Result<Vec<u8>> {
let salt_bytes = self.salt.as_ref().unwrap();
let mut plain_salt_bytes = length(&salt_bytes).unwrap();
let mut plain_salt_bytes = length(&salt_bytes)?;
plain_salt_bytes.extend(salt_bytes);
plain_salt_bytes
Ok(plain_salt_bytes)
}

pub fn get_aes_key(&mut self) -> Vec<u8> {
Expand Down Expand Up @@ -188,7 +177,7 @@ impl OED {
}
}

fn serialize_bytes(&self, data: &[u8], size: Option<usize>) -> Vec<Vec<u8>> {
fn serialize_bytes(&self, data: &[u8], size: Option<usize>) -> Result<Vec<Vec<u8>>> {
let size = if size.is_none() {
let size: usize = 1024;
size
Expand All @@ -202,7 +191,7 @@ impl OED {

for i in (0..data_size).step_by(size) {
let buffer = &data[i..std::cmp::min(i + size, data_size)];
let buffer_length = length(&buffer.to_vec()).unwrap();
let buffer_length = length(&buffer.to_vec())?;
let mut serialized_chunk = Vec::with_capacity(buffer_length.len() + buffer.len());

if i + size > data_size {
Expand All @@ -219,7 +208,7 @@ impl OED {
}

serialized_bytes.push(b"0000".to_vec());
serialized_bytes
Ok(serialized_bytes)
}

pub fn from_json_or_string(
Expand Down Expand Up @@ -332,7 +321,7 @@ impl OED {

self.chunk_size = 0;
for bytes in self
.serialize_bytes(&self.encrypted_data.as_ref().unwrap(), None)
.serialize_bytes(&self.encrypted_data.as_ref().unwrap(), None)?
.iter()
{
stream.send(&bytes).await?;
Expand Down

0 comments on commit 2dec0ce

Please sign in to comment.