Skip to content

Commit fcf7f4c

Browse files
ydshiehEduardoPach
authored andcommitted
Make bark could have tiny model (huggingface#25290)
* temp * update * update * update * small dim * small dim * small dim * fix * update * fix * fix * fix * fix * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
1 parent f307068 commit fcf7f4c

File tree

3 files changed

+95
-26
lines changed

3 files changed

+95
-26
lines changed

src/transformers/models/bark/configuration_bark.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from ...configuration_utils import PretrainedConfig
2121
from ...utils import add_start_docstrings, logging
22-
from ..auto import AutoConfig
22+
from ..auto import CONFIG_MAPPING
2323

2424

2525
logger = logging.get_logger(__name__)
@@ -299,7 +299,8 @@ def __init__(
299299
self.semantic_config = BarkSemanticConfig(**semantic_config)
300300
self.coarse_acoustics_config = BarkCoarseConfig(**coarse_acoustics_config)
301301
self.fine_acoustics_config = BarkFineConfig(**fine_acoustics_config)
302-
self.codec_config = AutoConfig.for_model(**codec_config)
302+
codec_model_type = codec_config["model_type"] if "model_type" in codec_config else "encodec"
303+
self.codec_config = CONFIG_MAPPING[codec_model_type](**codec_config)
303304

304305
self.initializer_range = initializer_range
305306

@@ -311,7 +312,7 @@ def from_sub_model_configs(
311312
semantic_config: BarkSemanticConfig,
312313
coarse_acoustics_config: BarkCoarseConfig,
313314
fine_acoustics_config: BarkFineConfig,
314-
codec_config: AutoConfig,
315+
codec_config: PretrainedConfig,
315316
**kwargs,
316317
):
317318
r"""

tests/models/bark/test_modeling_bark.py

Lines changed: 87 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from transformers import (
2424
BarkCoarseConfig,
25+
BarkConfig,
2526
BarkFineConfig,
2627
BarkSemanticConfig,
2728
is_torch_available,
@@ -37,6 +38,7 @@
3738
from ...generation.test_utils import GenerationTesterMixin
3839
from ...test_configuration_common import ConfigTester
3940
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
41+
from ..encodec.test_modeling_encodec import EncodecModelTester
4042

4143

4244
if is_torch_available():
@@ -72,8 +74,6 @@ def __init__(
7274
initializer_range=0.02,
7375
n_codes_total=8, # for BarkFineModel
7476
n_codes_given=1, # for BarkFineModel
75-
config_class=None,
76-
model_class=None,
7777
):
7878
self.parent = parent
7979
self.batch_size = batch_size
@@ -98,8 +98,6 @@ def __init__(
9898
self.n_codes_given = n_codes_given
9999

100100
self.is_encoder_decoder = False
101-
self.config_class = config_class
102-
self.model_class = model_class
103101

104102
def prepare_config_and_inputs(self):
105103
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@@ -121,7 +119,7 @@ def prepare_config_and_inputs(self):
121119
return config, inputs_dict
122120

123121
def get_config(self):
124-
return self.config_class(
122+
return BarkSemanticConfig(
125123
vocab_size=self.vocab_size,
126124
output_vocab_size=self.output_vocab_size,
127125
hidden_size=self.hidden_size,
@@ -137,14 +135,15 @@ def get_config(self):
137135
def get_pipeline_config(self):
138136
config = self.get_config()
139137
config.vocab_size = 300
138+
config.output_vocab_size = 300
140139
return config
141140

142141
def prepare_config_and_inputs_for_common(self):
143142
config, inputs_dict = self.prepare_config_and_inputs()
144143
return config, inputs_dict
145144

146145
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
147-
model = self.model_class(config=config).to(torch_device).eval()
146+
model = BarkSemanticModel(config=config).to(torch_device).eval()
148147

149148
input_ids = inputs_dict["input_ids"]
150149
attention_mask = inputs_dict["attention_mask"]
@@ -211,8 +210,6 @@ def __init__(
211210
initializer_range=0.02,
212211
n_codes_total=8, # for BarkFineModel
213212
n_codes_given=1, # for BarkFineModel
214-
config_class=None,
215-
model_class=None,
216213
):
217214
self.parent = parent
218215
self.batch_size = batch_size
@@ -237,8 +234,6 @@ def __init__(
237234
self.n_codes_given = n_codes_given
238235

239236
self.is_encoder_decoder = False
240-
self.config_class = config_class
241-
self.model_class = model_class
242237

243238
def prepare_config_and_inputs(self):
244239
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@@ -260,7 +255,7 @@ def prepare_config_and_inputs(self):
260255
return config, inputs_dict
261256

262257
def get_config(self):
263-
return self.config_class(
258+
return BarkCoarseConfig(
264259
vocab_size=self.vocab_size,
265260
output_vocab_size=self.output_vocab_size,
266261
hidden_size=self.hidden_size,
@@ -276,14 +271,15 @@ def get_config(self):
276271
def get_pipeline_config(self):
277272
config = self.get_config()
278273
config.vocab_size = 300
274+
config.output_vocab_size = 300
279275
return config
280276

281277
def prepare_config_and_inputs_for_common(self):
282278
config, inputs_dict = self.prepare_config_and_inputs()
283279
return config, inputs_dict
284280

285281
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
286-
model = self.model_class(config=config).to(torch_device).eval()
282+
model = BarkCoarseModel(config=config).to(torch_device).eval()
287283

288284
input_ids = inputs_dict["input_ids"]
289285
attention_mask = inputs_dict["attention_mask"]
@@ -350,8 +346,6 @@ def __init__(
350346
initializer_range=0.02,
351347
n_codes_total=8, # for BarkFineModel
352348
n_codes_given=1, # for BarkFineModel
353-
config_class=None,
354-
model_class=None,
355349
):
356350
self.parent = parent
357351
self.batch_size = batch_size
@@ -376,8 +370,6 @@ def __init__(
376370
self.n_codes_given = n_codes_given
377371

378372
self.is_encoder_decoder = False
379-
self.config_class = config_class
380-
self.model_class = model_class
381373

382374
def prepare_config_and_inputs(self):
383375
input_ids = ids_tensor([self.batch_size, self.seq_length, self.n_codes_total], self.vocab_size)
@@ -403,7 +395,7 @@ def prepare_config_and_inputs(self):
403395
return config, inputs_dict
404396

405397
def get_config(self):
406-
return self.config_class(
398+
return BarkFineConfig(
407399
vocab_size=self.vocab_size,
408400
output_vocab_size=self.output_vocab_size,
409401
hidden_size=self.hidden_size,
@@ -419,14 +411,15 @@ def get_config(self):
419411
def get_pipeline_config(self):
420412
config = self.get_config()
421413
config.vocab_size = 300
414+
config.output_vocab_size = 300
422415
return config
423416

424417
def prepare_config_and_inputs_for_common(self):
425418
config, inputs_dict = self.prepare_config_and_inputs()
426419
return config, inputs_dict
427420

428421
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
429-
model = self.model_class(config=config).to(torch_device).eval()
422+
model = BarkFineModel(config=config).to(torch_device).eval()
430423

431424
input_ids = inputs_dict["input_ids"]
432425
attention_mask = inputs_dict["attention_mask"]
@@ -473,6 +466,79 @@ def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
473466
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
474467

475468

469+
class BarkModelTester:
470+
def __init__(
471+
self,
472+
parent,
473+
semantic_kwargs=None,
474+
coarse_acoustics_kwargs=None,
475+
fine_acoustics_kwargs=None,
476+
codec_kwargs=None,
477+
is_training=False, # for now training is not supported
478+
):
479+
if semantic_kwargs is None:
480+
semantic_kwargs = {}
481+
if coarse_acoustics_kwargs is None:
482+
coarse_acoustics_kwargs = {}
483+
if fine_acoustics_kwargs is None:
484+
fine_acoustics_kwargs = {}
485+
if codec_kwargs is None:
486+
codec_kwargs = {}
487+
488+
self.parent = parent
489+
self.semantic_model_tester = BarkSemanticModelTester(parent, **semantic_kwargs)
490+
self.coarse_acoustics_model_tester = BarkCoarseModelTester(parent, **coarse_acoustics_kwargs)
491+
self.fine_acoustics_model_tester = BarkFineModelTester(parent, **fine_acoustics_kwargs)
492+
self.codec_model_tester = EncodecModelTester(parent, **codec_kwargs)
493+
494+
self.is_training = is_training
495+
496+
def prepare_config_and_inputs(self):
497+
# TODO: @Yoach: Preapre `inputs_dict`
498+
inputs_dict = {}
499+
config = self.get_config()
500+
501+
return config, inputs_dict
502+
503+
def get_config(self):
504+
return BarkConfig.from_sub_model_configs(
505+
self.semantic_model_tester.get_config(),
506+
self.coarse_acoustics_model_tester.get_config(),
507+
self.fine_acoustics_model_tester.get_config(),
508+
self.codec_model_tester.get_config(),
509+
)
510+
511+
def get_pipeline_config(self):
512+
config = self.get_config()
513+
514+
# follow the `get_pipeline_config` of the sub component models
515+
config.semantic_config.vocab_size = 300
516+
config.coarse_acoustics_config.vocab_size = 300
517+
config.fine_acoustics_config.vocab_size = 300
518+
519+
config.semantic_config.output_vocab_size = 300
520+
config.coarse_acoustics_config.output_vocab_size = 300
521+
config.fine_acoustics_config.output_vocab_size = 300
522+
523+
return config
524+
525+
def prepare_config_and_inputs_for_common(self):
526+
# TODO: @Yoach
527+
pass
528+
# return config, inputs_dict
529+
530+
531+
# Need this class in oder to create tiny model for `bark`
532+
# TODO (@Yoach) Implement actual test methods
533+
@unittest.skip("So far all tests will fail.")
534+
class BarkModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
535+
all_model_classes = (BarkModel,) if is_torch_available() else ()
536+
537+
def setUp(self):
538+
self.model_tester = BarkModelTester(self)
539+
self.config_tester = ConfigTester(self, config_class=BarkConfig, n_embd=37)
540+
541+
476542
@require_torch
477543
class BarkSemanticModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
478544
all_model_classes = (BarkSemanticModel,) if is_torch_available() else ()
@@ -488,9 +554,7 @@ class BarkSemanticModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Te
488554
test_resize_embeddings = True
489555

490556
def setUp(self):
491-
self.model_tester = BarkSemanticModelTester(
492-
self, config_class=BarkSemanticConfig, model_class=BarkSemanticModel
493-
)
557+
self.model_tester = BarkSemanticModelTester(self)
494558
self.config_tester = ConfigTester(self, config_class=BarkSemanticConfig, n_embd=37)
495559

496560
def test_config(self):
@@ -556,7 +620,7 @@ class BarkCoarseModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
556620
test_resize_embeddings = True
557621

558622
def setUp(self):
559-
self.model_tester = BarkCoarseModelTester(self, config_class=BarkCoarseConfig, model_class=BarkCoarseModel)
623+
self.model_tester = BarkCoarseModelTester(self)
560624
self.config_tester = ConfigTester(self, config_class=BarkCoarseConfig, n_embd=37)
561625

562626
def test_config(self):
@@ -623,7 +687,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
623687
test_resize_embeddings = True
624688

625689
def setUp(self):
626-
self.model_tester = BarkFineModelTester(self, config_class=BarkFineConfig, model_class=BarkFineModel)
690+
self.model_tester = BarkFineModelTester(self)
627691
self.config_tester = ConfigTester(self, config_class=BarkFineConfig, n_embd=37)
628692

629693
def test_config(self):

utils/create_dummy_models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,10 @@ def get_token_id_from_tokenizer(token_id_name, tokenizer, original_token_id):
974974

975975

976976
def get_config_overrides(config_class, processors):
977+
# `Bark` configuration is too special. Let's just not handle this for now.
978+
if config_class.__name__ == "BarkConfig":
979+
return {}
980+
977981
config_overrides = {}
978982

979983
# Check if there is any tokenizer (prefer fast version if any)

0 commit comments

Comments
 (0)