Skip to content

Commit babaa58

Browse files
committed
Adds minimum/maximum ops
1 parent 0c52e56 commit babaa58

File tree

2 files changed

+179
-4
lines changed

2 files changed

+179
-4
lines changed

src/lib.rs

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,6 @@ impl FromConstant for Vec<f32> {
9696
}
9797
}
9898

99-
pub trait Pow<Rhs=Self> {
100-
type Output;
101-
fn pow(self, rhs: Rhs) -> Self::Output;
102-
}
10399

104100
impl Deref for ANode {
105101
type Target = Arc<dyn Node>;
@@ -280,6 +276,11 @@ impl Neg for &ANode {
280276
}
281277
}
282278

279+
pub trait Pow<Rhs=Self> {
280+
type Output;
281+
fn pow(self, rhs: Rhs) -> Self::Output;
282+
}
283+
283284
impl Pow for ANode {
284285
type Output = ANode;
285286
fn pow(self, rhs: ANode) -> Self::Output {
@@ -322,3 +323,35 @@ impl BulkOps for Vec<&ANode> {
322323
}
323324
}
324325

326+
pub trait MaximumOps<Rhs=Self> {
327+
type Output;
328+
fn maximum(self, rhs: Rhs) -> Self::Output;
329+
330+
}
331+
332+
impl MaximumOps for ANode {
333+
type Output = ANode;
334+
fn maximum(self, rhs: ANode) -> Self::Output {
335+
Maximum::new(self, rhs)
336+
}
337+
}
338+
339+
convert_binops! { impl MaximumOps, maximum for ANode, ANode }
340+
forward_ref_binop! { impl MaximumOps, maximum for ANode, ANode }
341+
342+
pub trait MinimumOps<Rhs=Self> {
343+
type Output;
344+
fn minimum(self, rhs: Rhs) -> Self::Output;
345+
346+
}
347+
348+
impl MinimumOps for ANode {
349+
type Output = ANode;
350+
fn minimum(self, rhs: ANode) -> Self::Output {
351+
Minimum::new(self, rhs)
352+
}
353+
}
354+
355+
convert_binops! { impl MinimumOps, minimum for ANode, ANode }
356+
forward_ref_binop! { impl MinimumOps, minimum for ANode, ANode }
357+

src/ops.rs

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,116 @@ impl Node for BulkSum {
735735
}
736736

737737

738+
pub(crate) struct Maximum(NodeIdx, Vec<ANode>, Computation);
739+
740+
impl Maximum {
741+
pub(crate) fn new(left: ANode, right:ANode) -> ANode {
742+
let idx = NodeIdx::new();
743+
let value = Maximum::compute(&left, &right);
744+
let node = Maximum(idx, vec![left, right], Computation::pooled(value));
745+
ANode::new(Arc::new(node))
746+
}
747+
748+
fn compute(left: &ANode, right: &ANode) -> MPVec {
749+
let (lv, rv) = Broadcast::from_pair(left.value(), right.value());
750+
let mut out = allocate_vec(lv.len);
751+
out.iter_mut().zip(lv.zip(rv)).for_each(|(oi, (lvi, rvi))| {
752+
*oi = lvi.max(*rvi)
753+
});
754+
out
755+
}
756+
}
757+
758+
impl Node for Maximum {
759+
fn get_id(&self) -> NodeIdx { self.0.clone() }
760+
761+
fn get_children(&self) -> Option<&[ANode]> {
762+
Some(self.1.as_slice())
763+
}
764+
765+
fn is_leaf(&self) -> bool { false }
766+
767+
fn value(&self) -> &[DType] {
768+
&self.2.get()
769+
}
770+
771+
fn requires_grad(&self) -> bool { false }
772+
773+
fn compute_grad(&self, grad: &[DType], child_grads: &mut [Vec<DType>]) {
774+
// f(x,y) = x.max(y)
775+
let left = self.1[0].value();
776+
let right = self.1[1].value();
777+
let (lv, rv) = Broadcast::from_pair(left, right);
778+
let (left_grad, right_grad) = child_grads.split_at_mut(1);
779+
let mut left_out = Updater::new(&mut left_grad[0], grad.len());
780+
let mut right_out = Updater::new(&mut right_grad[0], grad.len());
781+
grad.iter().zip(lv.zip(rv)).for_each(|(gi, (xi, yi))| {
782+
if xi >= yi {
783+
left_out.add(*gi);
784+
right_out.add(0f32);
785+
} else {
786+
right_out.add(*gi);
787+
left_out.add(0f32);
788+
}
789+
});
790+
}
791+
}
792+
793+
pub(crate) struct Minimum(NodeIdx, Vec<ANode>, Computation);
794+
795+
impl Minimum {
796+
pub(crate) fn new(left: ANode, right:ANode) -> ANode {
797+
let idx = NodeIdx::new();
798+
let value = Minimum::compute(&left, &right);
799+
let node = Minimum(idx, vec![left, right], Computation::pooled(value));
800+
ANode::new(Arc::new(node))
801+
}
802+
803+
fn compute(left: &ANode, right: &ANode) -> MPVec {
804+
let (lv, rv) = Broadcast::from_pair(left.value(), right.value());
805+
let mut out = allocate_vec(lv.len);
806+
out.iter_mut().zip(lv.zip(rv)).for_each(|(oi, (lvi, rvi))| {
807+
*oi = lvi.min(*rvi)
808+
});
809+
out
810+
}
811+
}
812+
813+
impl Node for Minimum {
814+
fn get_id(&self) -> NodeIdx { self.0.clone() }
815+
816+
fn get_children(&self) -> Option<&[ANode]> {
817+
Some(self.1.as_slice())
818+
}
819+
820+
fn is_leaf(&self) -> bool { false }
821+
822+
fn value(&self) -> &[DType] {
823+
&self.2.get()
824+
}
825+
826+
fn requires_grad(&self) -> bool { false }
827+
828+
fn compute_grad(&self, grad: &[DType], child_grads: &mut [Vec<DType>]) {
829+
// f(x,y) = x.max(y)
830+
let left = self.1[0].value();
831+
let right = self.1[1].value();
832+
let (lv, rv) = Broadcast::from_pair(left, right);
833+
let (left_grad, right_grad) = child_grads.split_at_mut(1);
834+
let mut left_out = Updater::new(&mut left_grad[0], grad.len());
835+
let mut right_out = Updater::new(&mut right_grad[0], grad.len());
836+
grad.iter().zip(lv.zip(rv)).for_each(|(gi, (xi, yi))| {
837+
if xi >= yi {
838+
right_out.add(*gi);
839+
left_out.add(0f32);
840+
} else {
841+
left_out.add(*gi);
842+
right_out.add(0f32);
843+
}
844+
});
845+
}
846+
}
847+
738848
#[cfg(test)]
739849
mod tests {
740850
use super::*;
@@ -900,6 +1010,38 @@ mod tests {
9001010
assert_eq!(grad, &[-1., -(-1f32).exp(), -(-2f32).exp()]);
9011011
}
9021012

1013+
#[test]
1014+
fn test_maximum() {
1015+
let x = Variable::new(vec![1., 2.]);
1016+
let y = Variable::new(vec![3., 5.]);
1017+
1018+
let out = (&x).pow(4f32).maximum(2f32 * &y);
1019+
1020+
let mut graph = Graph::new();
1021+
graph.backward(&out);
1022+
1023+
let x_grad = graph.get_grad(&x).unwrap();
1024+
let y_grad = graph.get_grad(&y).unwrap();
1025+
assert_eq!(x_grad, &[0f32, 32f32]);
1026+
assert_eq!(y_grad, &[2f32, 0f32]);
1027+
}
1028+
1029+
#[test]
1030+
fn test_minimum() {
1031+
let x = Variable::new(vec![1., 2.]);
1032+
let y = Variable::new(vec![3., 5.]);
1033+
1034+
let out = (&x).pow(4f32).minimum(2f32 * &y);
1035+
1036+
let mut graph = Graph::new();
1037+
graph.backward(&out);
1038+
1039+
let x_grad = graph.get_grad(&x).unwrap();
1040+
let y_grad = graph.get_grad(&y).unwrap();
1041+
assert_eq!(x_grad, &[4f32, 0f32]);
1042+
assert_eq!(y_grad, &[0f32, 2f32]);
1043+
}
1044+
9031045
#[test]
9041046
fn test_backward_pass_simple1() {
9051047
// 2x

0 commit comments

Comments
 (0)