Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements for Trace Computation #628

Merged
merged 16 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions include/mqt-core/dd/DDpackageConfig.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

namespace dd {
struct DDPackageConfig {
// Note the order of parameters here must be the *same* as in the template
// definition.
static constexpr std::size_t UT_VEC_NBUCKET = 32768U;
static constexpr std::size_t UT_VEC_INITIAL_ALLOCATION_SIZE = 2048U;
static constexpr std::size_t UT_MAT_NBUCKET = 32768U;
Expand All @@ -22,6 +20,8 @@ struct DDPackageConfig {
static constexpr std::size_t CT_MAT_MAT_MULT_NBUCKET = 16384U;
static constexpr std::size_t CT_VEC_KRON_NBUCKET = 4096U;
static constexpr std::size_t CT_MAT_KRON_NBUCKET = 4096U;
static constexpr std::size_t CT_DM_TRACE_NBUCKET = 1U;
static constexpr std::size_t CT_MAT_TRACE_NBUCKET = 4096U;
static constexpr std::size_t CT_VEC_INNER_PROD_NBUCKET = 4096U;
static constexpr std::size_t CT_DM_NOISE_NBUCKET = 1U;
static constexpr std::size_t UT_DM_NBUCKET = 1U;
Expand Down Expand Up @@ -63,6 +63,8 @@ struct DensityMatrixSimulatorDDPackageConfig : public dd::DDPackageConfig {
static constexpr std::size_t UT_MAT_INITIAL_ALLOCATION_SIZE = 1U;
static constexpr std::size_t CT_VEC_KRON_NBUCKET = 1U;
static constexpr std::size_t CT_MAT_KRON_NBUCKET = 1U;
static constexpr std::size_t CT_DM_TRACE_NBUCKET = 4096U;
static constexpr std::size_t CT_MAT_TRACE_NBUCKET = 1U;
static constexpr std::size_t CT_VEC_INNER_PROD_NBUCKET = 1U;
static constexpr std::size_t STOCHASTIC_CACHE_OPS = 1U;
static constexpr std::size_t CT_VEC_ADD_MAG_NBUCKET = 1U;
Expand Down
91 changes: 80 additions & 11 deletions include/mqt-core/dd/Package.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "dd/CachedEdge.hpp"
#include "dd/Complex.hpp"
#include "dd/ComplexNumbers.hpp"
#include "dd/ComplexValue.hpp"
#include "dd/ComputeTable.hpp"
#include "dd/DDDefinitions.hpp"
#include "dd/DDpackageConfig.hpp"
Expand Down Expand Up @@ -256,16 +257,22 @@ template <class Config> class Package {
}
// invalidate all compute tables involving matrices if any matrix node has
// been collected
if (mCollect > 0 || dCollect > 0) {
if (mCollect > 0) {
matrixAdd.clear();
conjugateMatrixTranspose.clear();
matrixKronecker.clear();
matrixTrace.clear();
matrixVectorMultiplication.clear();
matrixMatrixMultiplication.clear();
stochasticNoiseOperationCache.clear();
}
// invalidate all compute tables involving density matrices if any density
// matrix node has been collected
if (dCollect > 0) {
densityAdd.clear();
densityDensityMultiplication.clear();
densityNoise.clear();
densityTrace.clear();
}
// invalidate all compute tables where any component of the entry contains
// numbers from the complex table if any complex numbers were collected
Expand All @@ -276,10 +283,12 @@ template <class Config> class Package {
vectorInnerProduct.clear();
vectorKronecker.clear();
matrixKronecker.clear();
matrixTrace.clear();
stochasticNoiseOperationCache.clear();
densityAdd.clear();
densityDensityMultiplication.clear();
densityNoise.clear();
densityTrace.clear();
burgholzer marked this conversation as resolved.
Show resolved Hide resolved
}
return vCollect > 0 || mCollect > 0 || cCollect > 0;
}
Expand Down Expand Up @@ -885,11 +894,13 @@ template <class Config> class Package {
vectorInnerProduct.clear();
vectorKronecker.clear();
matrixKronecker.clear();
matrixTrace.clear();

stochasticNoiseOperationCache.clear();
densityAdd.clear();
densityDensityMultiplication.clear();
densityNoise.clear();
densityTrace.clear();
}

///
Expand Down Expand Up @@ -1939,6 +1950,19 @@ template <class Config> class Package {
/// (Partial) trace
///
public:
UnaryComputeTable<dNode*, dCachedEdge, Config::CT_DM_TRACE_NBUCKET>
densityTrace{};
UnaryComputeTable<mNode*, mCachedEdge, Config::CT_MAT_TRACE_NBUCKET>
matrixTrace{};

template <class Node> [[nodiscard]] auto& getTraceComputeTable() {
if constexpr (std::is_same_v<Node, mNode>) {
return matrixTrace;
} else {
return densityTrace;
}
}

mEdge partialTrace(const mEdge& a, const std::vector<bool>& eliminate) {
auto r = trace(a, eliminate, eliminate.size());
return {r.p, cn.lookup(r.w)};
Expand All @@ -1947,7 +1971,8 @@ template <class Config> class Package {
template <class Node>
ComplexValue trace(const Edge<Node>& a, const std::size_t numQubits) {
if (a.isIdentity()) {
return a.w * std::pow(2, numQubits);
return static_cast<ComplexValue>(a.w);
;
}
const auto eliminate = std::vector<bool>(numQubits, true);
return trace(a, eliminate, numQubits).w;
Expand Down Expand Up @@ -1975,7 +2000,23 @@ template <class Config> class Package {
}

private:
/// TODO: introduce a compute table for the trace?
/**
* @brief Computes the normalized (partial) trace using a compute table to
* store results for eliminated nodes.
* @details At each level, perform lookup and store results in the compute
* table only if all lower level qubits are eliminated as well.
*
* This optimization allows the full trace
* computation to scale linearly with respect to the number of nodes.
* However, the partial trace computation still scales with the number of
* paths in the DD when bottom qubits are to be eliminated.
*
* Normalization is continuously applied, dividing by two at each level
* marked for elimination, thereby ensuring that the result is mapped to the
* interval [0,1].
* @note Normalization is only applied to matrix nodes, as the trace
* of density matrices equals 1 by definition.
*/
template <class Node>
CachedEdge<Node> trace(const Edge<Node>& a,
const std::vector<bool>& eliminate, std::size_t level,
Expand All @@ -1985,24 +2026,51 @@ template <class Config> class Package {
return CachedEdge<Node>::zero();
}

if (std::none_of(eliminate.begin(), eliminate.end(),
// If `a` is the identity matrix or there is nothing left to eliminate,
// then simply return `a`
if (a.isIdentity() ||
std::none_of(eliminate.begin(),
eliminate.begin() +
static_cast<std::vector<bool>::difference_type>(level),
[](bool v) { return v; })) {
return CachedEdge<Node>{a.p, aWeight};
}

if (a.isIdentity()) {
const auto elims =
std::count(eliminate.begin(),
eliminate.begin() + static_cast<int64_t>(level), true);
return CachedEdge<Node>{a.p, aWeight * std::pow(2, elims)};
}

const auto v = a.p->v;
if (eliminate[v]) {
// Lookup nodes marked for elimination in the compute table if all
// lower level qubits are eliminated as well: if the trace has already
// been computed, return the result
auto& computeTable = getTraceComputeTable<Node>();
if (std::all_of(
eliminate.begin(),
eliminate.begin() +
static_cast<std::vector<bool>::difference_type>(level),
[](bool e) { return e; })) {
TeWas marked this conversation as resolved.
Show resolved Hide resolved
if (const auto* r = computeTable.lookup(a.p); r) {
return {r->p, r->w * aWeight};
}
}

const auto elims = alreadyEliminated + 1;
auto r = add2(trace(a.p->e[0], eliminate, level - 1, elims),
trace(a.p->e[3], eliminate, level - 1, elims), v - 1);

// The resulting weight is continuously normalized to the range [0,1] for
// matrix nodes
if constexpr (std::is_same_v<Node, mNode>) {
r.w = r.w / 2.0;
}

// Insert result into compute table if all lower level qubits are
// eliminated as well
if (std::all_of(
eliminate.begin(),
eliminate.begin() +
static_cast<std::vector<bool>::difference_type>(level),
[](bool e) { return e; })) {
computeTable.insert(a.p, {r.p, r.w});
}
r.w = r.w * aWeight;
return r;
}
Expand All @@ -2019,6 +2087,7 @@ template <class Config> class Package {
eliminate.begin(), eliminate.end(), true)) -
alreadyEliminated));
auto r = makeDDNode(adjustedV, edge);
// nodes that were not eliminated are not added to the compute table
TeWas marked this conversation as resolved.
Show resolved Hide resolved
r.w = r.w * aWeight;
return r;
}
Expand Down
81 changes: 77 additions & 4 deletions test/dd/test_package.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,21 +285,94 @@ TEST(DDPackageTest, IdentityTrace) {
auto dd = std::make_unique<dd::Package<>>(4);
auto fullTrace = dd->trace(dd->makeIdent(), 4);

ASSERT_EQ(fullTrace.r, 16.);
ASSERT_EQ(fullTrace.r, 1.);
}

TEST(DDPackageTest, CNotKronTrace) {
auto dd = std::make_unique<dd::Package<>>(4);
auto cxGate = dd->makeGateDD(dd::X_MAT, 1_pc, 0);
auto cxGateKron = dd->kronecker(cxGate, cxGate, 2);
auto fullTrace = dd->trace(cxGateKron, 4);
ASSERT_EQ(fullTrace, 0.25);
}

TEST(DDPackageTest, PartialIdentityTrace) {
auto dd = std::make_unique<dd::Package<>>(2);
auto tr = dd->partialTrace(dd->makeIdent(), {false, true});
auto mul = dd->multiply(tr, tr);
EXPECT_EQ(dd::RealNumber::val(mul.w.r), 4.0);
EXPECT_EQ(dd::RealNumber::val(mul.w.r), 1.);
}

TEST(DDPackageTest, PartialNonIdentityTrace) {
TEST(DDPackageTest, PartialSWapMatTrace) {
auto dd = std::make_unique<dd::Package<>>(2);
auto swapGate = dd->makeTwoQubitGateDD(dd::SWAP_MAT, 0, 1);
auto ptr = dd->partialTrace(swapGate, {true, false});
EXPECT_EQ(ptr.w * ptr.w, 1.);
auto fullTrace = dd->trace(ptr, 1);
auto fullTraceOriginal = dd->trace(swapGate, 2);
EXPECT_EQ(dd::RealNumber::val(ptr.w.r), 0.5);
// Check that successively tracing out subsystems is the same as computing the
// full trace from the beginning
EXPECT_EQ(fullTrace.r, fullTraceOriginal.r);
}

TEST(DDPackageTest, PartialTraceKeepInnerQubits) {
// Check that the partial trace computation is correct when tracing out the
// outer qubits only. This test shows that we should avoid storing
// non-eliminated nodes in the compute table, as this would prevent their
// proper elimination in subsequent trace calls.

const std::size_t numQubits = 8;
auto dd = std::make_unique<dd::Package<>>(numQubits);
const auto swapGate = dd->makeTwoQubitGateDD(dd::SWAP_MAT, 0, 1);
auto swapKron = swapGate;
for (std::size_t i = 0; i < 3; ++i) {
swapKron = dd->kronecker(swapKron, swapGate, 2);
}
auto fullTraceOriginal = dd->trace(swapKron, numQubits);
auto ptr = dd->partialTrace(
swapKron, {true, true, false, false, false, false, true, true});
auto fullTrace = dd->trace(ptr, 4);
EXPECT_EQ(dd::RealNumber::val(ptr.w.r), 0.25);
EXPECT_EQ(fullTrace.r, 0.0625);
// Check that successively tracing out subsystems is the same as computing the
// full trace from the beginning
EXPECT_EQ(fullTrace.r, fullTraceOriginal.r);
}

TEST(DDPackageTest, TraceComplexity) {
// Check that the full trace computation scales with the number of nodes
// instead of paths in the DD due to the usage of a compute table
for (std::size_t numQubits = 1; numQubits <= 10; ++numQubits) {
auto dd = std::make_unique<dd::Package<>>(numQubits);
auto& computeTable = dd->getTraceComputeTable<dd::mNode>();
const auto hGate = dd->makeGateDD(dd::H_MAT, 0);
auto hKron = hGate;
for (std::size_t i = 0; i < numQubits - 1; ++i) {
hKron = dd->kronecker(hKron, hGate, 1);
}
dd->trace(hKron, numQubits);
const auto& stats = computeTable.getStats();
ASSERT_EQ(stats.lookups, 2 * numQubits - 1);
ASSERT_EQ(stats.hits, numQubits - 1);
}
}

TEST(DDPackageTest, KeepBottomQubitsPartialTraceComplexity) {
// Check that during the trace computation, once a level is reached
// where the remaining qubits should not be eliminated, the function does not
// recurse further but immediately returns the current CachedEdge<Node>.
const std::size_t numQubits = 8;
auto dd = std::make_unique<dd::Package<>>(numQubits);
auto& computeTable = dd->getTraceComputeTable<dd::mNode>();
const auto hGate = dd->makeGateDD(dd::H_MAT, 0);
auto hKron = hGate;
for (std::size_t i = 0; i < numQubits - 1; ++i) {
hKron = dd->kronecker(hKron, hGate, 1);
}
dd->partialTrace(hKron,
{false, false, false, false, false, false, true, true});
const auto& stats = computeTable.getStats();
ASSERT_EQ(stats.lookups, 3);
}

TEST(DDPackageTest, StateGenerationManipulation) {
Expand Down
Loading