Skip to content

Introduce optimizer_base_type in support of different optimizers #116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/nf/nf_network.f90
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module nf_network
!! This module provides the network type to create new models.

use nf_layer, only: layer
use nf_optimizers, only: sgd
use nf_optimizers, only: optimizer_base_type

implicit none

Expand Down Expand Up @@ -193,7 +193,7 @@ module subroutine train(self, input_data, output_data, batch_size, &
!! Set to `size(input_data, dim=2)` for a batch gradient descent.
integer, intent(in) :: epochs
!! Number of epochs to run
type(sgd), intent(in) :: optimizer
class(optimizer_base_type), intent(in) :: optimizer
!! Optimizer instance; currently this is an `sgd` optimizer type
!! and it will be made to be a more general optimizer type.
end subroutine train
Expand Down
49 changes: 27 additions & 22 deletions src/nf/nf_network_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
use nf_layer, only: layer
use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape
use nf_loss, only: quadratic_derivative
use nf_optimizers, only: sgd
use nf_optimizers, only: optimizer_base_type, sgd
use nf_parallel, only: tile_indices

implicit none
Expand Down Expand Up @@ -426,7 +426,7 @@ module subroutine train(self, input_data, output_data, batch_size, &
real, intent(in) :: output_data(:,:)
integer, intent(in) :: batch_size
integer, intent(in) :: epochs
type(sgd), intent(in) :: optimizer
class(optimizer_base_type), intent(in) :: optimizer

real :: pos
integer :: dataset_size
Expand All @@ -439,26 +439,31 @@ module subroutine train(self, input_data, output_data, batch_size, &
epoch_loop: do n = 1, epochs
batch_loop: do i = 1, dataset_size / batch_size

! Pull a random mini-batch from the dataset
call random_number(pos)
batch_start = int(pos * (dataset_size - batch_size + 1)) + 1
batch_end = batch_start + batch_size - 1

! FIXME shuffle in a way that doesn't require co_broadcast
call co_broadcast(batch_start, 1)
call co_broadcast(batch_end, 1)

! Distribute the batch in nearly equal pieces to all images
indices = tile_indices(batch_size)
istart = indices(1) + batch_start - 1
iend = indices(2) + batch_start - 1

do concurrent(j = istart:iend)
call self % forward(input_data(:,j))
call self % backward(output_data(:,j))
end do

call self % update(optimizer % learning_rate / batch_size)
! Pull a random mini-batch from the dataset
call random_number(pos)
batch_start = int(pos * (dataset_size - batch_size + 1)) + 1
batch_end = batch_start + batch_size - 1

! FIXME shuffle in a way that doesn't require co_broadcast
call co_broadcast(batch_start, 1)
call co_broadcast(batch_end, 1)

! Distribute the batch in nearly equal pieces to all images
indices = tile_indices(batch_size)
istart = indices(1) + batch_start - 1
iend = indices(2) + batch_start - 1

do concurrent(j = istart:iend)
call self % forward(input_data(:,j))
call self % backward(output_data(:,j))
end do

select type (optimizer)
type is (sgd)
call self % update(optimizer % learning_rate / batch_size)
class default
error stop 'Unsupported optimizer'
end select

end do batch_loop
end do epoch_loop
Expand Down
8 changes: 6 additions & 2 deletions src/nf/nf_optimizers.f90
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@ module nf_optimizers
implicit none

private
public :: sgd
public :: optimizer_base_type, sgd

type :: sgd
type, abstract :: optimizer_base_type
character(:), allocatable :: name
end type optimizer_base_type

type, extends(optimizer_base_type) :: sgd
!! Stochastic Gradient Descent optimizer
real :: learning_rate
real :: momentum = 0 !TODO
Expand Down