diff --git a/sproto.lua b/sproto.lua index 271e692..911dc87 100644 --- a/sproto.lua +++ b/sproto.lua @@ -101,22 +101,40 @@ end function sproto:request_encode(protoname, tbl) local p = queryproto(self, protoname) - return core.encode(p.request,tbl) , p.tag + local request = p.request + if request then + return core.encode(request,tbl) , p.tag + else + return "" , p.tag + end end function sproto:response_encode(protoname, tbl) local p = queryproto(self, protoname) - return core.encode(p.response,tbl) + local response = p.response + if response then + return core.encode(response,tbl) + else + return "" + end end function sproto:request_decode(protoname, ...) local p = queryproto(self, protoname) - return core.decode(p.request,...) , p.name + local request = p.request + if request then + return core.decode(request,...) , p.name + else + return nil, p.name + end end function sproto:response_decode(protoname, ...) local p = queryproto(self, protoname) - return core.decode(p.response,...) + local response = p.response + if response then + return core.decode(response,...) + end end sproto.pack = core.pack @@ -128,9 +146,13 @@ function sproto:default(typename, type) else local p = queryproto(self, typename) if type == "REQUEST" then - return core.default(p.request) + if p.request then + return core.default(p.request) + end elseif type == "RESPONSE" then - return core.default(p.response) + if p.response then + return core.default(p.response) + end else error "Invalid type" end diff --git a/testrpc.lua b/testrpc.lua index b645eef..8a02f9a 100644 --- a/testrpc.lua +++ b/testrpc.lua @@ -25,7 +25,6 @@ foo 2 { bar 3 {} blackhole 4 { - request {} } ]] @@ -36,8 +35,15 @@ local client_proto = sproto.parse [[ } ]] +print("=== default table") + print_r(server_proto:default("package")) print_r(server_proto:default("foobar", "REQUEST")) +assert(server_proto:default("foo", "REQUEST")==nil) +assert(server_proto:request_encode("foo")=="") +server_proto:response_encode("foo", { ok = true }) +assert(server_proto:request_decode("blackhole")==nil) +assert(server_proto:response_decode("blackhole")==nil) print("=== test 1")