Skip to content

Commit e329f4b

Browse files
authored
fix trainable parameters in distributions (#520)
1 parent 057f3fd commit e329f4b

File tree

3 files changed

+21
-19
lines changed

3 files changed

+21
-19
lines changed

bayesflow/distributions/diagonal_normal.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ def __init__(
5858
self.seed_generator = seed_generator or keras.random.SeedGenerator()
5959

6060
self.dim = None
61-
self.log_normalization_constant = None
6261
self._mean = None
6362
self._std = None
6463

@@ -71,17 +70,18 @@ def build(self, input_shape: Shape) -> None:
7170
self.mean = ops.cast(ops.broadcast_to(self.mean, (self.dim,)), "float32")
7271
self.std = ops.cast(ops.broadcast_to(self.std, (self.dim,)), "float32")
7372

74-
self.log_normalization_constant = -0.5 * self.dim * math.log(2.0 * math.pi) - ops.sum(ops.log(self.std))
75-
7673
if self.trainable_parameters:
7774
self._mean = self.add_weight(
7875
shape=ops.shape(self.mean),
79-
initializer=keras.initializers.get(self.mean),
76+
initializer=keras.initializers.get(keras.ops.copy(self.mean)),
8077
dtype="float32",
8178
trainable=True,
8279
)
8380
self._std = self.add_weight(
84-
shape=ops.shape(self.std), initializer=keras.initializers.get(self.std), dtype="float32", trainable=True
81+
shape=ops.shape(self.std),
82+
initializer=keras.initializers.get(keras.ops.copy(self.std)),
83+
dtype="float32",
84+
trainable=True,
8585
)
8686
else:
8787
self._mean = self.mean
@@ -91,7 +91,8 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor:
9191
result = -0.5 * ops.sum((samples - self._mean) ** 2 / self._std**2, axis=-1)
9292

9393
if normalize:
94-
result += self.log_normalization_constant
94+
log_normalization_constant = -0.5 * self.dim * math.log(2.0 * math.pi) - ops.sum(ops.log(self._std))
95+
result += log_normalization_constant
9596

9697
return result
9798

bayesflow/distributions/diagonal_student_t.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def __init__(
6363

6464
self.seed_generator = seed_generator or keras.random.SeedGenerator()
6565

66-
self.log_normalization_constant = None
6766
self.dim = None
6867
self._loc = None
6968
self._scale = None
@@ -78,21 +77,16 @@ def build(self, input_shape: Shape) -> None:
7877
self.loc = ops.cast(ops.broadcast_to(self.loc, (self.dim,)), "float32")
7978
self.scale = ops.cast(ops.broadcast_to(self.scale, (self.dim,)), "float32")
8079

81-
self.log_normalization_constant = (
82-
-0.5 * self.dim * math.log(self.df)
83-
- 0.5 * self.dim * math.log(math.pi)
84-
- math.lgamma(0.5 * self.df)
85-
+ math.lgamma(0.5 * (self.df + self.dim))
86-
- ops.sum(keras.ops.log(self.scale))
87-
)
88-
8980
if self.trainable_parameters:
9081
self._loc = self.add_weight(
91-
shape=ops.shape(self.loc), initializer=keras.initializers.get(self.loc), dtype="float32", trainable=True
82+
shape=ops.shape(self.loc),
83+
initializer=keras.initializers.get(keras.ops.copy(self.loc)),
84+
dtype="float32",
85+
trainable=True,
9286
)
9387
self._scale = self.add_weight(
9488
shape=ops.shape(self.scale),
95-
initializer=keras.initializers.get(self.scale),
89+
initializer=keras.initializers.get(keras.ops.copy(self.scale)),
9690
dtype="float32",
9791
trainable=True,
9892
)
@@ -105,7 +99,14 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor:
10599
result = -0.5 * (self.df + self.dim) * ops.log1p(mahalanobis_term / self.df)
106100

107101
if normalize:
108-
result += self.log_normalization_constant
102+
log_normalization_constant = (
103+
-0.5 * self.dim * math.log(self.df)
104+
- 0.5 * self.dim * math.log(math.pi)
105+
- math.lgamma(0.5 * self.df)
106+
+ math.lgamma(0.5 * (self.df + self.dim))
107+
- ops.sum(keras.ops.log(self._scale))
108+
)
109+
result += log_normalization_constant
109110

110111
return result
111112

bayesflow/distributions/mixture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def build(self, input_shape: Shape) -> None:
144144

145145
self._mixture_logits = self.add_weight(
146146
shape=(len(self.distributions),),
147-
initializer=keras.initializers.get(self.mixture_logits),
147+
initializer=keras.initializers.get(keras.ops.copy(self.mixture_logits)),
148148
dtype="float32",
149149
trainable=self.trainable_mixture,
150150
)

0 commit comments

Comments
 (0)