-
Notifications
You must be signed in to change notification settings - Fork 0
/
rwkv.cu
123 lines (99 loc) · 2.46 KB
/
rwkv.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#include <iostream>
#include <string>
#include <fstream>
#include "rwkv.h"
#include "chrono"
// #include "sampler/sample.hpp"
#include "tokenizer/tokenizer.hpp"
#include "thread"
#include "atomic"
#include "sampler/sample.h"
#include "tensor/operators/threading/threading.h"
RWKVTokenizer worldTokenizer("rwkv_vocab_v20230424.txt");
float temp = 1.0;
void run(RWKV& model,Tensor logitsin)
{
// std::cout << "Generating token " << i << std::endl;
auto pool = get_threadpool();
auto logs = (logitsin[0][logitsin.shape[1] - 1]);
size_t sample = dart((float *)logs.cpu().data, temp);
std::string output = "";
if (sample == 0)
{
output = "User";
pool->print("\n");
}
else
{
output = worldTokenizer.decode({sample});
}
// lock cout
auto vnn = output;
if (output == "User")
{
vnn += ": ";
}
pool->print(vnn);
if (output == "User")
{
std::string input = "";
std::getline(std::cin, input);
pool->print("\n");
auto logits = model({worldTokenizer.encode("User: " + input + "\n\nAssistant:")});
pool->add_job(
[logits, &model]()
{
run(model,logits);
},
0);
}
else
{
auto logits = model({{sample}});
pool->add_job(
[logits, &model]()
{
run(model,logits);
},
0);
}
};
int main(int argc, char **argv)
{
std::string path = "./model.safetensors";
if (argc > 1)
{
path = argv[1];
}
size_t threads = 8;
if (argc > 2)
{
threads = std::stoi(argv[2]);
}
RWKV model(path, threads);
if (argc > 3)
{
model.cuda();
}
if (argc > 4)
{
temp = std::stof(argv[4]);
}
std::string instruction = "System: You are a multi-lingual language model created by recursalAI and the RWKV group. Help the user with their tasks.\n\nUser: ";
std::cout << instruction;
std::string input = "";
std::getline(std::cin, input);
std::cout << "\n";
auto tokens = worldTokenizer.encode(instruction + input + "\n\n" + "Assistant:");
// model.cuda();
auto logitsstart = model({tokens});
auto pool = get_threadpool();
pool->add_job(
[logitsstart,&model]()
{
run(model, logitsstart);
},
0);
// hold for 1 min
std::this_thread::sleep_for(std::chrono::minutes(1));
}