Skip to content

Commit a758e78

Browse files
committed
The outputs of a tree-based model should be in the range [0, 1]^{# of label}, which is corresponding to the probability estimates.
Moreover, if we want to use sparse matrix to store the prediction values of a tree model, the value ``-inf'' will be a trouble issue.
1 parent 6e8fa71 commit a758e78

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

libmultilabel/linear/tree.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def predict_values(
5858
x: sparse.csr_matrix,
5959
beam_width: int = 10,
6060
) -> np.ndarray:
61-
"""Calculates the decision values associated with x.
61+
"""Calculates the probability estimates associated with x.
6262
6363
Args:
6464
x (sparse.csr_matrix): A matrix with dimension number of instances * number of features.
@@ -72,10 +72,10 @@ def predict_values(
7272
return np.vstack([self._beam_search(all_preds[i], beam_width) for i in range(all_preds.shape[0])])
7373

7474
def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarray:
75-
"""Predict with beam search using cached decision values for a single instance.
75+
"""Predict with beam search using cached probability estimates for a single instance.
7676
7777
Args:
78-
instance_preds (np.ndarray): A vector of cached decision values of each node, has dimension number of labels + total number of metalabels.
78+
instance_preds (np.ndarray): A vector of cached probability estimates of each node, has dimension number of labels + total number of metalabels.
7979
beam_width (int): Number of candidates considered.
8080
8181
Returns:
@@ -101,7 +101,7 @@ def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarra
101101
next_level = []
102102

103103
num_labels = len(self.root.label_map)
104-
scores = np.full(num_labels, -np.inf)
104+
scores = np.full(num_labels, 0)
105105
for node, score in cur_level:
106106
slice = np.s_[self.weight_map[node.index] : self.weight_map[node.index + 1]]
107107
pred = instance_preds[slice]

0 commit comments

Comments
 (0)