Skip to content

Commit

Permalink
resize lists at the end for serial runs also
Browse files Browse the repository at this point in the history
  • Loading branch information
lmiq committed Jun 6, 2023
1 parent 65d015a commit 9f10f7c
Showing 1 changed file with 15 additions and 18 deletions.
33 changes: 15 additions & 18 deletions src/neighborlists.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ end
function reduce_lists(list::NeighborList{T}, list_threaded::Vector{<:NeighborList{T}}) where {T}
ranges = cumsum(nb.n for nb in list_threaded)
npairs = ranges[end]
# need to resize here for the case where length(list) < npairs
list = resize!(list, npairs)
@sync for it in eachindex(list_threaded)
lt = list_threaded[it]
Expand Down Expand Up @@ -370,6 +371,10 @@ function neighborlist!(system::InPlaceNeighborList)
output_threaded=system.nb_threaded,
show_progress=system.show_progress
)
# need to resize here to return the correct number of pairs for serial runs
# (this resizing is redundant for parallel runs, since it occurs at the reduction function)
# before updating
!system.parallel && resize!(system.nb, system.nb.n)
return system.nb.list
end

Expand Down Expand Up @@ -781,9 +786,16 @@ end

end

@testitem "lists match" begin
import CellListMap
@test CellListMap.TestingNeighborLists.test_threaded_lists()
@testitem "list buffer reduction" begin
using CellListMap, StaticArrays
x = [ SVector{3,Float64}(0,0,0), SVector{3,Float64}(0,0,0.05) ];
system = InPlaceNeighborList(x=x, cutoff=0.1, unitcell=[1,1,1], parallel=false)
list0 = neighborlist!(system) # correct
@test length(list0) == 1
xnew = [ SVector{3,Float64}(0,0,0), SVector{3,Float64}(0,0,0.2) ];
update!(system, xnew)
list1 = neighborlist!(system)
@test length(list1) == 0
end

#
Expand Down Expand Up @@ -964,19 +976,4 @@ function lists_match(
return lists_match
end

# Test the successive generation of lists and updates
function test_threaded_lists()
x = rand(SVector{3,Float64}, 10^3);
cutoff = 0.1
system = InPlaceNeighborList(x=x, cutoff=cutoff, unitcell=[1,1,1], parallel=false)
x_new = rand(SVector{3,Float64}, 10^3);
for _=1:1000
x_new = rand(SVector{3,Float64}, 10^3);
update!(system, x_new)
end
list1 = copy(neighborlist!(system))
list2 = copy(neighborlist(x_new, 0.1; unitcell = [1, 1, 1]))
return lists_match(list1, list2, cutoff; verbose = true)
end

end # module TestingNeighborLists

0 comments on commit 9f10f7c

Please sign in to comment.