Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change API to be able to pass array(:,:,:) kind of objects #772

Merged
merged 3 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
change api to be able to pass array(:,:,:) kind of objects
  • Loading branch information
toxa81 committed Oct 5, 2022
commit ba33fc2ec322708e69b9dd5fa8daa2140a7e4027
11 changes: 7 additions & 4 deletions src/api/generate_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, attr_str):
idx1 = attr_str.find('(', idx)
idx2 = attr_str.find(')', idx1)
if idx1 != -1 and idx2 != -1 and idx2 > idx1:
self.dimension_ = attr_str[idx:idx2+1]
self.dimension_ = attr_str[idx1:idx2+1]
attr_str = attr_str[:idx] + attr_str[idx2+1:]
else:
raise Exception(f'wrong attribute string: {attr_str}')
Expand Down Expand Up @@ -138,10 +138,13 @@ def write_fortran_api_arg(self, out):
out.write(', value')
else:
out.write(', target')
if self.attr().dimension() != 'scalar':
out.write(f', {self.attr().dimension()}')
#if self.attr().dimension() != 'scalar':
# out.write(f', {self.attr().dimension()}')

out.write(f', intent({self.attr().intent()}) :: {self.name()}\n')
out.write(f', intent({self.attr().intent()}) :: {self.name()}')
if self.attr().dimension() != 'scalar':
out.write(f'{self.attr().dimension()}')
out.write('\n')

def write_interface_api_arg(self, out):
if self.type_id() == 'func':
Expand Down
116 changes: 58 additions & 58 deletions src/api/sirius.f90
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ subroutine sirius_set_parameters(handler,lmax_apw,lmax_rho,lmax_pot,num_fv_state
integer, optional, target, intent(in) :: num_mag_dims
real(8), optional, target, intent(in) :: pw_cutoff
real(8), optional, target, intent(in) :: gk_cutoff
integer, optional, target, dimension(3), intent(in) :: fft_grid_size
integer, optional, target, intent(in) :: fft_grid_size(3)
integer, optional, target, intent(in) :: auto_rmt
logical, optional, target, intent(in) :: gamma_point
logical, optional, target, intent(in) :: use_symmetry
Expand Down Expand Up @@ -805,7 +805,7 @@ subroutine sirius_get_parameters(handler,lmax_apw,lmax_rho,lmax_pot,num_fv_state
integer, optional, target, intent(out) :: num_mag_dims
real(8), optional, target, intent(out) :: pw_cutoff
real(8), optional, target, intent(out) :: gk_cutoff
integer, optional, target, dimension(3), intent(out) :: fft_grid_size
integer, optional, target, intent(out) :: fft_grid_size(3)
integer, optional, target, intent(out) :: auto_rmt
logical, optional, target, intent(out) :: gamma_point
logical, optional, target, intent(out) :: use_symmetry
Expand Down Expand Up @@ -1055,7 +1055,7 @@ subroutine sirius_set_mpi_grid_dims(handler,ndims,dims,error_code)
!
type(sirius_context_handler), target, intent(in) :: handler
integer, target, intent(in) :: ndims
integer, target, dimension(ndims), intent(in) :: dims
integer, target, intent(in) :: dims(ndims)
integer, optional, target, intent(out) :: error_code
!
type(C_PTR) :: handler_ptr
Expand Down Expand Up @@ -1098,9 +1098,9 @@ subroutine sirius_set_lattice_vectors(handler,a1,a2,a3,error_code)
implicit none
!
type(sirius_context_handler), target, intent(in) :: handler
real(8), target, dimension(3), intent(in) :: a1
real(8), target, dimension(3), intent(in) :: a2
real(8), target, dimension(3), intent(in) :: a3
real(8), target, intent(in) :: a1(3)
real(8), target, intent(in) :: a2(3)
real(8), target, intent(in) :: a3(3)
integer, optional, target, intent(out) :: error_code
!
type(C_PTR) :: handler_ptr
Expand Down Expand Up @@ -1274,8 +1274,8 @@ subroutine sirius_set_periodic_function_ptr(handler,label,f_mt,f_rg,error_code)
!
type(sirius_ground_state_handler), target, intent(in) :: handler
character(*), target, intent(in) :: label
real(8), optional, target, intent(in) :: f_mt
real(8), optional, target, intent(in) :: f_rg
real(8), optional, target, intent(in) :: f_mt(:,:,:)
real(8), optional, target, intent(in) :: f_rg(:)
integer, optional, target, intent(out) :: error_code
!
type(C_PTR) :: handler_ptr
Expand Down Expand Up @@ -1332,7 +1332,7 @@ subroutine sirius_set_periodic_function(handler,label,f_rg,f_rg_global,error_cod
!
type(sirius_ground_state_handler), target, intent(in) :: handler
character(*), target, intent(in) :: label
real(8), optional, target, dimension(*), intent(in) :: f_rg
real(8), optional, target, intent(in) :: f_rg(:)
logical, optional, target, intent(in) :: f_rg_global
integer, optional, target, intent(out) :: error_code
!
Expand Down Expand Up @@ -1399,11 +1399,11 @@ subroutine sirius_get_periodic_function(handler,label,f_mt,lmmax,max_num_mt_poin
!
type(sirius_ground_state_handler), target, intent(in) :: handler
character(*), target, intent(in) :: label
real(8), optional, target, intent(out) :: f_mt
real(8), optional, target, intent(out) :: f_mt(:,:,:)
integer, optional, target, intent(in) :: lmmax
integer, optional, target, intent(in) :: max_num_mt_points
integer, optional, target, intent(in) :: num_atoms
real(8), optional, target, dimension(*), intent(out) :: f_rg
real(8), optional, target, intent(out) :: f_rg(:)
integer, optional, target, intent(in) :: num_rg_points
integer, optional, target, intent(out) :: error_code
!
Expand Down Expand Up @@ -1489,8 +1489,8 @@ subroutine sirius_create_kset(handler,num_kpoints,kpoints,kpoint_weights,init_ks
!
type(sirius_context_handler), target, intent(in) :: handler
integer, target, intent(in) :: num_kpoints
real(8), target, dimension(3,num_kpoints), intent(in) :: kpoints
real(8), target, dimension(num_kpoints), intent(in) :: kpoint_weights
real(8), target, intent(in) :: kpoints(3,num_kpoints)
real(8), target, intent(in) :: kpoint_weights(num_kpoints)
logical, target, intent(in) :: init_kset
type(sirius_kpoint_set_handler), target, intent(out) :: kset_handler
integer, optional, target, intent(out) :: error_code
Expand Down Expand Up @@ -1553,8 +1553,8 @@ subroutine sirius_create_kset_from_grid(handler,k_grid,k_shift,use_symmetry,kset
implicit none
!
type(sirius_context_handler), target, intent(in) :: handler
integer, target, dimension(3), intent(in) :: k_grid
integer, target, dimension(3), intent(in) :: k_shift
integer, target, intent(in) :: k_grid(3)
integer, target, intent(in) :: k_shift(3)
logical, target, intent(in) :: use_symmetry
type(sirius_kpoint_set_handler), target, intent(out) :: kset_handler
integer, optional, target, intent(out) :: error_code
Expand Down Expand Up @@ -1646,7 +1646,7 @@ subroutine sirius_initialize_kset(ks_handler,count,error_code)
implicit none
!
type(sirius_kpoint_set_handler), target, intent(in) :: ks_handler
integer, optional, target, dimension(*), intent(in) :: count
integer, optional, target, intent(in) :: count(*)
integer, optional, target, intent(out) :: error_code
!
type(C_PTR) :: ks_handler_ptr
Expand Down Expand Up @@ -2060,7 +2060,7 @@ subroutine sirius_set_atom_type_radial_grid(handler,label,num_radial_points,radi
type(sirius_context_handler), target, intent(in) :: handler
character(*), target, intent(in) :: label
integer, target, intent(in) :: num_radial_points
real(8), target, dimension(num_radial_points), intent(in) :: radial_points
real(8), target, intent(in) :: radial_points(num_radial_points)
integer, optional, target, intent(out) :: error_code
!
type(C_PTR) :: handler_ptr
Expand Down Expand Up @@ -2116,7 +2116,7 @@ subroutine sirius_set_atom_type_radial_grid_inf(handler,label,num_radial_points,
type(sirius_context_handler), target, intent(in) :: handler
character(*), target, intent(in) :: label
integer, target, intent(in) :: num_radial_points
real(8), target, dimension(num_radial_points), intent(in) :: radial_points
real(8), target, intent(in) :: radial_points(num_radial_points)
integer, optional, target, intent(out) :: error_code
!
type(C_PTR) :: handler_ptr
Expand Down Expand Up @@ -2178,7 +2178,7 @@ subroutine sirius_add_atom_type_radial_function(handler,atom_type,label,rf,num_p
type(sirius_context_handler), target, intent(in) :: handler
character(*), target, intent(in) :: atom_type
character(*), target, intent(in) :: label
real(8), target, dimension(num_points), intent(in) :: rf
real(8), target, intent(in) :: rf(num_points)
integer, target, intent(in) :: num_points
integer, optional, target, intent(in) :: n
integer, optional, target, intent(in) :: l
Expand Down Expand Up @@ -2369,7 +2369,7 @@ subroutine sirius_set_atom_type_dion(handler,label,num_beta,dion,error_code)
type(sirius_context_handler), target, intent(in) :: handler
character(*), target, intent(in) :: label
integer, target, intent(in) :: num_beta
real(8), target, dimension(num_beta, num_beta), intent(in) :: dion
real(8), target, intent(in) :: dion(num_beta, num_beta)
integer, optional, target, intent(out) :: error_code
!
type(C_PTR) :: handler_ptr
Expand Down Expand Up @@ -2424,7 +2424,7 @@ subroutine sirius_set_atom_type_paw(handler,label,core_energy,occupations,num_oc
type(sirius_context_handler), target, intent(in) :: handler
character(*), target, intent(in) :: label
real(8), target, intent(in) :: core_energy
real(8), target, dimension(num_occ), intent(in) :: occupations
real(8), target, intent(in) :: occupations(num_occ)
integer, target, intent(in) :: num_occ
integer, optional, target, intent(out) :: error_code
!
Expand Down Expand Up @@ -2483,8 +2483,8 @@ subroutine sirius_add_atom(handler,label,position,vector_field,error_code)
!
type(sirius_context_handler), target, intent(in) :: handler
character(*), target, intent(in) :: label
real(8), target, dimension(3), intent(in) :: position
real(8), optional, target, dimension(3), intent(in) :: vector_field
real(8), target, intent(in) :: position(3)
real(8), optional, target, intent(in) :: vector_field(3)
integer, optional, target, intent(out) :: error_code
!
type(C_PTR) :: handler_ptr
Expand Down Expand Up @@ -2537,7 +2537,7 @@ subroutine sirius_set_atom_position(handler,ia,position,error_code)
!
type(sirius_context_handler), target, intent(in) :: handler
integer, target, intent(in) :: ia
real(8), target, dimension(3), intent(in) :: position
real(8), target, intent(in) :: position(3)
integer, optional, target, intent(out) :: error_code
!
type(C_PTR) :: handler_ptr
Expand Down Expand Up @@ -2585,10 +2585,10 @@ subroutine sirius_set_pw_coeffs(handler,label,pw_coeffs,transform_to_rg,ngv,gvl,
!
type(sirius_ground_state_handler), target, intent(in) :: handler
character(*), target, intent(in) :: label
complex(8), target, dimension(*), intent(in) :: pw_coeffs
complex(8), target, intent(in) :: pw_coeffs(*)
logical, optional, target, intent(in) :: transform_to_rg
integer, optional, target, intent(in) :: ngv
integer, optional, target, dimension(3, *), intent(in) :: gvl
integer, optional, target, intent(in) :: gvl(3, *)
integer, optional, target, intent(in) :: comm
integer, optional, target, intent(out) :: error_code
!
Expand Down Expand Up @@ -2669,9 +2669,9 @@ subroutine sirius_get_pw_coeffs(handler,label,pw_coeffs,ngv,gvl,comm,error_code)
!
type(sirius_ground_state_handler), target, intent(in) :: handler
character(*), target, intent(in) :: label
complex(8), target, dimension(*), intent(in) :: pw_coeffs
complex(8), target, intent(in) :: pw_coeffs(*)
integer, optional, target, intent(in) :: ngv
integer, optional, target, dimension(3, *), intent(in) :: gvl
integer, optional, target, intent(in) :: gvl(3, *)
integer, optional, target, intent(in) :: comm
integer, optional, target, intent(out) :: error_code
!
Expand Down Expand Up @@ -2979,7 +2979,7 @@ subroutine sirius_set_band_occupancies(ks_handler,ik,ispn,band_occupancies,error
type(sirius_kpoint_set_handler), target, intent(in) :: ks_handler
integer, target, intent(in) :: ik
integer, target, intent(in) :: ispn
real(8), target, dimension(*), intent(in) :: band_occupancies
real(8), target, intent(in) :: band_occupancies(*)
integer, optional, target, intent(out) :: error_code
!
type(C_PTR) :: ks_handler_ptr
Expand Down Expand Up @@ -3029,7 +3029,7 @@ subroutine sirius_get_band_occupancies(ks_handler,ik,ispn,band_occupancies,error
type(sirius_kpoint_set_handler), target, intent(in) :: ks_handler
integer, target, intent(in) :: ik
integer, target, intent(in) :: ispn
real(8), target, dimension(*), intent(out) :: band_occupancies
real(8), target, intent(out) :: band_occupancies(:)
integer, optional, target, intent(out) :: error_code
!
type(C_PTR) :: ks_handler_ptr
Expand Down Expand Up @@ -3079,7 +3079,7 @@ subroutine sirius_get_band_energies(ks_handler,ik,ispn,band_energies,error_code)
type(sirius_kpoint_set_handler), target, intent(in) :: ks_handler
integer, target, intent(in) :: ik
integer, target, intent(in) :: ispn
real(8), target, dimension(*), intent(out) :: band_energies
real(8), target, intent(out) :: band_energies(:)
integer, optional, target, intent(out) :: error_code
!
type(C_PTR) :: ks_handler_ptr
Expand Down Expand Up @@ -3174,7 +3174,7 @@ subroutine sirius_get_forces(handler,label,forces,error_code)
!
type(sirius_ground_state_handler), target, intent(in) :: handler
character(*), target, intent(in) :: label
real(8), target, dimension(3, *), intent(out) :: forces
real(8), target, intent(out) :: forces(3, *)
integer, optional, target, intent(out) :: error_code
!
type(C_PTR) :: handler_ptr
Expand Down Expand Up @@ -3221,7 +3221,7 @@ subroutine sirius_get_stress_tensor(handler,label,stress_tensor,error_code)
!
type(sirius_ground_state_handler), target, intent(in) :: handler
character(*), target, intent(in) :: label
real(8), target, dimension(3, 3), intent(out) :: stress_tensor
real(8), target, intent(out) :: stress_tensor(3, 3)
integer, optional, target, intent(out) :: error_code
!
type(C_PTR) :: handler_ptr
Expand Down Expand Up @@ -3320,10 +3320,10 @@ subroutine sirius_get_wave_functions(ks_handler,vkl,spin,num_gvec_loc,gvec_loc,e
implicit none
!
type(sirius_kpoint_set_handler), target, intent(in) :: ks_handler
real(8), optional, target, dimension(3), intent(in) :: vkl
real(8), optional, target, intent(in) :: vkl(3)
integer, optional, target, intent(in) :: spin
integer, optional, target, intent(in) :: num_gvec_loc
integer, optional, target, dimension(3, *), intent(in) :: gvec_loc
integer, optional, target, intent(in) :: gvec_loc(3, *)
complex(8), optional, target, intent(out) :: evec
integer, optional, target, intent(in) :: ld
integer, optional, target, intent(in) :: num_spin_comp
Expand Down Expand Up @@ -3637,7 +3637,7 @@ subroutine sirius_generate_coulomb_potential(handler,vh_el,error_code)
implicit none
!
type(sirius_ground_state_handler), target, intent(in) :: handler
real(8), optional, target, dimension(*), intent(out) :: vh_el
real(8), optional, target, intent(out) :: vh_el(*)
integer, optional, target, intent(out) :: error_code
!
type(C_PTR) :: handler_ptr
Expand Down Expand Up @@ -3859,9 +3859,9 @@ subroutine sirius_get_gvec_arrays(handler,gvec,gvec_cart,gvec_len,index_by_gvec,
implicit none
!
type(sirius_context_handler), target, intent(in) :: handler
integer, optional, target, dimension(3, *), intent(in) :: gvec
real(8), optional, target, dimension(3, *), intent(in) :: gvec_cart
real(8), optional, target, dimension(*), intent(in) :: gvec_len
integer, optional, target, intent(in) :: gvec(3, *)
real(8), optional, target, intent(in) :: gvec_cart(3, *)
real(8), optional, target, intent(in) :: gvec_len(*)
integer, optional, target, intent(in) :: index_by_gvec
integer, optional, target, intent(out) :: error_code
!
Expand Down Expand Up @@ -4041,11 +4041,11 @@ subroutine sirius_get_gkvec_arrays(ks_handler,ik,num_gkvec,gvec_index,gkvec,gkve
type(sirius_kpoint_set_handler), target, intent(in) :: ks_handler
integer, target, intent(in) :: ik
integer, target, intent(out) :: num_gkvec
integer, target, dimension(*), intent(out) :: gvec_index
real(8), target, dimension(3, *), intent(out) :: gkvec
real(8), target, dimension(3, *), intent(out) :: gkvec_cart
real(8), target, dimension(*), intent(out) :: gkvec_len
real(8), target, dimension(2, *), intent(out) :: gkvec_tp
integer, target, intent(out) :: gvec_index(*)
real(8), target, intent(out) :: gkvec(3, *)
real(8), target, intent(out) :: gkvec_cart(3, *)
real(8), target, intent(out) :: gkvec_len(*)
real(8), target, intent(out) :: gkvec_tp(2, *)
integer, optional, target, intent(out) :: error_code
!
type(C_PTR) :: ks_handler_ptr
Expand Down Expand Up @@ -4110,8 +4110,8 @@ subroutine sirius_get_step_function(handler,cfunig,cfunrg,num_rg_points,error_co
implicit none
!
type(sirius_context_handler), target, intent(in) :: handler
complex(8), target, dimension(*), intent(out) :: cfunig
real(8), target, dimension(*), intent(out) :: cfunrg
complex(8), target, intent(out) :: cfunig(*)
real(8), target, intent(out) :: cfunrg(*)
integer, target, intent(in) :: num_rg_points
integer, optional, target, intent(out) :: error_code
!
Expand Down Expand Up @@ -4507,7 +4507,7 @@ subroutine sirius_set_equivalent_atoms(handler,equivalent_atoms,error_code)
implicit none
!
type(sirius_context_handler), target, intent(in) :: handler
integer, target, dimension(*), intent(in) :: equivalent_atoms
integer, target, intent(in) :: equivalent_atoms(*)
integer, optional, target, intent(out) :: error_code
!
type(C_PTR) :: handler_ptr
Expand Down Expand Up @@ -5044,7 +5044,7 @@ subroutine sirius_get_fv_eigen_values(handler,ik,fv_eval,num_fv_states,error_cod
!
type(sirius_kpoint_set_handler), target, intent(in) :: handler
integer, target, intent(in) :: ik
real(8), target, intent(out) :: fv_eval
real(8), target, intent(out) :: fv_eval(:)
integer, target, intent(in) :: num_fv_states
integer, optional, target, intent(out) :: error_code
!
Expand Down Expand Up @@ -5621,10 +5621,10 @@ subroutine sirius_add_hubbard_atom_pair(handler,atom_pair,translation,n,l,coupli
implicit none
!
type(sirius_context_handler), target, intent(in) :: handler
integer, target, dimension(2), intent(in) :: atom_pair
integer, target, dimension(3), intent(in) :: translation
integer, target, dimension(2), intent(in) :: n
integer, target, dimension(2), intent(in) :: l
integer, target, intent(in) :: atom_pair(2)
integer, target, intent(in) :: translation(3)
integer, target, intent(in) :: n(2)
integer, target, intent(in) :: l(2)
real(8), target, intent(in) :: coupling
integer, optional, target, intent(out) :: error_code
!
Expand Down Expand Up @@ -5690,13 +5690,13 @@ subroutine sirius_linear_solver(handler,vkq,num_gvec_kq_loc,gvec_kq_loc,dpsi,psi
implicit none
!
type(sirius_ground_state_handler), target, intent(in) :: handler
real(8), target, dimension(3), intent(in) :: vkq
real(8), target, intent(in) :: vkq(3)
integer, target, intent(in) :: num_gvec_kq_loc
integer, target, dimension(3, num_gvec_kq_loc), intent(in) :: gvec_kq_loc
complex(8), target, dimension(ld, num_spin_comp), intent(inout) :: dpsi
complex(8), target, dimension(ld, num_spin_comp), intent(in) :: psi
real(8), target, dimension(*), intent(in) :: eigvals
complex(8), target, dimension(ld, num_spin_comp), intent(inout) :: dvpsi
integer, target, intent(in) :: gvec_kq_loc(3, num_gvec_kq_loc)
complex(8), target, intent(inout) :: dpsi(ld, num_spin_comp)
complex(8), target, intent(in) :: psi(ld, num_spin_comp)
real(8), target, intent(in) :: eigvals(*)
complex(8), target, intent(inout) :: dvpsi(ld, num_spin_comp)
integer, target, intent(in) :: ld
integer, target, intent(in) :: num_spin_comp
real(8), target, intent(in) :: alpha_pv
Expand Down
Loading