Skip to content

Commit

Permalink
Combining multiple where clauses with const generics into a single one (
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman authored Oct 19, 2022
1 parent f0cbf2f commit 50e81eb
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
6 changes: 2 additions & 4 deletions src/nn/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ impl<
const W: usize,
> Module<Tensor3D<I, H, W, T>> for Conv2D<I, O, K, S, P>
where
[(); (W + 2 * P - K) / S + 1]:,
[(); (H + 2 * P - K) / S + 1]:,
[[[(); (W + 2 * P - K) / S + 1]; (H + 2 * P - K) / S + 1]; O]:,
{
type Output = Tensor3D<O, { (H + 2 * P - K) / S + 1 }, { (W + 2 * P - K) / S + 1 }, T>;

Expand All @@ -92,8 +91,7 @@ impl<
const W: usize,
> Module<Tensor4D<B, I, H, W, T>> for Conv2D<I, O, K, S, P>
where
[(); (W + 2 * P - K) / S + 1]:,
[(); (H + 2 * P - K) / S + 1]:,
[[[[(); (W + 2 * P - K) / S + 1]; (H + 2 * P - K) / S + 1]; O]; B]:,
{
type Output = Tensor4D<B, O, { (H + 2 * P - K) / S + 1 }, { (W + 2 * P - K) / S + 1 }, T>;

Expand Down
6 changes: 2 additions & 4 deletions src/nn/pool2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ macro_rules! impl_pools {
T: Tape,
> Module<Tensor3D<C, H, W, T>> for $PoolTy<K, S, P>
where
[(); (W + 2 * P - K) / S + 1]:,
[(); (H + 2 * P - K) / S + 1]:,
[[(); (W + 2 * P - K) / S + 1]; (H + 2 * P - K) / S + 1]:,
{
type Output = Tensor3D<C, { (H + 2 * P - K) / S + 1 }, { (W + 2 * P - K) / S + 1 }, T>;

Expand All @@ -79,8 +78,7 @@ macro_rules! impl_pools {
T: Tape,
> Module<Tensor4D<B, C, H, W, T>> for $PoolTy<K, S, P>
where
[(); (W + 2 * P - K) / S + 1]:,
[(); (H + 2 * P - K) / S + 1]:,
[[(); (W + 2 * P - K) / S + 1]; (H + 2 * P - K) / S + 1]:,
{
type Output =
Tensor4D<B, C, { (H + 2 * P - K) / S + 1 }, { (W + 2 * P - K) / S + 1 }, T>;
Expand Down

0 comments on commit 50e81eb

Please sign in to comment.