-
Notifications
You must be signed in to change notification settings - Fork 228
/
reformat_data_glaive.py
127 lines (105 loc) · 4.89 KB
/
reformat_data_glaive.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python3
import argparse
import json
import os
import random
import string
def reformat_jsonl(input_file): # noqa: C901
output_file = os.path.splitext(input_file)[0] + "_reformatted.jsonl"
skipped_samples = []
with open(input_file, "r") as infile, open(output_file, "w") as outfile:
for i, line in enumerate(infile):
reformat_data = True
data = json.loads(line)
# Extract function description
try:
function_desc = json.loads(data["function_description"])
except json.decoder.JSONDecodeError:
function_desc = (
data["function_description"].replace("\n", "").replace("}{", "},{").replace("\\t", "")
)
function_desc = "[{" + function_desc[1:-1] + "}]"
function_desc = json.loads(function_desc)
function_desc = function_desc if isinstance(function_desc, list) else [function_desc]
# Reformat tools section
if len(function_desc) == 1 and function_desc[0] == {}:
tools = None
else:
tools = []
for f in function_desc:
if f["parameters"] is None:
f["parameters"] = {}
tools.append({"type": "function", "function": f})
messages = []
# Process conversations
for idx, msg in enumerate(data["conversations"]):
role = msg["from"]
content = msg["value"]
if role == "system":
messages.append(
{"role": "system", "content": content.split(" -")[0]}
)
elif role == "human":
messages.append({"role": "user", "content": content})
elif role == "function-call":
try:
function_call = json.loads(content)
except json.decoder.JSONDecodeError:
content = content.replace("'", "").replace("\\", "'")
try:
function_call = json.loads(content)
except: # noqa: E722
skipped_samples.append(str(i))
reformat_data = False
break
if not isinstance(function_call, list):
function_calls = [function_call]
else:
function_calls = function_call
tool_calls = []
for function_call in function_calls:
assert not isinstance(function_call, list)
tool_call_id = "".join(
random.choices(string.ascii_letters + string.digits, k=9)
)
if "arguments" in function_call and not isinstance(function_call["arguments"], str):
function_call["arguments"] = str(function_call["arguments"])
elif "arguments" not in function_call:
function_call["arguments"] = ""
tool_calls.append({"id": tool_call_id, "type": "function", "function": function_call})
messages.append(
{
"role": "assistant",
"tool_calls": tool_calls
}
)
elif role == "function-response":
if "tool_calls" not in messages[-1]:
skipped_samples.append(str(i))
reformat_data = False
break
assert len(messages[-1]["tool_calls"]) == 1
tool_call_id = messages[-1]["tool_calls"][0]["id"]
messages.append(
{
"role": "tool",
"content": content,
"tool_call_id": tool_call_id,
}
)
elif role == "gpt":
messages.append({"role": "assistant", "content": content})
output_data = {"messages": messages}
if tools is not None:
output_data["tools"] = tools
if reformat_data:
outfile.write(json.dumps(output_data) + "\n")
os.rename(output_file, input_file)
print(
f"Skipped {len(skipped_samples)} samples ({len(skipped_samples) / i:.2%}). The following samples are incorrectly formated: \n\n {', '.join(skipped_samples)}"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Reformat a JSONL file.")
parser.add_argument("file", type=str, help="The input JSONL file")
args = parser.parse_args()
reformat_jsonl(args.file)