Skip to content

Commit

Permalink
minor change for nn lib
Browse files Browse the repository at this point in the history
  • Loading branch information
Miraj98 committed Nov 18, 2022
1 parent dd72440 commit 12508e2
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/neural_net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ impl<T: Dataloader> Model<T> {
}

// Find loss and call backward on it
let loss = nn::cross_entropy(&a, &yt);
let loss = nn::loss::CrossEntropy(&a, &yt);
loss.backward();
loss
}
Expand Down
50 changes: 27 additions & 23 deletions src/nn/mod.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
use crate::tensor::{
ops::{binary_ops::BinaryOps, unary_ops::UnaryOps, reduce_ops::ReduceOps},
Tensor,
};
use std::rc::Rc;
pub mod loss {
use crate::tensor::{
ops::{binary_ops::BinaryOps, reduce_ops::ReduceOps, unary_ops::UnaryOps},
Tensor,
};
use std::rc::Rc;

pub fn cross_entropy(input: &Rc<Tensor>, target: &Rc<Tensor>) -> Rc<Tensor> {
assert_eq!(input.dim(), target.dim());
let input_ln = input.ln();
let lhs = target.mul(&input_ln);
#[allow(non_snake_case)]
pub fn CrossEntropy(input: &Rc<Tensor>, target: &Rc<Tensor>) -> Rc<Tensor> {
assert_eq!(input.dim(), target.dim());
let input_ln = input.ln();
let lhs = target.mul(&input_ln);

let ones = Tensor::ones(input.dim(), Some(true));
let _a = ones.sub(&input);
let _y = ones.sub(&target);
let rhs = _y.mul(&_a.ln());
let ones = Tensor::ones(input.dim(), Some(true));
let _a = ones.sub(&input);
let _y = ones.sub(&target);
let rhs = _y.mul(&_a.ln());

let l = lhs.add(&rhs).mul_scalar(-1.);
let loss = l.mean();
loss
}
let l = lhs.add(&rhs).mul_scalar(-1.);
let loss = l.mean();
loss
}

pub fn quadratic_loss(input: &Rc<Tensor>, target: &Rc<Tensor>) -> Rc<Tensor> {
assert_eq!(input.dim(), target.dim());
let l = target.sub(&input).square().mul_scalar(0.5);
let loss = l.mean();
loss
}
#[allow(non_snake_case)]
pub fn QuadraticLoss(input: &Rc<Tensor>, target: &Rc<Tensor>) -> Rc<Tensor> {
assert_eq!(input.dim(), target.dim());
let l = target.sub(&input).square().mul_scalar(0.5);
let loss = l.mean();
loss
}
}
8 changes: 4 additions & 4 deletions src/tensor/ops/binary_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ impl BinaryOpType {
pub trait BinaryOps {
type Value;
fn add(&self, x: &Self::Value) -> Rc<Tensor>;
fn mul(&self, x: &Self::Value) -> Self::Value;
fn mul_scalar(&self, x: f64) -> Self::Value;
fn sub(&self, x: &Self::Value) -> Self::Value;
fn matmul(&self, x: &Self::Value) -> Self::Value;
fn mul(&self, x: &Self::Value) -> Rc<Tensor>;
fn mul_scalar(&self, x: f64) -> Rc<Tensor>;
fn sub(&self, x: &Self::Value) -> Rc<Tensor>;
fn matmul(&self, x: &Self::Value) -> Rc<Tensor>;
}

#[derive(Debug)]
Expand Down

0 comments on commit 12508e2

Please sign in to comment.