Skip to content

Commit 68db741

Browse files
committed
Merge pull request #9 from vijayaditya/multilang-bugfix
Multilang bugfix
2 parents 36cad85 + 1805746 commit 68db741

File tree

3 files changed

+58
-15
lines changed

3 files changed

+58
-15
lines changed

src/nnet2/nnet-component.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,6 +1407,20 @@ Component *AffineComponent::CollapseWithNext(
14071407
return ans;
14081408
}
14091409

1410+
Component *AffineComponent::CollapseWithNext(
1411+
const FixedScaleComponent &next_component) const {
1412+
KALDI_ASSERT(this->OutputDim() == next_component.InputDim());
1413+
AffineComponent *ans =
1414+
dynamic_cast<AffineComponent*>(this->Copy());
1415+
KALDI_ASSERT(ans != NULL);
1416+
ans->linear_params_.MulRowsVec(next_component.scales_);
1417+
ans->bias_params_.MulElements(next_component.scales_);
1418+
1419+
return ans;
1420+
}
1421+
1422+
1423+
14101424
Component *AffineComponent::CollapseWithPrevious(
14111425
const FixedAffineComponent &prev_component) const {
14121426
// If at least one was non-updatable, make the whole non-updatable.

src/nnet2/nnet-component.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -707,8 +707,9 @@ class ScaleComponent: public Component {
707707

708708

709709

710-
class SumGroupComponent; // Forward declaration.
711-
class AffineComponent; // Forward declaration.
710+
class SumGroupComponent; // Forward declaration.
711+
class AffineComponent; // Forward declaration.
712+
class FixedScaleComponent; // Forward declaration.
712713

713714
class SoftmaxComponent: public NonlinearComponent {
714715
public:
@@ -803,6 +804,7 @@ class AffineComponent: public UpdatableComponent {
803804
// FixedLinearComponent yet.
804805
Component *CollapseWithNext(const AffineComponent &next) const ;
805806
Component *CollapseWithNext(const FixedAffineComponent &next) const;
807+
Component *CollapseWithNext(const FixedScaleComponent &next) const;
806808
Component *CollapseWithPrevious(const FixedAffineComponent &prev) const;
807809

808810
virtual std::string Info() const;
@@ -1473,6 +1475,7 @@ class FixedScaleComponent: public Component {
14731475
virtual void Write(std::ostream &os, bool binary) const;
14741476

14751477
protected:
1478+
friend class AffineComponent; // necessary for collapse
14761479
CuVector<BaseFloat> scales_;
14771480
KALDI_DISALLOW_COPY_AND_ASSIGN(FixedScaleComponent);
14781481
};

src/nnet2/nnet-nnet.cc

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ int32 Nnet::LeftContext() const {
5151
// non-negative left context. In addition, the NnetExample also stores data
5252
// left context as positive integer. To be compatible with these other classes
5353
// Nnet::LeftContext() returns a non-negative left context.
54-
5554
}
5655

5756
int32 Nnet::RightContext() const {
@@ -66,8 +65,8 @@ int32 Nnet::RightContext() const {
6665
void Nnet::ComputeChunkInfo(int32 input_chunk_size,
6766
int32 num_chunks,
6867
std::vector<ChunkInfo> *chunk_info_out) const {
69-
// First compute the output-chunk indices for the last component in the network.
70-
// we assume that the numbering of the input starts from zero.
68+
// First compute the output-chunk indices for the last component in the
69+
// network. we assume that the numbering of the input starts from zero.
7170
int32 output_chunk_size = input_chunk_size - LeftContext() - RightContext();
7271
KALDI_ASSERT(output_chunk_size > 0);
7372
std::vector<int32> current_output_inds;
@@ -88,7 +87,7 @@ void Nnet::ComputeChunkInfo(int32 input_chunk_size,
8887
for (int32 i = NumComponents() - 1; i >= 0; i--) {
8988
std::vector<int32> current_context = GetComponent(i).Context();
9089
std::set<int32> current_input_ind_set;
91-
for (size_t j = 0; j < current_context.size(); j++)
90+
for (size_t j = 0; j < current_context.size(); j++)
9291
for (size_t k = 0; k < current_output_inds.size(); k++)
9392
current_input_ind_set.insert(current_context[j] +
9493
current_output_inds[k]);
@@ -137,7 +136,6 @@ void Nnet::ComputeChunkInfo(int32 input_chunk_size,
137136
(*chunk_info_out)[i].Check();
138137
// (*chunk_info_out)[i].ToString();
139138
}
140-
141139
}
142140

143141
const Component& Nnet::GetComponent(int32 component) const {
@@ -359,29 +357,56 @@ void Nnet::ResizeOutputLayer(int32 new_num_pdfs) {
359357
KALDI_ASSERT(new_num_pdfs > 0);
360358
KALDI_ASSERT(NumComponents() > 2);
361359
int32 nc = NumComponents();
362-
SumGroupComponent *sgc = dynamic_cast<SumGroupComponent*>(components_[nc - 1]);
360+
SumGroupComponent *sgc =
361+
dynamic_cast<SumGroupComponent*>(components_[nc - 1]);
363362
if (sgc != NULL) {
364363
// Remove it. We'll resize things later.
365364
delete sgc;
366365
components_.erase(components_.begin() + nc - 1,
367366
components_.begin() + nc);
368367
nc--;
369368
}
370-
371369
SoftmaxComponent *sc;
372370
if ((sc = dynamic_cast<SoftmaxComponent*>(components_[nc - 1])) == NULL)
373371
KALDI_ERR << "Expected last component to be SoftmaxComponent.";
374372

373+
// check if nc-1 has a FixedScaleComponent
374+
bool has_fixed_scale_component = false;
375+
int32 fixed_scale_component_index = -1;
376+
int32 final_affine_component_index = nc - 2;
377+
int32 softmax_component_index = nc - 1;
378+
FixedScaleComponent *fsc =
379+
dynamic_cast<FixedScaleComponent*>(
380+
components_[final_affine_component_index]);
381+
if (fsc != NULL) {
382+
has_fixed_scale_component = true;
383+
fixed_scale_component_index = nc - 2;
384+
final_affine_component_index = nc - 3;
385+
}
375386
// note: it could be child class of AffineComponent.
376-
AffineComponent *ac = dynamic_cast<AffineComponent*>(components_[nc - 2]);
387+
AffineComponent *ac = dynamic_cast<AffineComponent*>(
388+
components_[final_affine_component_index]);
377389
if (ac == NULL)
378390
KALDI_ERR << "Network doesn't have expected structure (didn't find final "
379391
<< "AffineComponent).";
380-
392+
if (has_fixed_scale_component) {
393+
// collapse the fixed_scale_component with the affine_component before it
394+
AffineComponent *ac_new =
395+
dynamic_cast<AffineComponent*>(ac->CollapseWithNext(*fsc));
396+
KALDI_ASSERT(ac_new != NULL);
397+
delete fsc;
398+
delete ac;
399+
components_.erase(components_.begin() + fixed_scale_component_index,
400+
components_.begin() + (fixed_scale_component_index + 1));
401+
components_[final_affine_component_index] = ac_new;
402+
ac = ac_new;
403+
softmax_component_index = softmax_component_index - 1;
404+
}
381405
ac->Resize(ac->InputDim(), new_num_pdfs);
382406
// Remove the softmax component, and replace it with a new one
383-
delete components_[nc - 1];
384-
components_[nc - 1] = new SoftmaxComponent(new_num_pdfs);
407+
delete components_[softmax_component_index];
408+
components_[softmax_component_index] = new SoftmaxComponent(new_num_pdfs);
409+
this->SetIndexes(); // used for debugging
385410
this->Check();
386411
}
387412

@@ -655,8 +680,9 @@ void Nnet::Vectorize(VectorBase<BaseFloat> *params) const {
655680
KALDI_ASSERT(offset == GetParameterDim());
656681
}
657682

658-
void Nnet::ResetGenerators() { // resets random-number generators for all random
659-
// components.
683+
void Nnet::ResetGenerators() {
684+
// resets random-number generators for all random
685+
// components.
660686
for (int32 c = 0; c < NumComponents(); c++) {
661687
RandomComponent *rc = dynamic_cast<RandomComponent*>(
662688
&(GetComponent(c)));

0 commit comments

Comments
 (0)