Skip to content

Commit

Permalink
fix: updated mocked mode to be compatible with new types
Browse files Browse the repository at this point in the history
  • Loading branch information
jatZama committed Mar 5, 2024
1 parent b2f9477 commit c95296d
Show file tree
Hide file tree
Showing 10 changed files with 223 additions and 141 deletions.
25 changes: 25 additions & 0 deletions .github/workflows/testmock.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name: Pull request tests

on:
pull_request:
branches:
- main

jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
node-version: [18.x]
steps:
- uses: actions/checkout@v3
- name: Use Node.js ${{ matrix.node-version }}
uses: actions/setup-node@v3
with:
node-version: ${{ matrix.node-version }}
- run: cp .env.example .env
- run: npm ci
- name: "npm CI test"
run: |
# sometimes not created and is not tailed
npm run test:mock
1 change: 0 additions & 1 deletion codegen/overloadTests.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import overloads from './overloads.json';
import { OverloadSignature, signatureContractMethodName } from './testgen';

type OverloadTestJSON = {
inputs: (number | bigint | string)[];
Expand Down
133 changes: 92 additions & 41 deletions codegen/templates.ts
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ library TFHE {

supportedBits.forEach((bits) => {
operators.forEach((operator) => {
if (operator.shiftOperator) res.push(tfheShiftOperators(bits, operator, signatures));
if (operator.shiftOperator) res.push(tfheShiftOperators(bits, operator, signatures, mocked));
});
});

Expand All @@ -227,9 +227,9 @@ library TFHE {
res.push(tfheAsEboolUnaryCast(outputBits));
});
supportedBits.forEach((bits) => res.push(tfheUnaryOperators(bits, operators, signatures)));
supportedBits.forEach((bits) => res.push(tfheCustomUnaryOperators(bits, signatures)));
supportedBits.forEach((bits) => res.push(tfheCustomUnaryOperators(bits, signatures, mocked)));

res.push(tfheCustomMethods(ctx));
res.push(tfheCustomMethods(ctx, mocked));

res.push('}\n');

Expand Down Expand Up @@ -428,7 +428,12 @@ function tfheScalarOperator(
return res.join('');
}

function tfheShiftOperators(inputBits: number, operator: Operator, signatures: OverloadSignature[]): string {
function tfheShiftOperators(
inputBits: number,
operator: Operator,
signatures: OverloadSignature[],
mocked: boolean,
): string {
const res: string[] = [];

// Code and test for shift(euint{inputBits},euint8}
Expand All @@ -444,7 +449,12 @@ function tfheShiftOperators(inputBits: number, operator: Operator, signatures: O

const leftExpr = 'a';
const rightExpr = castRightToLeft ? `asEuint${outputBits}(b)` : 'b';
let implExpression = `Impl.${operator.name}(euint${outputBits}.unwrap(${leftExpr}), euint${outputBits}.unwrap(${rightExpr})${scalarFlag})`;
let implExpression;
if (mocked) {
implExpression = `Impl.${operator.name}(euint${outputBits}.unwrap(${leftExpr}), euint${outputBits}.unwrap(${rightExpr}) % ${lhsBits}${scalarFlag})`;
} else {
implExpression = `Impl.${operator.name}(euint${outputBits}.unwrap(${leftExpr}), euint${outputBits}.unwrap(${rightExpr})${scalarFlag})`;
}

if (inputBits >= 8) {
signatures.push({
Expand Down Expand Up @@ -472,16 +482,9 @@ function tfheShiftOperators(inputBits: number, operator: Operator, signatures: O

// Code and test for shift(euint{inputBits},uint8}
scalarFlag = ', true';
const leftOpName = operator.name;
var implExpressionA = `Impl.${operator.name}(euint${outputBits}.unwrap(a), uint256(b)${scalarFlag})`;
var implExpressionB = `Impl.${leftOpName}(euint${outputBits}.unwrap(b), uint256(a)${scalarFlag})`;
var maybeEncryptLeft = '';
if (operator.leftScalarEncrypt) {
// workaround until tfhe-rs left scalar support:
// do the trivial encryption and preserve order of operations
scalarFlag = ', false';
maybeEncryptLeft = `euint${outputBits} aEnc = asEuint${outputBits}(a);`;
implExpressionB = `Impl.${leftOpName}(euint${outputBits}.unwrap(aEnc), euint${8}.unwrap(b)${scalarFlag})`;
implExpression = `Impl.${operator.name}(euint${outputBits}.unwrap(a), uint256(b)${scalarFlag})`;
if (mocked) {
implExpression = `Impl.${operator.name}(euint${outputBits}.unwrap(a), uint256(b) % ${lhsBits}${scalarFlag})`;
}
signatures.push({
name: operator.name,
Expand All @@ -497,7 +500,7 @@ function tfheShiftOperators(inputBits: number, operator: Operator, signatures: O
if (!isInitialized(a)) {
a = asEuint${lhsBits}(0);
}
return ${returnType}.wrap(${implExpressionA});
return ${returnType}.wrap(${implExpression});
}
`);

Expand Down Expand Up @@ -614,8 +617,8 @@ function tfheUnaryOperators(bits: number, operators: Operator[], signatures: Ove
return res.join('\n');
}

function tfheCustomUnaryOperators(bits: number, signatures: OverloadSignature[]): string {
return `
function tfheCustomUnaryOperators(bits: number, signatures: OverloadSignature[], mocked: boolean): string {
let result = `
// Convert a serialized 'ciphertext' to an encrypted euint${bits} integer.
function asEuint${bits}(bytes memory ciphertext) internal pure returns (euint${bits}) {
return euint${bits}.wrap(Impl.verify(ciphertext, Common.euint${bits}_t));
Expand All @@ -626,10 +629,18 @@ function tfheCustomUnaryOperators(bits: number, signatures: OverloadSignature[])
return euint${bits}.wrap(Impl.trivialEncrypt(value, Common.euint${bits}_t));
}
`;
if (mocked) {
result += `
// Decrypts the encrypted 'value'.
function decrypt(euint${bits} value) internal view returns (${getUint(bits)}) {
return ${getUint(bits)}(Impl.decrypt(euint${bits}.unwrap(value)) % 2**${bits});
}
// Reencrypt the given 'value' under the given 'publicKey'.
// Return a serialized euint${bits} ciphertext.
function reencrypt(euint${bits} value, bytes32 publicKey) internal view returns (bytes memory reencrypted) {
return Impl.reencrypt(euint${bits}.unwrap(value), publicKey);
return Impl.reencrypt(euint${bits}.unwrap(value) % 2**${bits}, publicKey);
}
// Reencrypt the given 'value' under the given 'publicKey'.
Expand All @@ -639,17 +650,40 @@ function tfheCustomUnaryOperators(bits: number, signatures: OverloadSignature[])
bits,
)} defaultValue) internal view returns (bytes memory reencrypted) {
if (euint${bits}.unwrap(value) != 0) {
return Impl.reencrypt(euint${bits}.unwrap(value), publicKey);
return Impl.reencrypt(euint${bits}.unwrap(value) % 2**${bits}, publicKey);
} else {
return Impl.reencrypt(euint${bits}.unwrap(asEuint${bits}(defaultValue)), publicKey);
return Impl.reencrypt(euint${bits}.unwrap(asEuint${bits}(defaultValue)) % 2**${bits}, publicKey);
}
}
`;
} else {
result += `
// Decrypts the encrypted 'value'.
function decrypt(euint${bits} value) internal view returns (${getUint(bits)}) {
return ${getUint(bits)}(Impl.decrypt(euint${bits}.unwrap(value)));
}
// Reencrypt the given 'value' under the given 'publicKey'.
// Return a serialized euint${bits} ciphertext.
function reencrypt(euint${bits} value, bytes32 publicKey) internal view returns (bytes memory reencrypted) {
return Impl.reencrypt(euint${bits}.unwrap(value), publicKey);
}
// Reencrypt the given 'value' under the given 'publicKey'.
// If 'value' is not initialized, the returned value will contain the 'defaultValue' constant.
// Return a serialized euint${bits} ciphertext.
function reencrypt(euint${bits} value, bytes32 publicKey, ${getUint(
bits,
)} defaultValue) internal view returns (bytes memory reencrypted) {
if (euint${bits}.unwrap(value) != 0) {
return Impl.reencrypt(euint${bits}.unwrap(value), publicKey);
} else {
return Impl.reencrypt(euint${bits}.unwrap(asEuint${bits}(defaultValue)), publicKey);
}
}
`;
}
return result;
}

function unaryOperatorImpl(op: Operator): string {
Expand All @@ -661,8 +695,8 @@ function unaryOperatorImpl(op: Operator): string {
`;
}

function tfheCustomMethods(ctx: CodegenContext): string {
return `
function tfheCustomMethods(ctx: CodegenContext, mocked: boolean): string {
let result = `
// Optimistically require that 'b' is true.
//
// This function does not evaluate 'b' at the time of the call.
Expand All @@ -687,11 +721,6 @@ function tfheCustomMethods(ctx: CodegenContext): string {
Impl.optReq(euint8.unwrap(asEuint8(b)));
}
// Decrypts the encrypted 'value'.
function decrypt(ebool value) internal view returns (bool) {
return (Impl.decrypt(ebool.unwrap(value)) != 0);
}
// Reencrypt the given 'value' under the given 'publicKey'.
// Return a serialized euint8 value.
function reencrypt(ebool value, bytes32 publicKey) internal view returns (bytes memory reencrypted) {
Expand Down Expand Up @@ -753,6 +782,22 @@ function tfheCustomMethods(ctx: CodegenContext): string {
return euint32.wrap(Impl.randBounded(upperBound, Common.euint32_t));
}
`;
if (mocked) {
result += `
// Decrypts the encrypted 'value'.
function decrypt(ebool value) internal view returns (bool) {
return (Impl.decrypt(ebool.unwrap(value)) % 2 == 1);
}
`;
} else {
result += `
// Decrypts the encrypted 'value'.
function decrypt(ebool value) internal view returns (bool) {
return (Impl.decrypt(ebool.unwrap(value)) != 0);
}
`;
}
return result;
}

function implCustomMethods(ctx: CodegenContext): string {
Expand Down Expand Up @@ -971,18 +1016,24 @@ library Impl {
}
function cast(uint256 ciphertext, uint8 toType) internal pure returns (uint256 result) {
if (toType == 0) {
result = uint256(uint8(ciphertext));
}
if (toType == 1) {
result = uint256(uint16(ciphertext));
}
if (toType == 2) {
result = uint256(uint32(ciphertext));
}
if (toType == 3) {
result = uint256(uint64(ciphertext));
}
if (toType == 0) {
result = uint256(uint8(ciphertext));
}
if (toType == 1) {
result = uint256(uint8(ciphertext));
}
if (toType == 2) {
result = uint256(uint8(ciphertext));
}
if (toType == 3) {
result = uint256(uint16(ciphertext));
}
if (toType == 4) {
result = uint256(uint32(ciphertext));
}
if (toType == 5) {
result = uint256(uint64(ciphertext));
}
}
function trivialEncrypt(uint256 value, uint8 /*toType*/) internal pure returns (uint256 result) {
Expand Down
1 change: 0 additions & 1 deletion codegen/testgen.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { strict as assert } from 'node:assert';

import { Operator } from './common';
import { overloadTests } from './overloadTests';
import { getUint } from './utils';

Expand Down
Loading

0 comments on commit c95296d

Please sign in to comment.