Skip to content

[AutoDiff] Derive 'EuclideanDifferentiable' vector view from members' vector views. #26890

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/swift/AST/KnownIdentifiers.def
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ IDENTIFIER(x)
// Differentiable
IDENTIFIER(TangentVector)
IDENTIFIER(move)
IDENTIFIER(vectorView)
IDENTIFIER(differentiableVectorView)

// Kinds of layout constraints
IDENTIFIER_WITH_NAME(UnknownLayout, "_UnknownLayout")
Expand Down
55 changes: 36 additions & 19 deletions lib/Sema/DerivedConformanceDifferentiable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ bool DerivedConformance::canDeriveEuclideanDifferentiable(
return false;
auto &C = nominal->getASTContext();
auto *lazyResolver = C.getLazyResolver();
auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
auto *eucDiffProto =
C.getProtocol(KnownProtocolKind::EuclideanDifferentiable);
// Return true if all differentiation stored properties conform to
// `AdditiveArithmetic` and their `TangentVector` equals themselves.
SmallVector<VarDecl *, 16> diffProperties;
Expand All @@ -216,10 +217,8 @@ bool DerivedConformance::canDeriveEuclideanDifferentiable(
if (!member->hasInterfaceType())
return false;
auto varType = DC->mapTypeIntoContext(member->getValueInterfaceType());
if (!TypeChecker::conformsToProtocol(varType, addArithProto, DC, None))
return false;
auto memberAssocType = getTangentVectorType(member, DC);
return member->getType()->isEqual(memberAssocType);
return (bool)TypeChecker::conformsToProtocol(
varType, eucDiffProto, DC, None);
});
}

Expand Down Expand Up @@ -370,8 +369,8 @@ static ValueDecl *deriveDifferentiable_move(DerivedConformance &derived) {
{deriveBodyDifferentiable_move, nullptr});
}

/// Synthesize the `vectorView` property declaration.
static ValueDecl *deriveEuclideanDifferentiable_vectorView(
/// Synthesize the `differentiableVectorView` property declaration.
static ValueDecl *deriveEuclideanDifferentiable_differentiableVectorView(
DerivedConformance &derived) {
auto &C = derived.TC.Context;
auto *parentDC = derived.getConformanceContext();
Expand All @@ -383,8 +382,8 @@ static ValueDecl *deriveEuclideanDifferentiable_vectorView(
VarDecl *vectorViewDecl;
PatternBindingDecl *pbDecl;
std::tie(vectorViewDecl, pbDecl) = derived.declareDerivedProperty(
C.Id_vectorView, tangentType, tangentContextualType, /*isStatic*/ false,
/*isFinal*/ true);
C.Id_differentiableVectorView, tangentType, tangentContextualType,
/*isStatic*/ false, /*isFinal*/ true);

struct GetterSynthesizerContext {
StructDecl *tangentDecl;
Expand All @@ -397,7 +396,13 @@ static ValueDecl *deriveEuclideanDifferentiable_vectorView(
assert(context && "Invalid context");
auto *parentDC = getterDecl->getParent();
auto *nominal = parentDC->getSelfNominalTypeDecl();
auto *module = nominal->getModuleContext();
auto &C = nominal->getASTContext();
auto *eucDiffProto =
C.getProtocol(KnownProtocolKind::EuclideanDifferentiable);
auto *vectorViewReq =
eucDiffProto->lookupDirect(C.Id_differentiableVectorView).front();

SmallVector<VarDecl *, 8> diffProperties;
getStoredPropertiesForDifferentiation(nominal, nominal->getDeclContext(),
diffProperties);
Expand All @@ -419,20 +424,32 @@ static ValueDecl *deriveEuclideanDifferentiable_vectorView(

// Create a call:
// TangentVector.init(
// <property_name_1...>: self.<property_name_1>,
// <property_name_2...>: self.<property_name_2>,
// <property_name_1...>:
// self.differentiableVectorView.<property_name_1>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this comment be:

    //   TangentVector.init(
    //     <property_name_1>:
    //        self.<property_name_1>.differentiableVectorView,
    //     <property_name_2>:
    //        self.<property_name_2>.differentiableVectorView,
    //     ...

self.<property_name_1>.differentiableVectorView rather than self.differentiableVectorView.<property_name_1>.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Good catch.

// <property_name_2...>:
// self.differentiableVectorView.<property_name_2>,
// ...
// )
SmallVector<Identifier, 8> argLabels;
SmallVector<Expr *, 8> memberRefs;
auto *selfDRE = new (C) DeclRefExpr(getterDecl->getImplicitSelfDecl(),
DeclNameLoc(),
/*Implicit*/ true);
for (auto *member : diffProperties) {
auto *selfDRE = new (C) DeclRefExpr(getterDecl->getImplicitSelfDecl(),
DeclNameLoc(),
/*Implicit*/ true);
auto *memberExpr = new (C) MemberRefExpr(
selfDRE, SourceLoc(), member, DeclNameLoc(), /*Implicit*/ true);
auto memberType =
parentDC->mapTypeIntoContext(member->getValueInterfaceType());
auto confRef = module->lookupConformance(memberType, eucDiffProto);
assert(confRef &&
"Member missing conformance to `EuclideanDifferentiable`");
ConcreteDeclRef memberDeclRef = vectorViewReq;
if (confRef->isConcrete())
memberDeclRef = confRef->getConcrete()->getWitnessDecl(vectorViewReq);
argLabels.push_back(member->getName());
memberRefs.push_back(
new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(),
/*Implicit*/ true));
memberRefs.push_back(new (C) MemberRefExpr(
memberExpr, SourceLoc(), memberDeclRef, DeclNameLoc(),
/*Implicit*/ true));
}
assert(memberRefs.size() == argLabels.size());
CallExpr *callExpr =
Expand Down Expand Up @@ -875,8 +892,8 @@ ValueDecl *DerivedConformance::deriveEuclideanDifferentiable(
// Diagnose conformances in disallowed contexts.
if (checkAndDiagnoseDisallowedContext(requirement))
return nullptr;
if (requirement->getFullName() == TC.Context.Id_vectorView)
return deriveEuclideanDifferentiable_vectorView(*this);
if (requirement->getFullName() == TC.Context.Id_differentiableVectorView)
return deriveEuclideanDifferentiable_differentiableVectorView(*this);
TC.diagnose(requirement->getLoc(),
diag::broken_euclidean_differentiable_requirement);
return nullptr;
Expand Down
4 changes: 2 additions & 2 deletions lib/Sema/DerivedConformances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ ValueDecl *DerivedConformance::getDerivableRequirement(NominalTypeDecl *nominal,
return getRequirement(KnownProtocolKind::AdditiveArithmetic);

// SWIFT_ENABLE_TENSORFLOW
// EuclideanDifferentiable.vectorView
if (name.isSimpleName(ctx.Id_vectorView))
// EuclideanDifferentiable.differentiableVectorView
if (name.isSimpleName(ctx.Id_differentiableVectorView))
return getRequirement(KnownProtocolKind::EuclideanDifferentiable);

// SWIFT_ENABLE_TENSORFLOW
Expand Down
15 changes: 15 additions & 0 deletions stdlib/public/core/Array.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1965,6 +1965,14 @@ extension Array where Element : Differentiable {
}
}

extension Array.DifferentiableView : EuclideanDifferentiable
where Element : EuclideanDifferentiable {
public var differentiableVectorView: Array.DifferentiableView.TangentVector {
Array.DifferentiableView.TangentVector(
base.map { $0.differentiableVectorView })
}
}

extension Array.DifferentiableView : Equatable where Element : Equatable {
public static func == (
lhs: Array.DifferentiableView,
Expand Down Expand Up @@ -2061,6 +2069,13 @@ extension Array : Differentiable where Element : Differentiable {
}
}

extension Array : EuclideanDifferentiable
where Element : EuclideanDifferentiable {
public var differentiableVectorView: TangentVector {
TangentVector(map { $0.differentiableVectorView })
}
}

extension Array where Element : Differentiable {
public func _vjpSubscript(index: Int) ->
(Element, (Element.TangentVector) -> TangentVector)
Expand Down
36 changes: 26 additions & 10 deletions stdlib/public/core/AutoDiff.swift
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,16 @@ public protocol Differentiable {
""")
var zeroTangentVector: TangentVector { get }

@available(*, deprecated,
message: "'AllDifferentiableVariables' is now equal to 'Self' and will be removed")
@available(*, deprecated, message: """
'AllDifferentiableVariables' is now equal to 'Self' and will be removed
""")
typealias AllDifferentiableVariables = Self
}

public extension Differentiable {
@available(*, deprecated,
message: "'allDifferentiableVariables' is now equal to 'self' and will be removed")
@available(*, deprecated, message: """
'allDifferentiableVariables' is now equal to 'self' and will be removed
""")
var allDifferentiableVariables: AllDifferentiableVariables {
get { return self }
set { self = newValue }
Expand All @@ -204,8 +206,9 @@ public extension Differentiable where TangentVector == Self {
}
}

/// A type that consists of a differentiable vector space and some other
/// non-differentiable component.
/// A type that is differentiable in the Euclidean space.
/// The type may represent a vector space, or consist of a vector space and some
/// other non-differentiable component.
///
/// Mathematically, this represents a product manifold that consists of
/// a differentiable vector space and some arbitrary manifold, where the tangent
Expand All @@ -229,11 +232,11 @@ public extension Differentiable where TangentVector == Self {
/// `TangentVector` is equal to its vector space component.
public protocol EuclideanDifferentiable: Differentiable {
/// The differentiable vector component of `self`.
var vectorView: TangentVector { get }
var differentiableVectorView: TangentVector { get }
}

public extension EuclideanDifferentiable where TangentVector == Self {
var vectorView: TangentVector { _read { yield self } }
var differentiableVectorView: TangentVector { _read { yield self } }
}

/// Returns `x` like an identity function. When used in a context where `x` is
Expand Down Expand Up @@ -776,6 +779,9 @@ internal protocol _AnyDerivativeBox {
// `Differentiable` requirements.
mutating func _move(along direction: _AnyDerivativeBox)

// `EuclideanDifferentiable` requirements.
var _differentiableVectorView: _AnyDerivativeBox { get }

/// The underlying base value, type-erased to `Any`.
var _typeErasedBase: Any { get }

Expand Down Expand Up @@ -883,14 +889,19 @@ internal struct _ConcreteDerivativeBox<T> : _AnyDerivativeBox
}
_base.move(along: directionBase)
}

// `EuclideanDifferentiable` requirements.
var _differentiableVectorView: _AnyDerivativeBox {
return self
}
}

/// A type-erased derivative value.
///
/// The `AnyDerivative` type forwards its operations to an arbitrary underlying
/// base derivative value conforming to `Differentiable` and
/// `AdditiveArithmetic`, hiding the specifics of the underlying value.
public struct AnyDerivative : Differentiable & AdditiveArithmetic {
public struct AnyDerivative : EuclideanDifferentiable & AdditiveArithmetic {
internal var _box: _AnyDerivativeBox

internal init(_box: _AnyDerivativeBox) {
Expand Down Expand Up @@ -931,7 +942,7 @@ public struct AnyDerivative : Differentiable & AdditiveArithmetic {
/// Internal struct representing an opaque zero value.
@frozen
@usableFromInline
internal struct OpaqueZero : Differentiable & AdditiveArithmetic {}
internal struct OpaqueZero : EuclideanDifferentiable & AdditiveArithmetic {}

public static var zero: AnyDerivative {
return AnyDerivative(
Expand Down Expand Up @@ -974,6 +985,11 @@ public struct AnyDerivative : Differentiable & AdditiveArithmetic {
}
_box._move(along: direction._box)
}

// `EuclideanDifferentiable` requirements.
public var differentiableVectorView: TangentVector {
return self
}
}

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion stdlib/public/core/FloatingPointTypes.swift.gyb
Original file line number Diff line number Diff line change
Expand Up @@ -1302,7 +1302,7 @@ extension ${Self} : VectorProtocol {
}
}

extension ${Self} : Differentiable {
extension ${Self} : EuclideanDifferentiable {
public typealias TangentVector = ${Self}

public mutating func move(along direction: TangentVector) {
Expand Down
8 changes: 4 additions & 4 deletions stdlib/public/core/SIMDVectorTypes.swift.gyb
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public struct SIMD${n}<Scalar>: SIMD where Scalar: SIMDScalar {
/// Accesses the scalar at the specified position.
// SWIFT_ENABLE_TENSORFLOW
@differentiable(vjp: _vjpSubscript
where Scalar : Differentiable & BinaryFloatingPoint,
where Scalar : EuclideanDifferentiable & BinaryFloatingPoint,
Scalar.TangentVector : BinaryFloatingPoint)
public subscript(index: Int) -> Scalar {
@_transparent get {
Expand Down Expand Up @@ -192,14 +192,14 @@ extension SIMD${n} where Scalar: BinaryFloatingPoint {
// SWIFT_ENABLE_TENSORFLOW
extension SIMD${n} : AdditiveArithmetic where Scalar : FloatingPoint {}

extension SIMD${n} : Differentiable
where Scalar : Differentiable & BinaryFloatingPoint,
extension SIMD${n} : Differentiable & EuclideanDifferentiable
where Scalar : EuclideanDifferentiable & BinaryFloatingPoint,
Scalar.TangentVector : BinaryFloatingPoint {
public typealias TangentVector = SIMD${n}
}

extension SIMD${n}
where Scalar : Differentiable & BinaryFloatingPoint,
where Scalar : EuclideanDifferentiable & BinaryFloatingPoint,
Scalar.TangentVector : BinaryFloatingPoint {
@usableFromInline
internal func _vjpSubscript(index: Int)
Expand Down
8 changes: 4 additions & 4 deletions test/AutoDiff/derived_differentiable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public struct Foo : EuclideanDifferentiable {
// CHECK-AST: internal init(a: Float)
// CHECK-AST: public struct TangentVector
// CHECK-AST: public typealias TangentVector = Foo.TangentVector
// CHECK-AST: public var vectorView: Foo.TangentVector { get }
// CHECK-AST: public var differentiableVectorView: Foo.TangentVector { get }

// CHECK-SILGEN-LABEL: // Foo.a.getter
// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0] [ossa] @$s22derived_differentiable3FooV1aSfvg : $@convention(method) (Foo) -> Float
Expand All @@ -32,8 +32,8 @@ let _: @differentiable (AdditiveTangentIsSelf) -> Float = { x in
// CHECK-AST: internal var dummy: PointwiseMultiplicativeDummy
// CHECK-AST: internal init(a: Float, dummy: PointwiseMultiplicativeDummy)
// CHECK-AST: internal typealias TangentVector = AdditiveTangentIsSelf
// The following should not exist because when `Self == Self.TangentVector`, `vectorView` is not synthesized.
// CHECK-AST-NOT: internal var vectorView: AdditiveTangentIsSelf { get }
// The following should not exist because when `Self == Self.TangentVector`, `differentiableVectorView` is not synthesized.
// CHECK-AST-NOT: internal var differentiableVectorView: AdditiveTangentIsSelf { get }

struct TestNoDerivative : EuclideanDifferentiable {
var w: Float
Expand All @@ -46,7 +46,7 @@ struct TestNoDerivative : EuclideanDifferentiable {
// CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float)
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, ElementaryFunctions, VectorProtocol
// CHECK-AST: internal typealias TangentVector = TestNoDerivative.TangentVector
// CHECK-AST: internal var vectorView: TestNoDerivative.TangentVector { get }
// CHECK-AST: internal var differentiableVectorView: TestNoDerivative.TangentVector { get }

struct TestPointwiseMultiplicative : Differentiable {
var w: PointwiseMultiplicativeDummy
Expand Down
4 changes: 2 additions & 2 deletions test/AutoDiff/derived_differentiable_runtime.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ DerivedConformanceTests.test("EuclideanVectorView") {
init() { x = [1, 2, 3, 4]; y = .zero }
}
let x = Foo()
expectEqual(Foo.TangentVector(x: [1, 2, 3, 4]), x.vectorView)
expectEqual(Foo.TangentVector(x: [1, 2, 3, 4]), x.differentiableVectorView)
}
do {
class FooClass: EuclideanDifferentiable {
Expand All @@ -47,7 +47,7 @@ DerivedConformanceTests.test("EuclideanVectorView") {
init() { x = [1, 2, 3, 4]; y = .zero }
}
let x = FooClass()
expectEqual(FooClass.TangentVector(x: [1, 2, 3, 4]), x.vectorView)
expectEqual(FooClass.TangentVector(x: [1, 2, 3, 4]), x.differentiableVectorView)
}
}

Expand Down
3 changes: 0 additions & 3 deletions test/Sema/class_differentiable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,6 @@ struct MyVector2 : ElementaryFunctions, Differentiable, EuclideanDifferentiable
self.b = b
}
}
// Won't derive `EuclideanDifferentiable` because `MyVector2.TangentVector != MyVector2`.
// expected-error @+2 {{type 'AllMembersElementaryFunctions' does not conform to protocol 'EuclideanDifferentiable'}}
// expected-note @+1 {{do you want to add protocol stubs?}}
class AllMembersElementaryFunctions : Differentiable, EuclideanDifferentiable {
var v1: MyVector2
var v2: MyVector2
Expand Down
11 changes: 4 additions & 7 deletions test/Sema/struct_differentiable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ struct GenericVectorSpacesEqualSelf<T>
func testGenericVectorSpacesEqualSelf() {
var genericSame = GenericVectorSpacesEqualSelf<Double>(w: 1, b: 1)
genericSame.move(along: genericSame)
genericSame.move(along: genericSame.vectorView)
genericSame.move(along: genericSame.differentiableVectorView)
}

// Test nested type.
Expand Down Expand Up @@ -130,7 +130,7 @@ struct AllMembersAdditiveArithmetic : Differentiable, EuclideanDifferentiable {

// Test type `AllMembersVectorProtocol` whose members conforms to `VectorProtocol`,
// in which case we should make `TangentVector` conform to `VectorProtocol`.
struct MyVector : VectorProtocol, Differentiable {
struct MyVector : VectorProtocol, Differentiable, EuclideanDifferentiable {
var w: Float
var b: Float
}
Expand All @@ -149,9 +149,6 @@ struct MyVector2 : ElementaryFunctions, Differentiable, EuclideanDifferentiable
var b: Float
}

// Won't derive `EuclideanDifferentiable` because `MyVector2.TangentVector != MyVector2`.
// expected-error @+2 {{type 'AllMembersElementaryFunctions' does not conform to protocol 'EuclideanDifferentiable'}}
// expected-note @+1 {{do you want to add protocol stubs?}}
struct AllMembersElementaryFunctions : Differentiable, EuclideanDifferentiable {
var v1: MyVector2
var v2: MyVector2
Expand Down Expand Up @@ -186,8 +183,8 @@ struct EuclideanDifferentiableSubset : EuclideanDifferentiable {
func testEuclideanDifferentiableSubset() {
let x = EuclideanDifferentiableSubset(w: 1, b: 2, flag: false)
let tan = EuclideanDifferentiableSubset.TangentVector(w: 1, b: 1)
_ = x.vectorView.w * tan.w
_ = x.vectorView.b * tan.b
_ = x.differentiableVectorView.w * tan.w
_ = x.differentiableVectorView.b * tan.b

_ = pullback(at: x) { model in
model.w + model.b
Expand Down