-
Notifications
You must be signed in to change notification settings - Fork 145
Better global handling and fix caching bug #30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,6 +38,10 @@ cl::opt<bool> printconst( | |
| "enzyme_printconst", cl::init(false), cl::Hidden, | ||
| cl::desc("Print constant detection algorithm")); | ||
|
|
||
| cl::opt<bool> nonmarkedglobals_inactive( | ||
| "enzyme_nonmarkedglobals_inactive", cl::init(false), cl::Hidden, | ||
| cl::desc("Consider all nonmarked globals to be inactive")); | ||
|
|
||
| bool isIntASecretFloat(Value* val) { | ||
| assert(val->getType()->isIntegerTy()); | ||
|
|
||
|
|
@@ -199,8 +203,7 @@ Type* isIntPointerASecretFloat(Value* val) { | |
| continue; | ||
| } | ||
| if (auto gep = dyn_cast<GetElementPtrInst>(v)) { | ||
| v = gep->getOperand(0); | ||
| continue; | ||
| trackPointer(gep->getOperand(0)); | ||
| } | ||
| if (auto phi = dyn_cast<PHINode>(v)) { | ||
| for(auto &a : phi->incoming_values()) { | ||
|
|
@@ -218,6 +221,10 @@ Type* isIntPointerASecretFloat(Value* val) { | |
| et = st->getTypeAtIndex((unsigned int)0); | ||
| continue; | ||
| } | ||
| if (auto st = dyn_cast<ArrayType>(et)) { | ||
| et = st->getElementType(); | ||
| continue; | ||
| } | ||
| break; | ||
| } while(1); | ||
| llvm::errs() << " for val " << *v << *et << "\n"; | ||
|
|
@@ -401,6 +408,7 @@ bool isconstantM(Instruction* inst, SmallPtrSetImpl<Value*> &constants, SmallPtr | |
| } | ||
| if (auto call = dyn_cast<CallInst>(a)) { | ||
| auto fnp = call->getCalledFunction(); | ||
| // For known library functions, special case how derivatives flow to allow for more aggressive active variable detection | ||
| if (fnp) { | ||
| auto fn = fnp->getName(); | ||
| // todo realloc consider? | ||
|
|
@@ -410,6 +418,8 @@ bool isconstantM(Instruction* inst, SmallPtrSetImpl<Value*> &constants, SmallPtr | |
| continue; | ||
| if (fnp->getIntrinsicID() == Intrinsic::memcpy && call->getArgOperand(0) != inst && call->getArgOperand(1) != inst) | ||
| continue; | ||
| if (fnp->getIntrinsicID() == Intrinsic::memmove && call->getArgOperand(0) != inst && call->getArgOperand(1) != inst) | ||
| continue; | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -556,6 +566,36 @@ bool isconstantValueM(Value* val, SmallPtrSetImpl<Value*> &constants, SmallPtrSe | |
| llvm::errs() << *val << "\n"; | ||
| assert(0 && "must've put arguments in constant/nonconstant"); | ||
| } | ||
|
|
||
| if (auto gi = dyn_cast<GlobalVariable>(val)) { | ||
| if (!hasMetadata(gi, "enzyme_shadow") && nonmarkedglobals_inactive) { | ||
| constants.insert(val); | ||
| return true; | ||
| } | ||
| //TODO consider this more | ||
| if (gi->isConstant() && isconstantValueM(gi->getInitializer(), constants, nonconstant, retvals, originalInstructions, directions)) { | ||
| constants.insert(val); | ||
| return true; | ||
| } | ||
| } | ||
|
|
||
| if (auto ce = dyn_cast<ConstantExpr>(val)) { | ||
| if (ce->isCast()) { | ||
| if (isconstantValueM(ce->getOperand(0), constants, nonconstant, retvals, originalInstructions, directions)) { | ||
| constants.insert(val); | ||
| return true; | ||
| } | ||
| } | ||
| if (ce->isGEPWithNoNotionalOverIndexing()) { | ||
| if (isconstantValueM(ce->getOperand(0), constants, nonconstant, retvals, originalInstructions, directions)) { | ||
| constants.insert(val); | ||
| return true; | ||
| } | ||
| if (auto gi = dyn_cast<GlobalVariable>(val)) { | ||
|
|
||
| } | ||
| } | ||
| } | ||
|
|
||
| if (auto inst = dyn_cast<Instruction>(val)) { | ||
| if (isconstantM(inst, constants, nonconstant, retvals, originalInstructions, directions)) return true; | ||
|
|
@@ -589,6 +629,8 @@ bool isconstantValueM(Value* val, SmallPtrSetImpl<Value*> &constants, SmallPtrSe | |
| continue; | ||
| if (fnp->getIntrinsicID() == Intrinsic::memcpy && call->getArgOperand(0) != val && call->getArgOperand(1) != val) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For skipped intrinsics, a comment should be present describing the property the intrinsic has that lets us skip it. This lets people understand when they can add a "missing" intrinsic.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not skipping intrinsic, but rather special casing memcpy/memmove to say that the size variable is not made active even if other arguments are active
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My point was: if you are handling a bunch of things similarly (e.g. intrinsics) and its too cumbersome to give an explicit comment for each one, then an explanation for the group is good enough --- and should exist anyways so that folks understand what determines membership in the group of similar cases.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed re adding a comment, just wanted to clarify what code was doing |
||
| continue; | ||
| if (fnp->getIntrinsicID() == Intrinsic::memmove && call->getArgOperand(0) != val && call->getArgOperand(1) != val) | ||
| continue; | ||
| } | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| ; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -mem2reg -sroa -instsimplify -adce -correlated-propagation -simplifycfg -S | FileCheck %s | ||
|
|
||
| ; Function Attrs: noinline norecurse nounwind uwtable | ||
| define dso_local void @subf(i1 zeroext %z, double* nocapture %x) local_unnamed_addr #0 { | ||
| entry: | ||
| br i1 %z, label %if.then, label %if.end | ||
|
|
||
| if.then: ; preds = %entry | ||
| %0 = load double, double* %x, align 8 | ||
| %mul = fmul fast double %0, %0 | ||
| store double %mul, double* %x, align 8 | ||
| br label %if.end | ||
|
|
||
| if.end: ; preds = %if.then, %entry | ||
| ret void | ||
| } | ||
|
|
||
| ; Function Attrs: noinline norecurse nounwind uwtable | ||
| define dso_local void @f(i1 zeroext %z, double* nocapture %x) #0 { | ||
| entry: | ||
| tail call void @subf(i1 zeroext %z, double* %x) | ||
| %arrayidx = getelementptr inbounds double, double* %x, i64 1 | ||
| store double 2.000000e+00, double* %arrayidx, align 8 | ||
| ret void | ||
| } | ||
|
|
||
| ; Function Attrs: noinline nounwind uwtable | ||
| define dso_local double @dsumsquare(i1 zeroext %z, double* %x, double* %xp) local_unnamed_addr #1 { | ||
| entry: | ||
| %call = tail call fast double @__enzyme_autodiff(i8* bitcast (void (i1, double*)* @f to i8*), i1 zeroext %z, double* %x, double* %xp) | ||
| ret double %call | ||
| } | ||
|
|
||
| declare dso_local double @__enzyme_autodiff(i8*, i1 zeroext, double*, double*) | ||
|
|
||
| ; CHECK: define internal {} @diffef(i1 zeroext %z, double* nocapture %x, double* %"x'") { | ||
| ; CHECK-NEXT: entry: | ||
| ; CHECK-NEXT: %0 = call { { double } } @augmented_subf(i1 %z, double* %x, double* %"x'") | ||
| ; CHECK-NEXT: %1 = extractvalue { { double } } %0, 0 | ||
| ; CHECK-NEXT: %"arrayidx'ipge" = getelementptr inbounds double, double* %"x'", i64 1 | ||
| ; CHECK-NEXT: %arrayidx = getelementptr inbounds double, double* %x, i64 1 | ||
| ; CHECK-NEXT: store double 2.000000e+00, double* %arrayidx, align 8 | ||
| ; CHECK-NEXT: store double 0.000000e+00, double* %"arrayidx'ipge", align 8 | ||
| ; CHECK-NEXT: %2 = call {} @diffesubf(i1 %z, double* nonnull %x, double* %"x'", { double } %1) | ||
| ; CHECK-NEXT: ret {} undef | ||
| ; CHECK-NEXT: } | ||
|
|
||
| ; CHECK: define internal { { double } } @augmented_subf(i1 zeroext %z, double* nocapture %x, double* %"x'") | ||
| ; CHECK-NEXT: entry: | ||
| ; CHECK-NEXT: br i1 %z, label %if.then, label %if.end | ||
|
|
||
| ; CHECK: if.then: ; preds = %entry | ||
| ; CHECK-NEXT: %0 = load double, double* %x, align 8 | ||
| ; CHECK-NEXT: %mul = fmul fast double %0, %0 | ||
| ; CHECK-NEXT: store double %mul, double* %x, align 8 | ||
| ; CHECK-NEXT: br label %if.end | ||
|
|
||
| ; CHECK: if.end: ; preds = %if.then, %entry | ||
| ; CHECK-NEXT: %[[val:.+]] = phi double [ %0, %if.then ], [ undef, %entry ] | ||
| ; CHECK-NEXT: %[[toret:.+]] = insertvalue { { double } } undef, double %[[val]], 0, 0 | ||
| ; CHECK-NEXT: ret { { double } } %[[toret]] | ||
| ; CHECK-NEXT: } | ||
|
|
||
| ; CHECK: define internal {} @diffesubf(i1 zeroext %z, double* nocapture %x, double* %"x'", { double } %tapeArg) | ||
| ; CHECK-NEXT: entry: | ||
| ; CHECK-NEXT: br i1 %z, label %invertif.then, label %invertentry | ||
|
|
||
| ; CHECK: invertentry: ; preds = %entry, %invertif.then | ||
| ; CHECK-NEXT: ret {} undef | ||
|
|
||
| ; CHECK: invertif.then: ; preds = %entry | ||
| ; CHECK-NEXT: %0 = load double, double* %"x'" | ||
| ; CHECK-NEXT: store double 0.000000e+00, double* %"x'", align 8 | ||
| ; CHECK-NEXT: %_unwrap = extractvalue { double } %tapeArg, 0 | ||
| ; CHECK-NEXT: %m0diffe = fmul fast double %0, %_unwrap | ||
| ; CHECK-NEXT: %m1diffe = fmul fast double %0, %_unwrap | ||
| ; CHECK-NEXT: %1 = fadd fast double %m0diffe, %m1diffe | ||
| ; CHECK-NEXT: %2 = load double, double* %"x'" | ||
| ; CHECK-NEXT: %3 = fadd fast double %2, %1 | ||
| ; CHECK-NEXT: store double %3, double* %"x'" | ||
| ; CHECK-NEXT: br label %invertentry | ||
| ; CHECK-NEXT: } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments needed, especially before conditionals that involve recursive calls.