Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion ajet/backbone/trainer_verl.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from ajet.schema.task import Task
from ajet.task_reader import dict_to_ajet_task
from ajet.task_rollout.native_parallel_worker import VerlRolloutManager

from ajet.utils.metric_helper import save_trajectory_as_json_file, update_metrics

def parse_reward_from_dataproto(data: DataProto, return_dict=False) -> dict | torch.Tensor:
"""
Expand Down Expand Up @@ -602,6 +602,8 @@ def fit(self): # noqa: C901
),
}
)
save_trajectory_as_json_file(context_tracker_arr, self.global_steps, self.config, prefix="train")
update_metrics(context_tracker_arr, metrics)
if self.config.ajet.execute_test: # apply a test probe
from swanlab.data.run.main import get_run

Expand Down Expand Up @@ -1044,6 +1046,8 @@ def eval_dataset(self, target_dataset, target_dataset_name, mode, epoch):
f"TGC@{pass_n}-all-pass": num_all_success_tasks / num_tasks,
"mean_reward": sum(rewards) / len(rewards) if rewards else 0,
}
save_trajectory_as_json_file(ctx_trackers, self.global_steps, self.config, prefix="eval")
update_metrics(ctx_trackers, val_metrics)
print_dict(
val_metrics,
narrow=True,
Expand Down
15 changes: 11 additions & 4 deletions ajet/context_tracker/base_tracker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import List, Tuple, Union
from typing import List, Union, Tuple, Dict, Optional, Any
Comment on lines 1 to +2

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The typing module is imported twice, and the first import is now a subset of the second. To improve readability and avoid redundancy, please consolidate these into a single import statement.

Suggested change
from typing import List, Tuple, Union
from typing import List, Union, Tuple, Dict, Optional, Any
from typing import List, Union, Tuple, Dict, Optional, Any

from ajet.schema.task import WorkflowTask

from ajet.schema.extended_msg import (
INVALID_LOG_PROB_VALUE,
Expand Down Expand Up @@ -110,10 +112,14 @@ def replace_token_ids(


class BaseTracker(object):
def __init__(self, config, tokenizer, **kwargs):
self.task_batch_index = kwargs.get("task_batch_index", "undefined")
self.task_tag = kwargs.get("task_tag", "undefined")
self.task_id = kwargs.get("task_id", "undefined")
def __init__(self, config, tokenizer, workflow_task: WorkflowTask, **kwargs):

self.workflow_task = workflow_task
self.task_batch_index = self.workflow_task.task_batch_index
self.task_tag = self.workflow_task.task_tag
self.task_id = self.workflow_task.task_id
self.episode_uuid = self.workflow_task.episode_uuid

self.config = config
self.tokenizer = tokenizer
self.saved_timelines: List[List[ExtendedMessage]] = []
Expand All @@ -135,6 +141,7 @@ def __init__(self, config, tokenizer, **kwargs):
self.already_mad_flag: bool = False
self.round_cnt = 0
self.generation_prompt_token = None
self.log_metrics: Optional[Dict[str, Union[float, List[float]]]] = None # Initialize workflow_metadata to store tool statistics

assert (
self.config.ajet.data.max_prompt_length
Expand Down
2 changes: 2 additions & 0 deletions ajet/context_tracker/basic_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ def to_role_content(self, ext_msg_array: List[ExtendedMessage]) -> List:
}
if ext_msg.tool_calls:
d.update({"tool_calls": ext_msg.tool_calls})
if ext_msg.tool_call_id:
d.update({"tool_call_id": ext_msg.tool_call_id})
result.append(d)
return result

Expand Down
69 changes: 48 additions & 21 deletions ajet/context_tracker/multiagent_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def __init__(
config,
should_interrupt_fn,
generated_token_callback_fn,
episode_uuid: str,
**kwargs,
):
super().__init__(config, tokenizer, **kwargs)
Expand All @@ -61,7 +60,6 @@ def __init__(
self.output_kwargs = {}
self.input_kwargs = {}
self.timeline_cache = {}
self.episode_uuid = episode_uuid


def preprocess_tools_field(self, tools: List[dict] = [], disable_toolcalls: bool = False):
Expand All @@ -74,6 +72,40 @@ def preprocess_tools_field(self, tools: List[dict] = [], disable_toolcalls: bool
tools[i]["function"]["parameters"] = tools[i]["function"].pop("parameters")
return tools

def extract_text_content_from_content_dict(self, msg):
# msg = {
# "role": "assistant",
# "content": [
# {
# "type": "text",
# "text": "some text"
# },
# ],
# }

str_content = ""
for item in msg["content"]:
# item = {
# "type": "text",
# "text": "some text"
# },

assert isinstance(item, dict), f"Unsupported non-dict item in message content: {item}. Full message: {msg}"

if ("text" not in item):
logger.warning(
f"Non-text content in message content detected: {item}. Ignoring."
)
should_skip_message = True
return str_content, should_skip_message

if isinstance(item["text"], str):
str_content += str(item["text"])
else:
str_content = ""

should_skip_message = False
return str_content, should_skip_message

def step_spawn_timeline(self, messages: List[dict], tools: List = [], disable_toolcalls: bool = False) -> List[ExtendedMessage]:
"""Spawn a timeline from messages.
Expand All @@ -93,39 +125,32 @@ def step_spawn_timeline(self, messages: List[dict], tools: List = [], disable_to
consider_roles.remove("tool")

for i, msg in enumerate(messages):

if (disable_toolcalls) and (not isinstance(msg["content"], str)):
continue

if msg["role"] not in consider_roles:
continue

if not isinstance(msg["content"], str):
author = "env"
ignore = False
str_content = ""
should_skip_message = False

# fix msg content
if msg["content"] is None:
msg["content"] = ""

elif isinstance(msg["content"], list):
for item in msg["content"]:
if "text" not in item:
logger.warning(
f"Non-text content in message content detected: {item}. Ignoring."
)
ignore = True
break
if isinstance(item["text"], str):
str_content += str(item["text"])
else:
str_content = ""
msg["content"] = str_content
msg["content"], should_skip_message = self.extract_text_content_from_content_dict(msg)

else:
raise ValueError(
f"Unsupported non-str message content type: {type(msg['content'])}, Message:\n {msg}"
)
raise ValueError(f"Unsupported non-str message content type: {type(msg['content'])}, Message:\n {msg}")

if ignore:
if should_skip_message:
continue
msg["content"] = str(msg["content"]) # TODO: better handling mm data

if not isinstance(msg["content"], str):
msg["content"] = str(msg["content"]) # TODO: better handling mm data

if msg["role"] == "system":
author = "initialization"
Expand All @@ -143,7 +168,9 @@ def step_spawn_timeline(self, messages: List[dict], tools: List = [], disable_to
tokenizer=self.tokenizer,
tools=tools,
tool_calls=(msg["tool_calls"] if "tool_calls" in msg else []),
tool_call_id=(msg["tool_call_id"] if "tool_call_id" in msg else ""),
token_generator="auto",
name = (msg["name"] if "name" in msg else ""),
first_message=(i == 0),
)
]
Expand Down
28 changes: 28 additions & 0 deletions ajet/default_config/ajet_default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -231,29 +231,57 @@ ajet:

# trainer common configurations
trainer_common:

# validation before training
val_before_train: False
val_pass_n: 4

# save and test frequency (in step)
save_freq: 20
test_freq: 20

# total training epochs
total_epochs: 50

nnodes: 1
n_gpus_per_node: 8

# logger selection
logger: swanlab

# algorithm setting
algorithm:
adv_estimator: grpo
use_kl_in_reward: False

# number of optimizer.step per big batch
mini_batch_num: 1

# verl offload configs
fsdp_config:
param_offload: True
optimizer_offload: True

# learning rate
optim:
lr: 1e-6

# enable KL loss regularization
use_kl_loss: True

# kl divergence loss coefficient
kl_loss_coef: 0.002
kl_loss_type: low_var_kl

# Ulysses specific configs
ulysses_sequence_parallel_size: 1

# base directory to save checkpoints
checkpoint_base_dir: ./saved_checkpoints

# whether to save train/eval trajectories to JSON files
save_trajectory_as_json_file: False




Expand Down
17 changes: 14 additions & 3 deletions ajet/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ajet.utils.pty import pty_launch

set_loguru_default_color()
load_dotenv()
load_dotenv(override=False)


def parse_args():
Expand Down Expand Up @@ -59,6 +59,12 @@ def parse_args():
default=False,
help="Launch appworld",
)
parser.add_argument(
"--with-finworld",
action="store_true",
default=False,
help="Launch finworld",
)
parser.add_argument(
"--with-webshop",
action="store_true",
Expand All @@ -79,6 +85,7 @@ def parse_args():
help="Launch Crafters Env Simulation",
)
parser.add_argument("--reboot", action="store_true", default=False, help="reboot flag")
parser.add_argument("--skip-check-avail-gpu", action="store_true", default=False, help="Skip GPU availability check")
parser.add_argument(
"--kill",
type=str,
Expand Down Expand Up @@ -247,8 +254,9 @@ def main():
args = parse_args()

# Enforce GPU availability and free memory threshold before proceeding
if (args.backbone != "debug") and (not args.kill) and (not args.autokill):
check_avail_gpu(min_free_ratio=0.95)
if not args.skip_check_avail_gpu:
if (args.backbone != "debug") and (not args.kill) and (not args.autokill):
check_avail_gpu(min_free_ratio=0.95)

if args.autokill:
args.kill = "ray|vllm|VLLM|python"
Expand Down Expand Up @@ -295,6 +303,9 @@ def main():
if args.with_appworld:
pty_launch("appworld")

if args.with_finworld:
pty_launch("finworld")

if args.with_crafters:
pty_launch("crafters")

Expand Down
6 changes: 6 additions & 0 deletions ajet/schema/extended_msg.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def __init__(
build_from_uuid="",
tools=[],
tool_calls=[],
tool_call_id="",
token_logprob_arr=[],
name="", # preserved field, not used currently
first_message=False,
):
self.author = author
Expand All @@ -88,6 +90,8 @@ def __init__(
self.clip = clip
self.tools = tools
self.tool_calls = tool_calls
self.tool_call_id = tool_call_id
self.name = name # preserved field, not used currently
if not isinstance(self.tool_calls, list):
# agent scope sometimes gives weird type for tool_calls, which is against OpenAI schema
self.tool_calls = list(self.tool_calls)
Expand Down Expand Up @@ -146,6 +150,8 @@ def auto_tokenize_non_first_message(self, tokenizer, tools):
}
if self.tool_calls:
auto_tokenize_target.update({"tool_calls": self.tool_calls})
if self.tool_call_id:
auto_tokenize_target.update({"tool_call_id": self.tool_call_id})
text_frag_to = ajet_apply_chat_template(
tokenizer=tokenizer,
conversation=DUMMY_MSG + [auto_tokenize_target],
Expand Down
1 change: 1 addition & 0 deletions ajet/schema/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,4 @@ class WorkflowOutput(BaseModel):
reward: Union[float, List[float], None] = Field(default=None)
is_success: Union[bool, None] = Field(default=None)
metadata: Dict[str, Any] = Field(default_factory=dict)
log_metrics: Dict[str, Union[float, List[float]]] = Field(default_factory=dict)
23 changes: 16 additions & 7 deletions ajet/task_rollout/resource_keeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def _initialize_environment_and_messages(self) -> List[dict]:
params=self.env_params,
)
state_message: dict = init_response["state"]
_, init_messages = self._get_init_messages(state_message)
query, init_messages = self._get_init_messages(state_message)
# Update main_query with actual query from environment
self.workflow_task.task.main_query = query
except Exception as e:
logger.bind(exception=True).exception(
f"encounter exception in env_worker.create_instance~ error={e.args}"
Expand Down Expand Up @@ -176,16 +178,23 @@ def step(self, action: dict) -> Tuple[str, float, bool, dict]:
)
obs = ""
assert isinstance(env_output, dict)
if ("content" not in env_output["state"]) and ("error" in env_output["state"]):
obs = f"[Error from environment: {env_output['error']}]"
elif env_output["state"]["content"] == "":
obs = "Warning: the environment does not provide any feedback, please provide valid inpu and try again."

if isinstance(env_output["state"], list):
# 1. If state is a list (new standard format), pass through directly
obs = env_output["state"]
else:
obs = env_output["state"]["content"]
# 2. If state is a dict (old format or error)
if ("content" not in env_output["state"]) and ("error" in env_output["state"]):
obs = f"[Error from environment: {env_output['error']}]"
elif env_output["state"].get("content", "") == "":
obs = "Warning: the environment does not provide any feedback, please provide valid inpu and try again."
else:
obs = env_output["state"]["content"]

reward = 0
info = {}
terminate = env_output["is_terminated"]
return obs, reward, terminate, info
return obs, reward, terminate, info # type: ignore

def reset(self) -> str:
"""Reset gym environment."""
Expand Down
Loading
Loading