Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
NKNaN committed Oct 8, 2023
1 parent 91f9e01 commit 4c1bfd4
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
16 changes: 10 additions & 6 deletions python/paddle/distribution/binomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,7 @@ def entropy(self):
numpy.ndarray: the entropy for the binomial r.v.
"""
values = self._enumerate_support()
eps = paddle.finfo(self.probability.dtype).eps
log_prob = paddle.nan_to_num(self.log_prob(values), neginf=eps)
log_prob = self.log_prob(values)
return -(paddle.exp(log_prob) * log_prob).sum(0)

def _enumerate_support(self):
Expand Down Expand Up @@ -233,11 +232,16 @@ def log_prob(self, value):
- paddle.lgamma(self.total_count - value + 1.0)
- paddle.lgamma(value + 1.0)
)
eps = paddle.finfo(self.probability.dtype).eps
probs = paddle.clip(self.probability, min=eps, max=1 - eps)
# log_p
return (
log_comb
+ value * paddle.log(self.probability)
+ (self.total_count - value) * paddle.log(1 - self.probability)
return paddle.nan_to_num(
(
log_comb
+ value * paddle.log(probs)
+ (self.total_count - value) * paddle.log(1 - probs)
),
neginf=-eps,
)

def prob(self, value):
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distribution/poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def log_prob(self, value):
+ value * paddle.log(self.rate)
- paddle.lgamma(value + 1)
),
neginf=eps,
neginf=-eps,
)

def prob(self, value):
Expand Down

0 comments on commit 4c1bfd4

Please sign in to comment.