Skip to content

Basic reference wrapper support #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions Source/LuaBridge/detail/CFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,22 @@ template <class T>
auto unwrap_argument_or_error(lua_State* L, std::size_t index, std::size_t start)
{
auto result = Stack<T>::get(L, static_cast<int>(index + start));
if (! result)
raise_lua_error(L, "Error decoding argument #%d: %s", static_cast<int>(index + 1), result.message().c_str());
if (result)
return std::move(*result);

// TODO - this might be costly, how to deal with it ?
if constexpr (! std::is_lvalue_reference_v<T>)
{
using U = std::reference_wrapper<std::remove_reference_t<T>>;

return std::move(*result);
auto resultRef = Stack<U>::get(L, static_cast<int>(index));
if (resultRef)
return (*resultRef).get();
}

raise_lua_error(L, "Error decoding argument #%d: %s", static_cast<int>(index + 1), result.message().c_str());

unreachable();
}

template <class ArgsPack, std::size_t Start, std::size_t... Indices>
Expand Down Expand Up @@ -567,7 +579,6 @@ struct property_getter<T, void>
return 1;
}
};

/**
* @brief lua_CFunction to get a class data member.
*
Expand Down
41 changes: 41 additions & 0 deletions Source/LuaBridge/detail/LuaHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,47 @@ void* lua_newuserdata_aligned(lua_State* L, Args&&... args)
return pointer;
}

/**
* @brief Deallocate lua userdata from pointer.
*/
template <class T>
int lua_deleteuserdata_pointer(lua_State* L)
{
assert(isfulluserdata(L, 1));

T** aligned = align<T*>(lua_touserdata(L, 1));
delete *aligned;

return 0;
}

/**
* @brief Allocate lua userdata from pointer.
*/
template <class T>
void* lua_newuserdata_pointer(lua_State* L, T* ptr)
{
#if LUABRIDGE_ON_LUAU
void* pointer = lua_newuserdatadtor(L, maximum_space_needed_to_align<T*>(), [](void* x)
{
T** aligned = align<T*>(x);
delete *aligned;
});
#else
void* pointer = lua_newuserdata_x<T*>(L, maximum_space_needed_to_align<T*>());

lua_newtable(L);
lua_pushcfunction_x(L, &lua_deleteuserdata_pointer<T*>);
rawsetfield(L, -2, "__gc");
lua_setmetatable(L, -2);
#endif

T** aligned = align<T*>(pointer);
*aligned = ptr;

return pointer;
}

/**
* @brief Safe error able to walk backwards for error reporting correctly.
*/
Expand Down
76 changes: 74 additions & 2 deletions Source/LuaBridge/detail/Stack.h
Original file line number Diff line number Diff line change
Expand Up @@ -1343,8 +1343,81 @@ struct Stack<T[N]>
}
};

namespace detail {
//=================================================================================================
/**
* @brief Stack specialization for `std::reference_wrapper`.
*/
template <class T>
struct Stack<std::reference_wrapper<T>>
{
static Result push(lua_State* L, const std::reference_wrapper<T>& reference)
{
lua_newuserdata_aligned<std::reference_wrapper<T>>(L, reference.get());

luaL_newmetatable(L, typeName());
lua_pushvalue(L, -2);
lua_pushcclosure_x(L, &get_set_reference_value<T>, 1);
rawsetfield(L, -2, "__call");
lua_setmetatable(L, -2);

return {};
}

static TypeResult<std::reference_wrapper<T>> get(lua_State* L, int index)
{
auto ptr = luaL_testudata(L, index, typeName());
if (ptr == nullptr)
return makeErrorCode(ErrorCode::InvalidTypeCast);

auto reference = reinterpret_cast<std::reference_wrapper<T>*>(ptr);
if (reference == nullptr)
return makeErrorCode(ErrorCode::InvalidTypeCast);

return *reference;
}

static bool isInstance(lua_State* L, int index)
{
return luaL_testudata(L, index, typeName()) != nullptr;
}

private:
static const char* typeName()
{
static const std::string s{ detail::typeName<std::reference_wrapper<T>>() };
return s.c_str();
}

template <class U>
static int get_set_reference_value(lua_State* L)
{
LUABRIDGE_ASSERT(lua_isuserdata(L, lua_upvalueindex(1)));

std::reference_wrapper<U>* ptr = static_cast<std::reference_wrapper<U>*>(lua_touserdata(L, lua_upvalueindex(1)));
LUABRIDGE_ASSERT(ptr != nullptr);

if (lua_gettop(L) > 1)
{
auto result = Stack<U>::get(L, 2);
if (! result)
luaL_error(L, "%s", result.message().c_str());

ptr->get() = *result;

return 0;
}
else
{
auto result = Stack<U>::push(L, ptr->get());
if (! result)
luaL_error(L, "%s", result.message().c_str());

return 1;
}
}
};

namespace detail {
template <class T>
struct StackOpSelector<T&, false>
{
Expand Down Expand Up @@ -1398,7 +1471,6 @@ struct StackOpSelector<const T*, false>

static bool isInstance(lua_State* L, int index) { return Stack<T>::isInstance(L, index); }
};

} // namespace detail

template <class T>
Expand Down
4 changes: 3 additions & 1 deletion Source/LuaBridge/detail/Userdata.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class Userdata
lua_remove(L, -2); // Stack: rt, pot
}

// no return
unreachable();
}

static bool isInstance(lua_State* L, int index, const void* registryClassKey)
Expand Down Expand Up @@ -158,6 +158,8 @@ class Userdata

lua_remove(L, -2); // Stack: rt, pot
}

unreachable();
}

static Userdata* throwBadArg(lua_State* L, int index)
Expand Down
133 changes: 133 additions & 0 deletions Tests/Source/ClassTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2786,6 +2786,139 @@ TEST_F(ClassTests, NewIndexFallbackMetaMethodFreeFunctor)
ASSERT_EQ(246, result<int>());
}

TEST_F(ClassTests, ReferenceWrapperRead)
{
int x = 13;
std::reference_wrapper<int> ref_wrap_x(x);

luabridge::getGlobalNamespace(L)
.beginNamespace("test")
.addProperty("ref_wrap_x", &ref_wrap_x)
.addFunction("changeReference", [](std::reference_wrapper<int> r) { r.get() = 100; })
.endNamespace();

runLua(R"(
result = test.ref_wrap_x
test.changeReference(result)
)");

EXPECT_TRUE(result().isUserdata());
EXPECT_EQ(x, result().unsafe_cast<std::reference_wrapper<int>>().get());
EXPECT_EQ(100, x);
}

TEST_F(ClassTests, ReferenceWrapperWrite)
{
int x = 13;
std::reference_wrapper<int> ref_wrap_x(x);

luabridge::getGlobalNamespace(L)
.beginNamespace("test")
.addProperty("ref_wrap_x", &ref_wrap_x)
.endNamespace();

runLua(R"(
test.ref_wrap_x(100)
result = test.ref_wrap_x
)");

EXPECT_TRUE(result().isUserdata());
EXPECT_EQ(x, result().unsafe_cast<std::reference_wrapper<int>>().get());
EXPECT_EQ(100, x);
}

TEST_F(ClassTests, ReferenceWrapperRedirect)
{
int x = 13;
int y = 100;
std::reference_wrapper<int> ref_wrap_x(x);
std::reference_wrapper<int> ref_wrap_y(y);

luabridge::getGlobalNamespace(L)
.beginNamespace("test")
.addProperty("ref_wrap_x", &ref_wrap_x)
.addProperty("ref_wrap_y", &ref_wrap_y)
.endNamespace();

runLua(R"(
test.ref_wrap_x = test.ref_wrap_y
result = test.ref_wrap_x
)");

EXPECT_TRUE(result().isUserdata());
EXPECT_EQ(y, result().unsafe_cast<std::reference_wrapper<int>>().get());
}

TEST_F(ClassTests, ReferenceWrapperDecaysToType)
{
int x = 13;
std::reference_wrapper<int> ref_wrap_x(x);

luabridge::getGlobalNamespace(L)
.beginNamespace("test")
.addProperty("ref_wrap_x", &ref_wrap_x)
.addFunction("takeReference", [](int r) { return r * 10; })
.endNamespace();

runLua(R"(
result = test.takeReference(test.ref_wrap_x)
)");

EXPECT_EQ(130, result().unsafe_cast<int>());
}

TEST_F(ClassTests, ReferenceWrapperFailsOnInvalidType)
{
int x = 13;
std::reference_wrapper ref_wrap_x(x);

float y = 1.0f;
std::reference_wrapper ref_wrap_y(y);

luabridge::getGlobalNamespace(L)
.beginNamespace("test")
.addProperty("ref_wrap_x", &ref_wrap_x)
.addProperty("ref_wrap_y", &ref_wrap_y)
.addFunction("takeReference1", [](float r) { return r * 10; })
.addFunction("takeReference2", [](int r) { return r * 10; })
.addFunction("takeReference3", [](std::reference_wrapper<float> r) { return r.get() * 10; })
.addFunction("takeReference4", [](std::reference_wrapper<int> r) { return r.get() * 10; })
.endNamespace();

#if LUABRIDGE_HAS_EXCEPTIONS
EXPECT_THROW(runLua("result = test.takeReference1(test.ref_wrap_x)"), std::exception);
EXPECT_THROW(runLua("result = test.takeReference2(test.ref_wrap_y)"), std::exception);
EXPECT_THROW(runLua("result = test.takeReference3(test.ref_wrap_x)"), std::exception);
EXPECT_THROW(runLua("result = test.takeReference4(test.ref_wrap_y)"), std::exception);
#else
EXPECT_FALSE(runLua("result = test.takeReference1(test.ref_wrap_x)"));
EXPECT_FALSE(runLua("result = test.takeReference2(test.ref_wrap_y)"));
EXPECT_FALSE(runLua("result = test.takeReference3(test.ref_wrap_x)"));
EXPECT_FALSE(runLua("result = test.takeReference4(test.ref_wrap_y)"));
#endif
}

TEST_F(ClassTests, ReferenceWrapperAccessFromLua)
{
int x = 13;
std::reference_wrapper<int> ref_wrap_x(x);

luabridge::getGlobalNamespace(L)
.beginNamespace("test")
.addProperty("ref_wrap_x", &ref_wrap_x)
.endNamespace();

runLua(R"(
function xyz(x)
return x() * 10
end

result = xyz(test.ref_wrap_x)
)");

EXPECT_EQ(130, result().unsafe_cast<int>());
}

namespace {
struct ExtensibleBase
{
Expand Down