Skip to content

TCPGen in Conformer RNN-T #2890

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 120 commits into from
Closed

TCPGen in Conformer RNN-T #2890

wants to merge 120 commits into from

Conversation

BriansIDP
Copy link
Contributor

This pull request contains the implementation of the tree-constrained pointer generator (TCPGen) for contextual biasing. An example for Librispeech can be found in audio/examples/asr/librispeech_biasing.

)
parser.add_argument(
"--global-stats-path",
default=pathlib.Path("global_stats.json"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default value here can be global_stats_100.json so that users don't need to type --global-stats-path in slurm script.

@@ -0,0 +1,2 @@
dir=experiments/librispeech_clean100_suffix600_tcpgen500_sche30_nodrop/decode_test_clean_b10_KB1000/
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The path can be changed by input argument like $1. Could you also add a comment on how to use this script?

)

model = ConformerRNNTModule(str(args.sp_model_path), args.biasing)
data_module = get_data_module(str(args.librispeech_path), str(args.global_stats_path), str(args.sp_model_path),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I felt it hard when I want to tune the batch_size or max_token in the dataloader, where the GPU memory is limited in my usecase. @hwangjeff would it be better to provide the api for tuning those?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean exposing max_token to outside right? I agree


Sample SLURM command:
```
srun --cpus-per-task=16 --gpus-per-node=1 -N 1 --ntasks-per-node=1 python train.py --exp-dir <Path_to_exp> --librispeech-path <Path_to_librispeech_data> --global-stats-path ./global_stats_100.json --sp-model-path ./spm_unigram_600_100suffix.model --biasing --biasing-list ./blists/rareword_f15.txt --droprate 0.1 --maxsize 200 --epochs 90
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you change spm_unigram_600_100suffix.model to ./spm_unigram_1023.model which is the default output filename by train_spm.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the current example is done under 600 wordpiece tokens to replicate what I had in the paper. So maybe we should keep this like it to be consistent with the paper? I will also change that in train_spm.py to be consistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise all addressed and pushed. Thank you!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool, thanks!


Sample SLURM command:
```
srun --cpus-per-task=16 --gpus-per-node=1 -N 1 --ntasks-per-node=1 python eval.py --checkpoint-path <Path_to_model_checkpoint> --librispeech-path <Path_to_librispeech_data> --sp-model-path ./spm_unigram_600_100suffix.model --expdir <Path_to_exp> --use-cuda --biasing --biasing-list ./blists/all_rare_words.txt --droprate 0.0 --maxsize 1000
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same in here

help="Run using CUDA.",
)
parser.add_argument(
"--biasinglist",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"--biasinglist",
"--biasing-list",

Comment on lines 108 to 113
parser.add_argument(
"--biasing",
type=str,
help="Use biasing",
required=True,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
parser.add_argument(
"--biasing",
type=str,
help="Use biasing",
required=True,
)
parser.add_argument(
"--biasing",
action="store_true",
help="Use biasing",
)



def run_eval(args):
usebiasing = True if args.biasing == 'true' else False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
usebiasing = True if args.biasing == 'true' else False
usebiasing = args.biasing

model = ConformerRNNTModule.load_from_checkpoint(
args.checkpoint_path, sp_model=str(args.sp_model_path), biasing=usebiasing).eval()
data_module = get_data_module(str(args.librispeech_path), str(args.global_stats_path), str(args.sp_model_path),
biasinglist=args.biasinglist, droprate=args.droprate, maxsize=args.maxsize)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
biasinglist=args.biasinglist, droprate=args.droprate, maxsize=args.maxsize)
biasinglist=args.biasing_list, droprate=args.droprate, maxsize=args.maxsize)

)
parser.add_argument(
"--global-stats-path",
default=pathlib.Path("global_stats.json"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
default=pathlib.Path("global_stats.json"),
default=pathlib.Path("global_stats_100.json"),

mthrok added a commit to mthrok/audio that referenced this pull request Feb 9, 2023
This is the cleaned up version of pytorch#2890

> This pull request contains the implementation of
> the tree-constrained pointer generator (TCPGen) for contextual biasing.
> An example for Librispeech can be found in
> audio/examples/asr/librispeech_biasing.
mthrok added a commit to mthrok/audio that referenced this pull request Feb 9, 2023
This is the cleaned up version of pytorch#2890

> This pull request contains the implementation of
> the tree-constrained pointer generator (TCPGen) for contextual biasing.
> An example for Librispeech can be found in
> audio/examples/asr/librispeech_biasing.
@facebook-github-bot
Copy link
Contributor

@mthrok has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@mthrok
Copy link
Collaborator

mthrok commented Feb 13, 2023

Hi @BriansIDP

We found a way to land this PR without rebasing, so we are almost good to go. Couple of requests to merge

  1. Please run lint tools
    Please run pre-commit and run pre-commit run -a at the root directory.
    Please also run flake8.
  2. We cannot check-in the text files in blists directory as they are huge. We can put it in our S3, if the license permits. Can you tell how the files (all_rare_words.txt, rareword_f15.txt and rareword_f30.txt) are obtained?

Thanks,

@BriansIDP
Copy link
Contributor Author

Hi @BriansIDP

We found a way to land this PR without rebasing, so we are almost good to go. Couple of requests to merge

  1. Please run lint tools
    Please run pre-commit and run pre-commit run -a at the root directory.
    Please also run flake8.
  2. We cannot check-in the text files in blists directory as they are huge. We can put it in our S3, if the license permits. Can you tell how the files (all_rare_words.txt, rareword_f15.txt and rareword_f30.txt) are obtained?

Thanks,

Hi @mthrok. Thank you for the instruction. I have now run: (1) pre-commit and then pre-commit run -a and all tests passed. (2) flake8 and did not find errors in any of my modified files.

For the biasing lists, rareword_f15.txt and rareword_f30.txt are generated by thresholding the train-clean-100 set word frequencies at 15 and 30 respectively (so including any words appearing fewer than 15/30 times). The all_rare_words.txt is obtained from here: https://github.com/facebookresearch/fbai-speech/tree/main/is21_deep_bias. Please let me know what I should do with those. Thank you!

@xiaohui-zhang
Copy link
Contributor

thanks. @mthrok according to @BriansIDP 's reply, I've confirmed those files only contain word stats of Librispeech data. So it's safe to put on S3. Thanks again.

@facebook-github-bot
Copy link
Contributor

@mthrok has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@mthrok
Copy link
Collaborator

mthrok commented Feb 14, 2023

@xiaohui-zhang @BriansIDP What about error_analysis/word_freq.txt? It seems this is used during the generation of blists. But I cannot quite figure out from what word_freq.txt was generated.

@BriansIDP
Copy link
Contributor Author

@xiaohui-zhang @BriansIDP What about error_analysis/word_freq.txt? It seems this is used during the generation of blists. But I cannot quite figure out from what word_freq.txt was generated.

Hi @mthrok , This is a word count file counting all training set word frequencies in train_clean_100. This file is actually only needed when calculating OOV word error rates (since anything not in this file should be counted as OOV words). Should I add a line explaining what this is in get_error_word_count.py and then you can move it if needed?

@xiaohui-zhang
Copy link
Contributor

In this case I think it's OK to simply keep this file in S3 as well. cc @mthrok

@facebook-github-bot
Copy link
Contributor

@mthrok merged this pull request in 1ed330b.

@github-actions
Copy link

Hey @mthrok.
You merged this PR, but labels were not properly added. Please add a primary and secondary label (See https://github.com/pytorch/audio/blob/main/.github/process_commit.py)

@mthrok
Copy link
Collaborator

mthrok commented Feb 23, 2023

@BriansIDP I merged this PR in 1ed330b. Thank you for the contribution and congrats!

Regarding the txt files; I removed rare words, and uploaded them to torchaudio's CDN.

1ed330b#diff-a464de23e8e5d28a210663f87eff1a7fb55b6fcfcb6df611ef5013515b59c554

I did not add error_analysis/word_freq.txt to the CDN, as I was not sure if this should be accessible. Let me know if it is better mentioned in README.

@BriansIDP
Copy link
Contributor Author

@BriansIDP I merged this PR in 1ed330b. Thank you for the contribution and congrats!

Regarding the txt files; I removed rare words, and uploaded them to torchaudio's CDN.

1ed330b#diff-a464de23e8e5d28a210663f87eff1a7fb55b6fcfcb6df611ef5013515b59c554

I did not add error_analysis/word_freq.txt to the CDN, as I was not sure if this should be accessible. Let me know if it is better mentioned in README.

Thank you @mthrok so much for helping me throughout this PR!

I learned a lot! It would be better mentioning this file is only to keep track of the training set vocabulary to calculate OOV word error rates in error_analysis/get_error_word_count.py. Thank you!

@BriansIDP
Copy link
Contributor Author

Hi @mthrok. I am planning to write a tutorial about the new biasing module, and I wonder if it is possible to upload my biasing model to CDN so that I can load it in the tutorial (maybe name it https://download.pytorch.org/torchaudio/models/conformer_rnnt_biasing_librispeech.pt). If possible, what would be the best way for me to send the model to you (e.g. I can do this via Google drive)? Thank you so much for your help!

@mthrok
Copy link
Collaborator

mthrok commented Mar 2, 2023

Hi @BriansIDP

Google drive works for us. Make it public or share it with moto@meta.com.
Can you confirm that the model is trained on publicly available dataset? (i.e. LibriSpeech?)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.