2222
2323from transformers import (
2424 BarkCoarseConfig ,
25+ BarkConfig ,
2526 BarkFineConfig ,
2627 BarkSemanticConfig ,
2728 is_torch_available ,
3738from ...generation .test_utils import GenerationTesterMixin
3839from ...test_configuration_common import ConfigTester
3940from ...test_modeling_common import ModelTesterMixin , ids_tensor , random_attention_mask
41+ from ..encodec .test_modeling_encodec import EncodecModelTester
4042
4143
4244if is_torch_available ():
@@ -72,8 +74,6 @@ def __init__(
7274 initializer_range = 0.02 ,
7375 n_codes_total = 8 , # for BarkFineModel
7476 n_codes_given = 1 , # for BarkFineModel
75- config_class = None ,
76- model_class = None ,
7777 ):
7878 self .parent = parent
7979 self .batch_size = batch_size
@@ -98,8 +98,6 @@ def __init__(
9898 self .n_codes_given = n_codes_given
9999
100100 self .is_encoder_decoder = False
101- self .config_class = config_class
102- self .model_class = model_class
103101
104102 def prepare_config_and_inputs (self ):
105103 input_ids = ids_tensor ([self .batch_size , self .seq_length ], self .vocab_size )
@@ -121,7 +119,7 @@ def prepare_config_and_inputs(self):
121119 return config , inputs_dict
122120
123121 def get_config (self ):
124- return self . config_class (
122+ return BarkSemanticConfig (
125123 vocab_size = self .vocab_size ,
126124 output_vocab_size = self .output_vocab_size ,
127125 hidden_size = self .hidden_size ,
@@ -137,14 +135,15 @@ def get_config(self):
137135 def get_pipeline_config (self ):
138136 config = self .get_config ()
139137 config .vocab_size = 300
138+ config .output_vocab_size = 300
140139 return config
141140
142141 def prepare_config_and_inputs_for_common (self ):
143142 config , inputs_dict = self .prepare_config_and_inputs ()
144143 return config , inputs_dict
145144
146145 def create_and_check_decoder_model_past_large_inputs (self , config , inputs_dict ):
147- model = self . model_class (config = config ).to (torch_device ).eval ()
146+ model = BarkSemanticModel (config = config ).to (torch_device ).eval ()
148147
149148 input_ids = inputs_dict ["input_ids" ]
150149 attention_mask = inputs_dict ["attention_mask" ]
@@ -211,8 +210,6 @@ def __init__(
211210 initializer_range = 0.02 ,
212211 n_codes_total = 8 , # for BarkFineModel
213212 n_codes_given = 1 , # for BarkFineModel
214- config_class = None ,
215- model_class = None ,
216213 ):
217214 self .parent = parent
218215 self .batch_size = batch_size
@@ -237,8 +234,6 @@ def __init__(
237234 self .n_codes_given = n_codes_given
238235
239236 self .is_encoder_decoder = False
240- self .config_class = config_class
241- self .model_class = model_class
242237
243238 def prepare_config_and_inputs (self ):
244239 input_ids = ids_tensor ([self .batch_size , self .seq_length ], self .vocab_size )
@@ -260,7 +255,7 @@ def prepare_config_and_inputs(self):
260255 return config , inputs_dict
261256
262257 def get_config (self ):
263- return self . config_class (
258+ return BarkCoarseConfig (
264259 vocab_size = self .vocab_size ,
265260 output_vocab_size = self .output_vocab_size ,
266261 hidden_size = self .hidden_size ,
@@ -276,14 +271,15 @@ def get_config(self):
276271 def get_pipeline_config (self ):
277272 config = self .get_config ()
278273 config .vocab_size = 300
274+ config .output_vocab_size = 300
279275 return config
280276
281277 def prepare_config_and_inputs_for_common (self ):
282278 config , inputs_dict = self .prepare_config_and_inputs ()
283279 return config , inputs_dict
284280
285281 def create_and_check_decoder_model_past_large_inputs (self , config , inputs_dict ):
286- model = self . model_class (config = config ).to (torch_device ).eval ()
282+ model = BarkCoarseModel (config = config ).to (torch_device ).eval ()
287283
288284 input_ids = inputs_dict ["input_ids" ]
289285 attention_mask = inputs_dict ["attention_mask" ]
@@ -350,8 +346,6 @@ def __init__(
350346 initializer_range = 0.02 ,
351347 n_codes_total = 8 , # for BarkFineModel
352348 n_codes_given = 1 , # for BarkFineModel
353- config_class = None ,
354- model_class = None ,
355349 ):
356350 self .parent = parent
357351 self .batch_size = batch_size
@@ -376,8 +370,6 @@ def __init__(
376370 self .n_codes_given = n_codes_given
377371
378372 self .is_encoder_decoder = False
379- self .config_class = config_class
380- self .model_class = model_class
381373
382374 def prepare_config_and_inputs (self ):
383375 input_ids = ids_tensor ([self .batch_size , self .seq_length , self .n_codes_total ], self .vocab_size )
@@ -403,7 +395,7 @@ def prepare_config_and_inputs(self):
403395 return config , inputs_dict
404396
405397 def get_config (self ):
406- return self . config_class (
398+ return BarkFineConfig (
407399 vocab_size = self .vocab_size ,
408400 output_vocab_size = self .output_vocab_size ,
409401 hidden_size = self .hidden_size ,
@@ -419,14 +411,15 @@ def get_config(self):
419411 def get_pipeline_config (self ):
420412 config = self .get_config ()
421413 config .vocab_size = 300
414+ config .output_vocab_size = 300
422415 return config
423416
424417 def prepare_config_and_inputs_for_common (self ):
425418 config , inputs_dict = self .prepare_config_and_inputs ()
426419 return config , inputs_dict
427420
428421 def create_and_check_decoder_model_past_large_inputs (self , config , inputs_dict ):
429- model = self . model_class (config = config ).to (torch_device ).eval ()
422+ model = BarkFineModel (config = config ).to (torch_device ).eval ()
430423
431424 input_ids = inputs_dict ["input_ids" ]
432425 attention_mask = inputs_dict ["attention_mask" ]
@@ -473,6 +466,79 @@ def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
473466 self .parent .assertTrue (torch .allclose (output_from_past_slice , output_from_no_past_slice , atol = 1e-3 ))
474467
475468
469+ class BarkModelTester :
470+ def __init__ (
471+ self ,
472+ parent ,
473+ semantic_kwargs = None ,
474+ coarse_acoustics_kwargs = None ,
475+ fine_acoustics_kwargs = None ,
476+ codec_kwargs = None ,
477+ is_training = False , # for now training is not supported
478+ ):
479+ if semantic_kwargs is None :
480+ semantic_kwargs = {}
481+ if coarse_acoustics_kwargs is None :
482+ coarse_acoustics_kwargs = {}
483+ if fine_acoustics_kwargs is None :
484+ fine_acoustics_kwargs = {}
485+ if codec_kwargs is None :
486+ codec_kwargs = {}
487+
488+ self .parent = parent
489+ self .semantic_model_tester = BarkSemanticModelTester (parent , ** semantic_kwargs )
490+ self .coarse_acoustics_model_tester = BarkCoarseModelTester (parent , ** coarse_acoustics_kwargs )
491+ self .fine_acoustics_model_tester = BarkFineModelTester (parent , ** fine_acoustics_kwargs )
492+ self .codec_model_tester = EncodecModelTester (parent , ** codec_kwargs )
493+
494+ self .is_training = is_training
495+
496+ def prepare_config_and_inputs (self ):
497+ # TODO: @Yoach: Preapre `inputs_dict`
498+ inputs_dict = {}
499+ config = self .get_config ()
500+
501+ return config , inputs_dict
502+
503+ def get_config (self ):
504+ return BarkConfig .from_sub_model_configs (
505+ self .semantic_model_tester .get_config (),
506+ self .coarse_acoustics_model_tester .get_config (),
507+ self .fine_acoustics_model_tester .get_config (),
508+ self .codec_model_tester .get_config (),
509+ )
510+
511+ def get_pipeline_config (self ):
512+ config = self .get_config ()
513+
514+ # follow the `get_pipeline_config` of the sub component models
515+ config .semantic_config .vocab_size = 300
516+ config .coarse_acoustics_config .vocab_size = 300
517+ config .fine_acoustics_config .vocab_size = 300
518+
519+ config .semantic_config .output_vocab_size = 300
520+ config .coarse_acoustics_config .output_vocab_size = 300
521+ config .fine_acoustics_config .output_vocab_size = 300
522+
523+ return config
524+
525+ def prepare_config_and_inputs_for_common (self ):
526+ # TODO: @Yoach
527+ pass
528+ # return config, inputs_dict
529+
530+
531+ # Need this class in oder to create tiny model for `bark`
532+ # TODO (@Yoach) Implement actual test methods
533+ @unittest .skip ("So far all tests will fail." )
534+ class BarkModelTest (ModelTesterMixin , GenerationTesterMixin , unittest .TestCase ):
535+ all_model_classes = (BarkModel ,) if is_torch_available () else ()
536+
537+ def setUp (self ):
538+ self .model_tester = BarkModelTester (self )
539+ self .config_tester = ConfigTester (self , config_class = BarkConfig , n_embd = 37 )
540+
541+
476542@require_torch
477543class BarkSemanticModelTest (ModelTesterMixin , GenerationTesterMixin , unittest .TestCase ):
478544 all_model_classes = (BarkSemanticModel ,) if is_torch_available () else ()
@@ -488,9 +554,7 @@ class BarkSemanticModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Te
488554 test_resize_embeddings = True
489555
490556 def setUp (self ):
491- self .model_tester = BarkSemanticModelTester (
492- self , config_class = BarkSemanticConfig , model_class = BarkSemanticModel
493- )
557+ self .model_tester = BarkSemanticModelTester (self )
494558 self .config_tester = ConfigTester (self , config_class = BarkSemanticConfig , n_embd = 37 )
495559
496560 def test_config (self ):
@@ -556,7 +620,7 @@ class BarkCoarseModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
556620 test_resize_embeddings = True
557621
558622 def setUp (self ):
559- self .model_tester = BarkCoarseModelTester (self , config_class = BarkCoarseConfig , model_class = BarkCoarseModel )
623+ self .model_tester = BarkCoarseModelTester (self )
560624 self .config_tester = ConfigTester (self , config_class = BarkCoarseConfig , n_embd = 37 )
561625
562626 def test_config (self ):
@@ -623,7 +687,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
623687 test_resize_embeddings = True
624688
625689 def setUp (self ):
626- self .model_tester = BarkFineModelTester (self , config_class = BarkFineConfig , model_class = BarkFineModel )
690+ self .model_tester = BarkFineModelTester (self )
627691 self .config_tester = ConfigTester (self , config_class = BarkFineConfig , n_embd = 37 )
628692
629693 def test_config (self ):
0 commit comments