Skip to content

Commit

Permalink
validate transformers being passed into audiolm
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 7, 2022
1 parent 25c5ce6 commit 98723ad
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,16 +202,14 @@ generated_wav_with_text_condition = audiolm(text = ['chirping of birds and the d
- [x] simply not self attend to eos token on the prompting side (semantic for coarse transformer, coarse for fine transformer)
- [x] add structured dropout from forgetful causal masking, far better than traditional dropouts
- [x] figure out how to suppress logging in fairseq
- [x] assert that all three transformers passed into audiolm is compatible

- [ ] figure out how to do the normalization across each dimension mentioned in the paper, but ignore it for v1 of the framework
- [ ] offer option to weight tie coarse, fine, and semantic embeddings across the 3 hierarchical transformers
- [ ] DRY a little at the end
- [ ] test with speech synthesis for starters
- [ ] abstract out conditioning + classifier free guidance into external module or potentially a package
- [ ] add option to use flash attention
- [ ] simplify training even more within AudioLM class
- [ ] cli tool, something like `audiolm generate <wav.file | text>` and save generated wav file to local directory
- [ ] validation function within audiolm that ensures all the pieces are compatible
- [ ] return a list of waves in the case of variable lengthed audio

## Citations
Expand Down
6 changes: 6 additions & 0 deletions audiolm_pytorch/audiolm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,8 @@ def __init__(
**kwargs
):
super().__init__()
self.num_semantic_tokens = num_semantic_tokens

self.has_condition = has_condition
self.embed_text = partial(t5_encode_text, name = t5_name)
self.cond_drop_prob = cond_drop_prob
Expand Down Expand Up @@ -1378,6 +1380,10 @@ def __init__(
):
super().__init__()

assert semantic_transformer.num_semantic_tokens == coarse_transformer.num_semantic_tokens
assert coarse_transformer.codebook_size == fine_transformer.codebook_size
assert coarse_transformer.num_coarse_quantizers == fine_transformer.num_coarse_quantizers

self.semantic = SemanticTransformerWrapper(
wav2vec = wav2vec,
transformer = semantic_transformer,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'audiolm-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.6',
version = '0.1.7',
license='MIT',
description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 98723ad

Please sign in to comment.