Skip to content

Commit 6e74f72

Browse files
committed
Merge branch 'drop-masked-neighborhoods'
* exclude self-neighborhoods at masked-away sites * test OverlapCalculator::neighborhoods with various masks Resolve #25
2 parents 23f2a6c + abfe929 commit 6e74f72

File tree

2 files changed

+40
-6
lines changed

2 files changed

+40
-6
lines changed

src/diffpy/srreal/OverlapCalculator.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -344,18 +344,33 @@ vector< unordered_set<int> > OverlapCalculator::neighborhoods() const
344344
rvptr[*idx1] = rvptr[j0];
345345
}
346346
}
347-
// replace any unassigned items with a self-neighborhoods
348-
for (int i = 0; i < cntsites; ++i)
347+
// helper function that initializes neighbor set at site i
348+
auto initsiteset = [&](int i, int& counter) {
349+
if (!rvptr[i]) {
350+
rvptr[i].reset(new SiteSet(&i, &i + 1));
351+
++counter;
352+
}
353+
};
354+
// count initialized items in rvptr
355+
int cnt = count_if(rvptr.begin(), rvptr.end(),
356+
[](SiteSetPointers::value_type& p) { return bool(p); });
357+
// create self-neighborhoods unless prohibited by mask
358+
for (int j0 = 0; j0 < cntsites && cnt < cntsites; ++j0)
349359
{
350-
if (rvptr[i].get()) continue;
351-
rvptr[i].reset(new SiteSet(&i, &i + 1));
360+
for (int j1 = j0; j1 < cntsites; ++j1)
361+
{
362+
if (rvptr[j0] && rvptr[j1]) continue;
363+
if (!this->getPairMask(j0, j1)) continue;
364+
initsiteset(j0, cnt);
365+
initsiteset(j1, cnt);
366+
}
352367
}
353368
unordered_set<const SiteSet*> duplicate;
354369
vector< unordered_set<int> > rv;
355370
SiteSetPointers::const_iterator ssp;
356371
for (ssp = rvptr.begin(); ssp != rvptr.end(); ++ssp)
357372
{
358-
if (duplicate.count(ssp->get())) continue;
373+
if (!ssp->get() || duplicate.count(ssp->get())) continue;
359374
duplicate.insert(ssp->get());
360375
rv.push_back(**ssp);
361376
}

src/tests/TestOverlapCalculator.hpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,16 +270,35 @@ class TestOverlapCalculator : public CxxTest::TestSuite
270270

271271
void test_neighborhoods()
272272
{
273+
auto tb = molc->getAtomRadiiTable();
273274
molc->eval(mnacl);
274275
auto nbhood = molc->neighborhoods();
275276
TS_ASSERT_EQUALS(1u, nbhood.size());
276277
TS_ASSERT_EQUALS(8u, nbhood[0].size());
277-
molc->getAtomRadiiTable()->resetAll();
278+
tb->resetAll();
278279
molc->eval(mnacl);
279280
auto nbsep = molc->neighborhoods();
280281
TS_ASSERT_EQUALS(8u, nbsep.size());
281282
int nb5site = *(nbsep[5].begin());
282283
TS_ASSERT_EQUALS(5, nb5site);
284+
molc->maskAllPairs(false);
285+
molc->eval(mnacl);
286+
auto nbdark = molc->neighborhoods();
287+
TS_ASSERT(nbdark.empty());
288+
molc->setTypeMask("Na1+", "Na1+", true);
289+
molc->eval(mnacl);
290+
auto nbnanasep = molc->neighborhoods();
291+
TS_ASSERT_EQUALS(4u, nbnanasep.size());
292+
int nb3site = *(nbnanasep[3].begin());
293+
TS_ASSERT_EQUALS(3, nb3site);
294+
tb->setCustom("Na1+", 5);
295+
tb->setCustom("Cl1-", 5);
296+
molc->eval(mnacl);
297+
auto nbnana = molc->neighborhoods();
298+
TS_ASSERT_EQUALS(1u, nbnana.size());
299+
TS_ASSERT_EQUALS(4u, nbnana[0].size());
300+
TS_ASSERT(nbnana[0].count(0));
301+
TS_ASSERT(!nbnana[0].count(4));
283302
}
284303

285304

0 commit comments

Comments
 (0)