Skip to content
Merged
5 changes: 5 additions & 0 deletions mooncake-common/common.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ if (USE_CUDA)
)
endif()

if (USE_CXL)
add_compile_definitions(USE_CXL)
message(STATUS "CXL support is enabled")
endif()

if (USE_TCP)
add_compile_definitions(USE_TCP)
endif()
Expand Down
4 changes: 4 additions & 0 deletions mooncake-transfer-engine/include/transfer_metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class TransferMetadata {
std::vector<uint32_t> lkey; // for rdma
std::vector<uint32_t> rkey; // for rdma
std::string shm_name; // for nvlink
uint64_t offset; // for cxl
};

struct NVMeoFBufferDesc {
Expand Down Expand Up @@ -86,6 +87,9 @@ class TransferMetadata {
std::vector<BufferDesc> buffers;
// this is for nvmeof.
std::vector<NVMeoFBufferDesc> nvmeof_buffers;
// this is for cxl.
std::string cxl_name;
uint64_t cxl_base_addr;
// TODO : make these two a union or a std::variant
std::string timestamp;
// this is for ascend
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#ifndef CXL_TRANSPORT_H_
#define CXL_TRANSPORT_H_

#include <infiniband/verbs.h>

#include <atomic>
#include <cstddef>
Expand All @@ -37,47 +36,59 @@ class CxlTransport : public Transport {
public:
using BufferDesc = TransferMetadata::BufferDesc;
using SegmentDesc = TransferMetadata::SegmentDesc;
using HandShakeDesc = TransferMetadata::HandShakeDesc;

public:
CxlTransport();

~CxlTransport();

BatchID allocateBatchID(size_t batch_size) override;
Status submitTransfer(BatchID batch_id,
const std::vector<TransferRequest> &entries) override;

int submitTransfer(BatchID batch_id,
const std::vector<TransferRequest> &entries) override;
Status submitTransferTask(
const std::vector<TransferTask *> &task_list) override;

Status getTransferStatus(BatchID batch_id, size_t task_id,
TransferStatus &status) override;

Status freeBatchID(BatchID batch_id) override;
void* getCxlBaseAddr() { return cxl_base_addr; }

private:
int install(std::string &local_server_name,
std::shared_ptr<TransferMetadata> meta,
std::shared_ptr<Topology> topo) override;
std::shared_ptr<Topology> topo);

int allocateLocalSegmentID();

int registerLocalMemory(void *addr, size_t length,
const std::string &location, bool remote_accessible,
bool update_metadata) override;
bool update_metadata);

int unregisterLocalMemory(void *addr,
bool update_metadata = false) override;
int unregisterLocalMemory(void *addr, bool update_metadata = false);

int registerLocalMemoryBatch(
const std::vector<Transport::BufferEntry> &buffer_list,
const std::string &location) override {
return 0;
}
const std::string &location);

int unregisterLocalMemoryBatch(
const std::vector<void *> &addr_list) override {
return 0;
}
const std::vector<void *> &addr_list) override;

const char *getName() const override { return "cxl"; }

int cxlDevInit();

size_t cxlGetDeviceSize();

int cxlMemcpy(void *dest_addr, void *source_addr, size_t size);

bool isAddressInCxlRange(void *addr);

bool validateMemoryBounds(void *dest, void *src, size_t size);

private:
void* cxl_base_addr;
size_t cxl_dev_size;
char* cxl_dev_path;
};
} // namespace mooncake

Expand Down
4 changes: 1 addition & 3 deletions mooncake-transfer-engine/include/transport/transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,7 @@ class Transport {
const char *file_path;
} nvmeof;
struct {
void *remote_filename;
void *remote_addr;
size_t remote_offset;
void *dest_addr;
} cxl;
struct {
uint64_t dest_addr;
Expand Down
8 changes: 8 additions & 0 deletions mooncake-transfer-engine/src/multi_transport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
#ifdef USE_MNNVL
#include "transport/nvlink_transport/nvlink_transport.h"
#endif
#ifdef USE_CXL
#include "transport/cxl_transport/cxl_transport.h"
#endif

#include <cassert>

Expand Down Expand Up @@ -208,6 +211,11 @@ Transport *MultiTransport::installTransport(const std::string &proto,
transport = new NvlinkTransport();
}
#endif
#ifdef USE_CXL
else if (std::string(proto) == "cxl") {
transport = new CxlTransport();
}
#endif

if (!transport) {
LOG(ERROR) << "Unsupported transport " << proto
Expand Down
11 changes: 11 additions & 0 deletions mooncake-transfer-engine/src/transfer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,17 @@ int TransferEngine::init(const std::string &metadata_conn_string,
return -1;
}
#else

#if defined(USE_CXL) && !defined(USE_ASCEND)
if (std::getenv("MC_CXL_DEV_PATH") != nullptr && std::getenv("MC_CXL_DEV_SIZE") != nullptr) {
Transport* cxl_transport = multi_transports_->installTransport("cxl", local_topology_);
if (!cxl_transport) {
LOG(ERROR) << "Failed to install CXL transport";
return -1;
}
}
#endif

if (auto_discover_) {
LOG(INFO) << "Auto-discovering topology...";
if (getenv("MC_CUSTOM_TOPO_JSON")) {
Expand Down
35 changes: 33 additions & 2 deletions mooncake-transfer-engine/src/transfer_metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,18 @@ int TransferMetadata::encodeSegmentDesc(const SegmentDesc &desc,
buffersJSON.append(bufferJSON);
}
segmentJSON["buffers"] = buffersJSON;
} else if (segmentJSON["protocol"] == "cxl") {
segmentJSON["cxl_name"] = desc.cxl_name;
segmentJSON["cxl_base_addr"] = static_cast<Json::UInt64>(desc.cxl_base_addr);
Json::Value buffersJSON(Json::arrayValue);
for (const auto &buffer : desc.buffers) {
Json::Value bufferJSON;
bufferJSON["name"] = buffer.name;
bufferJSON["offset"] = static_cast<Json::UInt64>(buffer.offset);
bufferJSON["length"] = static_cast<Json::UInt64>(buffer.length);
buffersJSON.append(bufferJSON);
}
segmentJSON["buffers"] = buffersJSON;
} else {
LOG(ERROR) << "Unsupported segment descriptor for register, name "
<< desc.name << " protocol " << desc.protocol;
Expand Down Expand Up @@ -382,6 +394,21 @@ TransferMetadata::decodeSegmentDesc(Json::Value &segmentJSON,
desc->rank_info.deviceIp = rankInfoJSON["deviceIp"].asString();
desc->rank_info.devicePort = rankInfoJSON["devicePort"].asUInt64();
desc->rank_info.pid = rankInfoJSON["pid"].asUInt64();
} else if (desc->protocol == "cxl") {
desc->cxl_name = segmentJSON["cxl_name"].asString();
desc->cxl_base_addr = segmentJSON["cxl_base_addr"].asUInt64();
for (const auto &bufferJSON : segmentJSON["buffers"]) {
BufferDesc buffer;
buffer.name = bufferJSON["name"].asString();
buffer.offset = bufferJSON["offset"].asUInt64();
buffer.length = bufferJSON["length"].asUInt64();
if (buffer.name.empty() || !buffer.length) {
LOG(WARNING) << "Corrupted segment descriptor, name "
<< segment_name << " protocol " << desc->protocol;
return nullptr;
}
desc->buffers.push_back(buffer);
}
} else {
LOG(ERROR) << "Unsupported segment descriptor, name " << segment_name
<< " protocol " << desc->protocol;
Expand Down Expand Up @@ -557,8 +584,12 @@ int TransferMetadata::removeLocalMemoryBuffer(void *addr,
*new_segment_desc = *segment_desc;
segment_desc = new_segment_desc;
for (auto iter = segment_desc->buffers.begin();
iter != segment_desc->buffers.end(); ++iter) {
if (iter->addr == (uint64_t)addr) {
iter != segment_desc->buffers.end(); ++iter) {
if (iter->addr == (uint64_t)addr
#ifdef USE_CXL
|| (iter->offset + segment_desc->cxl_base_addr) == (uint64_t)addr
#endif
) {
segment_desc->buffers.erase(iter);
addr_exist = true;
break;
Expand Down
Loading
Loading