From b9864be77abd6743e63ea1bffd63136e89c84260 Mon Sep 17 00:00:00 2001 From: Lily Lin Date: Tue, 12 May 2020 12:01:06 -0700 Subject: [PATCH] Attempt to switch to unsigned ints/bytes using longs/shorts --- .../github/rctcwyvrn/blake3java/Blake3.java | 122 ++++++++++-------- src/com/github/rctcwyvrn/blake3java/Main.java | 14 +- 2 files changed, 78 insertions(+), 58 deletions(-) diff --git a/src/com/github/rctcwyvrn/blake3java/Blake3.java b/src/com/github/rctcwyvrn/blake3java/Blake3.java index f53e83a..70a3927 100644 --- a/src/com/github/rctcwyvrn/blake3java/Blake3.java +++ b/src/com/github/rctcwyvrn/blake3java/Blake3.java @@ -26,7 +26,7 @@ public class Blake3 { private static final int DERIVE_KEY_CONTEXT = 32; private static final int DERIVE_KEY_MATERIAL = 64; - private static final int[] IV = { + private static final long[] IV = { 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19 }; @@ -34,16 +34,17 @@ public class Blake3 { 2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8 }; - private static int wrappingAdd(int a, int b){ - return (a + b); //TODO: Should be mod something, ill figure it out + private static long wrappingAdd(long a, long b){ + return (a + b) & 0xffffffffL; //TODO: Should be mod something, ill figure it out } // TODO: Does java's rotate right work? It uses the two's complement signed representation right? Does this function work or does the java one? -// private static int rotateRight(int x, int len){ -// return (x >> len) | (x << (32 - len)); -// } + // Theoretically they should both work as long as the integer values are positive, which they should always be + private static long rotateRight(long x, int len){ + return ((x >> len) | (x << (32 - len))) & 0xffffffffL; + } - private static void g(int[] state, int a, int b, int c, int d, int mx, int my){ + private static void g(long[] state, int a, int b, int c, int d, long mx, long my){ state[a] = wrappingAdd(wrappingAdd(state[a], state[b]), mx); state[d] = rotateRight((state[d] ^ state[a]), 16); state[c] = wrappingAdd(state[c], state[d]); @@ -54,7 +55,7 @@ private static void g(int[] state, int a, int b, int c, int d, int mx, int my){ state[b] = rotateRight((state[b] ^ state[c]), 7); } - private static void roundFn(int[] state, int[] m){ + private static void roundFn(long[] state, long[] m){ // Mix columns g(state,0,4,8,12,m[0],m[1]); g(state,1,5,9,13,m[2],m[3]); @@ -68,18 +69,18 @@ private static void roundFn(int[] state, int[] m){ g(state,3,4,9,14,m[14],m[15]); } - private static int[] permute(int[] m){ - int[] permuted = new int[16]; + private static long[] permute(long[] m){ + long[] permuted = new long[16]; for(int i = 0;i<16;i++){ permuted[i] = m[MSG_PERMUTATION[i]]; } return permuted; } - private static int[] compress(int[] chainingValue, int[] blockWords, long counter, int blockLen, int flags){ - int counterInt = (int) counter; - int counterShift = (int) (counter >> 32); - int[] state = { + private static long[] compress(long[] chainingValue, long[] blockWords, long counter, int blockLen, long flags){ + long counterInt = counter & 0xffffffffL; + long counterShift = (counter >> 32) & 0xffffffffL; + long[] state = { chainingValue[0], chainingValue[1], chainingValue[2], @@ -118,13 +119,19 @@ private static int[] compress(int[] chainingValue, int[] blockWords, long counte return state; } - private static int[] wordsFromLEBytes(byte[] bytes){ + private static long[] wordsFromLEBytes(short[] bytes){ if ((bytes.length != 64)) throw new AssertionError(); - int[] words = new int[16]; - ByteBuffer buf = ByteBuffer.wrap(bytes); + long[] words = new long[16]; + + //ByteBuffer buf = ByteBuffer.wrap(bytes); + ByteBuffer buf = ByteBuffer.allocate(64); + for(short b: bytes){ + buf.put((byte) (b & 0xff)); + } + buf.order(ByteOrder.LITTLE_ENDIAN); for(int i = 0; i<16; i++){ - words[i] = buf.getInt(i*4); + words[i] = buf.getInt(i*4) & 0xffffffffL; } return words; } @@ -133,13 +140,13 @@ private static int[] wordsFromLEBytes(byte[] bytes){ // Is either chained into the next node using chainingValue() // Or used to calculate the hash digest using rootOutputBytes() private static class Node { - int[] inputChainingValue; - int[] blockWords; + long[] inputChainingValue; + long[] blockWords; long counter; int blockLen; int flags; - private Node(int[] inputChainingValue, int[] blockWords, long counter, int blockLen, int flags) { + private Node(long[] inputChainingValue, long[] blockWords, long counter, int blockLen, int flags) { this.inputChainingValue = inputChainingValue; this.blockWords = blockWords; this.counter = counter; @@ -148,25 +155,30 @@ private Node(int[] inputChainingValue, int[] blockWords, long counter, int block } // Return the 8 int CV - private int[] chainingValue(){ + private long[] chainingValue(){ return Arrays.copyOfRange( compress(inputChainingValue, blockWords, counter, blockLen, flags), 0,8); } - private byte[] rootOutputBytes(int outLen){ + // Fuck java's lack of unsigned byte >:C + private short[] rootOutputBytes(int outLen){ int outputCounter = 0; int outputsNeeded = Math.floorDiv(outLen,(2*OUT_LEN)) + 1; - byte[] hash = new byte[outLen]; + short[] hash = new short[outLen]; int i = 0; while(outputCounter < outputsNeeded){ - int[] words = compress(inputChainingValue, blockWords, outputCounter, blockLen,flags | ROOT ); - - for(int word: words){ - for(byte b: ByteBuffer.allocate(4).putInt(word).array()){ - hash[i] = b; + long[] words = compress(inputChainingValue, blockWords, outputCounter, blockLen,flags | ROOT ); + + for(long word: words){ + for(byte b: ByteBuffer.allocate(4) + .order(ByteOrder.LITTLE_ENDIAN) + .putInt((int) (word & 0xffffffffL)) + .array()){ + hash[i] = (short) (b & 0xff); i+=1; if(i == outLen){ + System.out.println(Arrays.toString(hash)); return hash; } } @@ -179,14 +191,14 @@ private byte[] rootOutputBytes(int outLen){ // Helper object for creating new Nodes and chaining them private class ChunkState { - int[] chainingValue; + long[] chainingValue; int chunkCounter; - byte[] block = new byte[64]; + short[] block = new short[BLOCK_LEN]; byte blockLen = 0; byte blocksCompressed = 0; int flags; - public ChunkState(int[] key, int chunkCounter, int flags){ + public ChunkState(long[] key, int chunkCounter, int flags){ this.chainingValue = key; this.chunkCounter = chunkCounter; this.flags = flags; @@ -200,17 +212,17 @@ private int startFlag(){ return blocksCompressed == 0? CHUNK_START: 0; } - private void update(byte[] input) { + private void update(short[] input) { while (input.length != 0) { // Chain the next 64 byte block into this chunk/node if (blockLen == BLOCK_LEN) { - int[] blockWords = wordsFromLEBytes(block); + long[] blockWords = wordsFromLEBytes(block); this.chainingValue = Arrays.copyOfRange( compress(this.chainingValue, blockWords, this.chunkCounter, BLOCK_LEN,this.flags | this.startFlag()), 0, 8); blocksCompressed += 1; - this.block = new byte[64]; + this.block = new short[BLOCK_LEN]; this.blockLen = 0; } @@ -233,8 +245,8 @@ private Node createNode(){ // Hasher private ChunkState chunkState; - private int[] key; - private int[][] cvStack = new int[54][]; + private long[] key; + private long[][] cvStack = new long[54][]; private byte cvStackLen = 0; private int flags; @@ -242,13 +254,13 @@ public Blake3(){ this(IV,0); } - public Blake3(int[] key, int flags){ + public Blake3(long[] key, int flags){ this.chunkState = new ChunkState(key, 0, flags); this.key = key; this.flags = flags; } - public Blake3(byte[] key){ + public Blake3(short[] key){ this(wordsFromLEBytes(key), KEYED_HASH); } @@ -256,36 +268,36 @@ public Blake3(String context){ Blake3 contextHasher = new Blake3(IV, DERIVE_KEY_CONTEXT); } - private void pushStack(int[] cv){ + private void pushStack(long[] cv){ this.cvStack[this.cvStackLen] = cv; cvStackLen+=1; } - private int[] popStack(){ + private long[] popStack(){ this.cvStackLen-=1; return cvStack[cvStackLen]; } // Combines the chaining values of two children to create the parent node - private static Node parentNode(int[] leftChildCV, int[] rightChildCV, int[] key, int flags){ - int[] blockWords = new int[16]; + private static Node parentNode(long[] leftChildCV, long[] rightChildCV, long[] key, int flags){ + long[] blockWords = new long[16]; int i = 0; - for(int x: leftChildCV){ + for(long x: leftChildCV){ blockWords[i] = x; i+=1; } - for(int x: rightChildCV){ + for(long x: rightChildCV){ blockWords[i] = x; i+=1; } return new Node(key, blockWords, 0, BLOCK_LEN, PARENT | flags); } - private static int[] parentCV(int[] leftChildCV, int[] rightChildCV, int[] key, int flags){ + private static long[] parentCV(long[] leftChildCV, long[] rightChildCV, long[] key, int flags){ return parentNode(leftChildCV, rightChildCV, key, flags).chainingValue(); } - private void addChunkChainingValue(int[] newCV, long totalChunks){ + private void addChunkChainingValue(long[] newCV, long totalChunks){ while((totalChunks & 1) == 0){ newCV = parentCV(popStack(), newCV, key, flags); totalChunks >>=1; @@ -293,12 +305,20 @@ private void addChunkChainingValue(int[] newCV, long totalChunks){ pushStack(newCV); } - public void update(byte[] input){ + public void update(String input){ + byte[] inputBytes = input.getBytes(); + short[] converted = new short[inputBytes.length]; + for(int i = 0; i < inputBytes.length; i++){ + converted[i] = (short) (0xff & inputBytes[i]); + } + updateRaw(converted); + } + public void updateRaw(short[] input){ while(input.length != 0) { // If this chunk has chained in 16 64 bytes of input, add it's CV to the stack if (chunkState.len() == CHUNK_LEN) { - int[] chunkCV = chunkState.createNode().chainingValue(); + long[] chunkCV = chunkState.createNode().chainingValue(); int totalChunks = chunkState.chunkCounter + 1; addChunkChainingValue(chunkCV, totalChunks); chunkState = new ChunkState(key, totalChunks, flags); @@ -311,7 +331,7 @@ public void update(byte[] input){ } } - public byte[] digest(int hashLen){ + public short[] digest(int hashLen){ Node node = this.chunkState.createNode(); int parentNodesRemaining = cvStackLen; while(parentNodesRemaining > 0){ @@ -334,7 +354,7 @@ public String hexdigest(){ return hexdigest(32); } - private static String bytesToHex(byte[] bytes) { + private static String bytesToHex(short[] bytes) { char[] hexChars = new char[bytes.length * 2]; for (int j = 0; j < bytes.length; j++) { int v = bytes[j] & 0xFF; diff --git a/src/com/github/rctcwyvrn/blake3java/Main.java b/src/com/github/rctcwyvrn/blake3java/Main.java index e806b78..03df63f 100644 --- a/src/com/github/rctcwyvrn/blake3java/Main.java +++ b/src/com/github/rctcwyvrn/blake3java/Main.java @@ -1,5 +1,7 @@ package com.github.rctcwyvrn.blake3java; +import javax.swing.*; +import java.nio.charset.Charset; import java.util.Arrays; public class Main { @@ -12,14 +14,12 @@ public static void main(String[] args){ // System.out.println(Arrays.toString(Arrays.copyOfRange(input, 4 + 1, input.length))); Blake3 hasher = new Blake3(); -// hasher.update("AAAAAAAAAAAAAAAA".getBytes()); -// hasher.update("AAAAAAAAAAAAAAAA".getBytes()); -// hasher.update("AAAAAAAAAAAAAAAA".getBytes()); -// hasher.update("AAAAAAAAAAAAAAAA".getBytes()); -// hasher.update("AAAAAAAAAAAAAAAA".getBytes()); +// hasher.update("AAAAAAAAAAAAAAAA"); - hasher.update("abc".getBytes()); - hasher.update("def".getBytes()); + //hasher.update("abc"); + //hasher.update("def"); + + hasher.update("This is a string"); //Should be 718b749f12a61257438b2ea6643555fd995001c9d9ff84764f93f82610a780f2 String hexhash = hasher.hexdigest(); System.out.println(hexhash);