@@ -51,7 +51,6 @@ int32 Nnet::LeftContext() const {
51
51
// non-negative left context. In addition, the NnetExample also stores data
52
52
// left context as positive integer. To be compatible with these other classes
53
53
// Nnet::LeftContext() returns a non-negative left context.
54
-
55
54
}
56
55
57
56
int32 Nnet::RightContext () const {
@@ -66,8 +65,8 @@ int32 Nnet::RightContext() const {
66
65
void Nnet::ComputeChunkInfo (int32 input_chunk_size,
67
66
int32 num_chunks,
68
67
std::vector<ChunkInfo> *chunk_info_out) const {
69
- // First compute the output-chunk indices for the last component in the network.
70
- // we assume that the numbering of the input starts from zero.
68
+ // First compute the output-chunk indices for the last component in the
69
+ // network. we assume that the numbering of the input starts from zero.
71
70
int32 output_chunk_size = input_chunk_size - LeftContext () - RightContext ();
72
71
KALDI_ASSERT (output_chunk_size > 0 );
73
72
std::vector<int32> current_output_inds;
@@ -88,7 +87,7 @@ void Nnet::ComputeChunkInfo(int32 input_chunk_size,
88
87
for (int32 i = NumComponents () - 1 ; i >= 0 ; i--) {
89
88
std::vector<int32> current_context = GetComponent (i).Context ();
90
89
std::set<int32> current_input_ind_set;
91
- for (size_t j = 0 ; j < current_context.size (); j++)
90
+ for (size_t j = 0 ; j < current_context.size (); j++)
92
91
for (size_t k = 0 ; k < current_output_inds.size (); k++)
93
92
current_input_ind_set.insert (current_context[j] +
94
93
current_output_inds[k]);
@@ -137,7 +136,6 @@ void Nnet::ComputeChunkInfo(int32 input_chunk_size,
137
136
(*chunk_info_out)[i].Check ();
138
137
// (*chunk_info_out)[i].ToString();
139
138
}
140
-
141
139
}
142
140
143
141
const Component& Nnet::GetComponent (int32 component) const {
@@ -359,29 +357,56 @@ void Nnet::ResizeOutputLayer(int32 new_num_pdfs) {
359
357
KALDI_ASSERT (new_num_pdfs > 0 );
360
358
KALDI_ASSERT (NumComponents () > 2 );
361
359
int32 nc = NumComponents ();
362
- SumGroupComponent *sgc = dynamic_cast <SumGroupComponent*>(components_[nc - 1 ]);
360
+ SumGroupComponent *sgc =
361
+ dynamic_cast <SumGroupComponent*>(components_[nc - 1 ]);
363
362
if (sgc != NULL ) {
364
363
// Remove it. We'll resize things later.
365
364
delete sgc;
366
365
components_.erase (components_.begin () + nc - 1 ,
367
366
components_.begin () + nc);
368
367
nc--;
369
368
}
370
-
371
369
SoftmaxComponent *sc;
372
370
if ((sc = dynamic_cast <SoftmaxComponent*>(components_[nc - 1 ])) == NULL )
373
371
KALDI_ERR << " Expected last component to be SoftmaxComponent." ;
374
372
373
+ // check if nc-1 has a FixedScaleComponent
374
+ bool has_fixed_scale_component = false ;
375
+ int32 fixed_scale_component_index = -1 ;
376
+ int32 final_affine_component_index = nc - 2 ;
377
+ int32 softmax_component_index = nc - 1 ;
378
+ FixedScaleComponent *fsc =
379
+ dynamic_cast <FixedScaleComponent*>(
380
+ components_[final_affine_component_index]);
381
+ if (fsc != NULL ) {
382
+ has_fixed_scale_component = true ;
383
+ fixed_scale_component_index = nc - 2 ;
384
+ final_affine_component_index = nc - 3 ;
385
+ }
375
386
// note: it could be child class of AffineComponent.
376
- AffineComponent *ac = dynamic_cast <AffineComponent*>(components_[nc - 2 ]);
387
+ AffineComponent *ac = dynamic_cast <AffineComponent*>(
388
+ components_[final_affine_component_index]);
377
389
if (ac == NULL )
378
390
KALDI_ERR << " Network doesn't have expected structure (didn't find final "
379
391
<< " AffineComponent)." ;
380
-
392
+ if (has_fixed_scale_component) {
393
+ // collapse the fixed_scale_component with the affine_component before it
394
+ AffineComponent *ac_new =
395
+ dynamic_cast <AffineComponent*>(ac->CollapseWithNext (*fsc));
396
+ KALDI_ASSERT (ac_new != NULL );
397
+ delete fsc;
398
+ delete ac;
399
+ components_.erase (components_.begin () + fixed_scale_component_index,
400
+ components_.begin () + (fixed_scale_component_index + 1 ));
401
+ components_[final_affine_component_index] = ac_new;
402
+ ac = ac_new;
403
+ softmax_component_index = softmax_component_index - 1 ;
404
+ }
381
405
ac->Resize (ac->InputDim (), new_num_pdfs);
382
406
// Remove the softmax component, and replace it with a new one
383
- delete components_[nc - 1 ];
384
- components_[nc - 1 ] = new SoftmaxComponent (new_num_pdfs);
407
+ delete components_[softmax_component_index];
408
+ components_[softmax_component_index] = new SoftmaxComponent (new_num_pdfs);
409
+ this ->SetIndexes (); // used for debugging
385
410
this ->Check ();
386
411
}
387
412
@@ -655,8 +680,9 @@ void Nnet::Vectorize(VectorBase<BaseFloat> *params) const {
655
680
KALDI_ASSERT (offset == GetParameterDim ());
656
681
}
657
682
658
- void Nnet::ResetGenerators () { // resets random-number generators for all random
659
- // components.
683
+ void Nnet::ResetGenerators () {
684
+ // resets random-number generators for all random
685
+ // components.
660
686
for (int32 c = 0 ; c < NumComponents (); c++) {
661
687
RandomComponent *rc = dynamic_cast <RandomComponent*>(
662
688
&(GetComponent (c)));
0 commit comments