Skip to content

Commit

Permalink
nose -> pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern committed Apr 10, 2023
1 parent 72fc2ec commit febe362
Show file tree
Hide file tree
Showing 15 changed files with 1,145 additions and 1,180 deletions.
82 changes: 40 additions & 42 deletions test/accuracy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,59 +24,57 @@
except ImportError:
from urllib.request import urlretrieve # Python 3
import gzip
from nose.plugins.attrib import attr

class AccuracyTest(unittest.TestCase):
def _get_index(self, dataset):
url = 'http://vectors.erikbern.com/%s.hdf5' % dataset
vectors_fn = os.path.join('test', dataset + '.hdf5')
index_fn = os.path.join('test', dataset + '.annoy')
def _get_index(dataset):
url = 'http://vectors.erikbern.com/%s.hdf5' % dataset
vectors_fn = os.path.join('test', dataset + '.hdf5')
index_fn = os.path.join('test', dataset + '.annoy')

if not os.path.exists(vectors_fn):
print('downloading', url, '->', vectors_fn)
urlretrieve(url, vectors_fn)
if not os.path.exists(vectors_fn):
print('downloading', url, '->', vectors_fn)
urlretrieve(url, vectors_fn)

dataset_f = h5py.File(vectors_fn, 'r')
distance = dataset_f.attrs['distance']
f = dataset_f['train'].shape[1]
annoy = AnnoyIndex(f, distance)
dataset_f = h5py.File(vectors_fn, 'r')
distance = dataset_f.attrs['distance']
f = dataset_f['train'].shape[1]
annoy = AnnoyIndex(f, distance)

if not os.path.exists(index_fn):
print('adding items', distance, f)
for i, v in enumerate(dataset_f['train']):
annoy.add_item(i, v)
if not os.path.exists(index_fn):
print('adding items', distance, f)
for i, v in enumerate(dataset_f['train']):
annoy.add_item(i, v)

print('building index')
annoy.build(10)
annoy.save(index_fn)
else:
annoy.load(index_fn)
return annoy, dataset_f
print('building index')
annoy.build(10)
annoy.save(index_fn)
else:
annoy.load(index_fn)
return annoy, dataset_f

def _test_index(self, dataset, exp_accuracy):
annoy, dataset_f = self._get_index(dataset)
def _test_index(dataset, exp_accuracy):
annoy, dataset_f = _get_index(dataset)

n, k = 0, 0
n, k = 0, 0

for i, v in enumerate(dataset_f['test']):
js_fast = annoy.get_nns_by_vector(v, 10, 1000)
js_real = dataset_f['neighbors'][i][:10]
assert len(js_fast) == 10
assert len(js_real) == 10
for i, v in enumerate(dataset_f['test']):
js_fast = annoy.get_nns_by_vector(v, 10, 1000)
js_real = dataset_f['neighbors'][i][:10]
assert len(js_fast) == 10
assert len(js_real) == 10

n += 10
k += len(set(js_fast).intersection(js_real))
n += 10
k += len(set(js_fast).intersection(js_real))

accuracy = 100.0 * k / n
print('%50s accuracy: %5.2f%% (expected %5.2f%%)' % (dataset, accuracy, exp_accuracy))
accuracy = 100.0 * k / n
print('%50s accuracy: %5.2f%% (expected %5.2f%%)' % (dataset, accuracy, exp_accuracy))

self.assertTrue(accuracy > exp_accuracy - 1.0) # should be within 1%
assert accuracy > exp_accuracy - 1.0 # should be within 1%

def test_glove_25(self):
self._test_index('glove-25-angular', 69.00)
def test_glove_25():
_test_index('glove-25-angular', 69.00)

def test_nytimes_16(self):
self._test_index('nytimes-16-angular', 80.00)
def test_nytimes_16():
_test_index('nytimes-16-angular', 80.00)

def test_fashion_mnist(self):
self._test_index('fashion-mnist-784-euclidean', 90.00)
def test_fashion_mnist():
_test_index('fashion-mnist-784-euclidean', 90.00)
Loading

0 comments on commit febe362

Please sign in to comment.