Skip to content

Commit 46e853b

Browse files
committed
Add basic concurrency support in kmeans
1 parent 5838de1 commit 46e853b

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

kmeans.go

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"image/color"
66
"math"
77
"math/rand"
8+
"sync"
89
)
910

1011
type cluster struct {
@@ -37,14 +38,23 @@ func getClusters(k int, set image.Image) []*cluster {
3738
}
3839

3940
func partition(clr []*cluster, set image.Image) {
41+
var w sync.WaitGroup
42+
w.Add(set.Bounds().Max.Y * set.Bounds().Max.X)
43+
var mtx sync.Mutex
44+
4045
// assign all data points to the nearest cluster
4146
for y := set.Bounds().Min.Y; y < set.Bounds().Max.Y; y++ {
4247
for x := set.Bounds().Min.X; x < set.Bounds().Max.X; x++ {
43-
v := set.At(x, y)
44-
i := indexNewCentroid(clr, v)
45-
clr[i].members = append(clr[i].members, image.Pt(x, y))
48+
go func(x, y int) {
49+
i := indexNewCentroid(clr, set.At(x, y))
50+
mtx.Lock()
51+
clr[i].members = append(clr[i].members, image.Pt(x, y)) // with go
52+
mtx.Unlock()
53+
w.Done()
54+
}(x, y)
4655
}
4756
}
57+
w.Wait()
4858

4959
var getAverage = func(pts []image.Point) color.Color {
5060
lPts := len(pts)
@@ -63,7 +73,7 @@ func partition(clr []*cluster, set image.Image) {
6373
i := uint8(rSum / uint32(lPts))
6474
j := uint8(gSum / uint32(lPts))
6575
k := uint8(bSum / uint32(lPts))
66-
return color.RGBA{R: i, G: j, B: k, A: 0}
76+
return color.RGBA{R: i, G: j, B: k, A: 100}
6777
}
6878

6979
// update each centroid per cluster
@@ -94,9 +104,8 @@ func indexNewCentroid(clr []*cluster, p color.Color) (idx int) {
94104
min := uint32(math.MaxUint32)
95105

96106
for i, c := range clr {
97-
cl := c.centroid
98107
// find cluster with lowest color-distance
99-
d := euclidDis(cl, p)
108+
d := euclidDis(c.centroid, p)
100109
if d < min {
101110
min = d
102111
idx = i

main.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,14 @@ func main() {
3535
log.Fatal(err)
3636
}
3737

38+
img := clusterImage(*k, imgData)
39+
3840
cImg, err := os.Create(*outPath)
3941
if err != nil {
4042
log.Fatal(err)
4143
}
4244
defer cImg.Close()
4345
defer cImg.Seek(0, 0)
4446

45-
img := clusterImage(*k, imgData)
4647
png.Encode(cImg, img)
4748
}

0 commit comments

Comments
 (0)