@@ -320,6 +320,32 @@ llvm::Value *CodeGenFunction::getTypeSize(QualType Ty) {
320
320
return CGM.getSize (SizeInChars);
321
321
}
322
322
323
+ void CodeGenFunction::GenerateOpenMPCapturedVarsAggregate (
324
+ const CapturedStmt &S, SmallVectorImpl<llvm::Value *> &CapturedVars) {
325
+ const RecordDecl *RD = S.getCapturedRecordDecl ();
326
+ QualType RecordTy = getContext ().getRecordType (RD);
327
+ // Create the aggregate argument struct for the outlined function.
328
+ LValue AggLV = MakeAddrLValue (
329
+ CreateMemTemp (RecordTy, " omp.outlined.arg.agg." ), RecordTy);
330
+
331
+ // Initialize the aggregate with captured values.
332
+ auto CurField = RD->field_begin ();
333
+ for (CapturedStmt::const_capture_init_iterator I = S.capture_init_begin (),
334
+ E = S.capture_init_end ();
335
+ I != E; ++I, ++CurField) {
336
+ LValue LV = EmitLValueForFieldInitialization (AggLV, *CurField);
337
+ // Initialize for VLA.
338
+ if (CurField->hasCapturedVLAType ()) {
339
+ EmitLambdaVLACapture (CurField->getCapturedVLAType (), LV);
340
+ } else
341
+ // Initialize for capturesThis, capturesVariableByCopy,
342
+ // capturesVariable
343
+ EmitInitializerForField (*CurField, LV, *I);
344
+ }
345
+
346
+ CapturedVars.push_back (AggLV.getPointer (*this ));
347
+ }
348
+
323
349
void CodeGenFunction::GenerateOpenMPCapturedVars (
324
350
const CapturedStmt &S, SmallVectorImpl<llvm::Value *> &CapturedVars) {
325
351
const RecordDecl *RD = S.getCapturedRecordDecl ();
@@ -420,6 +446,101 @@ struct FunctionOptions {
420
446
};
421
447
} // namespace
422
448
449
+ static llvm::Function *emitOutlinedFunctionPrologueAggregate (
450
+ CodeGenFunction &CGF, FunctionArgList &Args,
451
+ llvm::MapVector<const Decl *, std::pair<const VarDecl *, Address>>
452
+ &LocalAddrs,
453
+ llvm::DenseMap<const Decl *, std::pair<const Expr *, llvm::Value *>>
454
+ &VLASizes,
455
+ llvm::Value *&CXXThisValue, const CapturedStmt &CS, SourceLocation Loc,
456
+ StringRef FunctionName) {
457
+ const CapturedDecl *CD = CS.getCapturedDecl ();
458
+ const RecordDecl *RD = CS.getCapturedRecordDecl ();
459
+ assert (CD->hasBody () && " missing CapturedDecl body" );
460
+
461
+ CXXThisValue = nullptr ;
462
+ // Build the argument list.
463
+ CodeGenModule &CGM = CGF.CGM ;
464
+ ASTContext &Ctx = CGM.getContext ();
465
+ Args.append (CD->param_begin (), CD->param_end ());
466
+
467
+ // Create the function declaration.
468
+ const CGFunctionInfo &FuncInfo =
469
+ CGM.getTypes ().arrangeBuiltinFunctionDeclaration (Ctx.VoidTy , Args);
470
+ llvm::FunctionType *FuncLLVMTy = CGM.getTypes ().GetFunctionType (FuncInfo);
471
+
472
+ auto *F =
473
+ llvm::Function::Create (FuncLLVMTy, llvm::GlobalValue::InternalLinkage,
474
+ FunctionName, &CGM.getModule ());
475
+ CGM.SetInternalFunctionAttributes (CD, F, FuncInfo);
476
+ if (CD->isNothrow ())
477
+ F->setDoesNotThrow ();
478
+ F->setDoesNotRecurse ();
479
+
480
+ // Generate the function.
481
+ CGF.StartFunction (CD, Ctx.VoidTy , F, FuncInfo, Args, Loc, Loc);
482
+ Address ContextAddr = CGF.GetAddrOfLocalVar (CD->getContextParam ());
483
+ llvm::Value *ContextV = CGF.Builder .CreateLoad (ContextAddr);
484
+ LValue ContextLV = CGF.MakeNaturalAlignAddrLValue (
485
+ ContextV, CGM.getContext ().getTagDeclType (RD));
486
+ auto I = CS.captures ().begin ();
487
+ for (const FieldDecl *FD : RD->fields ()) {
488
+ LValue FieldLV = CGF.EmitLValueForFieldInitialization (ContextLV, FD);
489
+ // Do not map arguments if we emit function with non-original types.
490
+ Address LocalAddr = FieldLV.getAddress (CGF);
491
+ // If we are capturing a pointer by copy we don't need to do anything, just
492
+ // use the value that we get from the arguments.
493
+ if (I->capturesVariableByCopy () && FD->getType ()->isAnyPointerType ()) {
494
+ const VarDecl *CurVD = I->getCapturedVar ();
495
+ LocalAddrs.insert ({FD, {CurVD, LocalAddr}});
496
+ ++I;
497
+ continue ;
498
+ }
499
+
500
+ LValue ArgLVal =
501
+ CGF.MakeAddrLValue (LocalAddr, FD->getType (), AlignmentSource::Decl);
502
+ if (FD->hasCapturedVLAType ()) {
503
+ llvm::Value *ExprArg = CGF.EmitLoadOfScalar (ArgLVal, I->getLocation ());
504
+ const VariableArrayType *VAT = FD->getCapturedVLAType ();
505
+ VLASizes.try_emplace (FD, VAT->getSizeExpr (), ExprArg);
506
+ } else if (I->capturesVariable ()) {
507
+ const VarDecl *Var = I->getCapturedVar ();
508
+ QualType VarTy = Var->getType ();
509
+ Address ArgAddr = ArgLVal.getAddress (CGF);
510
+ if (ArgLVal.getType ()->isLValueReferenceType ()) {
511
+ ArgAddr = CGF.EmitLoadOfReference (ArgLVal);
512
+ } else if (!VarTy->isVariablyModifiedType () || !VarTy->isPointerType ()) {
513
+ assert (ArgLVal.getType ()->isPointerType ());
514
+ ArgAddr = CGF.EmitLoadOfPointer (
515
+ ArgAddr, ArgLVal.getType ()->castAs <PointerType>());
516
+ }
517
+ LocalAddrs.insert (
518
+ {FD, {Var, Address (ArgAddr.getPointer (), Ctx.getDeclAlign (Var))}});
519
+ } else if (I->capturesVariableByCopy ()) {
520
+ assert (!FD->getType ()->isAnyPointerType () &&
521
+ " Not expecting a captured pointer." );
522
+ const VarDecl *Var = I->getCapturedVar ();
523
+ Address CopyAddr = CGF.CreateMemTemp (FD->getType (), Ctx.getDeclAlign (FD),
524
+ Var->getName ());
525
+ LValue CopyLVal =
526
+ CGF.MakeAddrLValue (CopyAddr, FD->getType (), AlignmentSource::Decl);
527
+
528
+ RValue ArgRVal = CGF.EmitLoadOfLValue (ArgLVal, I->getLocation ());
529
+ CGF.EmitStoreThroughLValue (ArgRVal, CopyLVal);
530
+
531
+ LocalAddrs.insert ({FD, {Var, CopyAddr}});
532
+ } else {
533
+ // If 'this' is captured, load it into CXXThisValue.
534
+ assert (I->capturesThis ());
535
+ CXXThisValue = CGF.EmitLoadOfScalar (ArgLVal, I->getLocation ());
536
+ LocalAddrs.insert ({FD, {nullptr , ArgLVal.getAddress (CGF)}});
537
+ }
538
+ ++I;
539
+ }
540
+
541
+ return F;
542
+ }
543
+
423
544
static llvm::Function *emitOutlinedFunctionPrologue (
424
545
CodeGenFunction &CGF, FunctionArgList &Args,
425
546
llvm::MapVector<const Decl *, std::pair<const VarDecl *, Address>>
@@ -595,6 +716,37 @@ static llvm::Function *emitOutlinedFunctionPrologue(
595
716
return F;
596
717
}
597
718
719
+ llvm::Function *CodeGenFunction::GenerateOpenMPCapturedStmtFunctionAggregate (
720
+ const CapturedStmt &S, SourceLocation Loc) {
721
+ assert (
722
+ CapturedStmtInfo &&
723
+ " CapturedStmtInfo should be set when generating the captured function" );
724
+ const CapturedDecl *CD = S.getCapturedDecl ();
725
+ // Build the argument list.
726
+ FunctionArgList Args;
727
+ llvm::MapVector<const Decl *, std::pair<const VarDecl *, Address>> LocalAddrs;
728
+ llvm::DenseMap<const Decl *, std::pair<const Expr *, llvm::Value *>> VLASizes;
729
+ StringRef FunctionName = CapturedStmtInfo->getHelperName ();
730
+ llvm::Function *F = emitOutlinedFunctionPrologueAggregate (
731
+ *this , Args, LocalAddrs, VLASizes, CXXThisValue, S, Loc, FunctionName);
732
+ CodeGenFunction::OMPPrivateScope LocalScope (*this );
733
+ for (const auto &LocalAddrPair : LocalAddrs) {
734
+ if (LocalAddrPair.second .first ) {
735
+ LocalScope.addPrivate (LocalAddrPair.second .first , [&LocalAddrPair]() {
736
+ return LocalAddrPair.second .second ;
737
+ });
738
+ }
739
+ }
740
+ (void )LocalScope.Privatize ();
741
+ for (const auto &VLASizePair : VLASizes)
742
+ VLASizeMap[VLASizePair.second .first ] = VLASizePair.second .second ;
743
+ PGO.assignRegionCounters (GlobalDecl (CD), F);
744
+ CapturedStmtInfo->EmitBody (*this , CD->getBody ());
745
+ (void )LocalScope.ForceCleanup ();
746
+ FinishFunction (CD->getBodyRBrace ());
747
+ return F;
748
+ }
749
+
598
750
llvm::Function *
599
751
CodeGenFunction::GenerateOpenMPCapturedStmtFunction (const CapturedStmt &S,
600
752
SourceLocation Loc) {
@@ -1582,7 +1734,7 @@ static void emitCommonOMPParallelDirective(
1582
1734
// The following lambda takes care of appending the lower and upper bound
1583
1735
// parameters when necessary
1584
1736
CodeGenBoundParameters (CGF, S, CapturedVars);
1585
- CGF.GenerateOpenMPCapturedVars (*CS, CapturedVars);
1737
+ CGF.GenerateOpenMPCapturedVarsAggregate (*CS, CapturedVars);
1586
1738
CGF.CGM .getOpenMPRuntime ().emitParallelCall (CGF, S.getBeginLoc (), OutlinedFn,
1587
1739
CapturedVars, IfCond);
1588
1740
}
@@ -6050,7 +6202,7 @@ static void emitCommonOMPTeamsDirective(CodeGenFunction &CGF,
6050
6202
6051
6203
OMPTeamsScope Scope (CGF, S);
6052
6204
llvm::SmallVector<llvm::Value *, 16 > CapturedVars;
6053
- CGF.GenerateOpenMPCapturedVars (*CS, CapturedVars);
6205
+ CGF.GenerateOpenMPCapturedVarsAggregate (*CS, CapturedVars);
6054
6206
CGF.CGM .getOpenMPRuntime ().emitTeamsCall (CGF, S, S.getBeginLoc (), OutlinedFn,
6055
6207
CapturedVars);
6056
6208
}
0 commit comments