Skip to content

Commit

Permalink
[Memory] Adds typed ptr() and use_host_pointer for CPU modes
Browse files Browse the repository at this point in the history
  • Loading branch information
dmed256 committed Nov 30, 2019
1 parent 018d43d commit c61d636
Show file tree
Hide file tree
Showing 10 changed files with 143 additions and 25 deletions.
6 changes: 2 additions & 4 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,13 @@ before_install:
if [ -n "${MATRIX_INIT}" ]; then
eval "${MATRIX_INIT}"
fi
export BUILD_CXXFLAGS="${CXXFLAGS}"
export TEST_CXXFLAGS="${CXXFLAGS}"
script:
- cd ${TRAVIS_BUILD_DIR}
- CXXFLAGS="${BUILD_CXXFLAGS}" make -j 4
- make -j 4

- cd ${TRAVIS_BUILD_DIR}
- CXXFLAGS="${TEST_CXXFLAGS}" make -j 4 test
- make -j 4 test
- bash <(curl --no-buffer -s https://codecov.io/bash) > codecov_output
- head -n 100 codecov_output

Expand Down
6 changes: 6 additions & 0 deletions include/occa/core/base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ namespace occa {

void memcpy(memory dest, memory src,
const occa::properties &props);

namespace cpu {
occa::memory wrapMemory(void *ptr,
const udim_t bytes,
const occa::properties &props = occa::properties());
}
//====================================

//---[ Free Functions ]---------------
Expand Down
28 changes: 24 additions & 4 deletions include/occa/core/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,17 @@ namespace occa {

memory& swap(memory &m);

void* ptr();
const void* ptr() const;
template <class TM = void>
TM* ptr();

void* ptr(const occa::properties &props);
const void* ptr(const occa::properties &props) const;
template <class TM = void>
const TM* ptr() const;

template <class TM = void>
TM* ptr(const occa::properties &props);

template <class TM = void>
const TM* ptr(const occa::properties &props) const;

modeMemory_t* getModeMemory() const;
modeDevice_t* getModeDevice() const;
Expand Down Expand Up @@ -243,6 +249,20 @@ namespace occa {
const udim_t bytes,
const occa::properties &props = occa::properties());
}

template <>
void* memory::ptr<void>();

template <>
const void* memory::ptr<void>() const;

template <>
void* memory::ptr<void>(const occa::properties &props);

template <>
const void* memory::ptr<void>(const occa::properties &props) const;
}

#include "memory.tpp"

#endif
21 changes: 21 additions & 0 deletions include/occa/core/memory.tpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
namespace occa {
template <class TM>
TM* memory::ptr() {
return (TM*) ptr<void>();
}

template <class TM>
const TM* memory::ptr() const {
return (const TM*) ptr<void>();
}

template <class TM>
TM* memory::ptr(const occa::properties &props) {
return (TM*) ptr<void>(props);
}

template <class TM>
const TM* memory::ptr(const occa::properties &props) const {
return (const TM*) ptr<void>(props);
}
}
11 changes: 9 additions & 2 deletions src/c/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,18 +161,25 @@ occaMemory OCCA_RFUNC occaWrapCpuMemory(occaDevice device,
void *ptr,
occaUDim_t bytes,
occaProperties props) {
occa::device device_ = (
occa::c::isDefault(device)
? occa::getDevice()
: occa::c::device(device)
);

occa::memory mem;
if (occa::c::isDefault(props)) {
mem = occa::cpu::wrapMemory(occa::c::device(device),
mem = occa::cpu::wrapMemory(device_,
ptr,
bytes);
} else {
mem = occa::cpu::wrapMemory(occa::c::device(device),
mem = occa::cpu::wrapMemory(device_,
ptr,
bytes,
occa::c::properties(props));
}
mem.dontUseRefs();

return occa::c::newOccaType(mem);
}

Expand Down
8 changes: 8 additions & 0 deletions src/core/base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,14 @@ namespace occa {
const occa::properties &props) {
memcpy(dest, src, -1, 0, 0, props);
}

namespace cpu {
occa::memory wrapMemory(void *ptr,
const udim_t bytes,
const occa::properties &props) {
return occa::cpu::wrapMemory(getDevice(), ptr, bytes, props);
}
}
//====================================

//---[ Free Functions ]---------------
Expand Down
23 changes: 11 additions & 12 deletions src/core/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,25 +143,29 @@ namespace occa {
return *this;
}

void* memory::ptr() {
template <>
void* memory::ptr<void>() {
return (modeMemory
? modeMemory->ptr
: NULL);
}

const void* memory::ptr() const {
template <>
const void* memory::ptr<void>() const {
return (modeMemory
? modeMemory->ptr
: NULL);
}

void* memory::ptr(const occa::properties &props) {
template <>
void* memory::ptr<void>(const occa::properties &props) {
return (modeMemory
? modeMemory->getPtr(props)
: NULL);
}

const void* memory::ptr(const occa::properties &props) const {
template <>
const void* memory::ptr<void>(const occa::properties &props) const {
return (modeMemory
? modeMemory->getPtr(props)
: NULL);
Expand Down Expand Up @@ -594,15 +598,10 @@ namespace occa {
void *ptr,
const udim_t bytes,
const occa::properties &props) {
serial::memory &mem = *(new serial::memory(device.getModeDevice(),
bytes,
props));
occa::properties memProps = props;
memProps["use_host_pointer"] = true;

mem.dontUseRefs();
mem.ptr = (char*) ptr;
mem.isOrigin = false;

return occa::memory(&mem);
return device.malloc(bytes, ptr, memProps);
}
}
}
11 changes: 8 additions & 3 deletions src/modes/serial/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,14 @@ namespace occa {
const occa::properties &props) {
memory *mem = new memory(this, bytes, props);

mem->ptr = (char*) sys::malloc(bytes);
if (src) {
::memcpy(mem->ptr, src, bytes);
if (props.get("use_host_pointer", false)) {
mem->ptr = (char*) const_cast<void*>(src);
mem->isOrigin = false;
} else {
mem->ptr = (char*) sys::malloc(bytes);
if (src) {
::memcpy(mem->ptr, src, bytes);
}
}

return mem;
Expand Down
12 changes: 12 additions & 0 deletions tests/src/c/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,25 @@ void testInteropMethods() {
bytes,
occaDefault);

mem2 = occaWrapCpuMemory(occaDefault,
ptr,
bytes,
occaDefault);

occaProperties props = (
occaCreatePropertiesFromString("foo: 'bar'")
);

mem2 = occaWrapCpuMemory(occaHost(),
ptr,
bytes,
props);

mem2 = occaWrapCpuMemory(occaDefault,
ptr,
bytes,
props);

occaFree(mem1);
occaFree(mem2);
occaFree(props);
Expand Down
42 changes: 42 additions & 0 deletions tests/src/core/memory.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include <occa.hpp>
#include <occa/tools/testing.hpp>

void testMalloc();
void testCpuWrapMemory();

int main(const int argc, const char **argv) {
testMalloc();
testCpuWrapMemory();

return 0;
}

void testMalloc() {
const occa::udim_t bytes = 1 * sizeof(int);
int value = 4660;
int *hostPtr = &value;

occa::memory mem = occa::malloc(bytes);

mem = occa::malloc(bytes, hostPtr);
ASSERT_EQ(((int*) mem.ptr())[0], value);
ASSERT_NEQ(mem.ptr<int>(), hostPtr);

mem = occa::malloc(bytes, hostPtr, "use_host_pointer: true");
ASSERT_EQ(mem.ptr<int>()[0], value);
ASSERT_EQ(mem.ptr<int>(), hostPtr);
}

void testCpuWrapMemory() {
const occa::udim_t bytes = 1 * sizeof(int);
int value = 4660;
int *hostPtr = &value;

occa::memory mem = occa::cpu::wrapMemory(hostPtr, bytes);
ASSERT_EQ(mem.ptr<int>()[0], value);
ASSERT_EQ(mem.ptr<int>(), hostPtr);

mem = occa::cpu::wrapMemory(hostPtr, bytes, "use_host_pointer: false");
ASSERT_EQ(mem.ptr<int>()[0], value);
ASSERT_EQ(mem.ptr<int>(), hostPtr);
}

0 comments on commit c61d636

Please sign in to comment.