Skip to content

Batch inference #90

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion src/nf/nf_network.f90
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ module nf_network
procedure, private :: forward_3d
procedure, private :: output_1d
procedure, private :: output_3d
procedure, private :: output_batch_1d
procedure, private :: output_batch_3d

generic :: forward => forward_1d, forward_3d
generic :: output => output_1d, output_3d
generic :: output => output_1d, output_3d, output_batch_1d, output_batch_3d

end type network

Expand Down Expand Up @@ -107,6 +109,26 @@ module function output_3d(self, input) result(res)
!! Output of the network
end function output_3d

module function output_batch_1d(self, input) result(res)
!! Return the output of the network given an input batch of 3-d data.
class(network), intent(in out) :: self
!! Network instance
real, intent(in) :: input(:,:)
!! Input data; the last dimension is the batch
real, allocatable :: res(:,:)
!! Output of the network; the last dimension is the batch
end function output_batch_1d

module function output_batch_3d(self, input) result(res)
!! Return the output of the network given an input batch of 3-d data.
class(network), intent(in out) :: self
!! Network instance
real, intent(in) :: input(:,:,:,:)
!! Input data; the last dimension is the batch
real, allocatable :: res(:,:)
!! Output of the network; the last dimension is the batch
end function output_batch_3d

end interface output

interface
Expand Down
63 changes: 63 additions & 0 deletions src/nf/nf_network_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,69 @@ module function output_3d(self, input) result(res)
end function output_3d


module function output_batch_1d(self, input) result(res)
class(network), intent(in out) :: self
real, intent(in) :: input(:,:)
real, allocatable :: res(:,:)
integer :: i, batch_size, num_layers, output_size

num_layers = size(self % layers)
batch_size = size(input, dim=rank(input))
output_size = product(self % layers(num_layers) % layer_shape)

allocate(res(output_size, batch_size))

batch: do concurrent(i = 1:size(res, dim=2))

call self % forward(input(:,i))

select type(output_layer => self % layers(num_layers) % p)
type is(dense_layer)
res(:,i) = output_layer % output
type is(flatten_layer)
res(:,i) = output_layer % output
class default
error stop 'network % output not implemented for this output layer'
end select

end do batch

end function output_batch_1d


module function output_batch_3d(self, input) result(res)
class(network), intent(in out) :: self
real, intent(in) :: input(:,:,:,:)
real, allocatable :: res(:,:)
integer :: i, batch_size, num_layers, output_size

num_layers = size(self % layers)
batch_size = size(input, dim=rank(input))
output_size = product(self % layers(num_layers) % layer_shape)

allocate(res(output_size, batch_size))

batch: do concurrent(i = 1:batch_size)

call self % forward(input(:,:,:,i))

select type(output_layer => self % layers(num_layers) % p)
type is(conv2d_layer)
!FIXME flatten the result for now; find a better solution
res(:,i) = pack(output_layer % output, .true.)
type is(dense_layer)
res(:,i) = output_layer % output
type is(flatten_layer)
res(:,i) = output_layer % output
class default
error stop 'network % output not implemented for this output layer'
end select

end do batch

end function output_batch_3d


module subroutine print_info(self)
class(network), intent(in) :: self
call self % layers % print_info()
Expand Down