Skip to content

Commit 2bb8d69

Browse files
golf1052Sanders Lauture
authored andcommitted
Fix AttachmentCipherInputStream so that underlying file is reusable
If the Stream length is shortened to remove the MAC the cipherfile will not be usable again. Instead save the MAC before removing it from the file and add it back on before the Stream is disposed. Also ensure we dispose all Streams correctly.
1 parent 21f242f commit 2bb8d69

File tree

6 files changed

+130
-53
lines changed

6 files changed

+130
-53
lines changed

libsignal-service-dotnet-tests/crypto/AttachmentCipherTest.cs

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,29 +19,68 @@ public void Test_Attachment_EncryptDecrypt()
1919
byte[] plaintextInput = Encoding.UTF8.GetBytes("Peter Parker");
2020
EncryptResult encryptResult = EncryptData(plaintextInput, key);
2121
string cipherFile = WriteToFile(encryptResult.ciphertext);
22-
Stream inputStream = AttachmentCipherInputStream.CreateForAttachment(File.Open(cipherFile, FileMode.Open), plaintextInput.Length, key, encryptResult.digest);
22+
using Stream inputStream = AttachmentCipherInputStream.CreateForAttachment(File.Open(cipherFile, FileMode.Open), plaintextInput.Length, key, encryptResult.digest);
2323
byte[] plaintextOutput = ReadInputStreamFully(inputStream);
2424

2525
CollectionAssert.AreEqual(plaintextInput, plaintextOutput);
2626

2727
DeleteFile(cipherFile);
2828
}
2929

30+
[TestMethod]
31+
public void Test_Attachment_EncryptDecryptMultipleTimes()
32+
{
33+
// Test that the file passed to AttachmentCipherInputStream can be reused.
34+
byte[] key = Util.GetSecretBytes(64);
35+
byte[] plaintextInput = Encoding.UTF8.GetBytes("Peter Parker");
36+
EncryptResult encryptResult = EncryptData(plaintextInput, key);
37+
string cipherFile = WriteToFile(encryptResult.ciphertext);
38+
39+
for (int i = 0; i < 10; i++)
40+
{
41+
using Stream inputStream = AttachmentCipherInputStream.CreateForAttachment(File.Open(cipherFile, FileMode.Open), plaintextInput.Length, key, encryptResult.digest);
42+
byte[] plaintextOutput = ReadInputStreamFully(inputStream);
43+
44+
CollectionAssert.AreEqual(plaintextInput, plaintextOutput);
45+
}
46+
47+
DeleteFile(cipherFile);
48+
}
49+
3050
[TestMethod]
3151
public void Test_Attachment_EncryptDecryptEmpty()
3252
{
3353
byte[] key = Util.GetSecretBytes(64);
3454
byte[] plaintextInput = Encoding.UTF8.GetBytes(string.Empty);
3555
EncryptResult encryptResult = EncryptData(plaintextInput, key);
3656
string cipherFile = WriteToFile(encryptResult.ciphertext);
37-
Stream inputStream = AttachmentCipherInputStream.CreateForAttachment(File.Open(cipherFile, FileMode.Open), plaintextInput.Length, key, encryptResult.digest);
57+
using Stream inputStream = AttachmentCipherInputStream.CreateForAttachment(File.Open(cipherFile, FileMode.Open), plaintextInput.Length, key, encryptResult.digest);
3858
byte[] plaintextOutput = ReadInputStreamFully(inputStream);
3959

4060
CollectionAssert.AreEqual(plaintextInput, plaintextOutput);
4161

4262
DeleteFile(cipherFile);
4363
}
4464

65+
[TestMethod]
66+
public void Test_Attachment_EncryptDecryptEmptyMultipleTimes()
67+
{
68+
byte[] key = Util.GetSecretBytes(64);
69+
byte[] plaintextInput = Encoding.UTF8.GetBytes(string.Empty);
70+
EncryptResult encryptResult = EncryptData(plaintextInput, key);
71+
string cipherFile = WriteToFile(encryptResult.ciphertext);
72+
73+
for (int i = 0; i < 10; i++)
74+
{
75+
using Stream inputStream = AttachmentCipherInputStream.CreateForAttachment(File.Open(cipherFile, FileMode.Open), plaintextInput.Length, key, encryptResult.digest);
76+
byte[] plaintextOutput = ReadInputStreamFully(inputStream);
77+
78+
CollectionAssert.AreEqual(plaintextInput, plaintextOutput);
79+
}
80+
81+
DeleteFile(cipherFile);
82+
}
83+
4584
[TestMethod]
4685
public void Test_Attachment_DecryptFailOnBadKey()
4786
{
@@ -57,7 +96,8 @@ public void Test_Attachment_DecryptFailOnBadKey()
5796

5897
cipherFile = WriteToFile(encryptResult.ciphertext);
5998

60-
AttachmentCipherInputStream.CreateForAttachment(File.Open(cipherFile, FileMode.Open), plaintextInput.Length, badKey, encryptResult.digest);
99+
using FileStream fileStream = File.Open(cipherFile, FileMode.Open);
100+
AttachmentCipherInputStream.CreateForAttachment(fileStream, plaintextInput.Length, badKey, encryptResult.digest);
61101
}
62102
catch (InvalidMessageException)
63103
{
@@ -89,7 +129,8 @@ public void Test_Attachmetn_DecryptFailOnBadDigest()
89129

90130
cipherFile = WriteToFile(encryptResult.ciphertext);
91131

92-
AttachmentCipherInputStream.CreateForAttachment(File.Open(cipherFile, FileMode.Open), plaintextInput.Length, key, badDigest);
132+
using FileStream fileStream = File.Open(cipherFile, FileMode.Open);
133+
AttachmentCipherInputStream.CreateForAttachment(fileStream, plaintextInput.Length, key, badDigest);
93134
}
94135
catch (InvalidMessageException)
95136
{
@@ -120,7 +161,8 @@ public void Test_Attachment_DecryptFailOnNullDigest()
120161

121162
cipherFile = WriteToFile(encryptResult.ciphertext);
122163

123-
AttachmentCipherInputStream.CreateForAttachment(File.Open(cipherFile, FileMode.Open), plaintextInput.Length, key, null);
164+
using FileStream fileStream = File.Open(cipherFile, FileMode.Open);
165+
AttachmentCipherInputStream.CreateForAttachment(fileStream, plaintextInput.Length, key, null);
124166
}
125167
catch (InvalidMessageException)
126168
{
@@ -155,7 +197,8 @@ public void Test_Attachment_DecryptFailOnBadMac()
155197

156198
cipherFile = WriteToFile(badMacCiphertext);
157199

158-
AttachmentCipherInputStream.CreateForAttachment(File.Open(cipherFile, FileMode.Open), plaintextInput.Length, key, encryptResult.digest);
200+
using FileStream fileStream = File.Open(cipherFile, FileMode.Open);
201+
AttachmentCipherInputStream.CreateForAttachment(fileStream, plaintextInput.Length, key, encryptResult.digest);
159202
}
160203
catch (InvalidMessageException)
161204
{
@@ -178,7 +221,7 @@ public void Test_Sticker_EncryptDecrypt()
178221
byte[] packKey = Util.GetSecretBytes(32);
179222
byte[] plaintextInput = Encoding.UTF8.GetBytes("Peter Parker");
180223
EncryptResult encryptResult = EncryptData(plaintextInput, ExpandPackKey(packKey));
181-
Stream inputStream = AttachmentCipherInputStream.CreateForStickerData(encryptResult.ciphertext, packKey);
224+
using Stream inputStream = AttachmentCipherInputStream.CreateForStickerData(encryptResult.ciphertext, packKey);
182225
byte[] plaintextOutput = ReadInputStreamFully(inputStream);
183226

184227
CollectionAssert.AreEqual(plaintextInput, plaintextOutput);
@@ -190,7 +233,7 @@ public void Test_Sticker_EncryptDecryptEmpty()
190233
byte[] packKey = Util.GetSecretBytes(32);
191234
byte[] plaintextInput = Encoding.UTF8.GetBytes(string.Empty);
192235
EncryptResult encryptResult = EncryptData(plaintextInput, ExpandPackKey(packKey));
193-
Stream inputStream = AttachmentCipherInputStream.CreateForStickerData(encryptResult.ciphertext, packKey);
236+
using Stream inputStream = AttachmentCipherInputStream.CreateForStickerData(encryptResult.ciphertext, packKey);
194237
byte[] plaintextOutput = ReadInputStreamFully(inputStream);
195238

196239
CollectionAssert.AreEqual(plaintextInput, plaintextOutput);
@@ -270,14 +313,7 @@ private static void DeleteFile(string path)
270313
{
271314
if (File.Exists(path))
272315
{
273-
try
274-
{
275-
File.Delete(path);
276-
}
277-
catch (IOException)
278-
{
279-
// for some reason this fails
280-
}
316+
File.Delete(path);
281317
}
282318
}
283319

libsignal-service-dotnet/crypto/AttachmentCipherInputStream.cs

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ public class AttachmentCipherInputStream : Stream
2020
private readonly Aes aes;
2121
private readonly CryptoStream cipher;
2222
private readonly ICryptoTransform decryptor;
23+
private readonly byte[] mac;
24+
private bool disposed = false;
2325

2426
private readonly long totalDataSize;
2527
private long totalRead = 0;
@@ -59,12 +61,8 @@ public static Stream CreateForAttachment(Stream inputStream, long plaintextLengt
5961

6062
VerifyMac(inputStream, inputStream.Length, mac, digest);
6163
inputStream.Seek(0, SeekOrigin.Begin);
62-
// We need to truncate the MAC off the end of the input stream or CryptoStream will fail to decrypt
63-
// correctly because it will think there's more data to decrypt when the MAC isn't actually part of
64-
// what needs to be decrypted.
65-
inputStream.SetLength(inputStream.Length - MacLength);
6664

67-
Stream stream = new AttachmentCipherInputStream(inputStream, parts[0], inputStream.Length - BLOCK_SIZE);
65+
Stream stream = new AttachmentCipherInputStream(inputStream, parts[0], inputStream.Length - BLOCK_SIZE - MacLength);
6866

6967
if (plaintextLength > 0)
7068
{
@@ -103,12 +101,8 @@ public static Stream CreateForStickerData(byte[] data, byte[] packKey)
103101
MemoryStream inputStream = new MemoryStream(data);
104102
VerifyMac(inputStream, data.Length, mac, null);
105103
inputStream.Seek(0, SeekOrigin.Begin);
106-
// We need to truncate the MAC off the end of the input stream or CryptoStream will fail to decrypt
107-
// correctly because it will think there's more data to decrypt when the MAC isn't actually part of
108-
// what needs to be decrypted.
109-
inputStream.SetLength(inputStream.Length - MacLength);
110104

111-
return new AttachmentCipherInputStream(inputStream, parts[0], data.Length - BLOCK_SIZE);
105+
return new AttachmentCipherInputStream(inputStream, parts[0], data.Length - BLOCK_SIZE - MacLength);
112106
}
113107
catch (InvalidMacException ex)
114108
{
@@ -120,6 +114,14 @@ private AttachmentCipherInputStream(Stream inputStream, byte[] key, long totalDa
120114
{
121115
this.inputStream = inputStream;
122116

117+
// We need to truncate the MAC off the end of the input stream or CryptoStream will fail to decrypt
118+
// correctly because it will think there's more data to decrypt when the MAC isn't actually part of
119+
// what needs to be decrypted. Truncating the MAC off the end of the stream however will make the
120+
// stream not reusable so store the MAC so we can add it back before we dispose the stream.
121+
const int MacLength = 32;
122+
mac = GetMac(inputStream, MacLength);
123+
inputStream.SetLength(inputStream.Length - MacLength);
124+
123125
byte[] iv = new byte[BLOCK_SIZE];
124126
ReadFully(iv);
125127

@@ -135,9 +137,20 @@ private AttachmentCipherInputStream(Stream inputStream, byte[] key, long totalDa
135137
this.totalDataSize = totalDataSize;
136138
}
137139

140+
protected override void Dispose(bool disposing)
141+
{
142+
if (!disposed)
143+
{
144+
disposed = true;
145+
inputStream.Seek(0, SeekOrigin.End);
146+
inputStream.Write(mac, 0, mac.Length);
147+
}
148+
inputStream.Dispose();
149+
base.Dispose(disposing);
150+
}
151+
138152
public override void Flush()
139153
{
140-
throw new NotImplementedException();
141154
}
142155

143156
public override int Read(byte[] buffer, int offset, int count)
@@ -207,6 +220,15 @@ private static void VerifyMac(Stream inputStream, long length, IncrementalHash m
207220
}
208221
}
209222

223+
private byte[] GetMac(Stream inputStream, int macLength)
224+
{
225+
byte[] mac = new byte[macLength];
226+
inputStream.Seek(-macLength, SeekOrigin.End);
227+
Util.ReadFully(inputStream, mac);
228+
inputStream.Seek(0, SeekOrigin.Begin);
229+
return mac;
230+
}
231+
210232
private void ReadFully(byte[] buffer)
211233
{
212234
int offset = 0;

libsignal-service-dotnet/crypto/DigestingOutputStream.cs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
using System;
2-
using System.Collections.Generic;
1+
using System;
32
using System.IO;
43
using System.Security.Cryptography;
5-
using System.Text;
64

75
namespace libsignalservice.crypto
86
{
@@ -22,6 +20,12 @@ public DigestingOutputStream(Stream outputStream)
2220
OutputStream = outputStream;
2321
}
2422

23+
protected override void Dispose(bool disposing)
24+
{
25+
OutputStream.Dispose();
26+
base.Dispose(disposing);
27+
}
28+
2529
public override void Flush()
2630
{
2731
OutputStream.Flush();

libsignal-service-dotnet/crypto/PaddingInputStream.cs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,41 @@ namespace libsignalservice.crypto
66
{
77
internal class PaddingInputStream : Stream
88
{
9-
private readonly Stream InputStream;
10-
private long PaddingRemaining;
9+
private readonly Stream inputStream;
10+
private long paddingRemaining;
1111

1212
public override bool CanRead => true;
1313
public override bool CanSeek => false;
1414
public override bool CanWrite => false;
15-
public override long Length { get => InputStream.Length + Util.ToIntExact(PaddingRemaining); }
15+
public override long Length { get => inputStream.Length + Util.ToIntExact(paddingRemaining); }
1616
public override long Position { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }
1717

1818
public PaddingInputStream(Stream inputStream, long plainTextLength)
1919
{
20-
InputStream = inputStream;
21-
PaddingRemaining = GetPaddedSize(plainTextLength) - plainTextLength;
20+
this.inputStream = inputStream;
21+
paddingRemaining = GetPaddedSize(plainTextLength) - plainTextLength;
22+
}
23+
24+
protected override void Dispose(bool disposing)
25+
{
26+
inputStream.Dispose();
27+
base.Dispose(disposing);
2228
}
2329

2430
public override void Flush()
2531
{
26-
throw new NotImplementedException();
2732
}
2833

2934
public override int Read(byte[] buffer, int offset, int count)
3035
{
31-
int result = InputStream.Read(buffer, offset, count);
36+
int result = inputStream.Read(buffer, offset, count);
3237
if (result >= 0)
3338
return result;
3439

35-
if (PaddingRemaining > 0)
40+
if (paddingRemaining > 0)
3641
{
37-
count = Math.Min(count, Util.ToIntExact(PaddingRemaining));
38-
PaddingRemaining -= count;
42+
count = Math.Min(count, Util.ToIntExact(paddingRemaining));
43+
paddingRemaining -= count;
3944
return count;
4045
}
4146
return 0;

libsignal-service-dotnet/crypto/ProfileCipherInputStream.cs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace libsignalservice.crypto
1111
public class ProfileCipherInputStream : Stream
1212
{
1313
private readonly GcmBlockCipher cipher;
14-
private readonly Stream InputStream;
14+
private readonly Stream inputStream;
1515

1616
private bool finished = false;
1717

@@ -28,12 +28,17 @@ public ProfileCipherInputStream(Stream inputStream, byte[] key)
2828
byte[] nonce = new byte[12];
2929
Util.ReadFully(inputStream, nonce);
3030
cipher.Init(false, new AeadParameters(new KeyParameter(key), 128, nonce));
31-
InputStream = inputStream;
31+
this.inputStream = inputStream;
32+
}
33+
34+
protected override void Dispose(bool disposing)
35+
{
36+
inputStream.Dispose();
37+
base.Dispose(disposing);
3238
}
3339

3440
public override void Flush()
3541
{
36-
throw new NotImplementedException();
3742
}
3843

3944
/// <summary>
@@ -50,7 +55,7 @@ public override int Read(byte[] buffer, int offset, int count)
5055
try
5156
{
5257
byte[] ciphertext = new byte[count / 2];
53-
int read = InputStream.Read(ciphertext, 0, ciphertext.Length);
58+
int read = inputStream.Read(ciphertext, 0, ciphertext.Length);
5459

5560
if (read <= 0)
5661
{

libsignal-service-dotnet/messages/multidevice/ChunkedInputStream.cs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ public int ReadRawVarint32()// throws IOException
7070

7171
internal class LimitedInputStream : Stream
7272
{
73-
private Stream InputStream;
74-
private long Left;
73+
private Stream inputStream;
74+
private long left;
7575

7676
public override bool CanRead => true;
7777
public override bool CanSeek => false;
@@ -81,24 +81,29 @@ internal class LimitedInputStream : Stream
8181

8282
internal LimitedInputStream(Stream inputStream, long limit)
8383
{
84-
InputStream = inputStream;
85-
Left = limit;
84+
this.inputStream = inputStream;
85+
left = limit;
86+
}
87+
88+
protected override void Dispose(bool disposing)
89+
{
90+
inputStream.Dispose();
91+
base.Dispose(disposing);
8692
}
8793

8894
public override void Flush()
8995
{
90-
throw new NotImplementedException();
9196
}
9297

9398
public override int Read(byte[] buffer, int offset, int count)
9499
{
95-
if (Left == 0)
100+
if (left == 0)
96101
return 0;
97102

98-
count = (int) Math.Min(count, Left);
99-
int result = InputStream.Read(buffer, offset, count);
103+
count = (int) Math.Min(count, left);
104+
int result = inputStream.Read(buffer, offset, count);
100105
if (result > 0)
101-
Left -= result;
106+
left -= result;
102107
return result;
103108
}
104109

0 commit comments

Comments
 (0)