-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathrm_clip_and_add_channels.py
32 lines (21 loc) · 1.43 KB
/
rm_clip_and_add_channels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
pretrained_model_path='../checkpoints/original/model.ckpt'
ckpt_file=torch.load(pretrained_model_path,map_location='cpu')
# add input conv mask channel
new_input_weight=torch.zeros(320,1,3,3)
ckpt_file['state_dict']['model.diffusion_model.input_blocks.0.0.weight']=torch.cat((torch.cat((ckpt_file['state_dict']['model.diffusion_model.input_blocks.0.0.weight'][:,:4], new_input_weight), dim = 1),ckpt_file['state_dict']['model.diffusion_model.input_blocks.0.0.weight'][:,4:]),dim = 1)
# add input conv pose channel
new_input_weight=torch.zeros(320,8,3,3)
ckpt_file['state_dict']['model.diffusion_model.input_blocks.0.0.weight']=torch.cat((ckpt_file['state_dict']['model.diffusion_model.input_blocks.0.0.weight'],new_input_weight),dim=1)
# add output conv mask channel
new_output_weight=torch.zeros(1,320,3,3)
ckpt_file['state_dict']['model.diffusion_model.out.2.weight']=torch.cat((ckpt_file['state_dict']['model.diffusion_model.out.2.weight'],new_output_weight),dim=0)
new_output_bias=torch.zeros(1)
ckpt_file['state_dict']['model.diffusion_model.out.2.bias']=torch.cat((ckpt_file['state_dict']['model.diffusion_model.out.2.bias'],new_output_bias),dim=0)
state_dict = ckpt_file['state_dict']
new_state_dict = {}
for key, value in state_dict.items():
if not key.startswith('cond_stage_model'):
new_state_dict[key] = value
ckpt_file['state_dict'] = new_state_dict
torch.save(ckpt_file,"../checkpoints/original/model_prepared.ckpt")