Skip to content

Commit 88eac18

Browse files
authored
Merge pull request #20 from open-sciencelab/split-model-providers
feat(webui): support M_synth & M_train from different providers
2 parents 813215a + eb2b047 commit 88eac18

File tree

2 files changed

+31
-14
lines changed

2 files changed

+31
-14
lines changed

webui/app.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -105,17 +105,20 @@ def sum_tokens(client):
105105
env = {
106106
"SYNTHESIZER_BASE_URL": arguments[12],
107107
"SYNTHESIZER_MODEL": arguments[13],
108-
"TRAINEE_BASE_URL": arguments[12],
108+
"TRAINEE_BASE_URL": arguments[20],
109109
"TRAINEE_MODEL": arguments[14],
110110
"SYNTHESIZER_API_KEY": arguments[15],
111-
"TRAINEE_API_KEY": arguments[15],
111+
"TRAINEE_API_KEY": arguments[21],
112112
"RPM": arguments[17],
113113
"TPM": arguments[18],
114114
}
115115

116116
# Test API connection
117117
test_api_connection(env["SYNTHESIZER_BASE_URL"],
118118
env["SYNTHESIZER_API_KEY"], env["SYNTHESIZER_MODEL"])
119+
if config['if_trainee_model']:
120+
test_api_connection(env["TRAINEE_BASE_URL"],
121+
env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"])
119122

120123
# Initialize GraphGen
121124
graph_gen = init_graph_gen(config, env)
@@ -278,20 +281,32 @@ def sum_tokens(client):
278281
interactive=True)
279282

280283
with gr.Accordion(label=_("Model Config"), open=False):
281-
base_url = gr.Textbox(label="Base URL",
284+
synthesizer_url = gr.Textbox(label="Synthesizer URL",
282285
value="https://api.siliconflow.cn/v1",
283-
info=_("Base URL Info"),
286+
info=_("Synthesizer URL Info"),
284287
interactive=True)
285288
synthesizer_model = gr.Textbox(label="Synthesizer Model",
286289
value="Qwen/Qwen2.5-7B-Instruct",
287290
info=_("Synthesizer Model Info"),
288291
interactive=True)
292+
trainee_url = gr.Textbox(label="Trainee URL",
293+
value="https://api.siliconflow.cn/v1",
294+
info=_("Trainee URL Info"),
295+
interactive=True,
296+
visible=if_trainee_model.value is True)
289297
trainee_model = gr.Textbox(
290298
label="Trainee Model",
291299
value="Qwen/Qwen2.5-7B-Instruct",
292300
info=_("Trainee Model Info"),
293301
interactive=True,
294302
visible=if_trainee_model.value is True)
303+
trainee_api_key = gr.Textbox(
304+
label=_("SiliconCloud Token for Trainee Model"),
305+
type="password",
306+
value="",
307+
info="https://cloud.siliconflow.cn/account/ak",
308+
visible=if_trainee_model.value is True)
309+
295310

296311
with gr.Accordion(label=_("Generation Config"), open=False):
297312
chunk_size = gr.Slider(label="Chunk Size",
@@ -428,12 +443,12 @@ def sum_tokens(client):
428443
# Test Connection
429444
test_connection_btn.click(
430445
test_api_connection,
431-
inputs=[base_url, api_key, synthesizer_model],
446+
inputs=[synthesizer_url, api_key, synthesizer_model],
432447
outputs=[])
433448

434449
if if_trainee_model.value:
435450
test_connection_btn.click(test_api_connection,
436-
inputs=[base_url, api_key, trainee_model],
451+
inputs=[trainee_url, api_key, trainee_model],
437452
outputs=[])
438453

439454
expand_method.change(lambda method:
@@ -443,11 +458,9 @@ def sum_tokens(client):
443458
outputs=[max_extra_edges, max_tokens])
444459

445460
if_trainee_model.change(
446-
lambda use_trainee: (gr.update(visible=use_trainee is True),
447-
gr.update(visible=use_trainee is True),
448-
gr.update(visible=use_trainee is True)),
461+
lambda use_trainee: [gr.update(visible=use_trainee)] * 5,
449462
inputs=if_trainee_model,
450-
outputs=[trainee_model, quiz_samples, edge_sampling])
463+
outputs=[trainee_url, trainee_model, quiz_samples, edge_sampling, trainee_api_key])
451464

452465
# 计算上传文件的token数
453466
upload_file.change(
@@ -471,8 +484,8 @@ def sum_tokens(client):
471484
if_trainee_model, upload_file, tokenizer, qa_form,
472485
bidirectional, expand_method, max_extra_edges, max_tokens,
473486
max_depth, edge_sampling, isolated_node_strategy,
474-
loss_strategy, base_url, synthesizer_model, trainee_model,
475-
api_key, chunk_size, rpm, tpm, quiz_samples, token_counter
487+
loss_strategy, synthesizer_url, synthesizer_model, trainee_model,
488+
api_key, chunk_size, rpm, tpm, quiz_samples, trainee_url, trainee_api_key, token_counter
476489
],
477490
outputs=[output, token_counter],
478491
)

webui/translation.json

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
"Title": "✨Easy-to-use LLM Training Data Generation Framework✨",
44
"Intro": "is a framework for synthetic data generation guided by knowledge graphs, designed to tackle challenges for knowledge-intensive QA generation. \n\nBy uploading your text chunks (such as knowledge in agriculture, healthcare, or marine science) and filling in the LLM API key, you can generate the training data required by **[LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory)** and **[xtuner](https://github.com/InternLM/xtuner)** online. We will automatically delete user information after completion.",
55
"Use Trainee Model": "Use Trainee Model to identify knowledge blind spots, please keep disable for SiliconCloud",
6-
"Base URL Info": "Base URL for the API, use SiliconFlow as default",
6+
"Synthesizer URL Info": "Base URL for the Synthesizer Model API, use SiliconFlow as default",
7+
"Trainee URL Info": "Base URL for the Trainee Model API, use SiliconFlow as default",
78
"Synthesizer Model Info": "Model for constructing KGs and generating QAs",
89
"Trainee Model Info": "Model for training",
910
"Model Config": "Model Configuration",
1011
"Generation Config": "Generation Config",
1112
"SiliconCloud Token": "SiliconCloud API Key",
13+
"SiliconCloud Token for Trainee Model": "SiliconCloud API Key for Trainee Model",
1214
"Test Connection": "Test Connection",
1315
"Run GraphGen": "Run GraphGen",
1416
"Upload File": "Upload File",
@@ -18,12 +20,14 @@
1820
"Title": "✨开箱即用的LLM训练数据生成框架✨",
1921
"Intro": "是一个基于知识图谱的数据合成框架,旨在知识密集型任务中生成问答。\n\n 上传你的文本块(如农业、医疗、海洋知识),填写 LLM api key,即可在线生成 **[LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory)**、**[xtuner](https://github.com/InternLM/xtuner)** 所需训练数据。结束后我们将自动删除用户信息。",
2022
"Use Trainee Model": "使用Trainee Model来识别知识盲区,使用硅基流动时请保持禁用",
21-
"Base URL Info": "调用模型API的URL,默认使用硅基流动",
23+
"Synthesizer URL Info": "调用合成模型API的URL,默认使用硅基流动",
24+
"Trainee URL Info": "调用学生模型API的URL,默认使用硅基流动",
2225
"Synthesizer Model Info": "用于构建知识图谱和生成问答的模型",
2326
"Trainee Model Info": "用于训练的模型",
2427
"Model Config": "模型配置",
2528
"Generation Config": "生成配置",
2629
"SiliconCloud Token": "硅基流动 API Key",
30+
"SiliconCloud Token for Trainee Model": "硅基流动 API Key (学生模型)",
2731
"Test Connection": "测试接口",
2832
"Run GraphGen": "运行GraphGen",
2933
"Upload File": "上传文件",

0 commit comments

Comments
 (0)