Skip to content

Commit d039af2

Browse files
committed
optinal pg decimal to tnt decimal cast
1 parent 615a031 commit d039af2

File tree

2 files changed

+50
-16
lines changed

2 files changed

+50
-16
lines changed

pg/driver.c

+41-15
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@
4949
#undef PACKAGE_VERSION
5050
#include <module.h>
5151

52+
struct dec_opt {
53+
char cast;
54+
int dnew_index;
55+
};
56+
typedef struct dec_opt dec_opt_t;
57+
5258
/**
5359
* Infinity timeout from tarantool_ev.c. I mean, this should be in
5460
* a module.h file.
@@ -97,7 +103,7 @@ lua_push_error(struct lua_State *L)
97103
* Parse pg values to lua
98104
*/
99105
static int
100-
parse_pg_value(struct lua_State *L, char dec_cast, PGresult *res, int row, int col)
106+
parse_pg_value(struct lua_State *L, dec_opt_t *dopt, PGresult *res, int row, int col)
101107
{
102108
if (PQgetisnull(res, row, col))
103109
return false;
@@ -108,9 +114,18 @@ parse_pg_value(struct lua_State *L, char dec_cast, PGresult *res, int row, int c
108114

109115
switch (PQftype(res, col)) {
110116
case NUMERICOID: {
111-
if (dec_cast == 's')
112-
{
117+
if (dopt->cast == 's') {
118+
lua_pushlstring(L, val, len);
119+
break;
120+
}
121+
else if (dopt->cast == 'd' && dopt->dnew_index != -1) {
122+
lua_rawgeti(L, LUA_REGISTRYINDEX, dopt->dnew_index);
113123
lua_pushlstring(L, val, len);
124+
int fail = lua_pcall(L, 1, 1, 0);
125+
if (fail) {
126+
lua_pop(L, 2);
127+
return false;
128+
}
114129
break;
115130
}
116131
// else fallthrough
@@ -148,15 +163,15 @@ static int
148163
safe_pg_parsetuples(struct lua_State *L)
149164
{
150165
PGresult *res = (PGresult *)lua_topointer(L, 1);
151-
const char dec_cast = (char)lua_tointeger(L, 2);
166+
dec_opt_t *dopt = (dec_opt_t *)lua_topointer(L, 2);
152167
int row, rows = PQntuples(res);
153168
int col, cols = PQnfields(res);
154169
lua_newtable(L);
155170
for (row = 0; row < rows; ++row) {
156171
lua_pushnumber(L, row + 1);
157172
lua_newtable(L);
158173
for (col = 0; col < cols; ++col)
159-
parse_pg_value(L, dec_cast, res, row, col);
174+
parse_pg_value(L, dopt, res, row, col);
160175
lua_settable(L, -3);
161176
}
162177
return 1;
@@ -213,7 +228,7 @@ pg_wait_for_result(PGconn *conn)
213228
* Appends result fom postgres to lua table
214229
*/
215230
static int
216-
pg_resultget(struct lua_State *L, const char dec_cast, PGconn *conn, int *res_no, int status_ok)
231+
pg_resultget(struct lua_State *L, dec_opt_t *dopt, PGconn *conn, int *res_no, int status_ok)
217232
{
218233
int wait_res = pg_wait_for_result(conn);
219234
if (wait_res != 1)
@@ -243,7 +258,7 @@ pg_resultget(struct lua_State *L, const char dec_cast, PGconn *conn, int *res_no
243258
lua_pushinteger(L, (*res_no)++);
244259
lua_pushcfunction(L, safe_pg_parsetuples);
245260
lua_pushlightuserdata(L, pg_res);
246-
lua_pushinteger(L, dec_cast);
261+
lua_pushlightuserdata(L, dopt);
247262
fail = lua_pcall(L, 2, 1, 0);
248263
if (!fail)
249264
lua_settable(L, -3);
@@ -338,19 +353,25 @@ lua_pg_execute(struct lua_State *L)
338353
{
339354
PGconn *conn = lua_check_pgconn(L, 1);
340355

341-
char dec_cast = 'n';
356+
dec_opt_t dopt = {'n', -1};
342357
if (lua_isstring(L, 2)) {
343358
const char *tmp = lua_tostring(L, 2);
344-
if (*tmp == 'n' || *tmp == 's') // TODO 'd' - decimal
345-
dec_cast = *tmp;
359+
if (*tmp == 'n' || *tmp == 's' || *tmp == 'd')
360+
dopt.cast = *tmp;
346361
}
347362

348-
if (!lua_isstring(L, 3)) {
363+
if (!lua_isstring(L, 4)) {
349364
safe_pushstring(L, "Second param should be a sql command");
350365
return lua_push_error(L);
351366
}
352-
const char *sql = lua_tostring(L, 3);
353-
int paramCount = lua_gettop(L) - 3;
367+
368+
if (lua_isfunction(L, 3)) {
369+
lua_pushvalue(L, 3);
370+
dopt.dnew_index = luaL_ref(L, LUA_REGISTRYINDEX);
371+
}
372+
373+
const char *sql = lua_tostring(L, 4);
374+
int paramCount = lua_gettop(L) - 4;
354375

355376
const char **paramValues = NULL;
356377
int *paramLengths = NULL;
@@ -371,7 +392,7 @@ lua_pg_execute(struct lua_State *L)
371392

372393
int idx;
373394
for (idx = 0; idx < paramCount; ++idx) {
374-
lua_parse_param(L, idx + 4, paramValues + idx,
395+
lua_parse_param(L, idx + 5, paramValues + idx,
375396
paramLengths + idx, paramTypes + idx);
376397
}
377398
res = PQsendQueryParams(conn, sql, paramCount, paramTypes,
@@ -383,14 +404,19 @@ lua_pg_execute(struct lua_State *L)
383404
if (res == -1) {
384405
lua_pushinteger(L, PQstatus(conn) == CONNECTION_BAD ? -1: 0);
385406
lua_pushstring(L, PQerrorMessage(conn));
407+
if (dopt.dnew_index != -1)
408+
luaL_unref(L, LUA_REGISTRYINDEX, dopt.dnew_index);
386409
return 2;
387410
}
388411
lua_pushinteger(L, 0);
389412
lua_newtable(L);
390413

391414
int res_no = 1;
392415
int status_ok = 1;
393-
while ((status_ok = pg_resultget(L, dec_cast, conn, &res_no, status_ok)));
416+
while ((status_ok = pg_resultget(L, &dopt, conn, &res_no, status_ok)));
417+
418+
if (dopt.dnew_index != -1)
419+
luaL_unref(L, LUA_REGISTRYINDEX, dopt.dnew_index);
394420

395421
return 2;
396422
}

pg/init.lua

+9-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@ local fiber = require('fiber')
44
local driver = require('pg.driver')
55
local ffi = require('ffi')
66

7+
local dnew
8+
do
9+
local ok, dec = pcall(require, "decimal")
10+
if ok then
11+
dnew = dec.new
12+
end
13+
end
14+
715
local pool_mt
816
local conn_mt
917

@@ -61,7 +69,7 @@ conn_mt = {
6169
self.queue:put(false)
6270
return get_error(self.raise.pool, 'Connection is broken')
6371
end
64-
local status, datas = self.conn:execute(self.dec_cast, sql, ...)
72+
local status, datas = self.conn:execute(self.dec_cast, dnew, sql, ...)
6573
if status ~= 0 then
6674
self.queue:put(status > 0)
6775
return error(datas)

0 commit comments

Comments
 (0)