|
| 1 | +#include <array> |
| 2 | +#include <vector> |
| 3 | +#include <iostream> |
| 4 | +#include <cstdlib> |
| 5 | +#include <cstdint> |
| 6 | +#include <random> |
| 7 | +#include <algorithm> |
| 8 | +#include <unistd.h> |
| 9 | +#include <assert.h> |
| 10 | + |
| 11 | +inline size_t fastrange64(uint64_t hash, size_t range) { |
| 12 | + __uint128_t wide = __uint128_t{range} * hash; |
| 13 | + return static_cast<size_t>(wide >> 64); |
| 14 | +} |
| 15 | + |
| 16 | +inline uint32_t fastrange32(uint32_t hash, uint32_t range) { |
| 17 | + uint64_t wide = uint64_t{hash} * range; |
| 18 | + return static_cast<uint32_t>(wide >> 32); |
| 19 | +} |
| 20 | + |
| 21 | +// Best is around 20/20, but this can make for slightly faster queries |
| 22 | +static constexpr uint32_t front_smash = 32; |
| 23 | +static constexpr uint32_t back_smash = 31; |
| 24 | + |
| 25 | +struct GaussData { |
| 26 | + uint64_t row = 0; |
| 27 | + uint32_t start = 0; |
| 28 | + uint32_t pivot = 0; |
| 29 | + void Reset(uint64_t h, uint32_t len) { |
| 30 | + uint32_t addrs = len - 63 + front_smash + back_smash; |
| 31 | + start = fastrange32((uint32_t)(h >> 32), addrs); |
| 32 | + start = std::max(start, front_smash); |
| 33 | + start -= front_smash; |
| 34 | + start = std::min(start, len - 64); |
| 35 | + assert(start < len - 63); |
| 36 | + row = h + 0x9e3779b97f4a7c13 * 0x9e3779b97f4a7c13; |
| 37 | + row ^= h >> 32; |
| 38 | + row |= (uint64_t{1} << 63); |
| 39 | + pivot = 0; |
| 40 | + } |
| 41 | +}; |
| 42 | + |
| 43 | +static inline uint32_t getShard(uint64_t h, uint32_t shards) { |
| 44 | + return fastrange32((uint32_t)(h >> 32), shards); |
| 45 | +} |
| 46 | + |
| 47 | +static inline uint32_t getSection(uint64_t h) { |
| 48 | + uint32_t v = h & 1023; |
| 49 | + if (v < 300) { |
| 50 | + return v / 3; |
| 51 | + } else if (v < 428) { |
| 52 | + return v - 200; |
| 53 | + } else if (v < 512) { |
| 54 | + return (v + 256) / 3; |
| 55 | +// } else if (v < 532) { |
| 56 | +// return (v + 1516) / 8; |
| 57 | + } else { |
| 58 | + return 0; |
| 59 | + } |
| 60 | +} |
| 61 | + |
| 62 | +static inline uint64_t rot64(uint64_t h, int count) { |
| 63 | + return (h << count) | (h >> (64 - count)); |
| 64 | +} |
| 65 | + |
| 66 | +int main(int argc, char *argv[]) { |
| 67 | + std::mt19937_64 rand(getpid()); |
| 68 | + |
| 69 | + uint32_t nkeys = (uint32_t)std::atoi(argv[1]); |
| 70 | + double f = std::atof(argv[2]); |
| 71 | + uint32_t lenish = (uint32_t)(f * nkeys + 0.5); |
| 72 | + uint32_t shards = 1; |
| 73 | + while (lenish / shards > 1414) { |
| 74 | + shards *= 2; |
| 75 | + } |
| 76 | + uint32_t avg_len_per_shard = (lenish + shards / 2) / shards; |
| 77 | + uint32_t min_len_per_shard = avg_len_per_shard & ~uint32_t{63}; |
| 78 | + uint32_t max_len_per_shard = (avg_len_per_shard + 63) & ~uint32_t{63}; |
| 79 | + |
| 80 | + std::array<std::vector<uint64_t>, 256> *hashes = new std::array<std::vector<uint64_t>, 256>[shards]; |
| 81 | + for (uint32_t i = 0; i < nkeys; ++i) { |
| 82 | + uint64_t h = (uint64_t)rand(); |
| 83 | + if ((h & uint64_t{0x8000000000000380}) == uint64_t{0x8000000000000380}) { |
| 84 | + h -= uint64_t{0x8000000000000000}; |
| 85 | + } |
| 86 | + hashes[getShard(h, shards)][getSection(h)].push_back(h); |
| 87 | + } |
| 88 | + |
| 89 | + GaussData *data = new GaussData[max_len_per_shard]; |
| 90 | + std::vector<uint64_t> shard_hashes; |
| 91 | + std::vector<uint64_t> *bumped = new std::vector<uint64_t>[shards]; |
| 92 | + |
| 93 | + for (uint32_t shard = 0; shard < shards; ++shard) { |
| 94 | + uint32_t len_this_shard = ((shard * avg_len_per_shard + 63 + avg_len_per_shard) & ~uint32_t{63}) - ((shard * avg_len_per_shard + 63) & ~uint32_t{63}); |
| 95 | + assert(len_this_shard == min_len_per_shard || len_this_shard == max_len_per_shard); |
| 96 | + |
| 97 | + uint32_t last_section = 0; |
| 98 | + size_t kept_count = hashes[shard][last_section].size() + bumped[shard].size(); |
| 99 | + for (; last_section < 255; ++last_section) { |
| 100 | + size_t next_count = hashes[shard][last_section + 1].size(); |
| 101 | + if (kept_count + next_count > len_this_shard) { |
| 102 | + break; |
| 103 | + } |
| 104 | + kept_count += next_count; |
| 105 | + } |
| 106 | + std::cout << "pre-kept@" << shard << " = " << kept_count << " / " << len_this_shard << " (" << (1.0 * kept_count / len_this_shard) << ") (last=" << last_section << ")" << std::endl; |
| 107 | + if (shard == shards - 1) { |
| 108 | + // no more bumps |
| 109 | + if (last_section < 255) { |
| 110 | + uint32_t overflow_count = 0; |
| 111 | + for (uint32_t i = last_section + 1; i < 256; ++i) { |
| 112 | + overflow_count += hashes[shard][i].size(); |
| 113 | + } |
| 114 | + std::cout << "overflow! " << overflow_count << std::endl; |
| 115 | + return 1; |
| 116 | + } |
| 117 | + } else { |
| 118 | + if (kept_count > len_this_shard) { |
| 119 | + std::cout << "early overflow!" << std::endl; |
| 120 | + return 1; |
| 121 | + } |
| 122 | + } |
| 123 | + |
| 124 | + retry: |
| 125 | + uint64_t seed = rot64(uint64_t{0x9e3779b97f4a7c13}, (last_section * 13) & 63); |
| 126 | + for (uint64_t h : bumped[shard]) { |
| 127 | + shard_hashes.push_back(h /** seed */); |
| 128 | + } |
| 129 | + for (uint32_t i = 0; i <= last_section; ++i) { |
| 130 | + for (uint64_t h : hashes[shard][i]) { |
| 131 | + //shard_hashes.push_back(rot64(h * 0x9e3779b97f4a7c13, (last_section * 13) & 63) * 0x9e3779b97f4a7c13); |
| 132 | + shard_hashes.push_back(h * seed); |
| 133 | + } |
| 134 | + } |
| 135 | + assert(kept_count == shard_hashes.size()); |
| 136 | + std::sort(shard_hashes.begin(), shard_hashes.end()); |
| 137 | + for (uint64_t i = 0; i < kept_count; ++i) { |
| 138 | + data[i].Reset(shard_hashes[i], len_this_shard); |
| 139 | + } |
| 140 | + shard_hashes.clear(); |
| 141 | + for (uint32_t i = 0; i < kept_count; ++i) { |
| 142 | + GaussData &di = data[i]; |
| 143 | + if (di.row == 0) { |
| 144 | + if (last_section == 0) { |
| 145 | + std::cout << "early2 overflow!" << std::endl; |
| 146 | + return 1; |
| 147 | + } |
| 148 | + kept_count -= hashes[shard][last_section].size(); |
| 149 | + --last_section; |
| 150 | + goto retry; |
| 151 | + } |
| 152 | + int tz = __builtin_ctzl(di.row); |
| 153 | + di.pivot = di.start + tz; |
| 154 | + for (uint32_t j = i + 1; j < kept_count; ++j) { |
| 155 | + GaussData &dj = data[j]; |
| 156 | + assert(dj.start >= di.start); |
| 157 | + if (di.pivot < dj.start) { |
| 158 | + break; |
| 159 | + } |
| 160 | + if ((dj.row >> (di.pivot - dj.start)) & 1) { |
| 161 | + dj.row ^= (di.row >> (dj.start - di.start)); |
| 162 | + // TODO?: forward-looking check for 0 |
| 163 | + } |
| 164 | + } |
| 165 | + } |
| 166 | + // OK |
| 167 | + std::cout << "kept@" << shard << " = " << kept_count << " / " << len_this_shard << " (" << (1.0 * kept_count / len_this_shard) << ") (last=" << last_section << ")" << std::endl; |
| 168 | + if (shard < shards - 1) { |
| 169 | + for (uint32_t i = last_section + 1; i < 256; ++i) { |
| 170 | + // bump |
| 171 | + uint64_t keep_mask = shards / 2; |
| 172 | + if (keep_mask > 0) { |
| 173 | + while ((shard & keep_mask) == keep_mask && (keep_mask & 1) == 0) { |
| 174 | + keep_mask |= keep_mask / 2; |
| 175 | + } |
| 176 | + while (keep_mask < uint64_t{0x8000000000000000}) { |
| 177 | + keep_mask <<= 1; |
| 178 | + } |
| 179 | + } |
| 180 | + uint64_t other_mask = ~keep_mask >> 1; |
| 181 | + for (uint64_t h : hashes[shard][i]) { |
| 182 | + uint64_t rot_h = (h >> 32) | (h << 32); |
| 183 | + uint64_t alt_h = (uint64_t{0x8000000000000000} | (h >> 1)) ^ (rot_h & other_mask); |
| 184 | + uint32_t new_shard = getShard(alt_h, shards); |
| 185 | + assert(new_shard > shard); |
| 186 | + bumped[new_shard].push_back(h * seed); |
| 187 | + } |
| 188 | + } |
| 189 | + } |
| 190 | + } |
| 191 | + |
| 192 | + return 0; |
| 193 | +} |
0 commit comments