Skip to content

[ORAM batch handling] oramExecutor with grpc queuing for batches #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 130 additions & 33 deletions pkg/oramExecutor/oramExecutor.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"fmt"
"os"
"strings"
"sync"
"sync/atomic"

executor "github.com/project/ObliSql/api/oramExecutor"
// "github.com/redis/go-redis/v9"
Expand All @@ -19,10 +21,35 @@ const (
stashSize = 100000 // Maximum number of blocks in stash
)

type Operation struct {
RequestID uint64
Key string
Value string
Index int
}

type KVPair struct {
channelId string
Key string
Value string
}

type responseChannel struct {
m *sync.RWMutex
channel chan KVPair
}

type MyOram struct {
executor.UnimplementedExecutorServer
// rdb *redis.Client
o *ORAM

batchSize int

channelMap map[string]responseChannel
requestNumber atomic.Int64
channelLock sync.RWMutex

oramExecutorChannel chan *KVPair
}

type tempBlock struct {
Expand All @@ -34,55 +61,116 @@ type StringPair struct {
Second string
}

func (e MyOram) ExecuteBatch(ctx context.Context, req *executor.RequestBatchORAM) (*executor.RespondBatchORAM, error) {
fmt.Printf("Got a request with ID: %d \n", req.RequestId)

// set batchsize
batchSize := 60

// Batching(requests []request.Request, batchSize int)
func (e *MyOram) ExecuteBatch(ctx context.Context, req *executor.RequestBatchORAM) (*executor.RespondBatchORAM, error) {
if len(req.Keys) != len(req.Values) {
return nil, fmt.Errorf("keys and values length mismatch")
}

var replyKeys []string
var replyVals []string
reqNum := e.requestNumber.Add(1) // New id for this client/batch channel

for start := 0; start < len(req.Values); start += batchSize {
recv_resp := make([]KVPair, 0, len(req.Keys)) // This will store completed key value pairs

var requestList []Request
var returnValues []string
channelId := fmt.Sprintf("%d-%d", req.RequestId, reqNum)
localRespChannel := make(chan KVPair, len(req.Keys))

end := start + batchSize
if end > len(req.Values) {
end = len(req.Values) // Ensure we don't go out of bounds
e.channelLock.Lock() // Add channel to global map
e.channelMap[channelId] = responseChannel{
m: &sync.RWMutex{},
channel: localRespChannel,
}
e.channelLock.Unlock()

sent := 0
for i, key := range req.Keys {
value := req.Values[i]
kv := &KVPair{
channelId: channelId,
Key: key,
Value: value,
}
// Block if the channel is full
sent++
e.oramExecutorChannel <- kv
}

// Slice the keys and values for the current batch
batchKeys := req.Keys[start:end]
batchValues := req.Values[start:end]
// Finished adding keys to ORAM channel

for i := range batchKeys {
// Read operation
currentRequest := Request{
Key: batchKeys[i],
Value: batchValues[i],
}
// Now wait for responses
for i := 0; i < len(req.Keys); i++ {
item := <-localRespChannel
recv_resp = append(recv_resp, item)
}

requestList = append(requestList, currentRequest)
}
close(localRespChannel)

returnValues, _ = e.o.Batching(requestList, batchSize)
e.channelLock.Lock()
delete(e.channelMap, channelId)
e.channelLock.Unlock()

replyKeys = append(replyKeys, batchKeys...)
replyVals = append(replyVals, returnValues...)
sendKeys := make([]string, 0, len(req.Keys))
sendVal := make([]string, 0, len(req.Keys))

for _, v := range recv_resp {
sendKeys = append(sendKeys, v.Key)
sendVal = append(sendVal, v.Value)
}

// Return response with original request ID
return &executor.RespondBatchORAM{
RequestId: req.RequestId,
Keys: replyKeys,
Values: replyVals,
Keys: sendKeys,
Values: sendVal,
}, nil
}

func (e *MyOram) processBatches() {
for {

if len(e.oramExecutorChannel) >= e.batchSize {
var requestList []Request

var chanIds []string

for i := 0; i < e.batchSize; i++ {
op := <-e.oramExecutorChannel // Read from channel

chanIds = append(chanIds, op.channelId)

requestList = append(requestList, Request{
Key: op.Key,
Value: op.Value,
})
}
// Execute ORAM batch
returnValues, err := e.o.Batching(requestList, e.batchSize)
if err != nil {
// Handle error (e.g., log and continue)
fmt.Printf("ORAM batch error: %v\n", err)
continue
}

channelCache := make(map[string]chan KVPair, e.batchSize)

e.channelLock.RLock()
for _, v := range chanIds {

channelCache[v] = e.channelMap[v].channel

}
e.channelLock.RUnlock()

for i := 0; i < e.batchSize; i++ {
newKVPair := KVPair{
Key: requestList[i].Key,
Value: returnValues[i],
}
responseChannel := channelCache[chanIds[i]]
responseChannel <- newKVPair
}
}
}
}

func NewORAM(LogCapacity, Z, StashSize int, redisAddr string, tracefile string, useSnapshot bool, key []byte) (*MyOram, error) {
// If key is not provided (nil or empty), generate a random key
if len(key) == 0 {
Expand Down Expand Up @@ -111,6 +199,7 @@ func NewORAM(LogCapacity, Z, StashSize int, redisAddr string, tracefile string,
// Load the Stashmap and Keymap into memory
// Allow redis to update state using dump.rdb
oram.loadSnapshotMaps()
fmt.Println("ORAM snapshot loaded successfully!")
} else {
// Clear the Redis database to ensure a fresh start
if err := client.FlushDB(); err != nil {
Expand Down Expand Up @@ -171,9 +260,17 @@ func NewORAM(LogCapacity, Z, StashSize int, redisAddr string, tracefile string,
}

myOram := &MyOram{
o: oram,
o: oram,
batchSize: 60, // Set from config or constant
channelMap: make(map[string]responseChannel),
channelLock: sync.RWMutex{},
oramExecutorChannel: make(chan *KVPair),
}

myOram.oramExecutorChannel = make(chan *KVPair, 100000)

go myOram.processBatches() // Start batch processing

return myOram, nil
}

Expand Down