Skip to content

Enable the use of the more modern 'mpi_f08' module for MPI #1142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -1257,11 +1257,42 @@ endif
exit 1; \
fi


mpi_f08_test:
@#
@# MPAS_MPI_F08 will be set to:
@# 0 if no mpi_f08 module support was detected
@# 1 if the MPI library provides an mpi_f08 module
@#
$(info Checking for mpi_f08 support...)
$(eval MPAS_MPI_F08 := $(shell $\
printf "program main\n$\
& use mpi_f08, only : MPI_Init, MPI_Comm\n$\
& integer :: ierr\n$\
& type (MPI_Comm) :: comm\n$\
& call MPI_Init(ierr)\n$\
end program main\n" | sed 's/&/ /' > mpi_f08.f90; $\
$\
$(FC) mpi_f08.f90 -o mpi_f08.x $(FFLAGS) $(LDFLAGS) > /dev/null 2>&1; $\
mpi_f08_status=$$?; $\
rm -f mpi_f08.f90 mpi_f08.x; $\
if [ $$mpi_f08_status -eq 0 ]; then $\
printf "1"; $\
else $\
printf "0"; $\
fi $\
))
$(if $(findstring 0,$(MPAS_MPI_F08)), $(eval MPI_F08_MESSAGE = "Using the mpi module."), )
$(if $(findstring 0,$(MPAS_MPI_F08)), $(info No working mpi_f08 module detected; using mpi module.))
$(if $(findstring 1,$(MPAS_MPI_F08)), $(eval override CPPFLAGS += -DMPAS_USE_MPI_F08), )
$(if $(findstring 1,$(MPAS_MPI_F08)), $(eval MPI_F08_MESSAGE = "Using the mpi_f08 module."), )
$(if $(findstring 1,$(MPAS_MPI_F08)), $(info mpi_f08 module detected.))

ifneq "$(PIO)" ""
MAIN_DEPS = openmp_test openacc_test pio_test
MAIN_DEPS = openmp_test openacc_test pio_test mpi_f08_test
override CPPFLAGS += "-DMPAS_PIO_SUPPORT"
else
MAIN_DEPS = openmp_test openacc_test
MAIN_DEPS = openmp_test openacc_test mpi_f08_test
IO_MESSAGE = "Using the SMIOL library."
override CPPFLAGS += "-DMPAS_SMIOL_SUPPORT"
endif
Expand Down Expand Up @@ -1300,6 +1331,7 @@ endif
@echo $(PRECISION_MESSAGE)
@echo $(DEBUG_MESSAGE)
@echo $(PARALLEL_MESSAGE)
@echo $(MPI_F08_MESSAGE)
@echo $(PAPI_MESSAGE)
@echo $(TAU_MESSAGE)
@echo $(OPENMP_MESSAGE)
Expand Down
17 changes: 14 additions & 3 deletions src/driver/mpas_subdriver.F
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,11 @@ module mpas_subdriver
contains


subroutine mpas_init(corelist, domain_ptr, mpi_comm, namelistFileParam, streamsFileParam)
subroutine mpas_init(corelist, domain_ptr, external_comm, namelistFileParam, streamsFileParam)

#ifdef MPAS_USE_MPI_F08
use mpi_f08, only : MPI_Comm
#endif
use mpas_stream_manager, only : MPAS_stream_mgr_init, MPAS_build_stream_filename, MPAS_stream_mgr_validate_streams
use iso_c_binding, only : c_char, c_loc, c_ptr, c_int
use mpas_c_interfacing, only : mpas_f_to_c_string, mpas_c_to_f_string
Expand All @@ -53,7 +56,11 @@ subroutine mpas_init(corelist, domain_ptr, mpi_comm, namelistFileParam, streamsF

type (core_type), intent(inout), pointer :: corelist
type (domain_type), intent(inout), pointer :: domain_ptr
integer, intent(in), optional :: mpi_comm
#ifdef MPAS_USE_MPI_F08
type (MPI_Comm), intent(in), optional :: external_comm
#else
integer, intent(in), optional :: external_comm
#endif
character(len=*), intent(in), optional :: namelistFileParam
character(len=*), intent(in), optional :: streamsFileParam

Expand Down Expand Up @@ -192,7 +199,7 @@ end subroutine xml_stream_get_attributes
!
! Initialize infrastructure
!
call mpas_framework_init_phase1(domain_ptr % dminfo, mpi_comm=mpi_comm)
call mpas_framework_init_phase1(domain_ptr % dminfo, external_comm=external_comm)


#ifdef CORE_ATMOSPHERE
Expand Down Expand Up @@ -303,7 +310,11 @@ end subroutine xml_stream_get_attributes

call mpas_f_to_c_string(domain_ptr % streams_filename, c_filename)
call mpas_f_to_c_string(mesh_stream, c_mesh_stream)
#ifdef MPAS_USE_MPI_F08
c_comm = domain_ptr % dminfo % comm % mpi_val
#else
c_comm = domain_ptr % dminfo % comm
#endif
call xml_stream_get_attributes(c_filename, c_mesh_stream, c_comm, &
c_mesh_filename_temp, c_ref_time_temp, &
c_filename_interval_temp, c_iotype, c_ierr)
Expand Down
6 changes: 5 additions & 1 deletion src/framework/mpas_abort.F
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@ subroutine mpas_dmpar_global_abort(mesg, deferredAbort)!{{{
use mpas_kind_types, only : StrKIND
use mpas_io_units, only : mpas_new_unit
use mpas_threading, only : mpas_threading_get_thread_num

#ifdef _MPI
#ifndef NOMPIMOD
#ifdef MPAS_USE_MPI_F08
use mpi_f08
#else
use mpi
#endif
#endif
#endif

implicit none
Expand Down
4 changes: 4 additions & 0 deletions src/framework/mpas_derived_types.F
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ module mpas_derived_types
use smiolf, only : SMIOLf_context, SMIOLf_decomp, SMIOLf_file, SMIOL_offset_kind
#endif

#ifdef MPAS_USE_MPI_F08
use mpi_f08, only : MPI_Request, MPI_Comm, MPI_Info
#endif

use ESMF

#include "mpas_attlist_types.inc"
Expand Down
46 changes: 41 additions & 5 deletions src/framework/mpas_dmpar.F
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@ module mpas_dmpar

#ifdef _MPI
#ifndef NOMPIMOD
#ifdef MPAS_USE_MPI_F08
use mpi_f08
#else
use mpi
#endif
#endif
#endif

implicit none
Expand All @@ -42,16 +46,31 @@ module mpas_dmpar
#ifdef NOMPIMOD
include 'mpif.h'
#endif
#ifdef MPAS_USE_MPI_F08
type (MPI_Datatype), parameter :: MPI_INTEGERKIND = MPI_INTEGER
type (MPI_Datatype), parameter :: MPI_2INTEGERKIND = MPI_2INTEGER
#else
integer, parameter :: MPI_INTEGERKIND = MPI_INTEGER
integer, parameter :: MPI_2INTEGERKIND = MPI_2INTEGER
#endif

#ifdef SINGLE_PRECISION
#ifdef MPAS_USE_MPI_F08
type (MPI_Datatype), parameter :: MPI_REALKIND = MPI_REAL
type (MPI_Datatype), parameter :: MPI_2REALKIND = MPI_2REAL
#else
integer, parameter :: MPI_REALKIND = MPI_REAL
integer, parameter :: MPI_2REALKIND = MPI_2REAL
#endif
#else
#ifdef MPAS_USE_MPI_F08
type (MPI_Datatype), parameter :: MPI_REALKIND = MPI_DOUBLE_PRECISION
type (MPI_Datatype), parameter :: MPI_2REALKIND = MPI_2DOUBLE_PRECISION
#else
integer, parameter :: MPI_REALKIND = MPI_DOUBLE_PRECISION
integer, parameter :: MPI_2REALKIND = MPI_2DOUBLE_PRECISION
#endif
#endif
#endif

integer, parameter, public :: IO_NODE = 0
Expand Down Expand Up @@ -232,12 +251,16 @@ module mpas_dmpar
!> It also setups of the domain information structure.
!
!-----------------------------------------------------------------------
subroutine mpas_dmpar_init(dminfo, mpi_comm)!{{{
subroutine mpas_dmpar_init(dminfo, external_comm)!{{{

implicit none

type (dm_info), intent(inout) :: dminfo !< Input/Output: Domain information
integer, intent(in), optional :: mpi_comm !< Input - Optional: externally-supplied MPI communicator
#ifdef MPAS_USE_MPI_F08
type (MPI_Comm), intent(in), optional :: external_comm !< Input - Optional: externally-supplied MPI communicator
#else
integer, intent(in), optional :: external_comm !< Input - Optional: externally-supplied MPI communicator
#endif

#ifdef _MPI
integer :: mpi_rank, mpi_size
Expand All @@ -246,13 +269,13 @@ subroutine mpas_dmpar_init(dminfo, mpi_comm)!{{{
integer :: desiredThreadLevel, threadLevel
#endif

if ( present(mpi_comm) ) then
if ( present(external_comm) ) then
dminfo % initialized_mpi = .false.
#ifdef MPAS_OPENMP
desiredThreadLevel = MPI_THREAD_FUNNELED
call MPI_Query_thread(threadLevel, mpi_ierr)
#endif
call MPI_Comm_dup(mpi_comm, dminfo % comm, mpi_ierr)
call MPI_Comm_dup(external_comm, dminfo % comm, mpi_ierr)
else
dminfo % initialized_mpi = .true.
#ifdef MPAS_OPENMP
Expand Down Expand Up @@ -1540,7 +1563,12 @@ subroutine mpas_dmpar_get_exch_list(haloLayer, ownedListField, neededListField,
type (field0dInteger), pointer :: offsetCursor, ownedLimitCursor
integer :: nOwnedBlocks, nNeededBlocks
integer :: nOwnedList, nNeededList
integer :: mpi_ierr, mpi_rreq, mpi_sreq
integer :: mpi_ierr
#ifdef MPAS_USE_MPI_F08
type (MPI_Request) :: mpi_rreq, mpi_sreq
#else
integer :: mpi_rreq, mpi_sreq
#endif

type (hashtable) :: neededHash
integer :: nUniqueNeededList, threadNum
Expand Down Expand Up @@ -10073,7 +10101,11 @@ subroutine mpas_dmpar_exch_group_print_buffers(exchangeGroup)!{{{
call mpas_log_write(' proc: $i', intArgs=(/commListPtr % procID/))
call mpas_log_write(' size check: $i $i', intArgs=(/commListPtr % nlist, size( commListPtr % rbuffer )/))
call mpas_log_write(' bufferOffset: $i', intArgs=(/commListPtr % bufferOffset/))
#ifdef MPAS_USE_MPI_F08
call mpas_log_write(' reqId: $i', intArgs=(/commListPtr % reqId % mpi_val/))
#else
call mpas_log_write(' reqId: $i', intArgs=(/commListPtr % reqId/))
#endif
call mpas_log_write(' ibuffer assc: $l', logicArgs=(/ associated( commListPtr % ibuffer ) /) )
call mpas_log_write(' rbuffer assc: $l', logicArgs=(/ associated( commListPtr % rbuffer ) /) )
call mpas_log_write(' next assc: $l', logicArgs=(/ associated( commListPtr % next ) /) )
Expand All @@ -10092,7 +10124,11 @@ subroutine mpas_dmpar_exch_group_print_buffers(exchangeGroup)!{{{
call mpas_log_write(' proc: $i', intArgs=(/ commListPtr % procID /) )
call mpas_log_write(' size check: $i $i', intArgs=(/ commListPtr % nlist, size( commListPtr % rbuffer ) /) )
call mpas_log_write(' bufferOffset: $i', intArgs=(/ commListPtr % bufferOffset /) )
#ifdef MPAS_USE_MPI_F08
call mpas_log_write(' reqId: $i', intArgs=(/ commListPtr % reqId % mpi_val /) )
#else
call mpas_log_write(' reqId: $i', intArgs=(/ commListPtr % reqId /) )
#endif
call mpas_log_write(' ibuffer assc: $l', logicArgs=(/ associated( commListPtr % ibuffer ) /) )
call mpas_log_write(' rbuffer assc: $l', logicArgs=(/ associated( commListPtr % rbuffer ) /) )
call mpas_log_write(' next assc: $l', logicArgs=(/ associated( commListPtr % next ) /) )
Expand Down
13 changes: 12 additions & 1 deletion src/framework/mpas_dmpar_types.inc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@
integer, parameter :: MPAS_DMPAR_BUFFER_EXISTS = 6

type dm_info
integer :: nprocs, my_proc_id, comm, info
#ifdef MPAS_USE_MPI_F08
type (MPI_Comm) :: comm
type (MPI_Info) :: info
#else
integer :: comm
integer :: info
#endif
integer :: nprocs, my_proc_id
logical :: initialized_mpi

! Add variables specific to block decomposition. {{{
Expand Down Expand Up @@ -47,7 +54,11 @@
integer :: bufferOffset
real (kind=RKIND), dimension(:), pointer :: rbuffer => null()
integer, dimension(:), pointer :: ibuffer => null()
#ifdef MPAS_USE_MPI_F08
type (MPI_Request) :: reqID
#else
integer :: reqID
#endif
type (mpas_communication_list), pointer :: next => null()
integer :: commListSize
logical :: received
Expand Down
14 changes: 11 additions & 3 deletions src/framework/mpas_framework.F
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,23 @@ module mpas_framework
!> MPI, the log unit numbers.
!
!-----------------------------------------------------------------------
subroutine mpas_framework_init_phase1(dminfo, mpi_comm)!{{{
subroutine mpas_framework_init_phase1(dminfo, external_comm)!{{{

#ifdef MPAS_USE_MPI_F08
use mpi_f08, only : MPI_Comm
#endif

implicit none

type (dm_info), pointer :: dminfo
integer, intent(in), optional :: mpi_comm
#ifdef MPAS_USE_MPI_F08
type (MPI_Comm), intent(in), optional :: external_comm
#else
integer, intent(in), optional :: external_comm
#endif

allocate(dminfo)
call mpas_dmpar_init(dminfo, mpi_comm)
call mpas_dmpar_init(dminfo, external_comm)

end subroutine mpas_framework_init_phase1!}}}

Expand Down
27 changes: 26 additions & 1 deletion src/framework/mpas_halo.F
Original file line number Diff line number Diff line change
Expand Up @@ -485,16 +485,28 @@ end subroutine mpas_halo_exch_group_add_field
!-----------------------------------------------------------------------
subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)

#ifdef MPAS_USE_MPI_F08
use mpi_f08
#else
use mpi
#endif
use mpas_derived_types, only : domain_type, mpas_halo_group, MPAS_HALO_REAL, MPAS_LOG_CRIT
use mpas_pool_routines, only : mpas_pool_get_array
use mpas_log, only : mpas_log_write

! Parameters
#ifdef MPAS_USE_MPI_F08
#ifdef SINGLE_PRECISION
type (MPI_Datatype), parameter :: MPI_REALKIND = MPI_REAL
#else
type (MPI_Datatype), parameter :: MPI_REALKIND = MPI_DOUBLE_PRECISION
#endif
#else
#ifdef SINGLE_PRECISION
integer, parameter :: MPI_REALKIND = MPI_REAL
#else
integer, parameter :: MPI_REALKIND = MPI_DOUBLE_PRECISION
#endif
#endif

! Arguments
Expand All @@ -508,7 +520,12 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
integer :: i1, i2, j, iNeighbor, iReq
integer :: iHalo, iEndp
integer :: nHalos, nSendEndpts, nRecvEndpts
integer :: rank, comm
integer :: rank
#ifdef MPAS_USE_MPI_F08
type (MPI_Comm) :: comm
#else
integer :: comm
#endif
integer :: mpi_ierr
type (mpas_halo_group), pointer :: group
integer, dimension(:), pointer :: compactHaloInfo
Expand Down Expand Up @@ -550,7 +567,11 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
! the group; all fields should be using the same communicator, so this should not
! be problematic
!
#ifdef MPAS_USE_MPI_F08
comm % mpi_val = group % fields(1) % compactHaloInfo(7)
#else
comm = group % fields(1) % compactHaloInfo(7)
#endif
rank = group % fields(1) % compactHaloInfo(8)


Expand Down Expand Up @@ -992,7 +1013,11 @@ subroutine mpas_halo_compact_halo_info(domain, sendList, recvList, dimSizes, hal
! 7-8: Add MPI info
!
idx = 7
#ifdef MPAS_USE_MPI_F08
compactHaloInfo(idx) = domain % dminfo % comm % mpi_val
#else
compactHaloInfo(idx) = domain % dminfo % comm
#endif
idx = idx + 1
compactHaloInfo(idx) = domain % dminfo % my_proc_id
idx = idx + 1
Expand Down
12 changes: 10 additions & 2 deletions src/framework/mpas_halo_types.inc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@
integer :: nGroupSendNeighbors = MPAS_HALO_INVALID ! Number of unique neighbors that we send to
integer :: groupSendBufSize = MPAS_HALO_INVALID ! Total number of elements to be sent in a group exchange
real (kind=RKIND), dimension(:), pointer :: sendBuf => null() ! Segmented buffer used for outgoing messages
integer, dimension(:), pointer :: sendRequests => null() ! Used internally - MPI request IDs
#ifdef MPAS_USE_MPI_F08
type (MPI_Request), dimension(:), pointer :: sendRequests => null() ! Used internally - MPI request IDs
#else
integer, dimension(:), pointer :: sendRequests => null() ! Used internally - MPI request IDs
#endif
integer, dimension(:,:), pointer :: groupPackOffsets => null() ! Offsets into sendBuf for each neighbor and each field
! dimensioned (nGroupSendNeighbors, nFields)
integer, dimension(:), pointer :: groupSendNeighbors => null() ! List of neighbors we send to
Expand All @@ -60,7 +64,11 @@
integer :: nGroupRecvNeighbors = MPAS_HALO_INVALID ! Number of unique neighbors that we recv from
integer :: groupRecvBufSize = MPAS_HALO_INVALID ! Total number of elements to be recvd in a group exchange
real (kind=RKIND), dimension(:), pointer :: recvBuf => null() ! Segmented buffer used for incoming messages
integer, dimension(:), pointer :: recvRequests => null() ! Used internally - MPI request IDs
#ifdef MPAS_USE_MPI_F08
type (MPI_Request), dimension(:), pointer :: recvRequests => null() ! Used internally - MPI request IDs
#else
integer, dimension(:), pointer :: recvRequests => null() ! Used internally - MPI request IDs
#endif
integer, dimension(:,:), pointer :: groupUnpackOffsets => null() ! Offsets into recvBuf for each neighbor and each field
! dimensioned (nGroupRecvNeighbors, nFields)
integer, dimension(:), pointer :: groupRecvNeighbors => null() ! List of neighbors we recv from
Expand Down
Loading