Skip to content

Commit a87a55b

Browse files
authored
[Generation] Get logits output. (#319)
1 parent 15451f2 commit a87a55b

File tree

7 files changed

+40
-10
lines changed

7 files changed

+40
-10
lines changed

include/models.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class Model {
3838

3939
bool isDone();
4040

41+
std::tuple<float *, int, int> forward();
42+
4143
std::vector<int32_t> generate();
4244

4345
void createSearcher(SearcherConfig &config_);
@@ -50,6 +52,10 @@ class Model {
5052

5153
int getSeqLen() { return seqLen; }
5254

55+
void setVocabSize(int vocabSize) { this->vocabSize = vocabSize; }
56+
57+
int getVocabSize() { return this->vocabSize; }
58+
5359
SearcherConfig getConfig() { return configuration; }
5460

5561
void setDecoder(AbstractDecoder *dec);
@@ -70,6 +76,7 @@ class Model {
7076
std::vector<int32_t> inputIds;
7177
int batchSize;
7278
int seqLen;
79+
int vocabSize;
7380
SearcherConfig configuration;
7481
bool isNewInput;
7582
};

src/common/transformer_ctx.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -191,13 +191,6 @@ struct DecoderContext {
191191
}
192192

193193
void ResetConfigReader(std::string _configPath, std::string _sectionName = "") {
194-
fs::path filePath(_configPath);
195-
196-
if (!fs::exists(filePath)) {
197-
printf("Config File %s does not exist!", configPath.c_str());
198-
exit(-1);
199-
}
200-
201194
this->configPath = _configPath;
202195
this->configReader = INIReader(_configPath);
203196
if (this->configReader.ParseError() < 0) {

src/models/models.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ bool Model::isDone() {
125125
return !isNewInput && searcher->isDone();
126126
}
127127

128+
std::tuple<float *, int, int> Model::forward() {
129+
int64_t dims[3] = {batchSize, 1, seqLen};
130+
return decoder->forward(inputIds.data(), dims, 0, true);
131+
}
132+
128133
std::vector<int32_t> Model::generate() {
129134
if (inputIds.empty()) {
130135
printf("Please set input tokens by model.input().\n");
@@ -261,6 +266,7 @@ AutoModel::AutoModel(std::string modelPath, xft::DataType datatype) : Model() {
261266
exit(-1);
262267
}
263268
std::string modeltype = *reader.Sections().begin();
269+
setVocabSize(reader.GetInteger(modeltype, "vocab_size"));
264270

265271
if (modeltype == "gpt") {
266272
switch (datatype) {

src/pytorch/auto_model.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,25 @@ struct TorchAutoModel : torch::CustomClassHolder {
129129
doSample, temperature, topK, topP, repetitionPenalty, stopWordsList_int32);
130130
}
131131

132+
torch::Tensor forward(torch::Tensor &inputIds) {
133+
int batchSize = inputIds.size(0);
134+
int seqLen = inputIds.size(1);
135+
int vocabSize = model->getVocabSize();
136+
int logitsN = batchSize * seqLen * vocabSize;
137+
138+
if (model->getRank() == 0) { input(inputIds); }
139+
140+
std::tuple<float *, int, int> result = model->forward();
141+
float *outBuf = std::get<0>(result);
142+
int sampleOffset = std::get<1>(result);
143+
int sampleSize = std::get<2>(result);
144+
145+
// Create a torch::Tensor from the C array
146+
int64_t tdims[3] = {batchSize, seqLen, vocabSize};
147+
torch::Tensor ret = torch::from_blob(outBuf, tdims, torch::kFloat32);
148+
return ret;
149+
}
150+
132151
torch::Tensor generate() {
133152
auto nextTokens = model->generate();
134153

src/pytorch/pytorch_warpper.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ TORCH_LIBRARY(xfastertransformer, m) {
2222
.def("input", &TorchAutoModel::input)
2323
.def("config", &TorchAutoModel::config)
2424
.def("is_done", &TorchAutoModel::isDone)
25+
.def("forward", &TorchAutoModel::forward)
2526
.def("generate", &TorchAutoModel::generate)
2627
.def("finalize", &TorchAutoModel::finalize)
2728
.def("set_prefix", &TorchAutoModel::setPrefix)

src/utils/weight_util.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,10 @@ int loadWeight(std::string filename, T *&ptr, int size, DataType w_type = DataTy
158158
// By default, read the config.ini configuration file
159159
// in the same directory as the model file to determine the data type of the file.
160160
if (w_type == DataType::unknown) {
161-
std::filesystem::path pathObj(filename);
162-
std::filesystem::path folderPath = pathObj.parent_path();
163-
w_type = getWeightType(folderPath.append("config.ini").string());
161+
std::size_t pos = filename.find_last_of("/\\");
162+
std::string dirPath = filename.substr(0, pos);
163+
std::string configFilePath = dirPath + "/config.ini";
164+
w_type = getWeightType(configFilePath);
164165
}
165166
//1 uint4x2 stores 2 uint4 value, so load size is halfed.
166167
if constexpr (std::is_same_v<T, uint4x2_t>) { size = size / 2; }

src/xfastertransformer/automodel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def __init__(self, path, dtype: str = "fp16"):
3838
else:
3939
raise Exception(f"{self.__class__.__name__} don't support {dtype}.")
4040

41+
def __call__(self, inputs, **kwargs):
42+
return self.model.forward(inputs)
43+
4144
@classmethod
4245
def from_pretrained(cls, path, dtype: str = "fp16"):
4346
return cls(path, dtype)

0 commit comments

Comments
 (0)