|
20 | 20 |
|
21 | 21 | import torch |
22 | 22 |
|
23 | | -from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel, ValueFunction |
| 23 | +from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel |
24 | 24 | from diffusers.utils import floats_tensor, slow, torch_device |
25 | 25 |
|
26 | 26 | from .test_modeling_common import ModelTesterMixin |
@@ -524,86 +524,3 @@ def test_output_pretrained(self): |
524 | 524 | def test_forward_with_norm_groups(self): |
525 | 525 | # Not implemented yet for this UNet |
526 | 526 | pass |
527 | | - |
528 | | - |
529 | | -class UNetRLModelTests(ModelTesterMixin, unittest.TestCase): |
530 | | - model_class = ValueFunction |
531 | | - |
532 | | - @property |
533 | | - def dummy_input(self): |
534 | | - batch_size = 4 |
535 | | - num_features = 14 |
536 | | - seq_len = 16 |
537 | | - |
538 | | - noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) |
539 | | - time_step = torch.tensor([10] * batch_size).to(torch_device) |
540 | | - |
541 | | - return {"sample": noise, "timestep": time_step} |
542 | | - |
543 | | - @property |
544 | | - def input_shape(self): |
545 | | - return (4, 14, 16) |
546 | | - |
547 | | - @property |
548 | | - def output_shape(self): |
549 | | - return (4, 14, 1) |
550 | | - |
551 | | - def test_ema_training(self): |
552 | | - pass |
553 | | - |
554 | | - def test_training(self): |
555 | | - pass |
556 | | - |
557 | | - def prepare_init_args_and_inputs_for_common(self): |
558 | | - init_dict = { |
559 | | - "block_out_channels": (32, 64, 128, 256), |
560 | | - "in_channels": 14, |
561 | | - "out_channels": 14, |
562 | | - } |
563 | | - inputs_dict = self.dummy_input |
564 | | - return init_dict, inputs_dict |
565 | | - |
566 | | - def test_from_pretrained_hub(self): |
567 | | - unet, loading_info = UNet1DModel.from_pretrained( |
568 | | - "bglick13/hopper-medium-v2-unet-hor32", output_loading_info=True |
569 | | - ) |
570 | | - value_function, vf_loading_info = ValueFunction.from_pretrained( |
571 | | - "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True |
572 | | - ) |
573 | | - self.assertIsNotNone(unet) |
574 | | - self.assertEqual(len(loading_info["missing_keys"]), 0) |
575 | | - self.assertIsNotNone(value_function) |
576 | | - self.assertEqual(len(vf_loading_info["missing_keys"]), 0) |
577 | | - |
578 | | - unet.to(torch_device) |
579 | | - value_function.to(torch_device) |
580 | | - image = value_function(**self.dummy_input) |
581 | | - |
582 | | - assert image is not None, "Make sure output is not None" |
583 | | - |
584 | | - def test_output_pretrained(self): |
585 | | - value_function, vf_loading_info = ValueFunction.from_pretrained( |
586 | | - "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True |
587 | | - ) |
588 | | - torch.manual_seed(0) |
589 | | - if torch.cuda.is_available(): |
590 | | - torch.cuda.manual_seed_all(0) |
591 | | - |
592 | | - num_features = value_function.in_channels |
593 | | - seq_len = 14 |
594 | | - noise = torch.randn((1, seq_len, num_features)).permute( |
595 | | - 0, 2, 1 |
596 | | - ) # match original, we can update values and remove |
597 | | - time_step = torch.full((num_features,), 0) |
598 | | - |
599 | | - with torch.no_grad(): |
600 | | - output = value_function(noise, time_step).sample |
601 | | - |
602 | | - # fmt: off |
603 | | - expected_output_slice = torch.tensor([207.0272] * seq_len) |
604 | | - # fmt: on |
605 | | - self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3)) |
606 | | - |
607 | | - def test_forward_with_norm_groups(self): |
608 | | - # Not implemented yet for this UNet |
609 | | - pass |
0 commit comments