From ed6ff3fe805b6ed2093df533d7a74d5ea1e063b8 Mon Sep 17 00:00:00 2001 From: scx567888 Date: Mon, 14 Oct 2024 14:51:50 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84=20MultiPart=20=E8=A7=A3?= =?UTF-8?q?=E6=9E=90=E5=99=A8=20(#99)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../cool/scx/http/ScxHttpHeadersHelper.java | 2 +- .../media/multi_part/MultiPartStream.java | 45 ++++++++++--------- .../multi_part/MultiPartStreamCached.java | 15 +++---- .../java/cool/scx/http/test/HeadersTest.java | 28 ++++++++++++ .../cool/scx/http/test/MultiPartTest.java | 4 ++ .../java/cool/scx/io/ByteArrayDataReader.java | 6 +-- .../src/main/java/cool/scx/io/DataReader.java | 30 +++++++++---- .../java/cool/scx/io/LinkedDataReader.java | 6 +-- .../java/cool/scx/io/test/DataReaderTest.java | 6 +-- 9 files changed, 94 insertions(+), 48 deletions(-) create mode 100644 scx-http/src/test/java/cool/scx/http/test/HeadersTest.java diff --git a/scx-http/src/main/java/cool/scx/http/ScxHttpHeadersHelper.java b/scx-http/src/main/java/cool/scx/http/ScxHttpHeadersHelper.java index f0ab9c65..e76ec348 100644 --- a/scx-http/src/main/java/cool/scx/http/ScxHttpHeadersHelper.java +++ b/scx-http/src/main/java/cool/scx/http/ScxHttpHeadersHelper.java @@ -7,7 +7,7 @@ public static ScxHttpHeadersWritable parseHeaders(String headersStr) { var lines = headersStr.split("\r\n"); for (var line : lines) { - int i = line.indexOf(":"); + int i = line.indexOf(':'); if (i != -1) { var key = line.substring(0, i).trim(); var value = line.substring(i + 1).trim(); diff --git a/scx-http/src/main/java/cool/scx/http/media/multi_part/MultiPartStream.java b/scx-http/src/main/java/cool/scx/http/media/multi_part/MultiPartStream.java index 35e56380..07ac2941 100644 --- a/scx-http/src/main/java/cool/scx/http/media/multi_part/MultiPartStream.java +++ b/scx-http/src/main/java/cool/scx/http/media/multi_part/MultiPartStream.java @@ -1,6 +1,5 @@ package cool.scx.http.media.multi_part; -import cool.scx.common.util.ArrayUtils; import cool.scx.http.ScxHttpHeaders; import cool.scx.http.ScxHttpHeadersWritable; import cool.scx.io.InputStreamDataSupplier; @@ -16,15 +15,12 @@ public class MultiPartStream implements MultiPart, Iterator { protected static final byte[] CRLF_CRLF_BYTES = "\r\n\r\n".getBytes(); protected final LinkedDataReader linkedDataReader; - protected final byte[] boundaryHeadCRLFBytes; - protected final byte[] boundaryENDBytes; + protected final byte[] boundaryBytes; protected boolean hasNextPart; protected String boundary; public MultiPartStream(InputStream inputStream, String boundary) { - var boundaryHeadBytes = ArrayUtils.concat("--".getBytes(), boundary.getBytes()); - this.boundaryHeadCRLFBytes = ArrayUtils.concat(boundaryHeadBytes, "\r\n".getBytes()); - this.boundaryENDBytes = ArrayUtils.concat(boundaryHeadBytes, "--".getBytes()); + this.boundaryBytes = ("--" + boundary).getBytes(); this.linkedDataReader = new LinkedDataReader(new InputStreamDataSupplier(inputStream)); this.hasNextPart = readNext(); } @@ -34,36 +30,42 @@ public ScxHttpHeadersWritable readToHeaders() { // head /r/n // /r/n // content - var headersBytes = linkedDataReader.readMatch(CRLF_CRLF_BYTES); + var headersBytes = linkedDataReader.readUntil(CRLF_CRLF_BYTES); var headersStr = new String(headersBytes); return ScxHttpHeaders.of(headersStr); } public byte[] readContentToByte() throws IOException { - //我们需要查找终结点 先假设不是最后一个 那我们就需要查找下一个开始位置 + //因为正常的表单一定是 --xxxxxx 结尾的 所以我们只需要找 下一个分块的起始位置作为结束位置即可 try { - var i = linkedDataReader.indexOf(boundaryHeadCRLFBytes); + var i = linkedDataReader.indexOf(boundaryBytes); // i - 2 因为我们不需要读取内容结尾的 \r\n var bytes = linkedDataReader.read(i - 2); //跳过 \r\n 方便后续读取 linkedDataReader.skip(2); return bytes; } catch (NoMatchFoundException e) { - //可能是最后一个查找 最终终结点 - var i = linkedDataReader.indexOf(boundaryENDBytes); - var bytes = linkedDataReader.read(i - 2); - //跳过 \r\n 方便后续读取 - linkedDataReader.skip(2); - return bytes; + // 理论上一个正常的 MultiPart 不会有这种情况 + throw new RuntimeException("异常状态 !!!"); } } public boolean readNext() { - //查找 --xxxxxxxxx\r\n 没有代表 读取到结尾 + //查找 --xxxxxxxxx try { - var i = linkedDataReader.indexOf(boundaryHeadCRLFBytes); - linkedDataReader.skip(i + boundaryHeadCRLFBytes.length); - return true; + var i = linkedDataReader.indexOf(boundaryBytes); + linkedDataReader.skip(i + boundaryBytes.length); + //向后读取两个字节 + var a = linkedDataReader.read(); + var b = linkedDataReader.read(); + // 判断 是 \r\n or -- + if (a == '\r' && b == '\n') { //还有数据 + return true; + } else if (a == '-' && b == '-') { // 读取到了终结符 + return false; + } else { // 理论上一个正常的 MultiPart 不会有这种情况 + throw new RuntimeException("未知字符 !!! "); + } } catch (NoMatchFoundException e) { return false; } @@ -91,10 +93,11 @@ public MultiPartPart next() { } try { + var part = new MultiPartPartImpl(); + // 读取当前部分的头部信息 var headers = readToHeaders(); - - var part = new MultiPartPartImpl().headers(headers); + part.headers(headers); //读取内容 var content = readContentToByte(); diff --git a/scx-http/src/main/java/cool/scx/http/media/multi_part/MultiPartStreamCached.java b/scx-http/src/main/java/cool/scx/http/media/multi_part/MultiPartStreamCached.java index f4b452c3..cce80980 100644 --- a/scx-http/src/main/java/cool/scx/http/media/multi_part/MultiPartStreamCached.java +++ b/scx-http/src/main/java/cool/scx/http/media/multi_part/MultiPartStreamCached.java @@ -37,17 +37,14 @@ public Path readContentToPath(Path path) throws IOException { try (output) { //我们需要查找终结点 先假设不是最后一个 那我们就需要查找下一个开始位置 try { - var i = linkedDataReader.indexOf(boundaryHeadCRLFBytes); + var i = linkedDataReader.indexOf(boundaryBytes); // i - 2 因为我们不需要读取内容结尾的 \r\n linkedDataReader.read(output, i - 2); //跳过 \r\n 方便后续读取 linkedDataReader.skip(2); } catch (NoMatchFoundException e) { - //可能是最后一个查找 最终终结点 - var i = linkedDataReader.indexOf(boundaryENDBytes); - linkedDataReader.read(output, i - 2); - //跳过 \r\n 方便后续读取 - linkedDataReader.skip(2); + // 理论上一个正常的 MultiPart 不会有这种情况 + throw new RuntimeException("异常状态 !!!"); } } @@ -61,10 +58,11 @@ public MultiPartPart next() { } try { + var part = new MultiPartPartImpl(); + // 读取当前部分的头部信息 var headers = readToHeaders(); - - var part = new MultiPartPartImpl().headers(headers); + part.headers(headers); var b = needCached(headers); if (b) { @@ -74,6 +72,7 @@ public MultiPartPart next() { var content = readContentToByte(); part.body(content); } + // 检查是否有下一个部分 hasNextPart = readNext(); diff --git a/scx-http/src/test/java/cool/scx/http/test/HeadersTest.java b/scx-http/src/test/java/cool/scx/http/test/HeadersTest.java new file mode 100644 index 00000000..689c9ba0 --- /dev/null +++ b/scx-http/src/test/java/cool/scx/http/test/HeadersTest.java @@ -0,0 +1,28 @@ +package cool.scx.http.test; + +import cool.scx.http.MediaType; +import cool.scx.http.ScxHttpHeaders; +import cool.scx.http.content_type.ContentType; + +import java.nio.charset.StandardCharsets; + +public class HeadersTest { + + public static void main(String[] args) { + test1(); + } + + public static void test1() { + long l = System.nanoTime(); + for (int i = 0; i < 9999; i++) { + var h = ScxHttpHeaders.of(); + h.add("Content-Disposition", "form-data; name=myname"); + h.contentLength(100); + h.contentType(ContentType.of(MediaType.APPLICATION_JSON).charset(StandardCharsets.UTF_8)); + var s = h.encode(); + var nw = ScxHttpHeaders.of(s); + } + System.out.println((System.nanoTime() - l) / 1000_000); + } + +} diff --git a/scx-http/src/test/java/cool/scx/http/test/MultiPartTest.java b/scx-http/src/test/java/cool/scx/http/test/MultiPartTest.java index 63a69083..d55324b4 100644 --- a/scx-http/src/test/java/cool/scx/http/test/MultiPartTest.java +++ b/scx-http/src/test/java/cool/scx/http/test/MultiPartTest.java @@ -1,5 +1,6 @@ package cool.scx.http.test; +import cool.scx.common.util.ArrayUtils; import cool.scx.http.MediaType; import cool.scx.http.ScxHttpHeaders; import cool.scx.http.content_type.ContentType; @@ -36,6 +37,9 @@ public static void test1() { ss.write(b); byte[] byteArray = b.toByteArray(); + //复制两遍查看是否会产生错误的读取 + byteArray = ArrayUtils.concat(byteArray, byteArray); + long l = System.nanoTime(); for (int j = 0; j < 9999; j++) { diff --git a/scx-io/src/main/java/cool/scx/io/ByteArrayDataReader.java b/scx-io/src/main/java/cool/scx/io/ByteArrayDataReader.java index 94d2b21c..bf7dc84e 100644 --- a/scx-io/src/main/java/cool/scx/io/ByteArrayDataReader.java +++ b/scx-io/src/main/java/cool/scx/io/ByteArrayDataReader.java @@ -53,7 +53,7 @@ public void read(OutputStream outputStream, int maxLength) throws NoMoreDataExce } @Override - public byte get() throws NoMoreDataException { + public byte peek() throws NoMoreDataException { try { return bytes[position]; } catch (ArrayIndexOutOfBoundsException e) { @@ -62,7 +62,7 @@ public byte get() throws NoMoreDataException { } @Override - public byte[] get(int maxLength) throws NoMoreDataException { + public byte[] peek(int maxLength) throws NoMoreDataException { int availableLength = bytes.length - position; if (availableLength <= 0) { throw new NoMoreDataException(); @@ -74,7 +74,7 @@ public byte[] get(int maxLength) throws NoMoreDataException { } @Override - public void get(OutputStream outputStream, int maxLength) throws NoMoreDataException { + public void peek(OutputStream outputStream, int maxLength) throws NoMoreDataException { int availableLength = bytes.length - position; if (availableLength <= 0) { throw new NoMoreDataException(); diff --git a/scx-io/src/main/java/cool/scx/io/DataReader.java b/scx-io/src/main/java/cool/scx/io/DataReader.java index 790914a4..58e77718 100644 --- a/scx-io/src/main/java/cool/scx/io/DataReader.java +++ b/scx-io/src/main/java/cool/scx/io/DataReader.java @@ -11,6 +11,7 @@ public interface DataReader { * 当没有更多的数据时会抛出异常 * * @return byte + * @throws NoMoreDataException 没有更多数据时抛出 */ byte read() throws NoMoreDataException; @@ -20,6 +21,7 @@ public interface DataReader { * * @param maxLength 最大长度 * @return bytes + * @throws NoMoreDataException 没有更多数据时抛出 */ byte[] read(int maxLength) throws NoMoreDataException; @@ -28,6 +30,7 @@ public interface DataReader { * 当没有更多的数据时会抛出异常 * * @param maxLength 最大长度 + * @throws NoMoreDataException 没有更多数据时抛出 */ void read(OutputStream outputStream, int maxLength) throws NoMoreDataException; @@ -36,8 +39,9 @@ public interface DataReader { * 当没有更多的数据时会抛出异常 * * @return byte + * @throws NoMoreDataException 没有更多数据时抛出 */ - byte get() throws NoMoreDataException; + byte peek() throws NoMoreDataException; /** * 读取指定长度字节 (指针不会移动) @@ -45,22 +49,25 @@ public interface DataReader { * * @param maxLength 最大长度 * @return byte + * @throws NoMoreDataException 没有更多数据时抛出 */ - byte[] get(int maxLength) throws NoMoreDataException; + byte[] peek(int maxLength) throws NoMoreDataException; /** * 向 outputStream 写入指定长度字节 (指针不会移动) * 当没有更多的数据时会抛出异常 * * @param maxLength 最大长度 + * @throws NoMoreDataException 没有更多数据时抛出 */ - void get(OutputStream outputStream, int maxLength) throws NoMoreDataException; + void peek(OutputStream outputStream, int maxLength) throws NoMoreDataException; /** * 查找 指定字节 第一次出现的 index (指针不会移动) * * @param b 指定字节 * @return index 或者 -1 (未找到) + * @throws NoMatchFoundException 没有匹配时抛出 */ int indexOf(byte b) throws NoMatchFoundException; @@ -69,6 +76,7 @@ public interface DataReader { * * @param b 指定字节数组 * @return index 或者 -1 (未找到) + * @throws NoMatchFoundException 没有匹配时抛出 */ int indexOf(byte[] b) throws NoMatchFoundException; @@ -84,8 +92,9 @@ public interface DataReader { * * @param b 指定字节 * @return bytes + * @throws NoMatchFoundException 没有匹配时抛出 */ - default byte[] readMatch(byte b) throws NoMatchFoundException { + default byte[] readUntil(byte b) throws NoMatchFoundException { var index = indexOf(b); var data = read(index); skip(1); @@ -97,8 +106,9 @@ default byte[] readMatch(byte b) throws NoMatchFoundException { * * @param b 指定字节 * @return bytes + * @throws NoMatchFoundException 没有匹配时抛出 */ - default byte[] readMatch(byte[] b) throws NoMatchFoundException { + default byte[] readUntil(byte[] b) throws NoMatchFoundException { var index = indexOf(b); var data = read(index); skip(b.length); @@ -110,10 +120,11 @@ default byte[] readMatch(byte[] b) throws NoMatchFoundException { * * @param b 指定字节 * @return bytes + * @throws NoMatchFoundException 没有匹配时抛出 */ - default byte[] getMatch(byte b) throws NoMatchFoundException { + default byte[] peekUntil(byte b) throws NoMatchFoundException { var index = indexOf(b); - return get(index); + return peek(index); } /** @@ -121,10 +132,11 @@ default byte[] getMatch(byte b) throws NoMatchFoundException { * * @param b 指定字节 * @return bytes + * @throws NoMatchFoundException 没有匹配时抛出 */ - default byte[] getMatch(byte[] b) throws NoMatchFoundException { + default byte[] peekUntil(byte[] b) throws NoMatchFoundException { var index = indexOf(b); - return get(index); + return peek(index); } } diff --git a/scx-io/src/main/java/cool/scx/io/LinkedDataReader.java b/scx-io/src/main/java/cool/scx/io/LinkedDataReader.java index e9d70481..eb3abfbb 100644 --- a/scx-io/src/main/java/cool/scx/io/LinkedDataReader.java +++ b/scx-io/src/main/java/cool/scx/io/LinkedDataReader.java @@ -127,13 +127,13 @@ public void read(OutputStream outputStream, int maxLength) throws NoMoreDataExce } @Override - public byte get() throws NoMoreDataException { + public byte peek() throws NoMoreDataException { ensureAvailable(); return head.bytes[head.position]; } @Override - public byte[] get(int maxLength) throws NoMoreDataException { + public byte[] peek(int maxLength) throws NoMoreDataException { ensureAvailable(); // 确保至少有一个字节可读 var result = new byte[maxLength]; var remaining = maxLength; //剩余字节数 @@ -169,7 +169,7 @@ public byte[] get(int maxLength) throws NoMoreDataException { } @Override - public void get(OutputStream outputStream, int maxLength) throws NoMoreDataException { + public void peek(OutputStream outputStream, int maxLength) throws NoMoreDataException { ensureAvailable(); // 确保至少有一个字节可读 var result = new byte[maxLength]; var remaining = maxLength; //剩余字节数 diff --git a/scx-io/src/test/java/cool/scx/io/test/DataReaderTest.java b/scx-io/src/test/java/cool/scx/io/test/DataReaderTest.java index 01cedef0..c6ba51a7 100644 --- a/scx-io/src/test/java/cool/scx/io/test/DataReaderTest.java +++ b/scx-io/src/test/java/cool/scx/io/test/DataReaderTest.java @@ -22,15 +22,15 @@ public static void test1() { var dataReader = new ByteArrayDataReader("11112345678".getBytes(StandardCharsets.UTF_8)); //不会影响读取 - dataReader.get(99); + dataReader.peek(99); dataReader.indexOf("1".getBytes(StandardCharsets.UTF_8)); - var index = dataReader.readMatch("123".getBytes()); + var index = dataReader.readUntil("123".getBytes()); try { //第二次应该匹配失败 - var index1 = dataReader.readMatch("123".getBytes()); + var index1 = dataReader.readUntil("123".getBytes()); } catch (NoMatchFoundException _) { }