Skip to content

Commit 01cc346

Browse files
authored
[Optimizer] Use bitset for qubit masks (#220)
* [Optimizer] Use bitset for qubit masks * code format
1 parent 915d681 commit 01cc346

File tree

4 files changed

+113
-3
lines changed

4 files changed

+113
-3
lines changed

src/quartz/math/bitset.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#include "bitset.h"
2+
3+
#include <cassert>
4+
5+
namespace quartz {
6+
7+
Bitset::Bitset(size_t n) { a.assign((n + BLOCK - 1) / BLOCK, 0); }
8+
9+
Bitset::reference::reference(Bitset::value_t *a, int x)
10+
: pos(a + (x >> LOGBLOCK)), dig(((value_t)1) << (x & BLOCK1)) {}
11+
Bitset::reference::operator bool() const { return *pos & dig; }
12+
bool Bitset::reference::operator~() const { return ~*pos & dig; }
13+
Bitset::reference &Bitset::reference::operator=(bool x) {
14+
if (x)
15+
*pos |= dig;
16+
else
17+
*pos &= MASK ^ dig;
18+
return *this;
19+
}
20+
Bitset::reference &Bitset::reference::operator=(const Bitset::reference &x) {
21+
if (*x.pos & x.dig)
22+
*pos |= dig;
23+
else
24+
*pos &= MASK ^ dig;
25+
return *this;
26+
}
27+
Bitset::reference &Bitset::reference::flip() {
28+
*pos ^= dig;
29+
return *this;
30+
}
31+
32+
void Bitset::flip(int x) { a[x >> LOGBLOCK] ^= ((value_t)1) << (x & BLOCK1); }
33+
bool Bitset::operator[](int x) const {
34+
return (a[x >> LOGBLOCK] >> (x & BLOCK1)) & 1;
35+
}
36+
Bitset::reference Bitset::operator[](int x) { return {a.data(), x}; }
37+
Bitset Bitset::operator^(const Bitset &b) const {
38+
assert(a.size() == b.a.size());
39+
Bitset result(a.size() * BLOCK);
40+
for (int i = 0; i < (int)a.size(); i++) {
41+
result.a[i] = a[i] ^ b.a[i];
42+
}
43+
return result;
44+
}
45+
bool Bitset::operator==(const Bitset &b) const {
46+
return a == b.a; // std::vector<>::operator== does the job!
47+
}
48+
49+
std::size_t BitsetHash::operator()(const Bitset &x) const {
50+
std::hash<size_t> hash_fn;
51+
std::size_t result = x.a.size();
52+
for (auto &val : x.a) {
53+
result = result * 17 + hash_fn(val);
54+
}
55+
return result;
56+
}
57+
} // namespace quartz

src/quartz/math/bitset.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#pragma once
2+
3+
#include <vector>
4+
5+
namespace quartz {
6+
7+
class Bitset;
8+
struct BitsetHash {
9+
public:
10+
std::size_t operator()(const Bitset &x) const;
11+
};
12+
class Bitset {
13+
public:
14+
using value_t = unsigned long long;
15+
Bitset() = default;
16+
Bitset(size_t n);
17+
class reference {
18+
public:
19+
value_t *pos;
20+
value_t dig;
21+
reference() = default;
22+
reference(value_t *a, int x);
23+
operator bool() const;
24+
bool operator~() const;
25+
reference &operator=(bool x);
26+
reference &operator=(const reference &x);
27+
reference &flip();
28+
};
29+
void flip(int x);
30+
bool operator[](int x) const;
31+
reference operator[](int x);
32+
Bitset operator^(const Bitset &b) const;
33+
bool operator==(const Bitset &b) const;
34+
friend std::size_t BitsetHash::operator()(const Bitset &x) const;
35+
36+
private:
37+
static const int LOGBLOCK = 6;
38+
static const int BLOCK = (1 << LOGBLOCK);
39+
static const int BLOCK1 = BLOCK - 1;
40+
static const value_t MASK = ((value_t)-1);
41+
std::vector<value_t> a;
42+
};
43+
44+
} // namespace quartz

src/quartz/tasograph/tasograph.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "tasograph.h"
22

3+
#include "quartz/math/bitset.h"
34
#include "substitution.h"
45

56
#include <cassert>
@@ -994,7 +995,7 @@ void Graph::rotation_merging(GateType target_rotation) {
994995
assert(false);
995996
}
996997
// Step 1: calculate the bitmask of each operator
997-
std::unordered_map<Pos, uint64_t, PosHash> bitmasks;
998+
std::unordered_map<Pos, QubitMaskType, PosHash> bitmasks;
998999
std::unordered_map<Pos, int, PosHash> pos_to_qubits;
9991000
std::queue<Op> todos;
10001001

@@ -1003,7 +1004,10 @@ void Graph::rotation_merging(GateType target_rotation) {
10031004
if (it.first.ptr->tp == GateType::input_qubit) {
10041005
todos.push(it.first);
10051006
int qubit_idx = input_qubit_op_2_qubit_idx[it.first];
1006-
bitmasks[Pos(it.first, 0)] = 1 << qubit_idx;
1007+
if (bitmasks.count(Pos(it.first, 0)) == 0) {
1008+
bitmasks[Pos(it.first, 0)] = Bitset(get_num_qubits());
1009+
}
1010+
bitmasks[Pos(it.first, 0)][qubit_idx] = true;
10071011
pos_to_qubits[Pos(it.first, 0)] = qubit_idx;
10081012
} else if (it.first.ptr->tp == GateType::input_param) {
10091013
todos.push(it.first);
@@ -1106,7 +1110,8 @@ void Graph::rotation_merging(GateType target_rotation) {
11061110

11071111
// Step 4: merge rotations with the same bitmasks on the same qubit
11081112
std::unordered_map<
1109-
int, std::unordered_map<uint64_t, std::unordered_set<Pos, PosHash>>>
1113+
int, std::unordered_map<QubitMaskType, std::unordered_set<Pos, PosHash>,
1114+
QubitMaskHash>>
11101115
qubit_2_bm_2_pos;
11111116
for (const auto &pos : covered) {
11121117
if (pos.op.ptr->tp == GateType::cx) {

src/quartz/utils/utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#pragma once
22

3+
#include "quartz/math/bitset.h"
4+
35
#include <complex>
46
#include <filesystem>
57

@@ -22,6 +24,8 @@ using CircuitSeqHashType = unsigned long long;
2224
using PhaseShiftIdType = int;
2325
using EquivalenceHashType = std::pair<unsigned long long, int>;
2426
using InputParamMaskType = unsigned long long;
27+
using QubitMaskType = quartz::Bitset; // for rotation merging
28+
using QubitMaskHash = quartz::BitsetHash;
2529

2630
using namespace std::complex_literals; // so that we can write stuff like 1.0i
2731

0 commit comments

Comments
 (0)