@@ -73,9 +73,19 @@ void Model::initMaxSeqLen() {
7373
7474void Model::exitSlaves () {
7575 if (decoder->getRank () == 0 ) {
76- configuration.numBeams = 0 ;
77- Messenger &messenger = decoder->getMessenger ();
78- messenger.broadcast ((int *)&configuration, sizeof (SearcherConfig) / sizeof (int ));
76+ if (searcher != nullptr ) {
77+ configuration.numBeams = 0 ;
78+ Messenger &messenger = decoder->getMessenger ();
79+ messenger.broadcast ((int *)&configuration, sizeof (SearcherConfig) / sizeof (int ));
80+ return ;
81+ } else {
82+ // Only work for Model::set_input(std::vector<int32_t> &inputIds_, std::vector<int32_t> &seqLens_,
83+ // std::vector<int> seqIDs, std::vector<int> &maxLen)
84+ // TODO: Add support for other continuous batching interface
85+ Messenger &messenger = decoder->getMessenger ();
86+ int dim[4 ] = {-1 , -1 , -1 , -1 };
87+ messenger.broadcast (dim, 4 );
88+ }
7989 }
8090}
8191
@@ -584,6 +594,7 @@ std::vector<int> Model::set_input(std::vector<int32_t> &inputIds_, std::vector<i
584594 messenger.broadcast (dim, 4 );
585595
586596 if (messenger.getRank () != 0 ) {
597+ if (dim[0 ] < 0 ) { exit (0 ); }
587598 inputIds_.resize (dim[0 ]);
588599 seqLens_.resize (dim[1 ]);
589600 seqIDs.resize (dim[2 ]);
@@ -791,7 +802,8 @@ std::tuple<float *, int, int> Model::forward(bool logits_all) {
791802 for (int i = 0 ; i < works; ++i) {
792803 for (int j = 0 ; j < totalSeqSize; ++j) {
793804 memcpy (logits.data () + (i * offset + j * vocabSize),
794- logitsRecvBuf.data () + offset * totalSeqSize + j * splitSizes[i], splitSizes[i] * sizeof (float ));
805+ logitsRecvBuf.data () + offset * totalSeqSize + j * splitSizes[i],
806+ splitSizes[i] * sizeof (float ));
795807 }
796808 offset += splitSizes[i];
797809 }
0 commit comments