Skip to content

Commit 2d08a3f

Browse files
authored
[AutoDiff upstream] Add SIL derivative function type calculation. (#29396)
Add `SILFunctionType::getAutoDiffDerivativeFunctionType`. It computes the derivative `SILFunctionType` for an "original" `SILFunctionType`, given: - Differentiability parameter indices - Differentiability result index - Derivative function kind - Derivative function generic signature (optional) - Other auxiliary parameters Add doc comments explaining typing rules, preconditions, and other details. Partially resolves TF-1124. Unblocks upstreaming other SIL differentiable programming infrastructure.
1 parent 7b611fc commit 2d08a3f

File tree

2 files changed

+274
-0
lines changed

2 files changed

+274
-0
lines changed

include/swift/AST/Types.h

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4390,6 +4390,89 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
43904390

43914391
const clang::FunctionType *getClangFunctionType() const;
43924392

4393+
/// Returns the type of the derivative function for the given parameter
4394+
/// indices, result index, derivative function kind, derivative function
4395+
/// generic signature (optional), and other auxiliary parameters.
4396+
///
4397+
/// Preconditions:
4398+
/// - Parameters corresponding to parameter indices must conform to
4399+
/// `Differentiable`.
4400+
/// - The result corresponding to the result index must conform to
4401+
/// `Differentiable`.
4402+
///
4403+
/// Typing rules, given:
4404+
/// - Original function type: $(T0, T1, ...) -> (R0, R1, ...)
4405+
///
4406+
/// Terminology:
4407+
/// - The derivative of a `Differentiable`-conforming type has the
4408+
/// `TangentVector` associated type. `TangentVector` is abbreviated as `Tan`
4409+
/// below.
4410+
/// - "wrt" parameters refers to parameters indicated by the parameter
4411+
/// indices.
4412+
/// - "wrt" result refers to the result indicated by the result index.
4413+
///
4414+
/// JVP derivative type:
4415+
/// - Takes original parameters.
4416+
/// - Returns original results, followed by a differential function, which
4417+
/// takes "wrt" parameter derivatives and returns a "wrt" result derivative.
4418+
///
4419+
/// $(T0, ...) -> (R0, ..., (T0.Tan, T1.Tan, ...) -> R0.Tan)
4420+
/// ^~~~~~~ ^~~~~~~~~~~~~~~~~~~ ^~~~~~
4421+
/// original results | derivatives wrt params | derivative wrt result
4422+
///
4423+
/// VJP derivative type:
4424+
/// - Takes original parameters.
4425+
/// - Returns original results, followed by a pullback function, which
4426+
/// takes a "wrt" result derivative and returns "wrt" parameter derivatives.
4427+
///
4428+
/// $(T0, ...) -> (R0, ..., (R0.Tan) -> (T0.Tan, T1.Tan, ...))
4429+
/// ^~~~~~~ ^~~~~~ ^~~~~~~~~~~~~~~~~~~
4430+
/// original results | derivative wrt result | derivatives wrt params
4431+
///
4432+
/// A "constrained derivative generic signature" is computed from
4433+
/// `derivativeFunctionGenericSignature`, if specified. Otherwise, it is
4434+
/// computed from the original generic signature. A "constrained derivative
4435+
/// generic signature" requires all "wrt" parameters to conform to
4436+
/// `Differentiable`; this is important for correctness.
4437+
///
4438+
/// This "constrained derivative generic signature" is used for
4439+
/// parameter/result type lowering. It is used as the actual generic signature
4440+
/// of the derivative function type iff the original function type has a
4441+
/// generic signature and not all generic parameters are bound to concrete
4442+
/// types. Otherwise, no derivative generic signature is used.
4443+
///
4444+
/// Other properties of the original function type are copied exactly:
4445+
/// `ExtInfo`, coroutine kind, callee convention, yields, optional error
4446+
/// result, witness method conformance, etc.
4447+
///
4448+
/// Special cases:
4449+
/// - Reabstraction thunks have special derivative type calculation. The
4450+
/// original function-typed last parameter is transformed into a
4451+
/// `@differentiable` function-typed parameter in the derivative type. This
4452+
/// is necessary for the differentiation transform to support reabstraction
4453+
/// thunk differentiation because the function argument is opaque and cannot
4454+
/// be differentiated. Instead, the argument is made `@differentiable` and
4455+
/// reabstraction thunk JVP/VJP callers are responsible for passing a
4456+
/// `@differentiable` function.
4457+
/// - TODO(TF-1036): Investigate more efficient reabstraction thunk
4458+
/// derivative approaches. The last argument can simply be a
4459+
/// corresponding derivative function, instead of a `@differentiable`
4460+
/// function - this is more direct. It may be possible to implement
4461+
/// reabstraction thunk derivatives using "reabstraction thunks for
4462+
/// the original function's derivative", avoiding extra code generation.
4463+
///
4464+
/// Caveats:
4465+
/// - We may support multiple result indices instead of a single result index
4466+
/// eventually. At the SIL level, this enables differentiating wrt multiple
4467+
/// function results. At the Swift level, this enables differentiating wrt
4468+
/// multiple tuple elements for tuple-returning functions.
4469+
CanSILFunctionType getAutoDiffDerivativeFunctionType(
4470+
IndexSubset *parameterIndices, unsigned resultIndex,
4471+
AutoDiffDerivativeFunctionKind kind, Lowering::TypeConverter &TC,
4472+
LookupConformanceFn lookupConformance,
4473+
CanGenericSignature derivativeFunctionGenericSignature = nullptr,
4474+
bool isReabstractionThunk = false);
4475+
43934476
ExtInfo getExtInfo() const {
43944477
return ExtInfo(Bits.SILFunctionType.ExtInfoBits, getClangFunctionType());
43954478
}

lib/SIL/SILFunctionType.cpp

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "swift/AST/DiagnosticsSIL.h"
2323
#include "swift/AST/ForeignInfo.h"
2424
#include "swift/AST/GenericEnvironment.h"
25+
#include "swift/AST/GenericSignatureBuilder.h"
2526
#include "swift/AST/Module.h"
2627
#include "swift/AST/ModuleLoader.h"
2728
#include "swift/AST/ProtocolConformance.h"
@@ -190,6 +191,196 @@ SILFunctionType::getWitnessMethodClass(SILModule &M) const {
190191
return nullptr;
191192
}
192193

194+
// Returns the canonical generic signature for an autodiff derivative function
195+
// given an existing derivative function generic signature. All
196+
// differentiability parameters are required to conform to `Differentiable`.
197+
static CanGenericSignature getAutoDiffDerivativeFunctionGenericSignature(
198+
CanGenericSignature derivativeFnGenSig,
199+
ArrayRef<SILParameterInfo> originalParameters,
200+
IndexSubset *parameterIndices, ModuleDecl *module) {
201+
if (!derivativeFnGenSig)
202+
return nullptr;
203+
auto &ctx = module->getASTContext();
204+
GenericSignatureBuilder builder(ctx);
205+
// Add derivative function generic signature.
206+
builder.addGenericSignature(derivativeFnGenSig);
207+
// All differentiability parameters are required to conform to
208+
// `Differentiable`.
209+
auto source =
210+
GenericSignatureBuilder::FloatingRequirementSource::forAbstract();
211+
auto *differentiableProtocol =
212+
ctx.getProtocol(KnownProtocolKind::Differentiable);
213+
for (unsigned paramIdx : parameterIndices->getIndices()) {
214+
auto paramType = originalParameters[paramIdx].getInterfaceType();
215+
Requirement req(RequirementKind::Conformance, paramType,
216+
differentiableProtocol->getDeclaredType());
217+
builder.addRequirement(req, source, module);
218+
}
219+
return std::move(builder)
220+
.computeGenericSignature(SourceLoc(), /*allowConcreteGenericParams*/ true)
221+
->getCanonicalSignature();
222+
}
223+
224+
CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
225+
IndexSubset *parameterIndices, unsigned resultIndex,
226+
AutoDiffDerivativeFunctionKind kind, TypeConverter &TC,
227+
LookupConformanceFn lookupConformance,
228+
CanGenericSignature derivativeFnGenSig, bool isReabstractionThunk) {
229+
auto &ctx = getASTContext();
230+
231+
// Returns true if `index` is a differentiability parameter index.
232+
auto isDiffParamIndex = [&](unsigned index) -> bool {
233+
return index < parameterIndices->getCapacity() &&
234+
parameterIndices->contains(index);
235+
};
236+
237+
// Calculate differentiability parameter infos.
238+
SmallVector<SILParameterInfo, 4> diffParams;
239+
for (auto valueAndIndex : enumerate(getParameters()))
240+
if (isDiffParamIndex(valueAndIndex.index()))
241+
diffParams.push_back(valueAndIndex.value());
242+
243+
// Get the canonical derivative function generic signature.
244+
if (!derivativeFnGenSig)
245+
derivativeFnGenSig = getSubstGenericSignature();
246+
derivativeFnGenSig = getAutoDiffDerivativeFunctionGenericSignature(
247+
derivativeFnGenSig, getParameters(), parameterIndices, &TC.M);
248+
249+
// Given a type, returns its formal SIL parameter info.
250+
auto getTangentParameterInfoForOriginalResult =
251+
[&](CanType tanType, ResultConvention origResConv) -> SILParameterInfo {
252+
AbstractionPattern pattern(derivativeFnGenSig, tanType);
253+
auto &tl =
254+
TC.getTypeLowering(pattern, tanType, TypeExpansionContext::minimal());
255+
ParameterConvention conv;
256+
switch (origResConv) {
257+
case ResultConvention::Owned:
258+
case ResultConvention::Autoreleased:
259+
conv = tl.isTrivial() ? ParameterConvention::Direct_Unowned
260+
: ParameterConvention::Direct_Guaranteed;
261+
break;
262+
case ResultConvention::Unowned:
263+
case ResultConvention::UnownedInnerPointer:
264+
conv = ParameterConvention::Direct_Unowned;
265+
break;
266+
case ResultConvention::Indirect:
267+
conv = ParameterConvention::Indirect_In_Guaranteed;
268+
break;
269+
}
270+
return {tanType, conv};
271+
};
272+
273+
// Given a type, returns its formal SIL result info.
274+
auto getTangentResultInfoForOriginalParameter =
275+
[&](CanType tanType, ParameterConvention origParamConv) -> SILResultInfo {
276+
AbstractionPattern pattern(derivativeFnGenSig, tanType);
277+
auto &tl =
278+
TC.getTypeLowering(pattern, tanType, TypeExpansionContext::minimal());
279+
ResultConvention conv;
280+
switch (origParamConv) {
281+
case ParameterConvention::Direct_Owned:
282+
case ParameterConvention::Direct_Guaranteed:
283+
case ParameterConvention::Direct_Unowned:
284+
conv =
285+
tl.isTrivial() ? ResultConvention::Unowned : ResultConvention::Owned;
286+
break;
287+
case ParameterConvention::Indirect_In:
288+
case ParameterConvention::Indirect_Inout:
289+
case ParameterConvention::Indirect_In_Constant:
290+
case ParameterConvention::Indirect_In_Guaranteed:
291+
case ParameterConvention::Indirect_InoutAliasable:
292+
conv = ResultConvention::Indirect;
293+
break;
294+
}
295+
return {tanType, conv};
296+
};
297+
298+
CanSILFunctionType closureType;
299+
switch (kind) {
300+
case AutoDiffDerivativeFunctionKind::JVP: {
301+
SmallVector<SILParameterInfo, 8> differentialParams;
302+
for (auto &param : diffParams) {
303+
auto paramTan =
304+
param.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
305+
assert(paramTan && "Parameter type does not have a tangent space?");
306+
differentialParams.push_back(
307+
{paramTan->getCanonicalType(), param.getConvention()});
308+
}
309+
SmallVector<SILResultInfo, 8> differentialResults;
310+
auto &result = getResults()[resultIndex];
311+
auto resultTan =
312+
result.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
313+
assert(resultTan && "Result type does not have a tangent space?");
314+
differentialResults.push_back(
315+
{resultTan->getCanonicalType(), result.getConvention()});
316+
closureType = SILFunctionType::get(
317+
/*genericSignature*/ nullptr, ExtInfo(), SILCoroutineKind::None,
318+
ParameterConvention::Direct_Guaranteed, differentialParams, {},
319+
differentialResults, None, getSubstitutions(),
320+
isGenericSignatureImplied(), ctx);
321+
break;
322+
}
323+
case AutoDiffDerivativeFunctionKind::VJP: {
324+
SmallVector<SILParameterInfo, 8> pullbackParams;
325+
auto &origRes = getResults()[resultIndex];
326+
auto resultTan =
327+
origRes.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
328+
assert(resultTan && "Result type does not have a tangent space?");
329+
pullbackParams.push_back(getTangentParameterInfoForOriginalResult(
330+
resultTan->getCanonicalType(), origRes.getConvention()));
331+
SmallVector<SILResultInfo, 8> pullbackResults;
332+
for (auto &param : diffParams) {
333+
auto paramTan =
334+
param.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
335+
assert(paramTan && "Parameter type does not have a tangent space?");
336+
pullbackResults.push_back(getTangentResultInfoForOriginalParameter(
337+
paramTan->getCanonicalType(), param.getConvention()));
338+
}
339+
closureType = SILFunctionType::get(
340+
/*genericSignature*/ nullptr, ExtInfo(), SILCoroutineKind::None,
341+
ParameterConvention::Direct_Guaranteed, pullbackParams, {},
342+
pullbackResults, {}, getSubstitutions(), isGenericSignatureImplied(),
343+
ctx);
344+
break;
345+
}
346+
}
347+
348+
SmallVector<SILParameterInfo, 4> newParameters;
349+
newParameters.reserve(getNumParameters());
350+
for (auto &param : getParameters()) {
351+
newParameters.push_back(param.getWithInterfaceType(
352+
param.getInterfaceType()->getCanonicalType(derivativeFnGenSig)));
353+
}
354+
// TODO(TF-1124): Upstream reabstraction thunk derivative typing rules.
355+
// Blocked by TF-1125: `SILFunctionType::getWithDifferentiability`.
356+
SmallVector<SILResultInfo, 4> newResults;
357+
newResults.reserve(getNumResults() + 1);
358+
for (auto &result : getResults()) {
359+
newResults.push_back(result.getWithInterfaceType(
360+
result.getInterfaceType()->getCanonicalType(derivativeFnGenSig)));
361+
}
362+
newResults.push_back({closureType->getCanonicalType(derivativeFnGenSig),
363+
ResultConvention::Owned});
364+
// Derivative function type has a generic signature only if the original
365+
// function type does, and if `derivativeFnGenSig` does not have all concrete
366+
// generic parameters.
367+
CanGenericSignature canGenSig;
368+
if (getSubstGenericSignature() && derivativeFnGenSig &&
369+
!derivativeFnGenSig->areAllParamsConcrete())
370+
canGenSig = derivativeFnGenSig;
371+
// If original function is `@convention(c)`, the derivative function should
372+
// have `@convention(thin)`. IRGen does not support `@convention(c)` functions
373+
// with multiple results.
374+
auto extInfo = getExtInfo();
375+
if (getRepresentation() == SILFunctionTypeRepresentation::CFunctionPointer)
376+
extInfo = extInfo.withRepresentation(SILFunctionTypeRepresentation::Thin);
377+
return SILFunctionType::get(canGenSig, extInfo, getCoroutineKind(),
378+
getCalleeConvention(), newParameters, getYields(),
379+
newResults, getOptionalErrorResult(),
380+
getSubstitutions(), isGenericSignatureImplied(),
381+
ctx, getWitnessMethodConformanceOrInvalid());
382+
}
383+
193384
static CanType getKnownType(Optional<CanType> &cacheSlot, ASTContext &C,
194385
StringRef moduleName, StringRef typeName) {
195386
if (!cacheSlot) {

0 commit comments

Comments
 (0)