Skip to content

Commit

Permalink
fix numerical error for cdsilu (#42)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored and iProzd committed Feb 7, 2025
1 parent ed44850 commit 475fe1d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
23 changes: 13 additions & 10 deletions source/lib/src/cdsilu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ inline FPTYPE customdsilugrad(const FPTYPE x, const FPTYPE a, const FPTYPE b) {
FPTYPE eax1 = std::exp(-xbar);
FPTYPE eax1p1 = eax1 + (FPTYPE)1.0;
FPTYPE eax1p1r = (FPTYPE)1.0 / eax1p1;
FPTYPE eax1eax1p1r = 1 - eax1p1r;
FPTYPE eaxb1 = std::exp(-xbar + b);
FPTYPE eaxb1p1 = eaxb1 + (FPTYPE)1.0;
FPTYPE eaxb1p1r = (FPTYPE)1.0 / eaxb1p1;
return (-xbar * eax1 * eax1p1r * eax1p1r - eax1p1r) * eaxb1p1r +
((FPTYPE)1.0 - xbar * eax1p1r) * eaxb1 * eaxb1p1r * eaxb1p1r +
FPTYPE eaxb1eaxb1p1r = 1 - eaxb1p1r;
return (-xbar * eax1eax1p1r * eax1p1r - eax1p1r) * eaxb1p1r +
((FPTYPE)1.0 - xbar * eax1p1r) * eaxb1eaxb1p1r * eaxb1p1r +
silugrad(x);
}

Expand All @@ -54,17 +56,18 @@ inline FPTYPE customdsilugradgrad(const FPTYPE x,
FPTYPE eax1 = std::exp(-xbar);
FPTYPE eax1p1 = eax1 + (FPTYPE)1.0;
FPTYPE eax1p1r = (FPTYPE)1.0 / eax1p1;
FPTYPE eax1eax1p1r = 1 - eax1p1r;
FPTYPE eaxb1 = std::exp(-xbar + b);
FPTYPE eaxb1p1 = eaxb1 + (FPTYPE)1.0;
FPTYPE eaxb1p1r = (FPTYPE)1.0 / eaxb1p1;
return ((FPTYPE)2.0 * (-xbar * eax1 * eax1p1r * eax1p1r - eax1p1r) * eaxb1 -
((FPTYPE)1.0 - xbar * eax1p1r) * eaxb1) *
eaxb1p1r * eaxb1p1r +
(xbar * eax1 - (FPTYPE)2.0 * xbar * eax1 * eax1 * eax1p1r -
(FPTYPE)2.0 * eax1) *
eax1p1r * eax1p1r * eaxb1p1r +
(FPTYPE)2.0 * ((FPTYPE)1.0 - xbar * eax1p1r) * eaxb1 * eaxb1 *
eaxb1p1r * eaxb1p1r * eaxb1p1r +
FPTYPE eaxb1eaxb1p1r = 1 - eaxb1p1r;
return ((FPTYPE)2.0 * (-xbar * eax1eax1p1r * eax1p1r - eax1p1r) -
((FPTYPE)1.0 - xbar * eax1p1r)) *
eaxb1eaxb1p1r * eaxb1p1r +
(xbar - (FPTYPE)2.0 * xbar * eax1eax1p1r - (FPTYPE)2.0) * eax1eax1p1r *
eax1p1r * eaxb1p1r +
(FPTYPE)2.0 * ((FPTYPE)1.0 - xbar * eax1p1r) * eaxb1eaxb1p1r *
eaxb1eaxb1p1r * eaxb1p1r +
silugradgrad(x);
}

Expand Down
23 changes: 13 additions & 10 deletions source/lib/src/gpu/cdsilu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ __device__ inline FPTYPE customdsilugrad(const FPTYPE x,
FPTYPE eax1 = _exp(-xbar);
FPTYPE eax1p1 = eax1 + (FPTYPE)1.0;
FPTYPE eax1p1r = (FPTYPE)1.0 / eax1p1;
FPTYPE eax1eax1p1r = 1 - eax1p1r;
FPTYPE eaxb1 = _exp(-xbar + b);
FPTYPE eaxb1p1 = eaxb1 + (FPTYPE)1.0;
FPTYPE eaxb1p1r = (FPTYPE)1.0 / eaxb1p1;
return (-xbar * eax1 * eax1p1r * eax1p1r - eax1p1r) * eaxb1p1r +
((FPTYPE)1.0 - xbar * eax1p1r) * eaxb1 * eaxb1p1r * eaxb1p1r +
FPTYPE eaxb1eaxb1p1r = 1 - eaxb1p1r;
return (-xbar * eax1eax1p1r * eax1p1r - eax1p1r) * eaxb1p1r +
((FPTYPE)1.0 - xbar * eax1p1r) * eaxb1eaxb1p1r * eaxb1p1r +
silugrad(x);
}

Expand All @@ -58,17 +60,18 @@ __device__ inline FPTYPE customdsilugradgrad(const FPTYPE x,
FPTYPE eax1 = _exp(-xbar);
FPTYPE eax1p1 = eax1 + (FPTYPE)1.0;
FPTYPE eax1p1r = (FPTYPE)1.0 / eax1p1;
FPTYPE eax1eax1p1r = 1 - eax1p1r;
FPTYPE eaxb1 = _exp(-xbar + b);
FPTYPE eaxb1p1 = eaxb1 + (FPTYPE)1.0;
FPTYPE eaxb1p1r = (FPTYPE)1.0 / eaxb1p1;
return ((FPTYPE)2.0 * (-xbar * eax1 * eax1p1r * eax1p1r - eax1p1r) * eaxb1 -
((FPTYPE)1.0 - xbar * eax1p1r) * eaxb1) *
eaxb1p1r * eaxb1p1r +
(xbar * eax1 - (FPTYPE)2.0 * xbar * eax1 * eax1 * eax1p1r -
(FPTYPE)2.0 * eax1) *
eax1p1r * eax1p1r * eaxb1p1r +
(FPTYPE)2.0 * ((FPTYPE)1.0 - xbar * eax1p1r) * eaxb1 * eaxb1 *
eaxb1p1r * eaxb1p1r * eaxb1p1r +
FPTYPE eaxb1eaxb1p1r = 1 - eaxb1p1r;
return ((FPTYPE)2.0 * (-xbar * eax1eax1p1r * eax1p1r - eax1p1r) -
((FPTYPE)1.0 - xbar * eax1p1r)) *
eaxb1eaxb1p1r * eaxb1p1r +
(xbar - (FPTYPE)2.0 * xbar * eax1eax1p1r - (FPTYPE)2.0) * eax1eax1p1r *
eax1p1r * eaxb1p1r +
(FPTYPE)2.0 * ((FPTYPE)1.0 - xbar * eax1p1r) * eaxb1eaxb1p1r *
eaxb1eaxb1p1r * eaxb1p1r +
silugradgrad(x);
}

Expand Down

0 comments on commit 475fe1d

Please sign in to comment.