-
Notifications
You must be signed in to change notification settings - Fork 696
TCPGen in Conformer RNN-T #2890
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
) | ||
parser.add_argument( | ||
"--global-stats-path", | ||
default=pathlib.Path("global_stats.json"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default value here can be global_stats_100.json
so that users don't need to type --global-stats-path
in slurm script.
@@ -0,0 +1,2 @@ | |||
dir=experiments/librispeech_clean100_suffix600_tcpgen500_sche30_nodrop/decode_test_clean_b10_KB1000/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The path can be changed by input argument like $1
. Could you also add a comment on how to use this script?
) | ||
|
||
model = ConformerRNNTModule(str(args.sp_model_path), args.biasing) | ||
data_module = get_data_module(str(args.librispeech_path), str(args.global_stats_path), str(args.sp_model_path), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I felt it hard when I want to tune the batch_size or max_token in the dataloader, where the GPU memory is limited in my usecase. @hwangjeff would it be better to provide the api for tuning those?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you mean exposing max_token to outside right? I agree
|
||
Sample SLURM command: | ||
``` | ||
srun --cpus-per-task=16 --gpus-per-node=1 -N 1 --ntasks-per-node=1 python train.py --exp-dir <Path_to_exp> --librispeech-path <Path_to_librispeech_data> --global-stats-path ./global_stats_100.json --sp-model-path ./spm_unigram_600_100suffix.model --biasing --biasing-list ./blists/rareword_f15.txt --droprate 0.1 --maxsize 200 --epochs 90 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you change spm_unigram_600_100suffix.model
to ./spm_unigram_1023.model
which is the default output filename by train_spm.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the current example is done under 600 wordpiece tokens to replicate what I had in the paper. So maybe we should keep this like it to be consistent with the paper? I will also change that in train_spm.py to be consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise all addressed and pushed. Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool, thanks!
|
||
Sample SLURM command: | ||
``` | ||
srun --cpus-per-task=16 --gpus-per-node=1 -N 1 --ntasks-per-node=1 python eval.py --checkpoint-path <Path_to_model_checkpoint> --librispeech-path <Path_to_librispeech_data> --sp-model-path ./spm_unigram_600_100suffix.model --expdir <Path_to_exp> --use-cuda --biasing --biasing-list ./blists/all_rare_words.txt --droprate 0.0 --maxsize 1000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same in here
help="Run using CUDA.", | ||
) | ||
parser.add_argument( | ||
"--biasinglist", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"--biasinglist", | |
"--biasing-list", |
parser.add_argument( | ||
"--biasing", | ||
type=str, | ||
help="Use biasing", | ||
required=True, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
parser.add_argument( | |
"--biasing", | |
type=str, | |
help="Use biasing", | |
required=True, | |
) | |
parser.add_argument( | |
"--biasing", | |
action="store_true", | |
help="Use biasing", | |
) |
|
||
|
||
def run_eval(args): | ||
usebiasing = True if args.biasing == 'true' else False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
usebiasing = True if args.biasing == 'true' else False | |
usebiasing = args.biasing |
model = ConformerRNNTModule.load_from_checkpoint( | ||
args.checkpoint_path, sp_model=str(args.sp_model_path), biasing=usebiasing).eval() | ||
data_module = get_data_module(str(args.librispeech_path), str(args.global_stats_path), str(args.sp_model_path), | ||
biasinglist=args.biasinglist, droprate=args.droprate, maxsize=args.maxsize) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
biasinglist=args.biasinglist, droprate=args.droprate, maxsize=args.maxsize) | |
biasinglist=args.biasing_list, droprate=args.droprate, maxsize=args.maxsize) |
) | ||
parser.add_argument( | ||
"--global-stats-path", | ||
default=pathlib.Path("global_stats.json"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default=pathlib.Path("global_stats.json"), | |
default=pathlib.Path("global_stats_100.json"), |
This is the cleaned up version of pytorch#2890 > This pull request contains the implementation of > the tree-constrained pointer generator (TCPGen) for contextual biasing. > An example for Librispeech can be found in > audio/examples/asr/librispeech_biasing.
This is the cleaned up version of pytorch#2890 > This pull request contains the implementation of > the tree-constrained pointer generator (TCPGen) for contextual biasing. > An example for Librispeech can be found in > audio/examples/asr/librispeech_biasing.
@mthrok has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Hi @BriansIDP We found a way to land this PR without rebasing, so we are almost good to go. Couple of requests to merge
Thanks, |
Hi @mthrok. Thank you for the instruction. I have now run: (1) pre-commit and then pre-commit run -a and all tests passed. (2) flake8 and did not find errors in any of my modified files. For the biasing lists, rareword_f15.txt and rareword_f30.txt are generated by thresholding the train-clean-100 set word frequencies at 15 and 30 respectively (so including any words appearing fewer than 15/30 times). The all_rare_words.txt is obtained from here: https://github.com/facebookresearch/fbai-speech/tree/main/is21_deep_bias. Please let me know what I should do with those. Thank you! |
thanks. @mthrok according to @BriansIDP 's reply, I've confirmed those files only contain word stats of Librispeech data. So it's safe to put on S3. Thanks again. |
@mthrok has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@xiaohui-zhang @BriansIDP What about |
Hi @mthrok , This is a word count file counting all training set word frequencies in train_clean_100. This file is actually only needed when calculating OOV word error rates (since anything not in this file should be counted as OOV words). Should I add a line explaining what this is in get_error_word_count.py and then you can move it if needed? |
In this case I think it's OK to simply keep this file in S3 as well. cc @mthrok |
Hey @mthrok. |
@BriansIDP I merged this PR in 1ed330b. Thank you for the contribution and congrats! Regarding the txt files; I removed rare words, and uploaded them to torchaudio's CDN. 1ed330b#diff-a464de23e8e5d28a210663f87eff1a7fb55b6fcfcb6df611ef5013515b59c554 I did not add |
Thank you @mthrok so much for helping me throughout this PR! I learned a lot! It would be better mentioning this file is only to keep track of the training set vocabulary to calculate OOV word error rates in error_analysis/get_error_word_count.py. Thank you! |
Hi @mthrok. I am planning to write a tutorial about the new biasing module, and I wonder if it is possible to upload my biasing model to CDN so that I can load it in the tutorial (maybe name it https://download.pytorch.org/torchaudio/models/conformer_rnnt_biasing_librispeech.pt). If possible, what would be the best way for me to send the model to you (e.g. I can do this via Google drive)? Thank you so much for your help! |
Hi @BriansIDP Google drive works for us. Make it public or share it with moto@meta.com. |
This pull request contains the implementation of the tree-constrained pointer generator (TCPGen) for contextual biasing. An example for Librispeech can be found in audio/examples/asr/librispeech_biasing.