Skip to content

Commit 0090140

Browse files
committed
Fix data race in nestkernel/per_thread_bool_indicator.cpp
1 parent ff23f32 commit 0090140

File tree

5 files changed

+62
-46
lines changed

5 files changed

+62
-46
lines changed

nestkernel/connection_manager.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -876,13 +876,13 @@ nest::ConnectionManager::connect_( Node& source,
876876
{
877877
#pragma omp atomic write
878878
has_primary_connections_ = true;
879-
check_primary_connections_[ tid ].set_true();
879+
check_primary_connections_.set_true( tid );
880880
}
881881
else if ( check_secondary_connections_[ tid ].is_false() and not is_primary )
882882
{
883883
#pragma omp atomic write
884884
secondary_connections_exist_ = true;
885-
check_secondary_connections_[ tid ].set_true();
885+
check_secondary_connections_.set_true( tid );
886886
}
887887
}
888888

nestkernel/event_delivery_manager.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ EventDeliveryManager::gather_target_data( const size_t tid )
790790
assert( not kernel().connection_manager.is_source_table_cleared() );
791791

792792
// assume all threads have some work to do
793-
gather_completed_checker_[ tid ].set_false();
793+
gather_completed_checker_.set_false( tid );
794794
assert( gather_completed_checker_.all_false() );
795795

796796
const AssignedRanks assigned_ranks = kernel().vp_manager.get_assigned_ranks( tid );
@@ -802,7 +802,7 @@ EventDeliveryManager::gather_target_data( const size_t tid )
802802
{
803803
// assume this is the last gather round and change to false
804804
// otherwise
805-
gather_completed_checker_[ tid ].set_true();
805+
gather_completed_checker_.set_true( tid );
806806

807807
#pragma omp master
808808
{
@@ -819,7 +819,7 @@ EventDeliveryManager::gather_target_data( const size_t tid )
819819
assigned_ranks, kernel().mpi_manager.get_send_recv_count_target_data_per_rank() );
820820

821821
const bool gather_completed = collocate_target_data_buffers_( tid, assigned_ranks, send_buffer_position );
822-
gather_completed_checker_[ tid ].logical_and( gather_completed );
822+
gather_completed_checker_.logical_and( tid, gather_completed );
823823

824824
if ( gather_completed_checker_.all_true() )
825825
{
@@ -842,7 +842,7 @@ EventDeliveryManager::gather_target_data( const size_t tid )
842842
#pragma omp barrier
843843

844844
const bool distribute_completed = distribute_target_data_buffers_( tid );
845-
gather_completed_checker_[ tid ].logical_and( distribute_completed );
845+
gather_completed_checker_.logical_and( tid, distribute_completed );
846846

847847
// resize mpi buffers, if necessary and allowed
848848
if ( gather_completed_checker_.any_false() and kernel().mpi_manager.adaptive_target_buffers() )
@@ -864,7 +864,7 @@ EventDeliveryManager::gather_target_data_compressed( const size_t tid )
864864
assert( not kernel().connection_manager.is_source_table_cleared() );
865865

866866
// assume all threads have some work to do
867-
gather_completed_checker_[ tid ].set_false();
867+
gather_completed_checker_.set_false( tid );
868868
assert( gather_completed_checker_.all_false() );
869869

870870
const AssignedRanks assigned_ranks = kernel().vp_manager.get_assigned_ranks( tid );
@@ -874,7 +874,7 @@ EventDeliveryManager::gather_target_data_compressed( const size_t tid )
874874
while ( gather_completed_checker_.any_false() )
875875
{
876876
// assume this is the last gather round and change to false otherwise
877-
gather_completed_checker_[ tid ].set_true();
877+
gather_completed_checker_.set_true( tid );
878878

879879
#pragma omp master
880880
{
@@ -891,7 +891,7 @@ EventDeliveryManager::gather_target_data_compressed( const size_t tid )
891891
const bool gather_completed =
892892
collocate_target_data_buffers_compressed_( tid, assigned_ranks, send_buffer_position );
893893

894-
gather_completed_checker_[ tid ].logical_and( gather_completed );
894+
gather_completed_checker_.logical_and( tid, gather_completed );
895895

896896
if ( gather_completed_checker_.all_true() )
897897
{
@@ -916,7 +916,7 @@ EventDeliveryManager::gather_target_data_compressed( const size_t tid )
916916
// all data it is responsible for to buffers. Now combine with information on whether other ranks
917917
// have sent all their data. Note: All threads will return the same value for distribute_completed.
918918
const bool distribute_completed = distribute_target_data_buffers_( tid );
919-
gather_completed_checker_[ tid ].logical_and( distribute_completed );
919+
gather_completed_checker_.logical_and( tid, distribute_completed );
920920

921921
// resize mpi buffers, if necessary and allowed
922922
if ( gather_completed_checker_.any_false() and kernel().mpi_manager.adaptive_target_buffers() )

nestkernel/per_thread_bool_indicator.cpp

Lines changed: 17 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -50,62 +50,47 @@ PerThreadBoolIndicator::initialize( const size_t num_threads, const bool status
5050
kernel().vp_manager.assert_single_threaded();
5151
per_thread_status_.clear();
5252
per_thread_status_.resize( num_threads, BoolIndicatorUInt64( status ) );
53+
size_ = num_threads;
54+
if ( status )
55+
are_true_ = num_threads;
56+
else
57+
are_true_ = 0;
5358
}
5459

5560
bool
5661
PerThreadBoolIndicator::all_false() const
5762
{
5863
#pragma omp barrier
59-
for ( auto it = per_thread_status_.begin(); it < per_thread_status_.end(); ++it )
60-
{
61-
if ( it->is_true() )
62-
{
63-
return false;
64-
}
65-
}
66-
return true;
64+
bool ret = ( are_true_ == 0 );
65+
#pragma omp barrier
66+
return ret;
6767
}
6868

6969
bool
7070
PerThreadBoolIndicator::all_true() const
7171
{
7272
#pragma omp barrier
73-
for ( auto it = per_thread_status_.begin(); it < per_thread_status_.end(); ++it )
74-
{
75-
if ( it->is_false() )
76-
{
77-
return false;
78-
}
79-
}
80-
return true;
73+
bool ret = ( are_true_ == size_ );
74+
#pragma omp barrier
75+
return ret;
8176
}
8277

8378
bool
8479
PerThreadBoolIndicator::any_false() const
8580
{
8681
#pragma omp barrier
87-
for ( auto it = per_thread_status_.begin(); it < per_thread_status_.end(); ++it )
88-
{
89-
if ( it->is_false() )
90-
{
91-
return true;
92-
}
93-
}
94-
return false;
82+
bool ret = ( are_true_ < size_ );
83+
#pragma omp barrier
84+
return ret;
9585
}
9686

9787
bool
9888
PerThreadBoolIndicator::any_true() const
9989
{
10090
#pragma omp barrier
101-
for ( auto it = per_thread_status_.begin(); it < per_thread_status_.end(); ++it )
102-
{
103-
if ( it->is_true() )
104-
{
105-
return true;
106-
}
107-
}
108-
return false;
91+
bool ret = ( are_true_ > 0 );
92+
#pragma omp barrier
93+
return ret;
10994
}
11095

11196
} // namespace nest

nestkernel/per_thread_bool_indicator.h

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,17 @@ class BoolIndicatorUInt64
5252
bool is_true() const;
5353
bool is_false() const;
5454

55+
56+
protected:
5557
void set_true();
5658
void set_false();
57-
5859
void logical_and( const bool status );
5960

6061
private:
6162
static constexpr std::uint_fast64_t true_uint64 = true;
6263
static constexpr std::uint_fast64_t false_uint64 = false;
6364
std::uint_fast64_t status_;
65+
friend class PerThreadBoolIndicator;
6466
};
6567

6668
inline bool
@@ -106,6 +108,34 @@ class PerThreadBoolIndicator
106108

107109
BoolIndicatorUInt64& operator[]( const size_t tid );
108110

111+
void
112+
set_true( const size_t tid )
113+
{
114+
if ( per_thread_status_[ tid ].is_false() )
115+
{
116+
are_true_++;
117+
per_thread_status_[ tid ].set_true();
118+
}
119+
}
120+
void
121+
set_false( const size_t tid )
122+
{
123+
if ( per_thread_status_[ tid ].is_true() )
124+
{
125+
are_true_--;
126+
per_thread_status_[ tid ].set_false();
127+
}
128+
}
129+
void
130+
logical_and( const size_t tid, const bool status )
131+
{
132+
if ( per_thread_status_[ tid ].is_true() && !status )
133+
{
134+
are_true_--;
135+
per_thread_status_[ tid ].set_false();
136+
}
137+
}
138+
109139
/**
110140
* Resize to the given number of threads and set all elements to false.
111141
*/
@@ -133,6 +163,7 @@ class PerThreadBoolIndicator
133163

134164
private:
135165
std::vector< BoolIndicatorUInt64 > per_thread_status_;
166+
std::atomic< int > size_ { 0 }, are_true_ { 0 };
136167
};
137168

138169
} // namespace nest

nestkernel/source_table.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ SourceTable::clear( const size_t tid )
372372
it->clear();
373373
}
374374
sources_[ tid ].clear();
375-
is_cleared_[ tid ].set_true();
375+
is_cleared_.set_true( tid );
376376
}
377377

378378
inline void
@@ -412,15 +412,15 @@ SourceTable::save_entry_point( const size_t tid )
412412
assert( current_positions_[ tid ].lcid == -1 );
413413
saved_positions_[ tid ].lcid = -1;
414414
}
415-
saved_entry_point_[ tid ].set_true();
415+
saved_entry_point_.set_true( tid );
416416
}
417417
}
418418

419419
inline void
420420
SourceTable::restore_entry_point( const size_t tid )
421421
{
422422
current_positions_[ tid ] = saved_positions_[ tid ];
423-
saved_entry_point_[ tid ].set_false();
423+
saved_entry_point_.set_false( tid );
424424
}
425425

426426
inline void

0 commit comments

Comments
 (0)