Skip to content

[AutoDiff upstream] Serialize derivative function configurations. #30672

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/swift/AST/ASTContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,19 @@ class ASTContext final {
unsigned previousGeneration,
llvm::TinyPtrVector<AbstractFunctionDecl *> &methods);

/// Load derivative function configurations for the given
/// AbstractFunctionDecl.
///
/// \param originalAFD The declaration whose derivative function
/// configurations should be loaded.
///
/// \param previousGeneration The previous generation number. The AST already
/// contains derivative function configurations loaded from any generation up
/// to and including this one.
void loadDerivativeFunctionConfigurations(
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
llvm::SetVector<AutoDiffConfig> &results);

/// Retrieve the Clang module loader for this ASTContext.
///
/// If there is no Clang module loader, returns a null pointer.
Expand Down
14 changes: 8 additions & 6 deletions include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -5796,6 +5796,7 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
private:
ParameterList *Params;

private:
/// The generation at which we last loaded derivative function configurations.
unsigned DerivativeFunctionConfigGeneration = 0;
/// Prepare to traverse the list of derivative function configurations.
Expand All @@ -5810,6 +5811,13 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
struct DerivativeFunctionConfigurationList;
DerivativeFunctionConfigurationList *DerivativeFunctionConfigs = nullptr;

public:
/// Get all derivative function configurations.
ArrayRef<AutoDiffConfig> getDerivativeFunctionConfigurations();

/// Add the given derivative function configuration.
void addDerivativeFunctionConfiguration(AutoDiffConfig config);

protected:
// If a function has a body at all, we have either a parsed body AST node or
// we have saved the end location of the unparsed body.
Expand Down Expand Up @@ -6129,12 +6137,6 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
/// constructor.
bool hasDynamicSelfResult() const;

/// Get all derivative function configurations.
ArrayRef<AutoDiffConfig> getDerivativeFunctionConfigurations();

/// Add the given derivative function configuration.
void addDerivativeFunctionConfiguration(AutoDiffConfig config);

using DeclContext::operator new;
using Decl::getASTContext;
};
Expand Down
18 changes: 18 additions & 0 deletions include/swift/AST/ModuleLoader.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class DependencyCollector;
namespace swift {

class AbstractFunctionDecl;
struct AutoDiffConfig;
class ClangImporterOptions;
class ClassDecl;
class FileUnit;
Expand Down Expand Up @@ -153,6 +154,23 @@ class ModuleLoader {
unsigned previousGeneration,
llvm::TinyPtrVector<AbstractFunctionDecl *> &methods) = 0;

/// Load derivative function configurations for the given
/// AbstractFunctionDecl.
///
/// \param originalAFD The declaration whose derivative function
/// configurations should be loaded.
///
/// \param previousGeneration The previous generation number. The AST already
/// contains derivative function configurations loaded from any generation up
/// to and including this one.
///
/// \param results The result list of derivative function configurations.
/// This list will be extended with any methods found in subsequent
/// generations.
virtual void loadDerivativeFunctionConfigurations(
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
llvm::SetVector<AutoDiffConfig> &results) {};

/// Verify all modules loaded by this loader.
virtual void verifyAllModules() { }

Expand Down
4 changes: 4 additions & 0 deletions include/swift/Serialization/SerializedModuleLoader.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ class SerializedModuleLoaderBase : public ModuleLoader {
unsigned previousGeneration,
llvm::TinyPtrVector<AbstractFunctionDecl *> &methods) override;

virtual void loadDerivativeFunctionConfigurations(
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
llvm::SetVector<AutoDiffConfig> &results) override;

virtual void verifyAllModules() override;
};

Expand Down
11 changes: 11 additions & 0 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1471,6 +1471,17 @@ void ASTContext::loadObjCMethods(
}
}

void ASTContext::loadDerivativeFunctionConfigurations(
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
llvm::SetVector<AutoDiffConfig> &results) {
PrettyStackTraceDecl stackTrace(
"loading derivative function configurations for", originalAFD);
for (auto &loader : getImpl().ModuleLoaders) {
loader->loadDerivativeFunctionConfigurations(originalAFD,
previousGeneration, results);
}
}

void ASTContext::verifyAllLoadedModules() const {
#ifndef NDEBUG
FrontendStatsTracer tracer(Stats, "verify-all-loaded-modules");
Expand Down
6 changes: 4 additions & 2 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7099,8 +7099,10 @@ AbstractFunctionDecl::getDerivativeFunctionConfigurations() {
prepareDerivativeFunctionConfigurations();
auto &ctx = getASTContext();
if (ctx.getCurrentGeneration() > DerivativeFunctionConfigGeneration) {
// TODO(TF-1100): Upstream derivative function configuration serialization
// logic.
unsigned previousGeneration = DerivativeFunctionConfigGeneration;
DerivativeFunctionConfigGeneration = ctx.getCurrentGeneration();
ctx.loadDerivativeFunctionConfigurations(this, previousGeneration,
*DerivativeFunctionConfigs);
}
return DerivativeFunctionConfigs->getArrayRef();
}
Expand Down
10 changes: 10 additions & 0 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3970,6 +3970,10 @@ llvm::Expected<IndexSubset *> DifferentiableAttributeTypeCheckRequest::evaluate(
return nullptr;
}
getterDecl->getAttrs().add(newAttr);
// Register derivative function configuration.
auto *resultIndices = IndexSubset::get(ctx, 1, {0});
getterDecl->addDerivativeFunctionConfiguration(
{resolvedDiffParamIndices, resultIndices, derivativeGenSig});
return resolvedDiffParamIndices;
}
// Reject duplicate `@differentiable` attributes.
Expand Down Expand Up @@ -4341,6 +4345,12 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
return true;
}

// Register derivative function configuration.
auto *resultIndices = IndexSubset::get(Ctx, 1, {0});
originalAFD->addDerivativeFunctionConfiguration(
{resolvedDiffParamIndices, resultIndices,
derivative->getGenericSignature()});

return false;
}

Expand Down
5 changes: 5 additions & 0 deletions lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,11 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
} else {
witness->getAttrs().add(newAttr);
success = true;
// Register derivative function configuration.
auto *resultIndices = IndexSubset::get(ctx, 1, {0});
witnessAFD->addDerivativeFunctionConfiguration(
{newAttr->getParameterIndices(), resultIndices,
newAttr->getDerivativeGenericSignature()});
}
}
if (!success) {
Expand Down
2 changes: 2 additions & 0 deletions lib/Serialization/DeclTypeRecordNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ OTHER(XREF_OPAQUE_RETURN_TYPE_PATH_PIECE, 252)

OTHER(CLANG_TYPE, 253)

OTHER(DERIVATIVE_FUNCTION_CONFIGURATION, 254)

#undef RECORD
#undef DECLTYPERECORDNODES_HAS_RECORD_VAL
#undef RECORD_VAL
Expand Down
92 changes: 92 additions & 0 deletions lib/Serialization/ModuleFile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,66 @@ ModuleFile::readObjCMethodTable(ArrayRef<uint64_t> fields, StringRef blobData) {
base + sizeof(uint32_t), base));
}

/// Used to deserialize entries in the on-disk derivative function configuration
/// table.
class ModuleFile::DerivativeFunctionConfigTableInfo {
public:
using internal_key_type = StringRef;
using external_key_type = internal_key_type;
using data_type = SmallVector<std::pair<std::string, GenericSignatureID>, 8>;
using hash_value_type = uint32_t;
using offset_type = unsigned;

external_key_type GetExternalKey(internal_key_type ID) { return ID; }

internal_key_type GetInternalKey(external_key_type ID) { return ID; }

hash_value_type ComputeHash(internal_key_type key) {
return llvm::djbHash(key, SWIFTMODULE_HASH_SEED);
}

static bool EqualKey(internal_key_type lhs, internal_key_type rhs) {
return lhs == rhs;
}

static std::pair<unsigned, unsigned> ReadKeyDataLength(const uint8_t *&data) {
unsigned keyLength = endian::readNext<uint16_t, little, unaligned>(data);
unsigned dataLength = endian::readNext<uint16_t, little, unaligned>(data);
return {keyLength, dataLength};
}

static internal_key_type ReadKey(const uint8_t *data, unsigned length) {
return StringRef(reinterpret_cast<const char *>(data), length);
}

static data_type ReadData(internal_key_type key, const uint8_t *data,
unsigned length) {
data_type result;
const uint8_t *limit = data + length;
while (data < limit) {
DeclID genSigId = endian::readNext<uint32_t, little, unaligned>(data);
int32_t nameLength = endian::readNext<int32_t, little, unaligned>(data);
StringRef mangledName(reinterpret_cast<const char *>(data), nameLength);
data += nameLength;
result.push_back({mangledName, genSigId});
}
return result;
}
};

std::unique_ptr<ModuleFile::SerializedDerivativeFunctionConfigTable>
ModuleFile::readDerivativeFunctionConfigTable(ArrayRef<uint64_t> fields,
StringRef blobData) {
uint32_t tableOffset;
index_block::DerivativeFunctionConfigTableLayout::readRecord(fields,
tableOffset);
auto base = reinterpret_cast<const uint8_t *>(blobData.data());

using OwnedTable = std::unique_ptr<SerializedDerivativeFunctionConfigTable>;
return OwnedTable(SerializedDerivativeFunctionConfigTable::Create(
base + tableOffset, base + sizeof(uint32_t), base));
}

bool ModuleFile::readIndexBlock(llvm::BitstreamCursor &cursor) {
if (llvm::Error Err = cursor.EnterSubBlock(INDEX_BLOCK_ID)) {
// FIXME this drops the error on the floor.
Expand Down Expand Up @@ -1015,6 +1075,10 @@ bool ModuleFile::readIndexBlock(llvm::BitstreamCursor &cursor) {
case index_block::OBJC_METHODS:
ObjCMethods = readObjCMethodTable(scratch, blobData);
break;
case index_block::DERIVATIVE_FUNCTION_CONFIGURATIONS:
DerivativeFunctionConfigurations =
readDerivativeFunctionConfigTable(scratch, blobData);
break;
case index_block::ENTRY_POINT:
assert(blobData.empty());
setEntryPointClassID(scratch.front());
Expand Down Expand Up @@ -2405,6 +2469,34 @@ void ModuleFile::loadObjCMethods(
}
}

void ModuleFile::loadDerivativeFunctionConfigurations(
AbstractFunctionDecl *originalAFD,
llvm::SetVector<AutoDiffConfig> &results) {
if (!DerivativeFunctionConfigurations)
return;
auto &ctx = originalAFD->getASTContext();
Mangle::ASTMangler Mangler;
auto mangledName = Mangler.mangleDeclAsUSR(originalAFD, "");
auto configs = DerivativeFunctionConfigurations->find(mangledName);
if (configs == DerivativeFunctionConfigurations->end())
return;
for (auto entry : *configs) {
auto *parameterIndices = IndexSubset::getFromString(ctx, entry.first);
auto derivativeGenSigOrError = getGenericSignatureChecked(entry.second);
if (!derivativeGenSigOrError) {
if (!getContext().LangOpts.EnableDeserializationRecovery)
fatal(derivativeGenSigOrError.takeError());
llvm::consumeError(derivativeGenSigOrError.takeError());
}
auto derivativeGenSig = derivativeGenSigOrError.get();
// NOTE(TF-1038): Result indices are currently unsupported in derivative
// registration attributes. In the meantime, always use `{0}` (wrt the
// first and only result).
auto resultIndices = IndexSubset::get(ctx, 1, {0});
results.insert({parameterIndices, resultIndices, derivativeGenSig});
}
}

TinyPtrVector<ValueDecl *>
ModuleFile::loadNamedMembers(const IterableDeclContext *IDC, DeclBaseName N,
uint64_t contextData) {
Expand Down
18 changes: 18 additions & 0 deletions lib/Serialization/ModuleFile.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,12 @@ class ModuleFile
llvm::OnDiskIterableChainedHashTable<DeclUSRTableInfo>;
std::unique_ptr<SerializedDeclUSRTable> DeclUSRsTable;

class DerivativeFunctionConfigTableInfo;
using SerializedDerivativeFunctionConfigTable =
llvm::OnDiskIterableChainedHashTable<DerivativeFunctionConfigTableInfo>;
std::unique_ptr<SerializedDerivativeFunctionConfigTable>
DerivativeFunctionConfigurations;

/// A blob of 0 terminated string segments referenced in \c SourceLocsTextData
StringRef SourceLocsTextData;

Expand Down Expand Up @@ -550,6 +556,12 @@ class ModuleFile
std::unique_ptr<SerializedDeclMembersTable>
readDeclMembersTable(ArrayRef<uint64_t> fields, StringRef blobData);

/// Read an on-disk derivative function configuration table stored in
/// index_block::DerivativeFunctionConfigTableLayout format.
std::unique_ptr<ModuleFile::SerializedDerivativeFunctionConfigTable>
readDerivativeFunctionConfigTable(ArrayRef<uint64_t> fields,
StringRef blobData);

/// Reads the index block, which contains global tables.
///
/// Returns false if there was an error.
Expand Down Expand Up @@ -774,6 +786,12 @@ class ModuleFile
bool isInstanceMethod,
llvm::TinyPtrVector<AbstractFunctionDecl *> &methods);

/// Loads all derivative function configurations for the given
/// AbstractFunctionDecl.
void loadDerivativeFunctionConfigurations(
AbstractFunctionDecl *originalAFD,
llvm::SetVector<AutoDiffConfig> &results);

/// Reports all class members in the module to the given consumer.
///
/// This is intended for use with id-style lookup and code completion.
Expand Down
12 changes: 11 additions & 1 deletion lib/Serialization/ModuleFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0;
/// describe what change you made. The content of this comment isn't important;
/// it just ensures a conflict if two people change the module format.
/// Don't worry about adhering to the 80-column limit for this line.
const uint16_t SWIFTMODULE_VERSION_MINOR = 550; // linear_function, linear_function_extract
const uint16_t SWIFTMODULE_VERSION_MINOR = 551; // derivative function configurations

/// A standard hash seed used for all string hashes in a serialized module.
///
Expand Down Expand Up @@ -1934,6 +1934,10 @@ namespace index_block {
/// produce Objective-C methods.
OBJC_METHODS,

/// The derivative function configuration table, which maps original
/// function declaration names to derivative function configurations.
DERIVATIVE_FUNCTION_CONFIGURATIONS,

ENTRY_POINT,
LOCAL_DECL_CONTEXT_OFFSETS,
LOCAL_TYPE_DECLS,
Expand Down Expand Up @@ -1998,6 +2002,12 @@ namespace index_block {
BCBlob // map from member DeclBaseNames to offsets of DECL_MEMBERS records
>;

using DerivativeFunctionConfigTableLayout = BCRecordLayout<
DERIVATIVE_FUNCTION_CONFIGURATIONS, // record ID
BCVBR<16>, // table offset within the blob (see below)
BCBlob // map from original declaration names to derivative configs
>;

using EntryPointLayout = BCRecordLayout<
ENTRY_POINT,
DeclIDField // the ID of the main class; 0 if there was a main source file
Expand Down
Loading