forked from ztxz16/fastllm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.cpp
98 lines (90 loc) · 3.49 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#include "model.h"
struct RunConfig {
std::string path = "chatglm-6b-int4.bin"; // 模型文件路径
int threads = 4; // 使用的线程数
bool lowMemMode = false; // 是否使用低内存模式
};
void Usage() {
std::cout << "Usage:" << std::endl;
std::cout << "[-h|--help]: 显示帮助" << std::endl;
std::cout << "<-p|--path> <args>: 模型文件的路径" << std::endl;
std::cout << "<-t|--threads> <args>: 使用的线程数量" << std::endl;
std::cout << "<-l|--low>: 使用低内存模式" << std::endl;
std::cout << "<--top_p> <args>: 采样参数top_p" << std::endl;
std::cout << "<--top_k> <args>: 采样参数top_k" << std::endl;
std::cout << "<--temperature> <args>: 采样参数温度,越高结果越不固定" << std::endl;
std::cout << "<--repeat_penalty> <args>: 采样参数重复惩罚" << std::endl;
}
void ParseArgs(int argc, char **argv, RunConfig &config, fastllm::GenerationConfig &generationConfig) {
std::vector <std::string> sargv;
for (int i = 0; i < argc; i++) {
sargv.push_back(std::string(argv[i]));
}
for (int i = 1; i < argc; i++) {
if (sargv[i] == "-h" || sargv[i] == "--help") {
Usage();
exit(0);
} else if (sargv[i] == "-p" || sargv[i] == "--path") {
config.path = sargv[++i];
} else if (sargv[i] == "-t" || sargv[i] == "--threads") {
config.threads = atoi(sargv[++i].c_str());
} else if (sargv[i] == "-l" || sargv[i] == "--low") {
config.lowMemMode = true;
} else if (sargv[i] == "-m" || sargv[i] == "--model") {
i++;
} else if (sargv[i] == "--top_p") {
generationConfig.top_p = atof(sargv[++i].c_str());
} else if (sargv[i] == "--top_k") {
generationConfig.top_k = atof(sargv[++i].c_str());
} else if (sargv[i] == "--temperature") {
generationConfig.temperature = atof(sargv[++i].c_str());
} else if (sargv[i] == "--repeat_penalty") {
generationConfig.repeat_penalty = atof(sargv[++i].c_str());
} else {
Usage();
exit(-1);
}
}
}
int main(int argc, char **argv) {
int round = 0;
std::string history = "";
RunConfig config;
fastllm::GenerationConfig generationConfig;
ParseArgs(argc, argv, config, generationConfig);
fastllm::PrintInstructionInfo();
fastllm::SetThreads(config.threads);
fastllm::SetLowMemMode(config.lowMemMode);
auto model = fastllm::CreateLLMModelFromFile(config.path);
static std::string modelType = model->model_type;
printf("欢迎使用 %s 模型. 输入内容对话,reset清空历史记录,stop退出程序.\n", model->model_type.c_str());
while (true) {
printf("用户: ");
std::string input;
std::getline(std::cin, input);
if (input == "reset") {
history = "";
round = 0;
continue;
}
if (input == "stop") {
break;
}
std::string ret = model->Response(model->MakeInput(history, round, input), [](int index, const char* content) {
if (index == 0) {
printf("%s:%s", modelType.c_str(), content);
fflush(stdout);
}
if (index > 0) {
printf("%s", content);
fflush(stdout);
}
if (index == -1) {
printf("\n");
}
}, generationConfig);
history = model->MakeHistory(history, round, input, ret);
round++;
}
return 0;
}