forked from princeton-nlp/SWE-agent
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
223 lines (190 loc) · 7.76 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import json
import logging
import os
import re
import traceback
import yaml
from dataclasses import dataclass
from getpass import getuser
from pathlib import Path
from rich.logging import RichHandler
from simple_parsing import parse
from simple_parsing.helpers import FrozenSerializable, FlattenedAccess
from sweagent import (
Agent,
AgentArguments,
EnvironmentArguments,
ModelArguments,
SWEEnv,
get_data_path_name,
)
from swebench import KEY_INSTANCE_ID, KEY_MODEL, KEY_PREDICTION
from unidiff import PatchSet
handler = RichHandler(show_time=False, show_path=False)
handler.setLevel(logging.DEBUG)
logger = logging.getLogger("run_dev")
logger.setLevel(logging.DEBUG)
logger.addHandler(handler)
logger.propagate = False
logging.getLogger("simple_parsing").setLevel(logging.WARNING)
@dataclass(frozen=True)
class ScriptArguments(FlattenedAccess, FrozenSerializable):
environment: EnvironmentArguments
agent: AgentArguments
instance_filter: str = ".*" # Only run instances that completely match this regex
skip_existing: bool = True # Skip instances with existing trajectories
suffix: str = ""
@property
def run_name(self):
"""Generate a unique name for this run based on the arguments."""
model_name = args.agent.model.model_name.replace(":", "-")
data_stem = get_data_path_name(args.environment.data_path)
config_stem = Path(args.agent.config_file).stem
temp = args.agent.model.temperature
top_p = args.agent.model.top_p
per_instance_cost_limit = args.agent.model.per_instance_cost_limit
install_env = args.environment.install_environment
return (
f"{model_name}__{data_stem}__{config_stem}__t-{temp:.2f}__p-{top_p:.2f}"
+ f"__c-{per_instance_cost_limit:.2f}__install-{int(install_env)}"
+ (f"__{self.suffix}" if self.suffix else "")
)
def main(args: ScriptArguments):
logger.info(f"📙 Arguments: {args.dumps_yaml()}")
agent = Agent("primary", args.agent)
env = SWEEnv(args.environment)
traj_dir = Path("trajectories") / Path(getuser()) / args.run_name
os.makedirs(traj_dir, exist_ok=True)
save_arguments(traj_dir, args)
for index in range(len(env.data)):
try:
# Reset environment
instance_id = env.data[index]["instance_id"]
if should_skip(args, traj_dir, instance_id):
continue
logger.info("▶️ Beginning task " + str(index))
observation, info = env.reset(index)
if info is None:
continue
# Get info, patch information
issue = getattr(env, "query", None)
files = []
if "patch" in env.record:
files = "\n".join(
[f"- {x.path}" for x in PatchSet(env.record["patch"]).modified_files]
)
# Get test files, F2P tests information
test_files = []
if "test_patch" in env.record:
test_patch_obj = PatchSet(env.record["test_patch"])
test_files = "\n".join(
[f"- {x.path}" for x in test_patch_obj.modified_files + test_patch_obj.added_files]
)
tests = ""
if "FAIL_TO_PASS" in env.record:
tests = "\n".join([f"- {x}" for x in env.record["FAIL_TO_PASS"]])
setup_args = {
"issue": issue,
"files": files,
"test_files": test_files,
"tests": tests
}
info = agent.run(
setup_args=setup_args,
env=env,
observation=observation,
traj_dir=traj_dir,
return_type="info",
)
save_predictions(traj_dir, instance_id, info)
except KeyboardInterrupt:
logger.info("Exiting InterCode environment...")
env.close()
break
except Exception as e:
traceback.print_exc()
logger.warning(f"❌ Failed on {env.record['instance_id']}: {e}")
env.reset_container()
continue
def save_arguments(traj_dir, args):
"""Save the arguments to a yaml file to the run's trajectory directory."""
log_path = traj_dir / "args.yaml"
if log_path.exists():
try:
other_args = args.load_yaml(log_path)
if (args.dumps_yaml() != other_args.dumps_yaml()): # check yaml equality instead of object equality
logger.warning("**************************************************")
logger.warning("Found existing args.yaml with different arguments!")
logger.warning("**************************************************")
except Exception as e:
logger.warning(f"Failed to load existing args.yaml: {e}")
with log_path.open("w") as f:
args.dump_yaml(f)
def should_skip(args, traj_dir, instance_id):
"""Check if we should skip this instance based on the instance filter and skip_existing flag."""
# Skip instances that don't match the instance filter
if re.match(args.instance_filter, instance_id) is None:
logger.info(f"Instance filter not matched. Skipping instance {instance_id}")
return True
# If flag is set to False, don't skip
if not args.skip_existing:
return False
# Check if there's an existing trajectory for this instance
log_path = traj_dir / (instance_id + ".traj")
if log_path.exists():
with log_path.open("r") as f:
data = json.load(f)
# If the trajectory has no exit status, it's incomplete and we will redo it
exit_status = data["info"].get("exit_status", None)
if exit_status == "early_exit" or exit_status is None:
logger.info(f"Found existing trajectory with no exit status: {log_path}")
logger.info("Removing incomplete trajectory...")
os.remove(log_path)
else:
logger.info(f"⏭️ Skipping existing trajectory: {log_path}")
return True
return False
def save_predictions(traj_dir, instance_id, info):
output_file = Path(traj_dir) / "all_preds.jsonl"
model_patch = info["submission"] if "submission" in info else None
datum = {
KEY_MODEL: Path(traj_dir).name,
KEY_INSTANCE_ID: instance_id,
KEY_PREDICTION: model_patch,
}
with open(output_file, "a+") as fp:
print(json.dumps(datum), file=fp, flush=True)
logger.info(f"Saved predictions to {output_file}")
if __name__ == "__main__":
defaults = ScriptArguments(
suffix="",
environment=EnvironmentArguments(
image_name="swe-agent",
data_path="princeton-nlp/SWE-bench_Lite",
split="dev",
verbose=True,
install_environment=True,
),
skip_existing=True,
agent=AgentArguments(
model=ModelArguments(
model_name="gpt4",
total_cost_limit=0.0,
per_instance_cost_limit=2.0,
temperature=0.2,
top_p=0.95,
),
config_file="config/default.yaml",
),
)
# Nicer yaml dumping of multiline strings
def multiline_representer(dumper, data):
"""configures yaml for dumping multiline strings
Ref: https://stackoverflow.com/questions/8640959/how-can-i-control-what-scalar-form-pyyaml-uses-for-my-data
"""
if data.count("\n") > 0: # check for multiline string
return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
return dumper.represent_scalar("tag:yaml.org,2002:str", data)
yaml.add_representer(str, multiline_representer)
args = parse(ScriptArguments, default=defaults, add_config_path_arg=False)
main(args)