Skip to content

[IRGen] Add direct error return support for async functions #75221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 199 additions & 60 deletions lib/IRGen/GenCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2114,10 +2114,40 @@ void SignatureExpansion::expandAsyncReturnType() {
}
};

auto resultType = getSILFuncConventions().getSILResultType(
IGM.getMaximalTypeExpansionContext());
auto fnConv = getSILFuncConventions();

auto resultType =
fnConv.getSILResultType(IGM.getMaximalTypeExpansionContext());
auto &ti = IGM.getTypeInfo(resultType);
auto &native = ti.nativeReturnValueSchema(IGM);

if (!fnConv.hasIndirectSILResults() && !fnConv.hasIndirectSILErrorResults() &&
!native.requiresIndirect() && fnConv.funcTy->hasErrorResult() &&
fnConv.isTypedError()) {
auto errorType = getSILFuncConventions().getSILErrorType(
IGM.getMaximalTypeExpansionContext());
auto &errorTi = IGM.getTypeInfo(errorType);
auto &nativeError = errorTi.nativeReturnValueSchema(IGM);
if (!nativeError.shouldReturnTypedErrorIndirectly()) {
auto combined = combineResultAndTypedErrorType(IGM, native, nativeError);

if (combined.combinedTy->isVoidTy()) {
addErrorResult();
return;
}

if (auto *structTy = dyn_cast<llvm::StructType>(combined.combinedTy)) {
for (auto *elem : structTy->elements()) {
ParamIRTypes.push_back(elem);
}
} else {
ParamIRTypes.push_back(combined.combinedTy);
}
}
addErrorResult();
return;
}

if (native.requiresIndirect() || native.empty()) {
addErrorResult();
return;
Expand All @@ -2135,11 +2165,23 @@ void SignatureExpansion::expandAsyncReturnType() {
void SignatureExpansion::addIndirectThrowingResult() {
if (getSILFuncConventions().funcTy->hasErrorResult() &&
getSILFuncConventions().isTypedError()) {
auto resultType = getSILFuncConventions().getSILErrorType(
IGM.getMaximalTypeExpansionContext());
const TypeInfo &resultTI = IGM.getTypeInfo(resultType);
auto storageTy = resultTI.getStorageType();
ParamIRTypes.push_back(storageTy->getPointerTo());
auto resultType = getSILFuncConventions().getSILResultType(
IGM.getMaximalTypeExpansionContext());
auto &ti = IGM.getTypeInfo(resultType);
auto &native = ti.nativeReturnValueSchema(IGM);

auto errorType = getSILFuncConventions().getSILErrorType(
IGM.getMaximalTypeExpansionContext());
const TypeInfo &errorTI = IGM.getTypeInfo(errorType);
auto &nativeError = errorTI.nativeReturnValueSchema(IGM);

if (getSILFuncConventions().hasIndirectSILResults() ||
getSILFuncConventions().hasIndirectSILErrorResults() ||
native.requiresIndirect() ||
nativeError.shouldReturnTypedErrorIndirectly()) {
auto errorStorageTy = errorTI.getStorageType();
ParamIRTypes.push_back(errorStorageTy->getPointerTo());
}
}

}
Expand Down Expand Up @@ -2265,6 +2307,36 @@ void SignatureExpansion::expandAsyncAwaitType() {
IGM.getMaximalTypeExpansionContext());
auto &ti = IGM.getTypeInfo(resultType);
auto &native = ti.nativeReturnValueSchema(IGM);

if (!getSILFuncConventions().hasIndirectSILResults() &&
!getSILFuncConventions().hasIndirectSILErrorResults() &&
getSILFuncConventions().funcTy->hasErrorResult() &&
!native.requiresIndirect() && getSILFuncConventions().isTypedError()) {
auto errorType = getSILFuncConventions().getSILErrorType(
IGM.getMaximalTypeExpansionContext());
auto &errorTi = IGM.getTypeInfo(errorType);
auto &nativeError = errorTi.nativeReturnValueSchema(IGM);
if (!nativeError.shouldReturnTypedErrorIndirectly()) {
auto combined = combineResultAndTypedErrorType(IGM, native, nativeError);

if (combined.combinedTy->isVoidTy()) {
addErrorResult();
return;
}

if (auto *structTy = dyn_cast<llvm::StructType>(combined.combinedTy)) {
for (auto *elem : structTy->elements()) {
components.push_back(elem);
}
} else {
components.push_back(combined.combinedTy);
}
addErrorResult();
ResultIRType = llvm::StructType::get(IGM.getLLVMContext(), components);
return;
}
}

if (native.requiresIndirect() || native.empty()) {
addErrorResult();
ResultIRType = llvm::StructType::get(IGM.getLLVMContext(), components);
Expand All @@ -2278,7 +2350,6 @@ void SignatureExpansion::expandAsyncAwaitType() {
});

addErrorResult();

ResultIRType = llvm::StructType::get(IGM.getLLVMContext(), components);
}

Expand Down Expand Up @@ -2950,9 +3021,22 @@ class AsyncCallEmission final : public CallEmission {
setIndirectTypedErrorResultSlotArgsIndex(--LastArgWritten);
Args[LastArgWritten] = nullptr;
} else {
auto buf = IGF.getCalleeTypedErrorResultSlot(
fnConv.getSILErrorType(IGF.IGM.getMaximalTypeExpansionContext()));
Args[--LastArgWritten] = buf.getAddress();
auto silResultTy =
fnConv.getSILResultType(IGF.IGM.getMaximalTypeExpansionContext());
auto silErrorTy =
fnConv.getSILErrorType(IGF.IGM.getMaximalTypeExpansionContext());

auto &nativeSchema =
IGF.IGM.getTypeInfo(silResultTy).nativeReturnValueSchema(IGF.IGM);
auto &errorSchema =
IGF.IGM.getTypeInfo(silErrorTy).nativeReturnValueSchema(IGF.IGM);

if (nativeSchema.requiresIndirect() ||
errorSchema.shouldReturnTypedErrorIndirectly()) {
// Return the error indirectly.
auto buf = IGF.getCalleeTypedErrorResultSlot(silErrorTy);
Args[--LastArgWritten] = buf.getAddress();
}
}
}

Expand Down Expand Up @@ -3134,7 +3218,22 @@ class AsyncCallEmission final : public CallEmission {
errorType =
substConv.getSILErrorType(IGM.getMaximalTypeExpansionContext());

if (resultTys.size() == 1) {
SILFunctionConventions fnConv(getCallee().getOrigFunctionType(),
IGF.getSILModule());

// Get the natural IR type in the body of the function that makes
// the call. This may be different than the IR type returned by the
// call itself due to ABI type coercion.
auto resultType =
fnConv.getSILResultType(IGF.IGM.getMaximalTypeExpansionContext());
auto &nativeSchema =
IGF.IGM.getTypeInfo(resultType).nativeReturnValueSchema(IGF.IGM);

bool mayReturnErrorDirectly = mayReturnTypedErrorDirectly();
if (mayReturnErrorDirectly && !nativeSchema.requiresIndirect()) {
return emitToUnmappedExplosionWithDirectTypedError(resultType, result,
out);
} else if (resultTys.size() == 1) {
result = Builder.CreateExtractValue(result, numAsyncContextParams);
if (hasError) {
Address errorAddr = IGF.getCalleeErrorResultSlot(errorType,
Expand Down Expand Up @@ -3166,17 +3265,6 @@ class AsyncCallEmission final : public CallEmission {
result = resultAgg;
}

SILFunctionConventions fnConv(getCallee().getOrigFunctionType(),
IGF.getSILModule());

// Get the natural IR type in the body of the function that makes
// the call. This may be different than the IR type returned by the
// call itself due to ABI type coercion.
auto resultType =
fnConv.getSILResultType(IGF.IGM.getMaximalTypeExpansionContext());
auto &nativeSchema =
IGF.IGM.getTypeInfo(resultType).nativeReturnValueSchema(IGF.IGM);

// For ABI reasons the result type of the call might not actually match the
// expected result type.
//
Expand Down Expand Up @@ -3315,7 +3403,7 @@ void CallEmission::emitToUnmappedMemory(Address result) {
#ifndef NDEBUG
LastArgWritten = 0; // appease an assert
#endif

auto call = emitCallSite();

// Async calls need to store the error result that is passed as a parameter.
Expand Down Expand Up @@ -4403,32 +4491,21 @@ void CallEmission::emitToUnmappedExplosionWithDirectTypedError(
extractScalarResults(IGF, result->getType(), result, nativeExplosion);
auto values = nativeExplosion.claimAll();

auto convertIfNecessary = [&](llvm::Type *nativeTy,
llvm::Value *elt) -> llvm::Value * {
auto *eltTy = elt->getType();
if (nativeTy->isIntOrPtrTy() && eltTy->isIntOrPtrTy() &&
nativeTy->getPrimitiveSizeInBits() != eltTy->getPrimitiveSizeInBits()) {
if (nativeTy->isPointerTy() && eltTy == IGF.IGM.IntPtrTy) {
return IGF.Builder.CreateIntToPtr(elt, nativeTy);
}
return IGF.Builder.CreateTruncOrBitCast(elt, nativeTy);
}
return elt;
};

Explosion errorExplosion;
if (!errorSchema.empty()) {
if (auto *structTy =
dyn_cast<llvm::StructType>(errorSchema.getExpandedType(IGF.IGM))) {
for (unsigned i = 0, e = structTy->getNumElements(); i < e; ++i) {
llvm::Value *elt = values[combined.errorValueMapping[i]];
auto *nativeTy = structTy->getElementType(i);
elt = convertIfNecessary(nativeTy, elt);
elt = convertForAsyncDirect(IGF, elt, nativeTy, /*forExtraction*/ true);
errorExplosion.add(elt);
}
} else {
errorExplosion.add(convertIfNecessary(
combined.combinedTy, values[combined.errorValueMapping[0]]));
auto *converted =
convertForAsyncDirect(IGF, values[combined.errorValueMapping[0]],
combined.combinedTy, /*forExtraction*/ true);
errorExplosion.add(converted);
}

typedErrorExplosion =
Expand All @@ -4444,10 +4521,14 @@ void CallEmission::emitToUnmappedExplosionWithDirectTypedError(
dyn_cast<llvm::StructType>(nativeSchema.getExpandedType(IGF.IGM))) {
for (unsigned i = 0, e = structTy->getNumElements(); i < e; ++i) {
auto *nativeTy = structTy->getElementType(i);
resultExplosion.add(convertIfNecessary(nativeTy, values[i]));
auto *converted = convertForAsyncDirect(IGF, values[i], nativeTy,
/*forExtraction*/ true);
resultExplosion.add(converted);
}
} else {
resultExplosion.add(convertIfNecessary(combined.combinedTy, values[0]));
auto *converted = convertForAsyncDirect(
IGF, values[0], combined.combinedTy, /*forExtraction*/ true);
resultExplosion.add(converted);
}
out = nativeSchema.mapFromNative(IGF.IGM, IGF, resultExplosion, resultType);
}
Expand Down Expand Up @@ -5313,6 +5394,33 @@ llvm::Value* IRGenFunction::coerceValue(llvm::Value *value, llvm::Type *toTy,
return loaded;
}

llvm::Value *irgen::convertForAsyncDirect(IRGenFunction &IGF,
llvm::Value *value, llvm::Type *toTy,
bool forExtraction) {
auto &Builder = IGF.Builder;
auto *fromTy = value->getType();
if (toTy->isIntOrPtrTy() && fromTy->isIntOrPtrTy() && toTy != fromTy) {

if (toTy->isPointerTy()) {
if (fromTy->isPointerTy())
return Builder.CreateBitCast(value, toTy);
if (fromTy == IGF.IGM.IntPtrTy)
return Builder.CreateIntToPtr(value, toTy);
} else if (fromTy->isPointerTy()) {
if (toTy == IGF.IGM.IntPtrTy) {
return Builder.CreatePtrToInt(value, toTy);
}
}

if (forExtraction) {
return Builder.CreateTruncOrBitCast(value, toTy);
} else {
return Builder.CreateZExtOrBitCast(value, toTy);
}
}
return value;
}

void IRGenFunction::emitScalarReturn(llvm::Type *resultType,
Explosion &result) {
if (result.empty()) {
Expand Down Expand Up @@ -5754,32 +5862,18 @@ void IRGenFunction::emitScalarReturn(SILType returnResultType,
return;
}

auto convertIfNecessary = [&](llvm::Type *nativeTy,
llvm::Value *elt) -> llvm::Value * {
auto *eltTy = elt->getType();
if (nativeTy->isIntOrPtrTy() && eltTy->isIntOrPtrTy() &&
nativeTy->getPrimitiveSizeInBits() !=
eltTy->getPrimitiveSizeInBits()) {
assert(nativeTy->getPrimitiveSizeInBits() >
eltTy->getPrimitiveSizeInBits());
if (eltTy->isPointerTy()) {
return Builder.CreatePtrToInt(elt, nativeTy);
}
return Builder.CreateZExt(elt, nativeTy);
}
return elt;
};

if (auto *structTy = dyn_cast<llvm::StructType>(combinedTy)) {
nativeAgg = llvm::UndefValue::get(combinedTy);
for (unsigned i = 0, e = native.size(); i != e; ++i) {
llvm::Value *elt = native.claimNext();
auto *nativeTy = structTy->getElementType(i);
elt = convertIfNecessary(nativeTy, elt);
elt = convertForAsyncDirect(*this, elt, nativeTy,
/*forExtraction*/ false);
nativeAgg = Builder.CreateInsertValue(nativeAgg, elt, i);
}
} else {
nativeAgg = convertIfNecessary(combinedTy, native.claimNext());
nativeAgg = convertForAsyncDirect(*this, native.claimNext(), combinedTy,
/*forExtraction*/ false);
}
}

Expand Down Expand Up @@ -6089,6 +6183,51 @@ void irgen::emitAsyncReturn(IRGenFunction &IGF, AsyncContextLayout &asyncLayout,
SILFunctionConventions conv(fnType, IGF.getSILModule());
auto &nativeSchema =
IGM.getTypeInfo(funcResultTypeInContext).nativeReturnValueSchema(IGM);

if (fnType->hasErrorResult() && !conv.hasIndirectSILResults() &&
!conv.hasIndirectSILErrorResults() && !nativeSchema.requiresIndirect() &&
conv.isTypedError()) {
auto errorType = conv.getSILErrorType(IGM.getMaximalTypeExpansionContext());
auto &errorTI = IGM.getTypeInfo(errorType);
auto &nativeError = errorTI.nativeReturnValueSchema(IGM);
if (!nativeError.shouldReturnTypedErrorIndirectly()) {
assert(!error.empty() && "Direct error return must have error value");
auto *combinedTy =
combineResultAndTypedErrorType(IGM, nativeSchema, nativeError)
.combinedTy;

if (combinedTy->isVoidTy()) {
assert(result.empty() && "Unexpected result values");
} else {
if (auto *structTy = dyn_cast<llvm::StructType>(combinedTy)) {
llvm::Value *nativeAgg = llvm::UndefValue::get(structTy);
for (unsigned i = 0, e = result.size(); i != e; ++i) {
llvm::Value *elt = result.claimNext();
auto *nativeTy = structTy->getElementType(i);
elt = convertForAsyncDirect(IGF, elt, nativeTy,
/*forExtraction*/ false);
nativeAgg = IGF.Builder.CreateInsertValue(nativeAgg, elt, i);
}
Explosion out;
IGF.emitAllExtractValues(nativeAgg, structTy, out);
while (!out.empty()) {
nativeResultsStorage.push_back(out.claimNext());
}
} else {
auto *converted = convertForAsyncDirect(
IGF, result.claimNext(), combinedTy, /*forExtraction*/ false);
nativeResultsStorage.push_back(converted);
}
}

nativeResultsStorage.push_back(error.claimNext());
nativeResults = nativeResultsStorage;

emitAsyncReturn(IGF, asyncLayout, fnType, nativeResults);
return;
}
}

if (result.empty() && !nativeSchema.empty()) {
if (!nativeSchema.requiresIndirect())
// When we throw, we set the return values to undef.
Expand Down
4 changes: 4 additions & 0 deletions lib/IRGen/GenCall.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,10 @@ namespace irgen {
void forwardAsyncCallResult(IRGenFunction &IGF, CanSILFunctionType fnType,
AsyncContextLayout &layout, llvm::CallInst *call);

/// Converts a value for async direct errors.
llvm::Value *convertForAsyncDirect(IRGenFunction &IGF, llvm::Value *value,
llvm::Type *toTy, bool forExtraction);

} // end namespace irgen
} // end namespace swift

Expand Down
Loading