-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] Directly SILGen @derivative
attributes to diff witnesses.
#28621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Previously, `@derivative` attribute type-checking created implicit `@differentiable` attributes on the original declaration. This was a longstanding hack powering `@derivative` attribute derivative registration. #28608 made these changes: - Derivative function configurations (from `@differentiable` and `@derivative` attributes) are serialized in modules and are loaded from imported modules. - The differentiation transform uses these derivative function configurations for derivative function lookup instead of `@differentiable` attributes. Now, `@derivative` attributes are directly lowered to differentiability witnesses during SILGen, and implicit `@differentiable` attribute generation is removed. Type-checking changes: - "Overlapping" `@differentiable` and `@derivative` attributes (for the same original declaration and parameter indices) are now disallowed. They semantically conflict because the first "requests derivative generation" while the second "registers a derivative". - "Overlapping" `@differentiable` and `@derivative` attributes are allowed for protocol requirements. Requirement `@differentiable` attributes mean "add JVP/VJP witness table entries" - not "request derivative generation", because there is no function body. - Note that relaxing the "overlapping" condition to consider derivative generic signatures is possible after derivative generic signature mangling for derivative functions: TF-680. Resolves TF-835. Unblocks TF-1021: lifting the "same-file derivative registration only" limitation in `@derivative` attribute type-checking. This should be possible without much work, but needs testing! Exposes TF-1040: `@differentiable` attribute limitations for class methods. Exposes TF-1041: untested protocol requirement `@differentiable` attribute type-checking logic.
Fix "a derivative already exists" error for `Layer.inferring(from:)`.
This doesn't make sense to me. In protocol conformances, it is perfectly legal to declare a conformance and then provide implementations in a different extension in the same module. Differentiable functions should behave the same way. |
The semantics of |
That's a good point, I think I agree! This type-checking change seems misguided - the original motivation was to avoid a crash regarding " Reproducer: protocol P: Differentiable {}
extension P {
@differentiable
func foo() -> Float { 1 }
}
extension P {
@derivative(of: foo)
func vjpFoo() -> (value: Float, pullback: (Float) -> (TangentVector)) {
fatalError()
}
} Previously,
This generated a single differentiability witness, no problem. However, with the current logic in this patch: the This causes two differentiability witnesses to be SILGen'd, each with a different derivative generic signature:
Since derivative generic signature mangling is not yet implemented for derivative functions (TF-680), there's a name conflict issue for derivative functions with the same parameter indices but different derivative generic signatures. SILGen generates a VJP thunk for the second differentiability witness above called
According to the semantics of Tangents:
|
If there's a parameter indices mismatch or generic signature mismatch between a If parameter indices and generic signatures do match, then preventively check whether such a differentiability witness already exists before adding a new one. |
llvm::DenseMap< | ||
std::tuple<Decl *, IndexSubset *, AutoDiffDerivativeFunctionKind>, | ||
DerivativeAttr *> | ||
DerivativeAttrs; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to use the derivative config list that you created in the previous serialization PR, so that we do not have to maintain another list?
Could the derivative config list actually also replace DifferentiableAttrs
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wondered the same. The derivative function configuration doesn't currently store whether the configuration came from (a @differentiable
attribute or @derivative
attribute), and that information is important for users of DifferentiableAttrs
and DerivativeAttrs
, which check duplicate attributes of a specific kind, not just configurations.
I haven't thought deeply about this. I'll file an issue tracking this question if it isn't resolved by the time this PR is merged.
Edit: filed TF-1042 to track Investigate removing/moving ASTContext::{Differentiable,Derivative}Attrs
.
Using AbstractFunctionDecl::getDerivativeFunctionConfigurations
to detect duplicate @differentiable
and @derivative
attributes may be significant for cross-file duplicate derivative registration (TF-1021).
If you import a derivative for func foo
, you shouldn't be able to register a new derivative for func foo
with the same configuration.
This requires changing AbstractFunctionDecl::getDerivativeFunctionConfigurations
to return more information than ArrayRef<AutoDiffConfig>
.
- Minimally, it needs to return an
OptionSet
perAutoDiffConfig
, specifying where theAutoDiffConfig
came from:@differentiable
attribute@derivative
JVP@derivative
VJP- Any combination of the above (three bits)
- For good "duplicate attribute" diagnostics, it also needs to store sth from which we can get an
@differentiable
/@derivative
attributeSourceLoc
.
dup.swift:1:2: error: duplicate '@differentiable' attribute with same parameters
@differentiable
~^~~~~~~~~~~~~~
dup.swift:2:2: note: other attribute declared here << need SourceLoc to generate this note
@differentiable
^
- Revert `@differentiable` + `@derivative` attribute type-checking changes. `@differentiable` + `@derivative` attribute with the same original declaration and parameter indices are not in conflict. - Simplify TBDGen for AutoDiff symbols using `AbstractFunctionDecl::getDerivativeFunctionConfigurations`. - Add TF-1037 negative tests.
Extended test suite fails weirdly:
This may be because this PR updates the commit hash for |
The fix is: protocol witness Ideally, the diagnostics should have read |
auto implicitDiffAttr = false; | ||
if (reqDiffAttrSupersetMatch) { | ||
auto *witnessAFD = cast<AbstractFunctionDecl>(witness); | ||
(void)reqDiffAttr->getParameterIndices(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: I found that calling DifferentiableAttr::getParameterIndices
on witness attributes was sufficient to make test/AutoDiff/differentiable_attr_type_checking_primary_file.swift
pass. For good measure, I called it on requirement attributes too.
@swift-ci Please clean test tensorflow |
Fix test/Serialization/derivative_attr.swift.
@swift-ci Please test tensorflow |
@swift-ci Please test tensorflow |
Does this implementation handle duplicate but private-to-separate-files derivatives? |
I believe the The same-file derivative registration limitation hasn't been lifted from Sema yet (TF-1021), but I verified that it can be lifted trivially after this patch:
Lifting the restriction requires more testing and will be the next follow-up after this patch. |
This patch causes a fastai/fastai_dev breakage:
The issue seems similar to #28621 (comment) and should be fixable with a patch to fastai/fastai_dev. I'll investigate. Workaround in fastai/fastai_dev#309. |
Work around issues with `@differentiable` + `@derivative` attributes with different derivative generic signatures. Related discussion: swiftlang/swift#28621 (comment) https://bugs.swift.org/browse/TF-1037 tracks this issue.
…_:)`. Work around issues with `@differentiable` + `@derivative` attributes with different derivative generic signatures. Automatic differentiation can handle this enum `switch` now, so a custom derivative is no longer necessary. https://bugs.swift.org/browse/TF-1037 tracks this issue. Related discussion: swiftlang/swift#28621 (comment)
…_:)`. Work around issues with `@differentiable` + `@derivative` attributes with different derivative generic signatures. Automatic differentiation can handle this enum `switch` now, so a custom derivative is no longer necessary. https://bugs.swift.org/browse/TF-1037 tracks this issue. Related discussion: swiftlang/swift#28621 (comment)
…_:)`. Work around issues with `@differentiable` + `@derivative` attributes with different derivative generic signatures. Automatic differentiation can handle this enum `switch` now, so a custom derivative is no longer necessary. https://bugs.swift.org/browse/TF-1037 tracks this issue. Related discussion: swiftlang/swift#28621 (comment)
…_:)`. (#309) Work around issues with `@differentiable` + `@derivative` attributes with different derivative generic signatures. Automatic differentiation can handle this enum `switch` now, so a custom derivative is no longer necessary. https://bugs.swift.org/browse/TF-1037 tracks this issue. Related discussion: swiftlang/swift#28621 (comment)
Your test cases didn’t cover the scenario where duplicate ‘@Derivative(of:)’s for the same function are defined in different files in the same Swift module. |
Once the restriction is lifted, the following should compile without issues: // File A.swift
func foo(x: Float) -> Float
// File B.swift
@derivative(of: foo)
private func foo(x: Float) -> (value: Float, differential: ...)
// File C.swift
@derivative(of: foo)
private func foo(x: Float) -> (value: Float, differential: ...) Only derivative registrations in the same Swift module that have overlapping access levels should be diagnosed. |
Thanks for sharing the nuance, that makes sense! Let's definitely add a test for this in the proper patch for TF-1021. |
Confirmed that extended test suite passes. Merging! |
Previously,
@derivative
attribute type-checking created implicit@differentiable
attributes on the original declaration. This was alongstanding hack powering
@derivative
attribute derivative registration.#28608 made these changes:
@differentiable
and@derivative
attributes) are serialized in modules and are loaded from imported modules.
for derivative function lookup instead of
@differentiable
attributes.Now,
@derivative
attributes are directly lowered to differentiabilitywitnesses during SILGen, and implicit
@differentiable
attribute generationis removed.
Resolves TF-835.
Unblocks TF-1021: lifting the "same-file derivative registration only"
limitation in
@derivative
attribute type-checking. This should be possiblewithout much work, but needs testing!
Exposes TF-1037: crash due to no derivative generic signature mangling (TF-680).
Exposes TF-1040:
@differentiable
attribute limitations for class methods.Exposes TF-1041: untested protocol requirement
@differentiable
attributetype-checking logic.