Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why is it necessary to train before testing even in zero-shot learning? #3

Open
hexiangdong2020 opened this issue Mar 15, 2024 · 11 comments

Comments

@hexiangdong2020
Copy link

In UniTS_zeroshot_newdata.sh,we found that we need to pre-train before we could test zero-shot. Is there a pre-trained model on this for direct testing?

@gasvn
Copy link
Member

gasvn commented Mar 17, 2024

The zero-shot version of UniTS has shared prompt, cls, and mask tokens for all tasks, which is different from other setting. So we need to pretrain another model for this version.

@hexiangdong2020
Copy link
Author

It seems to take a lot of time to train such a pre-trained model. So are you able to provide pre-trained models that are applicable to this situation?

@gasvn
Copy link
Member

gasvn commented Mar 18, 2024

Our code is still under internal administrative review, and we are not allowed to release new ckpts for now. Training is pretty fast, as it takes about 1-2 day to train on one gpu.

@hexiangdong2020
Copy link
Author

We have spent two days pre-training using UniTS_zeroshot_newdata.sh. However, when executing the second command of UniTS_zeroshot_newdata.sh for testing, it reported the following error:

/home/deeprob/UniTS/auto
no ckpt found!

@hexiangdong2020
Copy link
Author

Therefore, we change "auto" to "pretrain_checkpoint.pth" in the second command and move the pre-trained model to the appropriate position. However, during the testing process, the results of the first three datasets deviated greatly, while the fourth dataset reported errors directly:
image
image

The full log is below:

(units) deeprob@deeprob-MS-7E06:~/UniTS$ bash ./scripts/zero_shot/zeroshot_test.sh
/home/deeprob/anaconda3/envs/units/lib/python3.9/site-packages/gluonts/json.py:101: UserWarning: Using `json`-module for json-handling. Consider installing one of `orjson`, `ujson` to speed up serialization and deserialization.
  warnings.warn(
Start running basic DDP on rank 0.
Args in experiment:
Namespace(task_name='ALL_task', is_training=0, model_id='UniTS_zeroshot_pretrain_x64', model='UniTS_zeroshot', data='All', features='M', target='OT', freq='h', task_data_config_path='data_provider/zeroshot_task.yaml', subsample_pct=None, local_rank=None, dist_url='env://', num_workers=0, memory_check=True, large_model=True, itr=1, train_epochs=10, prompt_tune_epoch=0, warmup_epochs=0, batch_size=32, acc_it=1, learning_rate=0.0001, min_lr=None, weight_decay=0.0, layer_decay=None, des='Exp', lradj='supervised', clip_grad=None, dropout=0.1, checkpoints='./checkpoints/', pretrained_weight='pretrain_checkpoint.pth', debug='online', project_name='zeroshot_newdata', d_model=64, n_heads=8, e_layers=3, share_embedding=False, patch_len=16, stride=16, prompt_num=10, fix_seed=None, inverse=False, mask_rate=0.25, anomaly_ratio=1.0, offset=0, max_offset=0, zero_shot_forecasting_new_length=None)
wandb: Currently logged in as: hexiangdong2020 (deeprob). Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.16.4
wandb: Run data is saved locally in /home/deeprob/UniTS/wandb/run-20240320_021600-fvifbx40
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run ALL_task_UniTS_zeroshot_pretrain_x64_UniTS_zeroshot_All_ftM_dm64_el3_Exp
wandb: ?? View project at https://wandb.ai/deeprob/zeroshot_newdata
wandb: ?? View run at https://wandb.ai/deeprob/zeroshot_newdata/runs/fvifbx40
device id 0
>>>>>>>testing : ALL_task_UniTS_zeroshot_pretrain_x64_UniTS_zeroshot_All_ftM_dm64_el3_Exp_0<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
Solar 10449
Saugeen_River 4622
Hospital 2
Web_Traffic 82
/home/deeprob/UniTS/pretrain_checkpoint.pth
loading pretrained model: pretrain_checkpoint.pth
_IncompatibleKeys(missing_keys=[], unexpected_keys=['module.pretrain_head.proj_in.weight', 'module.pretrain_head.proj_in.bias', 'module.pretrain_head.mlp.fc1.weight', 'module.pretrain_head.mlp.fc1.bias', 'module.pretrain_head.mlp.fc2.weight', 'module.pretrain_head.mlp.fc2.bias', 'module.pretrain_head.proj_out.weight', 'module.pretrain_head.proj_out.bias', 'module.pretrain_head.pos_proj.weights', 'module.pretrain_head.pos_proj.bias'])
data_task_name: Solar mse:0.5455618500709534, mae:0.5404726266860962
data_task_name: Saugeen_River mse:1.3698701858520508, mae:0.7061453461647034
data_task_name: Hospital mse:1.0425046682357788, mae:0.7855547666549683
Traceback (most recent call last):
  File "/home/deeprob/UniTS/run.py", line 178, in <module>
    exp.test(setting, load_pretrain=True)
  File "/home/deeprob/UniTS/exp/exp_sup.py", line 574, in test
    mse, mae = self.test_long_term_forecast(
  File "/home/deeprob/UniTS/exp/exp_sup.py", line 649, in test_long_term_forecast
    outputs = self.model(
  File "/home/deeprob/anaconda3/envs/units/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/deeprob/anaconda3/envs/units/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/deeprob/anaconda3/envs/units/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1523, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
  File "/home/deeprob/anaconda3/envs/units/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1359, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
  File "/home/deeprob/anaconda3/envs/units/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/deeprob/anaconda3/envs/units/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/deeprob/UniTS/models/UniTS_zeroshot.py", line 947, in forward
    dec_out = self.forecast(x_enc, x_mark_enc, task_id)
  File "/home/deeprob/UniTS/models/UniTS_zeroshot.py", line 741, in forecast
    x = self.backbone(x, prefix_prompt.shape[2], seq_token_len)
  File "/home/deeprob/UniTS/models/UniTS_zeroshot.py", line 724, in backbone
    x = block(x, prefix_seq_len=prefix_len +
  File "/home/deeprob/anaconda3/envs/units/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/deeprob/anaconda3/envs/units/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/deeprob/UniTS/models/UniTS_zeroshot.py", line 468, in forward
    x = self.seq_att_block(x, attn_mask)
  File "/home/deeprob/anaconda3/envs/units/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/deeprob/anaconda3/envs/units/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/deeprob/UniTS/models/UniTS_zeroshot.py", line 348, in forward
    x = self.attn_seq(x, attn_mask)
  File "/home/deeprob/anaconda3/envs/units/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/deeprob/anaconda3/envs/units/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/deeprob/UniTS/models/UniTS_zeroshot.py", line 241, in forward
    x = F.scaled_dot_product_attention(
RuntimeError: CUDA error: invalid configuration argument
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

wandb: \ 0.016 MB of 0.018 MB uploaded
wandb: Run history:
wandb:      eval_LF-mae_Hospital ?
wandb: eval_LF-mae_Saugeen_River ?
wandb:         eval_LF-mae_Solar ?
wandb:      eval_LF-mse_Hospital ?
wandb: eval_LF-mse_Saugeen_River ?
wandb:         eval_LF-mse_Solar ?
wandb: 
wandb: Run summary:
wandb:      eval_LF-mae_Hospital 0.78555
wandb: eval_LF-mae_Saugeen_River 0.70615
wandb:         eval_LF-mae_Solar 0.54047
wandb:      eval_LF-mse_Hospital 1.0425
wandb: eval_LF-mse_Saugeen_River 1.36987
wandb:         eval_LF-mse_Solar 0.54556
wandb: 
wandb: ?? View run ALL_task_UniTS_zeroshot_pretrain_x64_UniTS_zeroshot_All_ftM_dm64_el3_Exp at: https://wandb.ai/deeprob/zeroshot_newdata/runs/fvifbx40
wandb: Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20240320_021600-fvifbx40/logs
[2024-03-20 02:17:03,424] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 1068567) of binary: /home/deeprob/anaconda3/envs/units/bin/python
Traceback (most recent call last):
  File "/home/deeprob/anaconda3/envs/units/bin/torchrun", line 33, in <module>
    sys.exit(load_entry_point('torch==2.2.1', 'console_scripts', 'torchrun')())
  File "/home/deeprob/anaconda3/envs/units/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
    return f(*args, **kwargs)
  File "/home/deeprob/anaconda3/envs/units/lib/python3.9/site-packages/torch/distributed/run.py", line 812, in main
    run(args)
  File "/home/deeprob/anaconda3/envs/units/lib/python3.9/site-packages/torch/distributed/run.py", line 803, in run
    elastic_launch(
  File "/home/deeprob/anaconda3/envs/units/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 135, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/deeprob/anaconda3/envs/units/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 268, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
run.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-03-20_02:17:03
  host      : deeprob-MS-7E06
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 1068567)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

@hexiangdong2020
Copy link
Author

Our pre-trained model and logs are attached:
zero_shot.zip

@hexiangdong2020
Copy link
Author

Therefore, how should we correctly reproduce the experimental results in Section 5.3 of the paper? @gasvn

@hexiangdong2020
Copy link
Author

We have spent two days pre-training using UniTS_zeroshot_newdata.sh. However, when executing the second command of UniTS_zeroshot_newdata.sh for testing, it reported the following error:

/home/deeprob/UniTS/auto
no ckpt found!

I read the exp_sup.py. It seems that "auto" is only recognized in the train function, but not in the test function.

@hexiangdong2020
Copy link
Author

Our pre-trained model and logs are attached: zero_shot.zip

The values of the loss function during training are as follows:
image

@gasvn
Copy link
Member

gasvn commented Mar 20, 2024

Therefore, how should we correctly reproduce the experimental results in Section 5.3 of the paper? @gasvn

The results in the paper is only for one sample (the first sample of the dataset), as we need to compare previous zero-shot method which is very slow (They only use one example in their paper for comparison). The results you have by using the current repo is the performance on the whole dataset.

@gasvn
Copy link
Member

gasvn commented Mar 20, 2024

From the loss curve, it seems the training is working well. We will figure out the bug you mentioned. Thank you for the feedback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants