@@ -235,53 +235,77 @@ def migrate_jsonl_to_parquet(
235235 if "finish_reason" in msg :
236236 # Choice format - extract inner message, mark as trainable
237237 inner = msg .get ("message" , {})
238- messages .append ({
239- "role" : inner .get ("role" ),
240- "content" : inner .get ("content" ),
241- "tool_calls" : json .dumps (inner .get ("tool_calls" )) if inner .get ("tool_calls" ) else None ,
242- "tool_call_id" : None ,
243- "trainable" : True ,
244- })
238+ messages .append (
239+ {
240+ "role" : inner .get ("role" ),
241+ "content" : inner .get ("content" ),
242+ "tool_calls" : json .dumps (inner .get ("tool_calls" ))
243+ if inner .get ("tool_calls" )
244+ else None ,
245+ "tool_call_id" : None ,
246+ "trainable" : True ,
247+ }
248+ )
245249 else :
246250 # Regular message
247- messages .append ({
248- "role" : msg .get ("role" ),
249- "content" : msg .get ("content" ),
250- "tool_calls" : json .dumps (msg .get ("tool_calls" )) if msg .get ("tool_calls" ) else None ,
251- "tool_call_id" : msg .get ("tool_call_id" ),
252- "trainable" : False ,
253- })
254-
255- rows .append ({
256- "group_index" : group_index ,
257- "reward" : traj .get ("reward" ),
258- "metrics" : json .dumps (traj .get ("metrics" )) if traj .get ("metrics" ) else None ,
259- "metadata" : json .dumps (traj .get ("metadata" )) if traj .get ("metadata" ) else None ,
260- "tools" : json .dumps (traj .get ("tools" )) if traj .get ("tools" ) else None ,
261- "logs" : traj .get ("logs" ),
262- "additional_histories" : json .dumps (traj .get ("additional_histories" )) if traj .get ("additional_histories" ) else None ,
263- "messages" : messages ,
264- })
251+ messages .append (
252+ {
253+ "role" : msg .get ("role" ),
254+ "content" : msg .get ("content" ),
255+ "tool_calls" : json .dumps (msg .get ("tool_calls" ))
256+ if msg .get ("tool_calls" )
257+ else None ,
258+ "tool_call_id" : msg .get ("tool_call_id" ),
259+ "trainable" : False ,
260+ }
261+ )
262+
263+ rows .append (
264+ {
265+ "group_index" : group_index ,
266+ "reward" : traj .get ("reward" ),
267+ "metrics" : json .dumps (traj .get ("metrics" ))
268+ if traj .get ("metrics" )
269+ else None ,
270+ "metadata" : json .dumps (traj .get ("metadata" ))
271+ if traj .get ("metadata" )
272+ else None ,
273+ "tools" : json .dumps (traj .get ("tools" ))
274+ if traj .get ("tools" )
275+ else None ,
276+ "logs" : traj .get ("logs" ),
277+ "additional_histories" : json .dumps (
278+ traj .get ("additional_histories" )
279+ )
280+ if traj .get ("additional_histories" )
281+ else None ,
282+ "messages" : messages ,
283+ }
284+ )
265285
266286 # Define the message struct schema
267- message_type = pa .struct ([
268- ("role" , pa .string ()),
269- ("content" , pa .string ()),
270- ("tool_calls" , pa .string ()),
271- ("tool_call_id" , pa .string ()),
272- ("trainable" , pa .bool_ ()),
273- ])
274-
275- schema = pa .schema ([
276- ("group_index" , pa .int64 ()),
277- ("reward" , pa .float64 ()),
278- ("metrics" , pa .string ()),
279- ("metadata" , pa .string ()),
280- ("tools" , pa .string ()),
281- ("logs" , pa .list_ (pa .string ())),
282- ("additional_histories" , pa .string ()),
283- ("messages" , pa .list_ (message_type )),
284- ])
287+ message_type = pa .struct (
288+ [
289+ ("role" , pa .string ()),
290+ ("content" , pa .string ()),
291+ ("tool_calls" , pa .string ()),
292+ ("tool_call_id" , pa .string ()),
293+ ("trainable" , pa .bool_ ()),
294+ ]
295+ )
296+
297+ schema = pa .schema (
298+ [
299+ ("group_index" , pa .int64 ()),
300+ ("reward" , pa .float64 ()),
301+ ("metrics" , pa .string ()),
302+ ("metadata" , pa .string ()),
303+ ("tools" , pa .string ()),
304+ ("logs" , pa .list_ (pa .string ())),
305+ ("additional_histories" , pa .string ()),
306+ ("messages" , pa .list_ (message_type )),
307+ ]
308+ )
285309
286310 # Handle empty case
287311 if not rows :
0 commit comments