Skip to content

[u]int64_t parameters; decimal->number cast removed in favor of original textual representation #28

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ Connect to a database.
- `db` - a database name
- `conn_string` (mutual exclusive with host, port, user, pass, db) - PostgreSQL
[connection string][PQconnstring]
- `dec_cast` - an option that switches casting types for `NUMERIC` PostgreSQL
type. Possible values: `n` (`number`), `s` (`string`), `d` (`decimal`).

*Returns*:

Expand Down
122 changes: 110 additions & 12 deletions pg/driver.c
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,31 @@
#undef PACKAGE_VERSION
#include <module.h>

/**
* The fallthrough attribute with a null statement serves as a fallthrough
* statement. It hints to the compiler that a statement that falls through
* to another case label, or user-defined label in a switch statement is
* intentional and thus the -Wimplicit-fallthrough warning must not trigger.
* The fallthrough attribute may appear at most once in each attribute list,
* and may not be mixed with other attributes. It can only be used in a switch
* statement (the compiler will issue an error otherwise), after a preceding
* statement and before a logically succeeding case label, or user-defined
* label.
*/
#if defined(__cplusplus) && __has_cpp_attribute(fallthrough)
# define FALLTHROUGH [[fallthrough]]
#elif __has_attribute(fallthrough) || (defined(__GNUC__) && __GNUC__ >= 7)
# define FALLTHROUGH __attribute__((fallthrough))
#else
# define FALLTHROUGH
#endif

struct dec_opt {
char cast;
int dnew_index;
};
typedef struct dec_opt dec_opt_t;

/**
* Infinity timeout from tarantool_ev.c. I mean, this should be in
* a module.h file.
Expand Down Expand Up @@ -97,7 +122,7 @@ lua_push_error(struct lua_State *L)
* Parse pg values to lua
*/
static int
parse_pg_value(struct lua_State *L, PGresult *res, int row, int col)
parse_pg_value(struct lua_State *L, PGresult *res, int row, int col, dec_opt_t *dopt)
{
if (PQgetisnull(res, row, col))
return false;
Expand All @@ -107,9 +132,26 @@ parse_pg_value(struct lua_State *L, PGresult *res, int row, int col)
int len = PQgetlength(res, row, col);

switch (PQftype(res, col)) {
case INT2OID:
case INT4OID:
case NUMERICOID: {
if (dopt->cast == 's') {
lua_pushlstring(L, val, len);
break;
}
else if (dopt->cast == 'd' && dopt->dnew_index != -1) {
lua_rawgeti(L, LUA_REGISTRYINDEX, dopt->dnew_index);
lua_pushlstring(L, val, len);
int fail = lua_pcall(L, 1, 1, 0);
if (fail) {
lua_pop(L, 2);
return false;
}
break;
}
/* 'n': fallthrough */
FALLTHROUGH;
}
case INT2OID:
case INT4OID: {
lua_pushlstring(L, val, len);
double v = lua_tonumber(L, -1);
lua_pop(L, 1);
Expand Down Expand Up @@ -141,14 +183,15 @@ static int
safe_pg_parsetuples(struct lua_State *L)
{
PGresult *res = (PGresult *)lua_topointer(L, 1);
dec_opt_t *dopt = (dec_opt_t *)lua_topointer(L, 2);
int row, rows = PQntuples(res);
int col, cols = PQnfields(res);
lua_newtable(L);
for (row = 0; row < rows; ++row) {
lua_pushnumber(L, row + 1);
lua_newtable(L);
for (col = 0; col < cols; ++col)
parse_pg_value(L, res, row, col);
parse_pg_value(L, res, row, col, dopt);
lua_settable(L, -3);
}
return 1;
Expand Down Expand Up @@ -205,7 +248,7 @@ pg_wait_for_result(PGconn *conn)
* Appends result fom postgres to lua table
*/
static int
pg_resultget(struct lua_State *L, PGconn *conn, int *res_no, int status_ok)
pg_resultget(struct lua_State *L, PGconn *conn, int *res_no, int status_ok, dec_opt_t *dopt)
{
int wait_res = pg_wait_for_result(conn);
if (wait_res != 1)
Expand Down Expand Up @@ -235,9 +278,13 @@ pg_resultget(struct lua_State *L, PGconn *conn, int *res_no, int status_ok)
lua_pushinteger(L, (*res_no)++);
lua_pushcfunction(L, safe_pg_parsetuples);
lua_pushlightuserdata(L, pg_res);
fail = lua_pcall(L, 1, 1, 0);
if (!fail)
lua_pushlightuserdata(L, dopt);
fail = lua_pcall(L, 2, 1, 0);
if (!fail) {
lua_settable(L, -3);
break;
}
break;
case PGRES_COMMAND_OK:
res = 1;
break;
Expand Down Expand Up @@ -269,6 +316,15 @@ static void
lua_parse_param(struct lua_State *L,
int idx, const char **value, int *length, Oid *type)
{
/* Serialized [u]int64_t */
static char buf[512];
static char *pos = NULL;
/* lua_parse_param(L, idx + 5, ...) */
if (idx == 5) {
*buf = '\0';
pos = buf;
}

if (lua_isnil(L, idx)) {
*value = NULL;
*length = 0;
Expand All @@ -293,6 +349,27 @@ lua_parse_param(struct lua_State *L,
return;
}

if (luaL_iscdata(L, idx)) {
uint32_t ctypeid = 0;
void *cdata = luaL_checkcdata(L, idx, &ctypeid);
int len = 0;
if (ctypeid == luaL_ctypeid(L, "int64_t")) {
len = snprintf(pos, sizeof(buf) - (pos - buf), "%ld", *(int64_t*)cdata);
*type = INT8OID;
}
else if (ctypeid == luaL_ctypeid(L, "uint64_t")) {
len = snprintf(pos, sizeof(buf) - (pos - buf), "%lu", *(uint64_t*)cdata);
*type = NUMERICOID;
}

if (len > 0) {
*value = pos;
*length = len;
pos += len + 1;
return;
}
}

// We will pass all other types as strings
size_t len;
*value = lua_tolstring(L, idx, &len);
Expand All @@ -307,12 +384,28 @@ static int
lua_pg_execute(struct lua_State *L)
{
PGconn *conn = lua_check_pgconn(L, 1);
if (!lua_isstring(L, 2)) {

dec_opt_t dopt = {'n', -1};
if (lua_isstring(L, 2)) {
const char *dec_cast_type = lua_tostring(L, 2);
if (*dec_cast_type == 'n' ||
*dec_cast_type == 's' ||
*dec_cast_type == 'd')
dopt.cast = *dec_cast_type;
}

if (!lua_isstring(L, 4)) {
safe_pushstring(L, "Second param should be a sql command");
return lua_push_error(L);
}
const char *sql = lua_tostring(L, 2);
int paramCount = lua_gettop(L) - 2;

if (lua_isfunction(L, 3)) {
lua_pushvalue(L, 3);
dopt.dnew_index = luaL_ref(L, LUA_REGISTRYINDEX);
}

const char *sql = lua_tostring(L, 4);
int paramCount = lua_gettop(L) - 4;

const char **paramValues = NULL;
int *paramLengths = NULL;
Expand All @@ -333,7 +426,7 @@ lua_pg_execute(struct lua_State *L)

int idx;
for (idx = 0; idx < paramCount; ++idx) {
lua_parse_param(L, idx + 3, paramValues + idx,
lua_parse_param(L, idx + 5, paramValues + idx,
paramLengths + idx, paramTypes + idx);
}
res = PQsendQueryParams(conn, sql, paramCount, paramTypes,
Expand All @@ -345,14 +438,19 @@ lua_pg_execute(struct lua_State *L)
if (res == -1) {
lua_pushinteger(L, PQstatus(conn) == CONNECTION_BAD ? -1: 0);
lua_pushstring(L, PQerrorMessage(conn));
if (dopt.dnew_index != -1)
luaL_unref(L, LUA_REGISTRYINDEX, dopt.dnew_index);
return 2;
}
lua_pushinteger(L, 0);
lua_newtable(L);

int res_no = 1;
int status_ok = 1;
while ((status_ok = pg_resultget(L, conn, &res_no, status_ok)));
while ((status_ok = pg_resultget(L, conn, &res_no, status_ok, &dopt)));

if (dopt.dnew_index != -1)
luaL_unref(L, LUA_REGISTRYINDEX, dopt.dnew_index);

return 2;
}
Expand Down
11 changes: 10 additions & 1 deletion pg/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ local fiber = require('fiber')
local driver = require('pg.driver')
local ffi = require('ffi')

local has_decimal, dec = pcall(require, 'decimal')
if has_decimal then
dnew = dec.new
end

local pool_mt
local conn_mt

Expand All @@ -15,6 +20,10 @@ local function conn_create(pg_conn)
usable = true,
conn = pg_conn,
queue = queue,
dec_cast = 'n' -- Defined in pg/driver.c:
-- 'n' - number,
-- 's' - string,
-- 'd' - decimal.
}, conn_mt)

return conn
Expand Down Expand Up @@ -60,7 +69,7 @@ conn_mt = {
self.queue:put(false)
return get_error(self.raise.pool, 'Connection is broken')
end
local status, datas = self.conn:execute(sql, ...)
local status, datas = self.conn:execute(self.dec_cast, dnew, sql, ...)
if status ~= 0 then
self.queue:put(status > 0)
return error(datas)
Expand Down
36 changes: 36 additions & 0 deletions test/pg.test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,47 @@ function test_pg_int64(t, p)
p:put(conn)
end

function test_pg_decimal(t, p)
t:plan(8)

-- Setup
conn = p:get()
t:isnt(conn, nil, 'connection is established')
local num = 4500
conn:execute('CREATE TABLE dectest (num NUMERIC(7,2))')
conn:execute(('INSERT INTO dectest VALUES(%d)'):format(num))

local res, r, _
-- dec_cast is 'n'
t:is(conn.dec_cast, 'n', 'decimal casting type is "n" by default')
r, _ = conn:execute('SELECT num FROM dectest')
res = r[1][1]['num']
t:is(type(res), 'number', 'type is "number"')
t:is(res, num, 'decimal number is correct')
-- dec_cast is 's'
conn.dec_cast = 's'
r, _ = conn:execute('SELECT num FROM dectest')
res = r[1][1]['num']
t:is(type(res), 'string', 'type is "string"')
t:is(res, '4500.00', 'decimal number is correct')
-- dec_cast is 'd'
conn.dec_cast = 'd'
r, _ = conn:execute('SELECT num FROM dectest')
res = r[1][1]['num']
t:is(type(res), 'cdata', 'type is "decimal"')
t:is(res, num, 'decimal number is correct')

-- Teardown
conn:execute('DROP TABLE dectest')
p:put(conn)
end

tap.test('connection old api', test_old_api, conn)
local pool_conn = p:get()
tap.test('connection old api via pool', test_old_api, pool_conn)
p:put(pool_conn)
tap.test('test collection connections', test_gc, p)
tap.test('connection concurrent', test_conn_concurrent, p)
tap.test('int64', test_pg_int64, p)
tap.test('decimal', test_pg_decimal, p)
p:close()