Skip to content

Commit

Permalink
feat: support parallel compute hash function
Browse files Browse the repository at this point in the history
  • Loading branch information
flywukong committed Jul 12, 2023
1 parent 42e346d commit 6ec64a8
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 10 deletions.
5 changes: 5 additions & 0 deletions go/hash/checksum.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ import (
"fmt"
)

type SegmentInfo struct {
SegmentId int
Data []byte
}

// GenerateChecksum generates the checksum of one piece data
func GenerateChecksum(pieceData []byte) []byte {
hash := sha256.New()
Expand Down
167 changes: 158 additions & 9 deletions go/hash/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package hash
import (
"bytes"
"errors"
"fmt"
"io"
"os"
"runtime"
"sync"

storageTypes "github.com/bnb-chain/greenfield/x/storage/types"
Expand All @@ -13,6 +15,12 @@ import (
"github.com/bnb-chain/greenfield-common/go/redundancy"
)

const maxThreadNum = 5

type ComputeHashOption struct {
mutex *sync.Mutex
}

// IntegrityHasher compute integrityHash
type IntegrityHasher struct {
ecDataHashes [][][]byte
Expand Down Expand Up @@ -179,16 +187,9 @@ func ComputeIntegrityHash(reader io.Reader, segmentSize int64, dataShards, parit
// compute segment hash
checksum := GenerateChecksum(data)
segChecksumList = append(segChecksumList, checksum)
// get erasure encode bytes
encodeShards, err := redundancy.EncodeRawSegment(data, dataShards, parityShards)
if err != nil {
return nil, 0, storageTypes.REDUNDANCY_EC_TYPE, err
}

for index, shard := range encodeShards {
// compute hash of pieces
piecesHash := GenerateChecksum(shard)
encodeDataHash[index] = append(encodeDataHash[index], piecesHash)
if err = encodeAndComputeHash(encodeDataHash, data, dataShards, parityShards); err != nil {
return nil, 0, storageTypes.REDUNDANCY_EC_TYPE, err
}
}
}
Expand All @@ -212,6 +213,22 @@ func ComputeIntegrityHash(reader io.Reader, segmentSize int64, dataShards, parit
return hashList, contentLen, storageTypes.REDUNDANCY_EC_TYPE, nil
}

func encodeAndComputeHash(encodeDataHash [][][]byte, segment []byte, dataShards, parityShards int) error {
// get erasure encode bytes
encodeShards, err := redundancy.EncodeRawSegment(segment, dataShards, parityShards)
if err != nil {
return err
}

for index, shard := range encodeShards {
// compute hash of pieces
piecesHash := GenerateChecksum(shard)
encodeDataHash[index] = append(encodeDataHash[index], piecesHash)
}

return nil
}

// ComputerHashFromFile open a local file and compute hash result and segmentSize
func ComputerHashFromFile(filePath string, segmentSize int64, dataShards, parityShards int) ([][]byte, int64, storageTypes.RedundancyType, error) {
f, err := os.Open(filePath)
Expand All @@ -229,3 +246,135 @@ func ComputerHashFromBuffer(content []byte, segmentSize int64, dataShards, parit
reader := bytes.NewReader(content)
return ComputeIntegrityHash(reader, segmentSize, dataShards, parityShards)
}

func computePieceHashes(segment []byte, dataShards, parityShards int) ([][]byte, error) {
// get erasure encode bytes
encodeShards, err := redundancy.EncodeRawSegment(segment, dataShards, parityShards)
if err != nil {
return nil, err
}

var pieceChecksumList [][]byte
for _, shard := range encodeShards {
// compute hash of pieces
piecesHash := GenerateChecksum(shard)
pieceChecksumList = append(pieceChecksumList, piecesHash)
}

return pieceChecksumList, nil
}

func hashWorker(jobs <-chan SegmentInfo, errChan chan<- error, dataShards, parityShards int, wg *sync.WaitGroup, checksumMap *sync.Map, pieceHashMap *sync.Map) {
defer wg.Done()

for segInfo := range jobs {
checksum := GenerateChecksum(segInfo.Data)
checksumMap.Store(segInfo.SegmentId, checksum)

pieceCheckSumList, err := computePieceHashes(segInfo.Data, dataShards, parityShards)
if err != nil {
errChan <- err
return
}
pieceHashMap.Store(segInfo.SegmentId, pieceCheckSumList)
}
}

func ComputeIntegrityHashParallel(reader io.Reader, segmentSize int64, dataShards, parityShards int) ([][]byte, int64, storageTypes.RedundancyType, error) {
var (
segChecksumList [][]byte
ecShards = dataShards + parityShards
contentLen = int64(0)
wg sync.WaitGroup
)

segHashMap := &sync.Map{}
pieceHashMap := &sync.Map{}
encodeDataHash := make([][][]byte, ecShards)

hashList := make([][]byte, ecShards+1)

jobChan := make(chan SegmentInfo, 100)
errChan := make(chan error, 1)
// the thread num should be less than maxThreadNum
threadNum := runtime.NumCPU() / 2
if threadNum > maxThreadNum {
threadNum = maxThreadNum
}
// start workers to compute hash of each segment
for i := 0; i < threadNum; i++ {
wg.Add(1)
go hashWorker(jobChan, errChan, dataShards, parityShards, &wg, segHashMap, pieceHashMap)
}

jobNum := 0
for {
seg := make([]byte, segmentSize)
n, err := reader.Read(seg)
if err != nil {
if err != io.EOF {
log.Error().Msg("failed to read content:" + err.Error())
return nil, 0, storageTypes.REDUNDANCY_EC_TYPE, err
}
break
}

if n > 0 && n <= int(segmentSize) {
contentLen += int64(n)
data := seg[:n]
// compute segment hash

jobChan <- SegmentInfo{SegmentId: jobNum, Data: data}
jobNum++
}
}
close(jobChan)

for i := 0; i < ecShards; i++ {
encodeDataHash[i] = make([][]byte, jobNum)
}

wg.Wait()
close(errChan)

// check error
for err := range errChan {
if err != nil {
log.Error().Msg("err chan detected err:" + err.Error())
return nil, 0, storageTypes.REDUNDANCY_EC_TYPE, err
}
}

for i := 0; i < jobNum; i++ {
value, ok := segHashMap.Load(i)
if !ok {
return nil, 0, storageTypes.REDUNDANCY_EC_TYPE, fmt.Errorf("fail to load the segment hash")
}
segChecksumList = append(segChecksumList, value.([]byte))

pieceHashes, ok := pieceHashMap.Load(i)
if !ok {
return nil, 0, storageTypes.REDUNDANCY_EC_TYPE, fmt.Errorf("fail to load the segment hash")
}
hashValues := pieceHashes.([][]byte)
for j := 0; j < len(encodeDataHash); j++ {
encodeDataHash[j][i] = hashValues[j]
}
}

// combine the hash root of pieces of the PrimarySP
hashList[0] = GenerateIntegrityHash(segChecksumList)

// compute the integrity hash of the SecondarySP
spLen := len(encodeDataHash)
wg.Add(spLen)
for spID, content := range encodeDataHash {
go func(data [][]byte, id int) {
defer wg.Done()
hashList[id+1] = GenerateIntegrityHash(data)
}(content, spID)
}

wg.Wait()
return hashList, contentLen, storageTypes.REDUNDANCY_EC_TYPE, nil
}
61 changes: 60 additions & 1 deletion go/hash/hash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,66 @@ func TestHashResult(t *testing.T) {

for id, hash := range hashList {
if base64.StdEncoding.EncodeToString(hash) != expectedHashList[id] {
t.Errorf("compare hash error")
t.Errorf("compare hash error, id: %d, hash1, %s, hash2 %s \n", id, base64.StdEncoding.EncodeToString(hash), expectedHashList[id])
}
}
}

func TestParallelHashResult(t *testing.T) {
var buffer bytes.Buffer
line := `1234567890,1234567890,1234567890,1234567890,1234567890,1234567890,1234567890,1234567890,1234567890`

// generate 98 buffer
for i := 0; i < 1024*1024; i++ {
buffer.WriteString(fmt.Sprintf("[%05d] %s\n", i, line))
}

hashList, _, _, err := ComputeIntegrityHashParallel(bytes.NewReader(buffer.Bytes()), int64(segmentSize), redundancy.DataBlocks, redundancy.ParityBlocks)
if err != nil {
t.Errorf(err.Error())
}

// this is generated from sp side
expectedHashList := []string{
"6YA/kt2H0pS6+/tyR20LCqqeWmNCelS4wQcEUIhnAko=",
"C00Wks+pfo6NBQkG8iRGN5M0EtTvUAwMyaQ8+RsG4rA=",
"Z5AW9CvNIsDo9jtxeQysSpn2ayNml3Kr4ksm/2WUu8s=",
"dMlsKDw2dGRUygEgkyHJvOHYn9jVtycpUb7zvIGvEEk=",
"v7vNLlbIg+27zFAOYfT2UDkoAId53Z1gDkcTA7VWT5A=",
"1b7QsyQ8QT+7UoMU7K1SRhKOfIylogIfrSFsKJUfi4U=",
"/7A2gwAnaJ5jFuK6sbov6iFAkhfOga4wdAK/NlCuJBo=",
}

for id, hash := range hashList {
if base64.StdEncoding.EncodeToString(hash) != expectedHashList[id] {
t.Errorf("compare hash error, id: %d, hash1, %s, hash2 %s \n", id, base64.StdEncoding.EncodeToString(hash), expectedHashList[id])
}
}
}

func TestCompareHashResult(t *testing.T) {
var buffer bytes.Buffer
line := `1234567890,1234567890,1234567890,1234567890,1234567890,1234567890,1234567890,1234567890,1234567890`

for i := 0; i < 1024*500; i++ {
buffer.WriteString(fmt.Sprintf("[%05d] %s\n", i, line))
}

hashList, _, _, err := ComputeIntegrityHash(bytes.NewReader(buffer.Bytes()), int64(segmentSize), redundancy.DataBlocks, redundancy.ParityBlocks)
if err != nil {
t.Errorf(err.Error())
}

expectedHashList := hashList
hashList, _, _, err = ComputeIntegrityHashParallel(bytes.NewReader(buffer.Bytes()), int64(segmentSize), redundancy.DataBlocks, redundancy.ParityBlocks)
if err != nil {
t.Errorf(err.Error())
}

// Compare serial and parallel version results
for id, hash := range hashList {
if !bytes.Equal(hash, expectedHashList[id]) {
t.Errorf("compare hash error, id: %d, hash1, %s, hash2 %s \n", id, hash, expectedHashList[id])
}
}
}
Expand Down

0 comments on commit 6ec64a8

Please sign in to comment.