forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathHashStoreTest.cpp
99 lines (82 loc) · 2.75 KB
/
HashStoreTest.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include <c10/util/irange.h>
#include "StoreTestCommon.hpp"
#include <unistd.h>
#include <iostream>
#include <thread>
#include <c10d/HashStore.hpp>
#include <c10d/PrefixStore.hpp>
constexpr int64_t kShortStoreTimeoutMillis = 100;
void testGetSet(std::string prefix = "") {
// Basic set/get
{
auto hashStore = c10::make_intrusive<c10d::HashStore>();
c10d::PrefixStore store(prefix, hashStore);
c10d::test::set(store, "key0", "value0");
c10d::test::set(store, "key1", "value1");
c10d::test::set(store, "key2", "value2");
c10d::test::check(store, "key0", "value0");
c10d::test::check(store, "key1", "value1");
c10d::test::check(store, "key2", "value2");
// Check compareSet, does not check return value
c10d::test::compareSet(store, "key0", "wrongExpectedValue", "newValue");
c10d::test::check(store, "key0", "value0");
c10d::test::compareSet(store, "key0", "value0", "newValue");
c10d::test::check(store, "key0", "newValue");
auto numKeys = store.getNumKeys();
EXPECT_EQ(numKeys, 3);
auto delSuccess = store.deleteKey("key0");
EXPECT_TRUE(delSuccess);
numKeys = store.getNumKeys();
EXPECT_EQ(numKeys, 2);
auto delFailure = store.deleteKey("badKeyName");
EXPECT_FALSE(delFailure);
auto timeout = std::chrono::milliseconds(kShortStoreTimeoutMillis);
store.setTimeout(timeout);
EXPECT_THROW(store.get("key0"), std::runtime_error);
}
// get() waits up to timeout_.
{
auto hashStore = c10::make_intrusive<c10d::HashStore>();
c10d::PrefixStore store(prefix, hashStore);
std::thread th([&]() { c10d::test::set(store, "key0", "value0"); });
c10d::test::check(store, "key0", "value0");
th.join();
}
}
void stressTestStore(std::string prefix = "") {
// Hammer on HashStore::add
const auto numThreads = 4;
const auto numIterations = 100;
std::vector<std::thread> threads;
c10d::test::Semaphore sem1, sem2;
auto hashStore = c10::make_intrusive<c10d::HashStore>();
c10d::PrefixStore store(prefix, hashStore);
for (C10_UNUSED const auto i : c10::irange(numThreads)) {
threads.emplace_back(std::thread([&] {
sem1.post();
sem2.wait();
for (C10_UNUSED const auto j : c10::irange(numIterations)) {
store.add("counter", 1);
}
}));
}
sem1.wait(numThreads);
sem2.post(numThreads);
for (auto& thread : threads) {
thread.join();
}
std::string expected = std::to_string(numThreads * numIterations);
c10d::test::check(store, "counter", expected);
}
TEST(HashStoreTest, testGetAndSet) {
testGetSet();
}
TEST(HashStoreTest, testGetAndSetWithPrefix) {
testGetSet("testPrefix");
}
TEST(HashStoreTest, testStressStore) {
stressTestStore();
}
TEST(HashStoreTest, testStressStoreWithPrefix) {
stressTestStore("testPrefix");
}