Skip to content

Commit

Permalink
close to a working implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
mitschabaude committed Oct 11, 2024
1 parent f1afee5 commit d5a81b8
Showing 1 changed file with 98 additions and 40 deletions.
138 changes: 98 additions & 40 deletions src/attestations/dynamic-sha256.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Field, Gadgets, Packed, Provable, UInt32, UInt8 } from 'o1js';
import { DynamicArray } from './dynamic-array.ts';
import { StaticArray } from './static-array.ts';
import { assert, chunk } from '../util.ts';
import { assert, chunk, pad } from '../util.ts';

const { SHA256 } = Gadgets;

Expand All @@ -13,25 +13,27 @@ class Bytes extends DynamicArray(UInt8, { maxLength: 80 }) {
}
}
// hierarchy of packed types to do make array ops more efficient
const UInt8x2 = StaticArray(UInt8, 2);
const UInt16 = Packed.create(UInt8x2);
const UInt16x8 = StaticArray(UInt16, 8);
const UInt128 = Packed.create(UInt16x8);
const UInt128x4 = StaticArray(UInt128, 4);
class UInt8x64 extends StaticArray(UInt8, 64) {}
class UInt32x16 extends StaticArray(UInt32, 16) {}
class UInt32x4 extends StaticArray(UInt32, 4) {}
class UInt128 extends Packed.create(UInt32x4) {}
class UInt128x4 extends StaticArray(UInt128, 4) {}
const State = Provable.Array(UInt32, 16);

let bytes = Bytes.fromString('test');

let blocks = createPaddedBlocks(bytes);

let state = blocks.reduce(State, SHA256.initialState, hashBlock);
console.dir(blocks.toValue());

let state = blocks
.map(State, (block) => block.array)
.reduce(State, SHA256.initialState, hashBlock);

/**
* Apply padding to dynamic-length input bytes and split them into sha2 blocks
*/
function createPaddedBlocks(
bytes: DynamicArray<UInt8>
): DynamicArray<UInt32[]> {
function createPaddedBlocks(message: DynamicArray<UInt8>) {
/* padded message looks like this:
M ... M 0x1 0x0 ... 0x0 L L L L L L L L
Expand All @@ -43,44 +45,81 @@ function createPaddedBlocks(
- there are k 0x0 bytes, where k is the smallest number such that
the padded length (in bytes) is a multiple of 64
Corollary: the entire padding is always contained in the same (last) block
Corollaries:
- the entire L section is always contained at the end of the last block
- the 0x1 byte might be in the last block or the one before that
- max number of blocks = ceil((M.maxLength + 9) / 64)
- number of actual blocks = ceil((M.length + 9) / 64) = floor((M.length + 9 + 63) / 64) = floor((M.length + 8) / 64) + 1
- block number of L section = floor((M.length + 8) / 64)
- block number of 0x1 byte index = floor(M.length / 64)
*/

// create chunks of 64 bytes each
let { chunks: blocksOfUInt8, innerLength } = bytes.chunk(64);
// create blocks of 64 bytes each
const maxBlocks = Math.ceil((message.maxLength + 9) / 64);
const BlocksOfBytes = DynamicArray(UInt8x64, { maxLength: maxBlocks });

let lastBlockIndex = UInt32.Unsafe.fromField(message.length.add(8)).div(64);
let numberOfBlocks = lastBlockIndex.value.add(1);
let padded = pad(message.array, maxBlocks * 64, UInt8.from(0));
let chunked = chunk(padded, 64).map(UInt8x64.from);
let blocksOfBytes = new BlocksOfBytes(chunked, numberOfBlocks);

// pack each block of 64 bytes into 32 uint16s
let blocksOfUInt16 = blocksOfUInt8.map(UInt16x8, (block) =>
UInt16x8.from(block.chunk(2).map(UInt16, UInt16.pack))
// pack each block of 64 bytes into 16 uint32s (4 bytes each)
let blocksOfUInt32 = blocksOfBytes.map(UInt32x16, (block) =>
block.chunk(4).map(UInt32, uint32FromBytes)
);

// pack each block of 32 uint16s into 4 uint128s
let blocksOfUInt128 = blocksOfUInt16.map(UInt128x4, (block) =>
UInt128x4.from(block.chunk(8).map(UInt128, UInt128.pack))
// pack each block of 16 uint32s into 4 uint128s (4 uint32s each)
let blocksOfUInt128 = blocksOfUInt32.map(UInt128x4, (block) =>
block.chunk(4).map(UInt128, UInt128.pack)
);

// splice the length in the same way
// length = l0 + 2*(l10 + 8*l11) + 64*blocks.length
let { rest: l0, quotient: l1 } =
UInt32.Unsafe.fromField(innerLength).divMod(2);
let { rest: l10, quotient: l11 } = l1.divMod(8);

// get the last block, and correct sub-blocks within that
let lastIndex = blocksOfUInt128.length.sub(1);
let lastBlock = blocksOfUInt128.getOrUnconstrained(lastIndex);
let lastUint128 = lastBlock.getOrUnconstrained(l11.value).unpack();
let lastUint16 = lastUint128.getOrUnconstrained(l10.value).unpack();

// set 0x1 byte at `length`
lastUint16.setOrDoNothing(l0.value, UInt8.from(0x1));
lastUint128.setOrDoNothing(l11.value, UInt16.pack(lastUint16));
lastBlock.setOrDoNothing(l10.value, UInt128.pack(lastUint128));

throw Error('todo');
// length = l0 + 4*l1 + 16*l2 + 64*l3
let [l0, l1, l2, l3] = splitMultiIndex(
UInt32.Unsafe.fromField(message.length)
);

// hierarchically get blocks at `length` and set to 0x1 byte
let block = blocksOfUInt128.getOrUnconstrained(l3);
let uint32x4 = block.getOrUnconstrained(l2).unpack();
let uint8x4 = uint32ToBytes(uint32x4.getOrUnconstrained(l1));
uint8x4.setOrDoNothing(l0, UInt8.from(0x1));
uint32x4.setOrDoNothing(l1, uint32FromBytes(uint8x4));
block.setOrDoNothing(l2, UInt128.pack(uint32x4));

// set last 64 bits to encoded length (in bits, big-endian encoded)
// in fact, since we assume the length (in bytes) fits in 16 bits, we only need to set the last uint32
let lastBlock = blocksOfUInt128.getOrUnconstrained(lastBlockIndex.value);
let lastUInt128 = lastBlock.get(3).unpack();
lastUInt128.set(2, UInt32.from(0));
lastUInt128.set(3, encodeLength(message.length));
lastBlock.set(3, UInt128.pack(lastUInt128));
blocksOfUInt128.setOrDoNothing(lastBlockIndex.value, lastBlock);

// unpack all blocks to UInt32[]
return blocksOfUInt128.map(UInt32x16, (block) =>
block.array.flatMap((uint128) => uint128.unpack().array)
);
}

function padLastBlock(lastBlock: UInt32[]): UInt32[] {
throw Error('todo');
function splitMultiIndex(index: UInt32) {
let { rest: l0, quotient: l1 } = index.divMod(64);
let { rest: l00, quotient: l01 } = l0.divMod(16);
let { rest: l000, quotient: l001 } = l00.divMod(4);
return [l000.value, l001.value, l01.value, l1.value] as const;
}

function splitMultiIndexGeneral(index: UInt32, sizes: number[]) {
let indices: UInt32[] = Array(sizes.length + 1);

for (let i = sizes.length - 1; i >= 0; i--) {
let { rest, quotient } = index.divMod(sizes[i]!);
indices[i + 1] = quotient;
index = rest;
}
indices[0] = index;
return indices;
}

function hashBlock(state: UInt32[], block: UInt32[]) {
Expand All @@ -90,10 +129,10 @@ function hashBlock(state: UInt32[], block: UInt32[]) {

function bytesToState(bytes: UInt8[]) {
assert(bytes.length === 64, '64 bytes needed to create 16 uint32s');
return chunk(bytes, 4).map(bytesToWord);
return chunk(bytes, 4).map(uint32FromBytes);
}

function bytesToWord(bytes: UInt8[]) {
function uint32FromBytes(bytes: UInt8[] | StaticArray<UInt8>) {
assert(bytes.length === 4, '4 bytes needed to create a uint32');

let word = Field(0);
Expand All @@ -103,3 +142,22 @@ function bytesToWord(bytes: UInt8[]) {

return UInt32.Unsafe.fromField(word);
}

function uint32ToBytes(word: UInt32) {
// witness the bytes
let bytes = Provable.witness(StaticArray(UInt8, 4), () => {
let value = word.value.toBigInt();
return [0, 1, 2, 3].map((i) =>
UInt8.from((value >> BigInt(8 * i)) & 0xffn)
);
});

// prove that the bytes are correct
uint32FromBytes(bytes).assertEquals(word);

return bytes;
}

function encodeLength(lengthInBytes: Field): UInt32 {
return UInt32.Unsafe.fromField(lengthInBytes.mul(8));
}

0 comments on commit d5a81b8

Please sign in to comment.