Skip to content

Commit 6a2ae6e

Browse files
committed
DLPX-88137 Upgrade sshj with Raul's mod dlpx73623
PR URL: https://www.github.com/delphix/sshj/pull/7
1 parent f4d34d8 commit 6a2ae6e

File tree

8 files changed

+652
-34
lines changed

8 files changed

+652
-34
lines changed

src/main/java/net/schmizz/sshj/Config.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,4 +200,8 @@ public interface Config {
200200
* See {@link #isVerifyHostKeyCertificates()}.
201201
*/
202202
void setVerifyHostKeyCertificates(boolean value);
203+
204+
int getMaxCircularBufferSize();
205+
206+
void setMaxCircularBufferSize(int maxCircularBufferSize);
203207
}

src/main/java/net/schmizz/sshj/ConfigImpl.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ public class ConfigImpl
4949
private boolean waitForServerIdentBeforeSendingClientIdent = false;
5050
private LoggerFactory loggerFactory;
5151
private boolean verifyHostKeyCertificates = true;
52+
// HF-982: default to 16MB buffers.
53+
private int maxCircularBufferSize = 16 * 1024 * 1026;
5254

5355
@Override
5456
public List<Factory.Named<Cipher>> getCipherFactories() {
@@ -175,6 +177,16 @@ public LoggerFactory getLoggerFactory() {
175177
return loggerFactory;
176178
}
177179

180+
@Override
181+
public int getMaxCircularBufferSize() {
182+
return maxCircularBufferSize;
183+
}
184+
185+
@Override
186+
public void setMaxCircularBufferSize(int maxCircularBufferSize) {
187+
this.maxCircularBufferSize = maxCircularBufferSize;
188+
}
189+
178190
@Override
179191
public void setLoggerFactory(LoggerFactory loggerFactory) {
180192
this.loggerFactory = loggerFactory;
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
/*
2+
* Copyright (C)2009 - SSHJ Contributors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package net.schmizz.sshj.common;
17+
18+
public class CircularBuffer<T extends CircularBuffer<T>> {
19+
20+
public static class CircularBufferException
21+
extends SSHException {
22+
23+
public CircularBufferException(String message) {
24+
super(message);
25+
}
26+
}
27+
28+
public static final class PlainCircularBuffer
29+
extends CircularBuffer<PlainCircularBuffer> {
30+
31+
public PlainCircularBuffer(int size, int maxSize) {
32+
super(size, maxSize);
33+
}
34+
}
35+
36+
/**
37+
* Maximum size of the internal array (one plus the maximum capacity of the buffer).
38+
*/
39+
private final int maxSize;
40+
/**
41+
* Internal array for the data. All bytes minus one can be used to avoid empty vs full ambiguity when rpos == wpos.
42+
*/
43+
private byte[] data;
44+
/**
45+
* Next read position. Wraps around the end of the internal array. When it reaches wpos, the buffer becomes empty.
46+
* Can take the value data.length, which is equivalent to 0.
47+
*/
48+
private int rpos;
49+
/**
50+
* Next write position. Wraps around the end of the internal array. If it is equal to rpos, then the buffer is
51+
* empty; the code does not allow wpos to reach rpos from the left. This implies that the buffer can store up to
52+
* data.length - 1 bytes. Can take the value data.length, which is equivalent to 0.
53+
*/
54+
private int wpos;
55+
56+
/**
57+
* Determines the size to which to grow the internal array.
58+
*/
59+
private int getNextSize(int currentSize) {
60+
// Use next power of 2.
61+
int nextSize = 1;
62+
while (nextSize < currentSize) {
63+
nextSize <<= 1;
64+
if (nextSize <= 0) {
65+
return maxSize;
66+
}
67+
}
68+
return Math.min(nextSize, maxSize); // limit to max size
69+
}
70+
71+
/**
72+
* Creates a new circular buffer of the given size. The capacity of the buffer is one less than the size/
73+
*/
74+
public CircularBuffer(int size, int maxSize) {
75+
this.maxSize = maxSize;
76+
if (size > maxSize) {
77+
throw new IllegalArgumentException(
78+
String.format("Initial requested size %d larger than maximum size %d", size, maxSize));
79+
}
80+
int initialSize = getNextSize(size);
81+
this.data = new byte[initialSize];
82+
this.rpos = 0;
83+
this.wpos = 0;
84+
}
85+
86+
/**
87+
* Data available in the buffer for reading.
88+
*/
89+
public int available() {
90+
int available = wpos - rpos;
91+
return available >= 0 ? available : available + data.length; // adjust if wpos is left of rpos
92+
}
93+
94+
private void ensureAvailable(int a)
95+
throws CircularBufferException {
96+
if (available() < a) {
97+
throw new CircularBufferException("Underflow");
98+
}
99+
}
100+
101+
/**
102+
* Returns how many more bytes this buffer can receive.
103+
*/
104+
public int maxPossibleRemainingCapacity() {
105+
// Remaining capacity is one less than remaining space to ensure that wpos does not reach rpos from the left.
106+
int remaining = rpos - wpos - 1;
107+
if (remaining < 0) {
108+
remaining += data.length; // adjust if rpos is left of wpos
109+
}
110+
// Add the maximum amount the internal array can grow.
111+
return remaining + maxSize - data.length;
112+
}
113+
114+
/**
115+
* If the internal array does not have room for "capacity" more bytes, resizes the array to make that room.
116+
*/
117+
void ensureCapacity(int capacity) throws CircularBufferException {
118+
int available = available();
119+
int remaining = data.length - available;
120+
// If capacity fits exactly in the remaining space, expand it; otherwise, wpos would reach rpos from the left.
121+
if (remaining <= capacity) {
122+
int neededSize = available + capacity + 1;
123+
int nextSize = getNextSize(neededSize);
124+
if (nextSize < neededSize) {
125+
throw new CircularBufferException("Attempted overflow");
126+
}
127+
byte[] tmp = new byte[nextSize];
128+
// Copy data to the beginning of the new array.
129+
if (wpos >= rpos) {
130+
System.arraycopy(data, rpos, tmp, 0, available);
131+
wpos -= rpos; // wpos must be relative to the new rpos, which will be 0
132+
} else {
133+
int tail = data.length - rpos;
134+
System.arraycopy(data, rpos, tmp, 0, tail); // segment right of rpos
135+
System.arraycopy(data, 0, tmp, tail, wpos); // segment left of wpos
136+
wpos += tail; // wpos must be relative to the new rpos, which will be 0
137+
}
138+
rpos = 0;
139+
data = tmp;
140+
}
141+
}
142+
143+
/**
144+
* Reads data from this buffer into the provided array.
145+
*/
146+
public void readRawBytes(byte[] destination, int offset, int length) throws CircularBufferException {
147+
ensureAvailable(length);
148+
149+
int rposNext = rpos + length;
150+
if (rposNext <= data.length) {
151+
System.arraycopy(data, rpos, destination, offset, length);
152+
} else {
153+
int tail = data.length - rpos;
154+
System.arraycopy(data, rpos, destination, offset, tail); // segment right of rpos
155+
rposNext = length - tail; // rpos wraps around the end of the buffer
156+
System.arraycopy(data, 0, destination, offset + tail, rposNext); // remainder
157+
}
158+
// This can make rpos equal data.length, which has the same effect as wpos being 0.
159+
rpos = rposNext;
160+
}
161+
162+
/**
163+
* Writes data to this buffer from the provided array.
164+
*/
165+
@SuppressWarnings("unchecked")
166+
public T putRawBytes(byte[] source, int offset, int length) throws CircularBufferException {
167+
ensureCapacity(length);
168+
169+
int wposNext = wpos + length;
170+
if (wposNext <= data.length) {
171+
System.arraycopy(source, offset, data, wpos, length);
172+
} else {
173+
int tail = data.length - wpos;
174+
System.arraycopy(source, offset, data, wpos, tail); // segment right of wpos
175+
wposNext = length - tail; // wpos wraps around the end of the buffer
176+
System.arraycopy(source, offset + tail, data, 0, wposNext); // remainder
177+
}
178+
// This can make wpos equal data.length, which has the same effect as wpos being 0.
179+
wpos = wposNext;
180+
181+
return (T) this;
182+
}
183+
184+
// Used only for testing.
185+
int length() {
186+
return data.length;
187+
}
188+
189+
@Override
190+
public String toString() {
191+
return "CircularBuffer [rpos=" + rpos + ", wpos=" + wpos + ", size=" + data.length + "]";
192+
}
193+
194+
}

src/main/java/net/schmizz/sshj/connection/channel/AbstractChannel.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,7 @@ public String getType() {
164164
}
165165

166166
@Override
167-
public void handle(Message msg, SSHPacket buf)
168-
throws ConnectionException, TransportException {
167+
public void handle(Message msg, SSHPacket buf) throws SSHException {
169168
switch (msg) {
170169

171170
case CHANNEL_DATA:
@@ -354,7 +353,7 @@ protected void finishOff() {
354353
}
355354

356355
protected void gotExtendedData(SSHPacket buf)
357-
throws ConnectionException, TransportException {
356+
throws SSHException {
358357
throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR,
359358
"Extended data not supported on " + type + " channel");
360359
}
@@ -375,7 +374,7 @@ protected SSHPacket newBuffer(Message cmd) {
375374
}
376375

377376
protected void receiveInto(ChannelInputStream stream, SSHPacket buf)
378-
throws ConnectionException, TransportException {
377+
throws SSHException {
379378
final int len;
380379
try {
381380
len = buf.readUInt32AsInt();

src/main/java/net/schmizz/sshj/connection/channel/ChannelInputStream.java

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,19 @@ public final class ChannelInputStream
3838
private final Channel chan;
3939
private final Transport trans;
4040
private final Window.Local win;
41-
private final Buffer.PlainBuffer buf;
41+
private final CircularBuffer.PlainCircularBuffer buf;
4242
private final byte[] b = new byte[1];
4343

4444
private boolean eof;
4545
private SSHException error;
4646

4747
public ChannelInputStream(Channel chan, Transport trans, Window.Local win) {
4848
this.chan = chan;
49-
log = chan.getLoggerFactory().getLogger(getClass());
49+
this.log = chan.getLoggerFactory().getLogger(getClass());
5050
this.trans = trans;
5151
this.win = win;
52-
buf = new Buffer.PlainBuffer(chan.getLocalMaxPacketSize());
52+
this.buf = new CircularBuffer.PlainCircularBuffer(
53+
chan.getLocalMaxPacketSize(), trans.getConfig().getMaxCircularBufferSize());
5354
}
5455

5556
@Override
@@ -113,48 +114,44 @@ public int read(byte[] b, int off, int len)
113114
len = buf.available();
114115
}
115116
buf.readRawBytes(b, off, len);
116-
if (buf.rpos() > win.getMaxPacketSize() && buf.available() == 0) {
117-
buf.clear();
118-
}
119-
}
120117

121-
if (!chan.getAutoExpand()) {
122-
checkWindow();
118+
if (!chan.getAutoExpand()) {
119+
checkWindow();
120+
}
123121
}
124122

125123
return len;
126124
}
127125

128-
public void receive(byte[] data, int offset, int len)
129-
throws ConnectionException, TransportException {
126+
public void receive(byte[] data, int offset, int len) throws SSHException {
130127
if (eof) {
131128
throw new ConnectionException("Getting data on EOF'ed stream");
132129
}
133130
synchronized (buf) {
134131
buf.putRawBytes(data, offset, len);
135132
buf.notifyAll();
136-
}
137-
// Potential fix for #203 (window consumed below 0).
138-
// This seems to be a race condition if we receive more data, while we're already sending a SSH_MSG_CHANNEL_WINDOW_ADJUST
139-
// And the window has not expanded yet.
140-
synchronized (win) {
133+
// Potential fix for #203 (window consumed below 0).
134+
// This seems to be a race condition if we receive more data, while we're already sending a SSH_MSG_CHANNEL_WINDOW_ADJUST
135+
// And the window has not expanded yet.
141136
win.consume(len);
142-
}
143-
if (chan.getAutoExpand()) {
144-
checkWindow();
137+
if (chan.getAutoExpand()) {
138+
checkWindow();
139+
}
145140
}
146141
}
147142

148-
private void checkWindow()
149-
throws TransportException {
150-
synchronized (win) {
151-
final long adjustment = win.neededAdjustment();
152-
if (adjustment > 0) {
153-
log.debug("Sending SSH_MSG_CHANNEL_WINDOW_ADJUST to #{} for {} bytes", chan.getRecipient(), adjustment);
154-
trans.write(new SSHPacket(Message.CHANNEL_WINDOW_ADJUST)
155-
.putUInt32FromInt(chan.getRecipient()).putUInt32(adjustment));
156-
win.expand(adjustment);
157-
}
143+
private void checkWindow() throws TransportException {
144+
/*
145+
* Window must fit in remaining buffer capacity. We already expect win.size() amount of data to arrive. The
146+
* difference between that and the remaining capacity is the maximum adjustment we can make to the window.
147+
*/
148+
final long maxAdjustment = buf.maxPossibleRemainingCapacity() - win.getSize();
149+
final long adjustment = Math.min(win.neededAdjustment(), maxAdjustment);
150+
if (adjustment > 0) {
151+
log.debug("Sending SSH_MSG_CHANNEL_WINDOW_ADJUST to #{} for {} bytes", chan.getRecipient(), adjustment);
152+
trans.write(new SSHPacket(Message.CHANNEL_WINDOW_ADJUST)
153+
.putUInt32FromInt(chan.getRecipient()).putUInt32(adjustment));
154+
win.expand(adjustment);
158155
}
159156
}
160157

src/main/java/net/schmizz/sshj/connection/channel/direct/SessionChannel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ protected void eofInputStreams() {
210210

211211
@Override
212212
protected void gotExtendedData(SSHPacket buf)
213-
throws ConnectionException, TransportException {
213+
throws SSHException {
214214
try {
215215
final int dataTypeCode = buf.readUInt32AsInt();
216216
if (dataTypeCode == 1)

0 commit comments

Comments
 (0)