Skip to content

Commit

Permalink
use data points instead of file as input to query test
Browse files Browse the repository at this point in the history
  • Loading branch information
ekzhu committed Dec 20, 2015
1 parent 0b72b5e commit 25fda1f
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 201 deletions.
32 changes: 3 additions & 29 deletions src/lsh/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,40 +4,14 @@ import "sync"

type QueryFunc func(DataPoint) QueryResult

func QueryIndex(queryIter *PointIterator, queryFunc QueryFunc) QueryResults {
func ParallelQueryIndex(input []DataPoint, queryFunc QueryFunc, nWorker int) QueryResults {

// Input Thread
queries := make(chan DataPoint)
go func() {
p, err := queryIter.Next()
for err == nil {
queries <- p
p, err = queryIter.Next()
for _, q := range input {
queries <- q
}
queryIter.Close()
close(queries)
}()

results := make(QueryResults, 0)
for q := range queries {
r := queryFunc(q)
results = append(results, r)
}
return results
}

func ParallelQueryIndex(queryIter *PointIterator, queryFunc QueryFunc,
nWorker int) QueryResults {

// Input Thread
queries := make(chan DataPoint)
go func() {
p, err := queryIter.Next()
for err == nil {
queries <- p
p, err = queryIter.Next()
}
queryIter.Close()
close(queries)
}()

Expand Down
117 changes: 30 additions & 87 deletions src/lsh/knn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,13 @@ package lsh

import (
"container/heap"
"sort"
"time"
)

type Candidate struct {
id int
distance float64
}

type KHeap struct {
k int
candidates []Candidate
candidates []Neighbour
}

func (h KHeap) Len() int {
Expand All @@ -22,21 +18,21 @@ func (h KHeap) Len() int {
func (h KHeap) Less(i, j int) bool {
// We want to pop out the candidate with largest distance
// so we use greater than here
return h.candidates[i].distance > h.candidates[j].distance
return h.candidates[i].Distance > h.candidates[j].Distance
}

func (h KHeap) Swap(i, j int) {
h.candidates[i], h.candidates[j] = h.candidates[j], h.candidates[i]
}

func (h *KHeap) Push(x interface{}) {
c := x.(Candidate)
c := x.(Neighbour)
if len(h.candidates) < h.k {
h.candidates = append(h.candidates, c)
return
}
// Check if we can still push to the top-k heap when it is full
if h.candidates[0].distance > c.distance {
if h.candidates[0].Distance > c.Distance {
heap.Pop(h)
heap.Push(h, c)
}
Expand All @@ -52,136 +48,83 @@ func (h *KHeap) Pop() interface{} {
}

func NewKHeap(k int) *KHeap {
h := make([]Candidate, 0)
h := make([]Neighbour, 0)
return &KHeap{k, h}
}

type Knn struct {
data []Point
ids []int
points []DataPoint
}

func NewKnn(data []Point, ids []int) *Knn {
if len(data) != len(ids) {
panic("Mismatch between size of data and ids")
}
return &Knn{data, ids}
func NewKnn(points []DataPoint) *Knn {
return &Knn{points}
}

// Query outputs the top-k closest points from the query point
// to the chanel out. The sequence of output is NOT sorted by
// distance.
func (knn *Knn) Query(q Point, k int, out chan int) {
func (knn *Knn) Query(q Point, k int, out chan Neighbour) {
kheap := NewKHeap(k)
heap.Init(kheap)
for i, p := range knn.data {
d := p.L2(q)
heap.Push(kheap, Candidate{knn.ids[i], d})
for _, p := range knn.points {
d := p.Point.L2(q)
heap.Push(kheap, Neighbour{p.Id, d})
}
for i := range kheap.candidates {
out <- kheap.candidates[i].id
out <- kheap.candidates[i]
}
}

// RunKnn executes the KNN experiment
func RunKnn(datafile, output string, k, nQuery, nWorker int, parser *PointParser) {
// Load data
nData := CountPoint(datafile, parser.ByteLen)
iter := NewDataPointIterator(datafile, parser)
data := make([]Point, nData)
ids := make([]int, nData)
for i := 0; i < nData; i++ {
p, err := iter.Next()
if err != nil {
panic(err.Error())
}
data[i] = p.Point
ids[i] = p.Id
}

// Run Knn
knn := NewKnn(data, ids)
func RunKnn(data []DataPoint, queries []DataPoint,
output string, k, nWorker int) {
knn := NewKnn(data)
queryFunc := func(q DataPoint) QueryResult {
start := time.Now()
out := make(chan int)
out := make(chan Neighbour)
go func() {
knn.Query(q.Point, k, out)
close(out)
}()
r := make([]int, 0)
ns := make(Neighbours, 0)
for i := range out {
r = append(r, i)
ns = append(ns, i)
}
dur := time.Since(start)
ns := make(Neighbours, len(r))
for i := range r {
ns[i] = Neighbour{
Id: r[i],
Distance: q.Point.L2(data[i]),
}
}
sort.Sort(ns)
return QueryResult{
QueryId: q.Id,
Neighbours: ns,
Time: float64(dur) / float64(time.Millisecond),
}
}
// Select queries
queryIds := SelectQueries(nData, nQuery)
iter = NewQueryPointIterator(datafile, parser, queryIds)
// Run queries in parallel
results := ParallelQueryIndex(iter, queryFunc, nWorker)
results := ParallelQueryIndex(queries, queryFunc, nWorker)
DumpJson(output, results)
}

// RunKnn executes the KNN experiment
func RunKnnSampleAllPair(datafile, output string, nSample, nWorker int, parser *PointParser) {
// Load data
nData := CountPoint(datafile, parser.ByteLen)
pointIds := SelectQueries(nData, nSample)
iter := NewQueryPointIterator(datafile, parser, pointIds)
data := make([]Point, nSample)
ids := make([]int, nSample)
for i := 0; i < nSample; i++ {
p, err := iter.Next()
if err != nil {
panic(err.Error())
}
data[i] = p.Point
ids[i] = p.Id
}

// Run Knn
knn := NewKnn(data, ids)
func RunKnnSampleAllPair(data []DataPoint, output string, nWorker int) {
knn := NewKnn(data)
nSample := len(data)
queryFunc := func(q DataPoint) QueryResult {
start := time.Now()
out := make(chan int)
out := make(chan Neighbour)
go func() {
knn.Query(q.Point, nSample, out)
close(out)
}()
r := make([]int, 0)
ns := make(Neighbours, 0)
for i := range out {
r = append(r, i)
ns = append(ns, i)
}
dur := time.Since(start)
ns := make(Neighbours, len(r))
for i := range r {
ns[i] = Neighbour{
Id: r[i],
Distance: q.Point.L2(data[i]),
}
}
sort.Sort(ns)
return QueryResult{
QueryId: q.Id,
Neighbours: ns,
Time: float64(dur) / float64(time.Millisecond),
}
}
// Select queries
queryIds := SelectQueries(nData, nSample)
iter = NewQueryPointIterator(datafile, parser, queryIds)
// Run queries in parallel
results := ParallelQueryIndex(iter, queryFunc, nWorker)
results := ParallelQueryIndex(data, queryFunc, nWorker)
DumpJson(output, results)
}
22 changes: 11 additions & 11 deletions src/lsh/knn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func Test_KHeap(t *testing.T) {
distances := make([]float64, len(points))
for i := range points {
distances[i] = points[i].L2(q)
c := Candidate{i, distances[i]}
c := Neighbour{i, distances[i]}
heap.Push(h, c)
t.Log(c)
t.Log(h.candidates)
Expand All @@ -26,8 +26,8 @@ func Test_KHeap(t *testing.T) {
sort.Float64s(distances)
topK := make([]float64, k)
for i := 0; i < k; i++ {
c := heap.Pop(h).(Candidate)
topK[i] = c.distance
c := heap.Pop(h).(Neighbour)
topK[i] = c.Distance
}
for i := range topK {
if topK[i] != distances[k-1-i] {
Expand All @@ -40,10 +40,6 @@ func Test_KHeap(t *testing.T) {
func Test_Knn(t *testing.T) {
k := 5
points := randomPoints(20, 10, 100.0)
ids := make([]int, len(points))
for i := range points {
ids[i] = i
}
q := points[0]
// Build ground truth
distances := make([]float64, len(points))
Expand All @@ -52,16 +48,20 @@ func Test_Knn(t *testing.T) {
}
sort.Float64s(distances)
t.Log("Ground truth distances", distances[:k])
data := make([]DataPoint, len(points))
for i := range points {
data[i] = DataPoint{i, points[i]}
}
// Test Knn query
knn := NewKnn(points, ids)
out := make(chan int)
knn := NewKnn(data)
out := make(chan Neighbour)
go func() {
knn.Query(q, k, out)
close(out)
}()
for id := range out {
for n := range out {
// get the point
p := points[id]
p := points[n.Id]
// get the distance
d := p.L2(q)
t.Log(d)
Expand Down
33 changes: 9 additions & 24 deletions src/lsh/run_forest.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,15 @@ import (
"time"
)

func RunForest(datafile, output string,
func RunForest(data []DataPoint, queries []DataPoint,
output string,
k, nQuery, nWorker int,
parser *PointParser,
dim, m, l int, w float64) {

// Load data
nData := CountPoint(datafile, parser.ByteLen)
iter := NewDataPointIterator(datafile, parser)
data := make([]Point, nData)
ids := make([]int, nData)
for i := 0; i < nData; i++ {
p, err := iter.Next()
if err != nil {
panic(err.Error())
}
data[i] = p.Point
ids[i] = p.Id
}

// Build forest index
forest := NewLshForest(dim, l, m, w)
for i, p := range data {
forest.Insert(p, ids[i])
for _, p := range data {
forest.Insert(p.Point, p.Id)
}
// Forest query function wrapper
queryFunc := func(q DataPoint) QueryResult {
Expand All @@ -45,8 +31,10 @@ func RunForest(datafile, output string,
ns := make(Neighbours, len(r))
for i := range r {
ns[i] = Neighbour{
Id: r[i],
Distance: q.Point.L2(data[i]),
Id: r[i],
// We assume the id is equal to the index
// of the data point in the input data
Distance: q.Point.L2(data[r[i]].Point),
}
}
sort.Sort(ns)
Expand All @@ -59,10 +47,7 @@ func RunForest(datafile, output string,
Time: float64(dur) / float64(time.Millisecond),
}
}
// Select queries
queryIds := SelectQueries(nData, nQuery)
iter = NewQueryPointIterator(datafile, parser, queryIds)
// Run queries in parallel
results := ParallelQueryIndex(iter, queryFunc, nWorker)
results := ParallelQueryIndex(queries, queryFunc, nWorker)
DumpJson(output, results)
}
Loading

0 comments on commit 25fda1f

Please sign in to comment.