Skip to content

[AutoDiff] Improve invalid stored property projection diagnostics. #32497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/swift/AST/ASTTypeIDZone.def
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ SWIFT_TYPEID(PropertyWrapperTypeInfo)
SWIFT_TYPEID(Requirement)
SWIFT_TYPEID(ResilienceExpansion)
SWIFT_TYPEID(FragileFunctionKind)
SWIFT_TYPEID(TangentPropertyInfo)
SWIFT_TYPEID(Type)
SWIFT_TYPEID(TypePair)
SWIFT_TYPEID(TypeWitnessAndDecl)
Expand Down
9 changes: 5 additions & 4 deletions include/swift/AST/ASTTypeIDs.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "swift/Basic/LLVM.h"
#include "swift/Basic/TypeID.h"

namespace swift {

class AbstractFunctionDecl;
Expand Down Expand Up @@ -58,14 +59,14 @@ class Requirement;
enum class ResilienceExpansion : unsigned;
struct FragileFunctionKind;
class SourceFile;
struct TangentPropertyInfo;
class Type;
class ValueDecl;
class VarDecl;
class Witness;
class TypeAliasDecl;
class Type;
struct TypePair;
struct TypeWitnessAndDecl;
class ValueDecl;
class VarDecl;
class Witness;
enum class AncestryFlags : uint8_t;
enum class ImplicitMemberAction : uint8_t;
struct FingerprintAndMembers;
Expand Down
94 changes: 94 additions & 0 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class AnyFunctionType;
class SourceFile;
class SILFunctionType;
class TupleType;
class VarDecl;

/// A function type differentiability kind.
enum class DifferentiabilityKind : uint8_t {
Expand Down Expand Up @@ -459,6 +460,99 @@ class DerivativeFunctionTypeError
}
};

/// Describes the "tangent stored property" corresponding to an original stored
/// property in a `Differentiable`-conforming type.
///
/// The tangent stored property is the stored property in the `TangentVector`
/// struct of the `Differentiable`-conforming type, with the same name as the
/// original stored property and with the original stored property's
/// `TangentVector` type.
struct TangentPropertyInfo {
struct Error {
enum class Kind {
/// The original property is `@noDerivative`.
NoDerivativeOriginalProperty,
/// The nominal parent type does not conform to `Differentiable`.
NominalParentNotDifferentiable,
/// The original property's type does not conform to `Differentiable`.
OriginalPropertyNotDifferentiable,
/// The parent `TangentVector` type is not a struct.
ParentTangentVectorNotStruct,
/// The parent `TangentVector` struct does not declare a stored property
/// with the same name as the original property.
TangentPropertyNotFound,
/// The tangent property's type is not equal to the original property's
/// `TangentVector` type.
TangentPropertyWrongType,
/// The tangent property is not a stored property.
TangentPropertyNotStored
};

/// The error kind.
Kind kind;

private:
union Value {
Type type;
Value(Type type) : type(type) {}
Value() {}
} value;

public:
Error(Kind kind) : kind(kind), value() {
assert(kind == Kind::NoDerivativeOriginalProperty ||
kind == Kind::NominalParentNotDifferentiable ||
kind == Kind::OriginalPropertyNotDifferentiable ||
kind == Kind::ParentTangentVectorNotStruct ||
kind == Kind::TangentPropertyNotFound ||
kind == Kind::TangentPropertyNotStored);
};

Error(Kind kind, Type type) : kind(kind), value(type) {
assert(kind == Kind::TangentPropertyWrongType);
};

Type getType() const {
assert(kind == Kind::TangentPropertyWrongType);
return value.type;
}

friend bool operator==(const Error &lhs, const Error &rhs);
};

/// The tangent stored property.
VarDecl *tangentProperty = nullptr;

/// An optional error.
Optional<Error> error = None;

private:
TangentPropertyInfo(VarDecl *tangentProperty, Optional<Error> error)
: tangentProperty(tangentProperty), error(error) {}

public:
TangentPropertyInfo(VarDecl *tangentProperty)
: TangentPropertyInfo(tangentProperty, None) {}

TangentPropertyInfo(Error::Kind errorKind)
: TangentPropertyInfo(nullptr, Error(errorKind)) {}

TangentPropertyInfo(Error::Kind errorKind, Type errorType)
: TangentPropertyInfo(nullptr, Error(errorKind, errorType)) {}

/// Returns `true` iff this tangent property info is valid.
bool isValid() const { return tangentProperty && !error; }

explicit operator bool() const { return isValid(); }

friend bool operator==(const TangentPropertyInfo &lhs,
const TangentPropertyInfo &rhs) {
return lhs.tangentProperty == rhs.tangentProperty && lhs.error == rhs.error;
}
};

void simple_display(llvm::raw_ostream &OS, TangentPropertyInfo info);

/// The key type used for uniquing `SILDifferentiabilityWitness` in
/// `SILModule`: original function name, parameter indices, result indices, and
/// derivative generic signature.
Expand Down
21 changes: 19 additions & 2 deletions include/swift/AST/DiagnosticsSIL.def
Original file line number Diff line number Diff line change
Expand Up @@ -504,9 +504,26 @@ NOTE(autodiff_loadable_value_addressonly_tangent_unsupported,none,
"properties", (Type, Type))
NOTE(autodiff_enums_unsupported,none,
"differentiating enum values is not yet supported", ())
NOTE(autodiff_stored_property_parent_not_differentiable,none,
"cannot differentiate access to property '%0.%1' because '%0' does not "
"conform to 'Differentiable'", (StringRef, StringRef))
NOTE(autodiff_stored_property_not_differentiable,none,
"cannot differentiate access to property '%0.%1' because property type %2 "
"does not conform to 'Differentiable'", (StringRef, StringRef, Type))
NOTE(autodiff_stored_property_tangent_not_struct,none,
"cannot differentiate access to property '%0.%1' because "
"'%0.TangentVector' is not a struct", (StringRef, StringRef))
NOTE(autodiff_stored_property_no_corresponding_tangent,none,
"property cannot be differentiated because '%0.TangentVector' does not "
"have a member named '%1'", (StringRef, StringRef))
"cannot differentiate access to property '%0.%1' because "
"'%0.TangentVector' does not have a stored property named '%1'",
(StringRef, StringRef))
NOTE(autodiff_tangent_property_wrong_type,none,
"cannot differentiate access to property '%0.%1' because "
"'%0.TangentVector.%1' does not have expected type %2",
(StringRef, StringRef, /*originalPropertyTanType*/ Type))
NOTE(autodiff_tangent_property_not_stored,none,
"cannot differentiate access to property '%0.%1' because "
"'%0.TangentVector.%1' is not a stored property", (StringRef, StringRef))
NOTE(autodiff_coroutines_not_supported,none,
"differentiation of coroutine calls is not yet supported", ())
NOTE(autodiff_cannot_differentiate_writes_to_global_variables,none,
Expand Down
21 changes: 21 additions & 0 deletions include/swift/AST/TypeCheckRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -2191,6 +2191,27 @@ class DerivativeAttrOriginalDeclRequest
bool isCached() const { return true; }
};

/// Resolves the "tangent stored property" corresponding to an original stored
/// property in a `Differentiable`-conforming type.
class TangentStoredPropertyRequest
: public SimpleRequest<TangentStoredPropertyRequest,
TangentPropertyInfo(VarDecl *),
RequestFlags::Cached> {
public:
using SimpleRequest::SimpleRequest;

private:
friend SimpleRequest;

// Evaluation.
TangentPropertyInfo evaluate(Evaluator &evaluator,
VarDecl *originalField) const;

public:
// Caching.
bool isCached() const { return true; }
};

/// Checks whether a type eraser has a viable initializer.
class TypeEraserHasViableInitRequest
: public SimpleRequest<TypeEraserHasViableInitRequest,
Expand Down
2 changes: 2 additions & 0 deletions include/swift/AST/TypeCheckerTypeIDZone.def
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ SWIFT_REQUEST(TypeChecker, SuperclassTypeRequest,
SWIFT_REQUEST(TypeChecker, SynthesizeAccessorRequest,
AccessorDecl *(AbstractStorageDecl *, AccessorKind),
SeparatelyCached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, TangentStoredPropertyRequest,
llvm::Expected<VarDecl *>(VarDecl *), Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, TypeCheckFunctionBodyRequest,
bool(AbstractFunctionDecl *), Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, TypeCheckFunctionBodyAtLocRequest,
Expand Down
7 changes: 7 additions & 0 deletions include/swift/SIL/SILInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -5758,6 +5758,13 @@ class FieldIndexCacheBase : public SingleValueInstruction {
return s;
}

static bool classof(const SILNode *node) {
SILNodeKind kind = node->getKind();
return kind == SILNodeKind::StructExtractInst ||
kind == SILNodeKind::StructElementAddrInst ||
kind == SILNodeKind::RefElementAddrInst;
}

private:
unsigned cacheFieldIndex();
};
Expand Down
8 changes: 2 additions & 6 deletions include/swift/SILOptimizer/Differentiation/ADContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,9 @@ ADContext::emitNondifferentiabilityError(SILValue value,
getADDebugStream() << "For value:\n" << value;
getADDebugStream() << "With invoker:\n" << invoker << '\n';
});
auto valueLoc = value.getLoc().getSourceLoc();
// If instruction does not have a valid location, use the function location
// as a fallback. Improves diagnostics in some cases.
if (valueLoc.isInvalid())
valueLoc = value->getFunction()->getLocation().getSourceLoc();
auto valueLoc = getValidLocation(value).getSourceLoc();
return emitNondifferentiabilityError(valueLoc, invoker, diag,
std::forward<U>(args)...);
}
Expand All @@ -272,12 +270,10 @@ ADContext::emitNondifferentiabilityError(SILInstruction *inst,
getADDebugStream() << "For instruction:\n" << *inst;
getADDebugStream() << "With invoker:\n" << invoker << '\n';
});
auto instLoc = inst->getLoc().getSourceLoc();
// If instruction does not have a valid location, use the function location
// as a fallback. Improves diagnostics for `ref_element_addr` generated in
// synthesized stored property getters.
if (instLoc.isInvalid())
instLoc = inst->getFunction()->getLocation().getSourceLoc();
auto instLoc = getValidLocation(inst).getSourceLoc();
return emitNondifferentiabilityError(instLoc, invoker, diag,
std::forward<U>(args)...);
}
Expand Down
37 changes: 35 additions & 2 deletions include/swift/SILOptimizer/Differentiation/Common.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,27 @@
#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_COMMON_H
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_COMMON_H

#include "swift/AST/DiagnosticsSIL.h"
#include "swift/AST/Expr.h"
#include "swift/AST/SemanticAttrs.h"
#include "swift/SIL/SILDifferentiabilityWitness.h"
#include "swift/SIL/SILFunction.h"
#include "swift/SIL/SILModule.h"
#include "swift/SIL/TypeSubstCloner.h"
#include "swift/SILOptimizer/Analysis/ArraySemantic.h"
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"

namespace swift {

namespace autodiff {

class ADContext;

//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//

namespace autodiff {

/// Prints an "[AD] " prefix to `llvm::dbgs()` and returns the debug stream.
/// This is being used to print short debug messages within the AD pass.
raw_ostream &getADDebugStream();
Expand Down Expand Up @@ -136,6 +141,34 @@ template <class Inst> Inst *peerThroughFunctionConversions(SILValue value) {
return nullptr;
}

//===----------------------------------------------------------------------===//
// Diagnostic utilities
//===----------------------------------------------------------------------===//

// Returns `v`'s location if it is valid. Otherwise, returns `v`'s function's
// location as as a fallback. Used for diagnostics.
SILLocation getValidLocation(SILValue v);

// Returns `inst`'s location if it is valid. Otherwise, returns `inst`'s
// function's location as as a fallback. Used for diagnostics.
SILLocation getValidLocation(SILInstruction *inst);

//===----------------------------------------------------------------------===//
// Tangent property lookup utilities
//===----------------------------------------------------------------------===//

/// Returns the tangent stored property of `originalField`. On error, emits
/// diagnostic and returns nullptr.
VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField,
SILLocation loc,
DifferentiationInvoker invoker);

/// Returns the tangent stored property of the original stored property
/// referenced by `inst`. On error, emits diagnostic and returns nullptr.
VarDecl *getTangentStoredProperty(ADContext &context,
FieldIndexCacheBase *projectionInst,
DifferentiationInvoker invoker);

//===----------------------------------------------------------------------===//
// Code emission utilities
//===----------------------------------------------------------------------===//
Expand Down
Loading