-
Notifications
You must be signed in to change notification settings - Fork 45
Description
my dataset in about 32*32 images
i run python3 train.py --image_size 32 --exp exp --num_channels 2 --num_channels_dae 64 --ch_mult 1 1 2 2 4 4 --num_timesteps 4 --num_res_blocks 2 --batch_size 1 --contrast1 T1 --contrast2 T2 --num_epoch 50 --ngf 64 --embedding_type positional --r1_gamma 1. --z_emb_dim 256 --lr_d 1e-4 --lr_g 1.6e-4 --lazy_reg 10 --num_process_per_node 1 --save_content --local_rank 0 --input_path ./SynDiff_sample_data --output_path ./output
then i get error
module_path = /home/code/SynDiff/utils/op
padding in x-y with:0-52
padding in x-y with:0-52
padding in x-y with:0-52
padding in x-y with:0-52
train data size:25
val data size:25
initialize network with normal
initialize network with normal
initialize network with normal
initialize network with normal
Traceback (most recent call last):
File "/home/code/SynDiff/train.py", line 1092, in
init_processes(0, size, train_syndiff, args)
File "/home/code/SynDiff/train.py", line 724, in init_processes
fn(rank, gpu, args)
File "/home/code/SynDiff/train.py", line 453, in train_syndiff
x1_0_predict_diff = gen_diffusive_1(torch.cat((x1_tp1.detach(),x2_0_predict),axis=1), t1, latent_z1)
File "/home/miniconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/miniconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/miniconda3/envs/torch/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1519, in forward
else self._run_ddp_forward(*inputs, **kwargs)
File "/home/miniconda3/envs/torch/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1355, in _run_ddp_forward
return self.module(*inputs, **kwargs) # type: ignore[index]
File "/home/miniconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/miniconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/code/SynDiff/backbones/ncsnpp_generator_adagn.py", line 312, in forward
h = modules[m_idx](hs[-1], temb, zemb)
File "/home/miniconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/miniconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
TypeError: forward() takes 2 positional arguments but 4 were given
how can i change param to adapt my data?