@@ -717,63 +717,131 @@ class GradientUtils {
717717 }
718718
719719 Value* unwrapM (Value* val, IRBuilder<>& BuilderM, const ValueToValueMapTy& available, bool lookupIfAble) {
720- assert (val);
720+ assert (val);
721+
722+ static std::map<std::pair<Value*, BasicBlock*>, Value*> cache;
723+ auto cidx = std::make_pair (val, BuilderM.GetInsertBlock ());
724+ if (cache.find (cidx) != cache.end ()) {
725+ return cache[cidx];
726+ }
727+
721728 if (available.count (val)) {
722729 return available.lookup (val);
723730 }
731+
732+ if (auto inst = dyn_cast<Instruction>(val)) {
733+ if (isOriginalBlock (*BuilderM.GetInsertBlock ())) {
734+ if (BuilderM.GetInsertBlock ()->size () && BuilderM.GetInsertPoint () != BuilderM.GetInsertBlock ()->end ()) {
735+ if (DT.dominates (inst, &*BuilderM.GetInsertPoint ())) {
736+ // llvm::errs() << "allowed " << *inst << "from domination\n";
737+ return inst;
738+ }
739+ } else {
740+ if (DT.dominates (inst, BuilderM.GetInsertBlock ())) {
741+ // llvm::errs() << "allowed " << *inst << "from block domination\n";
742+ return inst;
743+ }
744+ }
745+ }
746+ }
724747
725748 if (isa<Argument>(val) || isa<Constant>(val)) {
749+ cache[std::make_pair (val, BuilderM.GetInsertBlock ())] = val;
726750 return val;
727751 } else if (isa<AllocaInst>(val)) {
752+ cache[std::make_pair (val, BuilderM.GetInsertBlock ())] = val;
728753 return val;
729754 } else if (auto op = dyn_cast<CastInst>(val)) {
730755 auto op0 = unwrapM (op->getOperand (0 ), BuilderM, available, lookupIfAble);
731756 if (op0 == nullptr ) goto endCheck;
732- return BuilderM.CreateCast (op->getOpcode (), op0, op->getDestTy (), op->getName ()+" _unwrap" );
757+ auto toreturn = BuilderM.CreateCast (op->getOpcode (), op0, op->getDestTy (), op->getName ()+" _unwrap" );
758+ if (cache.find (std::make_pair ((Value*)op->getOperand (0 ), BuilderM.GetInsertBlock ())) != cache.end ()) {
759+ cache[cidx] = toreturn;
760+ }
761+ return toreturn;
733762 } else if (auto op = dyn_cast<ExtractValueInst>(val)) {
734763 auto op0 = unwrapM (op->getAggregateOperand (), BuilderM, available, lookupIfAble);
735764 if (op0 == nullptr ) goto endCheck;
736- return BuilderM.CreateExtractValue (op0, op->getIndices (), op->getName ()+" _unwrap" );
765+ auto toreturn = BuilderM.CreateExtractValue (op0, op->getIndices (), op->getName ()+" _unwrap" );
766+ if (cache.find (std::make_pair ((Value*)op->getOperand (0 ), BuilderM.GetInsertBlock ())) != cache.end ()) {
767+ cache[cidx] = toreturn;
768+ }
769+ return toreturn;
737770 } else if (auto op = dyn_cast<BinaryOperator>(val)) {
738771 auto op0 = unwrapM (op->getOperand (0 ), BuilderM, available, lookupIfAble);
739772 if (op0 == nullptr ) goto endCheck;
740773 auto op1 = unwrapM (op->getOperand (1 ), BuilderM, available, lookupIfAble);
741774 if (op1 == nullptr ) goto endCheck;
742- return BuilderM.CreateBinOp (op->getOpcode (), op0, op1);
775+ auto toreturn = BuilderM.CreateBinOp (op->getOpcode (), op0, op1);
776+ if (
777+ (cache.find (std::make_pair ((Value*)op->getOperand (0 ), BuilderM.GetInsertBlock ())) != cache.end ()) &&
778+ (cache.find (std::make_pair ((Value*)op->getOperand (1 ), BuilderM.GetInsertBlock ())) != cache.end ()) ) {
779+ cache[cidx] = toreturn;
780+ }
781+ return toreturn;
743782 } else if (auto op = dyn_cast<ICmpInst>(val)) {
744783 auto op0 = unwrapM (op->getOperand (0 ), BuilderM, available, lookupIfAble);
745784 if (op0 == nullptr ) goto endCheck;
746785 auto op1 = unwrapM (op->getOperand (1 ), BuilderM, available, lookupIfAble);
747786 if (op1 == nullptr ) goto endCheck;
748- return BuilderM.CreateICmp (op->getPredicate (), op0, op1);
787+ auto toreturn = BuilderM.CreateICmp (op->getPredicate (), op0, op1);
788+ if (
789+ (cache.find (std::make_pair ((Value*)op->getOperand (0 ), BuilderM.GetInsertBlock ())) != cache.end ()) &&
790+ (cache.find (std::make_pair ((Value*)op->getOperand (1 ), BuilderM.GetInsertBlock ())) != cache.end ()) ) {
791+ cache[cidx] = toreturn;
792+ }
793+ return toreturn;
749794 } else if (auto op = dyn_cast<FCmpInst>(val)) {
750795 auto op0 = unwrapM (op->getOperand (0 ), BuilderM, available, lookupIfAble);
751796 if (op0 == nullptr ) goto endCheck;
752797 auto op1 = unwrapM (op->getOperand (1 ), BuilderM, available, lookupIfAble);
753798 if (op1 == nullptr ) goto endCheck;
754- return BuilderM.CreateFCmp (op->getPredicate (), op0, op1);
799+ auto toreturn = BuilderM.CreateFCmp (op->getPredicate (), op0, op1);
800+ if (
801+ (cache.find (std::make_pair ((Value*)op->getOperand (0 ), BuilderM.GetInsertBlock ())) != cache.end ()) &&
802+ (cache.find (std::make_pair ((Value*)op->getOperand (1 ), BuilderM.GetInsertBlock ())) != cache.end ()) ) {
803+ cache[cidx] = toreturn;
804+ }
805+ return toreturn;
755806 } else if (auto op = dyn_cast<SelectInst>(val)) {
756807 auto op0 = unwrapM (op->getOperand (0 ), BuilderM, available, lookupIfAble);
757808 if (op0 == nullptr ) goto endCheck;
758809 auto op1 = unwrapM (op->getOperand (1 ), BuilderM, available, lookupIfAble);
759810 if (op1 == nullptr ) goto endCheck;
760811 auto op2 = unwrapM (op->getOperand (2 ), BuilderM, available, lookupIfAble);
761812 if (op2 == nullptr ) goto endCheck;
762- return BuilderM.CreateSelect (op0, op1, op2);
813+ auto toreturn = BuilderM.CreateSelect (op0, op1, op2);
814+ if (
815+ (cache.find (std::make_pair ((Value*)op->getOperand (0 ), BuilderM.GetInsertBlock ())) != cache.end ()) &&
816+ (cache.find (std::make_pair ((Value*)op->getOperand (1 ), BuilderM.GetInsertBlock ())) != cache.end ()) &&
817+ (cache.find (std::make_pair ((Value*)op->getOperand (2 ), BuilderM.GetInsertBlock ())) != cache.end ()) ) {
818+ cache[cidx] = toreturn;
819+ }
820+ return toreturn;
763821 } else if (auto inst = dyn_cast<GetElementPtrInst>(val)) {
764822 auto ptr = unwrapM (inst->getPointerOperand (), BuilderM, available, lookupIfAble);
765823 if (ptr == nullptr ) goto endCheck;
824+ bool cached = cache.find (std::make_pair (inst->getPointerOperand (), BuilderM.GetInsertBlock ())) != cache.end ();
766825 SmallVector<Value*,4 > ind;
767826 for (auto & a : inst->indices ()) {
768827 auto op = unwrapM (a, BuilderM,available, lookupIfAble);
769828 if (op == nullptr ) goto endCheck;
829+ cached &= cache.find (std::make_pair ((Value*)a, BuilderM.GetInsertBlock ())) != cache.end ();
770830 ind.push_back (op);
771831 }
772- return BuilderM.CreateGEP (ptr, ind);
832+ auto toreturn = BuilderM.CreateGEP (ptr, ind, inst->getName () + " _unwrap" );
833+ if (cached) {
834+ cache[cidx] = toreturn;
835+ }
836+ return toreturn;
773837 } else if (auto load = dyn_cast<LoadInst>(val)) {
774838 Value* idx = unwrapM (load->getOperand (0 ), BuilderM, available, lookupIfAble);
775839 if (idx == nullptr ) goto endCheck;
776- return BuilderM.CreateLoad (idx);
840+ auto toreturn = BuilderM.CreateLoad (idx);
841+ if (cache.find (std::make_pair ((Value*)load->getOperand (0 ), BuilderM.GetInsertBlock ())) != cache.end ()) {
842+ cache[cidx] = toreturn;
843+ }
844+ return toreturn;
777845 } else if (auto op = dyn_cast<IntrinsicInst>(val)) {
778846 switch (op->getIntrinsicID ()) {
779847 case Intrinsic::sin: {
@@ -839,7 +907,6 @@ class GradientUtils {
839907 if (!inLoop) {
840908 return entryBuilder.CreateAlloca (T, nullptr , name+" _cache" );
841909 } else {
842- Value* size = nullptr ;
843910
844911 BasicBlock* outermostPreheader = nullptr ;
845912
@@ -853,38 +920,45 @@ class GradientUtils {
853920
854921 IRBuilder <> allocationBuilder (&outermostPreheader->back ());
855922
856- for (LoopContext idx = lc; ; getContext (idx.parent ->getHeader (), idx) ) {
857- // TODO handle allocations for dynamic loops
858- if (idx.dynamic && idx.parent != nullptr ) {
859- assert (idx.var );
860- assert (idx.var ->getParent ());
861- assert (idx.var ->getParent ()->getParent ());
862- llvm::errs () << *idx.var ->getParent ()->getParent () << " \n "
863- << " idx.var=" <<*idx.var << " \n "
864- << " idx.limit=" <<*idx.limit << " \n " ;
865- llvm::errs () << " cannot handle non-outermost dynamic loop\n " ;
866- assert (0 && " cannot handle non-outermost dynamic loop" );
867- }
868- Value* ns = nullptr ;
869- Type* intT = idx.dynamic ? cast<PointerType>(idx.limit ->getType ())->getElementType () : idx.limit ->getType ();
870- if (idx.dynamic ) {
871- ns = ConstantInt::get (intT, 1 );
872- } else {
873- Value* limitm1 = nullptr ;
874- ValueToValueMapTy emptyMap;
875- limitm1 = unwrapM (idx.limit , allocationBuilder, emptyMap, /* lookupIfAble*/ false );
876- if (limitm1 == nullptr ) {
877- assert (outermostPreheader);
878- assert (outermostPreheader->getParent ());
879- llvm::errs () << *outermostPreheader->getParent () << " \n " ;
880- llvm::errs () << " needed value " << *idx.limit << " at " << allocationBuilder.GetInsertBlock ()->getName () << " \n " ;
923+ Value* size = nullptr ;
924+ static std::map<BasicBlock*, Value*> sizecache;
925+ if (sizecache.find (lc.header ) != sizecache.end ()) {
926+ size = sizecache[lc.header ];
927+ } else {
928+ for (LoopContext idx = lc; ; getContext (idx.parent ->getHeader (), idx) ) {
929+ // TODO handle allocations for dynamic loops
930+ if (idx.dynamic && idx.parent != nullptr ) {
931+ assert (idx.var );
932+ assert (idx.var ->getParent ());
933+ assert (idx.var ->getParent ()->getParent ());
934+ llvm::errs () << *idx.var ->getParent ()->getParent () << " \n "
935+ << " idx.var=" <<*idx.var << " \n "
936+ << " idx.limit=" <<*idx.limit << " \n " ;
937+ llvm::errs () << " cannot handle non-outermost dynamic loop\n " ;
938+ assert (0 && " cannot handle non-outermost dynamic loop" );
939+ }
940+ Value* ns = nullptr ;
941+ Type* intT = idx.dynamic ? cast<PointerType>(idx.limit ->getType ())->getElementType () : idx.limit ->getType ();
942+ if (idx.dynamic ) {
943+ ns = ConstantInt::get (intT, 1 );
944+ } else {
945+ Value* limitm1 = nullptr ;
946+ ValueToValueMapTy emptyMap;
947+ limitm1 = unwrapM (idx.limit , allocationBuilder, emptyMap, /* lookupIfAble*/ false );
948+ if (limitm1 == nullptr ) {
949+ assert (outermostPreheader);
950+ assert (outermostPreheader->getParent ());
951+ llvm::errs () << *outermostPreheader->getParent () << " \n " ;
952+ llvm::errs () << " needed value " << *idx.limit << " at " << allocationBuilder.GetInsertBlock ()->getName () << " \n " ;
953+ }
954+ assert (limitm1);
955+ ns = allocationBuilder.CreateNUWAdd (limitm1, ConstantInt::get (intT, 1 ));
956+ }
957+ if (size == nullptr ) size = ns;
958+ else size = allocationBuilder.CreateNUWMul (size, ns);
959+ if (idx.parent == nullptr ) break ;
881960 }
882- assert (limitm1);
883- ns = allocationBuilder.CreateNUWAdd (limitm1, ConstantInt::get (intT, 1 ));
884- }
885- if (size == nullptr ) size = ns;
886- else size = allocationBuilder.CreateNUWMul (size, ns);
887- if (idx.parent == nullptr ) break ;
961+ sizecache[lc.header ] = size;
888962 }
889963
890964 auto firstallocation = CallInst::CreateMalloc (
@@ -955,6 +1029,7 @@ class GradientUtils {
9551029 limits.push_back (lim);
9561030 }
9571031
1032+ /*
9581033 Value* idx = nullptr;
9591034 for(unsigned i=0; i<indices.size(); i++) {
9601035 if (i == 0) {
@@ -963,20 +1038,18 @@ class GradientUtils {
9631038 auto mul = v.CreateNUWMul(indices[i], limits[i-1]);
9641039 idx = v.CreateNUWAdd(idx, mul);
9651040 }
966- }
1041+ }*/
9671042
9681043 if (dynamicPHI != nullptr ) {
9691044 Type *BPTy = Type::getInt8PtrTy (v.GetInsertBlock ()->getContext ());
9701045 auto realloc = newFunc->getParent ()->getOrInsertFunction (" realloc" , BPTy, BPTy, size->getType ());
9711046 Value* allocation = v.CreateLoad (holderAlloc);
972- auto foo = v.CreateNUWAdd (dynamicPHI, ConstantInt::get (dynamicPHI->getType (), 1 ));
1047+ Value* foo = v.CreateNUWAdd (dynamicPHI, ConstantInt::get (dynamicPHI->getType (), 1 ));
1048+ Value* realloc_size = v.CreateNUWMul (size, foo);
9731049 Value* idxs[2 ] = {
9741050 v.CreatePointerCast (allocation, BPTy),
9751051 v.CreateNUWMul (
976- ConstantInt::get (size->getType (), newFunc->getParent ()->getDataLayout ().getTypeAllocSizeInBits (T)/8 ),
977- v.CreateNUWMul (
978- size, foo
979- )
1052+ ConstantInt::get (size->getType (), newFunc->getParent ()->getDataLayout ().getTypeAllocSizeInBits (T)/8 ), realloc_size
9801053 )
9811054 };
9821055
0 commit comments