Skip to content

Commit 752e784

Browse files
Pass Attribute to debug function call
1 parent 8fe24bd commit 752e784

File tree

8 files changed

+161
-7
lines changed

8 files changed

+161
-7
lines changed

lib/Target/Lattigo/LattigoEmitter.cpp

+27
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,39 @@ LogicalResult LattigoEmitter::printOperation(func::ReturnOp op) {
164164
}
165165

166166
LogicalResult LattigoEmitter::printOperation(func::CallOp op) {
167+
// build debug attribute map for debug call
168+
auto debugAttrMapName = getDebugAttrMapName();
169+
if (isDebugPort(op.getCallee())) {
170+
os << debugAttrMapName << " := make(map[string]string)\n";
171+
for (auto attr : op->getAttrs()) {
172+
// callee is also an attribute internally, skip it
173+
if (attr.getName().getValue() == "callee") {
174+
continue;
175+
}
176+
os << debugAttrMapName << "[\"" << attr.getName().getValue()
177+
<< "\"] = \"";
178+
// Use AsmPrinter to print Attribute
179+
if (mlir::isa<StringAttr>(attr.getValue())) {
180+
os << mlir::cast<StringAttr>(attr.getValue()).getValue() << "\"\n";
181+
} else {
182+
os << attr.getValue() << "\"\n";
183+
}
184+
}
185+
// Use AsmPrinter to print Value to print the defining op
186+
auto ciphertext = op->getOperand(op->getNumOperands() - 1);
187+
os << debugAttrMapName << R"(["op"] = ")" << ciphertext << "\"\n";
188+
}
189+
167190
if (op.getNumResults() > 0) {
168191
os << getCommaSeparatedNames(op.getResults());
169192
os << " := ";
170193
}
171194
os << canonicalizeDebugPort(op.getCallee()) << "(";
172195
os << getCommaSeparatedNames(op.getOperands());
196+
// pass debug attribute map
197+
if (isDebugPort(op.getCallee())) {
198+
os << ", " << debugAttrMapName;
199+
}
173200
os << ")\n";
174201
return success();
175202
}

lib/Target/Lattigo/LattigoEmitter.h

+5
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,11 @@ class LattigoEmitter {
147147
return "err" + std::to_string(errCount++);
148148
}
149149

150+
std::string getDebugAttrMapName() {
151+
static int debugAttrMapCount = 0;
152+
return "debugAttrMap" + std::to_string(debugAttrMapCount++);
153+
}
154+
150155
std::string getCommaSeparatedNames(::mlir::ValueRange values) {
151156
return commaSeparatedValues(values,
152157
[&](Value value) { return getName(value); });

lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp

+36-1
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,12 @@ LogicalResult OpenFhePkeEmitter::printOperation(ModuleOp moduleOp) {
132132
return success();
133133
}
134134

135+
bool OpenFhePkeEmitter::isDebugPort(StringRef debugPortName) {
136+
return debugPortName.rfind("__heir_debug") == 0;
137+
}
138+
135139
StringRef OpenFhePkeEmitter::canonicalizeDebugPort(StringRef debugPortName) {
136-
if (debugPortName.rfind("__heir_debug") == 0) {
140+
if (isDebugPort(debugPortName)) {
137141
return "__heir_debug";
138142
}
139143
return debugPortName;
@@ -177,6 +181,10 @@ LogicalResult OpenFhePkeEmitter::printOperation(func::FuncOp funcOp) {
177181
os << commaSeparatedTypes(funcOp.getArgumentTypes(), [&](Type type) {
178182
return convertType(type, funcOp->getLoc()).value();
179183
});
184+
// debug attribute map for debug call
185+
if (isDebugPort(funcOp.getName())) {
186+
os << ", const std::map<std::string, std::string>&";
187+
}
180188
} else {
181189
os << commaSeparatedValues(funcOp.getArguments(), [&](Value value) {
182190
return convertType(value.getType(), funcOp->getLoc()).value() + " " +
@@ -213,6 +221,29 @@ LogicalResult OpenFhePkeEmitter::printOperation(func::CallOp op) {
213221
return emitError(op.getLoc(), "Only one return value supported");
214222
}
215223

224+
// build debug attribute map for debug call
225+
auto debugAttrMapName = getDebugAttrMapName();
226+
if (isDebugPort(op.getCallee())) {
227+
os << "std::map<std::string, std::string> " << debugAttrMapName << ";\n";
228+
for (auto attr : op->getAttrs()) {
229+
// callee is also an attribute internally, skip it
230+
if (attr.getName().getValue() == "callee") {
231+
continue;
232+
}
233+
os << debugAttrMapName << "[\"" << attr.getName().getValue()
234+
<< "\"] = \"";
235+
// Use AsmPrinter to print Attribute
236+
if (mlir::isa<StringAttr>(attr.getValue())) {
237+
os << mlir::cast<StringAttr>(attr.getValue()).getValue() << "\"\n";
238+
} else {
239+
os << attr.getValue() << "\";\n";
240+
}
241+
}
242+
// Use AsmPrinter to print Value to print the defining op
243+
auto ciphertext = op->getOperand(op->getNumOperands() - 1);
244+
os << debugAttrMapName << R"(["op"] = ")" << ciphertext << "\";\n";
245+
}
246+
216247
if (op.getNumResults() != 0) {
217248
emitAutoAssignPrefix(op.getResult(0));
218249
}
@@ -221,6 +252,10 @@ LogicalResult OpenFhePkeEmitter::printOperation(func::CallOp op) {
221252
os << commaSeparatedValues(op.getOperands(), [&](Value value) {
222253
return variableNames->getNameForValue(value);
223254
});
255+
// pass debug attribute map
256+
if (isDebugPort(op.getCallee())) {
257+
os << ", " << debugAttrMapName;
258+
}
224259
os << ");\n";
225260
return success();
226261
}

lib/Target/OpenFhePke/OpenFhePkeEmitter.h

+6
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,14 @@ class OpenFhePkeEmitter {
101101
LogicalResult emitType(::mlir::Type type, ::mlir::Location loc);
102102

103103
// Canonicalize Debug Port
104+
bool isDebugPort(::llvm::StringRef debugPortName);
104105
::llvm::StringRef canonicalizeDebugPort(::llvm::StringRef debugPortName);
105106

107+
std::string getDebugAttrMapName() {
108+
static int debugAttrMapCount = 0;
109+
return "debugAttrMap" + std::to_string(debugAttrMapCount++);
110+
}
111+
106112
void emitAutoAssignPrefix(::mlir::Value result);
107113
LogicalResult emitTypedAssignPrefix(::mlir::Value result,
108114
::mlir::Location loc);

tests/Dialect/Lattigo/Emitters/emit_lattigo.mlir

+17
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,20 @@ module attributes {scheme.bgv} {
168168
return %evaluator : !evaluator
169169
}
170170
}
171+
172+
// -----
173+
174+
// CHECK-LABEL: func dot_product
175+
// CHECK: ["bound"] = "50"
176+
// CHECK: ["complex"] = "{test = 1.200000e+00 : f64}"
177+
// CHECK: ["random"] = "3 : i64"
178+
// CHECK: ["secret.secret"] = "unit"
179+
// CHECK: ["op"]
180+
181+
module attributes {scheme.bgv} {
182+
func.func private @__heir_debug_0(!lattigo.bgv.evaluator, !lattigo.bgv.parameter, !lattigo.bgv.encoder, !lattigo.rlwe.decryptor, !lattigo.rlwe.ciphertext)
183+
func.func @dot_product(%evaluator: !lattigo.bgv.evaluator, %param: !lattigo.bgv.parameter, %encoder: !lattigo.bgv.encoder, %decryptor: !lattigo.rlwe.decryptor, %ct: !lattigo.rlwe.ciphertext, %ct_0: !lattigo.rlwe.ciphertext) -> !lattigo.rlwe.ciphertext attributes {mgmt.openfhe_params = #mgmt.openfhe_params<evalAddCount = 8, keySwitchCount = 15>} {
184+
call @__heir_debug_0(%evaluator, %param, %encoder, %decryptor, %ct) {bound = "50", random = 3, complex = {test = 1.2}, secret.secret} : (!lattigo.bgv.evaluator, !lattigo.bgv.parameter, !lattigo.bgv.encoder, !lattigo.rlwe.decryptor, !lattigo.rlwe.ciphertext) -> ()
185+
return %ct : !lattigo.rlwe.ciphertext
186+
}
187+
}

tests/Dialect/Openfhe/Emitters/emit_openfhe_pke.mlir

+30
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,33 @@ module attributes {scheme.ckks} {
179179
return %1 : !openfhe.crypto_context
180180
}
181181
}
182+
183+
// -----
184+
185+
!Z2147565569_i64_ = !mod_arith.int<2147565569 : i64>
186+
!Z65537_i64_ = !mod_arith.int<65537 : i64>
187+
#full_crt_packing_encoding = #lwe.full_crt_packing_encoding<scaling_factor = 0>
188+
#key = #lwe.key<>
189+
#modulus_chain_L0_C0_ = #lwe.modulus_chain<elements = <2147565569 : i64>, current = 0>
190+
!rns_L0_ = !rns.rns<!Z2147565569_i64_>
191+
#ring_Z65537_i64_1_x8_ = #polynomial.ring<coefficientType = !Z65537_i64_, polynomialModulus = <1 + x**8>>
192+
#plaintext_space = #lwe.plaintext_space<ring = #ring_Z65537_i64_1_x8_, encoding = #full_crt_packing_encoding>
193+
#ring_rns_L0_1_x8_ = #polynomial.ring<coefficientType = !rns_L0_, polynomialModulus = <1 + x**8>>
194+
!pt = !lwe.new_lwe_plaintext<application_data = <message_type = i16>, plaintext_space = #plaintext_space>
195+
#ciphertext_space_L0_ = #lwe.ciphertext_space<ring = #ring_rns_L0_1_x8_, encryption_type = lsb>
196+
!ct_L0_ = !lwe.new_lwe_ciphertext<application_data = <message_type = i16>, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L0_, key = #key, modulus_chain = #modulus_chain_L0_C0_>
197+
198+
// CHECK: __heir_debug(CryptoContextT, PrivateKeyT, CiphertextT, const std::map<std::string, std::string>&)
199+
// CHECK: ["bound"] = "50"
200+
// CHECK: ["complex"] = "{test = 1.200000e+00 : f64}"
201+
// CHECK: ["random"] = "3 : i64"
202+
// CHECK: ["secret.secret"] = "unit"
203+
// CHECK: ["op"]
204+
205+
module attributes {scheme.bgv} {
206+
func.func private @__heir_debug_0(!openfhe.crypto_context, !openfhe.private_key, !ct_L0_)
207+
func.func @add(%cc: !openfhe.crypto_context, %sk: !openfhe.private_key, %ct: !ct_L0_) -> !ct_L0_ {
208+
call @__heir_debug_0(%cc, %sk, %ct) {bound = "50", random = 3, complex = {test = 1.2}, secret.secret} : (!openfhe.crypto_context, !openfhe.private_key, !ct_L0_) -> ()
209+
return %ct : !ct_L0_
210+
}
211+
}

tests/Examples/lattigo/dot_product_8_debug.go

+19-3
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,32 @@ package dotproduct8debug
33

44
import (
55
"fmt"
6+
"strings"
67

78
"github.com/tuneinsight/lattigo/v6/core/rlwe"
89
"github.com/tuneinsight/lattigo/v6/schemes/bgv"
910
)
1011

11-
func __heir_debug(evaluator *bgv.Evaluator, param bgv.Parameters, encoder *bgv.Encoder, decryptor *rlwe.Decryptor, ct *rlwe.Ciphertext) {
12+
func __heir_debug(evaluator *bgv.Evaluator, param bgv.Parameters, encoder *bgv.Encoder, decryptor *rlwe.Decryptor, ct *rlwe.Ciphertext, debugAttrMap map[string]string) {
13+
// print op
14+
op := debugAttrMap["op"]
15+
sepBlockArgument := strings.Index(op, "of type")
16+
// the functional type of the operation
17+
sepOp := strings.Index(op, ": (")
18+
if sepBlockArgument != -1 {
19+
op = op[:sepBlockArgument]
20+
} else if sepOp != -1 {
21+
op = op[:sepOp]
22+
}
23+
fmt.Println(op)
24+
25+
// print the decryption result
1226
value := make([]int64, 8)
1327
pt := decryptor.DecryptNew(ct)
1428
encoder.Decode(pt, value)
15-
fmt.Println(value)
29+
fmt.Printf(" %v\n", value)
30+
31+
// print the noise
1632

1733
// get a new pt with no noise
1834
// in Lattigo, Decrypt won't mod T
@@ -34,5 +50,5 @@ func __heir_debug(evaluator *bgv.Evaluator, param bgv.Parameters, encoder *bgv.E
3450
total += param.LogQi()[i]
3551
}
3652
// t * e for BGV
37-
fmt.Printf("Noise: %.2f Total: %d\n", max+param.LogT(), total)
53+
fmt.Printf(" Noise: %.2f Total: %d\n", max+param.LogT(), total)
3854
}

tests/Examples/openfhe/dot_product_8_debug_test.cpp

+21-3
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,33 @@ DCRTPoly DecryptCore(const std::vector<DCRTPoly>& cv,
4444
return b;
4545
}
4646

47+
#define OP
4748
#define DECRYPT
4849
#define NOISE
4950

50-
void __heir_debug(CryptoContextT cc, PrivateKeyT sk, CiphertextT ct) {
51+
void __heir_debug(CryptoContextT cc, PrivateKeyT sk, CiphertextT ct,
52+
const std::map<std::string, std::string>& debugAttrMap) {
53+
#ifdef OP
54+
auto op = debugAttrMap.at("op");
55+
auto sepBlockArgument = op.find("of type");
56+
// the functional type of the operation
57+
auto sepOpenfheOp = op.find(": (");
58+
auto sepLweOp = op.find(": !");
59+
if (sepBlockArgument != std::string::npos) {
60+
op = op.substr(0, sepBlockArgument);
61+
} else if (sepOpenfheOp != std::string::npos) {
62+
op = op.substr(0, sepOpenfheOp);
63+
} else if (sepLweOp != std::string::npos) {
64+
op = op.substr(0, sepLweOp);
65+
}
66+
std::cout << op << std::endl;
67+
#endif
68+
5169
#ifdef DECRYPT
5270
PlaintextT ptxt;
5371
cc->Decrypt(sk, ct, &ptxt);
5472
ptxt->SetLength(8);
55-
std::cout << ptxt << std::endl;
73+
std::cout << " " << ptxt << std::endl;
5674
#endif
5775

5876
#ifdef NOISE
@@ -73,7 +91,7 @@ void __heir_debug(CryptoContextT cc, PrivateKeyT sk, CiphertextT ct) {
7391
logQ += logqi;
7492
}
7593

76-
std::cout << "cv " << cv.size() << " Ql " << sizeQl << " logQ: " << logQ
94+
std::cout << " cv " << cv.size() << " Ql " << sizeQl << " logQ: " << logQ
7795
<< " logqi: " << logqi_v << " budget " << logQ - noise - 1
7896
<< " noise: " << noise << std::endl;
7997
#endif

0 commit comments

Comments
 (0)