Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/swift/AST/ASTDemangler.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ class ASTBuilder {

Type createImplFunctionType(
Demangle::ImplParameterConvention calleeConvention,
Demangle::ImplCoroutineKind coroutineKind,
ArrayRef<Demangle::ImplFunctionParam<Type>> params,
ArrayRef<Demangle::ImplFunctionYield<Type>> yields,
ArrayRef<Demangle::ImplFunctionResult<Type>> results,
std::optional<Demangle::ImplFunctionResult<Type>> errorResult,
ImplFunctionTypeFlags flags);
Expand Down
2 changes: 2 additions & 0 deletions include/swift/AST/DiagnosticsSIL.def
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,8 @@ NOTE(autodiff_cannot_differentiate_through_multiple_results,none,
"cannot differentiate through multiple results", ())
NOTE(autodiff_cannot_differentiate_through_inout_arguments,none,
"cannot differentiate through 'inout' arguments", ())
NOTE(autodiff_cannot_differentiate_through_direct_yield,none,
"cannot differentiate through a direct yield result", ())
NOTE(autodiff_enums_unsupported,none,
"differentiating enum values is not yet supported", ())
NOTE(autodiff_stored_property_parent_not_differentiable,none,
Expand Down
4 changes: 3 additions & 1 deletion include/swift/AST/IndexSubset.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@ class IndexSubset : public llvm::FoldingSetNode {
static IndexSubset *get(ASTContext &ctx, unsigned capacity,
ArrayRef<unsigned> indices) {
SmallBitVector indicesBitVec(capacity, false);
for (auto index : indices)
for (auto index : indices) {
assert(index < capacity);
indicesBitVec.set(index);
}
return IndexSubset::get(ctx, indicesBitVec);
}

Expand Down
5 changes: 4 additions & 1 deletion include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -5174,8 +5174,11 @@ class SILFunctionType final
/// Returns the number of function potential semantic results:
/// * Usual results
/// * Inout parameters
/// * yields
unsigned getNumAutoDiffSemanticResults() const {
return getNumResults() + getNumAutoDiffSemanticResultsParameters();
return getNumResults() +
getNumAutoDiffSemanticResultsParameters() +
getNumYields();
}

/// Get the generic signature that the component types are specified
Expand Down
1 change: 1 addition & 0 deletions include/swift/Demangling/Demangle.h
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ struct [[nodiscard]] ManglingError {
UnknownEncoding,
InvalidImplCalleeConvention,
InvalidImplDifferentiability,
InvalidImplCoroutineKind,
InvalidImplFunctionAttribute,
InvalidImplParameterConvention,
InvalidImplParameterTransferring,
Expand Down
1 change: 1 addition & 0 deletions include/swift/Demangling/DemangleNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ NODE(ImplFunctionAttribute)
NODE(ImplFunctionConvention)
NODE(ImplFunctionConventionName)
NODE(ImplFunctionType)
NODE(ImplCoroutineKind)
NODE(ImplInvocationSubstitutions)
CONTEXT_NODE(ImplicitClosure)
NODE(ImplParameter)
Expand Down
31 changes: 27 additions & 4 deletions include/swift/Demangling/TypeDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ enum class ImplMetatypeRepresentation {
ObjC,
};

enum class ImplCoroutineKind {
None,
YieldOnce,
YieldMany,
};

/// Describe a function parameter, parameterized on the type
/// representation.
template <typename BuiltType>
Expand Down Expand Up @@ -188,6 +194,9 @@ class ImplFunctionParam {
BuiltType getType() const { return Type; }
};

template<typename Type>
using ImplFunctionYield = ImplFunctionParam<Type>;

enum class ImplResultConvention {
Indirect,
Owned,
Expand Down Expand Up @@ -1023,9 +1032,11 @@ class TypeDecoder {
case NodeKind::ImplFunctionType: {
auto calleeConvention = ImplParameterConvention::Direct_Unowned;
llvm::SmallVector<ImplFunctionParam<BuiltType>, 8> parameters;
llvm::SmallVector<ImplFunctionYield<BuiltType>, 8> yields;
llvm::SmallVector<ImplFunctionResult<BuiltType>, 8> results;
llvm::SmallVector<ImplFunctionResult<BuiltType>, 8> errorResults;
ImplFunctionTypeFlags flags;
ImplCoroutineKind coroutineKind = ImplCoroutineKind::None;

for (unsigned i = 0; i < Node->getNumChildren(); i++) {
auto child = Node->getChild(i);
Expand Down Expand Up @@ -1066,6 +1077,15 @@ class TypeDecoder {
} else if (child->getText() == "@async") {
flags = flags.withAsync();
}
} else if (child->getKind() == NodeKind::ImplCoroutineKind) {
if (!child->hasText())
return MAKE_NODE_TYPE_ERROR0(child, "expected text");
if (child->getText() == "yield_once") {
coroutineKind = ImplCoroutineKind::YieldOnce;
} else if (child->getText() == "yield_many") {
coroutineKind = ImplCoroutineKind::YieldMany;
} else
return MAKE_NODE_TYPE_ERROR0(child, "failed to decode coroutine kind");
} else if (child->getKind() == NodeKind::ImplDifferentiabilityKind) {
ImplFunctionDifferentiabilityKind implDiffKind;
switch ((MangledDifferentiabilityKind)child->getIndex()) {
Expand All @@ -1088,10 +1108,14 @@ class TypeDecoder {
if (decodeImplFunctionParam(child, depth + 1, parameters))
return MAKE_NODE_TYPE_ERROR0(child,
"failed to decode function parameter");
} else if (child->getKind() == NodeKind::ImplYield) {
if (decodeImplFunctionParam(child, depth + 1, yields))
return MAKE_NODE_TYPE_ERROR0(child,
"failed to decode function yields");
} else if (child->getKind() == NodeKind::ImplResult) {
if (decodeImplFunctionParam(child, depth + 1, results))
return MAKE_NODE_TYPE_ERROR0(child,
"failed to decode function parameter");
"failed to decode function results");
} else if (child->getKind() == NodeKind::ImplErrorResult) {
if (decodeImplFunctionPart(child, depth + 1, errorResults))
return MAKE_NODE_TYPE_ERROR0(child,
Expand All @@ -1115,11 +1139,10 @@ class TypeDecoder {

// TODO: Some cases not handled above, but *probably* they cannot
// appear as the types of values in SIL (yet?):
// - functions with yield returns
// - functions with generic signatures
// - foreign error conventions
return Builder.createImplFunctionType(calleeConvention,
parameters, results,
return Builder.createImplFunctionType(calleeConvention, coroutineKind,
parameters, yields, results,
errorResult, flags);
}

Expand Down
2 changes: 2 additions & 0 deletions include/swift/RemoteInspection/TypeRefBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -1134,7 +1134,9 @@ class TypeRefBuilder {

const FunctionTypeRef *createImplFunctionType(
Demangle::ImplParameterConvention calleeConvention,
Demangle::ImplCoroutineKind coroutineKind,
llvm::ArrayRef<Demangle::ImplFunctionParam<const TypeRef *>> params,
llvm::ArrayRef<Demangle::ImplFunctionYield<const TypeRef *>> yields,
llvm::ArrayRef<Demangle::ImplFunctionResult<const TypeRef *>> results,
std::optional<Demangle::ImplFunctionResult<const TypeRef *>> errorResult,
ImplFunctionTypeFlags flags) {
Expand Down
8 changes: 8 additions & 0 deletions include/swift/SIL/SILFunctionConventions.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,14 @@ class SILFunctionConventions {
idx < indirectResults + getNumIndirectSILErrorResults();
}

unsigned getNumAutoDiffSemanticResults() const {
return funcTy->getNumAutoDiffSemanticResults();
}

unsigned getNumAutoDiffSemanticResultParameters() const {
return funcTy->getNumAutoDiffSemanticResultsParameters();
}

/// Are any SIL results passed as address-typed arguments?
bool hasIndirectSILResults() const { return getNumIndirectSILResults() != 0; }
bool hasIndirectSILErrorResults() const { return getNumIndirectSILErrorResults() != 0; }
Expand Down
11 changes: 9 additions & 2 deletions include/swift/SILOptimizer/Differentiation/ADContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADCONTEXT_H
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADCONTEXT_H

#include "swift/SIL/ApplySite.h"
#include "swift/SILOptimizer/Differentiation/Common.h"
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"

Expand Down Expand Up @@ -51,6 +52,12 @@ struct NestedApplyInfo {
/// The original pullback type before reabstraction. `None` if the pullback
/// type is not reabstracted.
std::optional<CanSILFunctionType> originalPullbackType;
/// Index of `apply` pullback in nested pullback call
unsigned pullbackIdx = -1U;
/// Pullback value itself that is memoized in some cases (e.g. pullback is
/// called by `begin_apply`, but should be destroyed after `end_apply`).
SILValue pullback = SILValue();
SILValue beginApplyToken = SILValue();
};

/// Per-module contextual information for the Differentiation pass.
Expand Down Expand Up @@ -97,7 +104,7 @@ class ADContext {

/// Mapping from original `apply` instructions to their corresponding
/// `NestedApplyInfo`s.
llvm::DenseMap<ApplyInst *, NestedApplyInfo> nestedApplyInfo;
llvm::DenseMap<FullApplySite, NestedApplyInfo> nestedApplyInfo;

/// List of generated functions (JVPs, VJPs, pullbacks, and thunks).
/// Saved for deletion during cleanup.
Expand Down Expand Up @@ -185,7 +192,7 @@ class ADContext {
invokers.insert({witness, DifferentiationInvoker(witness)});
}

llvm::DenseMap<ApplyInst *, NestedApplyInfo> &getNestedApplyInfo() {
llvm::DenseMap<FullApplySite, NestedApplyInfo> &getNestedApplyInfo() {
return nestedApplyInfo;
}

Expand Down
5 changes: 3 additions & 2 deletions include/swift/SILOptimizer/Differentiation/Common.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "swift/AST/DiagnosticsSIL.h"
#include "swift/AST/Expr.h"
#include "swift/AST/SemanticAttrs.h"
#include "swift/SIL/ApplySite.h"
#include "swift/SIL/SILDifferentiabilityWitness.h"
#include "swift/SIL/SILFunction.h"
#include "swift/SIL/Projection.h"
Expand Down Expand Up @@ -112,15 +113,15 @@ void collectAllDirectResultsInTypeOrder(SILFunction &function,
/// Given a function call site, gathers all of its actual results (both direct
/// and indirect) in an order defined by its result type.
void collectAllActualResultsInTypeOrder(
ApplyInst *ai, ArrayRef<SILValue> extractedDirectResults,
FullApplySite fai, ArrayRef<SILValue> extractedDirectResults,
SmallVectorImpl<SILValue> &results);

/// For an `apply` instruction with active results, compute:
/// - The results of the `apply` instruction, in type order.
/// - The set of minimal parameter and result indices for differentiating the
/// `apply` instruction.
void collectMinimalIndicesForFunctionCall(
ApplyInst *ai, const AutoDiffConfig &parentConfig,
FullApplySite fai, const AutoDiffConfig &parentConfig,
const DifferentiableActivityInfo &activityInfo,
SmallVectorImpl<SILValue> &results, SmallVectorImpl<unsigned> &paramIndices,
SmallVectorImpl<unsigned> &resultIndices);
Expand Down
24 changes: 12 additions & 12 deletions include/swift/SILOptimizer/Differentiation/LinearMapInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ class LinearMapInfo {
/// For differentials: these are successor enums.
llvm::DenseMap<SILBasicBlock *, EnumDecl *> branchingTraceDecls;

/// Mapping from `apply` instructions in the original function to the
/// Mapping from `apply` / `begin_apply` instructions in the original function to the
/// corresponding linear map tuple type index.
llvm::DenseMap<ApplyInst *, unsigned> linearMapIndexMap;
llvm::DenseMap<FullApplySite, unsigned> linearMapIndexMap;

/// Mapping from predecessor-successor basic block pairs in the original
/// function to the corresponding branching trace enum case.
Expand Down Expand Up @@ -112,9 +112,9 @@ class LinearMapInfo {
void populateBranchingTraceDecl(SILBasicBlock *originalBB,
SILLoopInfo *loopInfo);

/// Given an `apply` instruction, conditionally gets a linear map tuple field
/// AST type for its linear map function if it is active.
Type getLinearMapType(ADContext &context, ApplyInst *ai);
/// Given an `apply` / `begin_apply` instruction, conditionally gets a linear
/// map tuple field AST type for its linear map function if it is active.
Type getLinearMapType(ADContext &context, FullApplySite fai);

/// Generates linear map struct and branching enum declarations for the given
/// function. Linear map structs are populated with linear map fields and a
Expand Down Expand Up @@ -180,18 +180,18 @@ class LinearMapInfo {
}

/// Finds the linear map index in the pullback tuple for the given
/// `apply` instruction in the original function.
unsigned lookUpLinearMapIndex(ApplyInst *ai) const {
assert(ai->getFunction() == original);
auto lookup = linearMapIndexMap.find(ai);
/// `apply` / `begin_apply` instruction in the original function.
unsigned lookUpLinearMapIndex(FullApplySite fas) const {
assert(fas->getFunction() == original);
auto lookup = linearMapIndexMap.find(fas);
assert(lookup != linearMapIndexMap.end() &&
"No linear map field corresponding to the given `apply`");
return lookup->getSecond();
}

Type lookUpLinearMapType(ApplyInst *ai) const {
unsigned idx = lookUpLinearMapIndex(ai);
return getLinearMapTupleType(ai->getParentBlock())->getElement(idx).getType();
Type lookUpLinearMapType(FullApplySite fas) const {
unsigned idx = lookUpLinearMapIndex(fas);
return getLinearMapTupleType(fas->getParent())->getElement(idx).getType();
}

bool hasHeapAllocatedContext() const {
Expand Down
5 changes: 5 additions & 0 deletions include/swift/SILOptimizer/Differentiation/Thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb,
CanSILFunctionType fromType,
CanSILFunctionType toType);

SILValue reabstractCoroutine(
SILBuilder &builder, SILOptFunctionBuilder &fb, SILLocation loc,
SILValue fn, CanSILFunctionType toType,
std::function<SubstitutionMap(SubstitutionMap)> remapSubstitutions);

/// Reabstracts the given function-typed value `fn` to the target type `toType`.
/// Remaps substitutions using `remapSubstitutions`.
SILValue reabstractFunction(
Expand Down
25 changes: 24 additions & 1 deletion lib/AST/ASTDemangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,17 +571,33 @@ getResultOptions(ImplResultInfoOptions implOptions) {
return result;
}

static SILCoroutineKind
getCoroutineKind(ImplCoroutineKind kind) {
switch (kind) {
case ImplCoroutineKind::None:
return SILCoroutineKind::None;
case ImplCoroutineKind::YieldOnce:
return SILCoroutineKind::YieldOnce;
case ImplCoroutineKind::YieldMany:
return SILCoroutineKind::YieldMany;
}
llvm_unreachable("unknown coroutine kind");
}

Type ASTBuilder::createImplFunctionType(
Demangle::ImplParameterConvention calleeConvention,
Demangle::ImplCoroutineKind coroutineKind,
ArrayRef<Demangle::ImplFunctionParam<Type>> params,
ArrayRef<Demangle::ImplFunctionYield<Type>> yields,
ArrayRef<Demangle::ImplFunctionResult<Type>> results,
std::optional<Demangle::ImplFunctionResult<Type>> errorResult,
ImplFunctionTypeFlags flags) {
GenericSignature genericSig;

SILCoroutineKind funcCoroutineKind = SILCoroutineKind::None;
ParameterConvention funcCalleeConvention =
getParameterConvention(calleeConvention);
SILCoroutineKind funcCoroutineKind =
getCoroutineKind(coroutineKind);

SILFunctionTypeRepresentation representation;
switch (flags.getRepresentation()) {
Expand Down Expand Up @@ -644,6 +660,13 @@ Type ASTBuilder::createImplFunctionType(
funcParams.emplace_back(type, conv, options);
}

for (const auto &yield : yields) {
auto type = yield.getType()->getCanonicalType();
auto conv = getParameterConvention(yield.getConvention());
auto options = *getParameterOptions(yield.getOptions());
funcParams.emplace_back(type, conv, options);
}

for (const auto &result : results) {
auto type = result.getType()->getCanonicalType();
auto conv = getResultConvention(result.getConvention());
Expand Down
34 changes: 23 additions & 11 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ DifferentiabilityWitnessFunctionKind::getAsDerivativeFunctionKind() const {
llvm_unreachable("invalid derivative kind");
}

void AutoDiffConfig::dump() const {
print(llvm::errs());
}

void AutoDiffConfig::print(llvm::raw_ostream &s) const {
s << "(parameters=";
parameterIndices->print(s);
Expand Down Expand Up @@ -354,22 +358,30 @@ GenericSignature autodiff::getConstrainedDerivativeGenericSignature(
// Require differentiability results to conform to `Differentiable`.
SmallVector<SILResultInfo, 2> originalResults;
getSemanticResults(originalFnTy, diffParamIndices, originalResults);
unsigned firstSemanticParamResultIdx = originalFnTy->getNumResults();
unsigned firstYieldResultIndex = originalFnTy->getNumResults() +
originalFnTy->getNumAutoDiffSemanticResultsParameters();
for (unsigned resultIdx : diffResultIndices->getIndices()) {
// Handle formal original result.
if (resultIdx < originalFnTy->getNumResults()) {
if (resultIdx < firstSemanticParamResultIdx) {
auto resultType = originalResults[resultIdx].getInterfaceType();
addRequirement(resultType);
continue;
} else if (resultIdx < firstYieldResultIndex) {
// Handle original semantic result parameters.
auto resultParamIndex = resultIdx - originalFnTy->getNumResults();
auto resultParamIt = std::next(
originalFnTy->getAutoDiffSemanticResultsParameters().begin(),
resultParamIndex);
auto paramIndex =
std::distance(originalFnTy->getParameters().begin(), &*resultParamIt);
addRequirement(originalFnTy->getParameters()[paramIndex].getInterfaceType());
} else {
// Handle formal original yields.
assert(originalFnTy->isCoroutine());
assert(originalFnTy->getCoroutineKind() == SILCoroutineKind::YieldOnce);
auto yieldResultIndex = resultIdx - firstYieldResultIndex;
addRequirement(originalFnTy->getYields()[yieldResultIndex].getInterfaceType());
}
// Handle original semantic result parameters.
// FIXME: Constraint generic yields when we will start supporting them
auto resultParamIndex = resultIdx - originalFnTy->getNumResults();
auto resultParamIt = std::next(
originalFnTy->getAutoDiffSemanticResultsParameters().begin(),
resultParamIndex);
auto paramIndex =
std::distance(originalFnTy->getParameters().begin(), &*resultParamIt);
addRequirement(originalFnTy->getParameters()[paramIndex].getInterfaceType());
}

return buildGenericSignature(ctx, derivativeGenSig,
Expand Down
Loading