Skip to content

Embedding Layer #205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
embedding_layer: make integer input generics
  • Loading branch information
OneAdder committed Feb 23, 2025
commit fe02beb724e7b400dcc59ed341031b1499db421b
19 changes: 17 additions & 2 deletions src/nf/nf_network.f90
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,19 @@ module nf_network

procedure, private :: evaluate_batch_1d
procedure, private :: forward_1d
procedure, private :: forward_1d_int
procedure, private :: forward_2d
procedure, private :: forward_3d
procedure, private :: predict_1d
procedure, private :: predict_1d_int
procedure, private :: predict_2d
procedure, private :: predict_3d
procedure, private :: predict_batch_1d
procedure, private :: predict_batch_3d

generic :: evaluate => evaluate_batch_1d
generic :: forward => forward_1d, forward_2d, forward_3d
generic :: predict => predict_1d, predict_2d, predict_3d
generic :: forward => forward_1d, forward_1d_int, forward_2d, forward_3d
generic :: predict => predict_1d, predict_1d_int, predict_2d, predict_3d
generic :: predict_batch => predict_batch_1d, predict_batch_3d

end type network
Expand Down Expand Up @@ -95,6 +97,12 @@ module subroutine forward_1d(self, input)
!! 1-d input data
end subroutine forward_1d

module subroutine forward_1d_int(self, input)
!! Same as `forward_1d` except `integer`
class(network), intent(in out) :: self
integer, intent(in) :: input(:)
end subroutine forward_1d_int

module subroutine forward_2d(self, input)
!! Apply a forward pass through the network.
!!
Expand Down Expand Up @@ -137,6 +145,13 @@ module function predict_1d(self, input) result(res)
!! Output of the network
end function predict_1d

module function predict_1d_int(self, input) result(res)
!! Same as `predict_1d` except `integer`
class(network), intent(in out) :: self
integer, intent(in) :: input(:)
real, allocatable :: res(:)
end function predict_1d_int

module function predict_2d(self, input) result(res)
!! Return the output of the network given the input 1-d array.
class(network), intent(in out) :: self
Expand Down
42 changes: 40 additions & 2 deletions src/nf/nf_network_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,6 @@ module subroutine forward_1d(self, input)
select type(input_layer => self % layers(1) % p)
type is(input1d_layer)
call input_layer % set(input)
type is(embedding_layer)
call input_layer % forward(nint(input))
end select

do n = 2, size(self % layers)
Expand All @@ -221,6 +219,21 @@ module subroutine forward_1d(self, input)

end subroutine forward_1d

module subroutine forward_1d_int(self, input)
class(network), intent(in out) :: self
integer, intent(in) :: input(:)
integer :: n

select type(input_layer => self % layers(1) % p)
type is(embedding_layer)
call input_layer % forward(input)
end select

do n = 2, size(self % layers)
call self % layers(n) % forward(self % layers(n - 1))
end do

end subroutine forward_1d_int

module subroutine forward_2d(self, input)
class(network), intent(in out) :: self
Expand Down Expand Up @@ -285,6 +298,31 @@ module function predict_1d(self, input) result(res)

end function predict_1d

module function predict_1d_int(self, input) result(res)
class(network), intent(in out) :: self
integer, intent(in) :: input(:)
real, allocatable :: res(:)
integer :: n, num_layers

num_layers = size(self % layers)

call self % set_training_mode(.false.)
call self % forward(input)
call self % set_training_mode(.true.)

select type(output_layer => self % layers(num_layers) % p)
type is(dense_layer)
res = output_layer % output
type is(dropout_layer)
res = output_layer % output
type is(flatten_layer)
res = output_layer % output
class default
error stop 'network % output not implemented for ' // &
trim(self % layers(num_layers) % name) // ' layer'
end select

end function predict_1d_int

module function predict_2d(self, input) result(res)
class(network), intent(in out) :: self
Expand Down