forked from lm-sys/FastChat
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add scripts for chat data cleaning and analysis (lm-sys#2335)
- Loading branch information
1 parent
42be87e
commit 2fbfcbc
Showing
12 changed files
with
398 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
## Chatbot Arena Conversations | ||
|
||
1. Gather battles | ||
``` | ||
python3 clean_battle_data.py --max-num 10 --mode conv_release | ||
``` | ||
|
||
2. Tag OpenAI moderation | ||
``` | ||
python3 tag_openai_moderation.py --in clean_battle_conv_20230814.json | ||
``` | ||
|
||
3. Clean PII | ||
|
||
4. Filter additional blocked words | ||
|
||
``` | ||
python3 filter_bad_conv.py --in clean_battle_conv_20230630_tagged_v1_pii.json | ||
``` | ||
|
||
5. Add additional toxicity tag | ||
|
||
|
||
## All Conversations | ||
|
||
1. Gather chats | ||
``` | ||
python3 clean_chat_data.py | ||
``` | ||
|
||
2. Sample | ||
``` | ||
python3 conv_release_scripts/sample.py | ||
``` | ||
|
||
|
||
## Prompt distribution | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -87,5 +87,3 @@ deepspeed fastchat/train/train_lora_t5.py \ | |
--deepspeed playground/deepspeed_config_s2.json | ||
|
||
``` | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
""" | ||
Usage: | ||
python3 replace_model_name.py --in clean_conv_20230809_10k.json | ||
""" | ||
|
||
import argparse | ||
import json | ||
|
||
from fastchat.serve.monitor.clean_battle_data import replace_model_name | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--in-file", type=str, required=True) | ||
args = parser.parse_args() | ||
|
||
convs = json.load(open(args.in_file)) | ||
for x in convs: | ||
x["model"] = replace_model_name(x["model"]) | ||
|
||
with open(args.in_file, "w") as fout: | ||
json.dump(convs, fout, indent=2, ensure_ascii=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
""" | ||
Usage: | ||
python3 summarize_cluster.py --in results_c20_kmeans_cluster.pkl --model gpt-4 | ||
""" | ||
import argparse | ||
import pickle | ||
|
||
from fastchat.llm_judge.common import ( | ||
chat_compeletion_openai, | ||
chat_compeletion_anthropic, | ||
) | ||
from fastchat.conversation import get_conv_template | ||
|
||
|
||
def truncate_string(s, l): | ||
half = int(l // 2) | ||
return s[:half] + s[-half:] if len(s) > l else s | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--input-file", type=str, required=True) | ||
parser.add_argument("--model", type=str, default="gpt-3.5-turbo") | ||
parser.add_argument("--num-prompts", type=int, default=100) | ||
args = parser.parse_args() | ||
|
||
model = args.model | ||
|
||
cluster_infos = pickle.load(open(args.input_file, "rb")) | ||
num_total_prompts = sum([x[0] for x in cluster_infos]) | ||
|
||
topics = [] | ||
percentages = [] | ||
for i, info in enumerate(cluster_infos): | ||
num_samples, prompts = info | ||
percentage = num_samples / num_total_prompts | ||
print( | ||
f"cluster {i}, #prompts {num_samples}, percentage: {percentage * 100:.2f}%" | ||
) | ||
instruct = "Given a list of user messages, use less than 8 words to summarize a central topic for all messages in English. Your output should only include a single line. Try to be specific." | ||
prompt = "\n".join( | ||
[truncate_string(x, l=200) for x in prompts[: args.num_prompts]] | ||
) | ||
prompt = "BEGIN OF THE MESSAGE LIST\n" + prompt + "\nEND OF THE MESSAGE LIST." | ||
|
||
if "gpt" in model: | ||
template_name = "chatgpt" | ||
completion_func = chat_compeletion_openai | ||
elif "claude" in model: | ||
template_name = "claude" | ||
completion_func = chat_compeletion_anthropic | ||
|
||
conv = get_conv_template(template_name) | ||
conv.set_system_message(instruct) | ||
conv.append_message(conv.roles[0], prompt) | ||
conv.append_message(conv.roles[1], None) | ||
|
||
topic = completion_func(model, conv, temperature=0, max_tokens=256) | ||
print(topic) | ||
|
||
topics.append(topic) | ||
percentages.append(round(percentage, 6)) | ||
|
||
print() | ||
print(f"topics: {topics}") | ||
print(f"percentages: {percentages}") |
Oops, something went wrong.