Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/defrag.c
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,9 @@ void *activeDefragHfieldAndUpdateRef(void *ptr, void *privdata) {

/* Before the key is released, obtain the link to
* ensure we can safely access and update the key. */
dictUseStoredKeyApi(d, 1);
const void *key = dictStoredKey2Key(d, ptr);
link = dictFindLink(d, ptr, NULL);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚨 Bug: defrag.c: extracted lookup key is unused in dictFindLink call

In activeDefragHfieldAndUpdateRef(), line 280 correctly extracts the lookup key via dictStoredKey2Key(d, ptr) into the local variable key, but line 281 still passes the original ptr (the stored-key) to dictFindLink(). Since dictFindLink expects a lookup key (it hashes via hashFunction and compares using keyCompare), passing the stored-key will compute the wrong hash and/or fail comparisons when keyFromStoredKey is set.

The PR context explicitly identifies this as one of the critical bugs to fix: "The function calls dictFindLink(d, ptr, NULL) where ptr is a stored-key... This should be converted to a lookup key first."

This will cause defrag to fail to find entries in hash dictionaries that use the stored-key optimization (e.g., dbDictType), potentially leading to assertion failures on line 282 (serverAssert(link)) or silent data corruption if the assertion is disabled.

Was this helpful? React with 👍 / 👎

Suggested change
link = dictFindLink(d, ptr, NULL);
link = dictFindLink(d, key, NULL);
  • Apply suggested fix

serverAssert(link);
dictUseStoredKeyApi(d, 0);

Entry *newEntry = activeDefragEntry(ptr);
if (newEntry)
Expand Down Expand Up @@ -481,7 +480,6 @@ void activeDefragLuaScriptDictCallback(void *privdata, const dictEntry *de, dict
}

void activeDefragHfieldDictCallback(void *privdata, const dictEntry *de, dictEntryLink plink) {
UNUSED(plink);
dict *d = privdata;
Entry *newEntry = NULL, *entry = dictGetKey(de);

Expand All @@ -490,7 +488,7 @@ void activeDefragHfieldDictCallback(void *privdata, const dictEntry *de, dictEnt
* during the hash expiry ebuckets defragmentation phase. */
if (entryGetExpiry(entry) == EB_EXPIRE_TIME_INVALID) {
if ((newEntry = activeDefragEntry(entry))) {
/* Hash dicts use no_value=1, so we must use dictSetKeyAtLink */
/* Hash dicts use no_value=1, so we must use dictSetKeyAtLink */
dictSetKeyAtLink(d, newEntry, &plink, 0);
}
}
Expand Down
79 changes: 41 additions & 38 deletions src/dict.c
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ static void dictSetNext(dictEntry *de, dictEntry *next);
static int dictDefaultCompare(dictCmpCache *cache, const void *key1, const void *key2);
static dictEntryLink dictFindLinkInternal(dict *d, const void *key, dictEntryLink *bucket);
dictEntryLink dictFindLinkForInsert(dict *d, const void *key, dictEntry **existing);
static dictEntry *dictInsertKeyAtLink(dict *d, void *key, dictEntryLink link);
static dictEntry *dictInsertKeyAtLink(dict *d, void *key __stored_key, dictEntryLink link);

/* -------------------------- unused --------------------------- */
void dictSetSignedIntegerVal(dictEntry *de, int64_t val);
Expand All @@ -89,18 +89,22 @@ int64_t dictIncrSignedIntegerVal(dictEntry *de, int64_t val);

typedef int (*keyCmpFunc)(dictCmpCache *cache, const void *key1, const void *key2);
static inline keyCmpFunc dictGetCmpFunc(dict *d) {
if (d->useStoredKeyApi && d->type->storedKeyCompare)
return d->type->storedKeyCompare;
if (d->type->keyCompare)
return d->type->keyCompare;
return dictDefaultCompare;
}

static inline uint64_t dictHashKey(dict *d, const void *key, int isStoredKey) {
if (isStoredKey && d->type->storedHashFunction)
return d->type->storedHashFunction(key);
else
return d->type->hashFunction(key);
static const void *dictStoredKey2Key(dict *d, const void *key __stored_key) {
return (d->type->keyFromStoredKey) ? d->type->keyFromStoredKey(key) : key;
}

/* Validate that stored-key to key conversion works correctly */
static int validateStoredKeyConversion(dict *d, const void *key __stored_key) {
const void *extracted = dictStoredKey2Key(d, key);
if (d->type->keyFromStoredKey) {
return extracted != NULL;
}
return extracted == key;
}

/* -------------------------- hash functions -------------------------------- */
Expand All @@ -118,7 +122,7 @@ uint64_t siphash(const uint8_t *in, const size_t inlen, const uint8_t *k);
uint64_t siphash_nocase(const uint8_t *in, const size_t inlen, const uint8_t *k);

uint64_t dictGenHashFunction(const void *key, size_t len) {
return siphash(key,len,dict_hash_function_seed);
return siphash(key, len, dict_hash_function_seed);
}

uint64_t dictGenCaseHashFunction(const unsigned char *buf, size_t len) {
Expand Down Expand Up @@ -150,7 +154,7 @@ static inline int entryIsNormal(const dictEntry *de) {
}

/* Creates an entry without a value field. */
static inline dictEntry *createEntryNoValue(void *key, dictEntry *next) {
static inline dictEntry *createEntryNoValue(void *key __stored_key, dictEntry *next) {
dictEntryNoValue *entry = zmalloc(sizeof(*entry));
entry->key = key;
entry->next = next;
Expand Down Expand Up @@ -222,7 +226,6 @@ int _dictInit(dict *d, dictType *type)
d->rehashidx = -1;
d->pauserehash = 0;
d->pauseAutoResize = 0;
d->useStoredKeyApi = 0;
return DICT_OK;
}

Expand Down Expand Up @@ -333,10 +336,11 @@ static void rehashEntriesInBucketAtIndex(dict *d, uint64_t idx) {
dictEntry *nextde;
while (de) {
nextde = dictGetNext(de);
void *key = dictGetKey(de);
void *storedKey = dictGetKey(de);
/* Get the index in the new hash table */
if (d->ht_size_exp[1] > d->ht_size_exp[0]) {
h = dictHashKey(d, key, 1) & DICTHT_SIZE_MASK(d->ht_size_exp[1]);
const void *key = dictStoredKey2Key(d, storedKey);
h = dictGetHash(d, key) & DICTHT_SIZE_MASK(d->ht_size_exp[1]);
} else {
/* We're shrinking the table. The tables sizes are powers of
* two, so we simply mask the bucket index in the larger table
Expand All @@ -351,13 +355,13 @@ static void rehashEntriesInBucketAtIndex(dict *d, uint64_t idx) {
if (!entryIsKey(de)) zfree(decodeMaskedPtr(de));

if (d->type->keys_are_odd)
de = key; /* ENTRY_PTR_IS_ODD_KEY trivially set by the odd key. */
de = storedKey; /* ENTRY_PTR_IS_ODD_KEY trivially set by the odd key. */
else
de = encodeMaskedPtr(key, ENTRY_PTR_IS_EVEN_KEY);
de = encodeMaskedPtr(storedKey, ENTRY_PTR_IS_EVEN_KEY);

} else if (entryIsKey(de)) {
/* We don't have an allocated entry but we need one. */
de = createEntryNoValue(key, d->ht_table[1][h]);
de = createEntryNoValue(storedKey, d->ht_table[1][h]);
} else {
dictSetNext(de, d->ht_table[1][h]);
}
Expand Down Expand Up @@ -486,7 +490,7 @@ int _dictBucketRehash(dict *d, uint64_t idx) {
}

/* Add an element to the target hash table */
int dictAdd(dict *d, void *key, void *val)
int dictAdd(dict *d, void *key __stored_key, void *val)
{
dictEntry *entry = dictAddRaw(d,key,NULL);

Expand Down Expand Up @@ -519,10 +523,10 @@ int dictCompareKeys(dict *d, const void *key1, const void *key2) {
*
* If key was added, the hash entry is returned to be manipulated by the caller.
*/
dictEntry *dictAddRaw(dict *d, void *key, dictEntry **existing)
dictEntry *dictAddRaw(dict *d, void *key __stored_key, dictEntry **existing)
{
/* Get the position for the new key or NULL if the key already exists. */
void *position = dictFindLinkForInsert(d, key, existing);
void *position = dictFindLinkForInsert(d, dictStoredKey2Key(d, key), existing);
if (!position) return NULL;

/* Dup the key if necessary. */
Expand All @@ -535,7 +539,7 @@ dictEntry *dictAddRaw(dict *d, void *key, dictEntry **existing)
* call to dictFindLinkForInsert(). This is a low level function which allows
* splitting dictAddRaw in two parts. Normally, dictAddRaw or dictAdd should be
* used instead. It assumes that dictExpandIfNeeded() was called before. */
dictEntry *dictInsertKeyAtLink(dict *d, void *key, dictEntryLink link) {
dictEntry *dictInsertKeyAtLink(dict *d, void *key __stored_key, dictEntryLink link) {
dictEntryLink bucket = link; /* It's a bucket, but the API hides that. */
dictEntry *entry;
/* If rehashing is ongoing, we insert in table 1, otherwise in table 0.
Expand Down Expand Up @@ -580,7 +584,7 @@ dictEntry *dictInsertKeyAtLink(dict *d, void *key, dictEntryLink link) {
* Return 1 if the key was added from scratch, 0 if there was already an
* element with such key and dictReplace() just performed a value update
* operation. */
int dictReplace(dict *d, void *key, void *val)
int dictReplace(dict *d, void *key __stored_key, void *val)
{
dictEntry *entry, *existing;

Expand Down Expand Up @@ -611,7 +615,7 @@ int dictReplace(dict *d, void *key, void *val)
* existing key is returned.)
*
* See dictAddRaw() for more information. */
dictEntry *dictAddOrFind(dict *d, void *key) {
dictEntry *dictAddOrFind(dict *d, void *key __stored_key) {
dictEntry *entry, *existing;
entry = dictAddRaw(d,key,&existing);
return entry ? entry : existing;
Expand All @@ -629,7 +633,7 @@ static dictEntry *dictGenericDelete(dict *d, const void *key, int nofree) {
/* dict is empty */
if (dictSize(d) == 0) return NULL;

h = dictHashKey(d, key, d->useStoredKeyApi);
h = dictGetHash(d, key);
idx = h & DICTHT_SIZE_MASK(d->ht_size_exp[0]);

/* Rehash the hash table if needed */
Expand All @@ -643,7 +647,7 @@ static dictEntry *dictGenericDelete(dict *d, const void *key, int nofree) {
he = d->ht_table[table][idx];
prevHe = NULL;
while(he) {
void *he_key = dictGetKey(he);
const void *he_key = dictStoredKey2Key(d, dictGetKey(he));
if (key == he_key || cmpFunc(&cmpCache, key, he_key)) {
/* Unlink the element from the list */
if (prevHe)
Expand Down Expand Up @@ -772,7 +776,7 @@ static dictEntryLink dictFindLinkInternal(dict *d, const void *key, dictEntryLin
if (dictSize(d) == 0) return NULL;
}

const uint64_t hash = dictHashKey(d, key, d->useStoredKeyApi);
const uint64_t hash = dictGetHash(d, key);
idx = hash & DICTHT_SIZE_MASK(d->ht_size_exp[0]);
keyCmpFunc cmpFunc = dictGetCmpFunc(d);

Expand All @@ -790,7 +794,7 @@ static dictEntryLink dictFindLinkInternal(dict *d, const void *key, dictEntryLin
link = &(d->ht_table[table][idx]);
if (bucket) *bucket = link;
while(link && *link) {
void *visitedKey = dictGetKey(*link);
const void *visitedKey = dictStoredKey2Key(d, dictGetKey(*link));

/* Prefetch the next entry to improve cache efficiency */
redis_prefetch_read(dictGetNext(*link));
Expand Down Expand Up @@ -880,7 +884,7 @@ dictEntryLink dictFindLink(dict *d, const void *key, dictEntryLink *bucket) {
* newItem: 1 = Add a key with a new dictEntry.
* 0 = Set a key to an existing dictEntry.
*/
void dictSetKeyAtLink(dict *d, void *key, dictEntryLink *link, int newItem) {
void dictSetKeyAtLink(dict *d, void *key __stored_key, dictEntryLink *link, int newItem) {
dictEntryLink dummy = NULL;
if (link == NULL) link = &dummy;
void *addedKey = (d->type->keyDup) ? d->type->keyDup(d, key) : key;
Expand All @@ -895,9 +899,7 @@ void dictSetKeyAtLink(dict *d, void *key, dictEntryLink *link, int newItem) {
if (snap[0] != d->ht_size_exp[0] || snap[1] != d->ht_size_exp[1] || *link == NULL) {
dictEntryLink bucket;
/* Bypass dictFindLink() to search bucket even if dict is empty!!! */
dictUseStoredKeyApi(d, 1);
*link = dictFindLinkInternal(d, key, &bucket);
dictUseStoredKeyApi(d, 0);
*link = dictFindLinkInternal(d, dictStoredKey2Key(d, key), &bucket);
assert(bucket != NULL);
assert(*link == NULL);
*link = bucket; /* On newItem the link should be the bucket */
Expand All @@ -907,9 +909,9 @@ void dictSetKeyAtLink(dict *d, void *key, dictEntryLink *link, int newItem) {
}

/* Setting key of existing dictEntry (newItem == 0)*/

if (*link == NULL) {
*link = dictFindLink(d, key, NULL);
*link = dictFindLink(d, addedKey, NULL);
assert(*link != NULL);
}

Expand Down Expand Up @@ -959,15 +961,15 @@ dictEntryLink dictTwoPhaseUnlinkFind(dict *d, const void *key, int *table_index)
if (dictSize(d) == 0) return NULL; /* dict is empty */
if (dictIsRehashing(d)) _dictRehashStep(d);

h = dictHashKey(d, key, d->useStoredKeyApi);
h = dictGetHash(d, key);
keyCmpFunc cmpFunc = dictGetCmpFunc(d);

for (table = 0; table <= 1; table++) {
idx = h & DICTHT_SIZE_MASK(d->ht_size_exp[table]);
if (table == 0 && (long)idx < d->rehashidx) continue;
dictEntry **ref = &d->ht_table[table][idx];
while (ref && *ref) {
void *de_key = dictGetKey(*ref);
const void *de_key = dictStoredKey2Key(d, dictGetKey(*ref));
if (key == de_key || cmpFunc(&cmpCache, key, de_key)) {
*table_index = table;
dictPauseRehashing(d);
Expand All @@ -993,7 +995,7 @@ void dictTwoPhaseUnlinkFree(dict *d, dictEntryLink plink, int table_index) {
dictResumeRehashing(d);
}

void dictSetKey(dict *d, dictEntry* de, void *key) {
void dictSetKey(dict *d, dictEntry* de, void *key __stored_key) {
assert(!d->type->no_value);
if (d->type->keyDup)
de->key = d->type->keyDup(d, key);
Expand Down Expand Up @@ -1747,7 +1749,7 @@ dictEntryLink dictFindLinkForInsert(dict *d, const void *key, dictEntry **existi
unsigned long idx, table;
dictCmpCache cmpCache = {0};
dictEntry *he;
uint64_t hash = dictHashKey(d, key, d->useStoredKeyApi);
uint64_t hash = dictGetHash(d, key);
if (existing) *existing = NULL;
idx = hash & DICTHT_SIZE_MASK(d->ht_size_exp[0]);

Expand All @@ -1764,7 +1766,7 @@ dictEntryLink dictFindLinkForInsert(dict *d, const void *key, dictEntry **existi
/* Search if this slot does not already contain the given key */
he = d->ht_table[table][idx];
while(he) {
void *he_key = dictGetKey(he);
const void *he_key = dictStoredKey2Key(d, dictGetKey(he));
if (key == he_key || cmpFunc(&cmpCache, key, he_key)) {
if (existing) *existing = he;
return NULL;
Expand Down Expand Up @@ -1802,8 +1804,9 @@ void dictSetResizeEnabled(dictResizeEnable enable) {
dict_can_resize = enable;
}

/* Compiler inlines this for internal calls within dict.c (verified with -O3). */
uint64_t dictGetHash(dict *d, const void *key) {
return dictHashKey(d, key, d->useStoredKeyApi);
return d->type->hashFunction(key);
}

/* Provides the old and new ht size for a given dictionary during rehashing. This method
Expand Down
Loading
Loading