-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlabel_const.go
103 lines (94 loc) · 3.37 KB
/
label_const.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
package sticker
import (
"encoding/gob"
"fmt"
"io"
"log"
)
// LabelConst is the multi-label constant model.
type LabelConst struct {
// LabelList and LabelFreqList are the label and its frequency list in descending order in the training set occurrences.
LabelList LabelVector
LabelFreqList []float32
}
// TrainLabelConst returns an trained LabelConst on the given dataset ds.
func TrainLabelConst(ds *Dataset, debug *log.Logger) (*LabelConst, error) {
labelFreqs := make(map[uint32]float32)
for _, yi := range ds.Y {
for _, label := range yi {
labelFreqs[label]++
}
}
labelList := RankTopK(labelFreqs, uint(len(labelFreqs)))
labelFreqList := make([]float32, uint(len(labelFreqs)))
for rank, label := range labelList {
labelFreqList[rank] = labelFreqs[label]
}
return &LabelConst{
LabelList: labelList,
LabelFreqList: labelFreqList,
}, nil
}
// DecodeLabelConstWithGobDecoder decodes LabelConst using decoder.
//
// This function returns an error in decoding.
func DecodeLabelConstWithGobDecoder(model *LabelConst, decoder *gob.Decoder) error {
if err := decoder.Decode(&model.LabelList); err != nil {
return fmt.Errorf("DecodeLabelBoost: LabelList: %s", err)
}
if err := decoder.Decode(&model.LabelFreqList); err != nil {
return fmt.Errorf("DecodeLabelBoost: LabelFreqList: %s", err)
}
return nil
}
// DecodeLabelConst decodes LabelConst from r.
// Directly passing *os.File used by a gob.Decoder to this function causes mysterious errors.
// Thus, if users use gob.Decoder, then they should call DecodeLabelConstWithGobDecoder.
//
// This function returns an error in decoding.
func DecodeLabelConst(model *LabelConst, r io.Reader) error {
return DecodeLabelConstWithGobDecoder(model, gob.NewDecoder(r))
}
// EncodeLabelConstWithGobEncoder decodes LabelConst using encoder.
//
// This function returns an error in decoding.
func EncodeLabelConstWithGobEncoder(model *LabelConst, encoder *gob.Encoder) error {
if err := encoder.Encode(model.LabelList); err != nil {
return fmt.Errorf("EncodeLabelBoost: LabelList: %s", err)
}
if err := encoder.Encode(model.LabelFreqList); err != nil {
return fmt.Errorf("EncodeLabelBoost: LabelFreqList: %s", err)
}
return nil
}
// EncodeLabelConst encodes LabelConst to w.
// Directly passing *os.File used by a gob.Encoder to this function causes mysterious errors.
// Thus, if users use gob.Encoder, then they should call EncodeLabelBoostWithGobEncoder.
//
// This function returns an error in encoding.
func EncodeLabelConst(model *LabelConst, w io.Writer) error {
return EncodeLabelConstWithGobEncoder(model, gob.NewEncoder(w))
}
// GobEncode returns the error always, because users should encode large LabelConst objects with EncodeLabelConst.
func (model *LabelConst) GobEncode() ([]byte, error) {
return nil, fmt.Errorf("LabelConst should be encoded with EncodeLabelConst")
}
// PredictAll returns the top-K labels for each data entry in X.
func (model *LabelConst) PredictAll(X FeatureVectors, K uint) LabelVectors {
predictedLabels := make(LabelVector, K)
Kmax := K
if Kmax > uint(len(model.LabelList)) {
Kmax = uint(len(model.LabelList))
}
copy(predictedLabels, model.LabelList[:Kmax])
for rank := uint(len(model.LabelList)); rank < K; rank++ {
predictedLabels[rank] = ^uint32(0)
}
Y := make(LabelVectors, 0, len(X))
for range X {
yi := make(LabelVector, K)
copy(yi, predictedLabels)
Y = append(Y, yi)
}
return Y
}