-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
100 lines (80 loc) · 2.44 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
from typing import Literal
from tap import Tap
from trainers.gfn_trainer import GFNTrainer
from trainers.mle_trainer import MLETrainer
from trainers.safety_trainer import SafetyTrainer
from trainers.sft_trainer import SFTTrainer
from utils import load_victim_config, seed
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class Argument(Tap):
baseline: bool = False
mode: Literal["gfn", "sft", "mle", "safety"] = "gfn"
model_name: str = "gpt2"
victim_model: str = "vicgalle/gpt2-alpaca"
sft_ckpt: str = "save/gpt2-sft-position-final/latest"
save_dir: str = "./save"
prompt_file: str = "prompts/attack_prompt.jsonl"
few_shot_file: str = "prompts/sft_dataset.json"
epochs: int = 1
lr: float = 1e-4
max_norm: float = 1.0
weight_decay: float = 0.1
num_warmup_steps: int = 100
train_steps: int = 5000
batch_size: int = 16
grad_acc_steps: int = 8
len_norm: bool = False
max_len: int = 20
min_len: int = 5
victim_top_p: float = 0.92
victim_max_len: int = 30
victim_temp: float = 0.7
use_4bit: bool = False
load_buffer: bool = False
buffer_size: int = 1000
sim_tolerance: float = 0.25
prioritization: Literal["c_reward", "reward", "uniform"] = "c_reward"
buffer_ckpt: str = ""
compare: str = "reward"
metric: Literal["edit", "cosine"] = "edit"
dtype: str = "float32"
seed: int = 42
eval_period: int = 500
eval_batch_size: int = 1024
# lora hparams
lora: bool = False
lora_r: int = 32
lora_alpha: int = 16
lora_dropout: float = 0.0
# reward scaling
beta: float = 0.1
lm_sched_end: float = 1.0
lm_sched_start: float = 1.0
lm_sched_horizon: int = 2000
# reward temperature
reward_sched_start: float = 2.0
reward_sched_end: float = 1.0
reward_sched_horizon: int = 500
# sampling temperature
temp_low: float = 0.5
temp_high: float = 2.0
# victim model
num_r_samples: int = 5
do_sample: bool = True
# wandb
exp_name: str = "debug"
wandb_project: str = "red-team"
if __name__ == "__main__":
args = Argument(explicit_bool=True).parse_args()
load_victim_config(args)
seed(args.seed)
if args.mode == "gfn":
trainer = GFNTrainer(args)
elif args.mode == "mle":
trainer = MLETrainer(args)
elif args.mode == "safety":
trainer = SafetyTrainer(args)
else:
trainer = SFTTrainer(args)
trainer.train()