Skip to content

Commit c129834

Browse files
authored
fix(hset_family): Ensure empty hash sets are removed (#4873)
When a search operation is performed on a hash set, expired fields are removed as a side effect. If at the end of such an operation the hash set becomes empty, its key is removed from the database. Signed-off-by: Abhijat Malviya <abhijat@dragonflydb.io>
1 parent 2d96a57 commit c129834

File tree

2 files changed

+104
-26
lines changed

2 files changed

+104
-26
lines changed

src/server/hset_family.cc

Lines changed: 84 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ OpStatus IncrementValue(optional<string_view> prev_val, IncrByParam* param) {
166166
param->emplace<int64_t>(new_val);
167167

168168
return OpStatus::OK;
169-
};
169+
}
170170

171171
OpStatus OpIncrBy(const OpArgs& op_args, string_view key, string_view field, IncrByParam* param) {
172172
auto& db_slice = op_args.GetDbSlice();
@@ -264,6 +264,62 @@ OpStatus OpIncrBy(const OpArgs& op_args, string_view key, string_view field, Inc
264264
return OpStatus::OK;
265265
}
266266

267+
struct KeyCleanup {
268+
using CleanupFuncT = std::function<void(std::string_view)>;
269+
explicit KeyCleanup(CleanupFuncT func, const std::string_view key_view)
270+
: f{std::move(func)}, key{key_view} {
271+
}
272+
~KeyCleanup() {
273+
if (armed) {
274+
f(key);
275+
}
276+
}
277+
278+
void arm() {
279+
armed = true;
280+
}
281+
282+
CleanupFuncT f;
283+
std::string key;
284+
bool armed{false};
285+
};
286+
287+
void DeleteKey(DbSlice& db_slice, const OpArgs& op_args, std::string_view key) {
288+
if (auto del_it = db_slice.FindMutable(op_args.db_cntx, key, OBJ_HASH); del_it) {
289+
del_it->post_updater.Run();
290+
db_slice.Del(op_args.db_cntx, del_it->it);
291+
if (op_args.shard->journal()) {
292+
RecordJournal(op_args, "DEL"sv, {key});
293+
}
294+
}
295+
}
296+
297+
std::pair<OpResult<DbSlice::ConstIterator>, KeyCleanup> FindReadOnly(DbSlice& db_slice,
298+
const OpArgs& op_args,
299+
std::string_view key) {
300+
return std::pair{db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_HASH),
301+
KeyCleanup{[&](const auto& k) { DeleteKey(db_slice, op_args, k); }, key}};
302+
}
303+
304+
// The find and contains functions perform the usual search on string maps, with the added argument
305+
// KeyCleanup. This object is armed if the string map becomes empty during search due to keys being
306+
// expired. An armed object on destruction removes the key which has just become empty.
307+
StringMap::iterator Find(StringMap* sm, const std::string_view field, KeyCleanup& defer_cleanup) {
308+
auto it = sm->Find(field);
309+
if (sm->Empty()) {
310+
defer_cleanup.arm();
311+
}
312+
return it;
313+
}
314+
315+
bool Contains(StringMap* sm, const std::string_view field, KeyCleanup& defer_cleanup) {
316+
auto result = sm->Contains(field);
317+
if (sm->Empty()) {
318+
defer_cleanup.arm();
319+
}
320+
return result;
321+
}
322+
267323
OpResult<StringVec> OpScan(const OpArgs& op_args, std::string_view key, uint64_t* cursor,
268324
const ScanOpts& scan_op) {
269325
constexpr size_t HASH_TABLE_ENTRIES_FACTOR = 2; // return key/value
@@ -274,7 +330,8 @@ OpResult<StringVec> OpScan(const OpArgs& op_args, std::string_view key, uint64_t
274330
* of returning no or very few elements. (taken from redis code at db.c line 904 */
275331
constexpr size_t INTERATION_FACTOR = 10;
276332

277-
auto find_res = op_args.GetDbSlice().FindReadOnly(op_args.db_cntx, key, OBJ_HASH);
333+
DbSlice& db_slice = op_args.GetDbSlice();
334+
auto [find_res, defer_cleanup] = FindReadOnly(db_slice, op_args, key);
278335

279336
if (!find_res) {
280337
DVLOG(1) << "ScanOp: find failed: " << find_res << ", baling out";
@@ -328,6 +385,10 @@ OpResult<StringVec> OpScan(const OpArgs& op_args, std::string_view key, uint64_t
328385
do {
329386
*cursor = sm->Scan(*cursor, scanCb);
330387
} while (*cursor && max_iterations-- && res.size() < count);
388+
389+
if (sm->Empty()) {
390+
defer_cleanup.arm();
391+
}
331392
}
332393

333394
return res;
@@ -368,13 +429,15 @@ OpResult<uint32_t> OpDel(const OpArgs& op_args, string_view key, CmdArgList valu
368429
StringMap* sm = GetStringMap(pv, op_args.db_cntx);
369430

370431
for (auto s : values) {
371-
bool res = sm->Erase(ToSV(s));
372-
if (res) {
432+
if (sm->Erase(ToSV(s))) {
373433
++deleted;
374-
if (sm->UpperBoundSize() == 0) {
375-
key_remove = true;
376-
break;
377-
}
434+
}
435+
436+
// Even if the previous Erase op did not erase anything, it can remove expired fields as a
437+
// side effect.
438+
if (sm->Empty()) {
439+
key_remove = true;
440+
break;
378441
}
379442
}
380443
}
@@ -395,7 +458,7 @@ OpResult<vector<OptStr>> OpHMGet(const OpArgs& op_args, std::string_view key, Cm
395458
DCHECK(!fields.empty());
396459

397460
auto& db_slice = op_args.GetDbSlice();
398-
auto it_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_HASH);
461+
auto [it_res, defer_cleanup] = FindReadOnly(db_slice, op_args, key);
399462

400463
if (!it_res)
401464
return it_res.status();
@@ -443,8 +506,7 @@ OpResult<vector<OptStr>> OpHMGet(const OpArgs& op_args, std::string_view key, Cm
443506
StringMap* sm = GetStringMap(pv, op_args.db_cntx);
444507

445508
for (size_t i = 0; i < fields.size(); ++i) {
446-
auto it = sm->Find(ToSV(fields[i]));
447-
if (it != sm->end()) {
509+
if (auto it = Find(sm, ToSV(fields[i]), defer_cleanup); it != sm->end()) {
448510
result[i].emplace(it->second, sdslen(it->second));
449511
}
450512
}
@@ -468,7 +530,7 @@ OpResult<uint32_t> OpLen(const OpArgs& op_args, string_view key) {
468530

469531
OpResult<int> OpExist(const OpArgs& op_args, string_view key, string_view field) {
470532
auto& db_slice = op_args.GetDbSlice();
471-
auto it_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_HASH);
533+
auto [it_res, defer_cleanup] = FindReadOnly(db_slice, op_args, key);
472534

473535
if (!it_res) {
474536
if (it_res.status() == OpStatus::KEY_NOTFOUND)
@@ -486,13 +548,13 @@ OpResult<int> OpExist(const OpArgs& op_args, string_view key, string_view field)
486548

487549
DCHECK_EQ(kEncodingStrMap2, pv.Encoding());
488550
StringMap* sm = GetStringMap(pv, op_args.db_cntx);
489-
490-
return sm->Contains(field) ? 1 : 0;
551+
return Contains(sm, field, defer_cleanup) ? 1 : 0;
491552
};
492553

493554
OpResult<string> OpGet(const OpArgs& op_args, string_view key, string_view field) {
494555
auto& db_slice = op_args.GetDbSlice();
495-
auto it_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_HASH);
556+
auto [it_res, defer_cleanup] = FindReadOnly(db_slice, op_args, key);
557+
496558
if (!it_res)
497559
return it_res.status();
498560

@@ -510,12 +572,11 @@ OpResult<string> OpGet(const OpArgs& op_args, string_view key, string_view field
510572

511573
DCHECK_EQ(pv.Encoding(), kEncodingStrMap2);
512574
StringMap* sm = GetStringMap(pv, op_args.db_cntx);
513-
auto it = sm->Find(field);
514-
515-
if (it == sm->end())
516-
return OpStatus::KEY_NOTFOUND;
575+
if (const auto it = Find(sm, field, defer_cleanup); it != sm->end()) {
576+
return string(it->second, sdslen(it->second));
577+
}
517578

518-
return string(it->second, sdslen(it->second));
579+
return OpStatus::KEY_NOTFOUND;
519580
}
520581

521582
OpResult<vector<string>> OpGetAll(const OpArgs& op_args, string_view key, uint8_t mask) {
@@ -570,18 +631,15 @@ OpResult<vector<string>> OpGetAll(const OpArgs& op_args, string_view key, uint8_
570631
// and the enconding is guaranteed to be a DenseSet since we only support expiring
571632
// value with that enconding.
572633
if (res.empty()) {
573-
// post_updater will run immediately
574-
auto it = db_slice.FindMutable(op_args.db_cntx, key).it;
575-
576-
db_slice.Del(op_args.db_cntx, it);
634+
DeleteKey(db_slice, op_args, key);
577635
}
578636

579637
return res;
580638
}
581639

582640
OpResult<size_t> OpStrLen(const OpArgs& op_args, string_view key, string_view field) {
583641
auto& db_slice = op_args.GetDbSlice();
584-
auto it_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_HASH);
642+
auto [it_res, defer_cleanup] = FindReadOnly(db_slice, op_args, key);
585643

586644
if (!it_res) {
587645
if (it_res.status() == OpStatus::KEY_NOTFOUND)
@@ -601,7 +659,7 @@ OpResult<size_t> OpStrLen(const OpArgs& op_args, string_view key, string_view fi
601659
DCHECK_EQ(pv.Encoding(), kEncodingStrMap2);
602660
StringMap* sm = GetStringMap(pv, op_args.db_cntx);
603661

604-
auto it = sm->Find(field);
662+
auto it = Find(sm, field, defer_cleanup);
605663
return it != sm->end() ? sdslen(it->second) : 0;
606664
}
607665

src/server/hset_family_test.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,4 +521,24 @@ TEST_F(HSetFamilyTest, ScanAfterExpireSet) {
521521
EXPECT_THAT(vec, Contains("avalue").Times(1));
522522
}
523523

524+
TEST_F(HSetFamilyTest, KeyRemovedWhenEmpty) {
525+
auto test_cmd = [&](const std::function<void()>& f, const std::string_view tag) {
526+
EXPECT_THAT(Run({"HSET", "a", "afield", "avalue"}), IntArg(1));
527+
EXPECT_THAT(Run({"HEXPIRE", "a", "1", "FIELDS", "1", "afield"}), IntArg(1));
528+
AdvanceTime(1000);
529+
530+
EXPECT_THAT(Run({"EXISTS", "a"}), IntArg(1));
531+
f();
532+
EXPECT_THAT(Run({"EXISTS", "a"}), IntArg(0)) << "failed when testing " << tag;
533+
};
534+
535+
test_cmd([&] { EXPECT_THAT(Run({"HGET", "a", "afield"}), ArgType(RespExpr::NIL)); }, "HGET");
536+
test_cmd([&] { EXPECT_THAT(Run({"HGETALL", "a"}), RespArray(ElementsAre())); }, "HGETALL");
537+
test_cmd([&] { EXPECT_THAT(Run({"HDEL", "a", "afield"}), IntArg(0)); }, "HDEL");
538+
test_cmd([&] { EXPECT_THAT(Run({"HSCAN", "a", "0"}).GetVec()[0], "0"); }, "HSCAN");
539+
test_cmd([&] { EXPECT_THAT(Run({"HMGET", "a", "afield"}), ArgType(RespExpr::NIL)); }, "HMGET");
540+
test_cmd([&] { EXPECT_THAT(Run({"HEXISTS", "a", "afield"}), IntArg(0)); }, "HEXISTS");
541+
test_cmd([&] { EXPECT_THAT(Run({"HSTRLEN", "a", "afield"}), IntArg(0)); }, "HSTRLEN");
542+
}
543+
524544
} // namespace dfly

0 commit comments

Comments
 (0)