Skip to content

Commit f93de0b

Browse files
authored
Merge pull request #257 from AidnAS/read-async-recv-timeout
Make PooledSocket.ReadAsync respect receive timeout setting
2 parents 33ebfcf + 227e3f3 commit f93de0b

File tree

2 files changed

+168
-14
lines changed

2 files changed

+168
-14
lines changed

src/Enyim.Caching/Memcached/PooledSocket.cs

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public partial class PooledSocket : IDisposable
2424
private bool _isSocketDisposed;
2525
private readonly EndPoint _endpoint;
2626
private readonly int _connectionTimeout;
27-
27+
private readonly int _receiveTimeout;
2828
private NetworkStream _inputStream;
2929
private SslStream _sslStream;
3030
#if NET5_0_OR_GREATER
@@ -71,6 +71,7 @@ public PooledSocket(EndPoint endpoint, TimeSpan connectionTimeout, TimeSpan rece
7171

7272
socket.ReceiveTimeout = rcv;
7373
socket.SendTimeout = rcv;
74+
_receiveTimeout = rcv;
7475

7576
_socket = socket;
7677
}
@@ -425,21 +426,31 @@ public async Task ReadAsync(byte[] buffer, int offset, int count)
425426
{
426427
try
427428
{
428-
int currentRead = (_useSslStream
429-
? await _sslStream.ReadAsync(buffer, offset, shouldRead).ConfigureAwait(false)
430-
: await _inputStream.ReadAsync(buffer, offset, shouldRead).ConfigureAwait(false));
431-
if (currentRead == count)
432-
break;
433-
if (currentRead < 1)
434-
throw new IOException("The socket seems to be disconnected");
435-
436-
read += currentRead;
437-
offset += currentRead;
438-
shouldRead -= currentRead;
429+
var readTask = _useSslStream
430+
? _sslStream.ReadAsync(buffer, offset, shouldRead)
431+
: _inputStream.ReadAsync(buffer, offset, shouldRead);
432+
var timeoutTask = Task.Delay(_receiveTimeout);
433+
434+
if (await Task.WhenAny(readTask, timeoutTask).ConfigureAwait(false) == readTask)
435+
{
436+
int currentRead = await readTask.ConfigureAwait(false);
437+
if (currentRead == count)
438+
break;
439+
if (currentRead < 1)
440+
throw new IOException("The socket seems to be disconnected");
441+
442+
read += currentRead;
443+
offset += currentRead;
444+
shouldRead -= currentRead;
445+
}
446+
else
447+
{
448+
throw new TimeoutException($"Timeout to read from {_endpoint}.");
449+
}
439450
}
440451
catch (Exception ex)
441452
{
442-
if (ex is IOException || ex is SocketException)
453+
if (ex is IOException || ex is SocketException || ex is TimeoutException)
443454
{
444455
_isAlive = false;
445456
}
@@ -648,4 +659,4 @@ private IPEndPoint GetIPEndPoint(EndPoint endpoint)
648659
*
649660
* ************************************************************/
650661

651-
#endregion
662+
#endregion
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
using System;
2+
using System.Diagnostics;
3+
using System.IO;
4+
using System.Net;
5+
using System.Net.Sockets;
6+
using System.Threading;
7+
using System.Threading.Tasks;
8+
using Enyim.Caching.Memcached;
9+
using Microsoft.Extensions.Logging.Abstractions;
10+
using Xunit;
11+
12+
namespace MemcachedTest;
13+
14+
public class PooledSocketTest
15+
{
16+
[Fact]
17+
public async Task ReadSync_ShouldTimeoutOrFail_WhenServerResponseIsSlow()
18+
{
19+
// Arrange
20+
var logger = new NullLogger<PooledSocketTest>();
21+
const int port = 12345;
22+
var server = new SlowLorisServer();
23+
using var cts = new CancellationTokenSource();
24+
await server.StartAsync(port, cts.Token);
25+
var endpoint = new IPEndPoint(IPAddress.Loopback, port);
26+
var socket = new PooledSocket(
27+
endpoint,
28+
TimeSpan.FromSeconds(5),
29+
TimeSpan.FromMilliseconds(50),
30+
logger,
31+
useSslStream: false,
32+
useIPv6: false,
33+
sslClientAuthOptions: null
34+
);
35+
await socket.ConnectAsync();
36+
var buffer = new byte[server.Response.Length];
37+
38+
// Act
39+
var timer = Stopwatch.StartNew();
40+
var ex = Record.Exception(() =>
41+
{
42+
socket.Read(buffer, 0, server.Response.Length);
43+
});
44+
timer.Stop();
45+
46+
// Assert
47+
Assert.True(timer.Elapsed < TimeSpan.FromMilliseconds(500), "Read took too long");
48+
Assert.NotNull(ex);
49+
Assert.True(
50+
ex is TimeoutException or IOException,
51+
$"Expected TimeoutException or IOException, got {ex.GetType().Name}: {ex.Message}"
52+
);
53+
54+
await cts.CancelAsync();
55+
server.Stop();
56+
}
57+
58+
[Fact]
59+
public async Task ReadAsync_ShouldTimeoutOrFail_WhenServerResponseIsSlow()
60+
{
61+
// Arrange
62+
var logger = new NullLogger<PooledSocket>();
63+
const int port = 12345;
64+
var server = new SlowLorisServer();
65+
using var cts = new CancellationTokenSource();
66+
67+
await server.StartAsync(port, cts.Token);
68+
69+
var endpoint = new IPEndPoint(IPAddress.Loopback, port);
70+
var socket = new PooledSocket(
71+
endpoint,
72+
TimeSpan.FromSeconds(5),
73+
TimeSpan.FromMilliseconds(50),
74+
logger,
75+
useSslStream: false,
76+
useIPv6: false,
77+
sslClientAuthOptions: null
78+
);
79+
80+
await socket.ConnectAsync();
81+
82+
var buffer = new byte[server.Response.Length];
83+
84+
// Act
85+
var timer = Stopwatch.StartNew();
86+
var ex = await Record.ExceptionAsync(async () =>
87+
{
88+
await socket.ReadAsync(buffer, 0, server.Response.Length);
89+
});
90+
timer.Stop();
91+
92+
// Assert
93+
Assert.True(timer.Elapsed < TimeSpan.FromMilliseconds(500), "ReadAsync took too long");
94+
Assert.NotNull(ex);
95+
Assert.True(
96+
ex is TimeoutException or IOException,
97+
$"Expected TimeoutException or IOException, got {ex.GetType().Name}: {ex.Message}"
98+
);
99+
100+
// Cleanup
101+
await cts.CancelAsync();
102+
server.Stop();
103+
}
104+
}
105+
106+
public class SlowLorisServer
107+
{
108+
private TcpListener _listener;
109+
private CancellationToken _token;
110+
public readonly byte[] Response = "Hello, I'm slow!"u8.ToArray();
111+
112+
public Task StartAsync(int port, CancellationToken token)
113+
{
114+
_token = token;
115+
_listener = new TcpListener(IPAddress.Loopback, port);
116+
_listener.Start();
117+
118+
_ = Task.Run(async () =>
119+
{
120+
while (!token.IsCancellationRequested)
121+
{
122+
var client = await _listener.AcceptTcpClientAsync(token);
123+
_ = Task.Run(() => HandleClientAsync(client), token);
124+
}
125+
}, token);
126+
return Task.CompletedTask;
127+
}
128+
129+
private async Task HandleClientAsync(TcpClient client)
130+
{
131+
await using var stream = client.GetStream();
132+
for (var i = 0; i < Response.Length; i++)
133+
{
134+
await stream.WriteAsync(Response, i, 1, _token);
135+
await Task.Delay(100, _token);
136+
}
137+
await stream.FlushAsync(_token);
138+
client.Close();
139+
}
140+
141+
public void Stop() => _listener.Stop();
142+
}
143+

0 commit comments

Comments
 (0)