Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/nnet2/nnet-component.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1407,6 +1407,20 @@ Component *AffineComponent::CollapseWithNext(
return ans;
}

Component *AffineComponent::CollapseWithNext(
const FixedScaleComponent &next_component) const {
KALDI_ASSERT(this->OutputDim() == next_component.InputDim());
AffineComponent *ans =
dynamic_cast<AffineComponent*>(this->Copy());
KALDI_ASSERT(ans != NULL);
ans->linear_params_.MulRowsVec(next_component.scales_);
ans->bias_params_.MulElements(next_component.scales_);

return ans;
}



Component *AffineComponent::CollapseWithPrevious(
const FixedAffineComponent &prev_component) const {
// If at least one was non-updatable, make the whole non-updatable.
Expand Down
7 changes: 5 additions & 2 deletions src/nnet2/nnet-component.h
Original file line number Diff line number Diff line change
Expand Up @@ -707,8 +707,9 @@ class ScaleComponent: public Component {



class SumGroupComponent; // Forward declaration.
class AffineComponent; // Forward declaration.
class SumGroupComponent; // Forward declaration.
class AffineComponent; // Forward declaration.
class FixedScaleComponent; // Forward declaration.

class SoftmaxComponent: public NonlinearComponent {
public:
Expand Down Expand Up @@ -803,6 +804,7 @@ class AffineComponent: public UpdatableComponent {
// FixedLinearComponent yet.
Component *CollapseWithNext(const AffineComponent &next) const ;
Component *CollapseWithNext(const FixedAffineComponent &next) const;
Component *CollapseWithNext(const FixedScaleComponent &next) const;
Component *CollapseWithPrevious(const FixedAffineComponent &prev) const;

virtual std::string Info() const;
Expand Down Expand Up @@ -1473,6 +1475,7 @@ class FixedScaleComponent: public Component {
virtual void Write(std::ostream &os, bool binary) const;

protected:
friend class AffineComponent; // necessary for collapse
CuVector<BaseFloat> scales_;
KALDI_DISALLOW_COPY_AND_ASSIGN(FixedScaleComponent);
};
Expand Down
52 changes: 39 additions & 13 deletions src/nnet2/nnet-nnet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ int32 Nnet::LeftContext() const {
// non-negative left context. In addition, the NnetExample also stores data
// left context as positive integer. To be compatible with these other classes
// Nnet::LeftContext() returns a non-negative left context.

}

int32 Nnet::RightContext() const {
Expand All @@ -66,8 +65,8 @@ int32 Nnet::RightContext() const {
void Nnet::ComputeChunkInfo(int32 input_chunk_size,
int32 num_chunks,
std::vector<ChunkInfo> *chunk_info_out) const {
// First compute the output-chunk indices for the last component in the network.
// we assume that the numbering of the input starts from zero.
// First compute the output-chunk indices for the last component in the
// network. we assume that the numbering of the input starts from zero.
int32 output_chunk_size = input_chunk_size - LeftContext() - RightContext();
KALDI_ASSERT(output_chunk_size > 0);
std::vector<int32> current_output_inds;
Expand All @@ -88,7 +87,7 @@ void Nnet::ComputeChunkInfo(int32 input_chunk_size,
for (int32 i = NumComponents() - 1; i >= 0; i--) {
std::vector<int32> current_context = GetComponent(i).Context();
std::set<int32> current_input_ind_set;
for (size_t j = 0; j < current_context.size(); j++)
for (size_t j = 0; j < current_context.size(); j++)
for (size_t k = 0; k < current_output_inds.size(); k++)
current_input_ind_set.insert(current_context[j] +
current_output_inds[k]);
Expand Down Expand Up @@ -137,7 +136,6 @@ void Nnet::ComputeChunkInfo(int32 input_chunk_size,
(*chunk_info_out)[i].Check();
// (*chunk_info_out)[i].ToString();
}

}

const Component& Nnet::GetComponent(int32 component) const {
Expand Down Expand Up @@ -359,29 +357,56 @@ void Nnet::ResizeOutputLayer(int32 new_num_pdfs) {
KALDI_ASSERT(new_num_pdfs > 0);
KALDI_ASSERT(NumComponents() > 2);
int32 nc = NumComponents();
SumGroupComponent *sgc = dynamic_cast<SumGroupComponent*>(components_[nc - 1]);
SumGroupComponent *sgc =
dynamic_cast<SumGroupComponent*>(components_[nc - 1]);
if (sgc != NULL) {
// Remove it. We'll resize things later.
delete sgc;
components_.erase(components_.begin() + nc - 1,
components_.begin() + nc);
nc--;
}

SoftmaxComponent *sc;
if ((sc = dynamic_cast<SoftmaxComponent*>(components_[nc - 1])) == NULL)
KALDI_ERR << "Expected last component to be SoftmaxComponent.";

// check if nc-1 has a FixedScaleComponent
bool has_fixed_scale_component = false;
int32 fixed_scale_component_index = -1;
int32 final_affine_component_index = nc - 2;
int32 softmax_component_index = nc - 1;
FixedScaleComponent *fsc =
dynamic_cast<FixedScaleComponent*>(
components_[final_affine_component_index]);
if (fsc != NULL) {
has_fixed_scale_component = true;
fixed_scale_component_index = nc - 2;
final_affine_component_index = nc - 3;
}
// note: it could be child class of AffineComponent.
AffineComponent *ac = dynamic_cast<AffineComponent*>(components_[nc - 2]);
AffineComponent *ac = dynamic_cast<AffineComponent*>(
components_[final_affine_component_index]);
if (ac == NULL)
KALDI_ERR << "Network doesn't have expected structure (didn't find final "
<< "AffineComponent).";

if (has_fixed_scale_component) {
// collapse the fixed_scale_component with the affine_component before it
AffineComponent *ac_new =
dynamic_cast<AffineComponent*>(ac->CollapseWithNext(*fsc));
KALDI_ASSERT(ac_new != NULL);
delete fsc;
delete ac;
components_.erase(components_.begin() + fixed_scale_component_index,
components_.begin() + (fixed_scale_component_index + 1));
components_[final_affine_component_index] = ac_new;
ac = ac_new;
softmax_component_index = softmax_component_index - 1;
}
ac->Resize(ac->InputDim(), new_num_pdfs);
// Remove the softmax component, and replace it with a new one
delete components_[nc - 1];
components_[nc - 1] = new SoftmaxComponent(new_num_pdfs);
delete components_[softmax_component_index];
components_[softmax_component_index] = new SoftmaxComponent(new_num_pdfs);
this->SetIndexes(); // used for debugging
this->Check();
}

Expand Down Expand Up @@ -655,8 +680,9 @@ void Nnet::Vectorize(VectorBase<BaseFloat> *params) const {
KALDI_ASSERT(offset == GetParameterDim());
}

void Nnet::ResetGenerators() { // resets random-number generators for all random
// components.
void Nnet::ResetGenerators() {
// resets random-number generators for all random
// components.
for (int32 c = 0; c < NumComponents(); c++) {
RandomComponent *rc = dynamic_cast<RandomComponent*>(
&(GetComponent(c)));
Expand Down