Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DDIM inversion on the pretrained T2V model #131

Open
qpc1611094 opened this issue Jun 27, 2024 · 0 comments
Open

DDIM inversion on the pretrained T2V model #131

qpc1611094 opened this issue Jun 27, 2024 · 0 comments

Comments

@qpc1611094
Copy link

qpc1611094 commented Jun 27, 2024

I have tried DDIM inversion on the modelscope T2V, but get some abnormal results.
I apply as follows:

1) get the original video latent:

def load_video_frames(autoencoder, vid_path, train_trans, max_frames=16, double_frames_sr=False):
    capture = cv2.VideoCapture(vid_path)
    _fps = capture.get(cv2.CAP_PROP_FPS)
    sample_fps = _fps
    _total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT)
    stride = round(_fps / sample_fps)
    cover_frame_num = (stride * max_frames)
    if _total_frame_num < cover_frame_num + 5:
        start_frame = 0
        end_frame = _total_frame_num
    else:
        start_frame = 0
        end_frame = _total_frame_num
    
    pointer = 0
    frame_list = []
    while len(frame_list) < max_frames:
        ret, frame = capture.read()
        pointer += 1 
        if (not ret) or (frame is None): break
        if pointer < start_frame: continue
        if pointer >= _total_frame_num + 1: break
        if (pointer - start_frame) % stride == 0:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = Image.fromarray(frame)
            if double_frames_sr:
                frame_list.append(frame)
            frame_list.append(frame)
    
    capture.release()
    video_data = train_trans(frame_list)

    video_data = torch.nn.functional.interpolate(video_data, size=(256, 448), mode='bilinear')
    video_data = video_data.unsqueeze(0)
    video_data = video_data.cuda()

    batch_size, frames_num, _, _, _ = video_data.shape
    video_data = rearrange(video_data, 'b f c h w -> (b f) c h w')
    video_data_list = torch.chunk(video_data, video_data.shape[0]//2, dim=0)

    with torch.no_grad():
        decode_data = []
        for vd_data in video_data_list:
            tmp = autoencoder.encode_firsr_stage(vd_data, cfg.scale_factor).detach()
            decode_data.append(tmp)
        video_data_feature = torch.cat(decode_data, dim=0)
        video_data_feature = rearrange(video_data_feature, '(b f) c h w -> b c f h w', b = batch_size)
    return video_data_feature

train_trans = data.Compose([
        data.ToTensor(),
        data.Normalize(mean=cfg.mean, std=cfg.std)])
video_data_feature = load_video_frames(autoencoder, cfg.test_video_path, train_trans)

2)obtain the noise latent by DDIM inversion:

model_kwargs=[{'y': y_words, 'fps': fps_tensor},
                    {'y': zero_y_negative, 'fps': fps_tensor}]
noised_vid_feat = diffusion.ddim_reverse_sample_loop(video_data_feature,
                                                     model.eval(),
                                                     model_kwargs=model_kwargs,
                                                     clamp=None,
                                                     percentile=None,
                                                     guide_scale=cfg.guide_scale,
                                                     ddim_timesteps=cfg.ddim_timesteps)

3)reconstruct the original video:

video_reconstruct = diffusion.ddim_sample_loop(
                    noise=noised_vid_feat,
                    model=model.eval(),
                    model_kwargs=model_kwargs,
                    guide_scale=cfg.guide_scale,
                    ddim_timesteps=cfg.ddim_timesteps,
                    eta=0.0)
video_reconstruct = 1. / cfg.scale_factor * video_reconstruct
video_reconstruct = rearrange(video_reconstruct, 'b c f h w -> (b f) c h w')
chunk_size = min(cfg.decoder_bs, video_reconstruct.shape[0])
video_reconstruct_list = torch.chunk(video_reconstruct, video_reconstruct.shape[0]//chunk_size, dim=0)
decode_reconstruct = []
for vd_data in video_reconstruct_list:
    gen_frames = autoencoder.decode(vd_data)
    decode_reconstruct.append(gen_frames)
video_reconstruct = torch.cat(decode_reconstruct, dim=0)
video_reconstruct = rearrange(video_reconstruct, '(b f) c h w -> b c f h w', b = 1)
save_i2vgen_video_safe(local_path, video_reconstruct.cpu(), captions, cfg.mean, cfg.std, text_size)

the original video is:
rank_02_01_0003_A_horse_running_on_the_road

However, the reconstruct video is completely collapsed as:
Unknown

I guess maybe this problem is about the scale_factor, since when I use the cfg.scale_factor=1.0, the result seems better:
rank_01_00_reconstruct

Looking forward to your reply, very thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant