@@ -1012,11 +1012,6 @@ void RecurrentGradientMachine::generateSequence() {
10121012 /* width */ resultNum,
10131013 false ,
10141014 /* useGpu */ false );
1015- Matrix::resizeOrCreate (generator_.outArg .value ,
1016- /* height */ maxGenWordCount,
1017- /* width */ 1 ,
1018- false ,
1019- /* useGpu */ false );
10201015 }
10211016 ICpuGpuVector::resizeOrCreate (generator_.outArg .sequenceStartPositions ,
10221017 numSequences + 1 ,
@@ -1026,7 +1021,7 @@ void RecurrentGradientMachine::generateSequence() {
10261021 } else {
10271022 oneWaySearch (numSequences);
10281023 }
1029- if (dataArgsSize_) createDataOutlink (batchMachineIdVec_ );
1024+ if (dataArgsSize_) createDataOutlink ();
10301025
10311026 size_t size = generator_.ids .size ();
10321027 generator_.outArg .ids ->resize (size);
@@ -1106,6 +1101,7 @@ void RecurrentGradientMachine::oneWaySearch(size_t batchSize) {
11061101 }
11071102
11081103 batchMachineIdVec_.clear ();
1104+ batchMachineStartPos_.clear ();
11091105 int * starts = generator_.outArg .sequenceStartPositions ->getMutableData (false );
11101106 starts[0 ] = 0 ;
11111107 generator_.ids .clear ();
@@ -1312,13 +1308,20 @@ void RecurrentGradientMachine::fillGenOutputs() {
13121308 finalPaths_[i].resize (minFinalPathsSize);
13131309 }
13141310
1315- batchMachineIdVec_.clear ();
13161311 generator_.ids .clear ();
13171312 int * starts = generator_.outArg .sequenceStartPositions ->getMutableData (false );
13181313 starts[0 ] = 0 ;
13191314 if (numResults > 1 ) {
1320- real* probs = generator_.outArg .in ->getData ();
1315+ int idsProbSaveSize = 0 ;
1316+ for (auto inSeq : finalPaths_) {
1317+ for (auto path : inSeq) idsProbSaveSize += path.ids .size ();
1318+ idsProbSaveSize += inSeq.size ();
1319+ }
1320+ Matrix::resizeOrCreate (
1321+ generator_.outArg .value , idsProbSaveSize, 1 , false , false );
13211322 real* idsProb = generator_.outArg .value ->getData ();
1323+
1324+ real* probs = generator_.outArg .in ->getData ();
13221325 size_t curPos = 0 ;
13231326 for (size_t i = 0 ; i < finalPaths_.size (); ++i) {
13241327 for (size_t j = 0 ; j < finalPaths_[i].size (); ++j) {
@@ -1333,24 +1336,16 @@ void RecurrentGradientMachine::fillGenOutputs() {
13331336 curPos += genLen;
13341337 idsProb[curPos++] = -1.0 ;
13351338 probs[i * numResults + j] = path.logProb ;
1336-
1337- if (!j && dataArgsSize_) {
1338- // in beam search, here only reserved the top 1 generated result
1339- // for out_links that are not the generated word indices.
1340- batchMachineIdVec_.insert (batchMachineIdVec_.end (),
1341- path.machineIdVec .begin (),
1342- path.machineIdVec .end ());
1343- }
13441339 }
13451340 starts[i + 1 ] = generator_.ids .size ();
13461341 }
13471342 } else {
13481343 for (size_t i = 0 ; i < finalPaths_.size (); ++i) {
13491344 CHECK (!finalPaths_[i].empty ());
1350- generator_. ids . insert (generator_. ids . begin (),
1351- finalPaths_[i][ 0 ] .ids .begin (),
1352- finalPaths_[i][ 0 ] .ids .end ());
1353- starts[i + 1 ] = starts[i] + finalPaths_[i][ 0 ] .ids .size ();
1345+ Path& path = finalPaths_[i][ 0 ];
1346+ generator_ .ids .insert (
1347+ generator_. ids . begin (), path. ids . begin (), path .ids .end ());
1348+ starts[i + 1 ] = starts[i] + path .ids .size ();
13541349 }
13551350 }
13561351}
@@ -1364,25 +1359,76 @@ void RecurrentGradientMachine::copyDataOutlinkFrame(size_t machineCur) {
13641359 }
13651360}
13661361
1367- void RecurrentGradientMachine::createDataOutlink (
1368- std::vector<int >& machineIdVec) {
1369- size_t seqNum =
1370- getBeamSize () > 1UL ? finalPaths_.size () : finalPaths_[0 ].size ();
1371- std::vector<int > starts (seqNum + 1 , 0 );
1372- for (size_t i = 0 ; i < seqNum; ++i) {
1373- size_t seqLen = getBeamSize () > 1UL ? finalPaths_[i][0 ].ids .size ()
1374- : finalPaths_[0 ][i].ids .size ();
1375- starts[i + 1 ] = starts[i] + seqLen;
1362+ void RecurrentGradientMachine::createDataOutlinkSelRowsInfo (
1363+ bool isSeq, std::vector<Argument>& outArgs) {
1364+ batchMachineIdVec_.clear ();
1365+
1366+ size_t seqIdx = 0 ;
1367+ for (size_t i = 0 ; i < finalPaths_.size (); ++i) {
1368+ for (size_t j = 0 ; j < finalPaths_[i].size (); ++j) {
1369+ std::vector<int >& machineIdVec = finalPaths_[i][j].machineIdVec ;
1370+ if (isSeq) {
1371+ for (size_t i = 0 ; i < machineIdVec.size (); ++i) {
1372+ size_t rowId = machineIdVec[i];
1373+ int * seqPos =
1374+ outArgs[i].sequenceStartPositions ->getMutableData (false );
1375+ batchMachineIdVec_.push_back (seqPos[rowId]);
1376+ }
1377+ } else {
1378+ batchMachineIdVec_.insert (
1379+ batchMachineIdVec_.end (), machineIdVec.begin (), machineIdVec.end ());
1380+ }
1381+ seqIdx++;
1382+ }
1383+ }
1384+ }
1385+
1386+ void RecurrentGradientMachine::createDataOutlinkCopySizeInfo (
1387+ bool isSeq, std::vector<Argument>& outArgs, std::vector<int >& copySize) {
1388+ size_t totalSeqNum = std::accumulate (
1389+ finalPaths_.begin (),
1390+ finalPaths_.end (),
1391+ 0UL ,
1392+ [](size_t a, const std::vector<Path>& b) { return a + b.size (); });
1393+ copySize.resize (totalSeqNum, 1 );
1394+
1395+ batchMachineStartPos_.resize (totalSeqNum + 1 , 0 );
1396+ if (isSeq) {
1397+ ICpuGpuVectorPtr inputSeqStartPos = outArgs[0 ].sequenceStartPositions ;
1398+ CHECK_EQ (static_cast <size_t >(inputSeqStartPos->getSize () - 1 ),
1399+ getBeamSize () > 1 ? finalPaths_.size () : finalPaths_[0 ].size ());
1400+ int * starts = inputSeqStartPos->getMutableData (false );
1401+ int seqId = 0 ;
1402+ for (int i = 0 ; i < finalPaths_.size (); ++i) {
1403+ for (int j = 0 ; j < finalPaths_[i].size (); ++j) {
1404+ copySize[seqId] = getBeamSize () > 1 ? starts[i + 1 ] - starts[i]
1405+ : starts[j + 1 ] - starts[j];
1406+ batchMachineStartPos_[seqId + 1 ] =
1407+ batchMachineStartPos_[seqId] + finalPaths_[i][j].ids .size ();
1408+ seqId++;
1409+ }
1410+ }
1411+ } else {
1412+ for (size_t i = 0 ; i < finalPaths_[0 ].size (); ++i)
1413+ batchMachineStartPos_[i + 1 ] =
1414+ batchMachineStartPos_[i] + finalPaths_[0 ][i].ids .size ();
13761415 }
1416+ }
13771417
1418+ void RecurrentGradientMachine::createDataOutlink () {
13781419 for (size_t i = 0 ; i < dataArgsSize_; i++) {
1420+ bool isSeq = dataArgsFrame_[i][0 ].hasSeq ();
1421+ std::vector<int > copySize;
1422+ createDataOutlinkCopySizeInfo (isSeq, dataArgsFrame_[i], copySize);
1423+ createDataOutlinkSelRowsInfo (isSeq, dataArgsFrame_[i]);
1424+
13791425 dataArgs_[i].concat (dataArgsFrame_[i],
1380- machineIdVec,
1381- starts,
1426+ batchMachineIdVec_,
1427+ batchMachineStartPos_,
1428+ copySize,
13821429 useGpu_,
13831430 HPPL_STREAM_1,
13841431 PASS_TEST);
1385-
13861432 auto dataAgent =
13871433 dynamic_cast <DataLayer*>(outFrameLines_[i + 1 ].agentLayer .get ());
13881434 CHECK_NOTNULL (dataAgent);
0 commit comments